In [65]:
import os
import cv2
import numpy as np
import tensorflow_datasets as tfds
import tqdm
import jax
import imageio
import matplotlib.pyplot as plt
from scipy.ndimage import zoom
from octo.model.octo_model import OctoModel

In [66]:
# Load Octo model
print("Loading Octo model checkpoint...")
model = OctoModel.load_pretrained("hf://rail-berkeley/octo-base-1.5")

Loading Octo model checkpoint...


Fetching 6 files:   0%|          | 0/6 [00:00<?, ?it/s]

In [84]:
# Load multiple episodes from BRIDGE dataset
dataset_split = "train[:100]"  # Adjust number of episodes as needed
builder = tfds.builder_from_directory(builder_dir="gs://gresearch/robotics/bridge/0.1.0/")
ds = builder.as_dataset(split=dataset_split)

In [68]:
WINDOW_SIZE = 2
num_episodes = 20

all_pred_actions = []
all_true_actions = []
all_attention_maps = []
all_images = []

output_dir = "training_videos_1-10"
os.makedirs(output_dir, exist_ok=True)

for episode_idx, episode in enumerate(tfds.as_numpy(ds.take(num_episodes))):
    print(f"Processing Episode {episode_idx+1}/{num_episodes}")

    steps = list(episode["steps"])
    images = [cv2.resize(np.array(step["observation"]["image"]), (256, 256)) for step in steps]
    goal_image = images[-1]
    language_instruction = steps[0]["observation"]["natural_language_instruction"].decode()
    all_images.append(images)

    task = model.create_tasks(goals={"image_primary": goal_image[None]})
    pred_actions, true_actions, attention_maps_per_step = [], [], []

    video_filename = os.path.join(output_dir, f"episode_{episode_idx+1}.mp4")
    video_writer = cv2.VideoWriter(video_filename, cv2.VideoWriter_fourcc(*"mp4v"), 5, (256, 256))

    with open(os.path.join(output_dir, f"episode_{episode_idx+1}_instruction.txt"), "w") as f:
        f.write(language_instruction)

    for step in tqdm.trange(len(images) - (WINDOW_SIZE - 1)):
        input_images = np.stack(images[step:step+WINDOW_SIZE])[None]
        observation = {
            'image_primary': input_images,
            'timestep_pad_mask': np.full((1, input_images.shape[1]), True, dtype=bool)
        }

        actions, attention_maps = model.sample_actions(
            observation,
            task,
            unnormalization_statistics=model.dataset_statistics["bridge_dataset"]["action"],
            rng=jax.random.PRNGKey(0)
        )

        pred_actions.append(actions[0])
        attention_maps_per_step.append(attention_maps)

        final_window_step = step + WINDOW_SIZE - 1
        true_actions.append(np.concatenate(
            (
                steps[final_window_step]['action']['world_vector'],
                steps[final_window_step]['action']['rotation_delta'],
                np.array(steps[final_window_step]['action']['open_gripper']).astype(np.float32)[None]
            ), axis=-1
        ))

        frame = images[step]
        video_writer.write(frame)

    video_writer.release()
    print(f"Saved training video: {video_filename}")
    print(f"Saved instruction: episode_{episode_idx+1}_instruction.txt")

    all_pred_actions.append(pred_actions)
    all_true_actions.append(true_actions)
    all_attention_maps.append(attention_maps_per_step)




Processing Episode 1/20


100%|██████████| 37/37 [00:24<00:00,  1.51it/s]


Saved training video: training_videos/episode_1.mp4
Saved instruction: episode_1_instruction.txt
Processing Episode 2/20


100%|██████████| 31/31 [00:17<00:00,  1.77it/s]


Saved training video: training_videos/episode_2.mp4
Saved instruction: episode_2_instruction.txt
Processing Episode 3/20


100%|██████████| 31/31 [00:22<00:00,  1.40it/s]


Saved training video: training_videos/episode_3.mp4
Saved instruction: episode_3_instruction.txt
Processing Episode 4/20


