### Imports

In [113]:
import numpy as np
import pandas as pd
import os
from PIL import Image
import json

import torch
from torch import nn, Tensor
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
import torchvision
from tqdm.notebook import tqdm






### Constants

In [114]:
dir = 'datasets\\trials'
batch_size = 2

# Datasets

In [115]:
frames = []
infos = []

for trial_fp in os.listdir(dir):
    trial_fp = os.path.join(dir, trial_fp)
    imgs = []
    for fp in os.listdir(trial_fp):
        fp = os.path.join(trial_fp, fp)
        if fp[-4:] == '.png':
            imgs.append(np.array(Image.open(fp)))
        else:
            infos.append(json.load(open(fp)))
    frames.append(imgs)

In [116]:
instructions = [x['instruction'] for x in infos]
target_actions = [x['answers'] for x in infos]

## Instruction Dataset

In [117]:
class InstructionsDataset(Dataset):
  """
    Pytorch Dataset class to load the Instructions Data

    Data members:
      instructions: list of instructions
      n_ins: number of instructions in the dataset

    Member functions:
      __init__: ctor
      __len__: returns n_ins
      __getitem__: returns an instruction
  """

  def __init__(self, x):

    self.instructions = x

    self.n_ins = len(self.instructions)

    return

  def __len__(self):
    """
      Returns number of instructions in the Dataset
    """

    return self.n_ins

  def __getitem__(self, idx):
    """
      Given an index return a instruction at that index
    """

    return self.instructions[idx]

In [118]:
class InstructionsCollator(object):
  """
    Data Collator used for GPT2 in a classificaiton tasks

    Args:
      use_tokenizer :
        Transformer type tokenizer used to process raw text into numbers.

    Data members:
      use_tokenizer: Tokenizer to be used inside the class.

    Member functions:
      __init__: ctor
      __call__: tokenize input

    """

  def __init__(self, use_tokenizer):

    self.use_tokenizer = use_tokenizer

    return

  def __call__(self, instructions):
    """
        Tokenizes input
    """

    # Call tokenizer
    inputs = self.use_tokenizer(instructions, padding=True, truncation=True, return_tensors='pt')


    return inputs


In [119]:
lm_encoder = AutoModel.from_pretrained('sentence-transformers/all-mpnet-base-v2')
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2')

# Create data collator to encode text and labels into numbers.
InstructionsCollator = InstructionsCollator(use_tokenizer=tokenizer)

# Create pytorch dataset for instructions
ins_train_dataset = InstructionsDataset(instructions)

# Move pytorch dataset into dataloader 
ins_train_dataloader = DataLoader(ins_train_dataset, batch_size=batch_size, shuffle=False, collate_fn=InstructionsCollator)


## Frames Dataset

In [120]:
class FramesDataset(Dataset):
  """
    Pytorch Dataset class to load the Frame Data

    Data members:
      frames``ist of frames
      n_imgs: number of iamges in the dataset

    Member functions:
      __init__: ctor
      __len__: returns n_imgs
      __getitem__: returns an frame
  """

  def __init__(self, x):

    self.frames = x

    self.n_imgs = len(self.frames)

    return

  def __len__(self):
    """
      Returns number of frames in the Dataset
    """

    return self.n_imgs

  def __getitem__(self, idx):
    """
      Given an index return a frame
    """

    return self.instructions[idx]

In [121]:
vit_encoder = torchvision.models.vit_b_16(weights=torchvision.models.ViT_B_16_Weights.DEFAULT)

# Create pytorch dataset for instructions
frames_train_dataset = FramesDataset(frames)

# Move pytorch dataset into dataloader 
frames_train_dataloader = DataLoader(frames_train_dataset, batch_size=batch_size, shuffle=False)


# Language Encoder

In [122]:
def lm_embedder(batch, batch_size, encoder):
    #Mean Pooling - Take attention mask into account for correct averaging
    def mean_pooling(model_output, attention_mask):
        token_embeddings = model_output[0] #First element of model_output contains all token embeddings
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)


    # Compute token embeddings
    with torch.no_grad():
        lm_output = encoder(**batch)
    # Perform pooling
    sentence_embeddings = mean_pooling(lm_output, batch['attention_mask'])

    # Normalize embeddings
    sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
        
    return sentence_embeddings

# Image Encoder

In [134]:
vit_encoder([frames[0][0]])

AssertionError: Wrong image width! Expected 224 but got 3!

# Action Decoder

In [123]:
lm_embeddings = []

for b,v in zip(ins_train_dataloader,ins_train_dataloader):
   y = lm_embedder(b, 1, lm_encoder)
   lm_embeddings.append(lm_embedder(b, 1, lm_encoder))
   print(b,v)

{'input_ids': tensor([[    0, 11953,  4878,  1019,  1014, 11953,  4878,  1020,  1014,  8844,
          2282,  2001,  4878,  1019,  5024,  8844,  2282,  2001,  4878,  1020,
          1033,     2],
        [    0, 11953,  4878,  1019,  1014, 11953,  4878,  1020,  1014,  8844,
          2282,  2001,  4878,  1019,  5024,  8844,  2282,  2001,  4878,  1020,
          1033,     2]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])} {'input_ids': tensor([[    0, 11953,  4878,  1019,  1014, 11953,  4878,  1020,  1014,  8844,
          2282,  2001,  4878,  1019,  5024,  8844,  2282,  2001,  4878,  1020,
          1033,     2],
        [    0, 11953,  4878,  1019,  1014, 11953,  4878,  1020,  1014,  8844,
          2282,  2001,  4878,  1019,  5024,  8844,  2282,  2001,  4878,  1020,
          1033,     2]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,