# Notes

Things you might ahve to do...
- add nn.LayerNorm to decoder

### Imports

In [25]:
import numpy as np
import pandas as pd
import os
from PIL import Image
import json

import torch
from torch import nn, Tensor
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
import torchvision
from tqdm.notebook import tqdm
import copy





### Constants

In [26]:
DIR = 'datasets/trials'
BATCH_SIZE = 1
MAX_FRAMES = 3
VIT_OUT_DIM = 1000
LM_OUT_DIM = 768
ACT_TOKEN = '[ACT]'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [27]:
np.array([1,2,3]).shape

(3,)

# Datasets

In [28]:
frames = []
infos = []




for trial_fp in os.listdir(DIR):
    trial_fp = os.path.join(DIR, trial_fp)
    imgs = []
    for fp in os.listdir(trial_fp):
        fp = os.path.join(trial_fp, fp)
        if fp[-4:] == '.png':
            imgs.append(np.rollaxis(np.array(Image.open(fp), dtype=np.float32),2,0))
        else:
            infos.append(json.load(open(fp)))
    frames.append(np.array(imgs))

In [29]:
instructions = [x['instruction'] for x in infos]
raw_target_actions = [x['answers'] for x in infos]

### Encode Target Actions

In [30]:
action_map = {'true': np.array([1,0,0]), 'false': np.array([0,1,0]), 'null': np.array([0,0,1])}

target_actions = []

for actions in raw_target_actions:
    encoded = []
    for action in actions:
        encoded.append(action_map[action])
    target_actions.append(encoded)

target_actions

[[array([0, 0, 1]), array([0, 1, 0])],
 [array([0, 0, 1]), array([0, 1, 0])],
 [array([0, 0, 1]), array([0, 1, 0])],
 [array([0, 0, 1]), array([0, 1, 0])],
 [array([0, 0, 1]), array([0, 1, 0])]]

## Instruction Dataset

In [31]:
class InstructionsDataset(Dataset):
  """
    Pytorch Dataset class to load the Instructions Data

    Data members:
      instructions: list of instructions
      n_ins: number of instructions in the dataset

    Member functions:
      __init__: ctor
      __len__: returns n_ins
      __getitem__: returns an instruction
  """

  def __init__(self, x):

    self.instructions = x

    self.n_ins = len(self.instructions)

    return

  def __len__(self):
    """
      Returns number of instructions in the Dataset
    """

    return self.n_ins

  def __getitem__(self, idx):
    """
      Given an index return a instruction at that index
    """

    return self.instructions[idx]

In [32]:
class InstructionsCollator(object):
  """
    Data Collator used for GPT2 in a classificaiton tasks

    Args:
      use_tokenizer :
        Transformer type tokenizer used to process raw text into numbers.

    Data members:
      use_tokenizer: Tokenizer to be used inside the class.

    Member functions:
      __init__: ctor
      __call__: tokenize input

    """

  def __init__(self, use_tokenizer):

    self.use_tokenizer = use_tokenizer

    return

  def __call__(self, instructions):
    """
        Tokenizes input
    """

    # Call tokenizer
    inputs = self.use_tokenizer(instructions, padding=True, truncation=True, return_tensors='pt')


    return inputs


In [33]:
lm_encoder = AutoModel.from_pretrained('sentence-transformers/all-mpnet-base-v2')
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2')

# Create data collator to encode text and labels into numbers.
InstructionsCollator = InstructionsCollator(use_tokenizer=tokenizer)

# Create pytorch dataset for instructions
ins_train_dataset = InstructionsDataset(instructions)

# Move pytorch dataset into dataloader 
ins_train_dataloader = DataLoader(ins_train_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=InstructionsCollator)


## Frames Dataset

In [34]:
class FramesDataset(Dataset):
  """
    Pytorch Dataset class to load the Frame Data

    Data members:
      frames``ist of frames
      n_imgs: number of iamges in the dataset

    Member functions:
      __init__: ctor
      __len__: returns n_imgs
      __getitem__: returns an frame
  """

  def __init__(self, x):

    self.frames = x

    self.n_imgs = len(self.frames)

    return

  def __len__(self):
    """
      Returns number of frames in the Dataset
    """

    return self.n_imgs

  def __getitem__(self, idx):
    """
      Given an index return a frame
    """

    return torch.tensor(self.frames[idx])

In [35]:
vit_encoder = torchvision.models.vit_b_16(weights=torchvision.models.ViT_B_16_Weights.DEFAULT)

# Create pytorch dataset for instructions
frames_train_dataset = FramesDataset(frames)

