In [1]:
import rclpy
import matplotlib.pyplot as plt
import numpy as np
from subscriber import Subscriber, sub_call

In [None]:
rclpy.init(args=None)
node = Subscriber()

camera_image_array = sub_call(node, '/Camera_rgb', (512, 512, 3))
if camera_image_array:
    node.get_logger().info('Successfully received the latest image!')
    plt.imshow(camera_image_array)


wrist_image_array = sub_call(node, '/Camera_wrist_rgb', (512, 512, 3))
if camera_image_array:
    node.get_logger().info('Successfully received the latest image!')
    plt.imshow(wrist_image_array)

language_msg = sub_call(node, '/language_topic', None)
if language_msg:
    node.get_logger().info('Successfully received the latest language!')
    print(language_msg.data)

node.destroy_node()
rclpy.shutdown()

In [None]:
import os
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

import cv2
import jax
import tensorflow_datasets as tfds
import tqdm
import mediapy
import numpy as np

In [None]:
PATH_CHECKPOINTS = "/media/irobotics/Transcend/finetuned_checkpoints/v4_checkpoints/"
PATH_DATASET_ROSBAG = "/media/irobotics/Transcend/isaacsim_data/v4_test/"
PATH_DATASET_TFDS = '/media/irobotics/Transcend/tensorflow_datasets/v4_test/example_dataset/1.0.0/'
PATH_INFERENCE_RESULTS = "/media/irobotics/Transcend/inference_result/"

In [None]:
from octo.model.octo_model import OctoModel

model = OctoModel.load_pretrained(PATH_CHECKPOINTS)

In [None]:
# create RLDS dataset builder
builder = tfds.builder_from_directory(builder_dir=PATH_DATASET_TFDS)
ds = builder.as_dataset(split='train[:2]')
iterator = iter(ds)
episode = next(iterator)
# sample episode + resize to 256x256 (default third-person cam resolution)
steps = list(episode['steps'])
images = [cv2.resize(np.array(step['observation']['image']), (256, 256)) for step in steps]
# extract goal image & language instruction
goal_image = images[-1]
language_instruction = steps[100]['language_instruction'].numpy().decode()

### Inference

In [None]:
# create `task` dict
# task = model.create_tasks(goals={"image_primary": goal_image[None]})   # for goal-conditioned
task = model.create_tasks(goals={"image_primary": goal_image})
task = model.create_tasks(texts=[language_instruction])                  # for language conditioned

In [None]:
# run inference loop, this model only uses 3rd person image observations for bridge
# collect predicted and true actions
pred_actions= []
true_actions = steps[:]['action']
while True:
    input_images = sub_call(node, '/Camera_rgb', (512, 512, 3))
    input_images_wrist = sub_call(node, '/Camera_wrist_rgb', (512, 512, 3))
    observation = {
        'image_primary': input_images,
        'image_wrist': input_images_wrist,
        'timestep_pad_mask': np.full((1, input_images.shape[1]), True, dtype=bool)
    }
    
    # this returns *normalized* actions --> we need to unnormalize using the dataset statistics
    actions = model.sample_actions(
        observation, 
        task, 
        unnormalization_statistics=model.dataset_statistics["action"], 
        rng=jax.random.PRNGKey(0)
    )
    actions = actions[0] # remove batch dim

    pred_actions.append(actions)
    # TODO: publish actions to robot