100%|██████████| 21/21 [00:14<00:00,  1.48it/s]


Saved training video: training_videos/episode_4.mp4
Saved instruction: episode_4_instruction.txt
Processing Episode 5/20


100%|██████████| 47/47 [00:25<00:00,  1.85it/s]


Saved training video: training_videos/episode_5.mp4
Saved instruction: episode_5_instruction.txt
Processing Episode 6/20


100%|██████████| 26/26 [00:12<00:00,  2.10it/s]


Saved training video: training_videos/episode_6.mp4
Saved instruction: episode_6_instruction.txt
Processing Episode 7/20


100%|██████████| 31/31 [00:13<00:00,  2.28it/s]


Saved training video: training_videos/episode_7.mp4
Saved instruction: episode_7_instruction.txt
Processing Episode 8/20


100%|██████████| 46/46 [00:21<00:00,  2.12it/s]


Saved training video: training_videos/episode_8.mp4
Saved instruction: episode_8_instruction.txt
Processing Episode 9/20


100%|██████████| 21/21 [00:09<00:00,  2.18it/s]


Saved training video: training_videos/episode_9.mp4
Saved instruction: episode_9_instruction.txt
Processing Episode 10/20


100%|██████████| 48/48 [00:23<00:00,  2.08it/s]

Saved training video: training_videos/episode_10.mp4
Saved instruction: episode_10_instruction.txt





In [69]:
def overlay_attention_on_image(image, attn_map, patch_size=16):
    num_heads, num_patches = attn_map.shape
    grid_size = int(np.sqrt(num_patches))
    overlayed_images = []

    for head in range(num_heads):
        attn_grid = attn_map[head].reshape((grid_size, grid_size))
        attn_grid = (attn_grid - attn_grid.min()) / (attn_grid.max() - attn_grid.min())
        attn_heatmap = zoom(attn_grid, (patch_size, patch_size), order=1)
        attn_colormap = cv2.applyColorMap((attn_heatmap * 255).astype(np.uint8), cv2.COLORMAP_JET)
        attn_colormap = cv2.resize(attn_colormap, (image.shape[1], image.shape[0]))
        overlayed_image = cv2.addWeighted(image, 0.6, attn_colormap, 0.4, 0)
        overlayed_images.append(overlayed_image)
    return overlayed_images

Interpolation (Smoothing):

zoom(attn_grid, (patch_size, patch_size), order=1) upscales the coarse attention map using bilinear interpolation (order=1), which makes the attention map appear smoother rather than blocky.


In [70]:
def visualize_attention_episodes(all_attention_maps, all_images, save_path="attention_results", last_n_layers=5):
    os.makedirs(save_path, exist_ok=True)

    for episode_idx, attention_maps_per_step in enumerate(all_attention_maps):
        episode_save_path = os.path.join(save_path, f"episode_{episode_idx+1}")
        os.makedirs(episode_save_path, exist_ok=True)  # Create a directory for each episode

        images = all_images[episode_idx]
        num_steps = len(attention_maps_per_step)
        num_layers = len(attention_maps_per_step[0]["all"])

        for step in tqdm.trange(num_steps, desc=f"Visualizing Episode {episode_idx+1}"):
            image = images[step]

            for layer_idx in range(num_layers - last_n_layers, num_layers):
                layer_save_path = os.path.join(episode_save_path, f"layer_{layer_idx}")
                os.makedirs(layer_save_path, exist_ok=True)  # Create sub-folder for each layer in the episode

                attn_map = attention_maps_per_step[step]["all"][layer_idx][0]
                if attn_map.ndim != 2:
                    continue

                overlayed_images = overlay_attention_on_image(image, attn_map)
                num_heads = attn_map.shape[0]
                fig, axes = plt.subplots(1, num_heads, figsize=(20, 5))

                for head_idx in range(num_heads):
                    axes[head_idx].imshow(overlayed_images[head_idx])
                    axes[head_idx].set_title(f"Ep {episode_idx+1}, Step {step}, L{layer_idx}, H{head_idx}")
                    axes[head_idx].axis("off")

                plt.tight_layout()
                plt.savefig(os.path.join(layer_save_path, f"step_{step}.png"))
                plt.close()

