# Neural Network Modeling with MULTFS

#### Overview:
This notebook aims to provided an example of training a PyTorch Model on a premade task dataset
- Text-embedder: ```all-mpnet-base-v2 (pretrained Sentence Transformer)```
- Image-embedder: ```vit_b_16 (pretrained Vision Transformer)``` 
- Decoder/Classifier: Transformer Decoder only Model trained on both text and image embeddings to out put action class


#### Datasets/Training:
- In this notebook we use pre-saved and generated datasets
    - *See the ..._dataset_gen.ipynb notebooks for how to generate and save a dataset*




### Imports

In [49]:
import sys
sys.path.append('../')

import numpy as np
import os
from PIL import Image
import json
import math
import copy
from sklearn.model_selection import train_test_split

import torch
from torch import nn, Tensor
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
import torchvision
from tqdm.notebook import tqdm

from cognitive.task_bank import CompareLocTemporal
from cognitive import task_generator as tg
from cognitive import constants as const
from cognitive import stim_generator as sg
from cognitive import info_generator as ig
import random

### Constants

In [50]:
TRAIN_DIR = '../datasets/train'  # Training Dataset Directory
VAL_DIR = '../datasets/val'  # Validation Dataset Directory
TEST_DIR = '../datasets/test'  # Testing Dataset Directory
BATCH_SIZE = 1
MAX_FRAMES = 6  # the max possible frames across tasks
VIT_OUT_DIM = 1000 # vision transformer output dimension
LM_OUT_DIM = 768 # language model output dimension

device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

# Datasets
- Read in the pregenerated task trials organized into frames, instructions, and correct actions. 

In [51]:
def read_trials(path):
    frames = []
    infos = []

    for trial_fp in os.listdir(path):
        if 'trial' not in trial_fp:
            continue

        trial_fp = os.path.join(path, trial_fp)
        imgs = []
        info = None
        
        for fp in os.listdir(trial_fp):
            fp = os.path.join(trial_fp, fp)
            
            if fp[-4:] == '.png':
                img = np.rollaxis(np.array(Image.open(fp), dtype=np.float32),2,0)
                imgs.append(img)
            elif 'trial_info' in fp:
                info = json.load(open(fp))
                infos.append(info)
                
        if len(imgs) > MAX_FRAMES:
            raise Exception(trial_fp + " contains more frames than the set maximum (MAX_FRAMES) !!!")
        elif len(imgs) != len(info['answers']):
            raise Exception(trial_fp + " numbers of frames does not match number of actions")
            
        frames.append(np.array(imgs))

    return frames, infos

train_frames, train_infos = read_trials(TRAIN_DIR)
val_frames, val_infos = read_trials(VAL_DIR)
test_frames, test_infos = read_trials(TEST_DIR)

In [52]:
train_ins = [x['instruction'] for x in train_infos]
train_raw_targets = [x['answers'] for x in train_infos]

val_ins = [x['instruction'] for x in val_infos]
val_raw_targets = [x['answers'] for x in val_infos]

test_ins = [x['instruction'] for x in test_infos]
test_raw_targets = [x['answers'] for x in test_infos]

### Encode Target Actions
- Encodes the target actions into one hot encoding vectors corresponding to the actions ```true, false, and null```

In [53]:
action_map = {'true': 0, 'false': 1, 'null': 2}

def map_actions(amap, raw_actions):
    target_actions = []

    for actions in raw_actions:
        encoded = []
        for action in actions:
            encoded.append(amap[action])
        target_actions.append(encoded)
    
    return target_actions

train_targets = map_actions(action_map, train_raw_targets)
val_targets = map_actions(action_map, val_raw_targets)
test_targets = map_actions(action_map, test_raw_targets)

## Instruction Dataset

