In [1]:
import os
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

from octo.model.octo_model import OctoModel
model = OctoModel.load_pretrained("hf://rail-berkeley/octo-small-1.5")

Fetching 6 files:   0%|          | 0/6 [00:00<?, ?it/s]

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow_datasets as tfds
import cv2
import jax
from PIL import Image
import mediapy as mp
import tensorflow as tf
import tqdm

### Load the BRIDGE Dataset
print("Loading BRIDGE dataset...")
builder = tfds.builder_from_directory(builder_dir="gs://gresearch/robotics/bridge/0.1.0/")
ds = builder.as_dataset(split="train[:1]")  # Load first episode

# Extract a single episode
episode = next(iter(ds))
steps = list(episode["steps"])

# Resize images to 256x256 (default for Octo model)
images = [cv2.resize(np.array(step["observation"]["image"]), (256, 256)) for step in steps]

# Extract goal image (last frame) & language instruction
goal_image = images[-1]
language_instruction = steps[0]["observation"]["natural_language_instruction"].numpy().decode()

print(f"Instruction: {language_instruction}")
for img in images:
    cv2.imshow("Episode Frame", img)
    cv2.waitKey(100)  # Wait 100ms per frame

cv2.destroyAllWindows()

Loading BRIDGE dataset...


2025-02-25 04:16:57.876399: W external/local_tsl/tsl/platform/cloud/google_auth_provider.cc:184] All attempts to get a Google authentication bearer token failed, returning an empty token. Retrieving token from files failed with "NOT_FOUND: Could not locate the credentials file.". Retrieving token from GCE failed with "FAILED_PRECONDITION: Error executing an HTTP request: libcurl code 6 meaning 'Couldn't resolve host name', error details: Could not resolve host: metadata.google.internal".


Instruction: Place the can to the left of the pot.


2025-02-25 04:17:03.409 python[2988:22304] +[IMKClient subclass]: chose IMKClient_Modern
2025-02-25 04:17:03.409 python[2988:22304] +[IMKInputSession subclass]: chose IMKInputSession_Modern


In [3]:
### Load the Pretrained Octo Model
print("Loading Octo model checkpoint...")
model = OctoModel.load_pretrained("hf://rail-berkeley/octo-base-1.5")

WINDOW_SIZE = 2
task = model.create_tasks(goals={"image_primary": goal_image[None]})   # for goal-conditioned
task = model.create_tasks(texts=[language_instruction])                # for language conditioned

Loading Octo model checkpoint...


Fetching 6 files:   0%|          | 0/6 [00:00<?, ?it/s]



In [4]:
pred_actions, true_actions, attention_maps_per_step = [], [], []
for step in tqdm.trange(len(images) - (WINDOW_SIZE - 1)):
    input_images = np.stack(images[step:step+WINDOW_SIZE])[None]
    observation = {
        'image_primary': input_images,
        'timestep_pad_mask': np.full((1, input_images.shape[1]), True, dtype=bool)
    }

    # Get both predicted actions and attention maps
    actions, attention_maps = model.sample_actions(
        observation,
        task,
        unnormalization_statistics=model.dataset_statistics["bridge_dataset"]["action"],
        rng=jax.random.PRNGKey(0)
    )

    # Store predicted actions
    pred_actions.append(actions[0])

    # Store attention maps (for all layers & heads at this step)
    attention_maps_per_step.append(attention_maps)

    # Store true actions
    final_window_step = step + WINDOW_SIZE - 1
    true_actions.append(np.concatenate(
        (
            steps[final_window_step]['action']['world_vector'],
            steps[final_window_step]['action']['rotation_delta'],
            np.array(steps[final_window_step]['action']['open_gripper']).astype(np.float32)[None]
        ), axis=-1
    ))


100%|██████████| 37/37 [00:24<00:00,  1.49it/s]


In [5]:
import numpy as np
import cv2
import matplotlib.pyplot as plt
from scipy.ndimage import zoom

