# Video Classification using Space-Time Attention (TimeSFormer)

Reference: "Is Space-Time Attention All You Need for Video Understanding?" (TimeSFormer), Bertasius et al., NeurIPS 2021. [https://arxiv.org/abs/2102.05095]


In [1]:
import torch 
from torch import nn 

import torchvision 

In [None]:
from vision.transformers.blocks import MLP
from vision.transformers.attention import Attention

class EncoderLayer(nn.Module):
    """
    Encoder layer block for ViT
    """
    def __init__(
        self, 
        num_heads: int,
        num_channels: int,
        d_linear: int,
        num_linear_layers: int = 2,
        num_groups: int = 8,
        dropout: float = 0.1,
        is_masked: bool = False
    ):
        super(EncoderLayer, self).__init__()
        self.norm1, self.norm2 = nn.LayerNorm(num_channels),  nn.LayerNorm(num_channels)
        self.mha_time = Attention(dropout, num_heads, num_channels, num_groups)
        self.mha_space = Attention(dropout, num_heads, num_channels, num_groups)
        self.mlp = MLP(num_channels, d_linear, dropout, num_linear_layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        h = self.mha(self.norm1(x)) + x
        return self.mlp(self.norm2(h)) + h


class Encoder(nn.Module):
    def __init__(
        self, 
        num_heads: int,
        num_channels: int,
        num_layers: int,
        d_linear: int,
        num_linear_layers: int = 2,
        num_groups: int = 8,
        dropout: float = 0.1,
        is_masked: bool = False
    ):
        super(Encoder, self).__init__()
        self.layers = nn.ModuleList([
            EncoderLayer(
                num_heads, num_channels, d_linear, num_linear_layers, num_groups, dropout, is_masked
            ) for _ in range(num_layers)
        ])
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        for layer in self.layers:
            x = layer(x)
        return x 

In [None]:
class MLPClassicationHead(nn.Module):
    """
    Class implementation of a position wise MLP
    """
    def __init__(
        self,
        num_classes: int,
        num_channels:int,
        d_ff: int,
        num_layers: int = 2,
        dropout: float = 0.1,
    ) -> None:
        super(MLPClassicationHead, self).__init__()

        layers = []
        layers.append(nn.Linear(num_channels, d_ff, bias=True))
        for i in range(1, num_layers - 1):
            layers.append(nn.Linear(d_ff, d_ff, bias=True))
            layers.append(nn.ReLU())
        
        layers.append(nn.Linear(d_ff, num_classes, bias=True))
        self.mlp_layers = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.mlp_layers(x)

class TimeSFormer(nn.Module):
    """
    Cloned from Vision Transformer -- need to adapt to TimeSFormer architecture
    """
    def __init__(
        self,
        num_classes,
        num_heads: int,
        d_model: int,
        d_mlp: int,
        patch_size: int = 16,
        image_size: tuple[int] = 32,
        num_encoder_layers: int = 2,
        encoder_mlp_depth: int = 2,
        classification_mlp_depth: int = 2,
        num_groups: int = 8,
        dropout: float = 0.1,
    ):
        super(TimeSFormer, self).__init__()
        self.d_model, self.patch_size, self.image_size = d_model, patch_size, image_size
        self.n_patches = (image_size // patch_size) ** 2  # assumes square image 
        self.linear = nn.Linear((3*patch_size*patch_size), d_model)  # assumes rgb image 
        self.encoder = Encoder(
            num_heads, d_model, num_encoder_layers, d_mlp, encoder_mlp_depth, num_groups, dropout
        )

        self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))  # assumes rgb image 
        self.classification_head = MLPClassicationHead(num_classes, d_model, d_mlp, classification_mlp_depth, dropout) # d_ff and depth are different things

        self.pos_embed = nn.Parameter(torch.zeros(1, self.n_patches + 1, d_model))
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        forward pass for our vit classifier. takes in raw images and outputs a probability distribution over classes
        """
        batch_size = x.shape[0]
        x = self.linear(self._patchify(x))
        cls = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat((cls, x), 1)
        x = x + self.pos_embed[:, : x.size(1), :].to(x.device)
        x = self.encoder(x)
        return self.classification_head(x[:,0,:])
    
    def _patchify(self, x: torch.Tensor) -> torch.Tensor:
        batch_size, channels, _, _, frames = x.shape  # assumes tensor structure - double check when dataset loaded
        n_patch_side = self.image_size // self.patch_size
        x = x.reshape(
            batch_size,
            channels,
            n_patch_side,
            self.patch_size,
            n_patch_side,
            self.patch_size,
            frames
        )
        x = x.permute(0, 2, 4, 1, 3, 5, 6)
        return x.reshape(batch_size, -1, channels * self.patch_size * self.patch_size, frames)
        
    # Removed _positional_embedding, as we now use a learnable positional embedding


load ucf101 from torchvision

In [2]:
from torchvision.datasets import UCF101
from torchvision import transforms

# Define transforms for video frames
video_transform = transforms.Compose([
    transforms.Resize((128, 171)),  # resize shorter side to 128, keep aspect ratio
    transforms.CenterCrop(112),     # crop to 112x112 as in TimeSFormer paper
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.43216, 0.394666, 0.37645], std=[0.22803, 0.22145, 0.216989]),
])

# Set the root directory where UCF101 videos and annotation files are/will be stored
ucf101_root = "./data/ucf101"
ucf101_anns = "./data/ucf101_anns"

# Download the dataset (split 1, train)
ucf_train = UCF101(
    root=ucf101_root,
    annotation_path=ucf101_anns,
    frames_per_clip=16,
    step_between_clips=1,
    train=True,
    transform=video_transform,
)

# Download the dataset (split 1, test)
ucf_test = UCF101(
    root=ucf101_root,
    annotation_path=ucf101_anns,
    frames_per_clip=16,
    step_between_clips=1,
    train=False,
    transform=video_transform,
)


FileNotFoundError: [Errno 2] No such file or directory: './data/ucf101'