In [1]:
import itertools
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, Dataset, Subset, ConcatDataset
from pathlib import Path
import numpy as np
from tqdm import tqdm
import abc

from scipy.spatial.transform import Rotation as R
from PIL import Image
import torchvision
import json


# create abstract Dataset class called StickDataset
import cv2


from utils.r3D_semantic_dataset import load_depth
from utils.metrics import get_act_mean_std
from utils.traverse_data import iter_dir_for_traj_pths


class BaseStickDataset(Dataset, abc.ABC):
    def __init__(self, traj_path, time_skip, time_offset, time_trim):
        super().__init__()
        self.traj_path = Path(traj_path)
        self.time_skip = time_skip
        self.time_offset = time_offset
        self.time_trim = time_trim
        self.img_pth = self.traj_path / "images"
        self.depth_pth = self.traj_path / "depths"
        self.conf_pth = self.traj_path / "confs"
        self.labels_pth = self.traj_path / "labels.json"

        self.labels = json.load(self.labels_pth.open("r"))
        self.img_keys = sorted(self.labels.keys())
        # lable structure: {image_name: {'xyz' : [x,y,z], 'rpy' : [r, p, y], 'gripper': gripper}, ...}

        self.labels = np.array(
            [self.flatten_label(self.labels[k]) for k in self.img_keys]
        )

        # filter using time_skip and time_offset and time_trim. start from time_offset, skip time_skip, and remove last time_trim
        self.labels = self.labels[: -self.time_trim][self.time_offset :: self.time_skip]

        # filter keys using time_skip and time_offset and time_trim. start from time_offset, skip time_skip, and remove last time_trim
        self.img_keys = self.img_keys[: -self.time_trim][
            self.time_offset :: self.time_skip
        ]

    def flatten_label(self, label):
        # flatten label
        xyz = label["xyz"]
        rpy = label["rpy"]
        gripper = label["gripper"]
        return np.concatenate((xyz, rpy, np.array([gripper])))

    def __len__(self):
        return len(self.img_keys)

    def __getitem__(self, idx):
        # not implemented

        raise NotImplementedError


class StickDataset(BaseStickDataset, abc.ABC):
    def __init__(
        self, traj_path, traj_len, time_skip, time_offset, time_trim, traj_skip
    ):
        super().__init__(traj_path, time_skip, time_offset, time_trim)
        self.traj_len = traj_len
        self.traj_skip = traj_skip
        self.reformat_labels(self.labels)
        self.act_metrics = None

    def set_act_metrics(self, act_metrics):
        self.act_metrics = act_metrics

    def reformat_labels(self, labels):
        # reformat labels to be delta xyz, delta rpy, next gripper state
        new_labels = np.zeros_like(labels)
        new_img_keys = []

        for i in range(len(labels) - 1):
            if i == 0:
                current_label = labels[i]
                next_label = labels[i + 1]
            else:
                next_label = labels[i + 1]

            current_matrix = np.eye(4)
            r = R.from_euler("xyz", current_label[3:6], degrees=False)
            current_matrix[:3, :3] = r.as_matrix()
            current_matrix[:3, 3] = current_label[:3]

            next_matrix = np.eye(4)
            r = R.from_euler("xyz", next_label[3:6], degrees=False)
            next_matrix[:3, :3] = r.as_matrix()
            next_matrix[:3, 3] = next_label[:3]

            delta_matrix = np.linalg.inv(current_matrix) @ next_matrix
            delta_xyz = delta_matrix[:3, 3]
            delta_r = R.from_matrix(delta_matrix[:3, :3])
            delta_rpy = delta_r.as_euler("xyz", degrees=False)

            del_gripper = next_label[6] - current_label[6]
            xyz_norm = np.linalg.norm(delta_xyz)
            rpy_norm = np.linalg.norm(delta_r.as_rotvec())

            if xyz_norm < 0.01 and rpy_norm < 0.008 and abs(del_gripper) < 0.05:
                # drop this label and corresponding image_key since the delta is too small (basically the same image)
                continue

            new_labels[i] = np.concatenate(
                (delta_xyz, delta_rpy, np.array([next_label[6]]))
            )
            new_img_keys.append(self.img_keys[i])
            current_label = next_label

        # remove labels with all 0s
        new_labels = new_labels[new_labels.sum(axis=1) != 0]
        assert len(new_labels) == len(new_img_keys)
        self.labels = new_labels
        self.img_keys = new_img_keys

    def load_labels(self, idx):
        # load labels with window size of traj_len, starting from idx and moving window by traj_skip
        labels = self.labels[
            idx * self.traj_skip : idx * self.traj_skip + self.traj_len
        ]
        # normalize labels
        if self.act_metrics is not None:
            labels = (labels - self.act_metrics["mean"].numpy()) / self.act_metrics[
                "std"
            ].numpy()
        return labels

    def get_img_pths(self, idx):
        # get image paths with window size of traj_len, starting from idx and moving window by traj_skip
        img_keys = self.img_keys[
            idx * self.traj_skip : idx * self.traj_skip + self.traj_len
        ]
        img_pths = [self.img_pth / k for k in img_keys]
        return img_pths

    def __len__(self):
        return (len(self.img_keys) - self.traj_len) // self.traj_skip + 1

    def __getitem__(self, idx):
        if idx < 0 or idx >= len(self):
            raise IndexError()
        return None, self.load_labels(idx)