visualize_attention_episodes(all_attention_maps, all_images)

Visualizing Episode 1: 100%|██████████| 37/37 [01:03<00:00,  1.72s/it]
Visualizing Episode 2: 100%|██████████| 31/31 [01:37<00:00,  3.15s/it]
Visualizing Episode 3: 100%|██████████| 31/31 [00:54<00:00,  1.75s/it]
Visualizing Episode 4: 100%|██████████| 21/21 [00:37<00:00,  1.76s/it]
Visualizing Episode 5: 100%|██████████| 47/47 [01:23<00:00,  1.78s/it]
Visualizing Episode 6: 100%|██████████| 26/26 [00:50<00:00,  1.95s/it]
Visualizing Episode 7: 100%|██████████| 31/31 [01:05<00:00,  2.12s/it]
Visualizing Episode 8: 100%|██████████| 46/46 [01:17<00:00,  1.68s/it]
Visualizing Episode 9: 100%|██████████| 21/21 [00:51<00:00,  2.46s/it]
Visualizing Episode 10: 100%|██████████| 48/48 [01:18<00:00,  1.63s/it]


In [71]:
import cv2
import os
import tqdm

def create_video_from_images(image_folder, video_path, fps=5):
    """
    Creates a video from images stored in a folder.

    Args:
        image_folder (str): Path to the folder containing images.
        video_path (str): Path to save the output video file.
        fps (int): Frames per second.
    """
    images = sorted([img for img in os.listdir(image_folder) if img.endswith(".png")], key=lambda x: int(x.split('_')[1].split('.')[0]))

    if not images:
        print(f"No images found in {image_folder}, skipping video creation.")
        return

    first_image = cv2.imread(os.path.join(image_folder, images[0]))
    height, width, _ = first_image.shape

    fourcc = cv2.VideoWriter_fourcc(*'mp4v')  # Codec for .mp4
    video = cv2.VideoWriter(video_path, fourcc, fps, (width, height))

    for image in images:
        img_path = os.path.join(image_folder, image)
        frame = cv2.imread(img_path)
        video.write(frame)

    video.release()
    print(f"Video saved: {video_path}")

def generate_videos_for_layers(save_path="attention_results"):
    """
    Generates videos for each layer across all episodes.

    Args:
        save_path (str): The root folder containing episode directories.
    """
    for episode_folder in sorted(os.listdir(save_path)):
        episode_path = os.path.join(save_path, episode_folder)

        if not os.path.isdir(episode_path):
            continue

        print(f"Generating videos for {episode_folder}...")

        for layer_folder in sorted(os.listdir(episode_path)):
            layer_path = os.path.join(episode_path, layer_folder)

            if not os.path.isdir(layer_path):
                continue

            video_path = os.path.join(episode_path, f"{layer_folder}new.mp4")
            create_video_from_images(layer_path, video_path)

generate_videos_for_layers()

Generating videos for episode_1...
Video saved: attention_results/episode_1/layer_10new.mp4
Video saved: attention_results/episode_1/layer_11new.mp4
Video saved: attention_results/episode_1/layer_7new.mp4
Video saved: attention_results/episode_1/layer_8new.mp4
Video saved: attention_results/episode_1/layer_9new.mp4
Generating videos for episode_10...
Video saved: attention_results/episode_10/layer_10new.mp4
Video saved: attention_results/episode_10/layer_11new.mp4
Video saved: attention_results/episode_10/layer_7new.mp4
Video saved: attention_results/episode_10/layer_8new.mp4
Video saved: attention_results/episode_10/layer_9new.mp4
Generating videos for episode_2...
Video saved: attention_results/episode_2/layer_10new.mp4
Video saved: attention_results/episode_2/layer_11new.mp4
Video saved: attention_results/episode_2/layer_7new.mp4
Video saved: attention_results/episode_2/layer_8new.mp4
Video saved: attention_results/episode_2/layer_9new.mp4
Generating videos for episode_3...
Video sa

