# Neural Network Modeling with MULTFS

#### Overview:
This notebook aims to provided an example of training a PyTorch Model on a premade task dataset
- Text-embedder: ```all-mpnet-base-v2 (pretrained Sentence Transformer)```
- Image-embedder: ```vit_b_16 (pretrained Vision Transformer)``` 
- Decoder/Classifier: Transformer Decoder only Model trained on both text and image embeddings to out put action class


#### Datasets/Training:
- In this notebook we use pre-saved and generated datasets
    - *See the ..._dataset_gen.ipynb notebooks for how to generate and save a dataset*




### Imports

In [1]:
import sys
sys.path.append('../')

import numpy as np
import os
from PIL import Image
import json
import math
import copy
from sklearn.model_selection import train_test_split

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
import torchvision
from torchvision import transforms
from tqdm.notebook import tqdm

from cognitive.task_bank import CompareLocTemporal
from cognitive import task_generator as tg
from cognitive import constants as const
from cognitive import stim_generator as sg
from cognitive import info_generator as ig
import random

2023-10-30 19:17:45.162292: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-10-30 19:17:45.162447: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-10-30 19:17:45.162531: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-10-30 19:17:45.168451: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


### Constants

In [2]:
TRAIN_DIR = '../datasets/train_big'  # Training Dataset Directory
VAL_DIR = '../datasets/val_big'  # Validation Dataset Directory
TEST_DIR = '../datasets/val_big'  # Testing Dataset Directory
LM_PATH = 'offline_models/all-mpnet-base-v2'
IMGM_PATH = 'offline_models/resnet/resnet'
EMB_DIR = '../datasets/embeddings'
BATCH_SIZE = 256
MAX_FRAMES = 4  # the max possible frames across tasks
IMGM_OUT_DIM = 2048 # vision transformer output dimension
LM_OUT_DIM = 768 # language model output dimension

# Check if a GPU is available
if torch.cuda.is_available():
    # Request GPU device 0
    device = torch.device("cuda:0")
    print(f"Using GPU: {torch.cuda.get_device_name(device)}")
else:
    # If no GPU is available, fall back to CPU
    device = torch.device("cpu")
    print("No GPU available, using CPU.")

Using GPU: NVIDIA RTX A5000


# Datasets
- Read in the pregenerated task trials organized into frames, instructions, and correct actions. 

In [3]:
# transform = transforms.Compose([
#                             transforms.Resize(224),
#                             transforms.CenterCrop(224),
#                             transforms.ToTensor(),
#                             transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
#                         ])

# def read_trials(path):
#     frames = []
#     infos = []

#     for trial_fp in os.listdir(path):
#         if 'trial' not in trial_fp:
#             continue

#         trial_fp = os.path.join(path, trial_fp)
#         imgs = []
#         info = None
        
#         for fp in os.listdir(trial_fp):
#             fp = os.path.join(trial_fp, fp)
            
#             if fp[-4:] == '.png':
#                 img = Image.open(fp)
#                 img = transform(img)
#                 imgs.append(img)
#             elif 'trial_info' in fp:
#                 info = json.load(open(fp))
#                 infos.append(info)
                
#         if len(imgs) > MAX_FRAMES:
#             raise Exception(trial_fp + " contains more frames than the set maximum (MAX_FRAMES) !!!")
#         elif len(imgs) != len(info['answers']):
#             raise Exception(trial_fp + " numbers of frames does not match number of actions")
            
#         frames.append(np.array(imgs))

#     return frames, infos

# train_frames, train_infos = read_trials(TRAIN_DIR)
# val_frames, val_infos = read_trials(VAL_DIR)
# # test_frames, test_infos = read_trials(TEST_DIR)

In [4]:
# train_ins = [x['instruction'] for x in train_infos]
# train_raw_targets = [x['answers'] for x in train_infos]

# val_ins = [x['instruction'] for x in val_infos]
# val_raw_targets = [x['answers'] for x in val_infos]

