In [3]:
!nvidia-smi

Thu Jun 22 06:48:58 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.161.03   Driver Version: 470.161.03   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  Off  | 00000000:01:00.0  On |                  N/A |
| N/A   67C    P0    25W /  N/A |    531MiB /  6069MiB |      4%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [58]:
from addict import Dict

import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence

from torchinfo import summary

import torchvision.models as torchvision_models

import sys


import IPython
%load_ext autoreload

%autoreload 2


sys.path.append("../src")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
import models
from utils.parser import load_config

In [67]:
cfg = load_config(config_file_path="../configs/aejeps_cfg.yaml")

In [68]:
cfg.AEJEPS

{'EMBEDDING_DIM': 128,
 'HIDDEN_DIM': 512,
 'NUM_LAYERS_ENCODER': 2,
 'BATCH_FIRST': True,
 'ENCODER_DROPOUT': 0.35,
 'IS_BIDIRECTIONAL': True,
 'NUM_LAYERS_MOTOR': 1,
 'ACTIVATION_MOTOR': 'Gelu',
 'MOTOR_DROPOUT': 0.2,
 'NUM_LAYERS_LANG': 2,
 'ACTIVATION_LANG': 'Gelu',
 'LANG_DROPOUT': 0.2}

## Model refectoring

### Encoder

#### CNN backbone builder

In [42]:
def get_cnn_backbone(
    cfg:Dict, 
    backbone_name:str="resnet50", 
    freeze:bool=True,
    fc_out:int=None
):
    backbone = getattr(torchvision_models, backbone_name)(weights=cfg.MODEL.CNN_BACKBONES[backbone_name])
    
    # freeze backbone if specified
    if freeze:
        for param in backbone.parameters():
            param.requires_grad = False
    
    if fc_out is not None:
        # resnet-based models
        if "resnet" in backbone_name.lower():
            backbone.fc = nn.Linear(in_features=backbone.fc.in_features, out_features=fc_out)
            
    return backbone

In [43]:
b = get_cnn_backbone(cfg=cfg, fc_out=512)
b

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [50]:
xx = torch.randn((1, 3, 224, 224))

xx.shape

torch.Size([1, 3, 224, 224])

In [51]:
out = b(xx)

out.shape

torch.Size([1, 512])

#### JEPSAM encoder

In [86]:
class JEPSAMEncoder(nn.Module):
    def __init__(
        self, 
        cfg: Dict, 
        cnn_backbone_name:str="resnet50", 
        cnn_fc_out:int=512
    ):
        super().__init__()
        
        self.embedding = nn.Embedding(cfg.DATASET.VOCABULARY_SIZE, cfg.AEJEPS.EMBEDDING_DIM)
        
        self.image_feature_extractor = get_cnn_backbone(
            cfg=cfg,
            backbone_name=cnn_backbone_name,
            fc_out=cnn_fc_out
        )
        
        # features mixer
        encoder_input_dim = cfg.AEJEPS.EMBEDDING_DIM + 2 * self.image_feature_extractor.fc.in_features + cfg.DATASET.NUM_COMMANDS

        
        self.feature_mixing = nn.LSTM(
            input_size=encoder_input_dim, 
            hidden_size=cfg.AEJEPS.HIDDEN_DIM, 
            num_layers=cfg.AEJEPS.NUM_LAYERS_ENCODER,
            dropout=cfg.AEJEPS.ENCODER_DROPOUT, 
            bidirectional=cfg.AEJEPS.IS_BIDIRECTIONAL
        )
        
    
    def forward(self, inp:dict, mode:str='train'):
        """
        
        """
        B, max_len, *_ = inp["action_desc"].shape

        # 1. Image feature extraction
        feats_per = self.image_feature_extractor(inp["in_state"])
        feats_goal = self.image_feature_extractor(inp["goal_state"])
        ## Batch size x feat_dim -> Batch_size x (max_len x feat_dim) -> Batch_size x max_len x feat_dim
        feats_per = feats_per.repeat((1, max_len)).reshape((B, max_len, -1))
        feats_goal = feats_goal.repeat((1, max_len)).reshape((B, max_len, -1))
        
        # 2. Text feature extraction
        print(inp["action_desc"]["ids"])
        action_desc_emb = self.embedding(inp["action_desc"]["ids"])
        motor_cmd_emb = self.embedding(inp["motor_cmd"]["ids"])
        # For each batch entry determine the length of the longest of the text sequence
        lengths_max = [max(ltext, lcmd)
                       for ltext, lcmd in zip(inp["action_desc"]["length"], inp["motor_cmd"]["length"])]
        
        # 3. Feature Fusion
        concat_feats = torch.cat((feats_per, feats_goal, action_desc_emb, motor_cmd_emb), dim=2)
        
        # 4. Feature mixing
        packed_input = pack_padded_sequence(input=concat_feats, lengths=lengths_max, enforce_sorted=False)
        
        output, (hidden, carousel) = self.feature_mixing(packed_input)
        
        return output, (hidden, carousel)