In [72]:
def combine_videos_for_episode(episode_path, output_video_path, fps=2):
    """
    Combines multiple layer videos into a single episode video by stacking frames vertically.

    Args:
        episode_path (str): Path to the episode folder containing layer videos.
        output_video_path (str): Path to save the final combined episode video.
        fps (int): Frames per second.
    """
    layer_videos = sorted([vid for vid in os.listdir(episode_path) if vid.endswith(".mp4")])

    if not layer_videos:
        print(f"No layer videos found in {episode_path}, skipping.")
        return

    # Open video captures for each layer
    video_caps = [cv2.VideoCapture(os.path.join(episode_path, vid)) for vid in layer_videos]

    # Get frame width & height from the first video
    success, first_frame = video_caps[0].read()
    if not success:
        print(f"Error reading first frame from {layer_videos[0]}")
        return

    frame_height, frame_width, _ = first_frame.shape
    num_layers = len(video_caps)

    # Define output video size (stacking layers vertically)
    combined_height = frame_height * num_layers
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    output_video = cv2.VideoWriter(output_video_path, fourcc, fps, (frame_width, combined_height))

    print(f"Creating combined video: {output_video_path}")

    while True:
        frames = []
        for cap in video_caps:
            success, frame = cap.read()
            if not success:
                break
            frames.append(frame)

        if len(frames) != num_layers:  # Stop when any layer runs out of frames
            break

        combined_frame = np.vstack(frames)  # Stack frames vertically
        output_video.write(combined_frame)

    # Release all resources
    for cap in video_caps:
        cap.release()
    output_video.release()
    print(f"Combined episode video saved: {output_video_path}")

def generate_combined_videos(save_path="attention_results"):
    """
    Generates a combined video for each episode by stacking layer videos vertically.

    Args:
        save_path (str): The root folder containing episode directories.
    """
    for episode_folder in sorted(os.listdir(save_path)):
        episode_path = os.path.join(save_path, episode_folder)

        if not os.path.isdir(episode_path):
            continue

        output_video_path = os.path.join(episode_path, f"{episode_folder}_combined.mp4")
        combine_videos_for_episode(episode_path, output_video_path)

# Run the function to generate combined videos
generate_combined_videos("attention_results")

Creating combined video: attention_results/episode_1/episode_1_combined.mp4
Combined episode video saved: attention_results/episode_1/episode_1_combined.mp4
Creating combined video: attention_results/episode_10/episode_10_combined.mp4
Combined episode video saved: attention_results/episode_10/episode_10_combined.mp4
Creating combined video: attention_results/episode_2/episode_2_combined.mp4
Combined episode video saved: attention_results/episode_2/episode_2_combined.mp4
Creating combined video: attention_results/episode_3/episode_3_combined.mp4
Combined episode video saved: attention_results/episode_3/episode_3_combined.mp4
Creating combined video: attention_results/episode_4/episode_4_combined.mp4
Combined episode video saved: attention_results/episode_4/episode_4_combined.mp4
Creating combined video: attention_results/episode_5/episode_5_combined.mp4
Combined episode video saved: attention_results/episode_5/episode_5_combined.mp4
Creating combined video: attention_results/episode_6/e