# Move pytorch dataset into dataloader 
frames_train_dataloader = DataLoader(frames_train_dataset, batch_size=BATCH_SIZE, shuffle=False)


# Language Encoder

### Language Embedder

In [36]:
def lm_embedder(instruction, encoder):
    #Mean Pooling - Take attention mask into account for correct averaging
    def mean_pooling(model_output, attention_mask):
        token_embeddings = model_output[0] #First element of model_output contains all token embeddings
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

    # Compute token embeddings
    with torch.no_grad():
        lm_output = encoder(**instruction)
        # print(lm_output[0].size())
        # print(lm_output.pooler_output.shape)

    # Perform pooling
    sentence_embeddings = mean_pooling(lm_output, instruction['attention_mask'])
    
    # Normalize embeddings
    sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
        
    return sentence_embeddings

# Image Encoder

### Position Embeddings

### Image Embedder

In [37]:
def img_embedder(frames, encoder):
    with torch.no_grad():
        vit_out = encoder(torch.tensor(frames))

    npads = MAX_FRAMES-len(vit_out)
    pad = torch.ones((npads, vit_out.shape[1]))
    vit_out = torch.cat((vit_out, pad))

    return vit_out

# Create Embeddings

In [38]:
lm_embeddings = []
img_embeddings = []

for i,f in zip(ins_train_dataloader,frames_train_dataloader):
   f = f[0]
   lm_embeddings.append(lm_embedder(i, lm_encoder))
   img_embeddings.append(img_embedder(f, vit_encoder))

  vit_out = encoder(torch.tensor(frames))


# Embeddings Dataset

In [39]:
class EmbeddingsDataset(Dataset):
  """
    Pytorch Dataset class to load the embedded data

    Data members:
      lm_embeddings: list of language model embeddings
      img_embeddings: list of language model embeddings
      n_embs: number of embeddings in the dataset

    Member functions:
      __init__: ctor
      __len__: returns n_ins
      __getitem__: returns an instruction
  """

  def __init__(self, lm_embeddings, img_embeddings, actions):

    self.lm_embeddings = lm_embeddings
    self.img_embeddings = img_embeddings
    self.actions = actions

    self.n_embs = len(self.lm_embeddings)

    return

  def __len__(self):
    """
      Returns number of instructions in the Dataset
    """

    return self.n_embs

  def __getitem__(self, idx):
    """
      Given an index return a instruction at that index
    """

    return {'instruction':self.lm_embeddings[idx], 'frames':self.img_embeddings[idx], 'actions':torch.tensor(self.actions[idx])}

In [40]:
# Create pytorch dataset for instructions
train_dataset = EmbeddingsDataset(lm_embeddings[0:-1],img_embeddings[0:-1], target_actions[0:-1])
val_dataset = EmbeddingsDataset(lm_embeddings[-1:],img_embeddings[-1:], target_actions[-1:])

# Move pytorch dataset into dataloader 
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False)
val_ataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

# Action Decoder

In [54]:
from matplotlib import projections
import math
from typing import Tuple

import torch
from torch import nn, Tensor
import torch.nn.functional as F
from torch.utils.data import dataset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class CausalMatchTransformer(nn.Module):

    # Initialize Model with Params
    def __init__(self, nframes=MAX_FRAMES, blocks=3, nhead=5, emb_dim=VIT_OUT_DIM, classes=3, device=device):
        super().__init__()

        # Device
        self.device = device

        # Embedding Dimension
        self.emb_dim = emb_dim

        # Number of frames
        self.nframes = nframes

        # Frame Position Embedder Layer
        self.pos_emb = nn.Embedding(nframes,emb_dim)
        self.pos_emb.weight = nn.init.xavier_uniform_(self.pos_emb.weight)

        # Instruction Dim Projection Layer
        self.lm_linear_layer = nn.Linear(LM_OUT_DIM, emb_dim).to(device)

        # Decoder Layers
        self.decoder_layer = nn.TransformerDecoderLayer(d_model=emb_dim, nhead=nhead, batch_first=True).to(device)
        self.decoder_layers = _get_clones(self.decoder_layer, blocks)
        
        # Decoder
        self.decoder = nn.TransformerDecoder(self.decoder_layers, num_layers=blocks).to(device)

        # Action classifier
        self.classifier = nn.Linear(emb_dim, classes)

    # Function for forward pass
    def forward(self, instruction, frames, mask):

        output = self.lm_linear_layer(instruction)

        padding_mask = self.generate_pad_mask(batch=frames)

        for i in range(len(frames)):
            frames[i] += self.pos_emb(torch.tensor([i]))[0]

        for layer in self.decoder_layers:
            output = layer(output, frames, memory_mask=mask, memory_key_padding_mask=padding_mask) 

        output = self.classifier(output)

        return output

    # Generates a padding masks for each sequence in a batch
    def generate_pad_mask(self, batch):

        pad_tensor = torch.ones((batch.shape[2])).to(device)

        mask = np.zeros((batch.shape[0],batch.shape[1]))

        for s in range(0, batch.shape[0]):
            for v in range(0, batch[s].shape[0]):
                new_s = torch.all(batch[s][v] == pad_tensor)
                mask[s][v] = new_s

        return torch.tensor(mask).bool().to(self.device)

    
