In [None]:
import sys
import os
import shutil
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
sys.path.append(os.path.abspath(".."))
from loader.Dataset import VideoDataset
from torch.utils.data import DataLoader
from models.transformers.encoders.vit_encoder import ViT
from models.transformers.decoders.vit_decoder import ViT_Decoder
from models.transformers.CustomTransformer import CustomizableTransformer
from utils.util import count_model_params, train_epoch,eval_model,train_model
from matplotlib import pyplot as plt
from matplotlib import patches
import torch
from utils.util import l1_and_ssim_loss_function
from torch.utils.tensorboard import SummaryWriter
from utils.loss_function import ReconstructionLoss_L1_Ssim

%load_ext autoreload
%autoreload 2
from loader.transforms import RGBNormalizer,Composition,CustomResize,RandomHorizontalFlip,RandomVerticalFlip,CustomColorJitter


general_configs={
"data_path":"/home/nfs/inf6/data/datasets/MOVi/movi_c/",
"number_of_frames_per_video":24,
"max_objects_in_scene":11,
"batch_size":64,
"device" : torch.device("cuda" if torch.cuda.is_available() else "cpu"),
"img_height":64,
"img_width":64,
"channels":3,
"learning_rate":3e-4,
"num_epochs":60,
"trainingMode":1 #0 training by bounding boxes, 1 training by masks. 
}

encoder_configs={
        "token_dim":128,
        "attn_dim":128,
        "num_heads":4,
        "mlp_size":512,
        "num_tf_layers":4
}
decoder_configs={
        "token_dim":128,
        "attn_dim":128,
        "num_heads":4,
        "mlp_size":512,
        "num_tf_layers":4
}

data_transform_config={
        "img_height":general_configs["img_height"],
        "img_width":general_configs["img_width"],
        "vFlip_probability":0.6,
        "hFlip_probability":0.6,
        "color_jitter_brightness":(0.8, 1.2),
        "color_jitter_hue":(-0.3, 0.3),
        "color_jitter_contrast":(0.6, 1.8),
        "color_jitter_saturation":(0.5, 1.5)
}
transform_composition = Composition([
                                        RGBNormalizer(),
                                        CustomResize((data_transform_config["img_height"],data_transform_config["img_width"])),
                                        RandomVerticalFlip(data_transform_config["vFlip_probability"]),
                                        RandomHorizontalFlip(data_transform_config["hFlip_probability"]),
                                        CustomColorJitter(
                                            brightness=data_transform_config["color_jitter_brightness"],
                                            hue=data_transform_config["color_jitter_hue"],
                                            contrast=data_transform_config["color_jitter_contrast"],
                                            saturation=data_transform_config["color_jitter_saturation"]
                                        )
                                    ])
#transform_composition=None
validation_dataset = VideoDataset(data_path=general_configs["data_path"],
                            split='validation',
                            number_of_frames_per_video=general_configs["number_of_frames_per_video"],
                            max_objects_in_scene=general_configs["max_objects_in_scene"],
                            halve_dataset=True,
                            is_test_dataset=False,
                            transforms=transform_composition)
valid_loader = DataLoader(dataset=validation_dataset,
                            batch_size=general_configs["batch_size"],
                            shuffle=False,
                            drop_last=True)
train_dataset = VideoDataset(data_path=general_configs["data_path"],
                            split='train',
                            number_of_frames_per_video=general_configs["number_of_frames_per_video"],
                            max_objects_in_scene=general_configs["max_objects_in_scene"],
                            halve_dataset=False,
                            is_test_dataset=False
                            )
train_loader = DataLoader(dataset=train_dataset,
                            batch_size=general_configs["batch_size"],
                            shuffle=True,
                            drop_last=True)

In [None]:
import time
start_time = time.time()
for batch in train_loader:
    print(f"Data loading time: {time.time() - start_time:.2f} seconds")


In [3]:
def myTempFunc(T):
    temp=T*2
    temp+=T

In [None]:
import torch
from torch.profiler import profile, record_function, ProfilerActivity
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity

# Example: Profiling LPIPS computation
tensor1 = torch.rand(2, 10, 3, 64, 64).to('cuda')
tensor2 = torch.rand(2, 10, 3, 64, 64).to('cuda')
lpips = LearnedPerceptualImagePatchSimilarity(net_type='alex', normalize=True).to('cuda')

with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
    with record_function("lpips_computation"):
        for t in range(tensor1.shape[1]):
            myTempFunc(tensor1)
            myTempFunc(tensor2)
            loss = lpips(tensor1[:, t, :, :, :], tensor2[:, t, :, :, :])
            torch.cuda.empty_cache()  # Profile this too
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                      lpips_computation         0.00%       0.000us         0.00%       0.000us       0.000us      69.104ms      1095.97%      69.104ms      69.104ms             1  
                                      lpips_computation        32.76%      23.209ms        99.99%      70.826ms      70.826ms       0.000us         0.00%       6.305ms       6.305ms             1  
         

: 