### Imports

In [1]:
import numpy as np
import pandas as pd
import os
from PIL import Image
import json

import torch
from torch import nn, Tensor
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
import torchvision
from tqdm.notebook import tqdm
import copy





### Constants

In [2]:
DIR = 'datasets\\trials'
BATCH_SIZE = 1
MAX_FRAMES = 2
VIT_OUT_DIM = 1000
LM_OUT_DIM = 768
ACT_TOKEN = '[ACT]'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
np.array([1,2,3]).shape

(3,)

# Datasets

In [4]:
frames = []
infos = []

for trial_fp in os.listdir(DIR):
    trial_fp = os.path.join(DIR, trial_fp)
    imgs = []
    for fp in os.listdir(trial_fp):
        fp = os.path.join(trial_fp, fp)
        if fp[-4:] == '.png':
            imgs.append(np.rollaxis(np.array(Image.open(fp), dtype=np.float32),2,0))
        else:
            infos.append(json.load(open(fp)))
    frames.append(np.array(imgs))

In [5]:
instructions = [x['instruction'] for x in infos]
target_actions = [x['answers'] for x in infos]

## Instruction Dataset

In [6]:
class InstructionsDataset(Dataset):
  """
    Pytorch Dataset class to load the Instructions Data

    Data members:
      instructions: list of instructions
      n_ins: number of instructions in the dataset

    Member functions:
      __init__: ctor
      __len__: returns n_ins
      __getitem__: returns an instruction
  """

  def __init__(self, x):

    self.instructions = x

    self.n_ins = len(self.instructions)

    return

  def __len__(self):
    """
      Returns number of instructions in the Dataset
    """

    return self.n_ins

  def __getitem__(self, idx):
    """
      Given an index return a instruction at that index
    """

    return self.instructions[idx]

In [7]:
class InstructionsCollator(object):
  """
    Data Collator used for GPT2 in a classificaiton tasks

    Args:
      use_tokenizer :
        Transformer type tokenizer used to process raw text into numbers.

    Data members:
      use_tokenizer: Tokenizer to be used inside the class.

    Member functions:
      __init__: ctor
      __call__: tokenize input

    """

  def __init__(self, use_tokenizer):

    self.use_tokenizer = use_tokenizer

    return

  def __call__(self, instructions):
    """
        Tokenizes input
    """

    # Call tokenizer
    inputs = self.use_tokenizer(instructions, padding=True, truncation=True, return_tensors='pt')


    return inputs


In [8]:
lm_encoder = AutoModel.from_pretrained('sentence-transformers/all-mpnet-base-v2')
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2')

# Create data collator to encode text and labels into numbers.
InstructionsCollator = InstructionsCollator(use_tokenizer=tokenizer)

# Create pytorch dataset for instructions
ins_train_dataset = InstructionsDataset(instructions)

# Move pytorch dataset into dataloader 
ins_train_dataloader = DataLoader(ins_train_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=InstructionsCollator)


## Frames Dataset

In [9]:
class FramesDataset(Dataset):
  """
    Pytorch Dataset class to load the Frame Data

    Data members:
      frames``ist of frames
      n_imgs: number of iamges in the dataset

    Member functions:
      __init__: ctor
      __len__: returns n_imgs
      __getitem__: returns an frame
  """

  def __init__(self, x):

    self.frames = x

    self.n_imgs = len(self.frames)

    return

  def __len__(self):
    """
      Returns number of frames in the Dataset
    """

    return self.n_imgs

  def __getitem__(self, idx):
    """
      Given an index return a frame
    """

    return torch.tensor(self.frames[idx])

In [10]:
vit_encoder = torchvision.models.vit_b_16(weights=torchvision.models.ViT_B_16_Weights.DEFAULT)

# Create pytorch dataset for instructions
frames_train_dataset = FramesDataset(frames)

# Move pytorch dataset into dataloader 
frames_train_dataloader = DataLoader(frames_train_dataset, batch_size=BATCH_SIZE, shuffle=False)


# Language Encoder

### Language Embedder

In [11]:
def lm_embedder(instruction, encoder):
    #Mean Pooling - Take attention mask into account for correct averaging
    def mean_pooling(model_output, attention_mask):
        token_embeddings = model_output[0] #First element of model_output contains all token embeddings
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)


    # Compute token embeddings
    with torch.no_grad():
        lm_output = encoder(**instruction)
    # Perform pooling
    sentence_embeddings = mean_pooling(lm_output, instruction['attention_mask'])
    
    # Normalize embeddings
    sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
        
    return sentence_embeddings

# Image Encoder

### Position Embeddings

### Image Embedder

In [12]:
def img_embedder(frames, encoder):
    with torch.no_grad():
        vit_out = encoder(torch.tensor(frames))
    return vit_out

##

# Create Embeddings

In [13]:
lm_embeddings = []
img_embeddings = []

for i,f in zip(ins_train_dataloader,frames_train_dataloader):
   f = f[0]
   lm_embeddings.append(lm_embedder(i, lm_encoder))
   img_embeddings.append(img_embedder(f, vit_encoder))

  vit_out = encoder(torch.tensor(frames))


In [14]:
lm_embeddings[0].shape

torch.Size([1, 768])

In [15]:
img_embeddings[0].shape

torch.Size([2, 1000])

# Action Decoder

In [16]:
from matplotlib import projections
import math
from typing import Tuple