def overlay_attention_on_image(image, attn_map, patch_size=16):
    """
    Overlays attention heatmap on an input image.

    Args:
        image: (H, W, C) - The original image.
        attn_map: (num_heads, num_tokens) - Attention scores for image tokens.
        patch_size: The patch size used in tokenization (default = 16).

    Returns:
        List of overlayed images (one per head).
    """
    num_heads, num_patches = attn_map.shape
    grid_size = int(np.sqrt(num_patches))  # E.g., 256 tokens → 16x16 grid
    overlayed_images = []

    for head in range(num_heads):
        # Normalize attention map per head
        attn_grid = attn_map[head].reshape((grid_size, grid_size))
        attn_grid = (attn_grid - attn_grid.min()) / (attn_grid.max() - attn_grid.min())

        # Upsample the attention map to match image size
        attn_heatmap = zoom(attn_grid, (patch_size, patch_size), order=1)

        # Convert attention map to heatmap
        attn_colormap = cv2.applyColorMap((attn_heatmap * 255).astype(np.uint8), cv2.COLORMAP_JET)

        # Resize the heatmap to match the image resolution
        attn_colormap = cv2.resize(attn_colormap, (image.shape[1], image.shape[0]))

        # Blend the heatmap with the original image
        overlayed_image = cv2.addWeighted(image, 0.6, attn_colormap, 0.4, 0)
        overlayed_images.append(overlayed_image)

    return overlayed_images  # List of images, one per head

In [6]:
import os
import tqdm
import jax.numpy as jnp

def visualize_attention_maps(attention_maps_per_step, images, save_path="attention_visualization"):
    """
    Visualizes and saves attention maps overlaid on images.

    Args:
        attention_maps_per_step: List of attention maps extracted from OctoModel.
        images: List of original images used for inference.
        save_path: Directory to save visualized attention maps.
    """
    os.makedirs(save_path, exist_ok=True)

    num_steps = len(attention_maps_per_step)
    num_layers = len(attention_maps_per_step[0]["all"])  # Assuming all steps have same layers

    for step in tqdm.trange(num_steps, desc="Visualizing Attention"):
        image = images[step]

        for layer_idx in range(num_layers):  # Iterate over all layers
            attn_map = attention_maps_per_step[step]["all"][layer_idx][0] # Shape: (num_heads, image_token_num)

            # Verify shape is correct
            if attn_map.ndim != 2:
                print(f"Unexpected shape for attention map at step {step}, layer {layer_idx}: {attn_map.shape}")
                continue

            # Get overlaid images per head
            overlayed_images = overlay_attention_on_image(image, attn_map)

            # Plot all attention heads in a row
            num_heads = attn_map.shape[0]
            fig, axes = plt.subplots(1, num_heads, figsize=(20, 5))

            for head_idx in range(num_heads):
                axes[head_idx].imshow(overlayed_images[head_idx])
                axes[head_idx].set_title(f"Layer {layer_idx} - Head {head_idx}")
                axes[head_idx].axis("off")

            plt.tight_layout()
            plt.savefig(os.path.join(save_path, f"step_{step}_layer_{layer_idx}.png"))
            plt.close()


In [7]:
# Run visualization function
visualize_attention_maps(attention_maps_per_step, images, save_path="attention_results")

Visualizing Attention: 100%|██████████| 37/37 [03:39<00:00,  5.92s/it]


In [8]:
import os
import imageio
import cv2
import numpy as np
from tqdm import tqdm