In [54]:
class InstructionsDataset(Dataset):
  """
    Pytorch Dataset class to load the Instructions Data

    Data members:
      instructions: list of instructions
      n_ins: number of instructions in the dataset

    Member functions:
      __init__: ctor
      __len__: returns n_ins
      __getitem__: returns an instruction
  """

  def __init__(self, x):

    self.instructions = x

    self.n_ins = len(self.instructions)

    return

  def __len__(self):
    """
      Returns number of instructions in the Dataset
    """

    return self.n_ins

  def __getitem__(self, idx):
    """
      Given an index return a instruction at that index
    """

    return self.instructions[idx]

In [55]:
class InstructionsCollator(object):
  """
    Data Collator used for GPT2 in a classificaiton tasks

    Args:
      use_tokenizer :
        Transformer type tokenizer used to process raw text into numbers.

    Data members:
      use_tokenizer: Tokenizer to be used inside the class.

    Member functions:
      __init__: ctor
      __call__: tokenize input

    """

  def __init__(self, use_tokenizer):

    self.use_tokenizer = use_tokenizer

    return

  def __call__(self, instructions):
    """
        Tokenizes input
    """

    # Call tokenizer
    inputs = self.use_tokenizer(instructions, padding=True, truncation=True, return_tensors='pt')

    return inputs


In [56]:
# Pretrained Language Model and Tokenizer 
lm_encoder = AutoModel.from_pretrained('sentence-transformers/all-mpnet-base-v2')
tokenizer = AutoModel.from_pretrained('sentence-transformers/all-mpnet-base-v2')

# Uncomment for offline load of lang embedder
# lm_encoder = AutoModel.from_pretrained('offline_models/all-mpnet-base-v2')
# tokenizer = AutoTokenizer.from_pretrained('offline_models/all-mpnet-base-v2')

# Create data collator to encode text and labels into numbers.
InstructionsCollator = InstructionsCollator(use_tokenizer=tokenizer)

# Create pytorch datasets for instructions
ins_train_dataset = InstructionsDataset(train_ins)
ins_val_dataset = InstructionsDataset(val_ins)
ins_test_dataset = InstructionsDataset(test_ins)

# Move pytorch datasets into dataloaders
ins_train_dataloader = DataLoader(ins_train_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=InstructionsCollator)
ins_val_dataloader = DataLoader(ins_val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=InstructionsCollator)
ins_test_dataloader = DataLoader(ins_test_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=InstructionsCollator)


## Frames Dataset

In [57]:
class FramesDataset(Dataset):
  """
    Pytorch Dataset class to load the Frame Data

    Data members:
      frames``ist of frames
      n_imgs: number of iamges in the dataset

    Member functions:
      __init__: ctor
      __len__: returns n_imgs
      __getitem__: returns an frame
  """

  def __init__(self, x):

    self.frames = x

    self.n_imgs = len(self.frames)

    return

  def __len__(self):
    """
      Returns number of frames in the Dataset
    """

    return self.n_imgs

  def __getitem__(self, idx):
    """
      Given an index return a frame
    """

    return torch.tensor(self.frames[idx])

In [58]:
# Pretrained Vision Transformer
vit_encoder = torchvision.models.vit_b_16(weights=torchvision.models.ViT_B_16_Weights.DEFAULT)

# Uncommnent for offline load of vit
# vit_encoder = torch.load('offline_models/vit_b_16/vit_b_16')

# Create pytorch datasets for instructions
frames_train_dataset = FramesDataset(train_frames)
frames_val_dataset = FramesDataset(val_frames)
frames_test_dataset = FramesDataset(test_frames)

# Move pytorch datasets into dataloaders
frames_train_dataloader = DataLoader(frames_train_dataset, batch_size=BATCH_SIZE, shuffle=False)
frames_val_dataloader = DataLoader(frames_val_dataset, batch_size=BATCH_SIZE, shuffle=False)
frames_test_dataloader = DataLoader(frames_test_dataset, batch_size=BATCH_SIZE, shuffle=False)


# Language Encoder

### Language Embedder