import torch
from torch import nn, Tensor
import torch.nn.functional as F
from torch.utils.data import dataset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class CausalMatchTransformer(nn.Module):

    # Initialize Model with Params
    def __init__(self, nframes=MAX_FRAMES, blocks=3, nhead=5, emb_dim=VIT_OUT_DIM, device=device):
        super().__init__()

        # Device
        self.device = device

        # Number of frames
        self.nframes = nframes

        # Frame Position Embedder Layer
        self.pos_emb = nn.Embedding(nframes,emb_dim)
        self.pos_emb.weight = nn.init.xavier_uniform_(self.pos_emb.weight)

        # Instruction Dim Projection Layer
        self.lm_linear_layer = nn.Linear(LM_OUT_DIM, emb_dim).to(device)

        # Decoder Layers
        self.decoder_layer = nn.TransformerDecoderLayer(d_model=emb_dim, nhead=nhead, batch_first=True).to(device)
        self.decoder_layers = _get_clones(self.decoder_layer, blocks)
        
        # Decoder
        self.decoder = nn.TransformerDecoder(self.decoder_layers, num_layers=blocks).to(device)

    # Function for forward pass
    def forward(self, instruction, frames, mask):

        output = self.lm_linear_layer(instruction)

        for i in range(len(frames)):
            frames[i] += self.pos_emb(torch.tensor([i]))[0]

        for layer in self.decoder_layers:
            output = layer(output, frames, memory_mask=mask, memory_is_causal=True) 
            # (tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None, tgt_is_causal=False, memory_is_causal=False)

        return output

    # Generates a padding masks for each sequence in a batch
    def generate_pad_mask(self, batch):

        pad_tensor = torch.ones((batch.shape[2])).to(device)

        mask = np.zeros((batch.shape[0],batch.shape[1]))

        for s in range(0, batch.shape[0]):
            for v in range(0, batch[s].shape[0]):
                new_s = torch.all(batch[s][v] == pad_tensor)
                mask[s][v] = new_s

        return torch.tensor(mask).bool().to(self.device)
    
# Creates a list of torch duplicate torch modules
def _get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])

def generate_causal_mask(sz: int) -> Tensor:
    
    return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)

def generate_pad_mask(batch):

    pad_tensor = torch.ones((batch.shape[2])).to(device)
    
    mask = np.zeros((batch.shape[0],batch.shape[1]))

    for s in range(0, batch.shape[0]):
        for v in range(0, batch[s].shape[0]):
            new_s = torch.all(batch[s][v] == pad_tensor)
            mask[s][v] = new_s

    return torch.tensor(mask).bool()

In [17]:
mask = generate_causal_mask(2)

In [18]:

model = CausalMatchTransformer(nframes=MAX_FRAMES,
                               blocks=3,
                               nhead=5,
                               emb_dim=VIT_OUT_DIM,
                               device=device).float().to(device)

# Configurations
epochs = 50

criterion = nn.CrossEntropyLoss()

lr = 1e-3
optimizer = torch.optim.Adam(
    (p for p in model.parameters() if p.requires_grad), lr=lr
)

# mask = generate_causal_mask(nframes).to(device)

torch.manual_seed(0)

<torch._C.Generator at 0x214ed820d70>

In [19]:
model(lm_embeddings[0], img_embeddings[0], mask)

tensor([[-0.6380, -0.2778, -0.4223,  ..., -0.9591, -0.6478,  0.9478],
        [-0.8547, -0.1672, -0.2223,  ..., -1.2684, -0.4001,  0.6226]])
tensor([[-0.6701, -0.2735, -0.4988,  ..., -1.0212, -0.6825,  0.9605],
        [-0.8064, -0.0988, -0.2203,  ..., -1.2056, -0.3230,  0.6901]],
       grad_fn=<CopySlices>)


tensor([[ 0.8207,  0.1376, -0.8893, -0.0928,  1.1168, -2.0958, -0.4657, -0.5178,
          0.1308, -1.0753, -0.1561, -1.0498,  0.8365, -1.2998,  1.9882,  1.1052,
         -1.2403, -1.1389,  1.3018,  0.5328,  0.3162,  0.6033,  1.4248,  0.9096,
          2.1418,  0.6704, -0.0414, -0.5233, -0.8893, -0.4926, -0.3808,  0.3360,
          0.1819, -2.3062, -0.0175, -1.2074,  0.3069,  0.7717, -1.2177,  0.6617,
          0.3042, -1.6231, -0.5853,  0.2076, -0.2835,  0.5075, -1.1996,  0.4792,
         -1.3614,  0.3271, -0.2035, -1.1899, -0.7082,  0.7151,  0.2385, -1.1350,
         -1.7064,  0.4775,  0.5305,  0.6212,  0.4186,  0.1375, -1.2108, -0.2361,
          0.4878, -0.1315, -1.1718,  0.6085, -1.1749,  0.1685, -0.2346,  1.1782,
          0.0969,  0.3035,  1.0352, -1.0169,  0.6193,  1.3302,  0.3539,  0.8594,
         -0.3301, -0.9113, -1.0015, -0.5427,  0.6428, -0.3642,  1.0155, -0.3540,
         -1.3382, -1.3834, -1.2648, -0.5145, -1.5135, -1.1487, -1.7215,  0.7069,
          0.1486,  0.9219,  

# Evaluating