In [1]:
!nvidia-smi
#!kill 3954151

Thu Sep 25 20:58:36 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.158.01             Driver Version: 570.158.01     CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 3090        Off |   00000000:01:00.0 Off |                  N/A |
| 34%   39C    P8             27W /  350W |     321MiB /  24576MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [2]:
import sys
import os
import shutil
sys.path.append(os.path.abspath(".."))
from loader.Dataset import VideoDataset
from torch.utils.data import DataLoader
from models.transformers.encoders.vit_encoder import ViT
from models.transformers.decoders.vit_decoder import ViT_Decoder
from models.transformers.CustomTransformer import CustomizableTransformer
from utils.util import count_model_params, train_epoch,eval_model,train_model
from matplotlib import pyplot as plt
from matplotlib import patches
import torch
from utils.util import l1_and_ssim_loss_function
from torch.utils.tensorboard import SummaryWriter

%load_ext autoreload
%autoreload 2

#### configs

In [3]:
general_configs={
"data_path":"/home/nfs/inf6/data/datasets/MOVi/movi_c/",
"number_of_frames_per_video":24,
"max_objects_in_scene":11,
"batch_size":64,
"device" : torch.device("cuda" if torch.cuda.is_available() else "cpu"),
"img_height":128,
"img_width":128,
"channels":3,
"learning_rate":3e-4,
"num_epochs":60,
"trainingMode":0 #0 training by bounding boxes, 1 training by masks. 
}

encoder_configs={
        "token_dim":128,
        "attn_dim":128,
        "num_heads":4,
        "mlp_size":512,
        "num_tf_layers":4
}
decoder_configs={
        "token_dim":128,
        "attn_dim":128,
        "num_heads":4,
        "mlp_size":512,
        "num_tf_layers":4
}


#### Datasets

In [4]:
validation_dataset = VideoDataset(data_path=general_configs["data_path"],
                            split='validation',
                            number_of_frames_per_video=general_configs["number_of_frames_per_video"],
                            max_objects_in_scene=general_configs["max_objects_in_scene"])
valid_loader = DataLoader(dataset=validation_dataset,
                            batch_size=general_configs["batch_size"],
                            shuffle=False)
train_dataset = VideoDataset(data_path=general_configs["data_path"],
                            split='train',
                            number_of_frames_per_video=general_configs["number_of_frames_per_video"],
                            max_objects_in_scene=general_configs["max_objects_in_scene"])
train_loader = DataLoader(dataset=train_dataset,
                            batch_size=general_configs["batch_size"],
                            shuffle=True)

Data Loaded Successfully: len(self.coord_addresses)=250, len(self.mask_addresses)=250, len(self.rgb_addresses)=6000, len(self.flow_addresses)=6000
Data Loaded Successfully: len(self.coord_addresses)=9737, len(self.mask_addresses)=9737, len(self.rgb_addresses)=233688, len(self.flow_addresses)=233688


#### Datashapes

In [5]:
iterator=iter(train_loader)
coms,bboxes,masks,rgbs,flows=next(iterator)
print(f"shapes: \r\n{coms.shape=},\r\n{bboxes.shape=},\r\n{masks.shape=},\r\n{rgbs.shape=},\r\n{flows.shape=}\r\n============================================")


shapes: 
coms.shape=torch.Size([64, 24, 11, 2]),
bboxes.shape=torch.Size([64, 24, 11, 4]),
masks.shape=torch.Size([64, 24, 128, 128]),
rgbs.shape=torch.Size([64, 24, 3, 128, 128]),
flows.shape=torch.Size([64, 24, 3, 128, 128])


### Encoder

In [6]:
vit = ViT(
        img_height=general_configs["img_height"],
        img_width=general_configs["img_width"],
        channels=general_configs["channels"],
        max_objects_in_scene=general_configs["max_objects_in_scene"],
        frame_numbers=general_configs["number_of_frames_per_video"],
        token_dim=encoder_configs["token_dim"],
        attn_dim=encoder_configs["attn_dim"],
        num_heads=encoder_configs["num_heads"],
        mlp_size=encoder_configs["mlp_size"],
        num_tf_layers=encoder_configs["num_tf_layers"]).to(general_configs["device"])
print(f"ViT has {count_model_params(vit)} parameters")
vit