In [73]:
def evaluate_action_accuracy(pred_actions, true_actions):
    """
    Evaluates the quality of predicted actions using L2 distance and cosine similarity.

    Args:
        pred_actions (list of np.array): List of predicted actions.
        true_actions (list of np.array): List of ground truth actions.

    Returns:
        dict: Dictionary containing L2 distance and cosine similarity per episode.
    """
    results = {}

    for episode_idx, (pred, true) in enumerate(zip(pred_actions, true_actions)):
        pred = np.array(pred)  # Shape: (num_steps, action_horizon, action_dim)
        true = np.array(true)  # Shape: (num_steps, action_dim)

        if pred.ndim == 3:  # If prediction has action horizon, reduce dimension
            pred = pred[:, 0, :]  # Option 1: Take the first action per step
            # pred = pred.mean(axis=1)  # Option 2: Average over horizon

        if pred.shape != true.shape:
            print(f"Warning: Shape mismatch in Episode {episode_idx+1}. Pred: {pred.shape}, True: {true.shape}")
            continue  # Skip problematic episodes

        l2_dist = np.linalg.norm(pred - true, axis=1)  # L2 distance per step
        cos_sim = np.sum(pred * true, axis=1) / (np.linalg.norm(pred, axis=1) * np.linalg.norm(true, axis=1))  # Cosine similarity

        results[f"episode_{episode_idx+1}"] = {
            "l2_mean": np.mean(l2_dist),
            "l2_std": np.std(l2_dist),
            "cosine_mean": np.mean(cos_sim),
            "cosine_std": np.std(cos_sim),
        }

    return results

# Example usage
results = evaluate_action_accuracy(all_pred_actions, all_true_actions)
for ep, metrics in results.items():
    print(f"{ep}: L2 Distance: {metrics['l2_mean']:.3f}, Cosine Similarity: {metrics['cosine_mean']:.3f}")

SUCCESS_THRESHOLD_L2 = 0.1  # Set threshold based on data
SUCCESS_THRESHOLD_COS = 0.9

def evaluate_success(results):
    success_rates = []
    for ep, metrics in results.items():
        l2_score = metrics["l2_mean"]
        cos_score = metrics["cosine_mean"]

        # Weighted decision: if one is good, still count as success
        success = (l2_score < SUCCESS_THRESHOLD_L2) or (cos_score > SUCCESS_THRESHOLD_COS)
        success_rates.append(success)

        print(f"{ep}: {'Success' if success else 'Failure'}")

    overall_success = np.mean(success_rates) * 100
    print(f"Overall success rate: {overall_success:.2f}%")

evaluate_success(results)

episode_1: L2 Distance: 0.123, Cosine Similarity: 0.938
episode_2: L2 Distance: 0.085, Cosine Similarity: 0.600
episode_3: L2 Distance: 0.074, Cosine Similarity: 0.997
episode_4: L2 Distance: 0.100, Cosine Similarity: 0.679
episode_5: L2 Distance: 0.295, Cosine Similarity: 0.534
episode_6: L2 Distance: 0.301, Cosine Similarity: 0.445
episode_7: L2 Distance: 0.068, Cosine Similarity: 0.997
episode_8: L2 Distance: 0.173, Cosine Similarity: 0.628
episode_9: L2 Distance: 0.075, Cosine Similarity: 0.677
episode_10: L2 Distance: 0.273, Cosine Similarity: 0.504
episode_1: Success
episode_2: Success
episode_3: Success
episode_4: Success
episode_5: Failure
episode_6: Failure
episode_7: Success
episode_8: Failure
episode_9: Success
episode_10: Failure
Overall success rate: 60.00%


In [85]:
for episode_idx, episode in enumerate(tfds.as_numpy(ds.take(num_episodes))):
    print(f"Processing Episode {episode_idx+1}/{num_episodes}")
    language_instruction = steps[0]["observation"]["natural_language_instruction"].decode()
    print(language_instruction)


Processing Episode 1/20
Place the cheese wedge on the front right side of the table
Processing Episode 2/20
Place the cheese wedge on the front right side of the table
Processing Episode 3/20
Place the cheese wedge on the front right side of the table
Processing Episode 4/20
Place the cheese wedge on the front right side of the table
Processing Episode 5/20
Place the cheese wedge on the front right side of the table
Processing Episode 6/20
Place the cheese wedge on the front right side of the table
Processing Episode 7/20
Place the cheese wedge on the front right side of the table
Processing Episode 8/20
Place the cheese wedge on the front right side of the table
Processing Episode 9/20
Place the cheese wedge on the front right side of the table
Processing Episode 10/20
Place the cheese wedge on the front right side of the table
Processing Episode 11/20
Place the cheese wedge on the front right side of the table
Processing Episode 12/20
Place the cheese wedge on the front right side of

