In [1]:
!nvidia-smi

Wed Jul 26 18:44:41 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.199.02   Driver Version: 470.199.02   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  Off  | 00000000:01:00.0  On |                  N/A |
| N/A   71C    P8     9W /  N/A |    444MiB /  6069MiB |     34%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

# Libs

In [2]:
# reload sccripts on change
import IPython
%load_ext autoreload

%autoreload 2

In [3]:
import logging
logging.basicConfig(level="INFO")
import os 
import os.path as osp

import pandas as pd

import sys

import torch

from torch.utils.data import Dataset, DataLoader

In [4]:
# import scripts
sys.path.append("../src/")

In [5]:
from config import Config as cfg
from jepsam_tokenizer import SimpleTokenizer
import vocabulary as vocab

In [None]:
os.listdir(cfg.DATASET["PATH"])

['vocab.txt', '.ipynb_checkpoints', 'vocab_C.txt', 'v1', 'v2', 'vocab_W.txt']

## Dataset class

In [11]:
import albumentations as A
from albumentations.pytorch import ToTensorV2

import numpy as np
from PIL import Image

from torch.nn.utils.rnn import pad_sequence


class SimpleJEPSAMDataset(Dataset):
    def __init__(
        self,
        csv:str = None,
        df: pd.DataFrame = None,
        apply_transforms: bool = True,
        task: str = "train"    
    ):
        super().__init__()

        self.vocab = vocab
        self.dataset_directory = osp.join(cfg.DATASET["PATH"], "v1/JEPS_data")
        self.tokenizer = SimpleTokenizer(vocab)
        self.cfg = cfg
        self.TOKENS_MAPPING, self.REVERSE_TOKENS_MAPPING  = self.vocab.load_vocab() 
        
        if df is not None:
            self.dataset_points = df.copy()
        else:
            self.dataset_points = pd.read_csv(csv)
        
        # image transforms
        self.apply_transforms = apply_transforms
        if task == "train":
            tfms = [
                getattr(A, tfms)(**params) for tfms, params in self.cfg.DATASET["TRAIN_TFMS"].items()
            ]
        else:
            tfms = [
                getattr(A, tfms)(**params) for tfms, params in self.cfg.DATASET["TEST_TFMS"].items()
            ]

        tfms.append(A.Normalize())
        tfms.append(ToTensorV2())
        self.transforms = A.Compose(tfms) if self.apply_transforms else None

    def __len__(self):
        return len(self.dataset_points)

    def __getitem__(self, idx):
        
        data_point = self.dataset_points.iloc[idx]
        
        # visual inputs
        ## in state
        in_state = np.array(Image.open(os.path.join(
            self.dataset_directory, str(data_point.sample_ID), str(data_point.in_state)
        )))

        goal_state = np.array(Image.open(os.path.join(
            self.dataset_directory, str(data_point.sample_ID), str(data_point.goal_state)
        )))
        # apply image treansforms
        if self.apply_transforms:
            # apply transforms
            in_state = self.transforms(image=in_state)["image"]
            goal_state = self.transforms(image=goal_state)["image"]        
        
        # Language modalities
        ## action desc
        action_description = self.tokenizer.encode(input=data_point.action_description)

        ## motor cmd
        motor_command = self.tokenizer.encode(input=data_point.motor_cmd)
        
        
        sample = {
            "sample_id": data_point.sample_ID,
            "in_state": in_state.float(),
            "goal_state": goal_state.float(),
            "action_desc": {
                "raw"   : data_point.action_description,
                "ids"   : action_description.long(),
                "length": data_point.len_action_desc
            },
            "motor_cmd": {
                "raw"   : data_point.motor_cmd,
                "ids"   : motor_command.long(),
                "length": data_point.len_motor_cmd
            }
        }

        return sample

    def collate_fn(self, batch):
        
        # imgs
        batch_input_state = [b["in_state"] for b in batch]
        batch_input_state_stack = torch.stack(batch_input_state)
        
        batch_goal_state = [b["goal_state"] for b in batch]
        batch_goal_state_stack = torch.stack(batch_goal_state)

        # ad
        batch_action_desc = [b["action_desc"]["ids"] for b in batch]
        # print(batch_action_desc)
        batch_action_description = pad_sequence(
            batch_action_desc, 
            batch_first=True, 
            padding_value=self.tokenizer.TOKENS_MAPPING["[PAD]"]
        ).unsqueeze(1)
        # print(batch_action_description)
        
        batch_action_desc_lens = torch.as_tensor([b["action_desc"]["length"] for b in batch])
        # batch_action_desc_lens_stack = torch.tensor(batch_action_desc_lens)
        # print(batch_action_desc_lens_stack)
        
        #cmd
        batch_motor_commands = [b["motor_cmd"]["ids"] for b in batch]
        batch_motor_commands = pad_sequence(
            batch_motor_commands, 
            batch_first=True, 
            padding_value=self.tokenizer.TOKENS_MAPPING["[PAD]"]
        ).unsqueeze(1)
        batch_motor_commands_lens = torch.as_tensor([b["motor_cmd"]["length"] for b in batch])
        
        out = (
            batch_input_state_stack, 
            batch_goal_state_stack, 
            batch_action_description, 
            batch_motor_commands, 
            batch_action_desc_lens, 
            batch_motor_commands_lens
        )
        return out

