In [1]:
!nvidia-smi

Tue Jun 27 05:54:08 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.161.03   Driver Version: 470.161.03   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  Off  | 00000000:01:00.0  On |                  N/A |
| N/A   62C    P0    25W /  N/A |    541MiB /  6069MiB |      3%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
import IPython
%load_ext autoreload

%autoreload 2

In [3]:
from addict import Dict

import copy

from einops import rearrange

import numpy as np

import pandas as pd

import os.path as osp

import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

from torchinfo import summary
import torchvision.models as torchvision_models

from typing import List, Tuple, Union

import sys

In [4]:
%cd ..

sys.path.append("src")

/home/zeusdric/jepsam/repo/AEJEPS


In [5]:
import models
from utils.parser import load_config, parse_args
import utils.model_utils as model_utils
from utils.ae_resnet import get_configs, ResNetEncoder, ResNetDecoder

from dataloader import get_dataloaders, SimpleTokenizer, JEPSAMDataset
import vocabulary as vocab

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
cfg = load_config(config_file_path="configs/aejeps_cfg.yaml")

In [7]:
cfg.AEJEPS

{'EMBEDDING_DIM': 256,
 'HIDDEN_DIM': 256,
 'CNN_FC_DIM': 256,
 'NUM_LAYERS_ENCODER': 2,
 'BATCH_FIRST': True,
 'ENCODER_DROPOUT': 0.35,
 'IS_BIDIRECTIONAL': True,
 'NUM_LAYERS_MOTOR': 1,
 'ACTIVATION_MOTOR': 'LeakyReLU',
 'MOTOR_DROPOUT': 0.25,
 'NUM_LAYERS_LANG': 2,
 'ACTIVATION_LANG': 'LeakyReLU',
 'LANG_DROPOUT': 0.25,
 'HIDDEN_TO_CONV': 1024,
 'DECODER_ACTIVATION': 'Sigmoid'}

### Get batch of data

In [8]:
cfg.DATASET.PATH

'../dataset/'

In [9]:
tdf = pd.read_csv(
    osp.join(cfg.DATASET.PATH, "updated_train.csv")
)

tdf.head()

Unnamed: 0,sample_ID,in_state,goal_state,validator,action_description,motor_cmd,len_action_desc,len_motor_cmd
0,1005,0,9,amihretu,put the :BOTTLE to the left of :BOTTLE,:BOTTLE BLUE POSE-9 :BOTTLE RED POSE-2 :BOTTLE...,8,11
1,1011,0,9,amihretu,move the :BOTTLE left,:BOTTLE BLUE POSE-3 :BOTTLE #'*leftward-trans...,4,8
2,1012,0,9,amihretu,put the :BOTTLE to the right of :MUG,:BOTTLE BLUE POSE-7 :MUG RED POSE-3 :BOTTLE #...,8,11
3,1013,0,9,amihretu,shift the :CUP backwards,:CUP RED POSE-4 :CUP #'*backward-transformati...,4,8
4,1015,0,9,amihretu,shift the :BOTTLE forwards,:BOTTLE GREEN POSE-3 :BOTTLE #'*forward-trans...,4,8


In [10]:
train_dl, _ = get_dataloaders(
    train_df=tdf,
    cfg=cfg,
    # dataset_module=JEPSAMDataset
)


INFO:root:Prepared 1367 training samples and 605 validation samples 


In [11]:
for data in train_dl:
#     s_id, in_state, goal_state, ad, cmd = data['sample_id'], data[
#         'in_state'], data['goal_state'], data['action_desc'], data["motor_cmd"]
#     print("In\t\t:", in_state.shape)
#     print("Goal\t\t:", goal_state.shape)
#     print("Action desc\t:", ad["ids"].shape)
#     print("Action desc (len)\t:", ad["length"].shape)

#     print("CMD\t\t:", cmd["ids"].shape)
#     print("CMD(len)\t\t:", cmd["length"].shape)
    pass
    break

In [12]:
in_state, goal_state, ad, cmd, ad_lens, cmd_lens = data

In [13]:
in_state.shape, goal_state.shape, ad.shape, cmd.shape, ad_lens.shape, cmd_lens.shape