Before, we just do offline evaluation. Now, run on simpler_env to so that we can see the true simulation and results.

In [74]:
# import os
# import cv2
# import numpy as np
# import jax
# import tqdm
# import imageio
# import SimplerEnv  # Import SimplerEnv ?????????????
# from octo.model.octo_model import OctoModel
# from scipy.ndimage import zoom
# import matplotlib.pyplot as plt
#
# # Load Octo Model
# print("Loading Octo model checkpoint...")
# model = OctoModel.load_pretrained("hf://rail-berkeley/octo-base-1.5")
#
# # Initialize SimplerEnv
# env = SimplerEnv.make("SimplerEnv-v0")  # Change to match the correct environment name
# obs, info = env.reset()
#
# WINDOW_SIZE = 2  # Number of past frames as input
# fps = 5  # Video frame rate
# output_video = "simpler_env_simulation.mp4"
# gif_output = "simpler_env_simulation.gif"
# video_frames = []
# attention_frames = []
#
# # Initialize observation history
# observation_history = [obs] * WINDOW_SIZE
#
# def overlay_attention_on_image(image, attn_map, patch_size=16):
#     """
#     Overlays attention heatmap on an input image.
#     """
#     num_heads, num_patches = attn_map.shape
#     grid_size = int(np.sqrt(num_patches))
#     overlayed_images = []
#
#     for head in range(num_heads):
#         attn_grid = attn_map[head].reshape((grid_size, grid_size))
#         attn_grid = (attn_grid - attn_grid.min()) / (attn_grid.max() - attn_grid.min())
#         attn_heatmap = zoom(attn_grid, (patch_size, patch_size), order=1)
#         attn_colormap = cv2.applyColorMap((attn_heatmap * 255).astype(np.uint8), cv2.COLORMAP_JET)
#         attn_colormap = cv2.resize(attn_colormap, (image.shape[1], image.shape[0]))
#         overlayed_image = cv2.addWeighted(image, 0.6, attn_colormap, 0.4, 0)
#         overlayed_images.append(overlayed_image)
#
#     return overlayed_images
#
# # Run Simulation Loop
# for step in tqdm.trange(100):  # Run for 100 timesteps
#     input_images = np.stack(observation_history)[None]
#
#     # Format input for Octo model
#     observation = {
#         'image_primary': input_images,
#         'timestep_pad_mask': np.full((1, input_images.shape[1]), True, dtype=bool)
#     }
#
#     # Predict action and attention map using Octo
#     actions, attention_maps = model.sample_actions(
#         observation,
#         None,  # No explicit goal input
#         unnormalization_statistics=None,  # Use default scaling
#         rng=jax.random.PRNGKey(0)
#     )
#
#     # Take action in SimplerEnv
#     obs, reward, done, truncated, info = env.step(actions[0])
#
#     # Update observation history
#     observation_history.pop(0)
#     observation_history.append(obs)
#
#     # Render environment and store frame
#     frame = env.render(mode="rgb_array")  # Ensure SimplerEnv supports rendering
#     video_frames.append(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
#
#     # Overlay attention maps on the frame
#     attn_map = attention_maps["all"][-1][0]  # Last layer, first head
#     attn_overlay = overlay_attention_on_image(frame, attn_map)[0]  # Use first head
#     attention_frames.append(attn_overlay)
#
#     if done:
#         print(f"Episode finished at step {step}")
#         break
#
# env.close()
#
# # Save video without attention maps
# imageio.mimsave(output_video, video_frames, fps=fps)
# print(f"✅ Saved simulation video: {output_video}")
#
# # Save video with attention maps overlay
# imageio.mimsave(gif_output, attention_frames, fps=fps)
# print(f"✅ Saved attention visualization video: {gif_output}")
