In [None]:
#!kill 2128689
!nvidia-smi

In [None]:
import sys
import os
import shutil
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
sys.path.append(os.path.abspath(".."))
from loader.Dataset import VideoDataset
from torch.utils.data import DataLoader
from models.transformers.encoders.vit_encoder import ViT
from models.transformers.decoders.vit_decoder import ViT_Decoder
from models.transformers.CustomTransformer import CustomizableTransformer
from utils.util import count_model_params, train_epoch, eval_model, train_model, load_model 
from matplotlib import pyplot as plt
from matplotlib import patches
import torch
from torch.utils.tensorboard import SummaryWriter
from utils.loss_function import ReconstructionLoss_L1_Ssim, ReconstructionLoss_PSNR_SSIM
from loader.transforms import RGBNormalizer,Composition,CustomResize,RandomHorizontalFlip,RandomVerticalFlip,CustomColorJitter

%load_ext autoreload
%autoreload 2

#### configs

In [None]:

general_configs={
"data_path":"/home/nfs/inf6/data/datasets/MOVi/movi_c/",
"original_number_of_frames_per_video":24,
"selected_number_of_frames_per_video":4,
"max_objects_in_scene":11,
"batch_size":64,
"device" : torch.device("cuda" if torch.cuda.is_available() else "cpu"),
"img_height":64,
"img_width":64,
"channels":3,
"learning_rate":3e-4,
"num_epochs":60,
"trainingMode":1 #0 training by bounding boxes, 1 training by masks. 
}

encoder_configs={
        "token_dim":128,
        "attn_dim":128,
        "num_heads":4,
        "mlp_size":512,
        "num_tf_layers":4
}
decoder_configs={
        "token_dim":128,
        "attn_dim":128,
        "num_heads":4,
        "mlp_size":512,
        "num_tf_layers":4
}

data_transform_config={
        "img_height":general_configs["img_height"],
        "img_width":general_configs["img_width"],
        "vFlip_probability":0.3,
        "hFlip_probability":0.3,
        "color_jitter_brightness":(0.8, 1.2),
        "color_jitter_hue":(-0.3, 0.3),
        "color_jitter_contrast":(0.6, 1.8),
        "color_jitter_saturation":(0.5, 1.5)
}

#### Datasets

In [16]:
transform_composition = Composition([
                                        RGBNormalizer(),
                                        CustomResize((data_transform_config["img_height"],data_transform_config["img_width"])),
                                        RandomVerticalFlip(data_transform_config["vFlip_probability"]),
                                        RandomHorizontalFlip(data_transform_config["hFlip_probability"]),
                                        #CustomColorJitter(
                                        #    brightness=data_transform_config["color_jitter_brightness"],
                                        #    hue=data_transform_config["color_jitter_hue"],
                                        #    contrast=data_transform_config["color_jitter_contrast"],
                                        #    saturation=data_transform_config["color_jitter_saturation"]
                                        #)
                                    ])
#transform_composition=None
validation_dataset = VideoDataset(data_path=general_configs["data_path"],
                            split='validation',
                            original_number_of_frames_per_video=general_configs["original_number_of_frames_per_video"],
                            selected_number_of_frames_per_video=general_configs["selected_number_of_frames_per_video"],
                            max_objects_in_scene=general_configs["max_objects_in_scene"],
                            halve_dataset=True,
                            is_test_dataset=False,
                            transforms=transform_composition)
valid_loader = DataLoader(dataset=validation_dataset,
                            batch_size=general_configs["batch_size"],
                            shuffle=False,
                            drop_last=True)
train_dataset = VideoDataset(data_path=general_configs["data_path"],
                            split='train',
                            original_number_of_frames_per_video=general_configs["original_number_of_frames_per_video"],
                            selected_number_of_frames_per_video=general_configs["selected_number_of_frames_per_video"],
                            max_objects_in_scene=general_configs["max_objects_in_scene"],
                            halve_dataset=False,
                            is_test_dataset=False,
                            transforms=transform_composition)
train_loader = DataLoader(dataset=train_dataset,
                            batch_size=general_configs["batch_size"],
                            shuffle=True,
                            drop_last=True)

#### Datashapes

In [17]:
iterator=iter(train_loader)
bboxes,masks,rgbs=next(iterator)
print(f"shapes: \r\n{bboxes.shape=},\r\n{masks.shape=},\r\n{rgbs.shape=},\r\n============================================")
print(len(train_loader))