In [59]:
def lm_embedder(instruction, encoder):
    #Mean Pooling - Take attention mask into account for correct averaging
    def mean_pooling(model_output, attention_mask):
        token_embeddings = model_output[0] # First element of model_output contains all token embeddings
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

    # Compute token embeddings
    with torch.no_grad():
        lm_output = encoder(**instruction)

    # Perform pooling
    sentence_embeddings = mean_pooling(lm_output, instruction['attention_mask'])
    
    # Normalize embeddings
    sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
        
    return sentence_embeddings

# Image Encoder

### Position Embeddings

### Image Embedder
- We pad frames based on max possible

In [60]:
def img_embedder(frames, encoder):
    with torch.no_grad():
        vit_out = encoder(torch.tensor(frames))

    npads = MAX_FRAMES-len(vit_out)
    pad = torch.ones((npads, vit_out.shape[1]))
    vit_out = torch.cat((vit_out, pad))

    return vit_out

# Create Embeddings

In [61]:
train_lm_embeddings = []
train_img_embeddings = []

val_lm_embeddings = []
val_img_embeddings = []

test_lm_embeddings = []
test_img_embeddings = []

for train_i in ins_train_dataloader:
    train_f = train_f[0]

    train_lm_embeddings.append(lm_embedder(train_i, lm_encoder))
    train_img_embeddings.append(img_embedder(train_f, vit_encoder))
    
for val_i,val_f, test_i,test_f in zip(ins_val_dataloader,frames_val_dataloader, ins_test_dataloader,frames_test_dataloader):
    val_f = val_f[0]
    test_f = test_f[0]

    val_lm_embeddings.append(lm_embedder(val_i, lm_encoder))
    val_img_embeddings.append(img_embedder(val_f, vit_encoder))
    
    test_lm_embeddings.append(lm_embedder(test_i, lm_encoder))
    test_img_embeddings.append(img_embedder(test_f, vit_encoder))
    

  vit_out = encoder(torch.tensor(frames))


In [62]:
len(train_lm_embeddings)

800

# Embeddings Dataset

In [63]:
class EmbeddingsDataset(Dataset):
  """
    Pytorch Dataset class to load the embedded data

    Data members:
      lm_embeddings: list of language model embeddings
      img_embeddings: list of language model embeddings
      n_embs: number of embeddings in the dataset

    Member functions:
      __init__: ctor
      __len__: returns n_ins
      __getitem__: returns an instruction
  """

  def __init__(self, lm_embeddings, img_embeddings, actions):

    self.lm_embeddings = lm_embeddings
    self.img_embeddings = img_embeddings
    self.actions = actions

    self.n_embs = len(self.lm_embeddings)

    return

  def __len__(self):
    """
      Returns number of instructions in the Dataset
    """

    return self.n_embs

  def __getitem__(self, idx):
    """
      Given an index return a instruction at that index
    """

    return {'instruction':self.lm_embeddings[idx], 'frames':self.img_embeddings[idx], 'actions':torch.tensor(self.actions[idx], dtype=torch.float32)}

In [64]:
# Create pytorch dataset for train, val, test data
train_dataset = EmbeddingsDataset(train_lm_embeddings, train_img_embeddings, train_targets)
val_dataset = EmbeddingsDataset(val_lm_embeddings, val_img_embeddings, val_targets)
test_dataset  = EmbeddingsDataset(test_lm_embeddings, test_img_embeddings, test_targets)

# Move pytorch datasets into dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)

# Action Decoder

In [65]:
from matplotlib import projections
import math
from typing import Tuple

