In [2]:
# Warnings ignoring
import warnings
warnings.filterwarnings("ignore")

# OS tools
import os
import typing
from pathlib import Path
from dataclasses import dataclass
from collections import Counter

# Tables, arrays, and plotters 
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from einops import rearrange

# Torch
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchmetrics import F1Score

# Video Processing
from torchvision.io import read_video
from torchvision.transforms import v2
import torchvision.transforms as tt
import torchvision.models as models

# Transformers
from transformers import TimesformerModel

# Lighting
import pytorch_lightning as pl
from pytorch_lightning import Trainer, strategies
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.utilities import grad_norm
from pytorch_lightning.loggers import TensorBoardLogger

In [3]:
class PatchEmbed(nn.Module):
    def __init__(self, in_channels=3, patch_size=16, embed_dim=768):
        super(PatchEmbed, self).__init__()
        
        self.proj = nn.Conv3d(
            in_channels,
            embed_dim,
            kernel_size=(1, patch_size, patch_size),
            stride=(1, patch_size, patch_size)
        )

    def forward(self, x):
        # x: (B, C, T, H, W)
        x = self.proj(x)  # (B, embed_dim, T, H/patch, W/patch)
        x = rearrange(x, 'b c t h w -> b (t h w) c')
        return x

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8):
        super(Attention, self).__init__()
        self.num_heads = num_heads
        self.scale = (dim // num_heads) ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=False)
        self.proj = nn.Linear(dim, dim)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x)  # (B, N, 3C)
        qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads)
        q, k, v = qkv.permute(2, 0, 3, 1, 4)  # Each: (B, heads, N, dim)

        attn = (q @ k.transpose(-2, -1)) * self.scale  # (B, heads, N, N)
        attn = attn.softmax(dim=-1)

        x = (attn @ v)  # (B, heads, N, dim)
        x = x.transpose(1, 2).reshape(B, N, C)  # (B, N, C)
        return self.proj(x)

class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4.0, dropout=0.1):
        super(TransformerBlock, self).__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = Attention(dim, num_heads)
        self.norm2 = nn.LayerNorm(dim)

        self.mlp = nn.Sequential(
            nn.Linear(dim, int(dim * mlp_ratio)),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(int(dim * mlp_ratio), dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

class TimeSformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, num_frames=8, in_channels=3, embed_dim=768, depth=12, num_heads=12, num_classes=400):
        super(TimeSformer, self).__init__()

        self.patch_embed = PatchEmbed(in_channels, patch_size, embed_dim)
        num_patches = (img_size // patch_size) ** 2 * num_frames

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(0.1)

        self.blocks = nn.Sequential(*[
            TransformerBlock(embed_dim, num_heads) for _ in range(depth)
        ])

        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

        self._init_weights()

    def _init_weights(self):
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        nn.init.trunc_normal_(self.head.weight, std=0.02)
        if self.head.bias is not None:
            nn.init.zeros_(self.head.bias)

    def forward(self, x):
        # x: (B, C, T, H, W)
        x = self.patch_embed(x)  # (B, N, C)
        B, N, C = x.shape

        cls_tokens = self.cls_token.expand(B, -1, -1)  # (B, 1, C)
        x = torch.cat((cls_tokens, x), dim=1)  # (B, N+1, C)
        x = x + self.pos_embed[:, :N+1]
        x = self.pos_drop(x)

        x = self.blocks(x)
        x = self.norm(x)
        return self.head(x[:, 0])  # class token

def count_trainable_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [5]:
net = TimeSformer(
    num_classes=101,
    img_size=224,
    patch_size=16,
    num_frames=32,
    in_channels=3,
    embed_dim=768,
    depth=4,
    num_heads=4,    
)
f"{count_trainable_parameters(net):,}"

'33,830,501'

In [7]:
x = torch.rand(8, 3, 32, 224, 224)
net(x).shape

torch.Size([8, 101])