# test_ins = [x['instruction'] for x in test_infos]
# test_raw_targets = [x['answers'] for x in test_infos]

### Encode Target Actions
- Encodes the target actions into one hot encoding vectors corresponding to the actions ```true, false, and null```

In [5]:
# action_map = {'true': 0, 'false': 1, 'null': 2}



# def map_actions(amap, raw_actions):
#     count = {'true': 0, 'false': 0, 'null': 0}
#     target_actions = []

#     for actions in raw_actions:
#         encoded = []
#         for action in actions:
#             count[action] += 1
#             encoded.append(amap[action])
#         target_actions.append(encoded)
    
#     return target_actions, count

# train_targets, train_targets_count = map_actions(action_map, train_raw_targets)
# val_targets, val_targets_count = map_actions(action_map, val_raw_targets)
# # test_targets, test_targets_count = map_actions(action_map, test_raw_targets)

In [6]:
# train_targets_count
# val_targets_count

## Instruction Dataset

In [7]:
# class InstructionsDataset(Dataset):
#   """
#     Pytorch Dataset class to load the Instructions Data

#     Data members:
#       instructions: list of instructions
#       n_ins: number of instructions in the dataset

#     Member functions:
#       __init__: ctor
#       __len__: returns n_ins
#       __getitem__: returns an instruction
#   """

#   def __init__(self, x):

#     self.instructions = x

#     self.n_ins = len(self.instructions)

#     return

#   def __len__(self):
#     """
#       Returns number of instructions in the Dataset
#     """

#     return self.n_ins

#   def __getitem__(self, idx):
#     """
#       Given an index return a instruction at that index
#     """

#     return self.instructions[idx]

In [8]:
# class InstructionsCollator(object):
#   """
#     Data Collator used for GPT2 in a classificaiton tasks

#     Args:
#       use_tokenizer :
#         Transformer type tokenizer used to process raw text into numbers.

#     Data members:
#       use_tokenizer: Tokenizer to be used inside the class.

#     Member functions:
#       __init__: ctor
#       __call__: tokenize input

#     """

#   def __init__(self, use_tokenizer):

#     self.use_tokenizer = use_tokenizer

#     return

#   def __call__(self, instructions):
#     """
#         Tokenizes input
#     """

#     # Call tokenizer
#     inputs = self.use_tokenizer(instructions, padding=True, truncation=True, return_tensors='pt').to(device)

#     return inputs


In [9]:
# Pretrained Language Model and Tokenizer 
# lm_encoder = AutoModel.from_pretrained('sentence-transformers/all-mpnet-base-v2')
# tokenizer = AutoModel.from_pretrained('sentence-transformers/all-mpnet-base-v2')

# Uncomment for offline load of lang embedder
# lm_encoder = AutoModel.from_pretrained('offline_models/all-mpnet-base-v2').to(device).eval()
# tokenizer = AutoTokenizer.from_pretrained('offline_models/all-mpnet-base-v2')

# # Create data collator to encode text and labels into numbers.
# InstructionsCollator = InstructionsCollator(use_tokenizer=tokenizer)

# # Create pytorch datasets for instructions
# ins_train_dataset = InstructionsDataset(train_ins)
# ins_val_dataset = InstructionsDataset(val_ins)
# # ins_test_dataset = InstructionsDataset(test_ins)

# # Move pytorch datasets into dataloaders
# ins_train_dataloader = DataLoader(ins_train_dataset, batch_size=1, shuffle=False, collate_fn=InstructionsCollator)
# ins_val_dataloader = DataLoader(ins_val_dataset, batch_size=1, shuffle=False, collate_fn=InstructionsCollator)
# ins_test_dataloader = DataLoader(ins_test_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=InstructionsCollator)


## Frames Dataset

In [10]:
# class FramesDataset(Dataset):
#   """
#     Pytorch Dataset class to load the Frame Data

#     Data members:
#       frames``ist of frames
#       n_imgs: number of iamges in the dataset