# Creates a list of torch duplicate torch modules
def _get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])

def generate_causal_mask(sz: int) -> Tensor:
    
    return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)



In [55]:
model = CausalMatchTransformer(nframes=MAX_FRAMES,
                               blocks=3,
                               nhead=5,
                               emb_dim=VIT_OUT_DIM,
                               device=device).float().to(device)

# Training

In [56]:
# Configurations
epochs = 5

criterion = nn.CrossEntropyLoss()

lr = 1e-3
optimizer = torch.optim.Adam(
    (p for p in model.parameters() if p.requires_grad), lr=lr
)

mask = generate_causal_mask(MAX_FRAMES).to(device)

torch.manual_seed(0)

<torch._C.Generator at 0x7f97a2217a30>

In [57]:
mask

tensor([[0., -inf, -inf],
        [0., 0., -inf],
        [0., 0., 0.]])

In [58]:
for idx, batch in enumerate(iter(train_dataloader)):

    instruction = batch['instruction']
    frames = batch['frames']
    targets = batch['actions']

    print(frames.shape)
    print(mask.shape)

    prediction = model(instruction, frames, mask)

    # print(prediction)
    

torch.Size([1, 3, 1000])
torch.Size([3, 3])


RuntimeError: The shape of the 2D attn_mask is torch.Size([3, 3]), but should be (1, 3).

In [None]:
# Training and validation loop

# Store the average loss after each epoch so we can plot them.
all_loss = {'train_loss':[], 'val_loss':[]}
all_acc = {'train_acc':[], 'val_acc':[]}

print("starting")
for epoch in range(epochs):
    print(f"epoch={epoch}")

    # Epoch stat trackers
    epoch_loss = 0
    epoch_correct = 0
    epoch_count = 0
    for idx, batch in enumerate(iter(train_dataloader)):

        print(batch['episode'].shape)

        # Clear any past grads
        model.zero_grad()

        # Get predictions
        predictions = model(batch['episode'].float().to(device), mask)

        labels = batch['label'].to(device)

        # Calculate loss
        loss = criterion(predictions, labels)

        # Track stats
        correct = predictions.argmax(axis=1) == labels
        acc = correct.sum().item() / correct.size(0)
        epoch_correct += correct.sum().item()
        epoch_count += correct.size(0)
        epoch_loss += loss.item()

        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)

        optimizer.step()

    # Validate on validation set every 5 epochs
    if (epoch+1) % 5 == 0:
        # Turn off gradient calcs
        with torch.no_grad():
            val_epoch_loss = 0
            val_epoch_correct = 0
            val_epoch_count = 0

            for idx, batch in enumerate(iter(val_dataloader)):
                predictions = model(batch['episode'].float().to(device), mask)
                labels = batch['label'].to(device)

                v_loss = criterion(predictions, labels)

                correct = predictions.argmax(axis=1) == labels
                acc = correct.sum().item() / correct.size(0)

                val_epoch_correct += correct.sum().item()
                val_epoch_count += correct.size(0)
                val_epoch_loss += loss.item()

        # Print losses and accuracies every 5 epochs
        # print(f"val_epoch_loss={val_epoch_loss}")
        # print(f"val epoch accuracy: {val_epoch_correct / val_epoch_count}")
        # print(f"epoch_loss={epoch_loss}")
        # print(f"epoch accuracy: {epoch_correct / epoch_count}")

        # Track loss and acc ever 5 epochs
        avg_train_loss = epoch_loss / len(train_dataloader)
        avg_val_loss = val_epoch_loss / len(val_dataloader)

        all_loss['train_loss'].append(avg_train_loss)
        all_loss['val_loss'].append(avg_val_loss)

        all_acc['train_acc'].append(epoch_correct / epoch_count)
        all_acc['val_acc'].append(val_epoch_correct / val_epoch_count)

# Evaluating