In [1]:
# Objective: Save attention values of each timestep and save it as seperate dataset
# 1. Load dataset
# 2. Load model
# 3. Get attention values
# 4. Save attention values: Add key "attention" and "previous_obs" to dataset

In [1]:
from copy import deepcopy
import h5py

In [2]:
dataset_path = "../data/robomimic/datasets/lift/ph/low_dim_abs.hdf5"
new_dataset_path = "../data/robomimic/datasets/lift/ph/low_dim_abs_with_attention.hdf5"

# # Copy entire file
# with h5py.File(dataset_path, 'r') as src, h5py.File(new_dataset_path, 'w') as dst:
#     for key in src.keys():
#         src.copy(key, dst)
    
#     print(f"Created new dataset at {new_dataset_path}")
#     print(f"Copied {len(dst.keys())} top-level groups")
#     demos = dst['data']
#     print(f"Number of demonstrations: {len(demos)}")

In [None]:
import h5py
import numpy as np
import shutil

# Paths
dataset_path = "../data/robomimic/datasets/lift/ph/low_dim_abs.hdf5"
new_dataset_path = "../data/robomimic/datasets/lift/ph/low_dim_abs_with_attention.hdf5"

# First, copy the entire file to preserve all structure
shutil.copy(dataset_path, new_dataset_path)

# Open the new file in read-write mode
with h5py.File(new_dataset_path, 'r+') as f:
    
    num_demos = len(f['data'].keys())
    
    for demo_idx in range(num_demos):
        demo_key = f'data/demo_{demo_idx}'
        demo = f[demo_key]
        
        num_samples = demo.attrs['num_samples']
        
        for sample_idx in range(num_samples):
            obs = demo['obs'].keys()
            print(obs)
        
    # # Access the demo_0 group
    # demo_0 = f['data/demo_0']
    
    # # Get the number of samples from the attribute
    # num_samples = demo_0.attrs['num_samples']
    
    # # Create the spatial_attention dataset
    # demo_0.create_dataset(
    #     'spatial_attention',
    #     shape=(num_samples,),  # One value per sample
    #     dtype=np.float32,      # Using float32 for attention values
    #     data=np.zeros(num_samples)  # Initialize with zeros
    # )

In [None]:
# open file
f = h5py.File(dataset_path, "r")

# each demonstration is a group under "data"
demos = list(f["data"].keys())
num_demos = len(demos)
print("hdf5 file {} has {} demonstrations".format(dataset_path, num_demos))

In [None]:
# each demonstration is a group under "data"
demos_list = list(demos.keys())
num_demos = len(demos_list)
print("hdf5 file {} has {} demonstrations".format(dataset_path, num_demos))

In [1]:
import os
import sys
sys.path.append("../")
import hydra
import numpy as np
from tqdm import tqdm
from moviepy.editor import ImageSequenceClip
from PIL import Image
from matplotlib import pyplot as plt
import torch
import dill
from torch.utils.data import DataLoader
import h5py
import shutil

from diffusion_policy.dataset.robomimic_replay_image_dataset import RobomimicReplayImageDataset
from diffusion_policy.common.replay_buffer import ReplayBuffer
from diffusion_policy.workspace.base_workspace import BaseWorkspace
from diffusion_policy.common.pytorch_util import dict_apply

In [None]:
# zarr_path = os.path.expanduser('../data/robomimic/datasets/lift/ph/image_abs.hdf5.zarr.zip')
dataset_path = os.path.expanduser('../data/robomimic/datasets/lift/ph/image_abs.hdf5')
# replay_buffer = ReplayBuffer.copy_from_path(zarr_path, keys=None)

# Define shape metadata
shape_meta = {
    'action': {
        'shape': [7]
    },
    'obs': {
        'object': {
            'shape': [10]
        },
        'agentview_image': {
            'shape': [3, 84, 84],
            'type': 'rgb'
        },
        'robot0_eef_pos': {
            'shape': [3]
        },
        'robot0_eef_quat': {
            'shape': [4]
        },
        'robot0_eye_in_hand_image': {
            'shape': [3, 84, 84],
            'type': 'rgb'
        },
        'robot0_gripper_qpos': {
            'shape': [2]
        }
    }
}

# Create dataset
dataset = RobomimicReplayImageDataset(
    dataset_path=dataset_path,
    shape_meta=shape_meta,
    horizon=2,
    pad_before=1,
    pad_after=1,
    rotation_rep='rotation_6d',
    seed=42,
    val_ratio=0.0,
    use_legacy_normalizer=False,
)

dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
iterator = iter(dataloader)

In [3]:
# bring h5py file

# Paths
dataset_path = "../data/robomimic/datasets/lift/ph/low_dim_abs.hdf5"
new_dataset_path = "../data/robomimic/datasets/lift/ph/low_dim_abs_with_attention.hdf5"

# First, copy the entire file to preserve all structure
shutil.copy(dataset_path, new_dataset_path)

file = h5py.File(new_dataset_path, 'r+')


In [None]:
num_demos = len(file['data'].keys())
print(f"Number of demonstrations: {num_demos}")

length_of_each_demo = list()
for i in tqdm(range(num_demos)):
    demo_key = f'data/demo_{i}'
    demo = file[demo_key]
    length_of_each_demo.append(demo.attrs['num_samples'])
length_of_each_demo = np.array(length_of_each_demo)

assert (length_of_each_demo+1).sum() == len(dataset)


In [5]:
# Prepaer model
# Load model checkpoint
checkpoint = "../data/outputs/lift_lowdim_ph_reproduction/horizon_16/2025.03.11/10.57.22_train_diffusion_unet_lowdim_lift_lowdim_transformer_128/checkpoints/epoch=0200-test_mean_score=1.000.ckpt"
output_dir = "../data/outputs/lift_lowdim_ph_reproduction/horizon_16/2025.03.11/10.57.22_train_diffusion_unet_lowdim_lift_lowdim_transformer_128/dummy"
payload = torch.load(open(checkpoint, 'rb'), pickle_module=dill)
cfg = payload['cfg']
cfg.policy.noise_scheduler._target_ = 'diffusion_policy.schedulers.scheduling_ddpm.DDPMScheduler'

cls = hydra.utils.get_class(cfg._target_)
workspace = cls(cfg, output_dir=output_dir)
workspace: BaseWorkspace
workspace.load_payload(payload, exclude_keys=None, include_keys=None)

# Get policy from workspace
policy = workspace.model

# Setup device and model
device = torch.device('cuda:0')
policy.to(device);
policy.eval();

In [None]:
iterator = iter(dataloader)
for i in range(1):
    demo_key = f'data/demo_{i}'
    demo = file[demo_key]
    num_samples = demo.attrs['num_samples']
    
    spatial_attention = list()
    for sample_idx in tqdm(range(num_samples), leave=False):        
        sample = next(iterator)
        
        assert np.linalg.norm(sample['obs']['object'][0, 1, :].numpy()-demo['obs']['object'][sample_idx]) < 1e-4
        
        n_obs_dict = {
                    'obs': np.concatenate([
                        sample['obs']['object'], 
                        sample['obs']['robot0_eef_pos'], 
                        sample['obs']['robot0_eef_quat'], 
                        sample['obs']['robot0_gripper_qpos']
                    ], axis=-1).astype(np.float32)
                }
        # Device transfer
        obs_dict = dict_apply(n_obs_dict, 
            lambda x: torch.from_numpy(x).to(device=device))
        with torch.no_grad():
            spatial_attention.append(policy.kl_divergence_drop(obs_dict).detach().cpu().numpy().item())
    spatial_attention = np.array(spatial_attention)
    
    next(iterator) # Fro syncing
    
    if 'spatial_attention' not in demo.keys():
        demo.create_dataset(
            'spatial_attention',
            shape=(num_samples,),   
            data=spatial_attention,
            dtype=np.float32
        )
    else:
        demo['spatial_attention'][:] = spatial_attention


In [7]:
file.close()

In [None]:
# Check if it has the data
file = h5py.File(new_dataset_path, 'r+')
spatial_attention = file['data/demo_0/spatial_attention'][:]

fig, ax = plt.subplots(figsize=(4, 3))
time_steps = np.arange(len(spatial_attention))
ax.plot(time_steps, spatial_attention, 'b-', linewidth=1, label='Spatial Attention')
ax.set_xlabel('Time Step')
ax.tick_params(axis='y', labelcolor='b')
ax.grid(True)

In [None]:
# Open the new file in read-write mode
with h5py.File(new_dataset_path, 'r+') as f:
    
    num_demos = len(f['data'].keys())
    
    for demo_idx in range(num_demos):
        demo_key = f'data/demo_{demo_idx}'
        demo = f[demo_key]
        
        num_samples = demo.attrs['num_samples']
        
        for sample_idx in range(num_samples):
            obs = demo['obs'].keys()
            print(obs)