(torch.Size([2, 3, 224, 224]),
 torch.Size([2, 3, 224, 224]),
 torch.Size([2, 1, 7]),
 torch.Size([2, 1, 11]),
 torch.Size([2]),
 torch.Size([2]))

In [14]:
cmd

tensor([[[48,  5, 28, 42, 14, 27, 39,  5,  0, 14, 46]],

        [[48,  5, 28, 38,  5,  2,  5, 46, 47, 47, 47]]])

### Tokenizer

In [15]:
# from transformers import PreTrainedTokenizerFast

# tt = PreTrainedTokenizerFast(
#     tokenizer_file=cfg.DATASET.TOKENIZER_PATH, # You can load from the tokenizer file, alternatively
#     unk_token="[UNK]",
#     pad_token="[PAD]",
#     cls_token="[CLS]",
#     sep_token="[SEP]",
#     mask_token="[MASK]",
# )


tt = SimpleTokenizer(vocab=vocab)


In [16]:
# _, in_state, goal_state, ad, cmd = data['sample_id'], data['in_state'], data['goal_state'], data['action_desc'], data["motor_cmd"]

# ad["ids"]

## Model refectoring

### Encoder

#### CNN backbone builder

In [17]:
# b = model_utils.get_cnn_backbone(cfg=cfg, fc_out=512)
# b

In [18]:
# xx = torch.randn((1, 3, 224, 224))

# xx.shape

In [19]:
# 


#### JEPSAM encoder

In [20]:
class JEPSAMEncoder(nn.Module):
    def __init__(
        self, 
        cfg: Dict, 
        cnn_backbone_name:str="resnet50"
    ):
        super().__init__()
        
        self.cfg = cfg
        
        self.device = self.cfg.TRAIN.GPU_DEVICE if torch.cuda.is_available() else "cpu"
        
        # embedding layer
        self.embedding = nn.Embedding(cfg.DATASET.VOCABULARY_SIZE, cfg.AEJEPS.EMBEDDING_DIM)        
        
        # CNN ftr extractor
        configs, bottleneck = get_configs(cnn_backbone_name)

        self.image_feature_extractor = nn.Sequential(
            ResNetEncoder(configs, bottleneck),
            nn.Flatten(),
            nn.Linear(in_features = 2048*7*7, out_features=cfg.AEJEPS.CNN_FC_DIM)
        )
        
        # features mixer
        self.feature_mixing = nn.LSTM(
            input_size=cfg.AEJEPS.EMBEDDING_DIM, 
            hidden_size=cfg.AEJEPS.HIDDEN_DIM, 
            num_layers=cfg.AEJEPS.NUM_LAYERS_ENCODER,
            dropout=cfg.AEJEPS.ENCODER_DROPOUT, 
            bidirectional=cfg.AEJEPS.IS_BIDIRECTIONAL
        )
                
        self.to(self.device)
        
        
    def forward(self, inp:Union[List, list, dict], mode:str='train'):
        """
        
        """
        if isinstance(inp, list) or isinstance(inp, List):
            in_state, goal_state, ad, cmd, ad_lens, cmd_lens = inp
        else:
            _, in_state, goal_state, ad, cmd = inp['sample_id'], inp['in_state'], inp['goal_state'], inp['action_desc'], inp["motor_cmd"]
            ad = ad["ids"]
            # print(ad["length"])
            ad_lens = ad["length"]
            
            cmd = cmd["ids"]
            cmd_lens = cmd["length"]
        
        in_state, goal_state, ad, cmd, ad_lens, cmd_lens = in_state.to(self.device), goal_state.to(self.device), ad.to(self.device), cmd.to(self.device), ad_lens.to(self.device), cmd_lens.to(self.device)
        # print(in_state.device)
        B, _, max_len = ad.shape

        # 1. Image feature extraction
        feats_per = self.image_feature_extractor(in_state)
        # feats_per = self.img_projection(feats_per.view(B, -1))
        # print(feats_per.shape)
        feats_per = feats_per.repeat((1, max_len)).reshape((B, max_len, -1))
        
        if mode =="train":
            feats_goal = self.image_feature_extractor(goal_state)
            # feats_goal = self.img_projection(feats_goal.view(B, -1))
            # print(feats_goal.shape)
            
            feats_goal = feats_goal.repeat((1, max_len)).reshape((B, max_len, -1))
        else:
            pass
        
        # print(f"feats_per: {feats_per.shape}")
        # print(f"feats_goal: {feats_goal.shape}")
        
        # 2. Text feature extraction
        action_desc_emb = self.embedding(ad)#.squeeze(1)
        
        if mode =="train":
            motor_cmd_emb = self.embedding(cmd)#.squeeze(1)
            # For each batch entry determine the length of the longest of the text sequence
            lengths_max = [max(ltext, lcmd)
                           for ltext, lcmd in zip(ad_lens, cmd_lens)]        
            # print(lengths_max)
        else:
            lengths_max = [ltext for ltext in ad_lens]     
        # 3. Feature Fusion
        # Optional: add a projection layer that will 
        # print(feats_per.shape, feats_goal.shape, action_desc_emb.shape, motor_cmd_emb.shape)
        
        # print( 
        #     feats_per.shape, 
        #     feats_goal.shape, 
        #     action_desc_emb.squeeze(1).shape, 
        #     motor_cmd_emb.squeeze(1).shape
        #      )
        concat_feats = torch.cat((
            feats_per, 
            feats_goal, 
            action_desc_emb.squeeze(1), 
            motor_cmd_emb.squeeze(1)
        ), dim=-2)#.squeeze(1)
        
        # print(f"Fused feats: {concat_feats.shape}")
        
        # 4. Feature mixing
        # packed_input = concat_feats
        packed_input = pack_padded_sequence(
            input=concat_feats, 
            lengths=lengths_max, 
            enforce_sorted=False, 
            batch_first=True
        )
        
        output, (hidden, carousel) = self.feature_mixing(packed_input)
        # print(output.shape)
        output, len_output = pad_packed_sequence(output, batch_first= True)
        
        return output, len_output, hidden, carousel

