In [1]:
import sys
sys.path.append('../')

In [2]:

import os
import hydra
import numpy as np
from tqdm import tqdm
from moviepy.editor import ImageSequenceClip
from PIL import Image
from matplotlib import pyplot as plt
import torch
import dill
from torch.utils.data import DataLoader

from diffusion_policy.dataset.robomimic_replay_image_dataset import RobomimicReplayImageDataset # Need this to operate ReplayBuffer.copy_from_path. If not it will raise a codec error
from diffusion_policy.dataset.robomimic_replay_lowdim_dataset import RobomimicReplayLowdimDataset
from diffusion_policy.common.replay_buffer import ReplayBuffer
from diffusion_policy.workspace.base_workspace import BaseWorkspace
from diffusion_policy.common.pytorch_util import dict_apply

In [3]:
zarr_path = os.path.expanduser('../data/robomimic/datasets/lift/ph/image_abs.hdf5.zarr.zip')
dataset_path = os.path.expanduser('../data/robomimic/datasets/lift/ph/image_abs.hdf5')
replay_buffer = ReplayBuffer.copy_from_path(
            zarr_path, keys=None)

In [4]:
shape_meta = {
    'action': {
        'shape': [7]
    },
    'obs': {
        'object': {
            'shape': [10]
        },
        'agentview_image': {
            'shape': [3, 84, 84],
            'type': 'rgb'
        },
        'robot0_eef_pos': {
            'shape': [3]
        },
        'robot0_eef_quat': {
            'shape': [4]
        },
        'robot0_eye_in_hand_image': {
            'shape': [3, 84, 84],
            'type': 'rgb'
        },
        'robot0_gripper_qpos': {
            'shape': [2]
        }
    }
}

In [None]:
dataset = RobomimicReplayImageDataset(
    dataset_path=dataset_path,
    shape_meta=shape_meta,
    horizon=2,
    pad_before=1,
    pad_after=1,
    rotation_rep='rotation_6d',
    seed=42,
    val_ratio=0.0,
    use_legacy_normalizer=False,
)

In [None]:
low_dim_dataset = RobomimicReplayLowdimDataset(
    dataset_path=f"../data/robomimic/datasets/lift/ph/low_dim_abs.hdf5",
    horizon=2,
    pad_before=1,
    pad_after=1,
    abs_action=True,
    obs_keys=['object', 'robot0_eef_pos', 'robot0_eef_quat', 'robot0_gripper_qpos'],
)

In [7]:
low_dim_dataloader = DataLoader(low_dim_dataset, batch_size=1, shuffle=False)
image_dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

In [8]:
low_dim_sample = next(iter(low_dim_dataloader))
image_sample = next(iter(image_dataloader))


In [None]:
print(low_dim_sample['obs'].numpy())

In [None]:
n_obs_dict = {
            'obs': np.concatenate([image_sample['obs']['object'], image_sample['obs']['robot0_eef_pos'], image_sample['obs']['robot0_eef_quat'], image_sample['obs']['robot0_gripper_qpos']], axis=-1).astype(np.float16)
        }
print(n_obs_dict['obs'])

In [6]:
checkpoint = "../data/outputs/lift_lowdim_ph_reproduction/horizon_16/2025.03.11/10.57.22_train_diffusion_unet_lowdim_lift_lowdim_transformer_128/checkpoints/epoch=0300-test_mean_score=1.000.ckpt"
output_dir = "../data/outputs/lift_lowdim_ph_reproduction/horizon_16/2025.03.11/10.57.22_train_diffusion_unet_lowdim_lift_lowdim_transformer_128/dummy"
payload = torch.load(open(checkpoint, 'rb'), pickle_module=dill)
cfg = payload['cfg']
cfg.policy.noise_scheduler._target_ = 'diffusion_policy.schedulers.scheduling_ddpm.DDPMScheduler'

cls = hydra.utils.get_class(cfg._target_)
workspace = cls(cfg, output_dir=output_dir)
workspace: BaseWorkspace
workspace.load_payload(payload, exclude_keys=None, include_keys=None)

# get policy from workspace
policy = workspace.model

In [None]:
device = torch.device('cuda:0')
policy.to(device)
policy.eval()

video_dir = "../data/robomimic/datasets/lift/ph/videos_with_spatial_attention"

dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
iterator = iter(dataloader)
os.makedirs(video_dir, exist_ok=True)

for i in tqdm(range(replay_buffer.n_episodes)):
    epi = replay_buffer.get_episode(i)
    
    T, H, W, C = epi['agentview_image'].shape
    
    imgs = np.zeros((T+1, H, 2*W, C), dtype=epi['agentview_image'].dtype)
    print(f"episode {i}| T: {T}")
    
    spatial_attention = list()
    for t in tqdm(range(T+1), desc="Getting Image and Spatial Attention", leave=False):
        sample = next(iterator)
        # Transpose from (3,84,84) to (84,84,3) format
        imgs[t, :, :W, :] = ((sample['obs']['agentview_image'][0, 0].permute(1, 2, 0) * 255).numpy()).astype(np.uint8)
        imgs[t, :, W:, :] = (sample['obs']['robot0_eye_in_hand_image'][0, 0].permute(1, 2, 0) * 255).numpy().astype(np.uint8)
        n_obs_dict = {
            'obs': np.concatenate([sample['obs']['object'], sample['obs']['robot0_eef_pos'], sample['obs']['robot0_eef_quat'], sample['obs']['robot0_gripper_qpos']], axis=-1).astype(np.float32)
        }
        
        # device transfer
        obs_dict = dict_apply(n_obs_dict, 
            lambda x: torch.from_numpy(x).to(
                device=device))
        with torch.no_grad():
            spatial_attention.append(policy.kl_divergence_drop(obs_dict).detach().cpu().numpy().item())
    spatial_attention = np.array(spatial_attention)
    
    # get attention graph images
    fig, ax = plt.subplots(figsize=(4, 3))
    graph_frames = list()
    for t in range(len(spatial_attention)):
        ax.scatter(t, spatial_attention[t], color='red', s=30)
        
        fig.canvas.draw()
        graph_img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
        graph_img = graph_img.reshape(fig.canvas.get_width_height()[::-1] + (3,))
        graph_frames.append(graph_img)
    graph_height, graph_width = graph_frames[0].shape[:2]
    # prepare combined frames
    
    # attention_data = get_spatial_attention_from_episode(epi)
    combined_frames = list()
    for frame, graph in zip(imgs, graph_frames):
            # Resize frame to match graph height
            frame_height, frame_width = frame.shape[:2]
            aspect_ratio = frame_width / frame_height
            new_height = graph_height
            new_width = int(new_height * aspect_ratio)
            
            frame_resized = np.array(Image.fromarray(frame).resize((new_width, new_height)))
            
            # Create canvas and center the frame
            canvas = np.zeros((graph_height, graph_width, 3), dtype=np.uint8)
            x_offset = (graph_width - new_width) // 2
            
            if x_offset >= 0:
                canvas[:, x_offset:x_offset+new_width] = frame_resized
            else:
                crop_start = (-x_offset) // 2
                canvas = frame_resized[:, crop_start:crop_start+graph_width]
            
            combined_frames.append(np.hstack([canvas, graph]))
            
    video_path = os.path.join(video_dir, f"episode_{i}.mp4")
    
    # Create and write video with moviepy
    clip = ImageSequenceClip(combined_frames, fps=30)
    clip.write_videofile(video_path, codec='libx264')

In [None]:
video_dir = "../data/robomimic/datasets/lift/ph/videos_with_spatial_attention"

os.makedirs(video_dir, exist_ok=True)

for i in tqdm(range(replay_buffer.n_episodes)):
    epi = replay_buffer.get_episode(i)
    attention_data = get_spatial_attention_from_episode(epi)
    video_path = os.path.join(video_dir, f"episode_{i}.mp4")
    
    # Get image sequence
    agent_imgs = epi['agentview_image'] # T,H,W,C uint8 array
    eye_imgs = epi['robot0_eye_in_hand_image']
    # Concatenate images horizontally
    imgs = np.concatenate([agent_imgs, eye_imgs], axis=2) # T,H,2W,C
    
    # Scale dimensions by 3x
    h, w = imgs.shape[1:3]
    imgs_scaled = np.array([np.array(Image.fromarray(img).resize((w*3, h*3))) for img in imgs])
    
    # Create and write video with moviepy
    clip = ImageSequenceClip(list(imgs_scaled), fps=30)
    clip.write_videofile(video_path, codec='libx264')