In [1]:
# Imports block
import torch
from einops import rearrange
from torchviz import make_dot
from transformers import VivitImageProcessor, AutoFeatureExtractor

from common.amigos.dataset import AMIGOSDataset
from models.FEEG.base_embedding import ViViTFoundationEmbedder, MiniLMFoundationEmbedder, CBraModFoundationEmbedder
from models.FEEG.layers import PerceiverResampler
from models.FEEG.model import EEGAVI

Some weights of VivitModel were not initialized from the model checkpoint at google/vivit-b-16x2-kinetics400 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  return torch._C._cuda_getDeviceCount() > 0


## Dataset Interaction

In [None]:
ds = AMIGOSDataset("../../resources/AMIGOS/sampled/AMIGOS_sampled.csv")
ds.__getitem__(1)

In [None]:
import numpy as np

a = ds.__getitem__(1)
for i in a:
    if i is None:
        print("None")
    else:
        print(i.shape if isinstance(i, np.ndarray) else len(i))

In [None]:
len(a[0])

In [None]:
np.array(a[0]).shape

## Model
### Structure
We use ```make_dot``` to plot a structure of the actual model. <br>
This step is just to see if the shapes match and there were no mistakes on that behalf.

### ViViT
ViViT accepts only 32 frames sequence inputs. How to operate depends on approach:
- Uniform sampling → evenly pick 32 frames across the whole 4 s (good coverage)
- Random sampling → randomly pick 32 frames (common in training for augmentation).
- Sliding windows → split into multiple 32-frame clips (e.g. 120 frames → 3–4 clips of 32), process each, then average/aggregate.

> Sliding window could be what I need for face expressions altough more costy as I need to compute multiple times ViVIT downstream when feeding the video.

In [None]:
vivit = ViViTFoundationEmbedder()
video = torch.randint(low=0, high=256, size=(32, 3, 224, 224))

processor = VivitImageProcessor.from_pretrained("google/vivit-b-16x2-kinetics400")
a = processor(list(video), return_tensors="pt")

non_reshaped = vivit.base_model(a.pixel_values)
r = vivit(**a)

###  ViViT + Perceiver

In [None]:
print("Shape of video after ViViT:" + str(r.shape))
resampler_video = PerceiverResampler(768, 2, max_num_frames=16, max_num_media=None)(r)
print("After OpenFlamingo: " + str(resampler_video.shape))

In [None]:
# Fuse the time steps with the latent dimension space
R_v = rearrange(resampler_video, "b t l d -> b (t l) d")
R_v.shape

# WavLM

In [11]:
from models.FEEG.base_embedding import W2VBertFoundationEmbedder

audio_processor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")
audio = audio_processor(torch.randn(16000), padding=True, return_tensors='pt')
wavlm = W2VBertFoundationEmbedder()

It is strongly recommended to pass the `sampling_rate` argument to `SeamlessM4TFeatureExtractor()`. Failing to do so can result in silent errors that might be hard to debug.


In [12]:
audio

{'input_features': tensor([[[-0.1592,  0.5624,  1.4083,  ...,  1.4783, -0.8547, -0.8763],
         [ 1.2123, -1.8793, -0.4854,  ...,  0.2755,  0.4648, -0.5399],
         [ 0.4843,  0.2714, -0.1503,  ...,  0.3840, -0.2497, -0.9086],
         ...,
         [-0.4710, -0.8578,  0.9055,  ...,  0.2753,  0.2492,  1.8739],
         [ 0.1926,  0.4783,  0.3498,  ..., -0.8448, -2.1798,  1.5988],
         [ 0.2215,  0.0423, -0.0263,  ..., -0.2740,  0.4701,  0.1755]]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1]], dtype=torch.int32)}

In [14]:
y = wavlm(**audio)
print("Shape after WavLM: " + str(y.shape))
resampled = PerceiverResampler(1024, 2)(y)
print("Simple perceiver:" + str(resampled.shape))

Shape after WavLM: torch.Size([1, 1, 49, 1, 1024])
Simple perceiver:torch.Size([1, 1, 64, 1024])


In [15]:
R_a = rearrange(resampled, "b t l d -> b (t l) d")
R_a.shape

torch.Size([1, 64, 1024])

In [16]:
# The gated attn is fed with the concatenation of the aux embeddings.
torch.cat([R_v, R_a], dim=1).shape

NameError: name 'R_v' is not defined

## MiniLM

In [None]:
minilm = MiniLMFoundationEmbedder()

In [None]:
from transformers import AutoTokenizer

inputs = "This is a text test"
minilm_processor = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
a = minilm_processor(inputs, padding=True, truncation=True, return_tensors='pt')

In [None]:
a

In [None]:
res = minilm(**a)
res.shape

In [None]:
from torch import nn

text_resampled = PerceiverResampler(384, 2)(res)
R_t = rearrange(text_resampled, "b t l d -> b (t l) d")
print(R_t.shape)  # Shape mismatch for building the input

# For mismatch simply project to correct feature space.
o = nn.Linear(384, 768)(R_t)
o.shape

## CBraMod


In [8]:
cbramod = CBraModFoundationEmbedder()
# mock_eeg.shape = (batch_size, num_of_channels, time_segments, points_per_patch)
x_eeg = torch.randn(1, 22, 4, 200)
res = cbramod(x=x_eeg)

In [12]:
res.shape

torch.Size([1, 22, 4, 200])

In [None]:
text_resampled = PerceiverResampler(384, 2)(res)

## Main Model

In [5]:
# Model initialization
model = EEGAVI(
    resampler_depth=2,
    text_kd_size=600,
    video_kd_size=600,
    audio_kd_size=600
)

processor = VivitImageProcessor.from_pretrained("google/vivit-b-16x2-kinetics400")
# Build input x as:
video = torch.randint(low=0, high=256, size=(32, 3, 224, 224))
x_vid = processor(list(video), return_tensors="pt")

processor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")
x_aud = processor(torch.randn(1, 16000), return_tensors="pt")

x_tex = None  # We try without text

# mock_eeg.shape = (batch_size, num_of_channels, time_segments, points_per_patch)
x_eeg = torch.randn(1, 22, 4, 200)

y = model(({"x":x_eeg}, x_vid, x_aud, x_tex))

make_dot(y.mean(), params=dict(model.named_parameters()), show_attrs=True, show_saved=True)

It is strongly recommended to pass the `sampling_rate` argument to `SeamlessM4TFeatureExtractor()`. Failing to do so can result in silent errors that might be hard to debug.


KeyboardInterrupt: 