In [1]:
#!kill 2128689
!nvidia-smi

Tue Sep 30 16:29:22 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.172.08             Driver Version: 570.172.08     CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 4090        Off |   00000000:01:00.0 Off |                  Off |
|  0%   46C    P8             30W /  450W |    2461MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [2]:
import sys
import os
import shutil
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
sys.path.append(os.path.abspath(".."))
from loader.Dataset import VideoDataset
from torch.utils.data import DataLoader
from models.transformers.encoders.vit_encoder import ViT
from models.transformers.decoders.vit_decoder import ViT_Decoder
from models.transformers.CustomTransformer import CustomizableTransformer
from utils.util import count_model_params, train_epoch, eval_model, train_model, load_model 
from matplotlib import pyplot as plt
from matplotlib import patches
import torch
from torch.utils.tensorboard import SummaryWriter
from utils.loss_function import L1_SSIM_LPIPS_Loss5D_MemoryEfficient#ReconstructionLoss_L1_Ssim, ReconstructionLoss_PSNR_SSIM
from loader.transforms import RGBNormalizer,Composition,CustomResize,RandomHorizontalFlip,RandomVerticalFlip,CustomColorJitter

%load_ext autoreload
%autoreload 2

#### configs

In [3]:

general_configs={
"data_path":"/home/nfs/inf6/data/datasets/MOVi/movi_c/",
"original_number_of_frames_per_video":24,
"selected_number_of_frames_per_video":4,
"max_objects_in_scene":11,
"batch_size":64,
"device" : torch.device("cuda" if torch.cuda.is_available() else "cpu"),
"img_height":64,
"img_width":64,
"channels":3,
"learning_rate":3e-4,
"num_epochs":60,
"trainingMode":1 #0 training by bounding boxes, 1 training by masks. 
}

encoder_configs={
        "token_dim":128,
        "attn_dim":128,
        "num_heads":4,
        "mlp_size":512,
        "num_tf_layers":4
}
decoder_configs={
        "token_dim":128,
        "attn_dim":128,
        "num_heads":4,
        "mlp_size":512,
        "num_tf_layers":4
}

data_transform_config={
        "img_height":general_configs["img_height"],
        "img_width":general_configs["img_width"],
        "vFlip_probability":0.3,
        "hFlip_probability":0.3,
        "color_jitter_brightness":(0.8, 1.2),
        "color_jitter_hue":(-0.3, 0.3),
        "color_jitter_contrast":(0.6, 1.8),
        "color_jitter_saturation":(0.5, 1.5)
}

#### Datasets

In [4]:
transform_composition = Composition([
                                        RGBNormalizer(),
                                        CustomResize((data_transform_config["img_height"],data_transform_config["img_width"])),
                                        RandomVerticalFlip(data_transform_config["vFlip_probability"]),
                                        RandomHorizontalFlip(data_transform_config["hFlip_probability"]),
                                        #CustomColorJitter(
                                        #    brightness=data_transform_config["color_jitter_brightness"],
                                        #    hue=data_transform_config["color_jitter_hue"],
                                        #    contrast=data_transform_config["color_jitter_contrast"],
                                        #    saturation=data_transform_config["color_jitter_saturation"]
                                        #)
                                    ])
#transform_composition=None
validation_dataset = VideoDataset(data_path=general_configs["data_path"],
                            split='validation',
                            original_number_of_frames_per_video=general_configs["original_number_of_frames_per_video"],
                            selected_number_of_frames_per_video=general_configs["selected_number_of_frames_per_video"],
                            max_objects_in_scene=general_configs["max_objects_in_scene"],
                            halve_dataset=True,
                            is_test_dataset=False,
                            transforms=transform_composition)
valid_loader = DataLoader(dataset=validation_dataset,
                            batch_size=general_configs["batch_size"],
                            shuffle=False,
                            drop_last=True)
