# Neural Network Modeling with MULTFS

#### Overview:
This notebook aims to provided an example of training a PyTorch Model on a premade task dataset
- Text-embedder: ```all-mpnet-base-v2 (pretrained Sentence Transformer)```
- Image-embedder: ```vit_b_16 (pretrained Vision Transformer)``` 
- Decoder/Classifier: Transformer Decoder only Model trained on both text and image embeddings to out put action class


#### Datasets/Training:
- In this notebook we use pre-saved and generated datasets
    - *See the ..._dataset_gen.ipynb notebooks for how to generate and save a dataset*




### Imports

In [1]:
import sys
sys.path.append('../')

import numpy as np
import os
from PIL import Image
import json
import math
import copy
from sklearn.model_selection import train_test_split

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
import torchvision
from torchvision import transforms
from tqdm.notebook import tqdm

from cognitive.task_bank import CompareLocTemporal
from cognitive import task_generator as tg
from cognitive import constants as const
from cognitive import stim_generator as sg
from cognitive import info_generator as ig
import random

2023-11-01 17:22:26.825867: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-11-01 17:22:26.825999: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-11-01 17:22:26.826088: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-11-01 17:22:26.831961: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


### Constants

In [2]:
TRAIN_DIR = '../datasets/train_big'  # Training Dataset Directory
VAL_DIR = '../datasets/val_big'  # Validation Dataset Directory
TEST_DIR = '../datasets/val_big'  # Testing Dataset Directory
LM_PATH = 'offline_models/all-mpnet-base-v2'
IMGM_PATH = 'offline_models/resnet/resnet'
EMB_DIR = '../datasets/embeddings'
BATCH_SIZE = 128
MAX_FRAMES = 2  # the max possible frames across tasks
IMGM_OUT_DIM = 2048 # vision transformer output dimension
LM_OUT_DIM = 768 # language model output dimension



Using GPU: NVIDIA RTX A5000


# Language Encoder

### Language Embedder

In [3]:
def lm_embedder(instruction, encoder, tokenizer):
    instruction = tokenizer(instruction, padding=True, truncation=True, return_tensors='pt').to(device)
    
    #Mean Pooling - Take attention mask into account for correct averaging
    def mean_pooling(model_output, attention_mask):
        token_embeddings = model_output[0] # First element of model_output contains all token embeddings
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

    # Compute token embeddings
    with torch.no_grad():
        lm_output = encoder(**instruction)

    # Perform pooling
    sentence_embeddings = mean_pooling(lm_output, instruction['attention_mask'])
    
    # Normalize embeddings
    sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
        
    return sentence_embeddings

# Image Encoder

### Image Embedder
- We pad frames based on max possible

In [4]:
def img_embedder(frames, encoder):

    activation = {}
    def get_activation(name):
        def hook(model, input, output):
            activation[name] = output.detach()
        return hook
    
    encoder.layer4[2].register_forward_hook(get_activation('layer'))

    with torch.no_grad():
        encoder(torch.tensor(frames))
        out = activation['layer']
        # out = torch.flatten(activation['layer'], start_dim=1, end_dim=3)
        # out = torch.squeeze(activation['layer'], (2,3)) #torch.flatten(, start_dim=1, end_dim=2)

    npads = MAX_FRAMES-len(out)
    pad = torch.ones((npads, out.shape[1], out.shape[2], out.shape[3]))
    out = torch.cat((out, pad.to(device)))

    return out

# Create Embeddings