### Encoder summary

In [21]:
encoder = JEPSAMEncoder(
    cnn_backbone_name="resnet50",
    cfg=cfg
)

o, lo, h, c = encoder(data)

print("enc out: ", o.shape)

# # model summary
# summary(
#     model=encoder,
#     input_dict=data,
#     col_names=["kernel_size", "num_params"],

# )

enc out:  torch.Size([2, 11, 512])


### Decoder

In [22]:
class JEPSAMDecoder(nn.Module):
    def __init__(
        self, 
        cfg: Dict,
        cnn_backbone_name:str="resnet50",
        
    ):
        super().__init__()
        
        # class attributes
        self.cfg = cfg
        self.cnn_backbone_name = cnn_backbone_name
        self.num_directions = 2 if cfg.AEJEPS.IS_BIDIRECTIONAL else 1
        decoder_hidden_dim = self.num_directions * cfg.AEJEPS.HIDDEN_DIM
        
        ## Layers
        # tokenizer 
        # self.tokenizer  = PreTrainedTokenizerFast(
        #     tokenizer_file= self.cfg.DATASET.TOKENIZER_PATH, # You can load from the tokenizer file, alternatively
        #     unk_token="[UNK]",
        #     pad_token="[PAD]",
        #     cls_token="[CLS]",
        #     sep_token="[SEP]",
        #     mask_token="[MASK]",
        # )
        
        self.tokenizer = SimpleTokenizer(vocab)

        # embedding layer - same as encoder embedding layer
        self.embedding = nn.Embedding(
            cfg.DATASET.VOCABULARY_SIZE, 
            cfg.AEJEPS.EMBEDDING_DIM,
            device=self.cfg.TRAIN.GPU_DEVICE
        )
        
        # image decoder
        configs, bottleneck = get_configs(cnn_backbone_name)
        self.img_projection = nn.Sequential(
            nn.Linear(in_features=decoder_hidden_dim, out_features=cfg.AEJEPS.CNN_FC_DIM ),
            nn.LeakyReLU(),
            nn.Linear(in_features=cfg.AEJEPS.CNN_FC_DIM , out_features=2048*7*7),
        ).to(self.cfg.TRAIN.GPU_DEVICE)
        
        self.img_decoder = ResNetDecoder(configs[::-1], bottleneck).to(self.cfg.TRAIN.GPU_DEVICE)
        
        # motor command decoding layer
        self.motor_decoder = nn.LSTMCell(
            input_size=cfg.AEJEPS.EMBEDDING_DIM, 
            hidden_size=decoder_hidden_dim 
        ).to(self.cfg.TRAIN.GPU_DEVICE)
        
        # action desc. decoding layer
        self.lang_decoder = nn.LSTMCell(
            input_size=cfg.AEJEPS.EMBEDDING_DIM, 
            hidden_size=decoder_hidden_dim
        ).to(self.cfg.TRAIN.GPU_DEVICE)
        
        # projection layers
        # self.hidden_to_conv_in = nn.Linear(
        #     in_features=decoder_hidden_dim, 
        #     out_features=self.cfg.AEJEPS.HIDDEN_TO_CONV
        # )
            
        self.lang_head = nn.Linear(
            in_features=decoder_hidden_dim, 
            out_features=cfg.DATASET.VOCABULARY_SIZE 
            # To be discussed use the individual vocabs or the merged one for the projection
        ).to(self.cfg.TRAIN.GPU_DEVICE)
        self.motor_cmd_head = nn.Linear(
            in_features=decoder_hidden_dim, 
            out_features=cfg.DATASET.VOCABULARY_SIZE 
            # To be discussed use the individual vocabs or the merged one for the projection
        ).to(self.cfg.TRAIN.GPU_DEVICE)
        
        self.device = self.cfg.TRAIN.GPU_DEVICE if torch.cuda.is_available() else "cpu"
        
    def forward(
        self,
        enc_output, 
        len_enc_output, 
        hidden, 
        carousel,
        mode:str="train"
    ):
            
        batch_size, max_len, num_ftrs = enc_output.shape
        
        # hidden
        # hidden = hidden.view(self.num_directions, self.num_layers, batch_size, -1)
        # hidden = hidden[:self.num_directions, self.num_layers - 1, :, :]  # Take the last forward direction hidden state for
        
        hidden, carousel = self._rearrange_states(hidden, carousel)

        cmd_h_t, lang_h_t = (hidden, carousel), (hidden, carousel)
            
        # Unsqueeze to match expected input by transposed convolutions
        self.hidden = hidden.unsqueeze(0)
        
        # run decoding steps
        # generate action desc from latent representation
        lang_out = self._decode_action_description(hidden=lang_h_t, batch_size=batch_size, max_len=max_len)
        # generate motor cmd from latent representation
        motor_out = self._decode_motor_command(hidden=cmd_h_t, batch_size=batch_size, max_len=max_len)
        # reconstruct from latent representation
        per_image_rec = self._reconstruct_image(hidden)
        # generate from latent representation
        goal_image = self._generate_goal_image(hidden)
        
        return per_image_rec, goal_image, lang_out, motor_out
    
    
    def _rearrange_states(self, hidden, carousel):
        """
        
        """
        # hidden
        hidden = rearrange(
            hidden, 
            '(d l) b h -> l b (d h)',
            d=self.num_directions, 
            l=self.cfg.AEJEPS.NUM_LAYERS_ENCODER
        )
        hidden = hidden[self.cfg.AEJEPS.NUM_LAYERS_ENCODER - 1, :, :]
        
        # carousel
        carousel = rearrange(
            carousel, 
            '(d l) b h -> l b (d h)',
             d=self.num_directions, 
            l=self.cfg.AEJEPS.NUM_LAYERS_ENCODER
        )
        carousel = carousel[self.cfg.AEJEPS.NUM_LAYERS_ENCODER - 1, :, :]
        
        return hidden, carousel
    
    
    def _decode_action_description(
        self, 
        hidden, 
        batch_size:int, 
        max_len:int
    ):
        """
        """
        # print(next(self.lang_decoder.parameters()).is_cuda)
        
        lang_out = []        
        # Initialize the predictions with [SOS]
        prediction_txt_t = torch.ones(batch_size, 1).to(self.device).long() * self.cfg.DATASET.SOS
        # print(prediction_txt_t.device)
        for t in range(max_len):
            char = self.embedding(prediction_txt_t).squeeze(1)
            # hidden state at time step t for each RNN
            hidden, lang_c_t = self.lang_decoder(char, hidden)
            # project hidden state to vocab
            lang_scores = self.lang_head(hidden)
            # update hidden states
            hidden = (hidden, lang_c_t)
            # store newly generated token
            lang_out.append(lang_scores.unsqueeze(1))
            # draw new token: greedy decoding
            prediction_txt_t = lang_scores.argmax(dim=1)
        
        return torch.cat(lang_out, 1)
            
    def _decode_motor_command(
        self, 
        hidden, 
        batch_size:int, 
        max_len:int,
        method:str="embed"
    ):
        """
        Parameters:
        ----------
        
            method: str
                The method to use for token 
        """
        
        motor_out = []
        # Initialize the predictions with [SOS]
        prediction_cmd_t = torch.ones(batch_size, 1).to(self.device).long() * self.cfg.DATASET.SOS
            
        for t in range(max_len):
            if method=="one-hot":
                command = one_hot(
                    prediction_cmd_t.long(),
                    num_classes=num_commands
                ).squeeze(1).float()
            else:
                command = self.embedding(prediction_cmd_t).squeeze(1)
            
            # hidden state at time step t for each RNN
            hidden, cmd_c_t = self.motor_decoder(command, hidden)
            # project hidden state to vocab
            cmd_scores = self.motor_cmd_head(hidden)
            # update hidden states
            hidden = (hidden, cmd_c_t)
            # store newly generated token
            motor_out.append(cmd_scores.unsqueeze(1))
            # draw new token: greedy decoding
            prediction_cmd_t = cmd_scores.argmax(dim=1)
            
        return torch.cat(motor_out, 1)

    def _reconstruct_image(self, hidden):
        """
        """
        conv_in = self.img_projection(hidden)
        # print(conv_in.shape)
        B, _ = conv_in.shape
        
        return self.img_decoder(conv_in.view(B, 2048, 7, 7))
    
    def _generate_goal_image(self, hidden):
        """
        """
        return self._reconstruct_image(hidden)# to be fixed

    
    def pred_to_str(
        self, 
        predictions:torch.Tensor
    )->list:
        """
            Decode predictions (from ids to token)
            
            Parameters:
            ----------
                - predictions: Tensor
                    batch predictions from decoder module
        """
        return self.tokenizer.batch_decode(predictions.argmax(dim=-1))

