In [4]:
from vit_encoder import ViT
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from utils.util import count_model_params, train_epoch,eval_model,train_model
import os
import shutil
from torch.utils.tensorboard import SummaryWriter
from loader.Dataset import VideoDataset 
from torch.utils.data import DataLoader
from loader.transforms import RGBNormalizer,Composition,CustomResize,RandomHorizontalFlip,RandomVerticalFlip,CustomColorJitter

%load_ext autoreload
%autoreload 2

data_path='/home/nfs/inf6/data/datasets/MOVi/movi_c/'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
max_objects_in_scene=11
batch_size=64
img_height=64
img_width=64
channels=3
original_number_of_frames_per_video=24
selected_number_of_frames_per_video=4


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [6]:
#train_dataset=VideoDataset(data_path,split='train',number_of_frames_per_video=number_of_frames_per_video,max_objects_in_scene=max_objects_in_scene) 
#train_loader = DataLoader(dataset=train_dataset,
#                            batch_size=batch_size,
#                            shuffle=True)

transform_composition = Composition([
                                        RGBNormalizer(),
                                        CustomResize((img_height,img_width)),
                                        RandomVerticalFlip(0.6),
                                        RandomHorizontalFlip(0.6),
                                        CustomColorJitter(
                                            brightness=(0.8, 1.2),
                                            hue=(-0.3, 0.3),
                                            contrast=(0.6, 1.8),
                                            saturation=(0.5, 1.5)
                                        )
                                    ])
test_dataset=VideoDataset(data_path,
                            split='validation',
                            max_objects_in_scene=max_objects_in_scene,
                            halve_dataset=True,
                            is_test_dataset=True,
                            transforms=transform_composition,
                            original_number_of_frames_per_video=original_number_of_frames_per_video,
                            selected_number_of_frames_per_video=selected_number_of_frames_per_video)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 
                                            batch_size=batch_size, 
                                            shuffle=False) 

bboxs,masks,rgbs=next(iter(test_loader))
print(f"Shapes: >>>>>>>>>>>>>>>>> \r\n{bboxs.shape=}, \r\n{masks.shape=}, \r\n{rgbs.shape=}, \r\n<<<<<<<<<<<<<<<<<<")

Shapes: >>>>>>>>>>>>>>>>> 
bboxs.shape=torch.Size([64, 4, 11, 4]), 
masks.shape=torch.Size([64, 4, 64, 64]), 
rgbs.shape=torch.Size([64, 4, 3, 64, 64]), 
<<<<<<<<<<<<<<<<<<


In [7]:
vit = ViT(
        img_height=img_height,
        img_width=img_width,
        channels=channels,
        frame_numbers=selected_number_of_frames_per_video,
        token_dim=128,
        attn_dim=128,
        num_heads=4,
        mlp_size=512,
        num_tf_layers=4,
        max_objects_in_scene=11).to(device)
print(f"ViT has {count_model_params(vit)} parameters")
vit

ViT has 2390784 parameters


ViT(
  (patch_projection): Sequential(
    (0): LayerNorm((12288,), eps=1e-05, elementwise_affine=True)
    (1): Linear(in_features=12288, out_features=128, bias=True)
  )
  (pos_emb): PositionalEncoding()
  (encoderBlocks): Sequential(
    (0): EncoderBlock(
      (ln_att): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
      (attn): MultiHeadAttention(
        (q): Linear(in_features=128, out_features=128, bias=True)
        (k): Linear(in_features=128, out_features=128, bias=True)
        (v): Linear(in_features=128, out_features=128, bias=True)
        (out_proj): Linear(in_features=128, out_features=128, bias=True)
      )
      (ln_mlp): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
      (mlp): MLP(
        (mlp): Sequential(
          (0): Linear(in_features=128, out_features=512, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=512, out_features=128, bias=True)
        )
      )
    )
    (1): EncoderBlock(
      (ln_att): Lay

In [8]:
!nvidia-smi
#!kill -9 2418075

Mon Sep 29 17:44:35 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.158.01             Driver Version: 570.158.01     CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 3090        Off |   00000000:01:00.0 Off |                  N/A |
| 34%   40C    P2            101W /  350W |   20186MiB /  24576MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [9]:
with torch.no_grad():
    y = vit(rgbs,masks=masks)
attn_maps = vit.get_attn_mask()
print(f"Input Shape: {rgbs.shape}")
print(f"Output Shape: {y.shape}")
print(f"Found {len(attn_maps)} Attn Masps of shape {attn_maps[0].shape}")

Input Shape: torch.Size([64, 4, 3, 64, 64])
Output Shape: torch.Size([64, 4, 128])
Found 4 Attn Masps of shape torch.Size([256, 4, 12, 12])