#     Member functions:
#       __init__: ctor
#       __len__: returns n_imgs
#       __getitem__: returns an frame
#   """

#   def __init__(self, x):

#     self.frames = x

#     self.n_imgs = len(self.frames)

#     return

#   def __len__(self):
#     """
#       Returns number of frames in the Dataset
#     """

#     return self.n_imgs

#   def __getitem__(self, idx):
#     """
#       Given an index return a frame
#     """

#     return torch.tensor(self.frames[idx]).to(device)

In [11]:
# Pretrained Vision Transformer
# img_encoder = torchvision.models.vit_b_16(weights=torchvision.models.ViT_B_16_Weights.DEFAULT).to(device)
# img_encoder = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT).to(device)

# Uncommnent for offline load of vit
# img_encoder = torch.load(IMGM_PATH,map_location=device).to(device).eval()

# # Create pytorch datasets for instructions
# frames_train_dataset = FramesDataset(train_frames)
# frames_val_dataset = FramesDataset(val_frames)
# # frames_test_dataset = FramesDataset(test_frames)

# # Move pytorch datasets into dataloaders
# frames_train_dataloader = DataLoader(frames_train_dataset, batch_size=1, shuffle=False)
# frames_val_dataloader = DataLoader(frames_val_dataset, batch_size=1, shuffle=False)
# frames_test_dataloader = DataLoader(frames_test_dataset, batch_size=BATCH_SIZE, shuffle=False)


In [12]:
# from torchsummary import summary

# list(img_encoder.named_children())[-2]
# summary(torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT), input_size=(3, 224, 224))

# Language Encoder

### Language Embedder

In [13]:
def lm_embedder(instruction, encoder, tokenizer):
    instruction = tokenizer(instruction, padding=True, truncation=True, return_tensors='pt').to(device)
    
    #Mean Pooling - Take attention mask into account for correct averaging
    def mean_pooling(model_output, attention_mask):
        token_embeddings = model_output[0] # First element of model_output contains all token embeddings
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

    # Compute token embeddings
    with torch.no_grad():
        lm_output = encoder(**instruction)

    # Perform pooling
    sentence_embeddings = mean_pooling(lm_output, instruction['attention_mask'])
    
    # Normalize embeddings
    sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
        
    return sentence_embeddings

# Image Encoder

### Position Embeddings

### Image Embedder
- We pad frames based on max possible

In [14]:
def img_embedder(frames, encoder):

    activation = {}
    def get_activation(name):
        def hook(model, input, output):
            activation[name] = output.detach()
        return hook
    
    encoder.avgpool.register_forward_hook(get_activation('layer'))

    with torch.no_grad():
        out = encoder(torch.tensor(frames))
        out = torch.squeeze(activation['layer'], (2,3)) #torch.flatten(, start_dim=1, end_dim=2)

    npads = MAX_FRAMES-len(out)
    pad = torch.ones((npads, out.shape[1]))
    out = torch.cat((out, pad.to(device)))

    return out

# Create Embeddings

In [15]:
class TaskDataset(Dataset):
    # todo: this is not the most efficient way to access data, since each time it has to read from the directory 
    def __init__(self, root_dir):
        self.root_dir = root_dir
        # preprocessing steps for pretrained ResNet models
        self.transform = transforms.Compose([
                            transforms.Resize(224),
                            transforms.CenterCrop(224), # todo: to delete for shapenet task; why?
                            transforms.ToTensor(),
                            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                        ])
    

        # check the size of the dataset
        self.dataset_size = 0
        items = os.listdir(self.root_dir)
        for item in items:
            item_path = os.path.join(self.root_dir, item)
            # Check if the item is a directory
            if os.path.isdir(item_path):
                self.dataset_size += 1

    def __len__(self):
        return self.dataset_size

    def __getitem__(self, idx):
        trial_path = os.path.join(self.root_dir, "trial%d"%idx)
        images = []
        
        for fp in os.listdir(trial_path):
            fp = os.path.join(trial_path, fp)

            if fp[-4:] == '.png':
                img = Image.open(fp)
                img = self.transform(img)
                images.append(img)
            elif 'trial_info' in fp:
                info = json.load(open(fp))
                
                actions = self._action_map(info["answers"])

                npads = MAX_FRAMES - len(actions)
                actions.extend([-1 for _ in range(0,npads)])
                
                instructions = info['instruction']
        
        images = np.stack(images) # (2*3*224*224)

        return torch.tensor(images).to(device), instructions, torch.tensor(actions).cpu()
    
    def _action_map(self, actions):
        action_map = {'true': 0, 'false': 1, 'null': 2}
        updated_actions = []
        for action in actions:
            updated_actions.append(action_map[action])

        return updated_actions