In [87]:
encoder = JEPSAMEncoder(cfg=cfg)

# encoder

In [88]:
iinp = {
        "in_state": torch.randn((1, 3, 224, 224)),
        "goal_state": torch.randn((1, 3, 224, 224)),
        "action_desc": {
            torch.randn((1, 1, 64))
        },
        "motor_cmd": torch.randn((1, 1, 128))
}
encoder(iinp)

TypeError: new(): invalid data type 'str'

In [76]:
summary(
    encoder,
    [
        torch.randn((1, 3, 224, 224)),
        torch.randn((1, 3, 224, 224)),
        torch.randn((1, 1, 64)),
        torch.randn((1, 1, 128))
    ],
    
    dtypes=[
        torch.float,
        torch.float,
        torch.long,
        torch.long,
    ]
)

RuntimeError: Boolean value of Tensor with more than one value is ambiguous

### Decoder

In [None]:
class Decoder(nn.Module):
    def __init__(self, embedding_dim:):
        
        decoder_hidden_dim = self.num_directions * hidden_dim
        
        # motor command decoding layer
        self.motor_decoder = nn.LSTMCell(
            input_size=motor_dim, 
            hidden_size=decoder_hidden_dim, 
            num_layers=cfg.AEJEPS.NUM_LAYERS_MOTOR
        )
        # action desc. decoding layer
        self.lang_decoder = nn.LSTMCell(
            input_size=embedding_dim, 
            hidden_size=decoder_hidden_dim, 
            num_layers=num_layers_lang
        )
        
        # projection layers
        self.hidden_to_conv_in = nn.Linear(decoder_hidden_dim, 1024)
        self.lang_head = nn.Linear(
            in_features=decoder_hidden_dim, 
            out_features=vocabulary_size # To be discussed
        )
        self.motor_cmd_head = nn.Linear(
            in_features=decoder_hidden_dim, 
            out_features=vocabulary_size # To be discussed
        )

        self.hidden2img = self.__get_transposed_convs(
            decoder_hidden_dim, 
            image_size
        )
    
    def forward(self):
        pass
    
    def _decode_action_description(self):
        pass
    
    def _decode_image(self):
        pass
    
    def _decode_motor_command(self):
        pass
    
    def __get_transposed_convs(self, decoder_hidden_dim, image_size):
        tconv1 = nn.ConvTranspose2d(1, 4, 3, 2, 3, 0)
        tconv2 = nn.ConvTranspose2d(4, 8, 5, 2, 3, 0)
        tconv3 = nn.ConvTranspose2d(8, 16, 7, 2, 4, 1)
        tconv4 = nn.ConvTranspose2d(16, 3, 11, 1, 7, 0)

        return nn.Sequential(tconv1, tconv2, tconv3, tconv4)


### AEJEPS

In [None]:

class AutoencoderJEPS(nn.Module):
    """
    This class is an Autoencoder based deep learning implementation of a Joint Episdoic, Procedural, and Semantic Memory.

    Parameters
    ----------

    """
    def __init__(self, cfg: Dict):
        super().__init__()
        self.encoder = JEPSAMEncoder()
        self.decoder = JEPSAMDecoder()
        
    def forward(self, inp):
        enc = 