In [5]:
class TaskDataset(Dataset):
    # todo: this is not the most efficient way to access data, since each time it has to read from the directory 
    def __init__(self, root_dir):
        self.root_dir = root_dir
        # preprocessing steps for pretrained ResNet models
        self.transform = transforms.Compose([
                            transforms.Resize(224),
                            transforms.CenterCrop(224), # todo: to delete for shapenet task; why?
                            transforms.ToTensor(),
                            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                        ])
    

        # check the size of the dataset
        self.dataset_size = 0
        items = os.listdir(self.root_dir)
        for item in items:
            item_path = os.path.join(self.root_dir, item)
            # Check if the item is a directory
            if os.path.isdir(item_path):
                self.dataset_size += 1

    def __len__(self):
        return self.dataset_size

    def __getitem__(self, idx):
        trial_path = os.path.join(self.root_dir, "trial%d"%idx)
        images = []
        
        for fp in os.listdir(trial_path):
            fp = os.path.join(trial_path, fp)

            if fp[-4:] == '.png':
                img = Image.open(fp)
                img = self.transform(img)
                images.append(img)
            elif 'trial_info' in fp:
                info = json.load(open(fp))
                
                actions = self._action_map(info["answers"])

                # npads = MAX_FRAMES - len(actions)
                # actions.extend([-1 for _ in range(0,npads)])
                
                instructions = info['instruction']
        
        images = np.stack(images) # (2*3*224*224)

        return {'instructions':instructions, 'frames':torch.tensor(images).to(device), 'actions': torch.tensor(actions).to(device)}
    
    def _action_map(self, actions):
        action_map = {'true': 0, 'false': 1, 'null': 2}
        updated_actions = []
        for action in actions:
            updated_actions.append(action_map[action])

        return updated_actions


In [6]:
lm_encoder = AutoModel.from_pretrained('offline_models/all-mpnet-base-v2').to(device).eval()
tokenizer = AutoTokenizer.from_pretrained('offline_models/all-mpnet-base-v2')

img_encoder = torch.load(IMGM_PATH,map_location=device).to(device).eval()

train_TD = TaskDataset(TRAIN_DIR)
val_TD = TaskDataset(VAL_DIR)

train_dataloader = DataLoader(train_TD, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(val_TD, batch_size=BATCH_SIZE, shuffle=False)

len(train_TD)

20000

In [7]:
# train_lm_embeddings = []
# train_img_embeddings = []
# train_targets = []

# val_lm_embeddings = []
# val_img_embeddings = []
# val_targets = []


# for frames,instruction,actions in train_DL:
#     frames = frames[0]

#     train_lm_embeddings.append(lm_embedder(instruction, lm_encoder, tokenizer).cpu())
#     train_img_embeddings.append(frames)
#     # train_img_embeddings.append(img_embedder(frames, img_encoder).cpu())
#     train_targets.append(actions[0])


# for frames,instruction,actions in val_DL:
#     frames = frames[0]

#     val_lm_embeddings.append(lm_embedder(instruction, lm_encoder, tokenizer).cpu())
#     val_img_embeddings.append(frames)
#     # val_img_embeddings.append(img_embedder(frames, img_encoder).cpu())
#     val_targets.append(actions[0])

# Embeddings Dataset

In [8]:
# class EmbeddingsDataset(Dataset):
#   """
#     Pytorch Dataset class to load the embedded data

#     Data members:
#       lm_embeddings: list of language model embeddings
#       img_embeddings: list of language model embeddings
#       n_embs: number of embeddings in the dataset

#     Member functions:
#       __init__: ctor
#       __len__: returns n_ins
#       __getitem__: returns an instruction
#   """

#   def __init__(self, lm_embeddings, img_embeddings, actions):

#     self.lm_embeddings = lm_embeddings
#     self.img_embeddings = img_embeddings
#     self.actions = actions

#     self.n_embs = len(self.lm_embeddings)

#     return

#   def __len__(self):
#     """
#       Returns number of instructions in the Dataset
#     """

#     return self.n_embs

#   def __getitem__(self, idx):
#     """
#       Given an index return a instruction at that index
#     """

#     return {'instructions':self.lm_embeddings[idx].to(device), 'frames':self.img_embeddings[idx].to(device), 'actions':torch.tensor(self.actions[idx], dtype=torch.float32, device=device)}

In [9]:
# # Create pytorch dataset for train, val, test data
# train_dataset = EmbeddingsDataset(train_lm_embeddings, train_img_embeddings, train_targets)
# val_dataset = EmbeddingsDataset(val_lm_embeddings, val_img_embeddings, val_targets)
# # test_dataset  = EmbeddingsDataset(test_lm_embeddings, test_img_embeddings, test_targets)

# # Move pytorch datasets into dataloaders
# train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False)
# val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
# # test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [10]:
# len(train_dataset)

# Action Decoder

In [11]:
from matplotlib import projections
import math
from typing import Tuple

import torch
from torch import nn, Tensor
import torch.nn.functional as F
from torch.utils.data import dataset

activation = {}
def get_activation(name):
    def hook(model, input, output):
        activation[name] = output.detach()
    return hook

class CNNRNNNet(nn.Module):

    def __init__(self, hidden_size, dim_transformer_ffl=2048, nhead = 16, blocks=1, output_size = 3,):
        super().__init__()
        
        # set up the CNN model
        self.cnnmodel = torch.load(IMGM_PATH, map_location=device)
        # freeze layers of cnn model
        for paras in self.cnnmodel.parameters():
            paras.requires_grad = False
        # get relu activation of last block of resnet50
        
        self.cnnmodel.layer4[2].relu.register_forward_hook(get_activation('relu'))

        self.cnnlayer = torch.nn.Conv2d(2048, hidden_size, 1) # we can also bring the resnet embedding dim to a number different from hidden size

        self.input_size = hidden_size*7*7
        self.hidden_size = hidden_size
        self.output_size = output_size
        
        self.in2hidden = nn.Linear(self.input_size, hidden_size)
        self.layer_norm_in = nn.LayerNorm(self.hidden_size)
        
        self.rnn = nn.RNN(
            input_size = self.hidden_size, 
            hidden_size = self.hidden_size,
            nonlinearity = "relu", # guarnatee positive activations
            batch_first = True
            )

        self.pos_emb = PositionalEncoding(hidden_size, MAX_FRAMES)
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_size, dim_feedforward=dim_transformer_ffl, nhead=nhead, batch_first=False)
        self.encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=blocks)
        