In [16]:
lm_encoder = AutoModel.from_pretrained('offline_models/all-mpnet-base-v2').to(device).eval()
tokenizer = AutoTokenizer.from_pretrained('offline_models/all-mpnet-base-v2')

img_encoder = torch.load(IMGM_PATH,map_location=device).to(device).eval()

train_TD = TaskDataset(TRAIN_DIR)
val_TD = TaskDataset(VAL_DIR)

train_DL = DataLoader(train_TD, batch_size=1, shuffle=True)
val_DL = DataLoader(val_TD, batch_size=1, shuffle=False)


In [17]:
embDir = os.listdir('../datasets/embeddings/')
if '.ipynb_checkpoints' in embDir:
    embDir.remove('.ipynb_checkpoints')
if '.DS_Store' in embDir:
    embDir.remove('.DS_Store')

if len(embDir) != 6:
    train_lm_embeddings = []
    train_img_embeddings = []
    train_targets = []

    val_lm_embeddings = []
    val_img_embeddings = []
    val_targets = []


    for frames,instruction,actions in train_DL:
        frames = frames[0]

        train_lm_embeddings.append(lm_embedder(instruction, lm_encoder, tokenizer).cpu())
        train_img_embeddings.append(img_embedder(frames, img_encoder).cpu())
        train_targets.append(actions[0])
        
        
    for frames,instruction,actions in val_DL:
        frames = frames[0]

        val_lm_embeddings.append(lm_embedder(instruction, lm_encoder, tokenizer).cpu())
        val_img_embeddings.append(img_embedder(frames, img_encoder).cpu())
        val_targets.append(actions[0])


    torch.save(train_lm_embeddings, EMB_DIR + '/train_lm_embeddings')
    torch.save(train_img_embeddings, EMB_DIR + '/train_img_embeddings')
    np.save(EMB_DIR + '/train_targets', np.array(train_targets, dtype=object))
    
    torch.save(val_lm_embeddings, EMB_DIR + '/val_lm_embeddings')
    torch.save(val_img_embeddings, EMB_DIR + '/val_img_embeddings')
    np.save(EMB_DIR + '/val_targets', np.array(val_targets, dtype=object))
    
else:
    train_lm_embeddings = torch.load(EMB_DIR + '/train_lm_embeddings',map_location=device)
    train_img_embeddings = torch.load(EMB_DIR + '/train_img_embeddings',map_location=device)
    train_targets = np.load(EMB_DIR + '/train_targets.npy', allow_pickle=True).tolist()
    
    val_lm_embeddings = torch.load(EMB_DIR + '/val_lm_embeddings',map_location=device)
    val_img_embeddings = torch.load(EMB_DIR + '/val_img_embeddings',map_location=device)
    val_targets = np.load(EMB_DIR + '/val_targets.npy', allow_pickle=True).tolist()


# Embeddings Dataset