class ImageStickDataset(StickDataset):
    def __init__(
        self,
        traj_path,
        traj_len,
        time_skip,
        time_offset,
        time_trim,
        traj_skip,
        img_size,
        pre_load=False,
        transforms=None,
    ):
        super().__init__(
            traj_path, traj_len, time_skip, time_offset, time_trim, traj_skip
        )
        self.img_size = img_size
        self.pre_load = pre_load
        self.transforms = transforms
        self.preprocess_img_transforms = torchvision.transforms.Compose(
            [
                torchvision.transforms.Resize(self.img_size),
                torchvision.transforms.ToTensor(),
            ]
        )
        if self.pre_load:
            self.imgs = self.load_imgs()

    def load_imgs(self):
        # load images in uint8 with window size of traj_len, starting from idx and moving window by traj_skip
        imgs = []

        for key in tqdm(self.img_keys):
            img = Image.open(str(self.img_pth / key))
            img = self.preprocess_img_transforms(img)
            imgs.append(img)
        # add a nex axis at the beginning
        imgs = torch.stack(imgs, dim=0)
        return imgs

    def __getitem__(self, idx):
        _, labels = super().__getitem__(idx)

        if self.pre_load:
            imgs = self.imgs[
                idx * self.traj_skip : idx * self.traj_skip + self.traj_len
            ]
        else:
            imgs = []
            for key in self.img_keys[
                idx * self.traj_skip : idx * self.traj_skip + self.traj_len
            ]:
                img = Image.open(str(self.img_pth / key))
                img = self.preprocess_img_transforms(img)
                imgs.append(img)
            # add a nex axis at the beginning
            imgs = torch.stack(imgs, dim=0)

        if self.transforms:
            imgs = self.transforms(imgs)

        return imgs, labels


