In [1]:
!nvidia-smi

Fri Jun 23 10:38:51 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.161.03   Driver Version: 470.161.03   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  Off  | 00000000:01:00.0  On |                  N/A |
| N/A   63C    P8     9W /  N/A |    475MiB /  6069MiB |     15%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [178]:
from addict import Dict

import pandas as pd

import os.path as osp

import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

from torchinfo import summary

import torchvision.models as torchvision_models

import sys


import IPython
%load_ext autoreload

%autoreload 2


sys.path.append("../src")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [97]:
import models
from utils.parser import load_config, parse_args

from dataloader import get_dataloaders

In [98]:
cfg = load_config(config_file_path="../configs/aejeps_cfg.yaml")

In [99]:
cfg.AEJEPS

{'EMBEDDING_DIM': 256,
 'HIDDEN_DIM': 256,
 'NUM_LAYERS_ENCODER': 2,
 'BATCH_FIRST': True,
 'ENCODER_DROPOUT': 0.35,
 'IS_BIDIRECTIONAL': True,
 'NUM_LAYERS_MOTOR': 1,
 'ACTIVATION_MOTOR': 'Gelu',
 'MOTOR_DROPOUT': 0.2,
 'NUM_LAYERS_LANG': 2,
 'ACTIVATION_LANG': 'Gelu',
 'LANG_DROPOUT': 0.2}

### Get batch of data

In [101]:
cfg.DATASET.PATH

'../../dataset/'

In [102]:
tdf = pd.read_csv(
    osp.join(cfg.DATASET.PATH, "updated_train.csv")
)

tdf.head()

Unnamed: 0,sample_ID,in_state,goal_state,validator,action_description,motor_cmd,len_action_desc,len_motor_cmd
0,1005,0,9,amihretu,put the :BOTTLE to the left of :BOTTLE,:BOTTLE BLUE POSE-9 :BOTTLE RED POSE-2 :BOTTLE...,38,84
1,1011,0,9,amihretu,move the :BOTTLE left,:BOTTLE BLUE POSE-3 :BOTTLE #'*leftward-trans...,21,65
2,1012,0,9,amihretu,put the :BOTTLE to the right of :MUG,:BOTTLE BLUE POSE-7 :MUG RED POSE-3 :BOTTLE #...,36,79
3,1013,0,9,amihretu,shift the :CUP backwards,:CUP RED POSE-4 :CUP #'*backward-transformati...,24,55
4,1015,0,9,amihretu,shift the :BOTTLE forwards,:BOTTLE GREEN POSE-3 :BOTTLE #'*forward-trans...,26,65


In [103]:
train_dl, _ = get_dataloaders(
    train_df=tdf,
    cfg=cfg
)


INFO:root:Prepared 1463 training samples and 509 validation samples 


In [193]:
for data in train_dl:
#     s_id, in_state, goal_state, ad, cmd = data['sample_id'], data[
#         'in_state'], data['goal_state'], data['action_desc'], data["motor_cmd"]
#     print("In\t\t:", in_state.shape)
#     print("Goal\t\t:", goal_state.shape)
#     print("Action desc\t:", ad["ids"].shape)
#     print("Action desc (len)\t:", ad["length"].shape)

#     print("CMD\t\t:", cmd["ids"].shape)
#     print("CMD(len)\t\t:", cmd["length"].shape)
    pass
    break

In [194]:
data["action_desc"]["ids"].shape

torch.Size([4, 1, 64])

## Model refectoring

### Encoder

#### CNN backbone builder

In [195]:
def get_cnn_backbone(
    cfg:Dict, 
    backbone_name:str="resnet50", 
    freeze:bool=True,
    fc_out:int=None
):
    backbone = getattr(torchvision_models, backbone_name)(weights=cfg.MODEL.CNN_BACKBONES[backbone_name])
    
    # freeze backbone if specified
    if freeze:
        for param in backbone.parameters():
            param.requires_grad = False
    
    if fc_out is not None:
        # resnet-based models
        if "resnet" in backbone_name.lower():
            backbone.fc = nn.Linear(in_features=backbone.fc.in_features, out_features=fc_out)
            
    return backbone

In [196]:
# b = get_cnn_backbone(cfg=cfg, fc_out=512)
# b

In [197]:
# xx = torch.randn((1, 3, 224, 224))

# xx.shape

In [198]:
# 


#### JEPSAM encoder