In [18]:
class EmbeddingsDataset(Dataset):
  """
    Pytorch Dataset class to load the embedded data

    Data members:
      lm_embeddings: list of language model embeddings
      img_embeddings: list of language model embeddings
      n_embs: number of embeddings in the dataset

    Member functions:
      __init__: ctor
      __len__: returns n_ins
      __getitem__: returns an instruction
  """

  def __init__(self, lm_embeddings, img_embeddings, actions):

    self.lm_embeddings = lm_embeddings
    self.img_embeddings = img_embeddings
    self.actions = actions

    self.n_embs = len(self.lm_embeddings)

    return

  def __len__(self):
    """
      Returns number of instructions in the Dataset
    """

    return self.n_embs

  def __getitem__(self, idx):
    """
      Given an index return a instruction at that index
    """

    return {'instructions':self.lm_embeddings[idx].to(device), 'frames':self.img_embeddings[idx].to(device), 'actions':torch.tensor(self.actions[idx], dtype=torch.float32, device=device)}

In [19]:
# Create pytorch dataset for train, val, test data
train_dataset = EmbeddingsDataset(train_lm_embeddings, train_img_embeddings, train_targets)
val_dataset = EmbeddingsDataset(val_lm_embeddings, val_img_embeddings, val_targets)
# test_dataset  = EmbeddingsDataset(test_lm_embeddings, test_img_embeddings, test_targets)

# Move pytorch datasets into dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
# test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [20]:
len(train_dataloader)

79

# Action Decoder

In [21]:
from matplotlib import projections
import math
from typing import Tuple

import torch
from torch import nn, Tensor
import torch.nn.functional as F
from torch.utils.data import dataset

class CausalMatchTransformer(nn.Module):
    """
    Pytorch based transformer decoder model
    """

    # Initialize Model with Params
    def __init__(self, nframes=MAX_FRAMES, blocks=3, nhead=5, emb_dim=IMGM_OUT_DIM, t_ffl_dim=512, classes=3, device=device):
        super().__init__()

        # Device
        self.device = device

        # Embedding Dimension
        self.emb_dim = emb_dim

        # Number of frames
        self.nframes = nframes

        # Frame Position Embedder Layer
        # self.pos_emb = nn.Parameter(torch.Tensor(nframes,emb_dim)).to(device)
        # torch.nn.init.xavier_uniform_(
        #    self.pos_emb,
        #    gain=torch.nn.init.calculate_gain("linear"))
        self.pos_emb = PositionalEncoding(emb_dim, 0)

        # Instruction Dim Projection Layer
        self.lm_linear_layer = nn.Linear(LM_OUT_DIM, emb_dim)
        
        # Image Dim Projection Layer
        self.img_linear_layer = nn.Linear(IMGM_OUT_DIM, emb_dim)

        # Decoder Layers
        self.decoder_layer = nn.TransformerDecoderLayer(d_model=emb_dim, nhead=nhead, batch_first=True)
        
        # Decoder
        self.decoder = nn.TransformerDecoder(self.decoder_layer, num_layers=blocks)

        # Action classifier
        self.classifier = nn.Linear(emb_dim, classes)

    # Function for forward pass
    def forward(self, instruction, frames, mask, padding_mask):

        # Project instruction embedding
        instruction = self.lm_linear_layer(instruction)
        
        frames = self.img_linear_layer(frames)

        # Add the frame position embedding
        # for i in range(len(frames)):
        #     frames[i] += self.pos_emb[i,:]
        frames = self.pos_emb(frames)
        
        output = self.decoder(frames, instruction, tgt_mask=mask, tgt_key_padding_mask=padding_mask)


        # Pass through linear layer for classification
        output = self.classifier(output)

        return output
 
# Creates a list of torch duplicate torch modules
def _get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])

# Creates a square Sequential/Causal mask of size sz
def generate_causal_mask(sz: int) -> Tensor:
    return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)

# Generates a padding masks for each sequence in a batch
def generate_pad_mask(batch):

    pad_tensor = torch.ones((batch.shape[2])).to(device)

    mask = np.zeros((batch.shape[0],batch.shape[1]))

    for s in range(0, batch.shape[0]):
        for v in range(0, batch[s].shape[0]):
            new_s = torch.all(batch[s][v] == pad_tensor)
            mask[s][v] = new_s

    return torch.tensor(mask).bool().to(device)

