# Neural Network Modeling with iWISDM
### Overview:

This notebook aims to provided an example of training a PyTorch Model (transformer) on a premade task dataset
- Text-embedder: all-mpnet-base-v2 (pretrained Sentence Transformer)
- Image-embedder:  EfficientNetV2 (pretrained)
- Decoder/Classifier: Transformer Decoder only Model trained on both text and image embeddings to out put action class

Datasets/Training:
- See the task_dataset_gen.ipynb notebooks for how to generate and save a dataset for training and validation

## Imports

In [35]:
# General Imports
import torch
from torch import nn, Tensor
import torch.nn.functional as F
import numpy as np
import math
import os
import json
from torch.utils.data import Dataset, DataLoader

# For Instruction Encoding
from sentence_transformers import SentenceTransformer

# For Image Encoding
import torchvision
import torchvision.transforms as T
from PIL import Image

## Constants

In [36]:
# General
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 64

# Data
TRAIN_PATH = "outputs/trials/train/"
VAL_PATH = "outputs/trials/validation/"
MAX_FRAMES = 3
ACTIONS = {"null": 0, "true": 1, "false": 2}

# Instruction Encoder
INS_ENCODER_DIM = 768

# Image Encoder
IMG_ENCODER_DIM = 2048

# Decoder
DECODER_HIDDEN = 64



## Image Encoder

In [37]:
transform = T.Compose([
            T.Resize(224),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

img_encoder = torchvision.models.efficientnet_v2_s(weights=torchvision.models.EfficientNet_V2_S_Weights.DEFAULT)
img_encoder = img_encoder.eval().to(device)

activation = {}
def get_activation(name):
    def hook(model, input, output):
        activation[name] = output.detach()
    return hook

# # Set hook and freeze layers
for paras in img_encoder.parameters():
    paras.requires_grad = False
# img_encoder.features[8][2].register_forward_hook(get_activation('activation'))
img_encoder.avgpool.register_forward_hook(get_activation('activation'))

<torch.utils.hooks.RemovableHandle at 0x7fbc7021aed0>

## Datasets

#### Read in trials

In [39]:
def read_trials(path):
    embeddings = torch.tensor([]).to(device)
    infos = []

    for i, trial_fp in enumerate(os.listdir(path)):
        if 'trial' not in trial_fp:
            continue

        trial_fp = os.path.join(path, trial_fp)
        imgs = torch.tensor([])
        info = None
        
        for fp in os.listdir(trial_fp):
            fp = os.path.join(trial_fp, fp)
            
            if fp[-4:] == '.png':
                img = Image.open(fp)
                img = transform(img)
                imgs = torch.cat((imgs, img.unsqueeze(0)), 0)
            elif 'trial_info' in fp:
                info = json.load(open(fp))
                infos.append(info)

        if imgs.shape[0] > MAX_FRAMES:
            raise Exception(trial_fp + " contains more frames than the set maximum (MAX_FRAMES) !!!")
        elif imgs.shape[0] != len(info['answers']):
            raise Exception(trial_fp + " numbers of frames does not match number of actions")
        
        # Encode images
        img_encoder(imgs.to(device))
        img_embs = activation['activation']
        embeddings = torch.cat((embeddings, img_embs), 0)

        if i % 100 == 0:
            print(f"Processed {i} trials")
        
    return embeddings, infos

train_frame_embs, train_infos = read_trials(TRAIN_PATH)
val_frame_embs, val_infos = read_trials(VAL_PATH)

Processed 0 trials
Processed 100 trials
Processed 200 trials
Processed 300 trials
Processed 400 trials
Processed 500 trials
Processed 600 trials
Processed 700 trials
Processed 800 trials
Processed 900 trials
Processed 1000 trials
Processed 1100 trials
Processed 1200 trials
Processed 1300 trials
Processed 1400 trials
Processed 1500 trials
Processed 1600 trials
Processed 1700 trials
Processed 1800 trials
Processed 1900 trials
Processed 2000 trials
Processed 2100 trials
Processed 2200 trials
Processed 2300 trials
Processed 2400 trials
Processed 2500 trials
Processed 2600 trials
Processed 2700 trials
Processed 2800 trials
Processed 2900 trials
Processed 3000 trials
Processed 3100 trials
Processed 3200 trials
Processed 3300 trials
Processed 3400 trials
Processed 3500 trials
Processed 3600 trials
Processed 3700 trials
Processed 3800 trials
Processed 3900 trials
Processed 4000 trials
Processed 4100 trials
Processed 4200 trials
Processed 4300 trials
Processed 4400 trials
Processed 4500 trials


#### Encode Actions

In [45]:
def map_actions(infos):
    actions = []
    for info in train_infos:
        # Map actions to integers
        actions.append([ACTIONS[a] for a in info['answers']])

    return torch.tensor(actions).to(device)

train_actions = map_actions(train_infos)
val_actions = map_actions(val_infos)

#### Torch Dataset

In [47]:
class iWISDM_Dataset(Dataset):
    """
        desc: Torch Dataset inherited class which dynamically generates task samples

        args:
            - image embeddings (Tensor): image embeddings
            - instructions (list): list of instructions
            - answers (list): list of answers
    """

    def __init__(self, image_embs: Tensor, instructions: list, answers: list):
        self.image_embs = image_embs
        self.instructions = instructions
        self.answers = answers

    def __len__(self):
        return len(self.instructions)

    def __getitem__(self, idx):
        return self.image_embs[idx], self.instructions[idx], self.answers[idx]  

train_dataset = iWISDM_Dataset(train_frame_embs, [info['instruction'] for info in train_infos], train_actions)
val_dataset = iWISDM_Dataset(val_frame_embs, [info['instruction'] for info in val_infos], val_actions)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True)

## Instructions Encoder

In [48]:
ins_encoder = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')

modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/10.6k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/363 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

## Transformer Decoder 

#### Position Encoder + Maskers

In [51]:
# Creates a square Sequential/Causal mask of size sz
def generate_causal_mask(sz: int, device:str) -> Tensor:
    return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1).to(device)

