# Video Classification using Space-Time Attention (TimeSFormer)

Reference: "Is Space-Time Attention All You Need for Video Understanding?" (TimeSFormer), Bertasius et al., NeurIPS 2021. [https://arxiv.org/abs/2102.05095]


In [1]:
import torch 
from torch import nn 


from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.datasets import UCF101

from vision.transformers.blocks import MLP
from vision.transformers.attention import Attention

import lightning as L

## Space-Time Encoder 

The space time encoder is designed to follow the divided space-time attention module defined in the TimeSFormer paper.

![Space-Time Attention](images/st_attention.png)

*Figure: Divided space-time attention mechanism as described in the TimeSFormer paper*

In [2]:
class EncoderLayer(nn.Module):
    """
    Encoder layer block for ViT
    """
    def __init__(
        self, 
        num_heads: int,
        num_channels: int,
        d_linear: int,
        num_linear_layers: int = 2,
        num_groups: int = 8,
        dropout: float = 0.1,
        is_masked: bool = False
    ):
        super(EncoderLayer, self).__init__()
        self.norm1, self.norm2, self.norm3 = (
            nn.LayerNorm(num_channels),  nn.LayerNorm(num_channels), nn.LayerNorm(num_channels)
        )
        self.mha_space = Attention(dropout, num_heads, num_channels, num_groups)
        self.mha_time = Attention(dropout, num_heads, num_channels, num_groups)
        self.mlp = MLP(num_channels, d_linear, dropout, num_linear_layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch, time, patches, channels = x.shape
        # convert to [B, patches, time, channels] to compute attention across time dimension
        h = x.permute(0, 2, 1, 3).reshape(batch*patches, time, channels) 
        h = self.mha_time(self.norm1(h))  # 
        h = h.reshape(batch, patches, time, channels).permute(0, 2, 1, 3) + x
        h2 = h.reshape(batch*time, patches, channels)
        h2 = self.mha_space(self.norm2(h2))
        h2  = h2.reshape(batch, time, patches, channels) + h
        return self.mlp(self.norm3(h2)) + h2


class Encoder(nn.Module):
    def __init__(
        self, 
        num_heads: int,
        num_channels: int,
        num_layers: int,
        d_linear: int,
        num_linear_layers: int = 2,
        num_groups: int = 8,
        dropout: float = 0.1,
        is_masked: bool = False
    ):
        super(Encoder, self).__init__()
        self.layers = nn.ModuleList([
            EncoderLayer(
                num_heads, num_channels, d_linear, num_linear_layers, num_groups, dropout, is_masked
            ) for _ in range(num_layers)
        ])
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        for layer in self.layers:
            x = layer(x)
        return x 


class MLPClassicationHead(nn.Module):
    """
    MLP based classification head
    """
    def __init__(
        self,
        num_classes: int,
        num_channels:int,
        d_ff: int,
        num_layers: int = 2,
        dropout: float = 0.1,
    ) -> None:
        super(MLPClassicationHead, self).__init__()

        layers = []
        layers.append(nn.Linear(num_channels, d_ff, bias=True))
        for i in range(1, num_layers - 1):
            layers.append(nn.Linear(d_ff, d_ff, bias=True))
            layers.append(nn.ReLU())
        
        layers.append(nn.Linear(d_ff, num_classes, bias=True))
        self.mlp_layers = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.mlp_layers(x)

## TimeSFormer Model 

Much like a ViT, the TimeSFormer model patchifies input frames and adds positional embeddings to the input before passing it to the Encoder. In the space-time case, we want to encode positional information about patches through both space and time, so we learn two separate embeddings. Like any other transformer based classifier, we append a cls_token, and run the encoder outputs through a classification MLP at the end. 

In [3]:
class TimeSFormer(nn.Module):
    """
    Cloned from Vision Transformer -- need to adapt to TimeSFormer architecture
    """
    def __init__(
        self,
        num_classes,
        num_heads: int,
        d_model: int,
        d_mlp: int,
        patch_size: int = 16,
        frames: int = 8,
        image_size: tuple[int] = 224,
        num_encoder_layers: int = 2,
        encoder_mlp_depth: int = 2,
        classification_mlp_depth: int = 2,
        num_groups: int = 8,
        dropout: float = 0.1,
    ):
        super(TimeSFormer, self).__init__()
        self.d_model, self.patch_size, self.image_size = d_model, patch_size, image_size
        self.num_frames = frames
        self.n_patches = (image_size // patch_size) ** 2
        self.linear = nn.Linear((3*patch_size*patch_size), d_model)
        self.encoder = Encoder(
            num_heads, d_model, num_encoder_layers, d_mlp, encoder_mlp_depth, num_groups, dropout
        )

        self.cls_token = nn.Parameter(torch.zeros(1, 1, 1, d_model))
        self.classification_head = MLPClassicationHead(num_classes, d_model, d_mlp, classification_mlp_depth, dropout) # d_ff and depth are different things

        self.pos_embed_space = nn.Parameter(torch.zeros(1, 1, self.n_patches + 1, d_model))
        self.pos_embed_time = nn.Parameter(torch.zeros(1, self.num_frames, 1, d_model))
        
        nn.init.trunc_normal_(self.pos_embed_space, std=0.02)
        nn.init.trunc_normal_(self.pos_embed_time, std=0.02)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        forward pass for our vit classifier. takes in raw images and outputs a probability distribution over classes
        """
        batch_size, frames = x.shape[0], x.shape[2]
        x = self.linear(self._patchify(x))
        cls = self.cls_token.expand(batch_size, frames, -1, -1)
        x = torch.cat((cls, x), 2)
        x = x + self.pos_embed_space + self.pos_embed_time
        x = self.encoder(x)
        cls_token = x[:, :, 0, :]
        return self.classification_head(cls_token.mean(dim=1))
    
    def _patchify(self, x: torch.Tensor) -> torch.Tensor:
        """
        Splits a batch of videos into non-overlapping patches.

        Args:
            x (torch.Tensor): Input tensor of shape [B, C, F, H, W]

        Returns:
            torch.Tensor: Patchified tensor of shape [B, F, NUM_P, DIM_P],
                          where DIM_P = channels * patch_size * patch_size.
        """
        batch_size, channels, frames, _, _ = x.shape
        n_patch_side = self.image_size // self.patch_size
        x = x.permute(0, 2, 1, 3, 4)
        x = x.reshape(
            batch_size,
            frames,
            channels,
            n_patch_side,
            self.patch_size,
            n_patch_side,
            self.patch_size,
        )
        x = x.permute(0, 1, 3, 5, 2, 4, 6)
        return x.reshape(batch_size, frames, -1, channels * self.patch_size * self.patch_size)

## Lightning module

We will use pytorch lightning to reduce boiler plate (there's a lot in previous notebooks, despite the centralized modules in `/src`)

In [18]:
class LightningTimeSformer(L.LightningModule):
    def __init__(self, num_classes, num_heads, d_model, d_mlp):
        super().__init__()
        self.model = TimeSFormer(
            num_classes=num_classes,
            num_heads=num_heads,
            d_model=d_model,
            d_mlp=d_mlp
        )

    def forward(self, inputs):
        return self.model(inputs)

    def training_step(self, batch, batch_idx):
        inputs, target = batch
        output = self(inputs)
        loss = torch.nn.functional.cross_entropy(output, target)
        self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        inputs, target = batch
        output = self(inputs)
        loss = torch.nn.functional.cross_entropy(output, target)
        self.log("val_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.model.parameters(), lr=0.01)

## Load UCF101 Action Recognition dataset 

We use the UCF101 dataset which contains 13,320 videos from 101 action categories. This dataset is commonly used for benchmarking video action recognition models, such as basketball shooting, biking, diving, golf swinging, horse riding, and playing musical instruments.


In [9]:
train_dataset = UCF101(
    root='./data/UCF-101',
    annotation_path='./data/ucfTrainTestlist',
    frames_per_clip=8,
    train=True,
)

test_dataset = UCF101(
    root='./data/UCF-101',
    annotation_path='./data/ucfTrainTestlist',
    frames_per_clip=8,
    train=False,
)

100%|██████████| 833/833 [10:41<00:00,  1.30it/s]
100%|██████████| 833/833 [06:01<00:00,  2.30it/s]
100%|██████████| 833/833 [05:53<00:00,  2.36it/s]


In [10]:
import torch.nn.functional as F


def collate_ucf101(batch):
    # batch: list of (video, label, index) where label is detection labels 
    # and index is the index of the class for recognition
    xs, ys = [], []
    for v, _, l in batch:
        # v: T, H, W, C  (uint8)
        v = v.permute(0, 3, 1, 2)            # -> T, C, H, W
        v = v.float() / 255.0
        v = F.interpolate(v, size=(224, 224), mode='bilinear', align_corners=False)  # resize frames
        v = v.permute(1, 0, 2, 3).contiguous()  # -> C, F, H, W
        xs.append(v.clone())                  # new storage
        ys.append(int(l))
    return torch.stack(xs, 0), torch.tensor(ys, dtype=torch.long)

train_dataloader = DataLoader(
    train_dataset,
    batch_size=2,
    shuffle=True,
    num_workers=4,          # try 0 if debugging
    collate_fn=collate_ucf101,
    pin_memory=True,
)

test_dataloader = DataLoader(test_dataset, batch_size=2, shuffle=False, collate_fn=collate_ucf101)

In [19]:
lightning_timesformer = LightningTimeSformer(num_classes=101, num_heads=4, d_model=512, d_mlp=512)
trainer = L.Trainer(max_epochs=1)
trainer.fit(model=lightning_timesformer, train_dataloaders=train_dataloader)

💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/aryamanpandya/vision/.venv/lib/python3.10/site-packages/lightning/pytorch/trainer/configuration_validator.py:70: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type        | Params | Mode 
----------------------------------------------
0 | model | TimeSFormer | 18.7 M | train
----------------------------------------------
18.7 M    Trainable params
0         Non-trainable params
18.7 M    Total params
74.722    Total estimated model params size (MB)
60        Modules in train mode
0         Modules in eval mode


Epoch 0:   0%|          | 0/859661 [00:00<?, ?it/s] 



Epoch 0:   0%|          | 0/100 [06:03<?, ?it/s]19, 18.06it/s, v_num=1]
Epoch 0:   0%|          | 0/100 [06:03<?, ?it/s]
Epoch 0:   0%|          | 0/100 [06:03<?, ?it/s]
Epoch 0:   0%|          | 0/100 [06:03<?, ?it/s]
Epoch 0:   0%|          | 2199/859661 [01:27<9:27:43, 25.17it/s, v_num=1]


Detected KeyboardInterrupt, attempting graceful shutdown ...
Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x79417911c070>>
Traceback (most recent call last):
  File "/home/aryamanpandya/vision/.venv/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(
KeyboardInterrupt: 
Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x79417911c070>>
Traceback (most recent call last):
  File "/home/aryamanpandya/vision/.venv/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x79417911c070>>
Traceback (most recent call last):
  File "/home/aryamanpandya/vision/.venv/lib/python3.10/site-packages/ipykernel/

NameError: name 'exit' is not defined