class PositionalEncoding(nn.Module):
    # Positional encoding module taken from PyTorch Tutorial
    # Link: https://pytorch.org/tutorials/beginner/transformer_tutorial.html

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

In [51]:
torch.cuda.empty_cache()
model = CausalMatchTransformer(nframes=MAX_FRAMES,
                               blocks=1,
                               nhead=16,
                               emb_dim=256,
                               t_ffl_dim = 2048, # transformer's feedforward layers dimension
                               classes=3,
                               device=device).to(device)

# Training

## Pre-generated

In [52]:
# Training configurations
epochs = 20
criterion = nn.CrossEntropyLoss(ignore_index=-1)

lr = 1e-4
optimizer = torch.optim.AdamW((p for p in model.parameters() if p.requires_grad), lr = lr)

scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 5, T_mult=1 )

mask = generate_causal_mask(MAX_FRAMES).to(device)

In [53]:
# Calculates the number of correct null action predictions and the number of correct non-null action predictions
def correct(preds, targs):
    null_idxs = torch.where(targs.cpu() == 2)
    non_null_idxs = torch.where(targs.cpu() < 2)
    
    null_preds = preds[null_idxs]
    non_null_preds = preds[non_null_idxs]
    
    c_null = torch.sum(null_preds == targs[null_idxs])
    n_null = len(null_preds)
    null_acc = c_null/n_null
    
    c_non_null = torch.sum(non_null_preds == targs[non_null_idxs])
    n_non_null = len(non_null_preds)
    non_null_acc = c_non_null/n_non_null
    
    return null_acc, non_null_acc

# Calculates the loss for a forward pass for both null and non-null action predictions (this is to avoid overfitting to null actions)
def loss(preds, targs):
    # Find indexes of null frames and non-null frames
    null_idxs = torch.where(targs.cpu() == 2)
    non_null_idxs = torch.where(targs.cpu() < 2)
    
    # Add batch dimension and reorder into (batch_size, n_classes, seq_len)
    null_preds = preds[null_idxs]
    non_null_preds = preds[non_null_idxs]

    null_loss = criterion(null_preds, targs[null_idxs])
    non_null_loss = criterion(non_null_preds, targs[non_null_idxs])
    
    return null_loss, non_null_loss, len(non_null_idxs)/len(null_idxs)
    

In [54]:
# Training and validation loop

# Store the average loss after each epoch
all_loss = {'train_null_loss':[],'train_non_null_loss':[], 'val_null_loss':[], 'val_non_null_loss':[]}
all_acc = {'train_null_acc':[], 'train_non_null_acc':[], 'val_null_acc':[], 'val_non_null_acc':[]}

