In [1]:
import rclpy, asyncio
import matplotlib.pyplot as plt
import numpy as np
from subscriber import Subscriber, get_observation, wait_for_message
from sensor_msgs.msg import Image
from std_msgs.msg import String, Float32MultiArray
from time import sleep

In [2]:
import os
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

import cv2
import jax
import tensorflow_datasets as tfds
import tqdm
import mediapy
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
PATH_CHECKPOINTS = "/media/irobotics/Transcend/finetuned_checkpoints/v4_checkpoints/"
PATH_DATASET_ROSBAG = "/media/irobotics/Transcend/isaacsim_data/v4_test/"
PATH_DATASET_TFDS = '/media/irobotics/Transcend/tensorflow_datasets/v4_test/example_dataset/1.0.0/'
PATH_INFERENCE_RESULTS = "/media/irobotics/Transcend/inference_result/"

In [4]:
from octo.model.octo_model import OctoModel

model = OctoModel.load_pretrained(PATH_CHECKPOINTS)

2024-09-25 14:43:20.017393: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-09-25 14:43:20.017434: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-09-25 14:43:20.018018: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [5]:
# create RLDS dataset builder
builder = tfds.builder_from_directory(builder_dir=PATH_DATASET_TFDS)
ds = builder.as_dataset(split='train[:2]')
iterator = iter(ds)
episode = next(iterator)
# sample episode + resize to 256x256 (default third-person cam resolution)
steps = list(episode['steps'])
images = [cv2.resize(np.array(step['observation']['image']), (256, 256)) for step in steps]
# extract goal image & language instruction
goal_image = images[-1]
language_instruction = steps[100]['language_instruction'].numpy().decode()

2024-09-25 14:43:56.455551: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2256] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...


### Inference begins

In [6]:
# create `task` dict
# task = model.create_tasks(goals={"image_primary": goal_image[None]})   # for goal-conditioned
task = model.create_tasks(goals={"image_primary": goal_image})
task = model.create_tasks(texts=[language_instruction])                  # for language conditioned



In [None]:
rclpy.init(args=None)
node = Subscriber()
publisher = node.create_publisher(Float32MultiArray, 'online_eff_topic', 10)
pub_msg = Float32MultiArray()

In [21]:
# collect predicted and true actions
pred_actions= []
# true_actions = steps[:]['action']
window = 2
input_images_stack = [np.zeros((256, 256, 3)) for _ in range(window-1)]
input_images_wrist_stack = [np.zeros((512, 512, 3)) for _ in range(window-1)]

input_images = get_observation(node, Image, '/Camera_rgb', (256, 256, 3))
input_images_wrist = get_observation(node, Image, '/Camera_wrist_rgb', (512, 512, 3))
input_images_stack.append(input_images)
input_images_wrist_stack.append(input_images_wrist)

# cur_eff = np.array([0.26, 0.0, 0.4])

In [16]:
async def publish_efforts(pub_msg, trigger):
    input_images, input_images_wrist = None, None
    if not trigger:
        for _ in range(200):
            publisher.publish(pub_msg)
            await asyncio.sleep(0.1)
        input_images = get_observation(node, Image, '/Camera_rgb', (256, 256, 3))
        input_images_wrist = get_observation(node, Image, '/Camera_wrist_rgb', (512, 512, 3))
    else:
        cur_eff = wait_for_message(node, Float32MultiArray, '/eff_topic').data[0:3]
        delta = np.array(pub_msg.data[0:3]) - cur_eff
        gripper_eff = pub_msg.data[3]
        for i in range(4):
            cur_eff += delta/2
            pub_msg.data = cur_eff.tolist()
            pub_msg.data.append(gripper_eff)
            # print(f"{i}th triggering: ", pub_msg.data)
            for _ in range(50):
                publisher.publish(pub_msg)
                await asyncio.sleep(0.1)
        input_images = get_observation(node, Image, '/Camera_rgb', (256, 256, 3))
        input_images_wrist = get_observation(node, Image, '/Camera_wrist_rgb', (512, 512, 3))

    return input_images, input_images_wrist

In [22]:
# run inference loop, this model only uses 3rd person image observations for bridge
import warnings
warnings.filterwarnings("ignore")
trigger = True
while True:

    observation = {
        'image_primary': np.stack(input_images_stack)[None],
        'image_wrist': np.stack(input_images_wrist_stack)[None],
        'timestep_pad_mask': np.full((1, np.stack(input_images_stack)[None].shape[1]), True, dtype=bool)
    }

    # one step actions
    actions = model.sample_actions(
        observation, 
        task, 
        unnormalization_statistics=model.dataset_statistics["action"], 
        rng=jax.random.PRNGKey(0)
    )
    actions = actions[0] # remove batch dim
    print("Predicted actions: ", actions[0,:4].tolist())
    pred_actions.append(actions[0,:4].tolist())

    # publish actions to robot
    pred_eff = actions[0,:3].tolist()
    pred_gripper = 1 if actions[0,3] > 0.5 else 0
    # pub_msg.data = actions[0,:].tolist()

    #### manual triggering
    # delta = np.array(pred_eff) - cur_eff
    # for i in range(4):
    #     cur_eff += delta/2
    #     pub_msg.data = cur_eff.tolist()
    #     pub_msg.data.append(pred_gripper)
    #     # print(f"{i}th triggering: ", pub_msg.data)
    #     for _ in range(20):
    #         publisher.publish(pub_msg)
    #         sleep(0.1)
    # cur_eff = np.array(pred_eff)

    pub_msg.data = pred_eff
    pub_msg.data.append(pred_gripper)
    input_images, input_images_wrist = asyncio.run(publish_efforts(pub_msg, trigger))

    if input_images is not None and input_images_wrist is not None:
        print("New images received")
        # input stack pop front
        input_images_stack.pop(0)
        input_images_wrist_stack.pop(0)
        input_images_stack.append(input_images)
        input_images_wrist_stack.append(input_images_wrist)
    else:
        print("No new images received")
        continue


    # TODO: how to check if episode is done, then break


Predicted actions:  [0.2688314914703369, 0.06284713000059128, 0.4025217294692993, 0.019915001466870308]
New images received
Predicted actions:  [0.25711506605148315, 0.06557381898164749, 0.40063047409057617, -0.001128721283748746]
New images received
Predicted actions:  [0.2579149603843689, 0.06228829175233841, 0.40244895219802856, -0.0034990611020475626]
New images received
Predicted actions:  [0.256365031003952, 0.0593552328646183, 0.40309444069862366, -0.0033494585659354925]


KeyboardInterrupt: 

In [20]:
### publish random actions
pub_msg.data = [0.26, 0.0, 0.4, 0.01] # initial point
# pub_msg.data = [0.5, 0.5, 0.2, 0.01]


# while rclpy.ok():
#     publisher.publish(pub_msg)
for i in range(300):
    publisher.publish(pub_msg)
    sleep(0.1)

input_images_stack.pop(0)
input_images_wrist_stack.pop(0)
input_images = get_observation(node, Image, '/Camera_rgb', (256, 256, 3))
input_images_wrist = get_observation(node, Image, '/Camera_wrist_rgb', (512, 512, 3))
input_images_stack.append(input_images)
input_images_wrist_stack.append(input_images_wrist)

In [11]:
node.destroy_node()
rclpy.shutdown()

In [17]:
np.save(PATH_INFERENCE_RESULTS + "0923_online_pred_actions.npy", pred_actions)