def create_attention_video(image_folder="attention_results", output_dir="output", fps=5):
    """
    Creates a video and a GIF from saved attention map images.

    Args:
        image_folder (str): Path to the folder containing attention images.
        output_dir (str): Directory to save video and GIF.
        fps (int): Frames per second for the video.
    """

    # Ensure output directory exists
    os.makedirs(output_dir, exist_ok=True)
    output_video = os.path.join(output_dir, "attention_video.mp4")
    output_gif = os.path.join(output_dir, "attention_video.gif")

    # Get list of attention images (sorted by step)
    images = sorted([img for img in os.listdir(image_folder) if img.endswith(".png")])

    if not images:
        print("No images found in the folder! Check if visualize_attention_maps() was run.")
        return

    # Load first image to get video dimensions
    sample_image = cv2.imread(os.path.join(image_folder, images[0]))
    if sample_image is None:
        print("Error loading sample image. Check if images exist in the folder.")
        return

    height, width, _ = sample_image.shape

    # Create MP4 video writer
    video_writer = cv2.VideoWriter(output_video, cv2.VideoWriter_fourcc(*"mp4v"), fps, (width, height))

    # Create GIF writer
    gif_images = []

    print("🎥 Generating video and GIF...")
    for img_name in tqdm(images, desc="Processing Frames"):
        img_path = os.path.join(image_folder, img_name)
        frame = cv2.imread(img_path)

        if frame is None:
            print(f"Warning: Could not read image {img_name}. Skipping.")
            continue

        # Write frame to video
        video_writer.write(frame)

        # Convert BGR to RGB for GIF
        gif_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        gif_images.append(gif_frame)

    # Release video writer
    video_writer.release()

    # Save GIF using imageio
    imageio.mimsave(output_gif, gif_images, fps=fps)

    print(f"Video saved as: {output_video}")
    print(f"GIF saved as: {output_gif}")

# Run the function
create_attention_video(image_folder="attention_results", output_dir="output", fps=1)


🎥 Generating video and GIF...


Processing Frames: 100%|██████████| 444/444 [00:07<00:00, 55.79it/s]


Video saved as: output/attention_video.mp4
GIF saved as: output/attention_video.gif


In [17]:
def overlay_attention_on_image(image, attn_map, patch_size=16, title=""):

    num_patches = attn_map.shape[-1]
    grid_size = int(np.sqrt(num_patches))  # E.g., 256 tokens → 16x16 grid

    # Ensure attention map is (num_tokens,)
    if attn_map.ndim == 2:
        attn_map = attn_map.mean(axis=0)  # Take mean across heads

    # Reshape into 2D grid
    attn_grid = attn_map.reshape((grid_size, grid_size))
    attn_grid = (attn_grid - attn_grid.min()) / (attn_grid.max() - attn_grid.min())

    # Upsample the attention map to match image size
    attn_heatmap = zoom(attn_grid, (patch_size, patch_size), order=1)

    # Convert attention map to heatmap
    attn_colormap = cv2.applyColorMap((attn_heatmap * 255).astype(np.uint8), cv2.COLORMAP_JET)

    # Resize to match image
    attn_colormap = cv2.resize(attn_colormap, (image.shape[1], image.shape[0]))

    # Blend the heatmap with the original image
    overlayed_image = cv2.addWeighted(image, 0.6, attn_colormap, 0.4, 0)

    return overlayed_image

def visualize_mean_max_attention(attention_maps_per_step, images, save_path="attention_results_mean_max"):
    """
    Visualizes and saves mean and max attention maps overlaid on images.

    Args:
        attention_maps_per_step: List of attention maps extracted from OctoModel.
        images: List of original images used for inference.
        save_path: Directory to save visualized attention maps.
    """
    os.makedirs(save_path, exist_ok=True)

    num_steps = len(attention_maps_per_step)
    num_layers = len(attention_maps_per_step[0]["all"])  # Assuming all steps have the same layers

    for step in tqdm(range(num_steps), desc="Visualizing Mean & Max Attention"):
        image = images[step]

        for layer_idx in range(num_layers):  # Iterate over layers
            attn_map = attention_maps_per_step[step]["all"][layer_idx]  # Shape: (num_heads, image_token_num)

            # Fix batch dimension issue if necessary
            if attn_map.ndim == 3 and attn_map.shape[0] == 1:
                attn_map = attn_map[0]  # Remove batch dim (now: (num_heads, image_token_num))

            if attn_map.ndim != 2:
                print(f"Unexpected shape at step {step}, layer {layer_idx}: {attn_map.shape}")
                continue

            # Compute mean & max across heads
            mean_attn = attn_map.mean(axis=0)  # (image_token_num,)
            max_attn = attn_map.max(axis=0)  # (image_token_num,)

            # Overlay attention on image
            mean_overlay = overlay_attention_on_image(image, mean_attn, title=f"Step {step} - Layer {layer_idx} (Mean)")
            max_overlay = overlay_attention_on_image(image, max_attn, title=f"Step {step} - Layer {layer_idx} (Max)")

            # Save images
            cv2.imwrite(os.path.join(save_path, f"step_{step}_layer_{layer_idx}_mean.png"), mean_overlay)
            cv2.imwrite(os.path.join(save_path, f"step_{step}_layer_{layer_idx}_max.png"), max_overlay)