shapes: 
bboxes.shape=torch.Size([64, 4, 11, 4]),
masks.shape=torch.Size([64, 4, 64, 64]),
rgbs.shape=torch.Size([64, 4, 3, 64, 64]),
152


### Encoder

In [18]:
def defineVIT():
        return  ViT(
                img_height=general_configs["img_height"],
                img_width=general_configs["img_width"],
                channels=general_configs["channels"],
                max_objects_in_scene=general_configs["max_objects_in_scene"],
                frame_numbers=general_configs["selected_number_of_frames_per_video"],
                token_dim=encoder_configs["token_dim"],
                attn_dim=encoder_configs["attn_dim"],
                num_heads=encoder_configs["num_heads"],
                mlp_size=encoder_configs["mlp_size"],
                num_tf_layers=encoder_configs["num_tf_layers"]).to(general_configs["device"])

In [19]:
vit=defineVIT()
print(f"ViT has {count_model_params(vit)} parameters")
vit

ViT has 2390784 parameters


ViT(
  (patch_projection): Sequential(
    (0): LayerNorm((12288,), eps=1e-05, elementwise_affine=True)
    (1): Linear(in_features=12288, out_features=128, bias=True)
  )
  (pos_emb): PositionalEncoding()
  (encoderBlocks): Sequential(
    (0): EncoderBlock(
      (ln_att): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
      (attn): MultiHeadAttention(
        (q): Linear(in_features=128, out_features=128, bias=True)
        (k): Linear(in_features=128, out_features=128, bias=True)
        (v): Linear(in_features=128, out_features=128, bias=True)
        (out_proj): Linear(in_features=128, out_features=128, bias=True)
      )
      (ln_mlp): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
      (mlp): MLP(
        (mlp): Sequential(
          (0): Linear(in_features=128, out_features=512, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=512, out_features=128, bias=True)
        )
      )
    )
    (1): EncoderBlock(
      (ln_att): Lay

### Decoder

In [20]:
def defineDecoder():
    return ViT_Decoder(
    batch_size=general_configs["batch_size"],
    img_height=general_configs["img_height"],
    img_width=general_configs["img_width"],
    channels=general_configs["channels"],
    frame_numbers=general_configs["selected_number_of_frames_per_video"],
    token_dim=decoder_configs["token_dim"],
    attn_dim=decoder_configs["attn_dim"], 
    num_heads=decoder_configs["num_heads"], 
    mlp_size=decoder_configs["mlp_size"], 
    num_tf_layers=decoder_configs["num_tf_layers"],
    max_objects_in_scene=general_configs["max_objects_in_scene"],
    device=general_configs["device"]
).to(general_configs["device"])

In [21]:
decoder=defineDecoder()
print(f"Decoder has {count_model_params(decoder)} parameters")
decoder

Decoder has 1618176 parameters


ViT_Decoder(
  (patch_projection): Sequential(
    (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (1): Linear(in_features=128, out_features=12288, bias=True)
  )
  (pos_emb): PositionalEncoding()
)

### Transformer

In [22]:
transformer=CustomizableTransformer(encoder=vit, decoder=decoder).to(general_configs["device"])
assert count_model_params(decoder)+count_model_params(vit)==count_model_params(transformer)
print(f"transformer has {count_model_params(transformer)} parameters")
transformer

transformer has 4008960 parameters


CustomizableTransformer(
  (encoder): ViT(
    (patch_projection): Sequential(
      (0): LayerNorm((12288,), eps=1e-05, elementwise_affine=True)
      (1): Linear(in_features=12288, out_features=128, bias=True)
    )
    (pos_emb): PositionalEncoding()
    (encoderBlocks): Sequential(
      (0): EncoderBlock(
        (ln_att): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
        (attn): MultiHeadAttention(
          (q): Linear(in_features=128, out_features=128, bias=True)
          (k): Linear(in_features=128, out_features=128, bias=True)
          (v): Linear(in_features=128, out_features=128, bias=True)
          (out_proj): Linear(in_features=128, out_features=128, bias=True)
        )
        (ln_mlp): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
        (mlp): MLP(
          (mlp): Sequential(
            (0): Linear(in_features=128, out_features=512, bias=True)
            (1): GELU(approximate='none')
            (2): Linear(in_features=512, out_features=128,

#### Training

In [23]:
criterion=ReconstructionLoss_L1_Ssim(device=general_configs["device"],lambda_l1=0.8,lambda_ssim=0.2)
optimizer = torch.optim.Adam(transformer.parameters(), lr=general_configs["learning_rate"])
scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda=lambda epoch: 0.95)
TBOARD_LOGS = os.path.join(os.getcwd(), "../tboard_logs", "ViT_30")
if not os.path.exists(TBOARD_LOGS):
    os.makedirs(TBOARD_LOGS)

shutil.rmtree(TBOARD_LOGS)
writer = SummaryWriter(TBOARD_LOGS)

In [None]:
train_loss, val_loss, loss_iters, valid_acc = train_model(
        model=transformer,
        optimizer=optimizer,
        scheduler=scheduler,
        criterion=criterion,
        train_loader=train_loader,
        valid_loader=valid_loader,
        num_epochs=general_configs["num_epochs"],
        tboard=writer,
        trainingmode=general_configs["trainingMode"],
        conf=(decoder_configs,encoder_configs,general_configs),saveImagesPerEachEpoch=True   )

Started Epoch 1/60...
  --> Running valdiation epoch
  --> Running train epoch


100%|██████████| 152/152 [01:37<00:00,  1.56it/s]


Epoch 1/60
    Train loss: 0.28932
    Valid loss: 0.34271
    Valid Accuracy: 0.0%


Started Epoch 2/60...
  --> Running valdiation epoch
  --> Running train epoch


100%|██████████| 152/152 [01:36<00:00,  1.58it/s]


Epoch 2/60
    Train loss: 0.2554
    Valid loss: 0.26146
    Valid Accuracy: 0.0%


Started Epoch 3/60...
  --> Running valdiation epoch
  --> Running train epoch


100%|██████████| 152/152 [01:36<00:00,  1.57it/s]


Epoch 3/60
    Train loss: 0.2497
    Valid loss: 0.24987
    Valid Accuracy: 1.5625%


Started Epoch 4/60...
  --> Running valdiation epoch
  --> Running train epoch


 12%|█▎        | 19/152 [00:11<01:21,  1.63it/s]

In [None]:
from utils.util import save_model
stats = {
    "train_loss": train_loss,
    "valid_loss": val_loss,
    "loss_iters": loss_iters,
    "valid_acc": valid_acc
}
save_model(transformer, optimizer, epoch=general_configs["num_epochs"], stats=stats)

In [None]:
"""
stats = {
    "train_loss": [],
    "valid_loss": [],
    "loss_iters": [],
    "valid_acc": []
}
model, optimizer, epoch, stats = load_model(transformer, optimizer, savepath="../../checkpoints/checkpoint_epoch_60_SSIM_L1.pth")"""

In [None]:
from utils.util import visualize_progress

loss_iters = stats['loss_iters']
val_loss = stats['valid_loss']
train_loss = stats['train_loss']
valid_acc = stats['valid_acc']

visualize_progress(loss_iters, train_loss, val_loss, valid_acc, start=0)
plt.show()

In [None]:
test_dataset = VideoDataset(data_path=general_configs["data_path"],
                            split='validation',
                            original_number_of_frames_per_video=general_configs["original_number_of_frames_per_video"],
                            selected_number_of_frames_per_video=general_configs["selected_number_of_frames_per_video"],
                            max_objects_in_scene=general_configs["max_objects_in_scene"],
                            halve_dataset=True,
                            is_test_dataset=True,
                            transforms=transform_composition)
test_loader = DataLoader(dataset=test_dataset,
                            batch_size=general_configs["batch_size"],
                            shuffle=False,
                            drop_last=True)

In [None]:
import random


iterator=iter(test_loader)
coms,bboxes,masks,rgbs,flows=next(iterator)
transformer.eval()
with torch.no_grad():
   recons = transformer(rgbs.to(device),masks=masks.to(device)).to('cpu')

number_of_images=10
number_of_columns=2
video_index=random.randint(0,recons.shape[0])
fig, ax = plt.subplots(number_of_images,number_of_columns)
fig.set_size_inches(number_of_columns*5, number_of_images*3)

for i in range (number_of_images):
    for i in range(number_of_images):
        ax[i,1].imshow((recons[video_index,i]*255).clamp(0, 255).permute(1,2,0).byte().numpy())
        ax[i,1].axis("off")
        ax[i,0].imshow((rgbs[video_index,i]*255).clamp(0, 255).permute(1,2,0).byte().numpy())
        ax[i,0].axis("off")
        if i==0:
            ax[i,1].set_title(f"Reconstructed Image {i+1}")    
            ax[i,0].set_title(f"Original Image {i+1}")    


plt.tight_layout()
plt.show()