train_dataset = VideoDataset(data_path=general_configs["data_path"],
                            split='train',
                            original_number_of_frames_per_video=general_configs["original_number_of_frames_per_video"],
                            selected_number_of_frames_per_video=general_configs["selected_number_of_frames_per_video"],
                            max_objects_in_scene=general_configs["max_objects_in_scene"],
                            halve_dataset=False,
                            is_test_dataset=False,
                            transforms=transform_composition)
train_loader = DataLoader(dataset=train_dataset,
                            batch_size=general_configs["batch_size"],
                            shuffle=True,
                            drop_last=True)

#### Datashapes

In [5]:
iterator=iter(train_loader)
bboxes,masks,rgbs=next(iterator)
print(f"shapes: \r\n{bboxes.shape=},\r\n{masks.shape=},\r\n{rgbs.shape=},\r\n============================================")
print(len(train_loader))

shapes: 
bboxes.shape=torch.Size([64, 4, 11, 4]),
masks.shape=torch.Size([64, 4, 64, 64]),
rgbs.shape=torch.Size([64, 4, 3, 64, 64]),
152


### Encoder

In [6]:
def defineVIT():
        return  ViT(
                img_height=general_configs["img_height"],
                img_width=general_configs["img_width"],
                channels=general_configs["channels"],
                max_objects_in_scene=general_configs["max_objects_in_scene"],
                frame_numbers=general_configs["selected_number_of_frames_per_video"],
                token_dim=encoder_configs["token_dim"],
                attn_dim=encoder_configs["attn_dim"],
                num_heads=encoder_configs["num_heads"],
                mlp_size=encoder_configs["mlp_size"],
                num_tf_layers=encoder_configs["num_tf_layers"]).to(general_configs["device"])

In [7]:
vit=defineVIT()
print(f"ViT has {count_model_params(vit)} parameters")
vit

ViT has 1122752 parameters