def create_attention_video(image_folder, output_video, fps=5):
    """
    Creates a video from saved attention map images.

    Args:
        image_folder (str): Path to the folder containing attention images.
        output_video (str): Output filename for the video.
        fps (int): Frames per second for the video.
    """
    images = sorted([img for img in os.listdir(image_folder) if img.endswith(".png")])
    if not images:
        print(f"No images found in {image_folder}")
        return

    sample_image = cv2.imread(os.path.join(image_folder, images[0]))
    height, width, _ = sample_image.shape

    video_writer = cv2.VideoWriter(output_video, cv2.VideoWriter_fourcc(*"mp4v"), fps, (width, height))

    for img_name in images:
        img_path = os.path.join(image_folder, img_name)
        frame = cv2.imread(img_path)
        video_writer.write(frame)

    video_writer.release()
    print(f"Video saved as: {output_video}")

def run_visualization(attention_maps_per_step, images):
    visualize_mean_max_attention(attention_maps_per_step, images, save_path="attention_results_mean_max")
    create_attention_video("attention_results_mean_max", "mean_attention_video.mp4", fps=2)
    create_attention_video("attention_results_mean_max", "max_attention_video.mp4", fps=2)


Visualizing Mean Attention: 100%|██████████| 37/37 [00:01<00:00, 20.91it/s]


In [19]:
# def visualize_max_attention(attention_maps_per_step, images, save_path="attention_results_max"):
#     """
#     Visualizes and saves mean and max attention maps overlaid on images.
#
#     Args:
#         attention_maps_per_step: List of attention maps extracted from OctoModel.
#         images: List of original images used for inference.
#         save_path: Directory to save visualized attention maps.
#     """
#     os.makedirs(save_path, exist_ok=True)
#
#     num_steps = len(attention_maps_per_step)
#     num_layers = len(attention_maps_per_step[0]["all"])  # Assuming all steps have the same layers
#
#     for step in tqdm(range(num_steps), desc="Visualizing Max Attention"):
#         image = images[step]
#
#         for layer_idx in range(num_layers):  # Iterate over layers
#             attn_map = attention_maps_per_step[step]["all"][layer_idx]  # Shape: (num_heads, image_token_num)
#
#             # Fix batch dimension issue if necessary
#             if attn_map.ndim == 3 and attn_map.shape[0] == 1:
#                 attn_map = attn_map[0]  # Remove batch dim (now: (num_heads, image_token_num))
#
#             if attn_map.ndim != 2:
#                 print(f"Unexpected shape at step {step}, layer {layer_idx}: {attn_map.shape}")
#                 continue
#
#             max_attn = attn_map.max(axis=0)  # (image_token_num,)
#             max_overlay = overlay_attention_on_image(image, max_attn, title=f"Step {step} - Layer {layer_idx} (Max)")
#             cv2.imwrite(os.path.join(save_path, f"step_{step}_layer_{layer_idx}_max.png"), max_overlay)
#
# def create_attention_video(image_folder, output_video, fps=5):
#
#     images = sorted([img for img in os.listdir(image_folder) if img.endswith(".png")])
#     if not images:
#         print(f"No images found in {image_folder}")
#         return
#
#     sample_image = cv2.imread(os.path.join(image_folder, images[0]))
#     height, width, _ = sample_image.shape
#
#     video_writer = cv2.VideoWriter(output_video, cv2.VideoWriter_fourcc(*"mp4v"), fps, (width, height))
#
#     for img_name in images:
#         img_path = os.path.join(image_folder, img_name)
#         frame = cv2.imread(img_path)
#         video_writer.write(frame)
#
#     video_writer.release()
#     print(f"Video saved as: {output_video}")
#
# def run_visualization(attention_maps_per_step, images):
#     visualize_max_attention(attention_maps_per_step, images, save_path="attention_results_max")
#     create_attention_video("attention_results_max", "max_attention_video.mp4", fps=2)
#
# run_visualization(attention_maps_per_step, images)

