In [1]:
import os
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

from octo.model.octo_model import OctoModel
model = OctoModel.load_pretrained("hf://rail-berkeley/octo-small-1.5")

Fetching 6 files:   0%|          | 0/6 [00:00<?, ?it/s]

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow_datasets as tfds
import cv2
import jax
from PIL import Image
import mediapy as mp
import tensorflow as tf
import tqdm

### Load the BRIDGE Dataset
print("Loading BRIDGE dataset...")
builder = tfds.builder_from_directory(builder_dir="gs://gresearch/robotics/bridge/0.1.0/")
ds = builder.as_dataset(split="train[:1]")  # Load first episode

# Extract a single episode
episode = next(iter(ds))
steps = list(episode["steps"])

# Resize images to 256x256 (default for Octo model)
images = [cv2.resize(np.array(step["observation"]["image"]), (256, 256)) for step in steps]

# Extract goal image (last frame) & language instruction
goal_image = images[-1]
language_instruction = steps[0]["observation"]["natural_language_instruction"].numpy().decode()

print(f"Instruction: {language_instruction}")
for img in images:
    cv2.imshow("Episode Frame", img)
    cv2.waitKey(100)  # Wait 100ms per frame

cv2.destroyAllWindows()

Loading BRIDGE dataset...


2025-02-23 16:27:03.125628: W external/local_tsl/tsl/platform/cloud/google_auth_provider.cc:184] All attempts to get a Google authentication bearer token failed, returning an empty token. Retrieving token from files failed with "NOT_FOUND: Could not locate the credentials file.". Retrieving token from GCE failed with "FAILED_PRECONDITION: Error executing an HTTP request: libcurl code 6 meaning 'Couldn't resolve host name', error details: Could not resolve host: metadata.google.internal".


Instruction: Place the can to the left of the pot.


2025-02-23 16:27:11.592 python[34428:2606592] +[IMKClient subclass]: chose IMKClient_Modern
2025-02-23 16:27:11.592 python[34428:2606592] +[IMKInputSession subclass]: chose IMKInputSession_Modern


In [3]:
### Load the Pretrained Octo Model
print("Loading Octo model checkpoint...")
model = OctoModel.load_pretrained("hf://rail-berkeley/octo-base-1.5")

WINDOW_SIZE = 2
task = model.create_tasks(goals={"image_primary": goal_image[None]})   # for goal-conditioned
task = model.create_tasks(texts=[language_instruction])                # for language conditioned

Loading Octo model checkpoint...


Fetching 6 files:   0%|          | 0/6 [00:00<?, ?it/s]



In [4]:
pred_actions, true_actions, attention_maps_per_step = [], [], []
for step in tqdm.trange(len(images) - (WINDOW_SIZE - 1)):
    input_images = np.stack(images[step:step+WINDOW_SIZE])[None]
    observation = {
        'image_primary': input_images,
        'timestep_pad_mask': np.full((1, input_images.shape[1]), True, dtype=bool)
    }

    # Get both predicted actions and attention maps
    actions, attention_maps = model.sample_actions(
        observation,
        task,
        unnormalization_statistics=model.dataset_statistics["bridge_dataset"]["action"],
        rng=jax.random.PRNGKey(0)
    )

    # Store predicted actions
    pred_actions.append(actions[0])

    # Store attention maps (for all layers & heads at this step)
    attention_maps_per_step.append(attention_maps)

    # Store true actions
    final_window_step = step + WINDOW_SIZE - 1
    true_actions.append(np.concatenate(
        (
            steps[final_window_step]['action']['world_vector'],
            steps[final_window_step]['action']['rotation_delta'],
            np.array(steps[final_window_step]['action']['open_gripper']).astype(np.float32)[None]
        ), axis=-1
    ))


100%|██████████| 37/37 [00:26<00:00,  1.40it/s]


In [5]:
import numpy as np
import cv2
import matplotlib.pyplot as plt
from scipy.ndimage import zoom

def overlay_attention_on_image(image, attn_map, patch_size=16):
    """
    Overlays attention heatmap on an input image.

    Args:
        image: (H, W, C) - The original image.
        attn_map: (num_heads, num_tokens) - Attention scores for image tokens.
        patch_size: The patch size used in tokenization (default = 16).

    Returns:
        List of overlayed images (one per head).
    """
    num_heads, num_patches = attn_map.shape
    grid_size = int(np.sqrt(num_patches))  # E.g., 256 tokens → 16x16 grid
    overlayed_images = []

    for head in range(num_heads):
        # Normalize attention map per head
        attn_grid = attn_map[head].reshape((grid_size, grid_size))
        attn_grid = (attn_grid - attn_grid.min()) / (attn_grid.max() - attn_grid.min())

        # Upsample the attention map to match image size
        attn_heatmap = zoom(attn_grid, (patch_size, patch_size), order=1)

        # Convert attention map to heatmap
        attn_colormap = cv2.applyColorMap((attn_heatmap * 255).astype(np.uint8), cv2.COLORMAP_JET)

        # Resize the heatmap to match the image resolution
        attn_colormap = cv2.resize(attn_colormap, (image.shape[1], image.shape[0]))

        # Blend the heatmap with the original image
        overlayed_image = cv2.addWeighted(image, 0.6, attn_colormap, 0.4, 0)
        overlayed_images.append(overlayed_image)

    return overlayed_images  # List of images, one per head

In [8]:
import os
import tqdm
import jax.numpy as jnp

def visualize_attention_maps(attention_maps_per_step, images, save_path="attention_visualization"):
    """
    Visualizes and saves attention maps overlaid on images.

    Args:
        attention_maps_per_step: List of attention maps extracted from OctoModel.
        images: List of original images used for inference.
        save_path: Directory to save visualized attention maps.
    """
    os.makedirs(save_path, exist_ok=True)

    num_steps = len(attention_maps_per_step)
    num_layers = len(attention_maps_per_step[0]["all"])  # Assuming all steps have same layers

    for step in tqdm.trange(num_steps, desc="Visualizing Attention"):
        image = images[step]

        for layer_idx in range(num_layers):  # Iterate over all layers
            attn_map = attention_maps_per_step[step]["all"][layer_idx][0] # Shape: (num_heads, image_token_num)

            # Verify shape is correct
            if attn_map.ndim != 2:
                print(f"Unexpected shape for attention map at step {step}, layer {layer_idx}: {attn_map.shape}")
                continue

            # Get overlaid images per head
            overlayed_images = overlay_attention_on_image(image, attn_map)

            # Plot all attention heads in a row
            num_heads = attn_map.shape[0]
            fig, axes = plt.subplots(1, num_heads, figsize=(20, 5))

            for head_idx in range(num_heads):
                axes[head_idx].imshow(overlayed_images[head_idx])
                axes[head_idx].set_title(f"Layer {layer_idx} - Head {head_idx}")
                axes[head_idx].axis("off")

            plt.tight_layout()
            plt.savefig(os.path.join(save_path, f"step_{step}_layer_{layer_idx}.png"))
            plt.close()


In [9]:
# Run visualization function
visualize_attention_maps(attention_maps_per_step, images, save_path="attention_results")

Visualizing Attention: 100%|██████████| 37/37 [03:32<00:00,  5.76s/it]