import torch
from torch import nn, Tensor
import torch.nn.functional as F
from torch.utils.data import dataset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class CausalMatchTransformer(nn.Module):
    """
    Pytorch based transformer decoder model
    """

    # Initialize Model with Params
    def __init__(self, nframes=MAX_FRAMES, blocks=3, nhead=5, emb_dim=VIT_OUT_DIM, classes=3, device=device):
        super().__init__()

        # Device
        self.device = device

        # Embedding Dimension
        self.emb_dim = emb_dim

        # Number of frames
        self.nframes = nframes

        # Frame Position Embedder Layer
        self.pos_emb = nn.Parameter(torch.Tensor(nframes,emb_dim)).to(device)
        torch.nn.init.xavier_uniform_(
           self.pos_emb,
           gain=torch.nn.init.calculate_gain("linear"))

        # Instruction Dim Projection Layer
        self.lm_linear_layer = nn.Linear(LM_OUT_DIM, emb_dim).to(device)

        # Decoder Layers
        self.decoder_layer = nn.TransformerDecoderLayer(d_model=emb_dim, nhead=nhead, batch_first=True).to(device)
        self.decoder_layers = _get_clones(self.decoder_layer, blocks)
        
        # Decoder
        self.decoder = nn.TransformerDecoder(self.decoder_layers, num_layers=blocks).to(device)

        # Action classifier
        self.classifier = nn.Linear(emb_dim, classes).to(device)

    # Function for forward pass
    def forward(self, instruction, frames, mask, padding_mask):

        # Project instruction embedding
        instruction = self.lm_linear_layer(instruction)

        # Add the frame position embedding
        for i in range(len(frames)):
            frames[i] += self.pos_emb[i,:]

        # Apply each Decoder Layer (block)
        for layer in self.decoder_layers:
            frames = layer(frames, instruction, tgt_mask=mask, tgt_key_padding_mask=padding_mask, ) 

        # Pass through linear layer for classification
        output = self.classifier(frames)

        return output
 
# Creates a list of torch duplicate torch modules
def _get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])

# Creates a square Sequential/Causal mask of size sz
def generate_causal_mask(sz: int) -> Tensor:
    return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)

# Generates a padding masks for each sequence in a batch
def generate_pad_mask(batch):

    pad_tensor = torch.ones((batch.shape[2])).to(device)

    mask = np.zeros((batch.shape[0],batch.shape[1]))

    for s in range(0, batch.shape[0]):
        for v in range(0, batch[s].shape[0]):
            new_s = torch.all(batch[s][v] == pad_tensor)
            mask[s][v] = new_s

    return torch.tensor(mask).bool().to(device)

In [66]:
model = CausalMatchTransformer(nframes=MAX_FRAMES,
                               blocks=3,
                               nhead=8,
                               emb_dim=VIT_OUT_DIM,
                               classes=len(action_map.keys()),
                               device=device).float().to(device)

# Training

## Pre-generated

In [67]:
# Training configurations
epochs = 15

# weights = [2, 2, 1]
# class_weights = torch.FloatTensor(weights)
# criterion = nn.CrossEntropyLoss(weight=class_weights, reduction='sum')
criterion = nn.CrossEntropyLoss()

lr = 1e-5
optimizer = torch.optim.Adam(
    (p for p in model.parameters() if p.requires_grad), lr=lr
)

mask = generate_causal_mask(MAX_FRAMES).to(device)

In [68]:
# Calculates the number of correct null action predictions and the number of correct non-null action predictions
def correct(preds, targs):
    c_null = (preds == targs).sum().item()
    c_non_null = 0
    
    # Get indexs of non-null target actions
    idxs = np.where(targs < 2)
    
    # Count correct and totals
    c_non_null = (preds[idxs] == targs[idxs]).sum().item()
    n_non_null = (preds[idxs] == targs[idxs]).size(0)
    c_null -= c_non_null
    n_null = (preds == targs).size(1) - n_non_null
    
    return c_null,n_null, c_non_null,n_non_null