# Generates a padding masks for each sequence in a batch
def generate_pad_mask(batch, pad, device):

    pad_tensor = pad.to(device)

    mask = np.zeros((batch.shape[0],batch.shape[1]))

    for s in range(0, batch.shape[0]):
        for v in range(0, batch[s].shape[0]):
            new_s = torch.all(batch[s][v] == pad_tensor)
            mask[s][v] = new_s

    return torch.tensor(mask).bool().to(device)

# Sinusoidal Positional Encoding Module
class PositionalEncoding(nn.Module):
    # Positional encoding module taken from PyTorch Tutorial
    # Link: https://pytorch.org/tutorials/beginner/transformer_tutorial.html

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

#### Decoder

In [55]:
class TFDecoder(nn.Module):
    """
        desc: Torch nn.Module inherited Transformer Decoder Only class
        args:
            - hidden_dim (int): dimension of hidden layers
            - device : torch device
            - img_hidden (int): dimension of the image encoder's layer
            - dim_transformer_ffl (int): dimension of the decoder's feedfoward layers
            - nhead (int): number of attention heads per decoder block
            - blocks (int): number of decoder blocks to stack
            - output_dim (int): number of classes
            - max_frames (int): sequence length
    """
    def __init__(self, hidden_dim, device, img_hidden=1280, dim_transformer_ffl=2048, nhead = 16, blocks=2, max_frames=6, output_dim = 3, dropout=0.1):
        super().__init__()

        self.device = device

        self.max_frames = max_frames
    
        # Convolutional layer 
        self.cnnlayer = torch.nn.Conv3d(img_hidden, hidden_dim, 1)


        # Dimensions
        self.input_dim = hidden_dim*7*7
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim

        # Linear layers (instructions, in, & out)
        self.ins_hidden = nn.Linear(768, hidden_dim)
        self.in2hidden = nn.Linear(self.input_dim, hidden_dim)
        self.hidden2output = nn.Linear(self.hidden_dim, self.output_dim)

        # Input layer norm
        self.layer_norm_in = nn.LayerNorm(self.hidden_dim)
        
        # Position Encoding class
        self.pos_emb = PositionalEncoding(hidden_dim, dropout, self.max_frames)

        # Encoder only transformer 
        self.decoder_layer = nn.TransformerDecoderLayer(d_model=hidden_dim, dim_feedforward=dim_transformer_ffl, nhead=nhead, batch_first=False, dropout=dropout)
        self.decoder = nn.TransformerDecoder(self.decoder_layer, num_layers=blocks)

    # One forward pass through the model
    def forward(self, frames, instructions=None):
        # Get batch size and sequence length
        self.batch_size = frames.shape[0]
        self.seq_len = frames.shape[1]
        
        # Generate masks
        causal_mask = generate_causal_mask(self.seq_len, self.device)
        pad = torch.ones((frames.shape[2],frames.shape[3], frames.shape[4]))
        padding_mask = generate_pad_mask(frames, pad, self.device)

        # Swap axes to get (seq_len, batchsize, nc, w, h)
        x = torch.swapaxes(frames, 0, 1).float() 

        # Get activations from each frame
        x_acts = torch.tensor([]).to(self.device)
        for i in range(x.shape[0]):
            temp = self.cnnlayer(x[i, :, :, :, :])
            x_acts = torch.cat((x_acts, temp.unsqueeze(0)), dim=0)

        # Flatten nc, w, h into one dim
        x_acts = x_acts.reshape(x_acts.shape[0], x_acts.shape[1], -1)

        # Pass through input linear layer and prepend instructions
        ins_embeddings = self.ins_hidden(instructions).unsqueeze(0)
        hidden_x = self.in2hidden(x_acts.float())

        # Add positional encoding and normalize
        hidden_x = self.layer_norm_in(self.pos_emb(hidden_x))

        # Pass through transformer
        decoder_output = self.decoder(hidden_x, ins_embeddings, tgt_mask=causal_mask, tgt_key_padding_mask=padding_mask)

        # Pass through output linear layer
        out = self.hidden2output(decoder_output)
        
        return out

In [57]:
model = TFDecoder(hidden_dim=DECODER_HIDDEN, img_hidden=IMG_ENCODER_DIM, device=device, max_frames=MAX_FRAMES, output_dim=len(ACTIONS)).to(device)

## Train

## Validate