In [199]:
class JEPSAMEncoder(nn.Module):
    def __init__(
        self, 
        cfg: Dict, 
        cnn_backbone_name:str="resnet50", 
        cnn_fc_out:int=512
    ):
        super().__init__()
        
        self.embedding = nn.Embedding(cfg.DATASET.VOCABULARY_SIZE, cfg.AEJEPS.EMBEDDING_DIM)
        
        self.image_feature_extractor = get_cnn_backbone(
            cfg=cfg,
            backbone_name=cnn_backbone_name,
            fc_out=cnn_fc_out
        )
        
        # features mixer
        encoder_input_dim = cfg.AEJEPS.EMBEDDING_DIM + 2 * self.image_feature_extractor.fc.in_features + cfg.DATASET.NUM_COMMANDS

        self.feature_mixing = nn.LSTM(
            input_size=cfg.AEJEPS.EMBEDDING_DIM, 
            hidden_size=cfg.AEJEPS.HIDDEN_DIM, 
            num_layers=cfg.AEJEPS.NUM_LAYERS_ENCODER,
            dropout=cfg.AEJEPS.ENCODER_DROPOUT, 
            bidirectional=cfg.AEJEPS.IS_BIDIRECTIONAL
        )
        
    
    def forward(self, inp:dict, mode:str='train'):
        """
        
        """
        B, _, max_len = inp["action_desc"]["ids"].shape

        # 1. Image feature extraction
        feats_per = self.image_feature_extractor(inp["in_state"])
        feats_per = feats_per.repeat((1, max_len)).reshape((B, max_len, -1))
        
        if mode =="train":
            feats_goal = self.image_feature_extractor(inp["goal_state"])
            feats_goal = feats_goal.repeat((1, max_len)).reshape((B, max_len, -1))
        else:
            pass
        
        # print(f"feats_per: {feats_per.shape}")
        # print(f"feats_goal: {feats_goal.shape}")
        
        # 2. Text feature extraction
        action_desc_emb = self.embedding(inp["action_desc"]["ids"])#.squeeze(1)
        
        if mode =="train":
            motor_cmd_emb = self.embedding(inp["motor_cmd"]["ids"])#.squeeze(1)
            # For each batch entry determine the length of the longest of the text sequence
            lengths_max = [max(ltext, lcmd)
                           for ltext, lcmd in zip(inp["action_desc"]["length"], inp["motor_cmd"]["length"])]
        else:
            lengths_max = [ltext for ltext in inp["action_desc"]["length"]]     
        # 3. Feature Fusion
        # Optional: add a projection layer that will 
        # print(feats_per.shape, feats_goal.shape, action_desc_emb.shape, motor_cmd_emb.shape)
        
        concat_feats = torch.cat((
            feats_per.unsqueeze(1), 
            feats_goal.unsqueeze(1), 
            action_desc_emb, 
            motor_cmd_emb
        ), dim=2).squeeze(1)
        
        print(f"Fused feats: {concat_feats.shape}")
        
        # 4. Feature mixing
        packed_input = pack_padded_sequence(
            input=concat_feats, 
            lengths=lengths_max, 
            enforce_sorted=False, 
            batch_first=True
        )
        
        output, (hidden, carousel) = self.feature_mixing(packed_input)
        
        output, len_output = pad_packed_sequence(output, batch_first= True)
        
        return output, len_output, hidden, carousel

### Encoder summary

In [None]:
encoder = JEPSAMEncoder(
    cnn_backbone_name="resnet18",
    cfg=cfg, 
    cnn_fc_out = cfg.AEJEPS.HIDDEN_DIM
)

o, lo, h, c = encoder(data)

print(o.shape)

# model summary
summary(
    model=encoder,
    input_dict=data,
    col_names=["kernel_size", "output_size", "num_params", "mult_adds"],

)

### Decoder

In [None]:
class Decoder(nn.Module):
    def __init__(self, embedding_dim:):
        
        decoder_hidden_dim = self.num_directions * hidden_dim
        
        # motor command decoding layer
        self.motor_decoder = nn.LSTMCell(
            input_size=motor_dim, 
            hidden_size=decoder_hidden_dim, 
            num_layers=cfg.AEJEPS.NUM_LAYERS_MOTOR
        )
        # action desc. decoding layer
        self.lang_decoder = nn.LSTMCell(
            input_size=embedding_dim, 
            hidden_size=decoder_hidden_dim, 
            num_layers=num_layers_lang
        )
        
        # projection layers
        self.hidden_to_conv_in = nn.Linear(decoder_hidden_dim, 1024)
        self.lang_head = nn.Linear(
            in_features=decoder_hidden_dim, 
            out_features=vocabulary_size # To be discussed
        )
        self.motor_cmd_head = nn.Linear(
            in_features=decoder_hidden_dim, 
            out_features=vocabulary_size # To be discussed
        )

        self.hidden2img = self.__get_transposed_convs(
            decoder_hidden_dim, 
            image_size
        )
    
    def forward(self):
        pass
    
    def _decode_action_description(self):
        pass
    
    def _decode_image(self):
        pass
    
    def _decode_motor_command(self):
        pass
    
    def __get_transposed_convs(self, decoder_hidden_dim, image_size):
        tconv1 = nn.ConvTranspose2d(1, 4, 3, 2, 3, 0)
        tconv2 = nn.ConvTranspose2d(4, 8, 5, 2, 3, 0)
        tconv3 = nn.ConvTranspose2d(8, 16, 7, 2, 4, 1)
        tconv4 = nn.ConvTranspose2d(16, 3, 11, 1, 7, 0)

        return nn.Sequential(tconv1, tconv2, tconv3, tconv4)


### AEJEPS

In [None]:

class AutoencoderJEPS(nn.Module):
    """
    This class is an Autoencoder based deep learning implementation of a Joint Episdoic, Procedural, and Semantic Memory.

    Parameters
    ----------

    """
    def __init__(self, cfg: Dict):
        super().__init__()
        self.encoder = JEPSAMEncoder()
        self.decoder = JEPSAMDecoder()
        
    def forward(self, inp):
        
        enc = self.encoder(inp)
        # reconstructed_image, decoded_action_desc, decoded_cmd = self.decoder(enc)
        return self.decoder(enc)