# Calculates the loss for a forward pass for both null and non-null action predictions (this is to avoid overfitting to null actions)
def loss(preds, targs):
    null_idxs = np.where(targs == 2)
    non_null_idxs = np.where(targs < 2)
    
    # We use a permute to get the correct predictions shape (batch_size, n_classes, seq_len)
    null_loss = criterion(preds[null_idxs].unsqueeze(0).permute(0, 2, 1), targs[null_idxs].unsqueeze(0).long())
    non_null_loss = criterion(preds[non_null_idxs].unsqueeze(0).permute(0, 2, 1), targs[non_null_idxs].unsqueeze(0).long())
    
    return null_loss, non_null_loss
    

In [69]:
# Training and validation loop

# Store the average loss after each epoch
all_loss = {'train_loss':[], 'val_loss':[]}
all_acc = {'train_null_acc':[], 'train_non_null_acc':[], 'val_null_acc':[], 'val_non_null_acc':[]}

print("starting")
for epoch in range(epochs):
    print(f"epoch={epoch}")

    # Epoch stat trackers
    epoch_loss = 0
    epoch_correct_null = 0
    epoch_correct_non_null = 0
    epoch_count_null = 0
    epoch_count_non_null = 0
    epoch_count = 0
    for idx, batch in enumerate(iter(train_dataloader)):

        # Inputs and Targets
        instruction = batch['instruction']
        frames = batch['frames']
        targets = batch['actions']
        
        # Frame Padding
        padding_mask = generate_pad_mask(batch=frames)
        pad_indexes = np.argwhere(np.array(padding_mask) == False)[:,1]

        # Get predictions
        predictions = model(instruction, frames, mask, padding_mask)
        predictions = predictions[:,pad_indexes]
        
        # Get Loss for both null and non-null actions
        # null_loss, non_null_loss = loss(predictions, targets)
        # total_loss = null_loss + non_null_loss
        total_loss = criterion(predictions.permute(0, 2, 1), targets.long())

        # Track stats
        correct_counts = correct(predictions.argmax(dim=-1), targets)
        epoch_correct_null += correct_counts[0]
        epoch_count_null += correct_counts[1]
        epoch_correct_non_null += correct_counts[2]
        epoch_count_non_null += correct_counts[3]
        
        epoch_loss += total_loss.item()

        # Backward pass
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()
        
    # Validate on validation set every 5 epochs
    if (epoch+1) % 2 == 0 or epoch == epochs:
        # Turn off gradient calcs
        with torch.no_grad():
            val_epoch_loss = 0
            val_epoch_correct_null = 0
            val_epoch_correct_non_null = 0
            val_epoch_count_null = 0
            val_epoch_count_non_null = 0

            for idx, batch in enumerate(iter(val_dataloader)):
                # Inputs and Targets
                instruction = batch['instruction']
                frames = batch['frames']
                targets = batch['actions']

                # Frame Padding
                padding_mask = generate_pad_mask(batch=frames)
                pad_indexes = np.argwhere(np.array(padding_mask) == False)[:,1]
            
                # Get predictions
                predictions = model(instruction, frames, mask, padding_mask)
                predictions = predictions[:,pad_indexes]

                # Get Losses
                # null_loss, non_null_loss = loss(predictions, targets)
                # total_loss = null_loss + non_null_loss
                total_loss = criterion(predictions.permute(0, 2, 1), targets.long())
                
                # Track Stats
                val_correct_counts = correct(predictions.argmax(dim=-1), targets)
                val_epoch_correct_null += val_correct_counts[0]
                val_epoch_count_null += val_correct_counts[1]
                val_epoch_correct_non_null += val_correct_counts[2]
                val_epoch_count_non_null += val_correct_counts[3]

                val_epoch_loss += total_loss.item()

        # Track loss and acc ever 5 epochs
        avg_train_loss = epoch_loss / len(train_dataloader)
        avg_val_loss = val_epoch_loss / len(val_dataloader)

        all_loss['val_loss'].append(avg_val_loss)
        all_acc['val_null_acc'].append(val_epoch_correct_null / val_epoch_count_null)
        all_acc['val_non_null_acc'].append(val_epoch_correct_non_null / val_epoch_count_non_null)
        
        all_loss['train_loss'].append(avg_train_loss)
        all_acc['train_null_acc'].append(epoch_correct_null / epoch_count_null)
        all_acc['train_non_null_acc'].append(epoch_correct_non_null / epoch_count_non_null)

        print(all_acc)