In [23]:
decoder = JEPSAMDecoder(cfg=cfg)

decoder

JEPSAMDecoder(
  (embedding): Embedding(65, 256)
  (img_projection): Sequential(
    (0): Linear(in_features=512, out_features=256, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): Linear(in_features=256, out_features=100352, bias=True)
  )
  (img_decoder): ResNetDecoder(
    (conv1): DecoderBottleneckBlock(
      (00 EncoderLayer): DecoderBottleneckLayer(
        (weight_layer1): Sequential(
          (0): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (1): ReLU(inplace=True)
          (2): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
        (weight_layer2): Sequential(
          (0): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (1): ReLU(inplace=True)
          (2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (weight_layer3): Sequential(
          (0): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine

In [24]:
next(decoder.embedding.parameters()).is_cuda, decoder.device

(True, 'cuda:0')

#### Decoding language modalities

In [25]:
hh, cc = decoder._rearrange_states(h, c)
hh.shape, cc.shape

(torch.Size([2, 512]), torch.Size([2, 512]))

In [26]:
hh.device

device(type='cuda', index=0)

In [27]:
B, ML, NFTRS = o.shape

ad_out = decoder._decode_action_description(
    batch_size=B, 
    max_len=ML, 
    hidden=(hh, cc)
)

In [28]:
cmd_out = decoder._decode_motor_command(
    batch_size=B, 
    max_len=ML, 
    hidden=(hh, cc)
)

In [29]:
cmd_out.shape, ad_out.shape

(torch.Size([2, 11, 65]), torch.Size([2, 11, 65]))

In [30]:
cmd_out.argmax(dim=-1)[0]

tensor([ 7, 33, 62, 17, 59, 54, 17, 11, 42, 55, 26], device='cuda:0')

In [31]:
decoded_ad = decoder.pred_to_str(ad_out)
decoded_cmd = decoder.pred_to_str(cmd_out)

In [32]:
decoded_ad[0]

['top',
 'top',
 'top',
 'top',
 'top',
 'top',
 'left',
 '[PAD]',
 'top',
 'top',
 'top']

In [33]:
decoded_cmd[0]

[':BREAKFAST-CEREAL',
 'POSE-12',
 'the',
 ':MILK',
 'put',
 'in',
 ':MILK',
 ':CUBE',
 'POSE-7',
 'left',
 ':WEISSWURST']

#### Decoding visual modalities

In [34]:
# img_ftrs = encoder.image_feature_extractor(data[0])
# img_ftrs.shape

In [35]:
# rec_img = decoder._reconstruct_image(img_ftrs.cuda())

# rec_img.shape

#### Decoder I/O 

In [36]:
per_image_rec, goal_image, lang_out, motor_out = decoder(
    enc_output=o, 
    len_enc_output=lo, 
    hidden=h, 
    carousel=c
)

In [37]:
per_image_rec.shape, goal_image.shape, lang_out.shape, motor_out.shape

(torch.Size([2, 3, 224, 224]),
 torch.Size([2, 3, 224, 224]),
 torch.Size([2, 11, 65]),
 torch.Size([2, 11, 65]))

### AE: JEPSAM
    - Enc: JEPSAMEncoder
    - Dec: JEPSAMDecoder

In [38]:
class JEPSAM(nn.Module):
    """
    This class is an Autoencoder based deep learning implementation of a Joint Episdoic, Procedural, and Semantic Memory.

    Parameters
    ----------

    """
    def __init__(self, cfg: Dict):
        super().__init__()
        self.cfg = cfg

        # encoder
        self.encoder = JEPSAMEncoder(cfg=self.cfg)
        # decoder
        self.decoder = JEPSAMDecoder(cfg=self.cfg)
        # weight tying
        self.decoder.embedding.weight = self.encoder.embedding.weight
        
        self.device = self.cfg.TRAIN.GPU_DEVICE if torch.cuda.is_available() else "cpu"

        
    def forward(
        self, 
        inp, 
        mode:str="train"
    ):
        """
        """
        
        # encode
        o, lo, h, c  = self.encoder(inp)
        
        # decode
        reconstructed_image, goal_image, decoded_action_desc, decoded_cmd = self.decoder(
            enc_output=o, 
            len_enc_output=lo, 
            hidden=h, 
            carousel=c
        )

        return reconstructed_image, goal_image, decoded_action_desc, decoded_cmd



In [39]:
jepsam = JEPSAM(cfg=cfg).cuda()

jepsam

JEPSAM(
  (encoder): JEPSAMEncoder(
    (embedding): Embedding(65, 256)
    (image_feature_extractor): Sequential(
      (0): ResNetEncoder(
        (conv1): Sequential(
          (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
          (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
        (conv2): EncoderBottleneckBlock(
          (00 MaxPooling): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
          (01 EncoderLayer): EncoderBottleneckLayer(
            (weight_layer1): Sequential(
              (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): ReLU(inplace=True)
            )
            (weight_layer2): Sequential(
              (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias

#### JEPSAM I/O

In [40]:
per_image_rec, goal_image, lang_out, motor_out = jepsam(data)

In [41]:
per_image_rec.shape, goal_image.shape, lang_out.shape, motor_out.shape

(torch.Size([2, 3, 224, 224]),
 torch.Size([2, 3, 224, 224]),
 torch.Size([2, 11, 65]),
 torch.Size([2, 11, 65]))

In [43]:
motor_out.argmax(dim=-1)

tensor([[51, 56, 34, 42, 60, 33, 38,  4,  4, 30, 35],
        [48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48]], device='cuda:0')

In [50]:
nn.functional.mse_loss(motor_out.float().argmax(dim=-1), cmd.squeeze(1).float().cuda())

tensor(607.4091, device='cuda:0')

In [49]:
cmd.shape, motor_out.argmax(dim=-1).shape

(torch.Size([2, 1, 11]), torch.Size([2, 11]))