In [None]:
import os
# Set CUDA_VISIBLE_DEVICES to use only GPU 2
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

import copy
import os.path
import pickle
from os.path import join
import sys

import cv2
import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np
from PIL import ImageDraw
from matplotlib import cm

import jax
import jax.experimental
import jax.numpy as jnp
from PIL import Image

# Import from openpi package (already in Python path)
from openpi import EXP_DATA_PATH
from openpi.models.pi0 import Pi0
import openpi.models.model as _model
import openpi.training.config as _config
from openpi.shared.download import maybe_download

# Import from scripts/text_latent.py
import sys
scripts_path = os.path.abspath('.')
if scripts_path not in sys.path:
    sys.path.append(scripts_path)
    
from text_latent import Checkpoint, Args, create_dataloader


def restore_img(img) -> Image:
    float_img = np.asarray(img + 1, dtype=np.float64) / 2 * 255
    int_img = float_img.astype(np.uint8)
    img = Image.fromarray(int_img)
    return img


def normalize(vectors):
    return vectors / jnp.linalg.norm(vectors, axis=-1, keepdims=True)


def patch_idx_to_hw(patch_idx, num_patch=16, patch_size=14):
    """Convert patch index to height and width."""
    h = patch_idx // num_patch
    w = patch_idx % num_patch
    return int(h * patch_size), int(w * patch_size)


def draw_box_on_image(image, top_left_corner, box_width=14, box_height=14, color='red', thickness=1):
    # Create a Draw object to modify the image
    draw = ImageDraw.Draw(image)

    # Extract coordinates and calculate the bounding box
    top = top_left_corner[0]
    left = top_left_corner[1]
    right = left + box_width
    bottom = top + box_height

    bounding_box = (left, top, right, bottom)

    # Draw the rectangle
    draw.rectangle(bounding_box, outline=color, width=thickness)

    return image

In [None]:
# Check if EXP_DATA_PATH exists, create if not
if not os.path.exists(EXP_DATA_PATH):
    os.makedirs(EXP_DATA_PATH, exist_ok=True)
    print(f"Created directory: {EXP_DATA_PATH}")

# load text latents for libero goal task
text_latent_path = join(EXP_DATA_PATH, "pi0")
libero_goal_task_id = [i for i in range(10, 40)]
text_latents = {}

# Check if text latent files exist
if os.path.exists(text_latent_path):
    for id in libero_goal_task_id:
        file_path = join(text_latent_path, f"avg_states_{id}_{id + 1}_frame_0_119.pkl")
        if os.path.exists(file_path):
            with open(file_path, "rb") as f:
                data = pickle.load(f)
            text_latents[id] = data["hidden_states_avg"]
        else:
            print(f"Warning: Text latent file not found: {file_path}")
else:
    print(f"Warning: Text latent directory not found: {text_latent_path}")
    print("Please run text_latent.py first to generate text latents")

# load training data for libero goal task
policy = "pi0"
args = Args()
args.policy = Checkpoint(config=f"{policy}_libero", dir="s3://openpi-assets/checkpoints/{}_libero".format(policy))

try:
    train_config = _config.get_config(args.policy.config)
    data_config = train_config.data.create(train_config.assets_dirs, train_config.model)
    ckpt_dir = maybe_download(args.policy.dir)

    # load model
    model: Pi0 = train_config.model.load(_model.restore_params(ckpt_dir / "params", dtype=jnp.bfloat16))
    embed_prefix_jit = jax.jit(model.embed_prefix)
    
    def embed_prefix(observation):
        observation = _model.preprocess_observation(None, observation, train=False)
        return embed_prefix_jit(observation)
    
    print("Model loaded successfully")
except Exception as e:
    print(f"Error loading model: {e}")
    model = None
    embed_prefix_jit = None

In [None]:
# Load observation data for visualization
all_obs = []

# Only proceed if model is loaded successfully
if model is not None and embed_prefix_jit is not None:
    for task_id in range(10, 40):
        try:
            task_range = (task_id, task_id + 1)
            episode_to_use_for_collection = 1
            data_loader, dataset_meta = create_dataloader(train_config, data_config, ckpt_dir, task_range, 1,
                                                          episode_to_use_for_collection, True)
            obs = [_model.Observation.from_dict(element) for element in list(iter(data_loader))]
            base_images = [obs[i].images["base_0_rgb"][0] for i in range(len(obs))]
            print(f"Task {task_id}: {dataset_meta.tasks[task_id]}")


            def decode(obs):
                if not isinstance(obs, list):
                    obs = obs.tokenized_prompt[0].tolist()
                return data_loader.torch_loader.dataset._transform.transforms[-1].tokenizer._tokenizer.decode(obs)


            prompt = obs[0].tokenized_prompt[0]
            decoded = [decode([token.item()]) for token in prompt]
            prompt = decoded[:obs[0].tokenized_prompt_mask.sum()]
            print(f"Tokenized prompt: {prompt}")
            print(f"Token index: {prompt[:obs[0].tokenized_prompt_mask.sum()]}")

            # check embedding same
            prompt_len = len(decoded[:obs[0].tokenized_prompt_mask.sum()])
            all_obs.append((obs, len(prompt), prompt, len(obs)))
            
        except Exception as e:
            print(f"Error loading data for task {task_id}: {e}")
            continue
            
    print(f"Successfully loaded {len(all_obs)} tasks")
