In [3]:
import os
import pickle

import numpy as np
import torch
from tqdm.auto import tqdm

from policy import config
from policy.dataset.ms2dataset import get_MS_loaders
from policy.checkpoints import CheckpointIO

model_dir = "/home/mrl/Documents/Projects/tskill/out/PegInsertion/VAE/055"
cfg_path = os.path.join(model_dir, "config.yaml")
cfg = config.load_config(cfg_path, None)

# Dataset
cfg["data"]["pad"] = False
cfg["data"]["augment"] = False
cfg["data"]["full_seq"] = False
cfg["data"]["max_count"] = 100
cfg["data"]["val_split"] = 0
cfg["data"]["dataset"] = "/home/mrl/Documents/Projects/tskill/data/demos/v0/rigid_body/PegInsertionSide-v0/trajectory.rgbd.pd_joint_delta_pos_c256.h5"

# Load only the full episode version of the dataset
train_dataset, val_dataset = get_MS_loaders(cfg, return_datasets=True, 
                                            save_override=True,
                                            preshuffle=False,
                                            fullseq_override=True,
                                            )
print(len(train_dataset), len(val_dataset))
# Model
model = config.get_model(cfg, device="cuda")
checkpoint_io = CheckpointIO(model_dir, model=model)
load_dict = checkpoint_io.load("model_best.pt")
stt_encoder = model.stt_encoder
model.eval()

Updating train & val indices
Recomputing scaling functions...


Collecting all training data info:: 100%|██████████| 100/100 [00:23<00:00,  4.20it/s]

Computing seperate gripper scaling
Computing normal quantile transform
Computing linear scaling
Adding batch dimension to returned data!
100 0
freezing state encoder network!
/home/mrl/Documents/Projects/tskill/out/PegInsertion/VAE/055/model_best.pt
=> Loading checkpoint from local file...
load state dict: <All keys matched successfully>





TSkillCVAE(
  (decoder): Transformer(
    (encoder): TransformerEncoder(
      (layers): ModuleList(
        (0-2): 3 x TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
          )
          (linear1): Linear(in_features=256, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=512, out_features=256, bias=True)
          (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1, inplace=False)
        )
      )
      (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    )
    (decoder): TransformerDecoder(
      (layers): ModuleList(
        (0-4): 5 x TransformerDecoderLayer(
          (self_attn): MultiheadAttention(
       

In [4]:
import h5py

dataset_path = cfg["data"]["dataset"]
train_dataset.data.close()
val_dataset.data.close()
dataset_file = h5py.File(dataset_path, "r+")
train_dataset.data = dataset_file
for i in range(len(train_dataset)):
    print(f"==>> train_dataset[i]: {train_dataset[i].keys()}")
    rgb = train_dataset[i]["rgb"].to(model._device)
    print(f"==>> rgb.shape: {rgb.shape}")
    with torch.no_grad():
        img_feat, img_pe = stt_encoder(rgb)
        print(f"==>> img_feat.shape: {img_feat.shape}")
        print(f"==>> img_pe.shape: {img_pe.shape}")
    img_feat, img_pe = img_feat[:,0,...].detach().cpu().numpy(), img_pe[:,0,...].detach().cpu().numpy()
    print(f"==>> img_feat.shape: {img_feat.shape}")
    print(f"==>> img_pe.shape: {img_pe.shape}")
    eps = train_dataset.episodes[train_dataset.owned_indices[i]]
    print(f"==>> eps: {eps}")
    trajectory = dataset_file[f"traj_{eps['episode_id']}"]
    print(f"==>> trajectory: {trajectory}")
    del trajectory["obs"]["image"]
    trajectory.create_dataset("obs/resnet18/img_feat",
                              data=img_feat,
                              dtype=img_feat.dtype,
                              compression="gzip",
                            compression_opts=5,),
    trajectory.create_dataset("obs/resnet18/img_pe",
                              data=img_pe,
                              dtype=img_pe.dtype,
                              compression="gzip",
                              compression_opts=5,)

dataset_file.close()


==>> train_dataset[i]: dict_keys(['state', 'seq_pad_mask', 'skill_pad_mask', 'actions', 'rgb', 'dec_src_mask', 'dec_mem_mask', 'dec_tgt_mask', 'enc_src_mask', 'enc_mem_mask', 'enc_tgt_mask'])
==>> rgb.shape: torch.Size([1, 137, 4, 3, 128, 128])
==>> img_feat.shape: torch.Size([137, 1, 4, 16, 512])
==>> img_pe.shape: torch.Size([137, 1, 4, 16, 256])
==>> img_feat.shape: (137, 4, 16, 512)
==>> img_pe.shape: (137, 4, 16, 256)
==>> eps: {'episode_id': 0, 'episode_seed': 0, 'reset_kwargs': {'seed': 0, 'options': {}}, 'control_mode': 'pd_joint_delta_pos', 'elapsed_steps': 137, 'info': {'elapsed_steps': 137, 'success': True, 'peg_head_pos_at_hole': [-0.008832097053527832, -0.0012859664857387543, -0.0007166117429733276]}}
==>> trajectory: <HDF5 group "/traj_0" (4 members)>
==>> train_dataset[i]: dict_keys(['state', 'seq_pad_mask', 'skill_pad_mask', 'actions', 'rgb', 'dec_src_mask', 'dec_mem_mask', 'dec_tgt_mask', 'enc_src_mask', 'enc_mem_mask', 'enc_tgt_mask'])
==>> rgb.shape: torch.Size([1, 1