In [6]:
from vit_like_encoder import ViT_ImageBased
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from utils.util import count_model_params, train_epoch,eval_model,train_model
import os
import shutil
from torch.utils.tensorboard import SummaryWriter
from loader.Dataset import VideoDataset 
from torch.utils.data import DataLoader
from loader.transforms import RGBNormalizer,Composition,CustomResize,RandomHorizontalFlip,RandomVerticalFlip,CustomColorJitter

%load_ext autoreload
%autoreload 2

data_path='/home/nfs/inf6/data/datasets/MOVi/movi_c/'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size=64
img_height=64
img_width=64
channels=3
original_number_of_frames_per_video=24
selected_number_of_frames_per_video=4
patch_size=32

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
from loader.Dataset_image_based import VideoDataset_ImageBased


transform_composition = Composition([
                                        RGBNormalizer(),
                                        CustomResize((img_height,img_width)),
                                        RandomVerticalFlip(0.6),
                                        RandomHorizontalFlip(0.6),
                                        CustomColorJitter(
                                            brightness=(0.8, 1.2),
                                            hue=(-0.3, 0.3),
                                            contrast=(0.6, 1.8),
                                            saturation=(0.5, 1.5)
                                        )
                                    ])
test_dataset=VideoDataset_ImageBased(data_path,
                            split='validation',
                            halve_dataset=True,
                            is_test_dataset=True,
                            #transforms=transform_composition,
                            original_number_of_frames_per_video=original_number_of_frames_per_video,
                            selected_number_of_frames_per_video=selected_number_of_frames_per_video)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 
                                            batch_size=batch_size, 
                                            shuffle=False) 

rgbs=next(iter(test_loader))
print(f"Shapes: >>>>>>>>>>>>>>>>> \r\n{rgbs.shape=}, \r\n<<<<<<<<<<<<<<<<<<")

Shapes: >>>>>>>>>>>>>>>>> 
rgbs.shape=torch.Size([21, 4, 3, 128, 128]), 
<<<<<<<<<<<<<<<<<<


In [9]:
vit = ViT_ImageBased(
        img_size=rgbs.shape[-2], frame_numbers=selected_number_of_frames_per_video,patch_size=patch_size, token_dim=128, attn_dim=128, num_heads=4, mlp_size=512, num_tf_layers=4).to(device)
print(f"ViT has {count_model_params(vit)} parameters")
vit

ViT has 1192704 parameters


ViT_ImageBased(
  (patch_projection): Sequential(
    (0): LayerNorm((3072,), eps=1e-05, elementwise_affine=True)
    (1): Linear(in_features=3072, out_features=128, bias=True)
  )
  (pos_emb): PositionalEncoding()
  (transformer_blocks): Sequential(
    (0): EncoderBlock(
      (ln_att): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
      (attn): MultiHeadAttention(
        (q): Linear(in_features=128, out_features=128, bias=True)
        (k): Linear(in_features=128, out_features=128, bias=True)
        (v): Linear(in_features=128, out_features=128, bias=True)
        (out_proj): Linear(in_features=128, out_features=128, bias=True)
      )
      (ln_mlp): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
      (mlp): MLP(
        (mlp): Sequential(
          (0): Linear(in_features=128, out_features=512, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=512, out_features=128, bias=True)
        )
      )
    )
    (1): EncoderBlock(
     

In [10]:
with torch.no_grad():
    y = vit(rgbs.to(device))
attn_maps = vit.get_attn_mask()
print(f"Input Shape: {rgbs.shape}")
print(f"Encoder Output Shape: {y.shape}")
print(f"Found {len(attn_maps)} Attn Masps of shape {attn_maps[0].shape}")

Input Shape: torch.Size([21, 4, 3, 128, 128])
Encoder Output Shape: torch.Size([21, 4, 128])
Found 4 Attn Masps of shape torch.Size([84, 4, 17, 17])