In [2]:
def get_image_stick_dataset(
    data_path,
    traj_len=1,
    traj_skip=1,
    time_skip=4,
    time_offset=5,
    time_trim=5,
    img_size=224,
    pre_load=True,
    apply_transforms=True,
    val_mask=None,
    mask_texts=None,
    cfg=None,
):
    # add transforms for normalization and converting to float tensor
    if type(data_path) == str:
        data_path = Path(data_path)

    if apply_transforms:
        transforms = torchvision.transforms.Compose(
            [
                torchvision.transforms.Normalize(
                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                )
            ]
        )
    else:
        transforms = None

    train_traj_paths, val_traj_paths, test_traj_paths = iter_dir_for_traj_pths(
        data_path, val_mask, mask_texts
    )
    # train_traj_paths = train_traj_paths[:64]
    # val_traj_paths = val_traj_paths[:16]
    # test_traj_paths = test_traj_paths[:16]
    # concatenate all the Datasets for all the trajectories
    train_dataset = ConcatDataset(
        [
            ImageStickDataset(
                traj_path,
                traj_len,
                time_skip,
                time_offset_n,
                time_trim,
                traj_skip,
                img_size,
                pre_load=pre_load,
                transforms=transforms,
            )
            for traj_path, time_offset_n in itertools.product(
                train_traj_paths, [time_offset, time_offset + 2]
            )
        ]
    )

    

    if len(val_traj_paths) > 0:
        val_dataset = ConcatDataset(
            [
                ImageStickDataset(
                    traj_path,
                    traj_len,
                    time_skip,
                    time_offset,
                    time_trim,
                    traj_skip,
                    img_size,
                    pre_load=pre_load,
                    transforms=transforms,
                )
                for traj_path in val_traj_paths
            ]
        )
        
    else:
        val_dataset = None

    if len(test_traj_paths) > 0:
        test_dataset = ConcatDataset(
            [
                ImageStickDataset(
                    traj_path,
                    traj_len,
                    time_skip,
                    time_offset,
                    time_trim,
                    traj_skip,
                    img_size,
                    pre_load=pre_load,
                    transforms=transforms,
                )
                for traj_path in test_traj_paths
            ]
        )
    else:
        test_dataset = None

    return train_dataset, val_dataset, test_dataset

In [3]:
%cd hello-robot-hack

[Errno 2] No such file or directory: 'hello-robot-hack'
/scratch/ar7420/VINN/hello-robot-hack


In [16]:
!echo "# hello-robot-hack" >> README.md
!git init
!git add README.md
!git commit -m "first commit"
!git branch -M main
!git remote add origin https://github.com/RaiAnant/hello-robot-hack.git
!git push -u origin main

Initialized empty Git repository in /scratch/ar7420/VINN/hello-robot-hack/.git/
[master (root-commit) 25afa31] first commit
 1 file changed, 1 insertion(+)
 create mode 100644 README.md
Username for 'https://github.com': ^C


In [4]:
val_mask = {'home': [],
  'env': [],
  'traj': ['2023-04-11--23-20-07_0', '2023-04-11--23-20-22_0']}

In [5]:
train, val, test = get_image_stick_dataset(data_path='/vast/ar7420/iphone_data/Benchmarking_Export/Door_Opening/CDS_Home/Env1', img_size=[224, 224], val_mask = val_mask)

Iterating through:  /vast/ar7420/iphone_data/Benchmarking_Export/Door_Opening/CDS_Home/Env1
Total number of trajectories:  24
total number of train trajectories:  24
total number of test trajectories:  0
Total number of val trajectories:  2


100%|██████████| 34/34 [00:00<00:00, 127.04it/s]
100%|██████████| 33/33 [00:00<00:00, 141.12it/s]
100%|██████████| 33/33 [00:00<00:00, 142.83it/s]
100%|██████████| 33/33 [00:00<00:00, 143.55it/s]
100%|██████████| 33/33 [00:00<00:00, 140.01it/s]
100%|██████████| 32/32 [00:00<00:00, 136.56it/s]
100%|██████████| 37/37 [00:00<00:00, 139.20it/s]
100%|██████████| 37/37 [00:00<00:00, 142.64it/s]
100%|██████████| 37/37 [00:00<00:00, 142.13it/s]
100%|██████████| 35/35 [00:00<00:00, 142.70it/s]
100%|██████████| 35/35 [00:00<00:00, 142.07it/s]
100%|██████████| 36/36 [00:00<00:00, 143.56it/s]
100%|██████████| 35/35 [00:00<00:00, 140.29it/s]
100%|██████████| 35/35 [00:00<00:00, 144.38it/s]
100%|██████████| 36/36 [00:00<00:00, 142.54it/s]
100%|██████████| 35/35 [00:00<00:00, 144.28it/s]
100%|██████████| 42/42 [00:00<00:00, 140.20it/s]
100%|██████████| 40/40 [00:00<00:00, 139.24it/s]
100%|██████████| 33/33 [00:00<00:00, 140.94it/s]
100%|██████████| 34/34 [00:00<00:00, 142.33it/s]
100%|██████████| 35/