In [2]:
import os
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

from octo.model.octo_model import OctoModel
model = OctoModel.load_pretrained("hf://rail-berkeley/octo-small-1.5")

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow_datasets as tfds
import cv2
import jax
from PIL import Image
import mediapy as mp
import tensorflow as tf
import tqdm

### Load the BRIDGE Dataset
print("Loading BRIDGE dataset...")
builder = tfds.builder_from_directory(builder_dir="gs://gresearch/robotics/bridge/0.1.0/")
ds = builder.as_dataset(split="train[:1]")  # Load first episode

# Extract a single episode
episode = next(iter(ds))
steps = list(episode["steps"])

# Resize images to 256x256 (default for Octo model)
images = [cv2.resize(np.array(step["observation"]["image"]), (256, 256)) for step in steps]

# Extract goal image (last frame) & language instruction
goal_image = images[-1]
language_instruction = steps[0]["observation"]["natural_language_instruction"].numpy().decode()

print(f"Instruction: {language_instruction}")
for img in images:
    cv2.imshow("Episode Frame", img)
    cv2.waitKey(100)  # Wait 100ms per frame

cv2.destroyAllWindows()

### Load the Pretrained Octo Model
print("Loading Octo model checkpoint...")
model = OctoModel.load_pretrained("hf://rail-berkeley/octo-base-1.5")

WINDOW_SIZE = 2
task = model.create_tasks(goals={"image_primary": goal_image[None]})   # for goal-conditioned
task = model.create_tasks(texts=[language_instruction])                # for language conditioned

pred_actions, true_actions, attention_maps_per_step = [], [], []
for step in tqdm.trange(len(images) - (WINDOW_SIZE - 1)):
    input_images = np.stack(images[step:step+WINDOW_SIZE])[None]
    observation = {
        'image_primary': input_images,
        'timestep_pad_mask': np.full((1, input_images.shape[1]), True, dtype=bool)
    }

    # Get both predicted actions and attention maps
    actions, attention_maps = model.sample_actions(
        observation,
        task,
        unnormalization_statistics=model.dataset_statistics["bridge_dataset"]["action"],
        rng=jax.random.PRNGKey(0)
    )

    # Store predicted actions
    pred_actions.append(actions[0])

    # Store attention maps (for all layers & heads at this step)
    attention_maps_per_step.append(attention_maps)

    # Store true actions
    final_window_step = step + WINDOW_SIZE - 1
    true_actions.append(np.concatenate(
        (
            steps[final_window_step]['action']['world_vector'],
            steps[final_window_step]['action']['rotation_delta'],
            np.array(steps[final_window_step]['action']['open_gripper']).astype(np.float32)[None]
        ), axis=-1
    ))


Fetching 6 files:   0%|          | 0/6 [00:00<?, ?it/s]

Loading BRIDGE dataset...


2025-02-25 05:08:26.099086: W external/local_tsl/tsl/platform/cloud/google_auth_provider.cc:184] All attempts to get a Google authentication bearer token failed, returning an empty token. Retrieving token from files failed with "NOT_FOUND: Could not locate the credentials file.". Retrieving token from GCE failed with "FAILED_PRECONDITION: Error executing an HTTP request: libcurl code 6 meaning 'Couldn't resolve host name', error details: Could not resolve host: metadata.google.internal".


Instruction: Place the can to the left of the pot.


2025-02-25 05:08:30.654 python[3727:70242] +[IMKClient subclass]: chose IMKClient_Modern
2025-02-25 05:08:30.654 python[3727:70242] +[IMKInputSession subclass]: chose IMKInputSession_Modern


Loading Octo model checkpoint...


Fetching 6 files:   0%|          | 0/6 [00:00<?, ?it/s]

100%|██████████| 37/37 [00:20<00:00,  1.83it/s]