print("starting")
for epoch in range(epochs):

    model.eval()
    
    # Epoch stat trackers
    null_accs = []
    non_null_accs = []
    null_losses = []
    non_null_losses = []
    for idx, batch in enumerate(iter(train_dataloader)):

        # Inputs and Targets
        instruction = batch['instructions']
        frames = batch['frames']
        actions = batch['actions']

        # Frame Padding
        padding_mask = generate_pad_mask(batch=frames)
        pad_indexes = torch.argwhere(padding_mask == True).cpu()
        pad_indexes = pad_indexes[:,0] * MAX_FRAMES + pad_indexes[:,1]

        # Get predictions
        output = model(instruction, frames, mask, padding_mask)
        
        # Get Loss for both null and non-null actions
        null_loss, non_null_loss, scale = loss(output.reshape(-1,3), actions.type(torch.LongTensor).reshape(-1).to(device))
        null_losses.append(null_loss.item())
        non_null_losses.append(non_null_loss.item())
        
        if epoch != 0:
            total_loss = non_null_loss
        else:
            total_loss = null_loss + non_null_loss

        # Track stats
        _, predicted = torch.max(output.data, 2)
        
        predicted = predicted.reshape(-1).cpu()
        predicted = torch.Tensor(np.delete(predicted.numpy(), pad_indexes.numpy())) # drop the pads
        
        actions = actions.reshape(-1).cpu()
        actions = torch.Tensor(np.delete(actions.numpy(), pad_indexes.numpy())) # drop the pads
        
        null_acc, non_null_acc = correct(predicted, actions)
        null_accs.append(null_acc.cpu())
        non_null_accs.append(non_null_acc.cpu())

        # Backward pass
        total_loss.backward()
        # torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step(epoch + idx / len(train_dataloader))
        
    # Validate on validation set every 2 epochs
    if (epoch+1) % 2 == 0 or epoch == epochs:
        model.eval()
        
        # print(torch.norm(model.mlp[-3].weight.grad))
        # Turn off gradient calcs
        with torch.no_grad():
            val_null_accs = []
            val_non_null_accs = []
            val_null_losses = []
            val_non_null_losses = []
            
            for idx, batch in enumerate(iter(val_dataloader)):
                # Inputs and Targets
                instruction = batch['instructions']
                frames = batch['frames']
                actions = batch['actions']

                # Frame Padding
                padding_mask = generate_pad_mask(batch=frames)
                pad_indexes = torch.argwhere(padding_mask == True).cpu()
                pad_indexes = pad_indexes[:,0] * MAX_FRAMES + pad_indexes[:,1]

                # Get predictions
                output = model(instruction, frames, mask, padding_mask)

                # Get Losses
                null_loss, non_null_loss, scale = loss(output.reshape(-1,3), actions.type(torch.LongTensor).reshape(-1).to(device))
                val_null_losses.append(null_loss.item())
                val_non_null_losses.append(non_null_loss.item())
                
                # Track Stats
                _, predicted = torch.max(output.data, 2)
                
                predicted = predicted.reshape(-1).cpu()
                predicted = torch.Tensor(np.delete(predicted.numpy(), pad_indexes.numpy())) # drop the pads
        
                actions = actions.reshape(-1).cpu()
                actions = torch.Tensor(np.delete(actions.numpy(), pad_indexes.numpy())) # drop the pads
                
                null_acc, non_null_acc = correct(predicted, actions)
                val_null_accs.append(null_acc.cpu())
                val_non_null_accs.append(non_null_acc.cpu())
                

        # Track loss and acc ever 5 epochs
        all_loss['val_null_loss'].append(sum(val_null_losses)/len(val_null_losses))
        all_loss['val_non_null_loss'].append(sum(val_non_null_losses)/len(val_non_null_losses))
        all_acc['val_null_acc'].append(sum(val_null_accs)/len(val_null_accs))
        all_acc['val_non_null_acc'].append(sum(val_non_null_accs)/len(val_non_null_accs))

        print("average val null acc %.2f" % (sum(val_null_accs)/len(val_null_accs)))
        print("average val non-null acc %.2f" % (sum(val_non_null_accs)/len(val_non_null_accs)))
        
    print('epoch: ', epoch)
    print('train null acc: ', null_acc)
    print('train non-null acc: ', non_null_acc)

    all_loss['train_null_loss'].append(sum(null_losses)/len(null_losses))
    all_loss['train_non_null_loss'].append(sum(non_null_losses)/len(non_null_losses))
    all_acc['train_null_acc'].append(sum(null_accs)/len(null_accs))
    all_acc['train_non_null_acc'].append(sum(non_null_accs)/len(non_null_accs))
        


