In [1]:
from vit import ViT
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from utils.util import count_model_params, train_epoch,eval_model,train_model
import os
import shutil
from torch.utils.tensorboard import SummaryWriter
from loader.Dataset import VideoDataset 
from torch.utils.data import DataLoader

%load_ext autoreload
%autoreload 2

data_path='/home/nfs/inf6/data/datasets/MOVi/movi_c/'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
number_of_frames_per_video=24
max_objects_in_scene=11
batch_size=64

In [None]:
train_dataset=VideoDataset(data_path,split='train',number_of_frames_per_video=number_of_frames_per_video,max_objects_in_scene=max_objects_in_scene) 
train_loader = DataLoader(dataset=train_dataset,
                            batch_size=batch_size,
                            shuffle=True)


test_dataset=VideoDataset(data_path,split='validation',number_of_frames_per_video=number_of_frames_per_video,max_objects_in_scene=max_objects_in_scene)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False) 

coms,bboxs,masks,rgbs,flows=next(iter(train_loader))
print(f"Shapes: >>>>>>>>>>>>>>>>> \r\n{coms.shape=}, \r\n{bboxs.shape=}, \r\n{masks.shape=}, \r\n{rgbs.shape=}, \r\n{flows.shape=}\r\n<<<<<<<<<<<<<<<<<<")

In [None]:
vit = ViT(
        patch_size=8,
        token_dim=128,
        attn_dim=128,
        num_heads=4,
        mlp_size=512,
        num_tf_layers=4,
        num_classes=10
    ).to(device)
print(f"ViT has {count_model_params(vit)} parameters")
vit

In [5]:
with torch.no_grad():
    y = vit(rgbs,masks=masks,max_objects_in_scene=11)
attn_maps = vit.get_attn_mask()
print(f"Input Shape: {rgbs.shape}")
print(f"Output Shape: {y.shape}")
print(f"Found {len(attn_maps)} Attn Masps of shape {attn_maps[0].shape}")

KeyboardInterrupt: 

In [5]:
vit = ViT(
        patch_size=8,
        token_dim=128,
        attn_dim=128,
        num_heads=4,
        mlp_size=512,
        num_tf_layers=4,
        num_classes=10
    ).to(device)

# classification loss function
criterion = nn.CrossEntropyLoss()

# Observe that all parameters are being optimized
optimizer = torch.optim.Adam(vit.parameters(), lr=3e-4)

# Decay LR by a factor of 3 every 5 epochs
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=1/3)

TBOARD_LOGS = os.path.join(os.getcwd(), "tboard_logs", "ViT_30")
if not os.path.exists(TBOARD_LOGS):
    os.makedirs(TBOARD_LOGS)

shutil.rmtree(TBOARD_LOGS)
writer = SummaryWriter(TBOARD_LOGS)

In [None]:
train_loss, val_loss, loss_iters, valid_acc = train_model(
        model=vit,
        optimizer=optimizer,
        scheduler=scheduler,
        criterion=criterion,
        train_loader=train_loader,
        valid_loader=test_loader,
        num_epochs=30,
        tboard=writer
    )