# TimeSformer: Video Classification with Divided Space-Time Attention

**Paper**: [Is Space-Time Attention All You Need for Video Understanding?](https://arxiv.org/abs/2102.05095) — Bertasius et al., ICML 2021

A from-scratch PyTorch implementation of TimeSformer — a **pure transformer** for video that extends ViT by factorizing attention into temporal and spatial components.

---

## Why TimeSformer?

Traditional video models (C3D, I3D, SlowFast) rely on 3D convolutions. TimeSformer shows that **self-attention alone** can match or beat them — but only if you decompose the attention correctly.

The paper tests **5 attention schemes** and finds that **Divided Space-Time Attention (T+S)** gives the best accuracy-efficiency tradeoff:

| Scheme | How It Works | Complexity | K400 Top-1 |
|--------|-------------|-----------|------------|
| Space Only (S) | Attend within same frame | O(TN²) | 75.2% |
| Joint (ST) | Attend to all patches, all frames | O((NF)²) | 77.9% |
| **Divided (T+S)** | **Temporal then Spatial** | **O(NF(N+F))** | **78.0%** |
| Sparse Local-Global (L+G) | Local then global subsampled | ~approx full | 75.7% |
| Axial (T+W+H) | 3 separate 1D attentions | O(NF(T+W+H)) | 76.7% |

---

### Architecture Overview

```
Video (B, T, C, H, W)
    │
    ▼
Patch Embedding ──── Conv2d(3, 768, k=16, s=16)
    │
    ▼
(B, T, 196, 768)  ── + time_embed + space_embed
    │
    ▼
Prepend CLS token → (B, T+1, N, D)
    │
    ▼
╔═══════════════════════════════════════╗
║       ×12 TimeSformerBlock            ║
║                                       ║
║  1. Temporal Attn  (B*N, T, D)        ║
║       └─ + residual + LayerNorm       ║
║  2. Spatial Attn   (B*T, N, D)        ║
║       └─ + residual + LayerNorm       ║
║  3. MLP (dim→4*dim→dim)              ║
║       └─ + residual + LayerNorm       ║
╚═══════════════════════════════════════╝
    │
    ▼
CLS output x[:, 0, 0] → LayerNorm → Linear → logits
```


---
## 1. Setup


In [23]:
!pip install torchvision trim einops



## 2. Imports

Key dependencies:
- **`einops`** — the `rearrange` function is the backbone of divided attention. It reshapes 4D video tensors into 3D for standard `nn.MultiheadAttention`
- **`timm`** — provides pretrained ViT-Base weights for transfer learning (spatial attention + MLP initialized from ImageNet)


In [24]:
# import

import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from einops import rearrange
from tqdm import tqdm
import timm, torchvision
from PIL import Image

In [25]:
device = "cuda" if torch.cuda.is_available() else "cpu"

print(device)

cuda


---
## 3. Video Frame Dataset

Loads videos stored as folders of frame images. Uniformly samples `num_frames=8` frames using `torch.linspace`.

```
Expected directory structure:

root_dir/
├── class_1/
│   ├── video_001/
│   │   ├── frame_0001.jpg
│   │   ├── frame_0002.jpg
│   │   └── ...
│   └── video_002/
│       └── ...
└── class_2/
    └── ...

Output per sample: (T, C, H, W) = (8, 3, 224, 224)
```

**Uniform sampling strategy**: If the video has 120 frames and we need 8, we pick frames at indices `[0, 17, 34, 51, 68, 85, 102, 119]` — spread evenly across the full duration.


In [26]:
class VideoFrameDataset(Dataset):
    def __init__(self, root_dir, num_frames=8, image_size=224):
        self.root_dir = root_dir
        self.num_frames = num_frames

        self.classes = sorted(os.listdir(root_dir))
        self.class_to_idx = {c: i for i, c in enumerate(self.classes)}

        self.samples = []
        for cls in self.classes:
            cls_path = os.path.join(root_dir, cls)
            for video in os.listdir(cls_path):
                video_path = os.path.join(cls_path, video)
                self.samples.append((video_path, self.class_to_idx[cls]))

        self.transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        video_path, label = self.samples[idx]
        frames = sorted(os.listdir(video_path))

        if len(frames) < self.num_frames:
            frames = frames * (self.num_frames // len(frames) + 1)

        idxs = torch.linspace(0, len(frames) - 1, self.num_frames).long()
        selected = [frames[i] for i in idxs]

        imgs = []
        for f in selected:
            img = Image.open(os.path.join(video_path, f)).convert("RGB")
            imgs.append(self.transform(img))

        video = torch.stack(imgs)  # (T, C, H, W)
        return video, label

---
## 4. TimeSformerBlock — Divided Space-Time Attention (Core of the Paper)

This is the most important cell. Each block performs attention in **two factorized steps** using `einops.rearrange` to reshape the 4D tensor `(B, T, N, D)` into 3D for standard `nn.MultiheadAttention`.

---

### Step 1: Temporal Attention — "What happened at this position over time?"

```
                    rearrange('b t n d -> (b n) t d')
  (B, T, N, D)  ─────────────────────────────────────→  (B*N, T, D)

  What this does visually:
  ┌──────────────────────────────────────────────────────────────┐
  │                                                              │
  │   Frame 1       Frame 2       Frame 3      ...  Frame T     │
  │  ┌───┬───┐    ┌───┬───┐    ┌───┬───┐         ┌───┬───┐     │
  │  │ . │ . │    │ . │ . │    │ . │ . │         │ . │ . │     │
  │  ├───┼───┤    ├───┼───┤    ├───┼───┤         ├───┼───┤     │
  │  │ . │ Q │    │ . │ K │    │ . │ K │         │ . │ K │     │
  │  └───┴───┘    └───┴───┘    └───┴───┘         └───┴───┘     │
  │                                                              │
  │  For EACH patch position (B*N total):                        │
  │  → gather that position across ALL T frames                  │
  │  → run self-attention over T tokens                          │
  │  → captures MOTION / temporal dynamics at each location      │
  │                                                              │
  └──────────────────────────────────────────────────────────────┘
```

**Concrete example** (B=1, T=8, N=196, D=768):
- Input: `(1, 8, 196, 768)` → rearrange → `(196, 8, 768)`
- 196 independent attention operations, each over 8 temporal tokens
- Each patch at position (row=3, col=5) attends to itself across all 8 frames

---

### Step 2: Spatial Attention — "What's in this frame?"

```
                    rearrange('b t n d -> (b t) n d')
  (B, T, N, D)  ─────────────────────────────────────→  (B*T, N, D)

  What this does visually:
  ┌──────────────────────────────────────────────────────────────┐
  │                                                              │
  │   Frame 1 only:                                              │
  │  ┌─────┬─────┬─────┬─────┬─────┬─────┬─────┐               │
  │  │ K₁  │ K₂  │ K₃  │ K₄  │ K₅  │ ... │K₁₉₆│               │
  │  └─────┴─────┴──┬──┴─────┴─────┴─────┴─────┘               │
  │                  │                                           │
  │            Query patch attends to ALL                        │
  │            196 patches in the SAME frame                     │
  │                                                              │
  │  For EACH frame instance (B*T total):                        │
  │  → gather all N patches in that single frame                 │
  │  → run self-attention over N tokens                          │
  │  → captures APPEARANCE / spatial context                     │
  │                                                              │
  └──────────────────────────────────────────────────────────────┘
```

**Concrete example** (B=1, T=8, N=196, D=768):
- Input: `(1, 8, 196, 768)` → rearrange → `(8, 196, 768)`
- 8 independent attention operations, each over 196 spatial tokens
- Each frame's patches attend to all other patches within that same frame

---

### Why This Ordering Works (Temporal → Spatial)

After temporal attention, each patch token is already enriched with temporal context. When spatial attention then runs, it operates on **temporally-informed** representations — so the spatial attention implicitly reasons about motion too.

---

### Complexity Breakdown

```
Joint attention (ST):
  Each token attends to N*F others → O(NF) per token
  Total: (NF+1) tokens × (NF+1) keys ≈ O((NF)²)

Divided attention (T+S):
  Temporal: each token attends to F others  → O(F) per token
  Spatial:  each token attends to N others  → O(N) per token
  Combined: O(N + F) per token
  Total: (NF) tokens × (N + F) keys ≈ O(NF(N+F))

With N=196, F=8:
  Joint:   196×8 = 1,568 keys per token  →  1568² ≈ 2.46M total
  Divided: 196+8 = 204 keys per token    →  1568×204 ≈ 320K total
  
  That's ~7.7× fewer operations!
```


In [27]:
class TimeSformerBlock(nn.Module):
  def __init__(self, dim, heads) -> None:
      super().__init__()

      self.temporal_attn = nn.MultiheadAttention(dim, heads, batch_first = True)
      self.spatial_attn = nn.MultiheadAttention(dim, heads, batch_first = True)

      self.norm1 = nn.LayerNorm(dim)
      self.norm2 = nn.LayerNorm(dim)
      self.norm3 = nn.LayerNorm(dim)

      self.mlp = nn.Sequential(
          nn.Linear(dim, dim * 4),
          nn.GELU(),
          nn.Linear(dim * 4, dim),
      )

  def forward(self, x):
      B, T, N, D  = x.shape

      # Temporal Attention
      xt = rearrange(x, 'b t n d -> (b n) t d')
      xt = self.temporal_attn(xt, xt, xt)[0]
      xt = rearrange(xt, '(b n) t d -> b t n d', b = B, n = N)
      x = x + self.norm1(xt)

      # Spatial Attention
      xs = rearrange(x, 'b t n d -> (b t) n d')
      xs = self.spatial_attn(xs, xs, xs)[0]
      xs = rearrange(xs, '(b t) n d -> b t n d', b = B, t = T)
      x = x + self.norm2(xs)

      # MLP
      xm = self.mlp(x)
      x = x + self.norm3(xm)

      return x

### Let's verify the dimensions step by step

This cell traces exact tensor shapes through the block with concrete numbers.


In [None]:
# Dimension verification — trace through TimeSformerBlock
# Using concrete values: B=2, T=8, N=196, D=768

B, T, N, D = 2, 8, 196, 768
x = torch.randn(B, T, N, D)
print(f"Input shape: {x.shape}")  # (2, 8, 196, 768)

print("\n--- TEMPORAL ATTENTION ---")
xt = rearrange(x, 'b t n d -> (b n) t d')
print(f"After rearrange to (B*N, T, D): {xt.shape}")  # (392, 8, 768)
print(f"  → {B}*{N} = {B*N} independent sequences, each of length T={T}")
print(f"  → Each sequence = one patch position tracked across {T} frames")

# After attention (shape doesn't change)
print(f"After MHA: {xt.shape}")  # still (392, 8, 768)

xt_back = rearrange(xt, '(b n) t d -> b t n d', b=B, n=N)
print(f"Rearranged back: {xt_back.shape}")  # (2, 8, 196, 768)

print("\n--- SPATIAL ATTENTION ---")
xs = rearrange(x, 'b t n d -> (b t) n d')
print(f"After rearrange to (B*T, N, D): {xs.shape}")  # (16, 196, 768)
print(f"  → {B}*{T} = {B*T} independent sequences, each of length N={N}")
print(f"  → Each sequence = all patches within one frame")

print(f"After MHA: {xs.shape}")  # still (16, 196, 768)

xs_back = rearrange(xs, '(b t) n d -> b t n d', b=B, t=T)
print(f"Rearranged back: {xs_back.shape}")  # (2, 8, 196, 768)

print("\n--- SUMMARY ---")
print(f"Temporal: {B*N} attention ops, seq_len={T} → captures motion")
print(f"Spatial:  {B*T} attention ops, seq_len={N} → captures appearance")
print(f"Output shape = Input shape = {x.shape}")


---
## 5. TimeSformer — Full Model with CLS Token & Positional Embeddings

### The Three Learnable Parameters — Why These Shapes?

```python
self.cls_token   = nn.Parameter(torch.zeros(1, 1, 1, embed_dim))          # (1, 1, 1, D)
self.time_embed  = nn.Parameter(torch.randn(1, num_frames + 1, embed_dim))   # (1, T+1, D)
self.space_embed = nn.Parameter(torch.randn(1, 1 + num_patches, embed_dim))  # (1, N+1, D)
```

---

### `cls_token` — shape `(1, 1, 1, D)`

**What**: A learnable "summary" token that aggregates global video information.

**Shape breakdown**:
```
  (  1,       1,          1,       D  )
     │        │           │        │
     │        │           │        └─ embedding dimension (768)
     │        │           └─ 1 spatial slot (expanded to N during forward)
     │        └─ 1 temporal slot (prepended as frame index 0)
     └─ 1 (broadcast across batch B via .expand())
```

**Why `zeros` not `randn`?**  
The CLS token has no prior spatial/temporal meaning — it learns purely from data. Starting at zero ensures it doesn't inject noise before training.

**In forward():**
```python
cls = self.cls_token.expand(B, -1, self.num_patches, -1)  # (1,1,1,D) → (B, 1, N, D)
cls = cls + self.time_embed[:, :1, None, :]               # add time position 0
x = torch.cat((cls, x), dim=1)                            # prepend → (B, T+1, N, D)
```

---

### `time_embed` — shape `(1, T+1, D)`

**What**: Tells the model *when* each frame appears in the video.

**Why T+1?** The `+1` reserves position 0 for the CLS token. Actual frames use positions `1` through `T`.

```
time_embed indices:  [  0  ,   1  ,   2  ,  ...  ,   T  ]
                        │       │       │              │
                       CLS   frame1  frame2  ...    frameT
```

**Applied via broadcasting:**
```python
x = x + self.time_embed[:, 1:T+1, None, :]
#        shape: (1, T, 1, D)
#                         ↑
#                 broadcasts across N patches
#   → SAME time encoding added to ALL patches in a given frame
```

---

### `space_embed` — shape `(1, N+1, D)`

**What**: Tells the model *where* each patch sits spatially.

**Why N+1?** The `+1` accounts for the CLS token in the spatial dimension.

```
space_embed indices:  [  0  ,   1  ,   2  ,  ...  ,   N  ]
                         │       │       │              │
                        CLS  patch_0  patch_1  ...   patch_N-1
```

**Applied via broadcasting:**
```python
x = x + self.space_embed[:, None, :, :]
#        shape: (1, 1, N+1, D)
#                    ↑
#            broadcasts across T frames
#   → SAME spatial encoding added to the same patch across ALL frames
```

---

### The Combined Positional Identity

Each token receives a **unique identity** from 3 additive sources:

```
  token[t, n] = patch_content  +  time_embed[t]  +  space_embed[n]
                  │                    │                   │
                  │                    │                   └─ "I'm at row 3, col 5"
                  │                    └─ "I'm in frame 4"
                  └─ visual features from Conv2d
```

This **factorized** approach (T+N parameters) is much cheaper than a full position matrix (T×N parameters) while still giving each token a unique space-time position.


In [28]:
class TimeSformer(nn.Module):
  def __init__(self,
               num_classes = 2,
               num_frames = 8,
               img_size = 224,
               patch_size = 16,
               embed_dim = 768,
               depth = 12,
               heads = 12):
    super().__init__()

    self.num_frames = num_frames
    self.patch_size = patch_size
    self.embed_dim = embed_dim
    self.num_patches = (img_size // patch_size) ** 2

    self.patch_embed = nn.Conv2d(3, embed_dim, kernel_size = patch_size,
                                 stride = patch_size)
    self.cls_token = nn.Parameter(torch.zeros(1, 1, 1, embed_dim))

    self.time_embed = nn.Parameter(torch.randn(1, self.num_frames + 1, embed_dim))
    self.space_embed = nn.Parameter(torch.randn(1, 1 + self.num_patches, embed_dim))

    self.blocks = nn.ModuleList(
        TimeSformerBlock(embed_dim, heads) for _ in range(depth)
    )

    self.norm = nn.LayerNorm(embed_dim)
    self.head = nn.Linear(embed_dim, num_classes)

  def forward(self, x):
    B, T, C, H, W = x.shape

    x = x.view(B * T, C, H, W)
    x = self.patch_embed(x)
    x = x.flatten(2).transpose(1, 2)
    x = x.view(B, T, -1, self.embed_dim)

    # Add positional embeddding exluding cls token
    x = x + self.time_embed[:, 1:T+1, None, :] + self.space_embed[:, None, :, :]
    # Add cls token
    cls = self.cls_token.expand(B, -1, self.num_patches, -1)
    cls = cls + self.time_embed[:, :1, None, :]
    # Prepend cls token
    x = torch.cat((cls, x), dim = 1)

    # transformer blocks
    for block in self.blocks:
      x = block(x)

    cls_out = self.norm(x[:,0,0])

    return self.head(cls_out)

### Let's trace the forward pass dimensions


In [None]:
# Trace the forward pass of TimeSformer with concrete shapes
# B=2, T=8, C=3, H=W=224, patch_size=16, embed_dim=768

B, T, C, H, W = 2, 8, 3, 224, 224
num_patches = (224 // 16) ** 2  # = 196
D = 768

print("=== Forward Pass Dimension Trace ===")
print(f"\n1. Input video: ({B}, {T}, {C}, {H}, {W})")

print(f"\n2. Reshape for patch_embed: ({B*T}, {C}, {H}, {W})")
print(f"   → Conv2d(3, 768, k=16, s=16)")
print(f"   → ({B*T}, {D}, {H//16}, {W//16})")
print(f"   → flatten(2).transpose(1,2) → ({B*T}, {num_patches}, {D})")
print(f"   → view back → ({B}, {T}, {num_patches}, {D})")

print(f"\n3. Add positional embeddings:")
print(f"   time_embed[:, 1:{T}+1, None, :] shape: (1, {T}, 1, {D}) → broadcasts to (B, T, N, D)")
print(f"   space_embed[:, None, :, :]       shape: (1, 1, {num_patches}, {D}) → broadcasts to (B, T, N, D)")

print(f"\n4. Prepend CLS token:")
print(f"   cls: (1,1,1,{D}) → expand → ({B}, 1, {num_patches}, {D})")
print(f"   + time_embed[:, :1, None, :] → add time pos 0")
print(f"   cat(cls, x, dim=1) → ({B}, {T}+1, {num_patches}, {D}) = ({B}, {T+1}, {num_patches}, {D})")

print(f"\n5. Through 12 TimeSformerBlocks:")
print(f"   Each block: ({B}, {T+1}, {num_patches}, {D}) → ({B}, {T+1}, {num_patches}, {D})")

print(f"\n6. Classification:")
print(f"   x[:, 0, 0] → ({B}, {D}) → norm → head(768, num_classes) → logits")


---
## 6. Transfer Learning: ImageNet ViT → TimeSformer

The key insight: **spatial attention in TimeSformer is exactly the same operation as ViT's self-attention**, just applied per-frame. So we can directly copy ViT weights into the spatial attention + MLP.

The **temporal attention** is new (ViT has no concept of time) — it's initialized randomly and must learn from video data.

```
ViT-Base (ImageNet)              TimeSformer
═══════════════════              ════════════
patch_embed.proj    ──────────→  patch_embed (Conv2d)
                                 
For each of 12 blocks:
  attn.qkv          ──────────→  spatial_attn.in_proj
  attn.proj          ──────────→  spatial_attn.out_proj
  mlp.fc1            ──────────→  mlp[0] (Linear)
  mlp.fc2            ──────────→  mlp[2] (Linear)
                                 
  (nothing)          ──────────→  temporal_attn (random init)
```

This gives the model strong spatial understanding from day one — it only needs to learn temporal dynamics during fine-tuning.


In [29]:
# pretrained vit

vit = timm.create_model('vit_base_patch16_224', pretrained = True)

vit.eval()

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (norm): Identity()
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False

### Weight Loading Function

Maps ViT's attention QKV weights (packed as a single matrix) to TimeSformer's `nn.MultiheadAttention` `in_proj_weight`.


In [30]:
def load_vit_weights_into_timesformer(timesformer, vit):
    # Patch embedding
    timesformer.patch_embed.weight.data.copy_(
        vit.patch_embed.proj.weight.data
    )
    timesformer.patch_embed.bias.data.copy_(
        vit.patch_embed.proj.bias.data
    )

    # Transformer blocks (spatial parts only)
    for ts_block, vit_block in zip(timesformer.blocks, vit.blocks):
        ts_block.spatial_attn.in_proj_weight.data.copy_(
            vit_block.attn.qkv.weight.data
        )
        ts_block.spatial_attn.in_proj_bias.data.copy_(
            vit_block.attn.qkv.bias.data
        )
        ts_block.spatial_attn.out_proj.weight.data.copy_(
            vit_block.attn.proj.weight.data
        )
        ts_block.spatial_attn.out_proj.bias.data.copy_(
            vit_block.attn.proj.bias.data
        )

        ts_block.mlp[0].weight.data.copy_(
            vit_block.mlp.fc1.weight.data
        )
        ts_block.mlp[0].bias.data.copy_(
            vit_block.mlp.fc1.bias.data
        )
        ts_block.mlp[2].weight.data.copy_(
            vit_block.mlp.fc2.weight.data
        )
        ts_block.mlp[2].bias.data.copy_(
            vit_block.mlp.fc2.bias.data
        )

    print("Loaded ImageNet ViT weights into TimeSformer (spatial only)")

---
## 7. Instantiate & Initialize

ViT-Base config: 12 layers, 12 heads, 768-dim, 16×16 patches → **N = (224/16)² = 196** patches per frame.


In [31]:
model = TimeSformer(
    num_classes=2,
    num_frames=8,
    embed_dim=768,
    depth=12,
    heads=12,
).to(device)

load_vit_weights_into_timesformer(model, vit)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
criterion = nn.CrossEntropyLoss()

Loaded ImageNet ViT weights into TimeSformer (spatial only)


---
## 8. Dataset & DataLoader

Point `root_dir` to your video frames directory. Use `batch_size=2` because TimeSformer is memory-hungry (12 blocks × 8 frames × 196 patches).


In [32]:
# train_dataset = VideoFrameDataset(
#     root_dir=" ______ ",
#     num_frames=8
# )

# train_loader = DataLoader(
#     train_dataset,
#     batch_size=2,
#     shuffle=True,
#     num_workers=2
# )

---
## 9. Training

Uses `AdamW` with a low learning rate (`1e-5`) since we're fine-tuning from pretrained spatial weights. The temporal attention learns from scratch.


In [33]:
def train_one_epoch(model, loader):
  model.train()
  total_loss = 0.0
  correct = 0
  total = 0

  for videos, labels in loader:
      videos = videos.to(device)
      labels = labels.to(device)

      logits = model(videos)
      loss = criterion(logits, labels)

      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

      total_loss += loss.item()

      preds = torch.argmax(logits, dim=1)
      correct += (preds == labels).sum().item()
      total += labels.size(0)

  avg_loss = total_loss / len(loader)
  accuracy = correct / total

  return avg_loss, accuracy




In [34]:
# epochs = 10

# for epoch in range(epochs):
#     train_loss, train_acc = train_one_epoch(model, train_loader)

#     print(
#         f"Epoch {epoch+1} | "
#         f"Train loss: {train_loss:.4f} | "
#         f"Train accuracy: {train_acc:.4f}"
#     )

---
## Summary

### What We Built

A complete TimeSformer that:
1. **Patches** video frames using Conv2d (same as ViT)
2. **Encodes position** with factorized time + space embeddings
3. **Attends temporally** (same patch across frames — captures motion)
4. **Attends spatially** (all patches within a frame — captures appearance)
5. **Classifies** via CLS token aggregation

### Key Takeaways

- The `einops.rearrange` trick makes factorized attention trivial to implement
- Divided attention achieves **best accuracy with much lower cost** than joint attention
- Transfer learning from ViT (spatial weights) enables strong performance with limited video data
- The CLS token at `x[:, 0, 0]` aggregates information from all frames and all patches through the attention blocks
