In [1]:
import os
import pickle

import numpy as np
import torch
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator
import seaborn as sns

from policy import config
from policy.dataset.ms2dataset import get_MS_loaders
from policy.checkpoints import CheckpointIO

model_dir = "/home/mrl/Documents/Projects/tskill/out/PegInsertion/031"

cfg_path = os.path.join(model_dir, "config.yaml")
# cfg_path = "/home/mrl/Documents/Projects/tskill/assets/skill/default.yaml"
cfg = config.load_config(cfg_path, None)

index_path = os.path.join(model_dir, "data_info.pickle")
with open(index_path, 'rb') as f:
    data_info = pickle.load(f)

# Dataset
cfg["data"]["pad"] = True
cfg["training"]["n_workers"] = 2
cfg["data"]["augment"] = False
cfg["data"]["augmentation"]["image_aug"] = False
cfg["data"]["augmentation"]["subsequence_rate"] = 1
cfg["data"]["full_seq"] = False
cfg["data"]["dataset"] = "/home/mrl/Documents/Projects/tskill/data/demos/v0/rigid_body/PegInsertionSide-v0/trajectory.rgbd.pd_joint_delta_pos3.h5"

# Load only the full episode version of the dataset
if "train_ep_indices" not in data_info.keys():
    train_idx, val_idx = data_info["train_indices"], data_info["val_indices"]
else:
    train_idx, val_idx = data_info["train_ep_indices"], data_info["val_ep_indices"]
train_dataset, val_dataset = get_MS_loaders(cfg, return_datasets=True, 
                                            indices=(train_idx, val_idx),
                                            save_override=True)
    

Found existing data info file
Using override indices
Loading action and state scaling from file
Adding batch dimension to returned data!


In [2]:
# Model
model = config.get_model(cfg, device="cpu")
checkpoint_io = CheckpointIO(model_dir, model=model)
load_dict = checkpoint_io.load("model_best.pt")
model.eval()

train = True
if not train:
    dataset = val_dataset
    print("Using Validation Dataset")
    idxs = val_idx
else:
    dataset = train_dataset
    print("Using Training Dataset")
    idxs = train_idx


class AttentionHook:
    def __init__(self):
        self.attention_weights = []
        self.attention = []
        self.inputs = []

    def __call__(self, module, input, output):
        self.attention.append(output[0])
        self.attention_weights.append(output[1])
        self.inputs.append(input)


def visualize_attention_weights(attention_weights, layer_name, layer_num, attn):
    avg_attention = attention_weights.mean(dim=0).abs().detach().cpu().numpy()
    print(avg_attention.shape)
    fig = plt.figure(figsize=(20, 20))
    sns.heatmap(avg_attention, annot=False, cmap='YlGnBu')
    ax = fig.axes[0]
    plt.xlabel('Input Sequence')
    plt.ylabel('Input Sequence')
    plt.title(f'{layer_name} Layer {layer_num} {attn} Attention Weight Visualization')
    
    # plt.xticks(range(avg_attention.shape[1]), range(avg_attention.shape[1]), rotation=90)
    ax.xaxis.set_major_locator(MultipleLocator(avg_attention.shape[0]/6))
    ax.xaxis.set_major_formatter('{x:.0f}')
    ax.xaxis.set_minor_locator(MultipleLocator(10))
    ax.yaxis.set_major_locator(MultipleLocator(avg_attention.shape[0]/6))
    ax.yaxis.set_major_formatter('{x:.0f}')
    ax.yaxis.set_minor_locator(MultipleLocator(10))
    
    plt.tight_layout()
    plt.show()

def visualize_attention(attention, input_len, layer_name, layer_num, attn):
    avg_attention = attention.mean(dim=-1).abs().squeeze().detach().cpu().numpy()
    if len(avg_attention.shape) != 0:
        print(avg_attention.shape)
        plt.figure(figsize=(20, 10))
        sns.barplot(avg_attention)
        
        plt.xlabel('Input Sequence')
        plt.ylabel('Attention')
        plt.title(f'{layer_name} Layer {layer_num} {attn} Attention Visualization')
        
        plt.xticks(range(input_len), range(input_len), rotation=90)
        
        plt.tight_layout()
        plt.show()

# Register hooks
dec_enc_hooks = []
dec_dec_hooks = []
enc_enc_hooks = []
enc_dec_hooks = []

for i, layer in enumerate(model.decoder.encoder.layers):
    hook = AttentionHook()
    hook.name = "SA"
    layer.self_attn.register_forward_hook(hook)
    dec_enc_hooks.append(hook)

