In [1]:
#!kill 2223683
!nvidia-smi

Sun Sep 28 10:41:02 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.158.01             Driver Version: 570.158.01     CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 3090        Off |   00000000:01:00.0 Off |                  N/A |
| 34%   36C    P8             27W /  350W |   14858MiB /  24576MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [1]:
import sys
import os
import shutil
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
sys.path.append(os.path.abspath(".."))
from loader.Dataset import VideoDataset
from torch.utils.data import DataLoader
from models.transformers.encoders.vit_encoder import ViT
from models.transformers.decoders.vit_decoder import ViT_Decoder
from models.transformers.CustomTransformer import CustomizableTransformer
from utils.util import count_model_params, train_epoch,eval_model,train_model
from matplotlib import pyplot as plt
from matplotlib import patches
import torch
from utils.util import l1_and_ssim_loss_function
from torch.utils.tensorboard import SummaryWriter
from utils.loss_function import ReconstructionLoss_L1_Ssim_Lpips

%load_ext autoreload
%autoreload 2

#### configs

In [2]:
from loader.transforms import RGBNormalizer,Composition,CustomResize,RandomHorizontalFlip,RandomVerticalFlip,CustomColorJitter


general_configs={
"data_path":"/home/nfs/inf6/data/datasets/MOVi/movi_c/",
"number_of_frames_per_video":24,
"max_objects_in_scene":11,
"batch_size":64,
"device" : torch.device("cuda" if torch.cuda.is_available() else "cpu"),
"img_height":64,
"img_width":64,
"channels":3,
"learning_rate":3e-4,
"num_epochs":60,
"trainingMode":1 #0 training by bounding boxes, 1 training by masks. 
}

encoder_configs={
        "token_dim":128,
        "attn_dim":128,
        "num_heads":4,
        "mlp_size":512,
        "num_tf_layers":4
}
decoder_configs={
        "token_dim":128,
        "attn_dim":128,
        "num_heads":4,
        "mlp_size":512,
        "num_tf_layers":4
}

data_transform_config={
        "img_height":general_configs["img_height"],
        "img_width":general_configs["img_width"],
        "vFlip_probability":0.6,
        "hFlip_probability":0.6,
        "color_jitter_brightness":(0.8, 1.2),
        "color_jitter_hue":(-0.3, 0.3),
        "color_jitter_contrast":(0.6, 1.8),
        "color_jitter_saturation":(0.5, 1.5)
}

#### Datasets

In [3]:
transform_composition = Composition([
                                        RGBNormalizer(),
                                        CustomResize((data_transform_config["img_height"],data_transform_config["img_width"])),
                                        RandomVerticalFlip(data_transform_config["vFlip_probability"]),
                                        RandomHorizontalFlip(data_transform_config["hFlip_probability"]),
                                        CustomColorJitter(
                                            brightness=data_transform_config["color_jitter_brightness"],
                                            hue=data_transform_config["color_jitter_hue"],
                                            contrast=data_transform_config["color_jitter_contrast"],
                                            saturation=data_transform_config["color_jitter_saturation"]
                                        )
                                    ])
#transform_composition=None
validation_dataset = VideoDataset(data_path=general_configs["data_path"],
                            split='validation',
                            number_of_frames_per_video=general_configs["number_of_frames_per_video"],
                            max_objects_in_scene=general_configs["max_objects_in_scene"],
                            halve_dataset=True,
                            is_test_dataset=False,
                            transforms=transform_composition)
valid_loader = DataLoader(dataset=validation_dataset,
                            batch_size=general_configs["batch_size"],
                            shuffle=False,
                            drop_last=True)
train_dataset = VideoDataset(data_path=general_configs["data_path"],
                            split='train',
                            number_of_frames_per_video=general_configs["number_of_frames_per_video"],
                            max_objects_in_scene=general_configs["max_objects_in_scene"],
                            halve_dataset=False,
                            is_test_dataset=False,
                            transforms=transform_composition)
train_loader = DataLoader(dataset=train_dataset,
                            batch_size=general_configs["batch_size"],
                            shuffle=True,
                            drop_last=True)