starting
epoch:  0
train null acc:  tensor(1.)
train non-null acc:  tensor(0.)
average val null acc 0.00
average val non-null acc 0.51
epoch:  1
train null acc:  tensor(0.)
train non-null acc:  tensor(0.5250)
epoch:  2
train null acc:  tensor(0.)
train non-null acc:  tensor(0.5312)
average val null acc 0.00
average val non-null acc 0.51
epoch:  3
train null acc:  tensor(0.)
train non-null acc:  tensor(0.5250)
epoch:  4
train null acc:  tensor(0.)
train non-null acc:  tensor(0.7188)
average val null acc 0.00
average val non-null acc 0.49
epoch:  5
train null acc:  tensor(0.)
train non-null acc:  tensor(0.4750)
epoch:  6
train null acc:  tensor(0.)
train non-null acc:  tensor(0.5312)
average val null acc 0.00
average val non-null acc 0.49
epoch:  7
train null acc:  tensor(0.)
train non-null acc:  tensor(0.4750)
epoch:  8
train null acc:  tensor(0.)
train non-null acc:  tensor(0.4375)
average val null acc 0.00
average val non-null acc 0.49
epoch:  9
train null acc:  tensor(0.)
train non-n

In [55]:
all_acc

{'train_null_acc': [tensor(0.6658),
  tensor(0.0977),
  tensor(0.),
  tensor(0.),
  tensor(0.),
  tensor(0.),
  tensor(0.),
  tensor(0.),
  tensor(0.),
  tensor(0.),
  tensor(0.),
  tensor(0.),
  tensor(0.),
  tensor(0.),
  tensor(0.),
  tensor(0.),
  tensor(0.),
  tensor(0.),
  tensor(0.),
  tensor(0.)],
 'train_non_null_acc': [tensor(0.1872),
  tensor(0.4614),
  tensor(0.4995),
  tensor(0.5043),
  tensor(0.5057),
  tensor(0.5033),
  tensor(0.4955),
  tensor(0.5004),
  tensor(0.4960),
  tensor(0.4974),
  tensor(0.5024),
  tensor(0.5040),
  tensor(0.4990),
  tensor(0.4963),
  tensor(0.4984),
  tensor(0.4998),
  tensor(0.5033),
  tensor(0.5037),
  tensor(0.5047),
  tensor(0.5026)],
 'val_null_acc': [tensor(0.),
  tensor(0.),
  tensor(0.),
  tensor(0.),
  tensor(0.),
  tensor(0.),
  tensor(0.),
  tensor(0.),
  tensor(0.),
  tensor(0.)],
 'val_non_null_acc': [tensor(0.5074),
  tensor(0.5074),
  tensor(0.4926),
  tensor(0.4926),
  tensor(0.4926),
  tensor(0.5074),
  tensor(0.4926),
  tenso

## Evaluating

In [27]:
model.eval()

val_epoch_correct_non_null = 0

val_epoch_count_non_null = 0

with torch.no_grad():
    test_correct_null = 0
    test_correct_non_null = 0
    test_count_null = 0
    test_count_non_null = 0

    for idx, batch in enumerate(iter(val_dataloader)):
        # Inputs and Targets
        instruction = batch['instruction']
        frames = batch['frames']
        targets = batch['actions'][:,-1]
        print(targets)

        # Frame Padding
        padding_mask = generate_pad_mask(batch=frames)
        pad_indexes = np.argwhere(np.array(padding_mask.cpu()) == False)[:,1]

        # Get predictions
        predictions = model(instruction, frames, mask, padding_mask)
        print(predictions)
        predictions = predictions[:,pad_indexes]

        # Get Losses
        non_null_loss = criterion(predictions, targets.long())
        # null_loss, non_null_loss, scale = loss(predictions, targets)
        # total_loss = non_null_loss #null_loss*scale + non_null_loss*(1/scale)
        # # total_loss = criterion(predictions.permute(0, 2, 1), targets.long())

        # Track Stats
        val_epoch_correct_non_null += (predictions.argmax(dim=-1).cpu() == targets.cpu()).sum().item()
        val_epoch_count_non_null += (predictions.argmax(dim=-1).cpu() == targets.cpu()).size(0)

print("Test Non-Null Accuracy: ", round(val_epoch_correct_non_null/val_epoch_count_non_null,4))

KeyError: 'instruction'