starting
epoch=0
epoch=1
{'train_null_acc': [0.4078398665554629], 'train_non_null_acc': [0.48875], 'val_null_acc': [0.3824561403508772], 'val_non_null_acc': [0.53]}
epoch=2
epoch=3
{'train_null_acc': [0.4078398665554629, 0.4132610508757298], 'train_non_null_acc': [0.48875, 0.44375], 'val_null_acc': [0.3824561403508772, 0.3508771929824561], 'val_non_null_acc': [0.53, 0.49]}
epoch=4
epoch=5
{'train_null_acc': [0.4078398665554629, 0.4132610508757298, 0.37447873227689743], 'train_non_null_acc': [0.48875, 0.44375, 0.48125], 'val_null_acc': [0.3824561403508772, 0.3508771929824561, 0.49122807017543857], 'val_non_null_acc': [0.53, 0.49, 0.46]}
epoch=6
epoch=7
{'train_null_acc': [0.4078398665554629, 0.4132610508757298, 0.37447873227689743, 0.3686405337781485], 'train_non_null_acc': [0.48875, 0.44375, 0.48125, 0.46125], 'val_null_acc': [0.3824561403508772, 0.3508771929824561, 0.49122807017543857, 0.4456140350877193], 'val_non_null_acc': [0.53, 0.49, 0.46, 0.43]}
epoch=8
epoch=9
{'train_null_acc'

In [70]:

    
# p = torch.tensor([[[ 0.0548,  1.0788,  0.4486],
#                    [-0.1559,  1.4204,  0.5418],
#                    [ 0.1171,  1.3842, -0.5025]]])
# t = torch.tensor([[2., 2., 1.]])

# loss(p,t)

In [71]:
all_acc

{'train_null_acc': [0.4078398665554629,
  0.4132610508757298,
  0.37447873227689743,
  0.3686405337781485,
  0.371976647206005,
  0.3623853211009174,
  0.3519599666388657],
 'train_non_null_acc': [0.48875,
  0.44375,
  0.48125,
  0.46125,
  0.48125,
  0.49375,
  0.49875],
 'val_null_acc': [0.3824561403508772,
  0.3508771929824561,
  0.49122807017543857,
  0.4456140350877193,
  0.3719298245614035,
  0.5614035087719298,
  0.35789473684210527],
 'val_non_null_acc': [0.53, 0.49, 0.46, 0.43, 0.42, 0.46, 0.45]}

## Evaluating

In [72]:
with torch.no_grad():
    test_correct_null = 0
    test_correct_non_null = 0
    test_count_null = 0
    test_count_non_null = 0

    for idx, batch in enumerate(iter(test_dataloader)):
        # Inputs and Targets
        instruction = batch['instruction']
        frames = batch['frames']
        targets = batch['actions']

        # Frame Padding
        padding_mask = generate_pad_mask(batch=frames)
        pad_indexes = np.argwhere(np.array(padding_mask) == False)[:,1]

        # Get predictions
        predictions = model(instruction, frames, mask, padding_mask)
        predictions = predictions[:,pad_indexes]

        test_correct_counts = correct(predictions.argmax(dim=-1), targets)

        test_correct_null += test_correct_counts[0]
        test_count_null += test_correct_counts[1]
        test_correct_non_null += test_correct_counts[2]
        test_count_non_null += test_correct_counts[3]

print("Test Null Accuracy: ", round(test_correct_null/test_count_null,4))
print("Test Non-Null Accuracy: ", round(test_correct_non_null/test_count_non_null,4))

Test Null Accuracy:  0.339
Test Non-Null Accuracy:  0.57