else:
    print("Model not loaded. Cannot proceed with data loading.")

In [None]:
# Attention visualization functions
def get_max_sim_patch(task_start, task_id, all_obs, obs_idx=0):
    """Compute similarity between text latent and image patches."""
    if not text_latents or task_id not in text_latents:
        print(f"Text latent not found for task {task_id}")
        return None
        
    if task_id - task_start >= len(all_obs):
        print(f"Task {task_id} not found in all_obs")
        return None
        
    task_obs = copy.deepcopy(all_obs[task_id - task_start])
    obs = task_obs[0]
    prompt_len = task_obs[1]
    
    if obs_idx >= len(obs):
        print(f"Observation index {obs_idx} out of range for task {task_id}")
        return None
        
    episode_similarity = jnp.zeros((18, 256), dtype=jnp.float32)

    try:
        # Get image and text embeddings
        embeddings, _, _ = embed_prefix(obs[obs_idx])
        image_embedding = embeddings[0, :256]
        text_latent = text_latents[task_id][:, 256 * 3 + 3:256 * 3 + prompt_len].sum(-2)
        
        # Normalize embeddings
        normed_image_embedding = normalize(image_embedding)
        normed_text_latent = normalize(text_latent)

        # Compute similarity
        similarity = jnp.dot(normed_text_latent, normed_image_embedding.T)
        episode_similarity += similarity

        patch_episode_similarity = jnp.max(episode_similarity[1:], axis=0)

        # Generate visualization
        float_img = np.asarray(task_obs[0][obs_idx].images["base_0_rgb"][0] + 1, dtype=np.float64) / 2 * 255
        original_image = float_img.astype(np.uint8)
        correlate_map = np.asarray(patch_episode_similarity).reshape(16, 16)
        norm_attention_map = (correlate_map - correlate_map.min()) / (correlate_map.max() - correlate_map.min())
        resized_attention_map = cv2.resize(norm_attention_map, (224, 224), interpolation=cv2.INTER_LINEAR)
        heatmap_rgba = cm.jet(resized_attention_map)
        heatmap_bgr = (heatmap_rgba[:, :, :3] * 255).astype(np.uint8)
        alpha = 0.3
        blended_image = cv2.addWeighted(heatmap_bgr, alpha, original_image, 1 - alpha, 0)
        return Image.fromarray(blended_image)
    
    except Exception as e:
        print(f"Error in get_max_sim_patch for task {task_id}, obs {obs_idx}: {e}")
        return None


def plot_dynamics(task_id):
    """Plot attention dynamics over time for a given task."""
    if task_id - 10 >= len(all_obs):
        print(f"Task {task_id} not found in loaded observations")
        return
        
    task_obs = all_obs[task_id - 10]
    print(f"Task {task_id}: {task_obs[2]}")
    
    # Generate images for different time steps
    images = []
    max_obs = min(task_obs[-1], 100)  # Limit to avoid memory issues
    step_size = max(1, max_obs // 20)  # Generate ~20 images max
    
    for t in range(0, max_obs, step_size):
        img = get_max_sim_patch(10, task_id, all_obs, obs_idx=t)
        if img is not None:
            images.append(img)
        else:
            print(f"Skipping timestep {t} due to error")

    if not images:
        print("No valid images generated")
        return
        
    print(f"Generated {len(images)} attention visualizations")

    def show_image(index):
        if 0 <= index < len(images):
            plt.figure(figsize=(8, 8))
            plt.imshow(images[index])
            plt.title(f"Task {task_id} - Timestep {index * step_size}")
            plt.axis('off')
            plt.show()
        else:
            print(f"Invalid index: {index}")

    # Create interactive slider
    slider = widgets.IntSlider(
        value=0, 
        min=0, 
        max=len(images) - 1, 
        step=1, 
        description='Timestep:'
    )
    widgets.interact(show_image, index=slider)

In [None]:
plot_dynamics(22)

In [None]:
plot_dynamics(29)

In [None]:
plot_dynamics(23)

In [None]:
plot_dynamics(27)

In [None]:
# Quick test to verify everything is working
print("=== Testing VLM Attention Visualization ===")
print(f"EXP_DATA_PATH: {EXP_DATA_PATH}")
print(f"Text latents loaded for {len(text_latents)} tasks")
print(f"Observation data loaded for {len(all_obs) if 'all_obs' in globals() else 0} tasks")

if model is not None:
    print("✓ Model loaded successfully")
else:
    print("✗ Model not loaded")
    
if text_latents:
    print(f"✓ Text latents available for tasks: {sorted(text_latents.keys())}")
else:
    print("✗ No text latents loaded")

# Test a single attention visualization if everything is loaded
if model is not None and text_latents and 'all_obs' in globals() and all_obs:
    print("\n=== Testing attention visualization for Task 22 ===")
    try:
        test_img = get_max_sim_patch(10, 22, all_obs, obs_idx=0)
        if test_img is not None:
            print("✓ Attention visualization test successful!")
            # Display the test image
            plt.figure(figsize=(8, 8))
            plt.imshow(test_img)
            plt.title("Test: Task 22 Attention Visualization")
            plt.axis('off')
            plt.show()
        else:
            print("✗ Attention visualization test failed")
    except Exception as e:
        print(f"✗ Attention visualization test error: {e}")
else:
    print("\n✗ Cannot test attention visualization - missing components")