In [1]:
# Warnings ignoring
import warnings
warnings.filterwarnings("ignore")

# OS tools
import os
import typing
from pathlib import Path
from dataclasses import dataclass
from collections import Counter

# Tables, arrays, and plotters 
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# Torch
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchmetrics import F1Score

# Video Processing
from torchvision.io import read_video
from torchvision.transforms import v2
import torchvision.transforms as tt
import torchvision.models as models
from torchvision.models.optical_flow import raft_small
from torchvision.utils import flow_to_image
from torchvision.transforms.functional import resize

# Lighting
import pytorch_lightning as pl
from pytorch_lightning import Trainer, strategies
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.utilities import grad_norm
from pytorch_lightning.loggers import TensorBoardLogger

# Ready scripts
from lib.utils import init_weights, count_trainable_parameters
from lib.trainer import HParams, VideoDataModule

In [8]:
torch.set_float32_matmul_precision('medium')

In [9]:
class OpticalFlow(nn.Module):
    def __init__(self):
        super(OpticalFlow, self).__init__()
        self.backbone = raft_small(pretrained=True)
        
        for param in self.backbone.parameters():
            param.requires_grad = False
            
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # [B, C, T, H, W] -> [B, T, C, H, W]
        x = x.permute(0, 2, 1, 3, 4)
        T = x.shape[1]
        
        flow_seq = []

        for t in range(T - 1):
            img1 = x[:, t]  # [B, 3, H, W]
            img2 = x[:, t + 1]  # [B, 3, H, W]

            with torch.no_grad():
                flow = self.backbone(img1, img2)  # returns [B, 2, H', W']
            flow_img = flow_to_image(flow[-1])
            # Resize back to original shape
            flow_seq.append(flow_img.unsqueeze(2))  # [B, 2, 1, H, W]

        # Pad the first frame with zeros to match original T
        zero_flow = torch.zeros_like(flow_seq[0])
        flow_seq = [zero_flow] + flow_seq

        flow_tensor = torch.cat(flow_seq, dim=2)  # [B, 2, T, H, W]
        return flow_tensor


class D3DNet(nn.Module):
    def __init__(self, n_outputs: int):
        super(D3DNet, self).__init__()
        
        self.raft = OpticalFlow()
        self.backbone = models.video.r3d_18(pretrained=True)
        self.backbone.fc = nn.Identity()
        
        self.teacher = models.resnet18(pretrained=True)
        self.teacher.fc = nn.Identity()
        
        self.classifier = nn.Sequential(
            nn.Linear(512, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            
            nn.Linear(512, n_outputs)
        )
        
        for param in self.backbone.parameters():
            param.requires_grad = False
        
        for param in self.teacher.parameters():
            param.requires_grad = False
    
    def forward(self, x: torch.Tensor):
        feat3d = self.backbone(x)
        logits = self.classifier(feat3d)
        
        if self.training:
            y: torch.Tensor = self.raft(x)
            B, C, T, H, W = y.shape
            y = y.permute(0, 2, 1, 3, 4).reshape(B * T, C, H, W)
            feat2d = self.teacher(y / 255)
            feat2d = feat2d.view(B, T, -1).mean(dim=1)
            return logits, feat3d, feat2d
        return logits

In [10]:
class train_model(pl.LightningModule):
    def __init__(self, model: nn.Module = None, params: HParams = None, alpha: float = 0.5):
        super().__init__()
        self.save_hyperparameters(
            params.__dict__,
            ignore=("dataset_dir", "train_meta", "test_meta", "validation_meta"),
        )
        self.params = params
        self.model = model

        self.accuracy = F1Score(
            task="multiclass", num_classes=params.n_classes, average="micro"
        )
        self.criterion = nn.CrossEntropyLoss(label_smoothing=params.ls)
        self.criterion_distill = nn.MSELoss()
        self.alpha = alpha

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y, _ = batch
        logits, feat3d, feat2d = self(x)
        
        loss_cls = self.criterion(logits, y)
        loss_distill = self.criterion_distill(feat3d, feat2d.detach())
        
        loss = loss_cls + self.alpha * loss_distill
        acc = self.accuracy(logits, y)

        self.log_dict(
            {
                "train_loss": loss,
                "train_acc": acc,
            },
            on_step=True,
            on_epoch=True,
            prog_bar=True,
        )

        return loss

    def validation_step(self, batch, batch_idx):
        x, y, _ = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        acc = self.accuracy(logits, y)

        self.log_dict(
            {
                "val_loss": loss,
                "val_acc": acc,
            },
            on_step=True,
            on_epoch=True,
        )

        return loss

    def test_step(self, batch, batch_idx):
        x, y, _ = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        acc = self.accuracy(logits, y)

        self.log_dict(
            {
                "test_loss": loss,
                "test_acc": acc,
            },
            on_step=True,
            on_epoch=True,
        )

        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(
            self.parameters(), lr=self.params.lr, momentum=0.9, weight_decay=1e-4
        )

        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
        return [optimizer], [scheduler]

    def on_before_optimizer_step(self, optimizer):
        norm_order = 2.0
        norms = grad_norm(self, norm_type=norm_order)
        self.log(
            "grad_norm",
            norms[f"grad_{norm_order}_norm_total"],
            on_step=True,
            on_epoch=False,
        )

In [11]:
os.cpu_count()

24

In [12]:
hparams = HParams(
    num_workers=20,
    clip_format="CTHW",
    arch="d3d",
    num_epoch=20,
    clip_len=32,
)
data_module = VideoDataModule(hparams)

agent = D3DNet(
    n_outputs=hparams.n_classes
)

print("Number of model paramteres:", count_trainable_parameters(agent))

Number of model paramteres: {'trainable': '314,469', 'total': '45,647,415'}


In [13]:
%load_ext tensorboard
%tensorboard --logdir logs

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Reusing TensorBoard on port 6006 (pid 5733), started 3:06:11 ago. (Use '!kill 5733' to kill it.)

In [14]:
checkpoint_callback_img = ModelCheckpoint(
    monitor='val_loss',
    dirpath=hparams.output_dir,
    filename=f"best_model_{hparams.arch}",
    save_top_k=1,
    mode='min',
)

model = train_model(model=agent, params=hparams)

logger = TensorBoardLogger("logs", name=hparams.arch)

trainer = Trainer(
    max_epochs=hparams.num_epoch,
    callbacks=[checkpoint_callback_img],
    accelerator="auto", 
    devices="auto",
    logger=logger
)

trainer.fit(model, data_module)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name              | Type              | Params | Mode 
----------------------------------------------------------------
0 | model             | D3DNet            | 45.6 M | train
1 | accuracy          | MulticlassF1Score | 0      | train
2 | criterion         | CrossEntropyLoss  | 0      | train
3 | criterion_distill | MSELoss           | 0      | train
----------------------------------------------------------------
314 K     Trainable params
45.3 M    Non-trainable params
45.6 M    Total params
182.590   Total estimated model params size (MB)
381       Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=20` reached.


In [15]:
best_model_path = checkpoint_callback_img.best_model_path
info = trainer.test(
    model=model,
    dataloaders=data_module,
    ckpt_path=best_model_path
)

Restoring states from the checkpoint path at /home/slauva/Documents/innopolis/Computer Vision 2025/final/temp/saved_models/best_model_d3d-v1.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at /home/slauva/Documents/innopolis/Computer Vision 2025/final/temp/saved_models/best_model_d3d-v1.ckpt


Testing: |          | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     test_acc_epoch         0.6732443571090698
     test_loss_epoch         3.59652042388916
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