#         self.mlp = nn.Sequential(
#                 nn.Linear(self.hidden_size, 1024),
#                 nn.ReLU(),
#                 nn.Linear(1024, 2048),
#                 nn.ReLU(),
#                 nn.Linear(2048, hidden_size),
#             )
        
        self.layer_norm_rnn = nn.LayerNorm(self.hidden_size)
        self.hidden2output = nn.Linear(self.hidden_size, self.output_size)


    def forward(self, input_img, hidden_state = None, is_noise = False,):
        # preprocess image with resnet
        self.batch_size = input_img.shape[0]
        self.seq_len = input_img.shape[1]

        # x = torch.swapaxes(input_img, 0, 1).float()# (seq_len, batchsize, nc, w, h)
        
        x = input_img.float()
        
        x_acts = []
        for i in range(self.seq_len):
            temp = self.cnnmodel(x[:,i,:,:,:])
            x_act = self.cnnlayer(activation["relu"])
            x_acts.append(x_act) # (batchsize, nc, w, h) = (batchsize, 2048, 7,7)
        
        x_acts = torch.stack(x_acts, axis = 1) # (seqlen, batchsize,nc, w,h) rnn_activations
        # self.cnn_acts = torch.stack(cnn_acts, axis = 0) # (seqlen, batchsize, nc, w,h) 
        # self.cnn_acts_down = x_acts
        
        
        x_acts = x_acts.reshape(self.batch_size, x_acts.shape[1], -1) # flatten nc,w,h into one dim
        
        
        """ RNN METHOD """
#         if hidden_state == None:
#             self.hidden_state = self.init_hidden(batch_size = self.batch_size)
#         hidden_x = self.layer_norm_in(torch.relu(self.in2hidden(x_acts.float()))).to(device)
        
#         rnn_output, _ = self.rnn(hidden_x, self.hidden_state.to(device))
#         rnn_output = self.layer_norm_rnn(rnn_output)
#         out = self.hidden2output(torch.tanh(rnn_output))
        
        """ TRANSFORMER METHOD """
        hidden_x = self.layer_norm_in(self.pos_emb(self.in2hidden(x_acts.float())))
        encoder_output = self.encoder(hidden_x)
        out = self.hidden2output(encoder_output)
        
        """ MLP METHOD """
        # hidden_x = self.pos_emb(self.layer_norm_in(torch.relu(self.in2hidden(x_acts.float()))).to(device))
        # mlp_output = self.mlp(hidden_x)
        # out = self.hidden2output(torch.tanh(mlp_output))
        
        
        
        return out
        


    def init_hidden(self, batch_size):
        return nn.init.kaiming_uniform_(torch.empty(1, batch_size, self.hidden_size))
            
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length):
        super(PositionalEncoding, self).__init__()
        
        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        self.register_buffer('pe', pe.unsqueeze(0))
        
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]



