In [7]:
# Imports block
import torch
from transformers import VivitImageProcessor

from common.amigos.dataset import AMIGOSDataset
from models.FEEG.base_embedding import BaseEmbedding

from torchviz import make_dot
from models.FEEG.model import EEGAVI

## Dataset Interaction

In [None]:

ds = AMIGOSDataset("../../resources/AMIGOS/sampled/AMIGOS_sampled.csv")
ds.__getitem__(1)

In [None]:
import numpy as np

a = ds.__getitem__(1)
for i in a:
    if i is None:
        print("None")
    else:
        print(i.shape if isinstance(i, np.ndarray) else len(i))

In [None]:
len(a[0])

In [None]:
np.array(a[0]).shape

In [None]:
eeg = BaseEmbedding.get_cbramod_base()

## Model
### Structure
We use ```make_dot``` to plot a structure of the actual model. <br>
This step is just to see if the shapes match and there were no mistakes on that behalf.

### ViViT
ViViT accepts only 32 frames sequence inputs. How to operate depends on approach:
- Uniform sampling → evenly pick 32 frames across the whole 4 s (good coverage)
- Random sampling → randomly pick 32 frames (common in training for augmentation).
- Sliding windows → split into multiple 32-frame clips (e.g. 120 frames → 3–4 clips of 32), process each, then average/aggregate.

> Sliding window could be what I need for face expressions altough more costy as I need to compute multiple times ViVIT downstream when feeding the video.

In [25]:
vivit = BaseEmbedding.get_ViViT_base().model
processor = VivitImageProcessor.from_pretrained("google/vivit-b-16x2-kinetics400")
video = torch.randint(low=0, high=256, size=(32, 3, 224, 224))
a = processor(list(video), return_tensors="pt")
r = vivit(**a)

Some weights of VivitModel were not initialized from the model checkpoint at google/vivit-b-16x2-kinetics400 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [26]:
print(r.last_hidden_state.shape)
# So I could use Patch tokens for mid-fusion from this call? Yes
# As 3137 contains: 3136 patch tokens + 1 [CLS] token.
# I could opt for Flamingo?

torch.Size([1, 3137, 768])


In [28]:
from torch import nn

# TODO: HERE
# I could work with the full video shape. What does VATE give me back tho?
nn.Linear(768, 400)(r.last_hidden_state).shape

torch.Size([1, 3137, 400])

In [5]:
res.last_hidden_state

tensor([[[ 0.5893,  0.3459,  0.9137,  ..., -0.4443, -0.7825,  0.0405],
         [ 0.3717,  0.0462,  0.3935,  ..., -0.4207, -0.1892,  0.0375],
         [ 0.5313, -0.5689,  0.3701,  ..., -0.2536, -0.8913,  0.0213],
         ...,
         [ 0.6091,  0.2223,  0.1393,  ..., -0.0443, -0.8038, -0.1764],
         [ 0.2583, -0.0116,  0.7239,  ..., -0.0897, -0.8880, -0.1585],
         [ 0.6755,  0.4019,  0.0686,  ..., -0.2835, -0.0257,  0.1094]],

        [[ 0.6270,  0.3154,  0.9685,  ..., -0.4629, -0.8222,  0.0963],
         [ 0.5641, -0.9205,  0.5163,  ..., -0.5630, -0.0640, -0.0462],
         [ 0.7916, -0.4010, -0.3732,  ..., -0.0926, -0.5847,  0.0397],
         ...,
         [ 0.6397,  0.1159, -0.2169,  ...,  0.1289, -0.4131, -0.1863],
         [ 0.9740, -0.2264,  0.1516,  ..., -0.2253, -0.5794,  0.0424],
         [ 0.0617, -0.0646,  0.3379,  ..., -0.1600, -0.8507,  0.2688]]],
       grad_fn=<NativeLayerNormBackward0>)

In [None]:
model = EEGAVI()  # Model initialization

# Build input x as:
x_vid = torch.randn(2, 32, 3, 224, 224)
x_aud = torch.randn(2, 16000)
x_tex = None  # We try without text

# mock_eeg.shape = (batch_size, num_of_channels, time_segments, points_per_patch)
x_eeg = torch.randn(8, 22, 4, 200)

y = model((x_eeg, x_vid, x_aud, x_tex))

make_dot(y.mean(), params=dict(model.named_parameters()), show_attrs=True, show_saved=True)