In [1]:
# Imports block
import torch
from torchviz import make_dot
from transformers import VivitImageProcessor, AutoProcessor
from models.FEEG.base_embedding import ViViTFoundationEmbedder, MiniLMFoundationEmbedder
from common.amigos.dataset import AMIGOSDataset
from models.FEEG.layers import PerceiverResampler, SimplePerceiverResampler
from models.FEEG.model import EEGAVI
from einops import rearrange

Some weights of VivitModel were not initialized from the model checkpoint at google/vivit-b-16x2-kinetics400 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  return torch._C._cuda_getDeviceCount() > 0


## Dataset Interaction

In [None]:
ds = AMIGOSDataset("../../resources/AMIGOS/sampled/AMIGOS_sampled.csv")
ds.__getitem__(1)

In [None]:
import numpy as np

a = ds.__getitem__(1)
for i in a:
    if i is None:
        print("None")
    else:
        print(i.shape if isinstance(i, np.ndarray) else len(i))

In [None]:
len(a[0])

In [None]:
np.array(a[0]).shape

## Model
### Structure
We use ```make_dot``` to plot a structure of the actual model. <br>
This step is just to see if the shapes match and there were no mistakes on that behalf.

### ViViT
ViViT accepts only 32 frames sequence inputs. How to operate depends on approach:
- Uniform sampling → evenly pick 32 frames across the whole 4 s (good coverage)
- Random sampling → randomly pick 32 frames (common in training for augmentation).
- Sliding windows → split into multiple 32-frame clips (e.g. 120 frames → 3–4 clips of 32), process each, then average/aggregate.

> Sliding window could be what I need for face expressions altough more costy as I need to compute multiple times ViVIT downstream when feeding the video.

In [3]:
vivit = ViViTFoundationEmbedder()
video = torch.randint(low=0, high=256, size=(32, 3, 224, 224))

processor = VivitImageProcessor.from_pretrained("google/vivit-b-16x2-kinetics400")
a = processor(list(video), return_tensors="pt")

non_reshaped = vivit.base_model(a.pixel_values)
r = vivit(**a)

Some weights of VivitModel were not initialized from the model checkpoint at google/vivit-b-16x2-kinetics400 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


###  ViViT + Perceiver

In [5]:
print("Shape of video after ViViT:" + str(r.shape))
other = SimplePerceiverResampler(768, 2)(non_reshaped.last_hidden_state)
print("Shape of processed after simple" + str(other.shape))

resampler_video = PerceiverResampler(768, 2, max_num_frames=16, max_num_media=None)(r)
print("After OpenFlamingo: " + str(resampler_video.shape))

Shape of video after ViViT:torch.Size([1, 32, 1, 98, 768])
Shape of processed after simpletorch.Size([1, 1, 64, 768])
After OpenFlamingo: torch.Size([1, 32, 64, 768])


In [6]:
# Fuse the time steps with the latent dimension space
R_v = rearrange(resampler_video, "b t l d -> b (t l) d")
R_v.shape

torch.Size([1, 2048, 768])

# WavLM

In [7]:
from models.FEEG.base_embedding import WavLMFoundationEmbedder

audio = torch.randn(1, 16000)
wavlm = WavLMFoundationEmbedder()

In [8]:
y = wavlm(audio)
print("Shape after WavLM: " + str(y.shape))
resampled = PerceiverResampler(768, 2)(y)
print("Simple perceiver:" + str(resampled.shape))

Shape after WavLM: torch.Size([1, 1, 49, 1, 768])
Simple perceiver:torch.Size([1, 1, 64, 768])


In [9]:
R_a = rearrange(resampled, "b t l d -> b (t l) d")
R_a.shape

torch.Size([1, 64, 768])

In [10]:
# The gated attn is fed with the concatenation of the aux embeddings.
torch.cat([R_v, R_a], dim=1).shape

torch.Size([1, 2112, 768])

## MiniLM

In [11]:
minilm = MiniLMFoundationEmbedder()

In [12]:
from transformers import AutoTokenizer

inputs = "This is a text test"
minilm_processor = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
a = minilm_processor(inputs, padding=True, truncation=True, return_tensors='pt')

In [13]:
a

{'input_ids': tensor([[ 101, 2023, 2003, 1037, 3793, 3231,  102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1]])}

In [15]:
res = minilm(**a)
res.shape

torch.Size([1, 1, 7, 1, 384])

In [18]:
from torch import nn

text_resampled = PerceiverResampler(384, 2)(res)
R_t = rearrange(text_resampled, "b t l d -> b (t l) d")
print(R_t.shape)  # Shape mismatch for building the input

# For mismatch simply project to correct feature space.
o = nn.Linear(384, 768)(R_t)
o.shape

torch.Size([1, 64, 384])


torch.Size([1, 64, 768])

## Main Model

In [None]:
model = EEGAVI()  # Model initialization

# Build input x as:
x_vid = torch.randn(2, 32, 3, 224, 224)
x_aud = torch.randn(2, 16000)
x_tex = None  # We try without text

# mock_eeg.shape = (batch_size, num_of_channels, time_segments, points_per_patch)
x_eeg = torch.randn(8, 22, 4, 200)

y = model((x_eeg, x_vid, x_aud, x_tex))

make_dot(y.mean(), params=dict(model.named_parameters()), show_attrs=True, show_saved=True)