#### Datashapes

In [6]:
iterator=iter(train_loader)
coms,bboxes,masks,rgbs,flows=next(iterator)
print(f"shapes: \r\n{coms.shape=},\r\n{bboxes.shape=},\r\n{masks.shape=},\r\n{rgbs.shape=},\r\n{flows.shape=}\r\n============================================")
print(len(train_loader))

shapes: 
coms.shape=torch.Size([64, 24, 11, 2]),
bboxes.shape=torch.Size([64, 24, 11, 4]),
masks.shape=torch.Size([64, 24, 64, 64]),
rgbs.shape=torch.Size([64, 24, 3, 64, 64]),
flows.shape=torch.Size([64, 24, 3, 64, 64])
152


### Encoder

In [None]:
vit = ViT(
        img_height=general_configs["img_height"],
        img_width=general_configs["img_width"],
        channels=general_configs["channels"],
        max_objects_in_scene=general_configs["max_objects_in_scene"],
        frame_numbers=general_configs["number_of_frames_per_video"],
        token_dim=encoder_configs["token_dim"],
        attn_dim=encoder_configs["attn_dim"],
        num_heads=encoder_configs["num_heads"],
        mlp_size=encoder_configs["mlp_size"],
        num_tf_layers=encoder_configs["num_tf_layers"]).to(general_configs["device"])
print(f"ViT has {count_model_params(vit)} parameters")
vit

### Decoder

In [None]:
decoder=ViT_Decoder(
    batch_size=general_configs["batch_size"],
    img_height=general_configs["img_height"],
    img_width=general_configs["img_width"],
    channels=general_configs["channels"],
    frame_numbers=general_configs["number_of_frames_per_video"],
    token_dim=decoder_configs["token_dim"],
    attn_dim=decoder_configs["attn_dim"], 
    num_heads=decoder_configs["num_heads"], 
    mlp_size=decoder_configs["mlp_size"], 
    num_tf_layers=decoder_configs["num_tf_layers"],
    max_objects_in_scene=general_configs["max_objects_in_scene"],
    device=general_configs["device"]
).to(general_configs["device"])
print(f"Decoder has {count_model_params(decoder)} parameters")
decoder

### Transformer

In [None]:
transformer=CustomizableTransformer(encoder=vit, decoder=decoder).to(general_configs["device"])
assert count_model_params(decoder)+count_model_params(vit)==count_model_params(transformer)
print(f"transformer has {count_model_params(transformer)} parameters")
transformer

#### Training

In [9]:
criterion=ReconstructionLoss_L1_Ssim_Lpips(device=general_configs["device"],lambda_l1=0.5,lambda_ssim=0.3,lambda_lpips=0.2)
optimizer = torch.optim.Adam(transformer.parameters(), lr=general_configs["learning_rate"])
scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda=lambda epoch: 0.95)
TBOARD_LOGS = os.path.join(os.getcwd(), "../tboard_logs", "ViT_30")
if not os.path.exists(TBOARD_LOGS):
    os.makedirs(TBOARD_LOGS)

shutil.rmtree(TBOARD_LOGS)
writer = SummaryWriter(TBOARD_LOGS)

In [None]:
train_loss, val_loss, loss_iters, valid_acc = train_model(
        model=transformer,
        optimizer=optimizer,
        scheduler=scheduler,
        criterion=criterion,
        train_loader=train_loader,
        valid_loader=valid_loader,
        num_epochs=general_configs["num_epochs"],
        tboard=writer,
        trainingmode=general_configs["trainingMode"]    )

In [None]:
from utils.util import save_model
stats = {
    "train_loss": train_loss,
    "valid_loss": val_loss,
    "loss_iters": loss_iters,
    "valid_acc": valid_acc
}
save_model(transformer, optimizer, epoch=general_configs["num_epochs"], stats=stats)

In [None]:
from utils.util import visualize_progress

loss_iters = stats['loss_iters']
val_loss = stats['valid_loss']
train_loss = stats['train_loss']
valid_acc = stats['valid_acc']

visualize_progress(loss_iters, train_loss, val_loss, valid_acc, start=0)
plt.show()