In [1]:
import os
import pickle

import numpy as np
import torch
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

from policy import config
from policy.dataset.ms2dataset import get_MS_loaders
from policy.checkpoints import CheckpointIO

model_dir = "/home/mrl/Documents/Projects/tskill/out/PegInsertion/023"

cfg_path = os.path.join(model_dir, "config.yaml")
cfg = config.load_config(cfg_path, None)

index_path = os.path.join(model_dir, "data_info.pickle")
with open(index_path, 'rb') as f:
    data_info = pickle.load(f)

# Dataset
cfg["data"]["pad_train"] = False
cfg["data"]["pad_val"] = False
cfg["data"]["augment"] = False
cfg["data"]["full_seq"] = False

# Load only the full episode version of the dataset
if "train_ep_indices" not in data_info.keys():
    train_idx, val_idx = data_info["train_indices"], data_info["val_indices"]
else:
    train_idx, val_idx = data_info["train_ep_indices"], data_info["val_ep_indices"]
train_dataset, val_dataset = get_MS_loaders(cfg, return_datasets=True, 
                                            indices=(train_idx, val_idx))

# Model
model = config.get_model(cfg, device="cpu")
checkpoint_io = CheckpointIO(model_dir, model=model)
load_dict = checkpoint_io.load("model_best.pt")
model.train()

Found existing data info file
Using override indices
Loading action and state scaling from file
/home/mrl/Documents/Projects/tskill/out/PegInsertion/023/model_best.pt
=> Loading checkpoint from local file...




load state dict: <All keys matched successfully>


TSkillCVAE(
  (decoder): Transformer(
    (encoder): TransformerEncoder(
      (layers): ModuleList(
        (0-1): 2 x TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
          )
          (linear1): Linear(in_features=128, out_features=256, bias=True)
          (dropout): Dropout(p=0.2, inplace=False)
          (linear2): Linear(in_features=256, out_features=128, bias=True)
          (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.2, inplace=False)
          (dropout2): Dropout(p=0.2, inplace=False)
        )
      )
      (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    )
    (decoder): TransformerDecoder(
      (layers): ModuleList(
        (0-1): 2 x TransformerDecoderLayer(
          (self_attn): MultiheadAttention(
       

In [2]:
train = True
if not train:
    dataset = val_dataset
    print("Using Validation Dataset")
    idxs = val_idx
else:
    dataset = train_dataset
    print("Using Training Dataset")
    idxs = train_idx

i = 1
data = dataset[i]
# with torch.no_grad():
data = {k: torch.vstack((v,v)) for k,v in data.items()}
data["seq_pad_mask"][0,...] = torch.ones_like(data["seq_pad_mask"][0,...])
data["skill_pad_mask"][0,...] = torch.ones_like(data["skill_pad_mask"][0,...])
print(data["seq_pad_mask"].shape)
model.zero_grad()
out = model(data)
print(f"==>> out: {out}")

Using Training Dataset
torch.Size([2, 152])
==>> batch_mask: tensor([False,  True])
==>> out: {'a_hat': tensor([[[-0.5221,  0.1630, -0.9107,  ...,  0.1944,  0.3796,  0.9750],
         [-0.5171, -0.0427, -0.8406,  ...,  0.3647,  0.2453,  0.9784],
         [-0.4795,  0.1366, -0.9926,  ...,  0.4508,  0.5319,  0.9671],
         ...,
         [ 0.7082,  0.6392,  0.4021,  ..., -0.9295,  0.6603, -0.9849],
         [ 0.5679,  0.6016,  0.6472,  ..., -0.6540,  0.4302, -0.9721],
         [ 0.6214,  0.4352,  0.6776,  ..., -0.6956,  0.4620, -0.9791]],

        [[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]],
       grad_