In [None]:
import numpy as np
import torch
from torch.utils.data import DataLoader, RandomSampler
from torchinfo import summary
from tqdm import tqdm
from transformers.optimization import get_linear_schedule_with_warmup
import wandb

from dataset import ATMADataset
from models.timesformer_gru import TimesformerGRU

In [None]:
model = TimesformerGRU(pretrained_tsf="facebook/timesformer-base-finetuned-k400",
                       gru_hidden_size=128, gru_layers=2,
                       num_classes=2)

In [None]:
dummy_input = list(np.random.rand(8 ,3, 224, 224))
summary(model=model)

In [None]:
class TrainingLoop:
    def __init__(self, model, dataloader, optimizer, scheduler, num_epochs, device):
        self.model = model
        self.train_dataloader = dataloader
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.num_epochs = num_epochs
        self.device = device

    def _grad_norm(self):
        total_norm = 0
        for p in self.model.parameters():
            param_grad = p.grad
            if param_grad is not None:
                param_norm = param_grad.data.norm(2)
                total_norm += param_norm.item() ** 2
        total_norm = total_norm ** (1. / 2)
        return total_norm

    def train(self):
        self.model.to(self.device)
        self.model.train()
        for epoch in range(self.num_epochs):
            epoch_iterator = tqdm(self.train_dataloader, desc=f"Epoch {epoch + 1}/{self.num_epochs}")
            for step, (inputs, labels) in enumerate(epoch_iterator):
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                self.optimizer.zero_grad()
                outputs = self.model(inputs)
                # print(f"Output shape: {outputs.shape}")
                # print(f"Labels shape: {labels.shape}")

                # print(f"Outputs: {outputs[0]}")
                # print(f"Labels: {labels[0]}")

                loss = torch.nn.functional.cross_entropy(outputs, labels)

                loss.backward()
                self.optimizer.step()
                self.scheduler.step()
                grad_norm = self._grad_norm()
                # wandb.log({"batch_loss": loss.item(), "grad_norm": grad_norm, "epoch": epoch})

                epoch_iterator.set_postfix(loss=loss.item())

In [None]:
batch_size = 1
lr = 1e-4
warm_up_steps = 0
num_epochs = 100

In [None]:
dataset = ATMADataset(vid_folder_path="./datasets/ATMA-V/videos/train/aug",
                    label_path="./datasets/ATMA-V/labels/labels.txt")

train_sampler = RandomSampler(dataset)
train_dataloader = DataLoader(dataset, sampler=train_sampler, batch_size=batch_size)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warm_up_steps, num_training_steps=len(train_dataloader))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = "cpu"

training_loop = TrainingLoop(model=model,
                             dataloader=train_dataloader,
                             optimizer=optimizer,
                             scheduler=scheduler,
                             num_epochs=num_epochs,
                             device=device)

In [None]:
training_loop.train()

In [None]:
# Inputs:  torch.Size([1, 30, 16, 3, 224, 224])
# Outputs:  torch.Size([1, 30, 2])
# Labels:  torch.Size([1, 30, 2])