Visualizing Max Attention: 100%|██████████| 37/37 [00:01<00:00, 20.87it/s]


Video saved as: max_attention_video.mp4


## extract attention from task-related words to image tokens

In [None]:
# import os
# import numpy as np
# import cv2
# import matplotlib.pyplot as plt
# import seaborn as sns
# from tqdm import tqdm
# from scipy.ndimage import zoom
#
# def extract_task_to_image_attention(attention_maps_per_step, task_token_idx, image_token_indices):
#     """
#     Extracts attention weights from task-related words to image tokens.
#
#     Args:
#         attention_maps_per_step: List of attention maps extracted from OctoModel.
#         task_token_idx: Index of the task-related word token.
#         image_token_indices: Indices of the image tokens.
#
#     Returns:
#         A list of extracted attention scores mapping task words to image tokens.
#     """
#     extracted_attention = []
#
#     for step in tqdm(range(len(attention_maps_per_step)), desc="Extracting Task-Image Attention"):
#         step_attention = []
#         num_layers = len(attention_maps_per_step[step]["all"])
#
#         for layer_idx in range(num_layers):
#             attn_map = attention_maps_per_step[step]["all"][layer_idx]  # Shape: (num_heads, num_tokens)
#
#             # Fix batch dimension issue if necessary
#             if attn_map.ndim == 3 and attn_map.shape[0] == 1:
#                 attn_map = attn_map[0]  # Remove batch dim
#
#             if attn_map.ndim != 2:
#                 print(f"Unexpected shape at step {step}, layer {layer_idx}: {attn_map.shape}")
#                 continue
#
#             # Extract attention from task token to image tokens
#             task_to_image_attention = attn_map[:, task_token_idx, image_token_indices]  # Shape: (num_heads, image_token_count)
#             step_attention.append(task_to_image_attention)
#
#         extracted_attention.append(step_attention)
#
#     return extracted_attention
#
# def plot_task_image_attention(attention_scores, image, title="Task to Image Attention", save_path=None):
#     """
#     Visualizes attention from task tokens to image tokens as a heatmap.
#
#     Args:
#         attention_scores: (num_heads, image_token_count) - Extracted attention scores.
#         image: Original image.
#         title: Title for the visualization.
#         save_path: If provided, saves the heatmap.
#     """
#     avg_attention = attention_scores.mean(axis=0)  # Average across heads
#     grid_size = int(np.sqrt(len(avg_attention)))
#
#     attn_grid = avg_attention.reshape((grid_size, grid_size))
#     attn_grid = zoom(attn_grid, (image.shape[0] // grid_size, image.shape[1] // grid_size), order=1)
#
#     plt.figure(figsize=(8, 8))
#     plt.imshow(image)
#     sns.heatmap(attn_grid, alpha=0.6, cmap='jet', linewidths=0, linecolor='black')
#     plt.title(title)
#     plt.axis("off")
#
#     if save_path:
#         plt.savefig(save_path)
#     plt.show()
#
# def run_task_image_attention_analysis(attention_maps_per_step, images, task_token_idx, image_token_indices, save_dir="task_image_attention"):
#     """
#     Extracts and visualizes attention from task-related words to image tokens.
#
#     Args:
#         attention_maps_per_step: List of attention maps.
#         images: List of input images.
#         task_token_idx: Index of task-related word.
#         image_token_indices: Indices of image tokens.
#         save_dir: Directory to save attention visualizations.
#     """
#     os.makedirs(save_dir, exist_ok=True)
#     extracted_attention = extract_task_to_image_attention(attention_maps_per_step, task_token_idx, image_token_indices)
#
#     for step in range(len(images)):
#         save_path = os.path.join(save_dir, f"step_{step}_task_to_image.png")
#         plot_task_image_attention(extracted_attention[step][-1], images[step], title=f"Step {step} - Task to Image Attention", save_path=save_path)
