In [1]:
# Warnings ignoring
import warnings
warnings.filterwarnings("ignore")

# OS tools
import os
import typing
from pathlib import Path
from dataclasses import dataclass
from collections import Counter

# Tables, arrays, and plotters 
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# Torch
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchmetrics import F1Score

# Video Processing
from torchvision.io import read_video
from torchvision.transforms import v2
import torchvision.transforms as tt

# Lighting
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.utilities import grad_norm

In [2]:
class SlowFast(nn.Module):
    
    def __init__(self, n_classes: int, freeze: bool) -> None:
        super(SlowFast, self).__init__()
        
        self.backbone = torch.hub.load('facebookresearch/pytorchvideo', 'slowfast_r50', pretrained=True)
        in_features = self.backbone.blocks[-1].proj.in_features
        self.backbone.blocks[-1].proj = nn.Linear(in_features=in_features, out_features=n_classes)
        
        if freeze:
            self._freeze_layers()
    
    def _freeze_layers(self):
        # Freeze all parameters in the model
        for param in self.backbone.blocks[:-1].parameters():
            param.requires_grad = False
        
        # Unfreeze the final classification layer (head)
        for param in self.backbone.blocks[-1].parameters():
            param.requires_grad = True
    
    def forward(self, x):
        s_x = x[:, :, :8, :, :]
        y = self.backbone([s_x, x])
        return y

def count_trainable_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [3]:
net = SlowFast(101, True)
f"{count_trainable_parameters(net):,}"

Using cache found in /home/slauva/.cache/torch/hub/facebookresearch_pytorchvideo_main


'232,805'

In [4]:
net.backbone.blocks

ModuleList(
  (0): MultiPathWayWithFuse(
    (multipathway_blocks): ModuleList(
      (0): ResNetBasicStem(
        (conv): Conv3d(3, 64, kernel_size=(1, 7, 7), stride=(1, 2, 2), padding=(0, 3, 3), bias=False)
        (norm): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activation): ReLU()
        (pool): MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=[0, 1, 1], dilation=1, ceil_mode=False)
      )
      (1): ResNetBasicStem(
        (conv): Conv3d(3, 8, kernel_size=(5, 7, 7), stride=(1, 2, 2), padding=(2, 3, 3), bias=False)
        (norm): BatchNorm3d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activation): ReLU()
        (pool): MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=[0, 1, 1], dilation=1, ceil_mode=False)
      )
    )
    (multipathway_fusion): FuseFastToSlow(
      (conv_fast_to_slow): Conv3d(8, 16, kernel_size=(7, 1, 1), stride=(4, 1, 1), padding=(3, 0, 0), bias=False)
    

In [84]:
inp = torch.rand((8, 3, 32, 224, 224))
out = net(inp)

out.argmax(dim=1)

tensor([100, 100,  25,  18,  78,  81,  41,   8])

In [79]:
net2 = SlowFast(101, False)
f"{count_trainable_parameters(net2):,}"

Using cache found in /home/slauva/.cache/torch/hub/facebookresearch_pytorchvideo_main


'33,877,293'