# Video Classification using Space-Time Attention (TimeSFormer)

Reference: "Is Space-Time Attention All You Need for Video Understanding?" (TimeSFormer), Bertasius et al., NeurIPS 2021. [https://arxiv.org/abs/2102.05095]


In [1]:
import torch 
from torch import nn 

import torchvision 

In [2]:
from vision.transformers.blocks import MLP
from vision.transformers.attention import Attention

class EncoderLayer(nn.Module):
    """
    Encoder layer block for ViT
    """
    def __init__(
        self, 
        num_heads: int,
        num_channels: int,
        d_linear: int,
        num_linear_layers: int = 2,
        num_groups: int = 8,
        dropout: float = 0.1,
        is_masked: bool = False
    ):
        super(EncoderLayer, self).__init__()
        self.norm1, self.norm2, self.norm3 = (
            nn.LayerNorm(num_channels),  nn.LayerNorm(num_channels), nn.LayerNorm(num_channels)
        )
        self.mha_space = Attention(dropout, num_heads, num_channels, num_groups)
        self.mha_time = Attention(dropout, num_heads, num_channels, num_groups)
        self.mlp = MLP(num_channels, d_linear, dropout, num_linear_layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        #  assumes x: b, time, num_patches + 1, d_model
        #  at each frame, we have num_patches + 1 patches 
        #  that have d_model dimensionality
        batch, time, patches, channels = x.shape
        h = x.permute(0, 2, 1, 3).reshape(batch*patches, time, channels)
        h = self.mha_time(self.norm1(h))
        h = h.reshape(batch, patches, time, channels).permute(0, 2, 1, 3) + x
        h2 = h.reshape(batch*time, patches, channels)
        h2 = self.mha_space(self.norm2(h2))
        h2  = h2.reshape(batch, time, patches, channels) + h
        return self.mlp(self.norm3(h2)) + h2


class Encoder(nn.Module):
    def __init__(
        self, 
        num_heads: int,
        num_channels: int,
        num_layers: int,
        d_linear: int,
        num_linear_layers: int = 2,
        num_groups: int = 8,
        dropout: float = 0.1,
        is_masked: bool = False
    ):
        super(Encoder, self).__init__()
        self.layers = nn.ModuleList([
            EncoderLayer(
                num_heads, num_channels, d_linear, num_linear_layers, num_groups, dropout, is_masked
            ) for _ in range(num_layers)
        ])
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        for layer in self.layers:
            x = layer(x)
        return x 

In [None]:
class MLPClassicationHead(nn.Module):
    """
    Class implementation of a position wise MLP
    """
    def __init__(
        self,
        num_classes: int,
        num_channels:int,
        d_ff: int,
        num_layers: int = 2,
        dropout: float = 0.1,
    ) -> None:
        super(MLPClassicationHead, self).__init__()

        layers = []
        layers.append(nn.Linear(num_channels, d_ff, bias=True))
        for i in range(1, num_layers - 1):
            layers.append(nn.Linear(d_ff, d_ff, bias=True))
            layers.append(nn.ReLU())
        
        layers.append(nn.Linear(d_ff, num_classes, bias=True))
        self.mlp_layers = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.mlp_layers(x)

class TimeSFormer(nn.Module):
    """
    Cloned from Vision Transformer -- need to adapt to TimeSFormer architecture
    """
    def __init__(
        self,
        num_classes,
        num_heads: int,
        d_model: int,
        d_mlp: int,
        patch_size: int = 16,
        frames: int = 8,
        image_size: tuple[int] = 32,
        num_encoder_layers: int = 2,
        encoder_mlp_depth: int = 2,
        classification_mlp_depth: int = 2,
        num_groups: int = 8,
        dropout: float = 0.1,
    ):
        super(TimeSFormer, self).__init__()
        self.d_model, self.patch_size, self.image_size = d_model, patch_size, image_size
        self.num_frames = frames
        self.n_patches = (image_size // patch_size) ** 2  # assumes square image 
        self.linear = nn.Linear((3*patch_size*patch_size), d_model)  # assumes rgb image 
        self.encoder = Encoder(
            num_heads, d_model, num_encoder_layers, d_mlp, encoder_mlp_depth, num_groups, dropout
        )

        self.cls_token = nn.Parameter(torch.zeros(1, 1, 1, d_model))  # assumes rgb image 
        self.classification_head = MLPClassicationHead(num_classes, d_model, d_mlp, classification_mlp_depth, dropout) # d_ff and depth are different things

        self.pos_embed_space = nn.Parameter(torch.zeros(1, 1, self.n_patches + 1, d_model))
        self.pos_embed_time = nn.Parameter(torch.zeros(1, self.num_frames, 1, d_model))
        nn.init.trunc_normal_(self.pos_embed_space, std=0.02)
        nn.init.trunc_normal_(self.pos_embed_time, std=0.02)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        forward pass for our vit classifier. takes in raw images and outputs a probability distribution over classes
        """
        batch_size, frames = x.shape[0], x.shape[2]
        x = self.linear(self._patchify(x))
        print(self.cls_token.shape)
        cls = self.cls_token.expand(batch_size, frames, -1, -1)
        x = torch.cat((cls, x), 2)
        print(f"pos_embed_space: {self.pos_embed_space.shape}")
        print(f"x: {x.shape}")
        x = x + self.pos_embed_space[:, : x.size(1), :].to(x.device)
        x = self.encoder(x)
        return self.classification_head(x[:,0,:])
    
    def _patchify(self, x: torch.Tensor) -> torch.Tensor:
        """
        Splits a batch of videos into non-overlapping patches.

        Args:
            x (torch.Tensor): Input tensor of shape [B, C, F, H, W]

        Returns:
            torch.Tensor: Patchified tensor of shape [B, F, NUM_P, DIM_P],
                          where DIM_P = channels * patch_size * patch_size.
        """
        batch_size, channels, frames, _, _ = x.shape
        n_patch_side = self.image_size // self.patch_size
        x = x.permute(0, 2, 1, 3, 4)
        x = x.reshape(
            batch_size,
            frames,
            channels,
            n_patch_side,
            self.patch_size,
            n_patch_side,
            self.patch_size,
        )
        x = x.permute(0, 1, 3, 5, 2, 4, 6)
        return x.reshape(batch_size, frames, -1, channels * self.patch_size * self.patch_size)
        
    # Removed _positional_embedding, as we now use a learnable positional embedding


In [19]:
batch_size = 2
channels = 3
frames = 8
height = 224
width = 224

dummy_videos = torch.randn(batch_size, channels, frames, height, width)
print(f"Input shape: {dummy_videos.shape}")

# Initialize TimeSFormer with correct parameters based on the class definition
timesformer = TimeSFormer(
    num_classes=10,
    num_heads=4,
    d_model=128,
    d_mlp=256,
    patch_size=16,
    image_size=224,
    num_encoder_layers=2,
    encoder_mlp_depth=2,
    classification_mlp_depth=2,
    num_groups=8,
    dropout=0.1
)

print(f"Model parameters: {sum(p.numel() for p in timesformer.parameters()):,}")

# Run forward pass
with torch.no_grad():
    output = timesformer(dummy_videos)
    print(f"Output shape: {output.shape}")
    print(f"Output probabilities sum: {torch.softmax(output, dim=-1).sum(dim=-1)}")


Input shape: torch.Size([2, 3, 8, 224, 224])
Model parameters: 1,349,514
torch.Size([1, 1, 1, 128])
pos_embed_space: torch.Size([1, 197, 128])
x: torch.Size([2, 8, 197, 128])


RuntimeError: The size of tensor a (197) must match the size of tensor b (8) at non-singleton dimension 2