In [2]:
# Warnings ignoring
import warnings
warnings.filterwarnings("ignore")

# OS tools
import os
import typing
from pathlib import Path
from dataclasses import dataclass
from collections import Counter

# Tables, arrays, and plotters 
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from einops import rearrange

# Torch
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchmetrics import F1Score

# Video Processing
from torchvision.io import read_video
from torchvision.transforms import v2
import torchvision.transforms as tt
import torchvision.models as models

# Transformers
from transformers import TimesformerModel

# Lighting
import pytorch_lightning as pl
from pytorch_lightning import Trainer, strategies
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.utilities import grad_norm
from pytorch_lightning.loggers import TensorBoardLogger

In [3]:
class TimeSformerClassifier(nn.Module):
    def __init__(self, n_outputs, freeze=False):
        super(TimeSformerClassifier, self).__init__()
        self.n_outputs = n_outputs
        self.freeze = freeze
        
        # Load pretrained TimeSformer
        self.backbone = TimesformerModel.from_pretrained("facebook/timesformer-base-finetuned-k400")
        
        # Final classifier head
        self.classifier = torch.nn.Linear(self.backbone.config.hidden_size, self.n_outputs)
        
        # Freeze all layers except the head if freeze=True
        if self.freeze:
            self._freeze_layers()
    
    def _freeze_layers(self):
        # Freeze all parameters in the model
        for param in self.backbone.parameters():
            param.requires_grad = False

        # Unfreeze the final classification layer (head)
        for param in self.classifier.parameters():
            param.requires_grad = True

    def forward(self, x):
        outputs = self.backbone(pixel_values=x)
        pooled_output = outputs.last_hidden_state[:, 0]
        logits = self.classifier(pooled_output)
        return logits

def count_trainable_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [4]:
net = TimeSformerClassifier(101, True)
count_trainable_parameters(net)

77669

In [6]:
x = torch.rand(8, 32, 3, 224, 224)
net(x).shape

torch.Size([8, 101])