ViT has 7181056 parameters


ViT(
  (patch_projection): Sequential(
    (0): LayerNorm((49152,), eps=1e-05, elementwise_affine=True)
    (1): Linear(in_features=49152, out_features=128, bias=True)
  )
  (pos_emb): PositionalEncoding()
  (encoderBlocks): Sequential(
    (0): EncoderBlock(
      (ln_att): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
      (attn): MultiHeadAttention(
        (q): Linear(in_features=128, out_features=128, bias=False)
        (k): Linear(in_features=128, out_features=128, bias=False)
        (v): Linear(in_features=128, out_features=128, bias=False)
        (out_proj): Linear(in_features=128, out_features=128, bias=False)
      )
      (ln_mlp): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
      (mlp): MLP(
        (mlp): Sequential(
          (0): Linear(in_features=128, out_features=512, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=512, out_features=128, bias=True)
        )
      )
    )
    (1): EncoderBlock(
      (ln_att):

### Decoder

In [7]:
decoder=ViT_Decoder(
    batch_size=general_configs["batch_size"],
    img_height=general_configs["img_height"],
    img_width=general_configs["img_width"],
    channels=general_configs["channels"],
    frame_numbers=general_configs["number_of_frames_per_video"],
    token_dim=decoder_configs["token_dim"],
    attn_dim=decoder_configs["attn_dim"], 
    num_heads=decoder_configs["num_heads"], 
    mlp_size=decoder_configs["mlp_size"], 
    num_tf_layers=decoder_configs["num_tf_layers"],
    max_objects_in_scene=general_configs["max_objects_in_scene"],
    device=general_configs["device"]
).to(general_configs["device"])
print(f"ViT has {count_model_params(decoder)} parameters")
decoder

ViT has 8508755 parameters


ViT_Decoder(
  (patch_projection): Sequential(
    (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (1): Linear(in_features=128, out_features=49152, bias=True)
  )
  (pos_emb): PositionalEncoding()
  (output_projector): Sequential(
    (0): Conv2d(33, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(16, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
)

### Transformer

In [8]:
transformer=CustomizableTransformer(encoder=vit, decoder=decoder).to(general_configs["device"])
assert count_model_params(decoder)+count_model_params(vit)==count_model_params(transformer)
transformer

CustomizableTransformer(
  (encoder): ViT(
    (patch_projection): Sequential(
      (0): LayerNorm((49152,), eps=1e-05, elementwise_affine=True)
      (1): Linear(in_features=49152, out_features=128, bias=True)
    )
    (pos_emb): PositionalEncoding()
    (encoderBlocks): Sequential(
      (0): EncoderBlock(
        (ln_att): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
        (attn): MultiHeadAttention(
          (q): Linear(in_features=128, out_features=128, bias=False)
          (k): Linear(in_features=128, out_features=128, bias=False)
          (v): Linear(in_features=128, out_features=128, bias=False)
          (out_proj): Linear(in_features=128, out_features=128, bias=False)
        )
        (ln_mlp): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
        (mlp): MLP(
          (mlp): Sequential(
            (0): Linear(in_features=128, out_features=512, bias=True)
            (1): GELU(approximate='none')
            (2): Linear(in_features=512, out_features=

#### Training

In [9]:
criterion=l1_and_ssim_loss_function
optimizer = torch.optim.Adam(transformer.parameters(), lr=general_configs["learning_rate"])
scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda=lambda epoch: 0.95)
TBOARD_LOGS = os.path.join(os.getcwd(), "../tboard_logs", "ViT_30")
if not os.path.exists(TBOARD_LOGS):
    os.makedirs(TBOARD_LOGS)

shutil.rmtree(TBOARD_LOGS)
writer = SummaryWriter(TBOARD_LOGS)

In [10]:
train_loss, val_loss, loss_iters, valid_acc = train_model(
        model=transformer,
        optimizer=optimizer,
        scheduler=scheduler,
        criterion=criterion,
        train_loader=train_loader,
        valid_loader=valid_loader,
        num_epochs=general_configs["num_epochs"],
        tboard=writer,
        trainingmode=general_configs["trainingMode"]    )

Started Epoch 1/60...
  --> Running valdiation epoch


AttributeError: module 'torchvision.ops' has no attribute 'ssim'