# load the trajectory dataset

In [1]:
# %matplotlib inline
import h5py
import numpy as np
from libero.lifelong.datasets import *
from libero.libero.utils.dataset_utils import get_dataset_info
from IPython.display import HTML
import imageio
from libero.libero import benchmark, get_libero_path, set_libero_default_path
import os
from termcolor import colored
import torch
import torchvision
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
import cv2
from libero.lifelong.models import *
from libero.lifelong.utils import *
model_path = '/home/ruiqi/projects/meta_adapt/scripts/experiments/LIBERO_OBJECT/PreTrainMultitask/BCViLTPolicy_seed10000/run_012/multitask_model_ep10.pth'
checkpoint = torch.load(model_path)
sd = checkpoint['state_dict']
cfg = checkpoint['cfg']
model = get_policy_class(cfg.policy.policy_type)(cfg, cfg.shape_meta)
model.load_state_dict(sd)
TASK_SUITE = cfg['task_creation']['task_suite']
def construct_data(task_emb, agent_view_im, eye_in_hand_im):
    data = {}
    data['obs'] = {}
    data['obs']['agentview_rgb'] = agent_view_im
    data['obs']['eye_in_hand_rgb'] = eye_in_hand_im
    data['task_emb'] = task_emb
    return data


def fetch_joint_att(data, model):
    att_maps = model.get_spatial_summary(data)
    att_maps_all_layer = []
    for i in range(len(att_maps)):
        att_maps_all_layer.append(att_maps[i])
    att_mat = torch.stack(att_maps_all_layer).squeeze(1)

    # Average the attention weights across all heads.
    att_mat = torch.mean(att_mat, dim=1)

    # To account for residual connections, we add an identity matrix to the
    # attention matrix and re-normalize the weights.
    residual_att = torch.eye(att_mat.size(1))
    aug_att_mat = att_mat + residual_att
    aug_att_mat = aug_att_mat / aug_att_mat.sum(dim=-1).unsqueeze(-1)

    # Recursively multiply the weight matrices
    joint_attentions = torch.zeros(aug_att_mat.size())
    joint_attentions[0] = aug_att_mat[0]

    for n in range(1, aug_att_mat.size(0)):
        joint_attentions[n] = torch.matmul(aug_att_mat[n], joint_attentions[n - 1])
    agentview_im = data['obs']['agentview_rgb'][0, 0, :].numpy().transpose(1, 2, 0)
    eye_in_hand_im = data['obs']['eye_in_hand_rgb'][0, 0, :].numpy().transpose(1, 2, 0)
    return joint_attentions, agentview_im, eye_in_hand_im


def show_att_im(att_map, agentview_im, eye_in_hand_im):
    v = att_map

    # get the att of the spatial summary token (index 0) over all other tokens
    agentview_att = v[0, 1:197]
    agentview_att = agentview_att / agentview_att.sum()
    # print(agentview_att.sum())
    eye_in_hand_att = v[0, 197:197 + 196]
    eye_in_hand_att = eye_in_hand_att / eye_in_hand_att.sum()
    grid_size = 14

    mask_agentview = agentview_att.reshape(grid_size, grid_size).detach().numpy()
    mask_agentview = cv2.resize(mask_agentview / mask_agentview.max(), agentview_im.shape[:-1])[..., np.newaxis]

    mask_eye_in_hand = eye_in_hand_att.reshape(grid_size, grid_size).detach().numpy()
    mask_eye_in_hand = cv2.resize(mask_eye_in_hand / mask_eye_in_hand.max(), eye_in_hand_im.shape[:-1])[..., np.newaxis]

    masked_im_a = agentview_im * mask_agentview
    masked_im_e = eye_in_hand_im * mask_eye_in_hand
    return masked_im_a[::-1], masked_im_e[::-1]


def show_att_images(images_list, save_flag=False, view='a', task_id=0, traj_id=0):
    # Define grid size (e.g., 3 rows and 4 columns)
    rows = 4
    cols = 7

    # Create a figure and axes
    fig, axes = plt.subplots(rows, cols, figsize=(18, 10))

    # Flatten the axes array for easy iteration
    axes = axes.flatten()

    # Plot each image in the grid
    for i, ax in enumerate(axes):
        if i < len(images_list):
            ax.imshow(images_list[i])
        ax.axis('off')  # Turn off axis

    # Adjust spacing
    plt.subplots_adjust(wspace=0.1, hspace=0.1)

    if not save_flag:
        # Show the plot
        plt.show()
    else:
        plt.savefig(f'figs1/{TASK_SUITE}/t_{task_id}_d_{traj_id}_{view}.png', dpi=300, bbox_inches='tight')
        


def fetch_att_im_entire_traj(agentview_data_list, eye_in_hand_data_list, task_embs, task_id=0, traj_id=0):
    
    agent_att_list = []
    eye_att_list = []

    for i in range(len(agentview_data_list)):
        temp = construct_data(task_embs, agentview_data_list[i], eye_in_hand_data_list[i])
        joint_att, agentview_im, eye_in_hand_im = fetch_joint_att(temp, model)
        im_a, im_e = show_att_im(joint_att[-1], agentview_im, eye_in_hand_im)
        agent_att_list.append((im_a * 255).astype(np.uint8))
        eye_att_list.append((im_e * 255).astype(np.uint8))

    show_att_images(agent_att_list, save_flag=True, view='a', task_id=task_id, traj_id=traj_id)
    show_att_images(eye_att_list, save_flag=True, view='e', task_id=task_id, traj_id=traj_id)


import pickle as pkl

TASK_ID = 0

online_traj_path = os.path.join('/', *model_path.split('/')[:-1], f'task_{TASK_ID}.pkl')

#online_traj_path = "/home/ruiqi/projects/meta_adapt/scripts/experiments/LIBERO_SPATIAL/PreTrainMultitask/BCViLTPolicy_seed10000/run_003/task_0.pkl"
with open(online_traj_path, 'rb') as f:
    task_obs = pkl.load(f)
    f.close()

TRAJ_ID = 1  # [0,4]

agentview_ims = [x['agentview_rgb'][TRAJ_ID] for x in task_obs]
eye_in_hand_ims = [x['eye_in_hand_rgb'][TRAJ_ID] for x in task_obs]

descriptions = []
benchmark_dict = benchmark.get_benchmark_dict()
benchmark_instance = benchmark_dict[TASK_SUITE]()
task_description = benchmark_instance.get_task(TASK_ID).language
descriptions.append(task_description)
task_embs = get_task_embs(cfg, descriptions)
# select several frames
interval = len(agentview_ims) // 20
selected_agentview_ims = []
selected_eye_in_hand_ims = []
selected_agentview_ims.append(agentview_ims[0][None, None, :].to('cpu'))
selected_eye_in_hand_ims.append(eye_in_hand_ims[0][None, None, :].to('cpu'))
for i in range(0, len(agentview_ims), interval):
    selected_agentview_ims.append(agentview_ims[i][None, None, :].to('cpu'))
    selected_eye_in_hand_ims.append(eye_in_hand_ims[i][None, None, :].to('cpu'))
selected_agentview_ims.append(agentview_ims[-1][None, None, :].to('cpu'))
selected_eye_in_hand_ims.append(eye_in_hand_ims[-1][None, None, :].to('cpu'))
len(selected_agentview_ims)
fetch_att_im_entire_traj(selected_agentview_ims,
                         selected_eye_in_hand_ims,
                         task_embs,
                         task_id=TASK_ID,
                         traj_id=TRAJ_ID)