In [12]:
def get_dataloaders(
        train_df: pd.DataFrame, 
        val_df: pd.DataFrame = None,
        dataset_module:Dataset=SimpleJEPSAMDataset,
        ) -> tuple:

    val_pct = cfg.DATASET["VALIDATION_PCT"]

    if val_df is None:
        # 80/20 train/test split
        random_indices = np.random.rand(len(train_df)) < (1-val_pct)

        # Split the DataFrame into train and test sets
        train = train_df[random_indices]
        val = train_df[~random_indices]

    # datasets
    train_ds        =   dataset_module(df=train)
    val_ds          =   dataset_module(df=val)

    logging.info(
        f"Prepared {len(train_ds)} training samples and {len(val_ds)} validation samples ")
    # data loaders
    train_dl = DataLoader(
        dataset=train_ds,
        batch_size=cfg.TRAIN["TRAIN_BATCH_SIZE"],
        shuffle=True,
        num_workers=cfg.TRAIN["NUM_WORKERS"],
        pin_memory=True,
        collate_fn=train_ds.collate_fn if dataset_module == SimpleJEPSAMDataset else None,
        drop_last=True
    )

    val_dl = DataLoader(
        dataset=val_ds,
        batch_size=cfg.TRAIN["TEST_BATCH_SIZE"],
        shuffle=False,
        num_workers=cfg.TRAIN["NUM_WORKERS"],
        pin_memory=True,
        collate_fn=val_ds.collate_fn if dataset_module == SimpleJEPSAMDataset else None,
        drop_last=True

    )

    return (train_dl, val_dl)

In [13]:
tdf = pd.read_csv(
    osp.join(cfg.DATASET['PATH'], "v1/updated_train.csv")
)

tdf.head()

Unnamed: 0,sample_ID,in_state,goal_state,validator,action_description,motor_cmd,len_action_desc,len_motor_cmd
0,1005,0,9,amihretu,put the :BOTTLE to the left of :BOTTLE,:BOTTLE BLUE POSE-9 :BOTTLE RED POSE-2 :BOTTLE...,8,11
1,1011,0,9,amihretu,move the :BOTTLE left,:BOTTLE BLUE POSE-3 :BOTTLE #'*leftward-trans...,4,8
2,1012,0,9,amihretu,put the :BOTTLE to the right of :MUG,:BOTTLE BLUE POSE-7 :MUG RED POSE-3 :BOTTLE #...,8,11
3,1013,0,9,amihretu,shift the :CUP backwards,:CUP RED POSE-4 :CUP #'*backward-transformati...,4,8
4,1015,0,9,amihretu,shift the :BOTTLE forwards,:BOTTLE GREEN POSE-3 :BOTTLE #'*forward-trans...,4,8


In [16]:
train_dl, val_dl = get_dataloaders(
    train_df=tdf
)

INFO:root:Prepared 1442 training samples and 528 validation samples 


In [17]:
try:
    logging.info("\n>> train data loader")
    for data in train_dl:
        s_id, in_state, goal_state, ad, cmd = data['sample_id'], data[
            'in_state'], data['goal_state'], data['action_desc'], data["motor_cmd"]
        print("In\t\t\t:", in_state.shape)
        print("Goal\t\t\t:", goal_state.shape)
        print("Action desc\t\t:", ad["ids"].shape)
        print("Action desc (len)\t:", ad["length"].shape)

        print("CMD\t\t\t:", cmd["ids"].shape)
        print("CMD(len)\t\t:", cmd["length"].shape)
        break

    logging.info("\n\n>> val data loader")
    for data in val_dl:
        s_id, in_state, goal_state, ad, cmd = data['sample_id'], data[
            'in_state'], data['goal_state'], data['action_desc'], data["motor_cmd"]
        print("In\t\t:", in_state.shape)
        print("Goal\t\t:", goal_state.shape)
        print("Action desc\t:", ad["ids"].shape)
        print("Action desc (len)\t:", ad["length"].shape)

        print("CMD\t\t:", cmd["ids"].shape)
        print("CMD(len)\t\t:", cmd["length"].shape)
        break
except Exception as e:
    logging.error(f"\n\t>> {e}"+"...\n\t>> Looks like SimpleJEPSAMDataset is being used")
    logging.info("\n>> train data loader")
    print(f"# train batches\t\t: {len(train_dl)}")
    for data in train_dl:
        in_state, goal_state, ad, cmd, ad_lens, cmd_lens = data[0], data[1], data[2], data[3], data[4], data[5]
        print("In\t\t\t:", in_state.shape)
        print("Goal\t\t\t:", goal_state.shape)
        print("Action desc\t\t:", ad.shape)
        print("Action desc (len)\t:", ad_lens.shape)
        print("CMD\t\t\t:", cmd.shape)
        print("CMD(len)\t\t:", cmd_lens.shape)
        print()
        break

INFO:root:
>> train data loader
ERROR:root:
	>> list indices must be integers or slices, not str...
	>> Looks like SimpleJEPSAMDataset is being used
INFO:root:
>> train data loader


# train batches		: 90
In			: torch.Size([16, 3, 128, 128])
Goal			: torch.Size([16, 3, 128, 128])
Action desc		: torch.Size([16, 1, 10])
Action desc (len)	: torch.Size([16])
CMD			: torch.Size([16, 1, 11])
CMD(len)		: torch.Size([16])