In [16]:
model = CNNRNNNet(hidden_size = 512, dim_transformer_ffl=64, nhead = 16, blocks=1, output_size = 3,).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr = 1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 5, T_mult=1 )
                                                                 # torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 2)

# Training

## Pre-generated

In [17]:
# Training configurations
epochs = 20
# criterion = nn.CrossEntropyLoss(ignore_index=-1)

# lr = 1e-3
# optimizer = torch.optim.AdamW((p for p in model.parameters() if p.requires_grad), lr = lr)

# scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 3, T_mult=1 )

# mask = generate_causal_mask(MAX_FRAMES).to(device)

In [18]:
# Calculates the number of correct null action predictions and the number of correct non-null action predictions
def correct(preds, targs):
    null_idxs = torch.where(targs.cpu() == 2)
    non_null_idxs = torch.where(targs.cpu() < 2)
    
    null_preds = preds[null_idxs]
    non_null_preds = preds[non_null_idxs]
    
    c_null = torch.sum(null_preds == targs[null_idxs])
    n_null = len(null_preds)
    null_acc = c_null/n_null
    
    c_non_null = torch.sum(non_null_preds == targs[non_null_idxs])
    n_non_null = len(non_null_preds)
    non_null_acc = c_non_null/n_non_null
    
    return null_acc, non_null_acc

# Calculates the loss for a forward pass for both null and non-null action predictions (this is to avoid overfitting to null actions)
def loss(preds, targs):
    # Find indexes of null frames and non-null frames
    null_idxs = torch.where(targs.cpu() == 2)
    non_null_idxs = torch.where(targs.cpu() < 2)
    
    # Add batch dimension and reorder into (batch_size, n_classes, seq_len)
    null_preds = preds[null_idxs]
    non_null_preds = preds[non_null_idxs]

    null_loss = criterion(null_preds, targs[null_idxs])
    non_null_loss = criterion(non_null_preds, targs[non_null_idxs])
    
    return null_loss, non_null_loss, len(non_null_idxs)/len(null_idxs)
    

In [None]:
# Training and validation loop

# Store the average loss after each epoch
all_loss = {'train_null_loss':[],'train_non_null_loss':[], 'val_null_loss':[], 'val_non_null_loss':[]}
all_acc = {'train_null_acc':[], 'train_non_null_acc':[], 'val_null_acc':[], 'val_non_null_acc':[]}