ViT(
  (patch_projection): Sequential(
    (0): Conv2d(3, 32, kernel_size=(4, 4), stride=(4, 4), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(4, 4), padding=(1, 1))
    (3): ReLU()
    (4): Conv2d(64, 96, kernel_size=(4, 4), stride=(4, 4), padding=(1, 1))
    (5): ReLU()
    (6): Conv2d(96, 128, kernel_size=(4, 4), stride=(4, 4), padding=(2, 2))
    (7): Sigmoid()
  )
  (pos_emb): PositionalEncoding()
  (encoderBlocks): Sequential(
    (0): EncoderBlock(
      (ln_att): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
      (attn): MultiHeadAttention(
        (q): Linear(in_features=128, out_features=128, bias=True)
        (k): Linear(in_features=128, out_features=128, bias=True)
        (v): Linear(in_features=128, out_features=128, bias=True)
        (out_proj): Linear(in_features=128, out_features=128, bias=True)
      )
      (ln_mlp): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
      (mlp): MLP(
        (mlp): Sequential(
    

### Decoder

In [8]:
def defineDecoder():
    return ViT_Decoder(
    batch_size=general_configs["batch_size"],
    img_height=general_configs["img_height"],
    img_width=general_configs["img_width"],
    channels=general_configs["channels"],
    frame_numbers=general_configs["selected_number_of_frames_per_video"],
    token_dim=decoder_configs["token_dim"],
    attn_dim=decoder_configs["attn_dim"], 
    num_heads=decoder_configs["num_heads"], 
    mlp_size=decoder_configs["mlp_size"], 
    num_tf_layers=decoder_configs["num_tf_layers"],
    max_objects_in_scene=general_configs["max_objects_in_scene"],
    device=general_configs["device"]
).to(general_configs["device"])

In [9]:
decoder=defineDecoder()
print(f"Decoder has {count_model_params(decoder)} parameters")
decoder

Decoder has 1248163 parameters


ViT_Decoder(
  (patch_projection): Sequential(
    (0): ConvTranspose2d(128, 256, kernel_size=(4, 4), stride=(1, 1))
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (7): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (10): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(32, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (13): Sigmoid()
  )
  (pos_emb): PositionalEncoding()
)

### Transformer

In [10]:
transformer=CustomizableTransformer(encoder=vit, decoder=decoder).to(general_configs["device"])
assert count_model_params(decoder)+count_model_params(vit)==count_model_params(transformer)
print(f"transformer has {count_model_params(transformer)} parameters")
transformer

transformer has 2370915 parameters


CustomizableTransformer(
  (encoder): ViT(
    (patch_projection): Sequential(
      (0): Conv2d(3, 32, kernel_size=(4, 4), stride=(4, 4), padding=(1, 1))
      (1): ReLU()
      (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(4, 4), padding=(1, 1))
      (3): ReLU()
      (4): Conv2d(64, 96, kernel_size=(4, 4), stride=(4, 4), padding=(1, 1))
      (5): ReLU()
      (6): Conv2d(96, 128, kernel_size=(4, 4), stride=(4, 4), padding=(2, 2))
      (7): Sigmoid()
    )
    (pos_emb): PositionalEncoding()
    (encoderBlocks): Sequential(
      (0): EncoderBlock(
        (ln_att): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
        (attn): MultiHeadAttention(
          (q): Linear(in_features=128, out_features=128, bias=True)
          (k): Linear(in_features=128, out_features=128, bias=True)
          (v): Linear(in_features=128, out_features=128, bias=True)
          (out_proj): Linear(in_features=128, out_features=128, bias=True)
        )
        (ln_mlp): LayerNorm((128,), eps=1

#### Training

In [11]:
criterion=L1_SSIM_LPIPS_Loss5D_MemoryEfficient(l1_lambda=1.0,ssim_lambda=0.5, lpips_lambda=0.1)
optimizer = torch.optim.Adam(transformer.parameters(), lr=general_configs["learning_rate"])
scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda=lambda epoch: 0.95)
TBOARD_LOGS = os.path.join(os.getcwd(), "../tboard_logs", "ViT_30")
if not os.path.exists(TBOARD_LOGS):
    os.makedirs(TBOARD_LOGS)

shutil.rmtree(TBOARD_LOGS)
writer = SummaryWriter(TBOARD_LOGS)

Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]




Loading model from: /home/user/masroora1/VideoTracker/.venv/lib/python3.12/site-packages/lpips/weights/v0.1/alex.pth


In [None]:
train_loss, val_loss, loss_iters, valid_acc = train_model(
        model=transformer,
        optimizer=optimizer,
        scheduler=scheduler,
        criterion=criterion,
        train_loader=train_loader,
        valid_loader=valid_loader,
        num_epochs=general_configs["num_epochs"],
        tboard=writer,
        trainingmode=general_configs["trainingMode"],
        saveImagesPerEachEpoch=True   )

Started Epoch 1/60...
  --> Running valdiation epoch
  --> Running train epoch


100%|██████████| 152/152 [01:30<00:00,  1.67it/s]


Epoch 1/60
    Train loss: 0.52324
    Valid loss: 0.84216
    Valid Accuracy: 0.0%


Started Epoch 2/60...
  --> Running valdiation epoch
  --> Running train epoch


100%|██████████| 152/152 [01:30<00:00,  1.68it/s]


Epoch 2/60
    Train loss: 0.48874
    Valid loss: 0.48428
    Valid Accuracy: 0.0%


Started Epoch 3/60...
  --> Running valdiation epoch
  --> Running train epoch


100%|██████████| 152/152 [01:30<00:00,  1.68it/s]


Epoch 3/60
    Train loss: 0.45979
    Valid loss: 0.47815
    Valid Accuracy: 0.0%


Started Epoch 4/60...
  --> Running valdiation epoch
  --> Running train epoch


100%|██████████| 152/152 [01:31<00:00,  1.67it/s]


Epoch 4/60
    Train loss: 0.42794
    Valid loss: 0.44617
    Valid Accuracy: 1.5625%


Started Epoch 5/60...
  --> Running valdiation epoch
  --> Running train epoch


100%|██████████| 152/152 [02:05<00:00,  1.21it/s]


Epoch 5/60
    Train loss: 0.41787
    Valid loss: 0.4042
    Valid Accuracy: 0.0%


Started Epoch 6/60...
  --> Running valdiation epoch
  --> Running train epoch


100%|██████████| 152/152 [02:13<00:00,  1.14it/s]


Epoch 6/60
    Train loss: 0.41033
    Valid loss: 0.39952
    Valid Accuracy: 0.0%


Started Epoch 7/60...
  --> Running valdiation epoch
  --> Running train epoch


100%|██████████| 152/152 [02:00<00:00,  1.26it/s]


Epoch 7/60
    Train loss: 0.40987
    Valid loss: 0.41544
    Valid Accuracy: 0.0%


Started Epoch 8/60...
  --> Running valdiation epoch
  --> Running train epoch


100%|██████████| 152/152 [02:13<00:00,  1.14it/s]


Epoch 8/60
    Train loss: 0.40655
    Valid loss: 0.39775
    Valid Accuracy: 0.0%


Started Epoch 9/60...
  --> Running valdiation epoch
  --> Running train epoch


100%|██████████| 152/152 [02:11<00:00,  1.15it/s]


Epoch 9/60
    Train loss: 0.40547
    Valid loss: 0.39137
    Valid Accuracy: 0.0%


Started Epoch 10/60...
  --> Running valdiation epoch
  --> Running train epoch


100%|██████████| 152/152 [02:10<00:00,  1.16it/s]


Epoch 10/60
    Train loss: 0.40436
    Valid loss: 0.39474
    Valid Accuracy: 0.0%


Started Epoch 11/60...
  --> Running valdiation epoch
  --> Running train epoch


100%|██████████| 152/152 [02:23<00:00,  1.06it/s]


Epoch 11/60
    Train loss: 0.40528
    Valid loss: 0.3905
    Valid Accuracy: 0.0%


Started Epoch 12/60...
  --> Running valdiation epoch
  --> Running train epoch


100%|██████████| 152/152 [02:02<00:00,  1.25it/s]


Epoch 12/60
    Train loss: 0.4026
    Valid loss: 0.39017
    Valid Accuracy: 0.0%


Started Epoch 13/60...
  --> Running valdiation epoch
  --> Running train epoch


100%|██████████| 152/152 [02:21<00:00,  1.07it/s]


Epoch 13/60
    Train loss: 0.40358
    Valid loss: 0.38917
    Valid Accuracy: 1.5625%


Started Epoch 14/60...
  --> Running valdiation epoch
  --> Running train epoch


100%|██████████| 152/152 [02:27<00:00,  1.03it/s]


Epoch 14/60
    Train loss: 0.40271
    Valid loss: 0.39253
    Valid Accuracy: 0.0%


Started Epoch 15/60...
  --> Running valdiation epoch
  --> Running train epoch


100%|██████████| 152/152 [01:59<00:00,  1.27it/s]


Epoch 15/60
    Train loss: 0.40275
    Valid loss: 0.39318
    Valid Accuracy: 1.5625%


Started Epoch 16/60...
  --> Running valdiation epoch
  --> Running train epoch


100%|██████████| 152/152 [02:24<00:00,  1.05it/s]


Epoch 16/60
    Train loss: 0.40236
    Valid loss: 0.38833
    Valid Accuracy: 0.0%


Started Epoch 17/60...
  --> Running valdiation epoch
  --> Running train epoch


100%|██████████| 152/152 [02:12<00:00,  1.14it/s]


Epoch 17/60
    Train loss: 0.4018
    Valid loss: 0.38781
    Valid Accuracy: 1.5625%


Started Epoch 18/60...
  --> Running valdiation epoch
  --> Running train epoch


100%|██████████| 152/152 [02:15<00:00,  1.12it/s]


Epoch 18/60
    Train loss: 0.40152
    Valid loss: 0.38998
    Valid Accuracy: 0.0%


Started Epoch 19/60...
  --> Running valdiation epoch
  --> Running train epoch


100%|██████████| 152/152 [02:12<00:00,  1.14it/s]


Epoch 19/60
    Train loss: 0.40113
    Valid loss: 0.38945
    Valid Accuracy: 0.0%


Started Epoch 20/60...
  --> Running valdiation epoch
  --> Running train epoch


100%|██████████| 152/152 [02:12<00:00,  1.15it/s]


Epoch 20/60
    Train loss: 0.40086
    Valid loss: 0.38764
    Valid Accuracy: 1.5625%


Started Epoch 21/60...
  --> Running valdiation epoch
  --> Running train epoch


100%|██████████| 152/152 [02:15<00:00,  1.12it/s]


Epoch 21/60
    Train loss: 0.40106
    Valid loss: 0.38824
    Valid Accuracy: 1.5625%


Started Epoch 22/60...
  --> Running valdiation epoch
  --> Running train epoch


100%|██████████| 152/152 [02:26<00:00,  1.04it/s]


Epoch 22/60
    Train loss: 0.40102
    Valid loss: 0.39222
    Valid Accuracy: 1.5625%


Started Epoch 23/60...
  --> Running valdiation epoch
  --> Running train epoch


100%|██████████| 152/152 [02:00<00:00,  1.26it/s]


Epoch 23/60
    Train loss: 0.40099
    Valid loss: 0.38832
    Valid Accuracy: 0.0%


Started Epoch 24/60...
  --> Running valdiation epoch
  --> Running train epoch


100%|██████████| 152/152 [02:29<00:00,  1.02it/s]


Epoch 24/60
    Train loss: 0.40008
    Valid loss: 0.38853
    Valid Accuracy: 0.0%


Started Epoch 25/60...
  --> Running valdiation epoch
  --> Running train epoch


100%|██████████| 152/152 [02:16<00:00,  1.11it/s]


Epoch 25/60
    Train loss: 0.40039
    Valid loss: 0.38672
    Valid Accuracy: 0.0%


Started Epoch 26/60...
  --> Running valdiation epoch
  --> Running train epoch


100%|██████████| 152/152 [02:14<00:00,  1.13it/s]


Epoch 26/60
    Train loss: 0.40003
    Valid loss: 0.39052
    Valid Accuracy: 0.0%


Started Epoch 27/60...
  --> Running valdiation epoch
  --> Running train epoch


100%|██████████| 152/152 [02:13<00:00,  1.14it/s]


Epoch 27/60
    Train loss: 0.39991
    Valid loss: 0.38804
    Valid Accuracy: 1.5625%


Started Epoch 28/60...
  --> Running valdiation epoch
  --> Running train epoch


100%|██████████| 152/152 [02:27<00:00,  1.03it/s]


Epoch 28/60
    Train loss: 0.39968
    Valid loss: 0.38576
    Valid Accuracy: 0.0%


Started Epoch 29/60...
  --> Running valdiation epoch
  --> Running train epoch


100%|██████████| 152/152 [01:58<00:00,  1.28it/s]


Epoch 29/60
    Train loss: 0.39943
    Valid loss: 0.38715
    Valid Accuracy: 0.0%


Started Epoch 30/60...
  --> Running valdiation epoch
  --> Running train epoch


100%|██████████| 152/152 [02:24<00:00,  1.05it/s]


Epoch 30/60
    Train loss: 0.39987
    Valid loss: 0.38581
    Valid Accuracy: 0.0%


Started Epoch 31/60...
  --> Running valdiation epoch
  --> Running train epoch


100%|██████████| 152/152 [02:16<00:00,  1.11it/s]


Epoch 31/60
    Train loss: 0.40014
    Valid loss: 0.38857
    Valid Accuracy: 0.0%


Started Epoch 32/60...
  --> Running valdiation epoch
  --> Running train epoch


100%|██████████| 152/152 [02:18<00:00,  1.09it/s]


Epoch 32/60
    Train loss: 0.39949
    Valid loss: 0.38887
    Valid Accuracy: 0.0%


Started Epoch 33/60...
  --> Running valdiation epoch
  --> Running train epoch


100%|██████████| 152/152 [02:24<00:00,  1.05it/s]


Epoch 33/60
    Train loss: 0.39881
    Valid loss: 0.38522
    Valid Accuracy: 1.5625%


Started Epoch 34/60...
  --> Running valdiation epoch
  --> Running train epoch


100%|██████████| 152/152 [02:00<00:00,  1.26it/s]


Epoch 34/60
    Train loss: 0.39897
    Valid loss: 0.38928
    Valid Accuracy: 0.0%


Started Epoch 35/60...
  --> Running valdiation epoch
  --> Running train epoch


100%|██████████| 152/152 [02:25<00:00,  1.05it/s]


Epoch 35/60
    Train loss: 0.39882
    Valid loss: 0.38633
    Valid Accuracy: 1.5625%


Started Epoch 36/60...
  --> Running valdiation epoch
  --> Running train epoch


100%|██████████| 152/152 [02:15<00:00,  1.12it/s]


Epoch 36/60
    Train loss: 0.39897
    Valid loss: 0.38559
    Valid Accuracy: 0.0%


Started Epoch 37/60...
  --> Running valdiation epoch
  --> Running train epoch


100%|██████████| 152/152 [02:14<00:00,  1.13it/s]


Epoch 37/60
    Train loss: 0.39854
    Valid loss: 0.38981
    Valid Accuracy: 0.0%


Started Epoch 38/60...
  --> Running valdiation epoch
  --> Running train epoch


100%|██████████| 152/152 [02:14<00:00,  1.13it/s]


Epoch 38/60
    Train loss: 0.39903
    Valid loss: 0.38773
    Valid Accuracy: 0.0%


Started Epoch 39/60...
  --> Running valdiation epoch
  --> Running train epoch


100%|██████████| 152/152 [02:26<00:00,  1.04it/s]


Epoch 39/60
    Train loss: 0.39814
    Valid loss: 0.38505
    Valid Accuracy: 0.0%


Started Epoch 40/60...
  --> Running valdiation epoch
  --> Running train epoch


100%|██████████| 152/152 [01:58<00:00,  1.28it/s]


Epoch 40/60
    Train loss: 0.39826
    Valid loss: 0.38611
    Valid Accuracy: 4.6875%


Started Epoch 41/60...
  --> Running valdiation epoch
  --> Running train epoch


100%|██████████| 152/152 [02:26<00:00,  1.04it/s]


Epoch 41/60
    Train loss: 0.39793
    Valid loss: 0.3858
    Valid Accuracy: 1.5625%


Started Epoch 42/60...
  --> Running valdiation epoch
  --> Running train epoch


100%|██████████| 152/152 [02:07<00:00,  1.19it/s]


Epoch 42/60
    Train loss: 0.39794
    Valid loss: 0.38674
    Valid Accuracy: 1.5625%


Started Epoch 43/60...
  --> Running valdiation epoch
  --> Running train epoch


100%|██████████| 152/152 [02:27<00:00,  1.03it/s]


Epoch 43/60
    Train loss: 0.39788
    Valid loss: 0.38564
    Valid Accuracy: 0.0%


Started Epoch 44/60...
  --> Running valdiation epoch
  --> Running train epoch


100%|██████████| 152/152 [02:02<00:00,  1.24it/s]


Epoch 44/60
    Train loss: 0.39749
    Valid loss: 0.38654
    Valid Accuracy: 1.5625%


Started Epoch 45/60...
  --> Running valdiation epoch
  --> Running train epoch


100%|██████████| 152/152 [02:15<00:00,  1.12it/s]


Epoch 45/60
    Train loss: 0.39759
    Valid loss: 0.38531
    Valid Accuracy: 1.5625%


Started Epoch 46/60...
  --> Running valdiation epoch
  --> Running train epoch


100%|██████████| 152/152 [02:17<00:00,  1.11it/s]


Epoch 46/60
    Train loss: 0.39758
    Valid loss: 0.3863
    Valid Accuracy: 1.5625%


Started Epoch 47/60...
  --> Running valdiation epoch
  --> Running train epoch


100%|██████████| 152/152 [02:17<00:00,  1.11it/s]


Epoch 47/60
    Train loss: 0.39734
    Valid loss: 0.38578
    Valid Accuracy: 0.0%


Started Epoch 48/60...
  --> Running valdiation epoch
  --> Running train epoch


100%|██████████| 152/152 [02:14<00:00,  1.13it/s]


Epoch 48/60
    Train loss: 0.3973
    Valid loss: 0.38626
    Valid Accuracy: 0.0%


Started Epoch 49/60...
  --> Running valdiation epoch
  --> Running train epoch


100%|██████████| 152/152 [02:27<00:00,  1.03it/s]


Epoch 49/60
    Train loss: 0.39707
    Valid loss: 0.38649
    Valid Accuracy: 1.5625%


Started Epoch 50/60...
  --> Running valdiation epoch
  --> Running train epoch


 46%|████▌     | 70/152 [00:57<01:04,  1.27it/s]

In [None]:
from utils.util import save_model
stats = {
    "train_loss": train_loss,
    "valid_loss": val_loss,
    "loss_iters": loss_iters,
    "valid_acc": valid_acc
}
save_model(transformer, optimizer, epoch=general_configs["num_epochs"], stats=stats)

In [None]:
"""
stats = {
    "train_loss": [],
    "valid_loss": [],
    "loss_iters": [],
    "valid_acc": []
}
model, optimizer, epoch, stats = load_model(transformer, optimizer, savepath="../../checkpoints/checkpoint_epoch_60_SSIM_L1.pth")"""

In [None]:
from utils.util import visualize_progress

loss_iters = stats['loss_iters']
val_loss = stats['valid_loss']
train_loss = stats['train_loss']
valid_acc = stats['valid_acc']

visualize_progress(loss_iters, train_loss, val_loss, valid_acc, start=0)
plt.show()

In [None]:
test_dataset = VideoDataset(data_path=general_configs["data_path"],
                            split='validation',
                            original_number_of_frames_per_video=general_configs["original_number_of_frames_per_video"],
                            selected_number_of_frames_per_video=general_configs["selected_number_of_frames_per_video"],
                            max_objects_in_scene=general_configs["max_objects_in_scene"],
                            halve_dataset=True,
                            is_test_dataset=True,
                            transforms=transform_composition)
test_loader = DataLoader(dataset=test_dataset,
                            batch_size=general_configs["batch_size"],
                            shuffle=False,
                            drop_last=True)

In [None]:
import random


iterator=iter(test_loader)
coms,bboxes,masks,rgbs,flows=next(iterator)
transformer.eval()
with torch.no_grad():
   recons = transformer(rgbs.to(device),masks=masks.to(device)).to('cpu')

number_of_images=10
number_of_columns=2
video_index=random.randint(0,recons.shape[0])
fig, ax = plt.subplots(number_of_images,number_of_columns)
fig.set_size_inches(number_of_columns*5, number_of_images*3)

for i in range (number_of_images):
    for i in range(number_of_images):
        ax[i,1].imshow((recons[video_index,i]*255).clamp(0, 255).permute(1,2,0).byte().numpy())
        ax[i,1].axis("off")
        ax[i,0].imshow((rgbs[video_index,i]*255).clamp(0, 255).permute(1,2,0).byte().numpy())
        ax[i,0].axis("off")
        if i==0:
            ax[i,1].set_title(f"Reconstructed Image {i+1}")    
            ax[i,0].set_title(f"Original Image {i+1}")    


plt.tight_layout()
plt.show()