for i, layer in enumerate(model.decoder.decoder.layers):
    hook1 = AttentionHook()
    hook1.name = "SA"
    hook2 = AttentionHook()
    hook2.name = "MHA"
    layer.self_attn.register_forward_hook(hook1)
    layer.multihead_attn.register_forward_hook(hook2)
    dec_dec_hooks.append(hook1)
    dec_dec_hooks.append(hook2)

for i, layer in enumerate(model.encoder.encoder.layers):
    hook = AttentionHook()
    hook.name = "SA"
    layer.self_attn.register_forward_hook(hook)
    enc_enc_hooks.append(hook)

for i, layer in enumerate(model.encoder.decoder.layers):
    hook1 = AttentionHook()
    hook1.name = "SA"
    hook2 = AttentionHook()
    hook2.name = "MHA"
    layer.self_attn.register_forward_hook(hook1)
    layer.multihead_attn.register_forward_hook(hook2)
    enc_dec_hooks.append(hook1)
    enc_dec_hooks.append(hook2)

freezing state encoder network!
/home/mrl/Documents/Projects/tskill/out/PegInsertion/031/model_best.pt
=> Loading checkpoint from local file...




load state dict: <All keys matched successfully>
Using Training Dataset


In [3]:
i = 20
data = train_dataset[i]
print(data["actions"].shape)
print((~data["seq_pad_mask"]).to(torch.int))
# with torch.no_grad():
out = model(data,use_precalc=True)

print(data["actions"])

# # Visualize attention for each layer
# for i, hook in enumerate(enc_enc_hooks):
#     visualize_attention_weights(hook.attention_weights[0], "Encoder Encoder", i+1, hook.name)
#     # visualize_attention(hook.attention[0], len(hook.inputs[0][0]), "Encoder Encoder", i+1, hook.name)
# visualize_attention_weights(torch.mean(torch.stack([x.attention_weights[0] for x in enc_enc_hooks if x.name == "SA"], 0), 0), "Decoder Encoder Mean", 0, "SA")

# for i, hook in enumerate(enc_dec_hooks):
#     visualize_attention_weights(hook.attention_weights[0], "Encoder Decoder", np.ceil((i+1)/2), hook.name)
#     # visualize_attention(hook.attention[0], len(hook.inputs[0][0]), "Encoder Decoder", i+1, hook.name)
# visualize_attention_weights(torch.mean(torch.stack([x.attention_weights[0] for x in enc_dec_hooks if x.name == "SA"], 0), 0), "Encoder Decoder Mean", 0, "SA")
# visualize_attention_weights(torch.mean(torch.stack([x.attention_weights[0] for x in enc_dec_hooks if x.name == "MHA"], 0), 0), "Encoder Decoder Mean", 0, "MHA")

# for i, hook in enumerate(dec_enc_hooks):
#     visualize_attention_weights(hook.attention_weights[0], "Decoder Encoder", i+1, hook.name)
#     # visualize_attention(hook.attention[0], len(hook.inputs[0][0]), "Decoder Encoder", i+1, hook.name)
# visualize_attention_weights(torch.mean(torch.stack([x.attention_weights[0] for x in dec_enc_hooks if x.name == "SA"], 0), 0), "Decoder Encoder Mean", 0, "SA")

# for i, hook in enumerate(dec_dec_hooks):
#     visualize_attention_weights(hook.attention_weights[0], "Decoder Decoder", np.ceil((i+1)/2), hook.name)
#     # visualize_attention(hook.attention[0], len(hook.inputs[0][0]), "Decoder Decoder", i+1, hook.name)    
# visualize_attention_weights(torch.mean(torch.stack([x.attention_weights[0] for x in dec_dec_hooks if x.name == "SA"], 0), 0), "Decoder Decoder Mean", 0, "SA")
# visualize_attention_weights(torch.mean(torch.stack([x.attention_weights[0] for x in dec_dec_hooks if x.name == "MHA"], 0), 0), "Decoder Decoder Mean", 0, "MHA")

torch.Size([1, 200, 8])
tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int32)
2936563
2936565
tensor(0)
tensor([[[-1.4607,  1.4697,  1.3295,  ..., -0.5014, -2.0299,  1.0000],
         [-1.4541,  1.5268,  1.3590,  ..., -0.5748, -2.0734,  1.0000],
         [-1.4489,  1.5850,  1.3816,  ..., -0.6325, -2.0972,  1.0000],
         ...,
         [ 0.0000,  0