In [1]:
# Warnings ignoring
import warnings
warnings.filterwarnings("ignore")

# OS tools
import os
import typing
from pathlib import Path
from dataclasses import dataclass
from collections import Counter

# Tables, arrays, and plotters 
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# Torch
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchmetrics import F1Score

# Video Processing
from torchvision.io import read_video
from torchvision.transforms import v2
import torchvision.transforms as tt
from torchvision.models.video import r3d_18

# Lighting
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.utilities import grad_norm

In [47]:
def count_trainable_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [49]:
def compute_tensor_size_conv(size, model):
    t_size = np.copy(size)
    for param in model.children():
        if hasattr(param, "weight") and hasattr(param, "kernel_size"):
            k, s, p = param.kernel_size, param.stride, param.padding
            for i in range(len(size)):
                t_size[i] = np.floor((t_size[i] + 2 * p[i] - k[i]) / s[i] + 1)
    return t_size

In [126]:
np.floor(27/4 + 1)

np.float64(7.0)

In [122]:
class AttentionFusion3D(nn.Module):
    def __init__(self, x_channels: int, y_channels: int, embed_channels: int):
        super(AttentionFusion3D, self).__init__()
        
        # Project both paths to same shape if needed
        self.align = nn.Conv3d(y_channels, x_channels, kernel_size=(5, 1, 1), stride=(8, 1, 1)) if y_channels != x_channels else nn.Identity()

        # Attention mechanism (Conv3D version)
        self.attention = nn.Sequential(
            nn.Conv3d(x_channels * 2, embed_channels, kernel_size=1),
            nn.Tanh(),
            nn.Conv3d(embed_channels, 2, kernel_size=1),  # Output: attention map for x and y
            nn.Softmax(dim=1)  # Softmax over 2 branches (not channels)
        )

    def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """
        x: [B, C, T, H, W]
        y: [B, C' or C, T, H, W]
        """
        y_aligned = self.align(y)  # Ensure both have same channel dim

        # Concatenate along channel dimension: [B, 2C, T, H, W]
        concat = torch.cat([x, y_aligned], dim=1)

        # Compute attention weights: [B, 2, T, H, W]
        weights = self.attention(concat)

        # Split weights: [B, 1, T, H, W] for x and y
        wx = weights[:, 0:1, :, :, :]
        wy = weights[:, 1:2, :, :, :]

        # Weighted sum
        fused = x * wx + y_aligned * wy
        return fused