print("starting")
for epoch in range(epochs):

    
    
    # Epoch stat trackers
    null_accs = []
    non_null_accs = []
    null_losses = []
    non_null_losses = []
    for idx, batch in enumerate(iter(train_dataloader)):
        model.train()
        optimizer.zero_grad()

        # Inputs and Targets
        instruction = batch['instructions']
        frames = batch['frames']
        actions = batch['actions']

        # Frame Padding
        # padding_mask = generate_pad_mask(batch=frames)
        # pad_indexes = torch.argwhere(padding_mask == True).cpu()
        # pad_indexes = pad_indexes[:,0] * MAX_FRAMES + pad_indexes[:,1]

        # Get predictions
        output =  model(frames) # model(instruction, frames, mask, padding_mask)
        
        # Get Loss for both null and non-null actions
        null_loss, non_null_loss, scale = loss(output.reshape(-1,3), actions.type(torch.LongTensor).reshape(-1).to(device))
        null_losses.append(null_loss.item())
        non_null_losses.append(non_null_loss.item())
        
        if idx%30 == 0 and epoch == 0 :
            total_loss = non_null_loss
        else:
            total_loss = non_null_loss
        # null_loss = 0
        

        # Track stats
        _, predicted = torch.max(output.data, 2)
        
        predicted = predicted.reshape(-1).cpu()
        # predicted = torch.Tensor(np.delete(predicted.numpy(), pad_indexes.numpy())) # drop the pads 
        
        actions = actions.reshape(-1).cpu()
        # actions = torch.Tensor(np.delete(actions.numpy(), pad_indexes.numpy())) # drop the pads 
        
        null_acc, non_null_acc = correct(predicted, actions)
        null_accs.append(null_acc.cpu())
        non_null_accs.append(non_null_acc.cpu())

        # Backward pass
        total_loss.backward()
        # torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step(epoch + idx / len(train_dataloader))
        
    print('epoch: ', epoch)
    print('avg train null acc: ', sum(null_accs)/len(null_accs))
    print('avg train non-null acc: ', sum(non_null_accs)/len(non_null_accs))

    all_loss['train_null_loss'].append(sum(null_losses)/len(null_losses))
    all_loss['train_non_null_loss'].append(sum(non_null_losses)/len(non_null_losses))
    all_acc['train_null_acc'].append(sum(null_accs)/len(null_accs))
    all_acc['train_non_null_acc'].append(sum(non_null_accs)/len(non_null_accs))
        
        
    # Validate on validation set
    # print(torch.norm(model.mlp[-3].weight.grad))
    # Turn off gradient calcs
    val_null_accs = []
    val_non_null_accs = []
    val_null_losses = []
    val_non_null_losses = []

    for idx, batch in enumerate(iter(val_dataloader)):
        model.eval()
        # Inputs and Targets
        instruction = batch['instructions']
        frames = batch['frames']
        actions = batch['actions']

        # Frame Padding
        # padding_mask = generate_pad_mask(batch=frames)
        # pad_indexes = torch.argwhere(padding_mask == True).cpu()
        # pad_indexes = pad_indexes[:,0] * MAX_FRAMES + pad_indexes[:,1]

        # Get predictions
        output = model(frames) # model(instruction, frames, mask, padding_mask)

        # Get Losses
        null_loss, non_null_loss, scale = loss(output.reshape(-1,3), actions.type(torch.LongTensor).reshape(-1).to(device))
        val_null_losses.append(null_loss.item())
        val_non_null_losses.append(non_null_loss.item())

        # Track Stats
        _, predicted = torch.max(output.data, 2)

        predicted = predicted.reshape(-1).cpu()
        # predicted = torch.Tensor(np.delete(predicted.numpy(), pad_indexes.numpy())) # drop the pads

        actions = actions.reshape(-1).cpu()
        # actions = torch.Tensor(np.delete(actions.numpy(), pad_indexes.numpy())) # drop the pads

        null_acc, non_null_acc = correct(predicted, actions)
        val_null_accs.append(null_acc.cpu())
        val_non_null_accs.append(non_null_acc.cpu())


    # Track loss and acc ever 5 epochs
    all_loss['val_null_loss'].append(sum(val_null_losses)/len(val_null_losses))
    all_loss['val_non_null_loss'].append(sum(val_non_null_losses)/len(val_non_null_losses))
    all_acc['val_null_acc'].append(sum(val_null_accs)/len(val_null_accs))
    all_acc['val_non_null_acc'].append(sum(val_non_null_accs)/len(val_non_null_accs))

    print("avg val null acc %.2f" % (sum(val_null_accs)/len(val_null_accs)))
    print("avg val non-null acc %.2f" % (sum(val_non_null_accs)/len(val_non_null_accs)))
    


