In [1]:
import os
import pickle

import numpy as np
import torch
from tqdm.auto import tqdm

from policy import config
from policy.dataset.dataset_loaders import dataset_loader
from policy.checkpoints import CheckpointIO

model_dir = "/home/mrl/Documents/Projects/tskill/out/PegInsertion/VAE/004"
cfg_path = os.path.join(model_dir, "config.yaml")
cfg = config.load_config(cfg_path, None)

# Dataset
cfg["data"]["pad"] = False
cfg["data"]["augment"] = False
cfg["data"]["full_seq"] = False
cfg["data"]["max_count"] = 100
cfg["data"]["val_split"] = 0
# cfg["data"]["dataset"] = "/home/mrl/Documents/Projects/tskill/data/demos/v0/rigid_body/PegInsertionSide-v0/trajectory.rgbd.pd_joint_delta_pos_c512.h5"

# Load only the full episode version of the dataset
train_dataset, val_dataset = dataset_loader(cfg, return_datasets=True, 
                                            save_override=False,
                                            preshuffle=False,
                                            fullseq_override=True,
                                            )
print(len(train_dataset), len(val_dataset))
# Model
model = config.get_model(cfg, device="cuda")
checkpoint_io = CheckpointIO(model_dir, model=model)
load_dict = checkpoint_io.load("model_best.pt")
stt_encoder = model.stt_encoder
model.eval()

  from .autonotebook import tqdm as notebook_tqdm


Found existing multitask data info file
>>> Loading multitask dataset from file: out/PegInsertion/VAE/004/data_info.pickle
Loading dataset: /home/mrl/Documents/Projects/tskill/LIBERO/libero/datasets/libero_90/KITCHEN_SCENE6_close_the_microwave_demo.hdf5
Loading indices from file: out/PegInsertion/VAE/004/data_info.pickle
Overriding full seq config!
Adding batch dimension to returned data!
Loading dataset: /home/mrl/Documents/Projects/tskill/LIBERO/libero/datasets/libero_90/LIVING_ROOM_SCENE6_put_the_chocolate_pudding_to_the_left_of_the_plate_demo.hdf5
Loading indices from file: out/PegInsertion/VAE/004/data_info.pickle
Overriding full seq config!
Adding batch dimension to returned data!
Loading dataset: /home/mrl/Documents/Projects/tskill/LIBERO/libero/datasets/libero_90/KITCHEN_SCENE10_close_the_top_drawer_of_the_cabinet_demo.hdf5
Loading indices from file: out/PegInsertion/VAE/004/data_info.pickle
Overriding full seq config!
Adding batch dimension to returned data!
Loading dataset: /



/home/mrl/Documents/Projects/tskill/out/PegInsertion/VAE/004/model_best.pt
=> Loading checkpoint from local file...
load state dict: <All keys matched successfully>


  state_dict = torch.load(filename, map_location="cpu")


TSkillCVAE(
  (decoder): Transformer(
    (encoder): TransformerEncoder(
      (layers): ModuleList(
        (0-3): 4 x TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
          )
          (linear1): Linear(in_features=256, out_features=512, bias=True)
          (dropout): Dropout(p=0.2, inplace=False)
          (linear2): Linear(in_features=512, out_features=256, bias=True)
          (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.2, inplace=False)
          (dropout2): Dropout(p=0.2, inplace=False)
        )
      )
      (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    )
    (decoder): TransformerDecoder(
      (layers): ModuleList(
        (0-5): 6 x TransformerDecoderLayer(
          (self_attn): MultiheadAttention(
       

In [10]:
import h5py

for dataset in val_dataset.sequence_datasets:
    dataset.data.close()

j = 0
for dataset in train_dataset.sequence_datasets:
    print(f"##################### {j}")
    dataset_path = dataset.dataset_file
    print(f"==>> dataset_file: {dataset_path}")
    dataset.data.close()
    dataset_file = h5py.File(dataset_path, "r")
    dataset_file.close()
    dataset_file = h5py.File(dataset_path, "r+")
    dataset.data = dataset_file
    dataset.episodes = dataset_file["data"]
    for i in range(len(dataset)):
        eps = dataset.owned_indices[i]
        print(f"==>> eps: {eps}")
        trajectory = dataset.episodes[f"demo_{eps}"]
        print(f"==>> trajectory: {trajectory}")
        print(f"==>> trajectory: {trajectory.keys()}")
        if "resnet18" in trajectory["obs"].keys():
            continue
        rgb = dataset[i]["rgb"].to(model._device)
        print(f"==>> rgb.shape: {rgb.shape}")
        with torch.no_grad():
            img_feat, img_pe = stt_encoder(rgb)
        img_feat, img_pe = img_feat[:,0,...].detach().cpu().numpy(), img_pe[:,0,...].detach().cpu().numpy()
        print(f"==>> img_feat.shape: {img_feat.shape}")
        print(f"==>> img_pe.shape: {img_pe.shape}")
        # raise ValueError
        # # del trajectory["obs"]["image"]
        trajectory.create_dataset("obs/resnet18/img_feat",
                                data=img_feat,
                                dtype=img_feat.dtype,
                                compression="gzip",
                                compression_opts=5,),
        trajectory.create_dataset("obs/resnet18/img_pe",
                                data=img_pe,
                                dtype=img_pe.dtype,
                                compression="gzip",
                                compression_opts=5,)
        
    j += 1

    dataset_file.close()


##################### 0
==>> dataset_file: /home/mrl/Documents/Projects/tskill/LIBERO/libero/datasets/libero_90/KITCHEN_SCENE6_close_the_microwave_demo.hdf5
==>> eps: 0
==>> trajectory: <HDF5 group "/data/demo_0" (6 members)>
==>> trajectory: <KeysViewHDF5 ['actions', 'dones', 'obs', 'rewards', 'robot_states', 'states']>
==>> eps: 1
==>> trajectory: <HDF5 group "/data/demo_1" (6 members)>
==>> trajectory: <KeysViewHDF5 ['actions', 'dones', 'obs', 'rewards', 'robot_states', 'states']>
==>> eps: 2
==>> trajectory: <HDF5 group "/data/demo_2" (6 members)>
==>> trajectory: <KeysViewHDF5 ['actions', 'dones', 'obs', 'rewards', 'robot_states', 'states']>
==>> eps: 3
==>> trajectory: <HDF5 group "/data/demo_3" (6 members)>
==>> trajectory: <KeysViewHDF5 ['actions', 'dones', 'obs', 'rewards', 'robot_states', 'states']>
==>> eps: 4
==>> trajectory: <HDF5 group "/data/demo_4" (6 members)>
==>> trajectory: <KeysViewHDF5 ['actions', 'dones', 'obs', 'rewards', 'robot_states', 'states']>
==>> eps: 5
=