class ResidualBlock(nn.Module):
    def __init__(
        self,
        in_dim: int,
        out_dim: int,
        config: list[dict],
        depth: int = 1,
        inplace: bool=False
    ):
        super(ResidualBlock, self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        
        self._consruct(config, depth, inplace)
    
    def _consruct(self, config, depth, inplace) -> None:
        layers = []
        
        for i in range(depth):
            for j, params in enumerate(config):
                layers.extend([
                    nn.Conv3d(
                        in_channels=self.in_dim,
                        out_channels=self.in_dim if j != len(config) - 1 else self.out_dim,
                        kernel_size=params.get("kernel", 1),
                        stride=params.get("stride", 1),
                        padding=params.get("padding", 0),
                        bias=False
                    ),
                    nn.BatchNorm3d(
                        num_features=self.in_dim if j != len(config) - 1 else self.out_dim
                    ),
                    nn.ReLU(inplace=inplace)
                ])
            
            if i != depth - 1:
                layers.append(
                    nn.Conv3d(self.out_dim, self.in_dim, kernel_size=1, stride=1),
                )
        self.hidden = nn.Sequential(*layers)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.hidden(x)

class SlowFastNet(nn.Module):
    def __init__(self, n_outputs: int, blocks: tuple[int, int, int, int]=(3, 4, 6, 3)):
        super(SlowFastNet, self).__init__()
        
        self.n_outputs = n_outputs
        
        
        fn_downsampling = lambda x: nn.Conv3d(x[0], x[1], kernel_size=1, stride=1, padding=0)
        
        # Fusions
        
        self.fusion_1 = AttentionFusion3D(64, 8, 64)
        self.fusion_2 = AttentionFusion3D(256, 32, 256)
        self.fusion_3 = AttentionFusion3D(512, 64, 512)
        self.fusion_4 = AttentionFusion3D(1024, 128, 1024)
        self.fusion_5 = AttentionFusion3D(2048, 256, 2048)
        
        # SlowPathWay
        
        self.slow_conv_pool_1 = nn.Sequential(
            nn.Conv3d(3, 64, kernel_size=(1, 7, 7), stride=(1, 1, 1), padding=(0, 2, 2)),
            nn.Conv3d(64, 64, kernel_size=(1, 1, 1), stride=(1, 2, 2), padding=(0, 1, 1)),
            nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 1, 1)),
            nn.Conv3d(64, 64, kernel_size=(1, 1, 1), stride=(1, 2, 2), padding=(0, 1, 1))
        )
        
        self.slow_res2 = ResidualBlock(
            in_dim=64, 
            out_dim=256, 
            config=[
                {"kernel": (1, 1, 1), "padding": (0, 0, 0)},
                {"kernel": (1, 3, 3), "padding": (0, 0, 0)},
                {"kernel": (1, 1, 1), "padding": (0, 1, 1)},
            ],
            depth=blocks[0]
        )
        
        self.slow_downsampling_2 = fn_downsampling([256, 128])
        
        self.slow_res_3 = ResidualBlock(
            in_dim=128, 
            out_dim=512, 
            config=[
                {"kernel": (1, 1, 1), "padding": (0, 13, 13), "stride": (1, 2, 2)},
                {"kernel": (1, 3, 3), "padding": (0, 0, 0)},
                {"kernel": (1, 1, 1), "padding": (0, 1, 1)},
            ],
            depth=blocks[1]
        )
        
        self.slow_downsampling_3 = fn_downsampling([512, 256])
        
        self.slow_res_4 = ResidualBlock(
            in_dim=256, 
            out_dim=1024, 
            config=[
                {"kernel": (3, 1, 1), "padding": (1, 7, 7), "stride": (1, 3, 3)},
                {"kernel": (1, 3, 3), "padding": (0, 2, 2)},
                {"kernel": (1, 1, 1), "padding": (0, 1, 1)},
            ],
            depth=blocks[2]
        )
        
        self.slow_downsampling_4 = fn_downsampling([1024, 512])
        
        self.slow_res_5 = ResidualBlock(
            in_dim=512, 
            out_dim=2048, 
            config=[
                {"kernel": (3, 1, 1), "padding": (1, 8, 8), "stride": (1, 5, 5)},
                {"kernel": (1, 3, 3), "padding": (0, 1, 1)},
                {"kernel": (1, 1, 1), "padding": (0, 1, 1)},
            ],
            depth=blocks[3]
        )
        
        # FastPathWay
        
        self.fast_conv_pool_1 = nn.Sequential(
            nn.Conv3d(3, 8, kernel_size=(1, 7, 7), stride=(1, 1, 1), padding=(0, 2, 2)),
            nn.Conv3d(8, 8, kernel_size=(1, 1, 1), stride=(1, 2, 2), padding=(0, 1, 1)),
            nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 1, 1)),
            nn.Conv3d(8, 8, kernel_size=(1, 1, 1), stride=(1, 2, 2), padding=(0, 1, 1))
        )
        
        self.fast_res2 = ResidualBlock(
            in_dim=8, 
            out_dim=32, 
            config=[
                {"kernel": (3, 1, 1), "padding": (1, 0, 0)},
                {"kernel": (1, 3, 3), "padding": (0, 0, 0)},
                {"kernel": (1, 1, 1), "padding": (0, 1, 1)},
            ],
            depth=blocks[0]
        )
        
        self.fast_downsampling_2 = fn_downsampling([32, 16])
        
        self.fast_res_3 = ResidualBlock(
            in_dim=16, 
            out_dim=64, 
            config=[
                {"kernel": (3, 1, 1), "padding": (1, 13, 13), "stride": (1, 2, 2)},
                {"kernel": (1, 3, 3), "padding": (0, 0, 0)},
                {"kernel": (1, 1, 1), "padding": (0, 1, 1)},
            ],
            depth=blocks[1]
        )
        
        self.fast_downsampling_3 = fn_downsampling([64, 32])
        
        self.fast_res_4 = ResidualBlock(
            in_dim=32, 
            out_dim=128, 
            config=[
                {"kernel": (3, 1, 1), "padding": (1, 7, 7), "stride": (1, 3, 3)},
                {"kernel": (1, 3, 3), "padding": (0, 2, 2)},
                {"kernel": (1, 1, 1), "padding": (0, 1, 1)},
            ],
            depth=blocks[2]
        )
        
        self.fast_downsampling_4 = fn_downsampling([128, 64])
        
        self.fast_res_5 = ResidualBlock(
            in_dim=64, 
            out_dim=256, 
            config=[
                {"kernel": (3, 1, 1), "padding": (1, 8, 8), "stride": (1, 5, 5)},
                {"kernel": (1, 3, 3), "padding": (0, 1, 1)},
                {"kernel": (1, 1, 1), "padding": (0, 1, 1)},
            ],
            depth=blocks[3]
        )
        
        # Classification
        self.avg_pool = nn.Sequential(
            nn.AdaptiveAvgPool3d(1),
            nn.Flatten(),
            nn.Dropout(p=0.5, inplace=False),
            nn.Linear(in_features=2048, out_features=n_outputs)
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_sl = x[:, :, ::8, :, :]
        
        # Conv Pool 1
        x = self.fast_conv_pool_1(x)
        x_sl = self.slow_conv_pool_1(x_sl)
        x_sl = self.fusion_1(x_sl, x)
        
        # Res 2
        x = self.fast_res2(x)
        x_sl = self.slow_res2(x_sl)
        x_sl = self.fusion_2(x_sl, x)
        
        # Downsampling 2
        x = self.fast_downsampling_2(x)
        x_sl = self.slow_downsampling_2(x_sl)
        
        # Res 3
        x = self.fast_res_3(x)
        x_sl = self.slow_res_3(x_sl)
        x_sl = self.fusion_3(x_sl, x)
        
        # Downsampling 3
        x = self.fast_downsampling_3(x)
        x_sl = self.slow_downsampling_3(x_sl)
        
        # Res 4
        x = self.fast_res_4(x)
        x_sl = self.slow_res_4(x_sl)
        x_sl = self.fusion_4(x_sl, x)
        
        # Downsampling 4
        x = self.fast_downsampling_4(x)
        x_sl = self.slow_downsampling_4(x_sl)
        
        # Res 5
        x = self.fast_res_5(x)
        x_sl = self.slow_res_5(x_sl)
        
        # Concat
        z = self.fusion_5(x_sl, x)
        z = self.avg_pool(z)
        
        return z

def slowfast_r18(n_outputs: int) -> nn.Module:
    return SlowFastNet(
        n_outputs=n_outputs,
        blocks=(1, 1, 1, 1)
    )

def slowfast_r50(n_outputs: int) -> nn.Module:
    return SlowFastNet(
        n_outputs=n_outputs,
        blocks=(3, 4, 6, 3)
    )

In [124]:
net = slowfast_r50(101)
x = torch.rand((8, 3, 32, 224, 224))
out = net(x)
f"{count_trainable_parameters(net):,}"

'39,601,551'

In [131]:
class Fusion3D(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel: int = None,
        stride: int = None,
    ):
        """Fusion 3D tensors

        Fusion tensor with different channel and resolution sizes. When we call it, the `x` is source tensor [B, T, C, H, W], which should to
        be projected to `y` tensor shape [B, T', C', H, W].

        Args:
            in_channels (int): The source channel number
            out_channels (int): The target channel number
            kernel (int, optional): The kernel size to projected the tensors resolutions
            stride (int, optional): The stride to projected the tensors resolutions
        """
        super(Fusion3D, self).__init__()

        self.proj_channels = (
            nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=1)
            if in_channels != out_channels
            else nn.Identity()
        )
        self.proj_resolution = (
            nn.Conv3d(
                in_channels=out_channels,
                out_channels=out_channels,
                kernel_size=(kernel, 1, 1),
                stride=(stride, 1, 1),
            )
            if kernel is not None and stride is not None
            else nn.Identity()
        )

    def forward(
        self, x: torch.Tensor, y: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Fusion 3D tensors

        Fusion tensor with different channel and resolution sizes. When we call it, the `x` is source tensor [B, T, C, H, W], which should to
        be projected to `y` tensor shape [B, T', C', H, W].

        Args:
            x (Tensor): The source tensor with shape [B, T, C, H, W]
            y (Tensor): The target tensor with shape [B, T', C', H, W]

        Returns:
            out (Tuple[Tensor, Tensor]): The concated tensors with shape [B, T', 2C', H, W], and projectiled source tensor with shape [B, T', C', H, W]
        """
        projected = self.proj_channels(x)
        projected = self.proj_resolution(projected)

        z = torch.cat([y, projected], dim=1)

        return z, projected


class AttentionFusion3D(Fusion3D):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        embed: int,
        kernel: int = None,
        stride: int = None,
    ):
        super().__init__(in_channels, out_channels, kernel, stride)

        self.attention = nn.Sequential(
            nn.Conv3d(out_channels * 2, embed, kernel_size=1),
            nn.Tanh(),
            nn.Conv3d(embed, 2, kernel_size=1),
            nn.Softmax(dim=1),
        )
    
    def forward(
        self, x: torch.Tensor, y: torch.Tensor
    ) -> torch.Tensor: 
        z, projected = super().forward(x, y)
        weights = self.attention(z)
        
        # Split weights: [B, 1, T, H, W] for x and y
        wx = weights[:, 0:1, :, :, :]
        wy = weights[:, 1:2, :, :, :]

        # Weighted sum
        fused = y * wx + projected * wy
        return fused

In [132]:
fast = torch.rand((8, 8, 32, 224, 224))
slow = torch.rand((8, 32, 4, 224, 224))

fusion = AttentionFusion3D(8, 32, 32, 5, 8)

fusion(fast, slow).shape

torch.Size([8, 32, 4, 224, 224])