starting
epoch:  0
avg train null acc:  tensor(0.0028)
avg train non-null acc:  tensor(0.5017)
avg val null acc 0.00
avg val non-null acc 0.49
epoch:  1
avg train null acc:  tensor(0.)
avg train non-null acc:  tensor(0.5283)
avg val null acc 0.00
avg val non-null acc 0.50
epoch:  2
avg train null acc:  tensor(0.)
avg train non-null acc:  tensor(0.5554)
avg val null acc 0.00
avg val non-null acc 0.50
epoch:  3
avg train null acc:  tensor(0.)
avg train non-null acc:  tensor(0.6025)
avg val null acc 0.00
avg val non-null acc 0.49
epoch:  4
avg train null acc:  tensor(0.)
avg train non-null acc:  tensor(0.6515)
avg val null acc 0.00
avg val non-null acc 0.49
epoch:  5
avg train null acc:  tensor(0.)
avg train non-null acc:  tensor(0.5608)
avg val null acc 0.00
avg val non-null acc 0.50
epoch:  6
avg train null acc:  tensor(0.)
avg train non-null acc:  tensor(0.6061)
avg val null acc 0.00
avg val non-null acc 0.50
epoch:  7
avg train null acc:  tensor(0.)
avg train non-null acc:  tensor(0.6

In [None]:
import matplotlib.pyplot as plt

def plot_dict(dict_arrays, use_xlabel='Epochs', use_ylabel='Value', use_title=None):
    # Font size select custom or adjusted on `magnify` value.
    font_size = np.interp(0.1, [0.1,1], [10.5,50])

    # Font variables dictionary. Keep it in this format for future updates.
    font_dict = dict(
        family='DejaVu Sans',
        color='black',
        weight='normal',
        size=font_size,
        )

    # Single plot figure.
    plt.subplot(1, 2, 1)

    # Use maximum length of steps. In case each arrya has different lengths.
    max_steps = []

    # Plot each array.
    for index, (use_label, array) in enumerate(dict_arrays.items()):
        # Set steps plotted on x-axis - we can use step if 1 unit has different value.
        if 0 > 0:
            # Offset all steps by start_step.
            steps = np.array(range(0, len(array))) * 1 + 0
            max_steps = steps if len(max_steps) < len(steps) else max_steps
        else:
            steps = np.array(range(1, len(array) + 1)) * 1
            max_steps = steps if len(max_steps) < len(steps) else max_steps

        # Plot array as a single line.
        plt.plot(steps, array, linestyle=(['-'] * len(dict_arrays))[index], label=use_label)

        # Plots points values.
        if ([False] * len(dict_arrays))[index]:
            # Loop through each point and plot the label.
            for x, y in zip(steps, array):
                # Add text label to plot.
                plt.text(x, y, str(round(y, 3)), fontdict=font_dict)

    # Set horizontal axis name.
    plt.xlabel(use_xlabel, fontdict=font_dict)

    # Use x ticks with steps or labels.
    plt.xticks(max_steps, None, rotation=0)

    # Set vertical axis name.
    plt.ylabel(use_ylabel, fontdict=font_dict)

    # Adjust both axis labels font size at same time.
    plt.tick_params(labelsize=font_dict['size'])

    # Place legend best position.
    plt.legend(loc='best', fontsize=font_dict['size'])

    # Adjust font for title.
    font_dict['size'] *= 1.8

    # Set title of figure.
    plt.title(use_title, fontdict=font_dict)

    # Rescale `magnify` to be used on inches.
    magnify = 0.1
    magnify *= 15

    # Display grid depending on `use_grid`.
    plt.grid(True)

    # Make figure nice.
    plt.tight_layout()

    # Get figure object from plot.
    fig = plt.gcf()

    # Get size of figure.
    figsize = fig.get_size_inches()

    # Change size depending on height and width variables.
    figsize = [figsize[0] * 3 * magnify, figsize[1] * 1 * magnify]

    # Set the new figure size with magnify.
    fig.set_size_inches(figsize)

    return

In [None]:
plot_dict(all_acc)

In [None]:
plot_dict(all_loss)

## Evaluating

In [None]:
model.eval()

val_epoch_correct_non_null = 0

val_epoch_count_non_null = 0

with torch.no_grad():
    test_correct_null = 0
    test_correct_non_null = 0
    test_count_null = 0
    test_count_non_null = 0

    for idx, batch in enumerate(iter(val_dataloader)):
        # Inputs and Targets
        instruction = batch['instruction']
        frames = batch['frames']
        targets = batch['actions'][:,-1]
        print(targets)

        # Frame Padding
        padding_mask = generate_pad_mask(batch=frames)
        pad_indexes = np.argwhere(np.array(padding_mask.cpu()) == False)[:,1]

        # Get predictions
        predictions = model(instruction, frames, mask, padding_mask)
        print(predictions)
        predictions = predictions[:,pad_indexes]

        # Get Losses
        non_null_loss = criterion(predictions, targets.long())
        # null_loss, non_null_loss, scale = loss(predictions, targets)
        # total_loss = non_null_loss #null_loss*scale + non_null_loss*(1/scale)
        # # total_loss = criterion(predictions.permute(0, 2, 1), targets.long())

        # Track Stats
        val_epoch_correct_non_null += (predictions.argmax(dim=-1).cpu() == targets.cpu()).sum().item()
        val_epoch_count_non_null += (predictions.argmax(dim=-1).cpu() == targets.cpu()).size(0)

print("Test Non-Null Accuracy: ", round(val_epoch_correct_non_null/val_epoch_count_non_null,4))