# VideoMAE Ablation Study: All Experiments

**Goal**: Measure impact of each improvement from Paper baseline → Full Custom (0.87)

| Exp | Name | Cumulative Changes |
|-----|------|-------------------|
| 1 | Paper Baseline | VideoMAE + paper settings |
| 2 | + Multi-Segment TTA | Exp1 + 2 temporal × 3 spatial |
| 3 | + Consistent Transform | Exp1 + same crop/flip all frames |
| 4 | + Mixup | Exp3 + Mixup α=0.8 |
| 5 | + Label Smoothing | Exp3 + ε=0.1 |
| 6 | + 2-Stage | Exp4 + Phase1→Phase2 |
| 7 | + Flip TTA | Exp6 + 6-view flip TTA |
| 8 | Full Custom | = videoMAE.ipynb (target 0.87) |

In [None]:
## 1. Setup & Shared Imports

import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
from PIL import Image
from tqdm.auto import tqdm
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from transformers import VideoMAEForVideoClassification, VideoMAEImageProcessor
from transformers import get_cosine_schedule_with_warmup
import pandas as pd
from sklearn.metrics import accuracy_score

def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

seed_everything(42)
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {DEVICE}")

# Results tracking
RESULTS = []

In [None]:
## 2. Configuration

PATH_DATA_TRAIN = Path('/kaggle/input/action-video/data/data_train')
PATH_DATA_TEST = Path('/kaggle/input/action-video/data/test')

MODEL_CKPT = "MCG-NJU/videomae-base-finetuned-kinetics"
NUM_FRAMES = 16
IMAGE_SIZE = 224
RESIZE_SIZE = 256  # Resize short edge before crop
BATCH_SIZE = 16
GRAD_ACCUM_STEPS = 2
EPOCHS = 4  # Quick test (use 30 for full)
WEIGHT_DECAY = 0.05
WARMUP_RATIO = 0.1

# Paper LR scaling
BASE_LR = 1e-3
EFFECTIVE_BATCH = BATCH_SIZE * GRAD_ACCUM_STEPS
LR = BASE_LR * EFFECTIVE_BATCH / 256

processor = VideoMAEImageProcessor.from_pretrained(MODEL_CKPT)
MEAN = processor.image_mean
STD = processor.image_std

print(f"LR: {LR:.2e}, Effective Batch: {EFFECTIVE_BATCH}")
print(f"Norm - Mean: {MEAN}, Std: {STD}")

In [None]:
## 3. Transform Classes

class VideoTransformBaseline:
    """Baseline: Per-frame independent random transforms."""
    def __init__(self, is_train=True):
        self.is_train = is_train
    
    def __call__(self, frames):
        # frames: [T, C, H, W] tensor
        if self.is_train:
            h, w = frames.shape[-2:]
            scale = random.uniform(0.8, 1.0)
            new_h, new_w = int(h * scale), int(w * scale)
            frames = TF.resize(frames, [new_h, new_w])
            i = random.randint(0, max(0, new_h - IMAGE_SIZE))
            j = random.randint(0, max(0, new_w - IMAGE_SIZE))
            frames = TF.crop(frames, i, j, min(IMAGE_SIZE, new_h), min(IMAGE_SIZE, new_w))
            frames = TF.resize(frames, [IMAGE_SIZE, IMAGE_SIZE])
            if random.random() < 0.5:
                frames = TF.hflip(frames)
        else:
            frames = TF.resize(frames, [IMAGE_SIZE, IMAGE_SIZE])
        return torch.stack([TF.normalize(f, MEAN, STD) for f in frames])


class VideoTransformConsistent:
    """Consistent: Same crop/flip params for ALL frames."""
    def __init__(self, is_train=True):
        self.is_train = is_train
    
    def __call__(self, frames):
        # frames: list of PIL Images
        frames = [TF.resize(img, RESIZE_SIZE) for img in frames]
        
        if self.is_train:
            # Get SAME random params for all frames
            i, j, h, w = T.RandomResizedCrop.get_params(
                frames[0], scale=(0.8, 1.0), ratio=(0.75, 1.33)
            )
            is_flip = random.random() > 0.5
            
            transformed = []
            for img in frames:
                img = TF.resized_crop(img, i, j, h, w, size=(IMAGE_SIZE, IMAGE_SIZE))
                if is_flip:
                    img = TF.hflip(img)
                img = TF.to_tensor(img)
                img = TF.normalize(img, MEAN, STD)
                transformed.append(img)
            return torch.stack(transformed)
        else:
            transformed = []
            for img in frames:
                img = TF.center_crop(img, IMAGE_SIZE)
                img = TF.to_tensor(img)
                img = TF.normalize(img, MEAN, STD)
                transformed.append(img)
            return torch.stack(transformed)

In [None]:
## 4. Dataset Classes

class VideoDataset(Dataset):
    def __init__(self, root, transform, use_pil=False):
        self.root = Path(root)
        self.transform = transform
        self.use_pil = use_pil  # True for ConsistentTransform
        
        self.classes = sorted([d.name for d in self.root.iterdir() if d.is_dir()])
        self.class_to_idx = {c: i for i, c in enumerate(self.classes)}
        
        self.samples = []
        for cls in self.classes:
            for vid in (self.root / cls).iterdir():
                if vid.is_dir():
                    self.samples.append((vid, self.class_to_idx[cls]))
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        vid_dir, label = self.samples[idx]
        files = sorted(vid_dir.glob('*.jpg'))
        total = len(files)
        indices = torch.linspace(0, total-1, NUM_FRAMES).long()
        
        if self.use_pil:
            frames = [Image.open(files[i]).convert('RGB') for i in indices]
        else:
            frames = [TF.to_tensor(Image.open(files[i]).convert('RGB')) for i in indices]
            frames = torch.stack(frames)
        
        frames = self.transform(frames)
        return frames, label

In [None]:
## 5. Test Dataset Variants

class TestDatasetSingle(Dataset):
    """Single center crop (no TTA)."""
    def __init__(self, root):
        self.root = Path(root)
        self.samples = [(d, int(d.name)) for d in self.root.iterdir() if d.is_dir()]
        self.samples.sort(key=lambda x: x[1])
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        vid_dir, vid_id = self.samples[idx]
        files = sorted(vid_dir.glob('*.jpg'))
        indices = torch.linspace(0, len(files)-1, NUM_FRAMES).long()
        
        frames = []
        for i in indices:
            img = Image.open(files[i]).convert('RGB')
            img = TF.resize(img, RESIZE_SIZE)
            img = TF.center_crop(img, IMAGE_SIZE)
            img = TF.to_tensor(img)
            img = TF.normalize(img, MEAN, STD)
            frames.append(img)
        return torch.stack(frames), vid_id


class TestDatasetMultiSegment(Dataset):
    """Paper-like TTA: N temporal segments × 3 spatial crops."""
    def __init__(self, root, num_segments=2):
        self.root = Path(root)
        self.num_segments = num_segments
        self.samples = [(d, int(d.name)) for d in self.root.iterdir() if d.is_dir()]
        self.samples.sort(key=lambda x: x[1])
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        vid_dir, vid_id = self.samples[idx]
        files = sorted(vid_dir.glob('*.jpg'))
        total = len(files)
        
        views = []
        for seg in range(self.num_segments):
            # Different temporal segment
            seg_len = total // self.num_segments
            start = seg * seg_len
            end = min(start + seg_len, total)
            seg_indices = torch.linspace(start, end-1, NUM_FRAMES).long()
            
            frames = [Image.open(files[i]).convert('RGB') for i in seg_indices]
            frames = [TF.resize(img, RESIZE_SIZE) for img in frames]
            
            w, h = frames[0].size  # PIL: (width, height)
            # 3 spatial crops: left, center, right
            crop_positions = [
                (0, 0),  # top-left
                ((h - IMAGE_SIZE) // 2, (w - IMAGE_SIZE) // 2),  # center
                (h - IMAGE_SIZE, w - IMAGE_SIZE),  # bottom-right
            ]
            
            for top, left in crop_positions:
                top = max(0, min(top, h - IMAGE_SIZE))
                left = max(0, min(left, w - IMAGE_SIZE))
                view_frames = []
                for img in frames:
                    cropped = TF.crop(img, top, left, IMAGE_SIZE, IMAGE_SIZE)
                    t = TF.to_tensor(cropped)
                    t = TF.normalize(t, MEAN, STD)
                    view_frames.append(t)
                views.append(torch.stack(view_frames))
        
        return torch.stack(views), vid_id  # [num_segments*3, T, C, H, W]


class TestDatasetFlipTTA(Dataset):
    """Custom TTA: 3 spatial crops + 3 flipped = 6 views."""
    def __init__(self, root):
        self.root = Path(root)
        self.samples = [(d, int(d.name)) for d in self.root.iterdir() if d.is_dir()]
        self.samples.sort(key=lambda x: x[1])
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        vid_dir, vid_id = self.samples[idx]
        files = sorted(vid_dir.glob('*.jpg'))
        indices = torch.linspace(0, len(files)-1, NUM_FRAMES).long()
        
        frames = [Image.open(files[i]).convert('RGB') for i in indices]
        frames = [TF.resize(img, RESIZE_SIZE) for img in frames]
        
        w, h = frames[0].size
        crop_positions = [
            ((h - IMAGE_SIZE) // 2, (w - IMAGE_SIZE) // 2),  # center
            (0, (w - IMAGE_SIZE) // 2),  # top-center
            (h - IMAGE_SIZE, (w - IMAGE_SIZE) // 2),  # bottom-center
        ]
        
        views = []
        for top, left in crop_positions:
            top = max(0, min(top, h - IMAGE_SIZE))
            left = max(0, min(left, w - IMAGE_SIZE))
            view_frames = []
            for img in frames:
                cropped = TF.crop(img, top, left, IMAGE_SIZE, IMAGE_SIZE)
                t = TF.to_tensor(cropped)
                t = TF.normalize(t, MEAN, STD)
                view_frames.append(t)
            views.append(torch.stack(view_frames))
        
        # Flipped versions
        for top, left in crop_positions:
            top = max(0, min(top, h - IMAGE_SIZE))
            left = max(0, min(left, w - IMAGE_SIZE))
            view_frames = []
            for img in frames:
                cropped = TF.crop(img, top, left, IMAGE_SIZE, IMAGE_SIZE)
                cropped = TF.hflip(cropped)
                t = TF.to_tensor(cropped)
                t = TF.normalize(t, MEAN, STD)
                view_frames.append(t)
            views.append(torch.stack(view_frames))
        
        return torch.stack(views), vid_id  # [6, T, C, H, W]

In [None]:
## 6. Mixup Collate Function

class MixupCollate:
    def __init__(self, num_classes, alpha=0.8):
        self.num_classes = num_classes
        self.alpha = alpha
    
    def __call__(self, batch):
        inputs, targets = torch.utils.data.default_collate(batch)
        batch_size = inputs.size(0)
        
        lam = np.random.beta(self.alpha, self.alpha)
        index = torch.randperm(batch_size)
        
        inputs = lam * inputs + (1 - lam) * inputs[index]
        targets_onehot = F.one_hot(targets, self.num_classes).float()
        targets = lam * targets_onehot + (1 - lam) * targets_onehot[index]
        
        return inputs, targets

In [None]:
## 7. Training & Evaluation Functions

def train_epoch(model, loader, optimizer, scheduler, scaler, use_mixup=False, label_smoothing=0.0):
    model.train()
    total_loss, correct, total = 0.0, 0, 0
    optimizer.zero_grad()
    
    for batch_idx, (videos, targets) in enumerate(tqdm(loader, leave=False)):
        videos = videos.to(DEVICE)
        targets = targets.to(DEVICE)
        
        with torch.amp.autocast(device_type='cuda'):
            outputs = model(videos)
            logits = outputs.logits
            
            if use_mixup:
                log_probs = F.log_softmax(logits, dim=1)
                loss = -torch.sum(targets * log_probs, dim=1).mean()
                true_labels = targets.argmax(dim=1)
            else:
                loss = F.cross_entropy(logits, targets, label_smoothing=label_smoothing)
                true_labels = targets
        
        preds = logits.argmax(dim=1)
        correct += (preds == true_labels).sum().item()
        total += true_labels.size(0)
        total_loss += loss.item() * true_labels.size(0)
        
        loss = loss / GRAD_ACCUM_STEPS
        scaler.scale(loss).backward()
        
        if (batch_idx + 1) % GRAD_ACCUM_STEPS == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            scheduler.step()
    
    return total_loss / total, correct / total


def evaluate(model, loader, multi_view=False, id2label=None):
    model.eval()
    predictions = []
    
    with torch.no_grad():
        for data, vid_ids in tqdm(loader, leave=False, desc="Eval"):
            if multi_view:
                B, V, T, C, H, W = data.shape
                data = data.view(B * V, T, C, H, W).to(DEVICE)
                outputs = model(data)
                logits = outputs.logits.view(B, V, -1).mean(dim=1)
            else:
                data = data.to(DEVICE)
                outputs = model(data)
                logits = outputs.logits
            
            preds = logits.argmax(dim=1)
            for vid, pred in zip(vid_ids.tolist(), preds.tolist()):
                predictions.append((vid, id2label[pred]))
    
    return predictions

In [None]:
## 8. Load Test Labels

!gdown "1Xv2CWOqdBj3kt0rkNJKRsodSIEd3-wX_" -O test_labels.csv -q
gt_df = pd.read_csv("test_labels.csv")
GT_LABELS = dict(zip(gt_df['id'].astype(str), gt_df['class']))

def calc_accuracy(predictions):
    y_pred, y_true = [], []
    for vid_id, pred_cls in predictions:
        if str(vid_id) in GT_LABELS:
            y_pred.append(pred_cls)
            y_true.append(GT_LABELS[str(vid_id)])
    return accuracy_score(y_true, y_pred)

print(f"Loaded {len(GT_LABELS)} test labels")

In [None]:
## 9. Experiment Runner

def run_experiment(
    exp_name, train_dataset, test_dataset,
    epochs=EPOCHS, use_mixup=False, label_smoothing=0.0,
    multi_view=False, two_stage=False, mixup_collate=None
):
    print(f"\n{'='*60}")
    print(f"EXPERIMENT: {exp_name}")
    print(f"{'='*60}")
    
    seed_everything(42)  # Reset seed for reproducibility
    
    # Create fresh model
    label2id = train_dataset.class_to_idx
    id2label = {v: k for k, v in label2id.items()}
    
    model = VideoMAEForVideoClassification.from_pretrained(
        MODEL_CKPT, label2id=label2id, id2label=id2label,
        ignore_mismatched_sizes=True, num_frames=NUM_FRAMES
    ).to(DEVICE)
    
    # DataLoaders
    train_loader = DataLoader(
        train_dataset, batch_size=BATCH_SIZE, shuffle=True,
        num_workers=2, pin_memory=True, drop_last=True,
        collate_fn=mixup_collate
    )
    test_loader = DataLoader(
        test_dataset, batch_size=4 if multi_view else BATCH_SIZE,
        shuffle=False, num_workers=2
    )
    
    # Optimizer & Scheduler
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    scaler = torch.amp.GradScaler()
    num_steps = len(train_loader) * epochs // GRAD_ACCUM_STEPS
    num_warmup = int(num_steps * WARMUP_RATIO)
    scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup, num_steps)
    
    # Phase 1 Training
    best_acc = 0.0
    for epoch in range(epochs):
        loss, acc = train_epoch(
            model, train_loader, optimizer, scheduler, scaler,
            use_mixup=use_mixup, label_smoothing=label_smoothing
        )
        print(f"  Epoch {epoch+1}/{epochs}: Loss={loss:.4f}, Acc={acc:.4f}")
        if acc > best_acc:
            best_acc = acc
            torch.save(model.state_dict(), f'{exp_name}_best.pt')
    
    # Phase 2 (if 2-stage)
    if two_stage:
        print("  --> Phase 2: Fine-tuning with Label Smoothing...")
        model.load_state_dict(torch.load(f'{exp_name}_best.pt'))
        
        # Phase 2 settings: low LR, no mixup, label smoothing
        p2_epochs = 3
        optimizer = torch.optim.AdamW(model.parameters(), lr=1e-6, weight_decay=WEIGHT_DECAY)
        num_steps_p2 = len(train_loader) * p2_epochs // GRAD_ACCUM_STEPS
        scheduler = get_cosine_schedule_with_warmup(optimizer, 0, num_steps_p2)
        
        p2_loader = DataLoader(
            train_dataset, batch_size=BATCH_SIZE, shuffle=True,
            num_workers=2, pin_memory=True, drop_last=True
        )
        
        for epoch in range(p2_epochs):
            loss, acc = train_epoch(
                model, p2_loader, optimizer, scheduler, scaler,
                use_mixup=False, label_smoothing=0.1
            )
            print(f"  P2 Epoch {epoch+1}/{p2_epochs}: Loss={loss:.4f}, Acc={acc:.4f}")
            best_acc = max(best_acc, acc)
    else:
        model.load_state_dict(torch.load(f'{exp_name}_best.pt'))
    
    # Evaluate
    predictions = evaluate(model, test_loader, multi_view=multi_view, id2label=id2label)
    test_acc = calc_accuracy(predictions)
    
    print(f"\n  >>> TEST ACCURACY: {test_acc:.4f} ({test_acc*100:.2f}%)")
    RESULTS.append({'exp': exp_name, 'train_acc': best_acc, 'test_acc': test_acc})
    
    # Cleanup
    del model
    torch.cuda.empty_cache()
    
    return test_acc

In [None]:
## ========================================
## PHASE 1: PAPER BASELINE EXPERIMENTS
## ========================================

In [None]:
## Exp 1: Paper Baseline (VideoMAE + paper settings, single view)

train_ds_baseline = VideoDataset(
    PATH_DATA_TRAIN, VideoTransformBaseline(is_train=True), use_pil=False
)
test_ds_single = TestDatasetSingle(PATH_DATA_TEST)

run_experiment("Exp1_Paper_Baseline", train_ds_baseline, test_ds_single)

In [None]:
## Exp 2: Paper + Multi-Segment TTA

test_ds_multi = TestDatasetMultiSegment(PATH_DATA_TEST, num_segments=2)

run_experiment(
    "Exp2_MultiSegment_TTA", train_ds_baseline, test_ds_multi,
    multi_view=True
)

In [None]:
## ========================================
## PHASE 2: CUSTOM IMPROVEMENTS (ABLATION)
## ========================================

In [None]:
## Exp 3: Consistent Spatial Transforms

train_ds_consistent = VideoDataset(
    PATH_DATA_TRAIN, VideoTransformConsistent(is_train=True), use_pil=True
)

run_experiment("Exp3_Consistent_Transform", train_ds_consistent, test_ds_single)

In [None]:
## Exp 4: Consistent + Mixup

mixup_collate = MixupCollate(num_classes=len(train_ds_consistent.classes), alpha=0.8)

run_experiment(
    "Exp4_Mixup", train_ds_consistent, test_ds_single,
    use_mixup=True, mixup_collate=mixup_collate
)

In [None]:
## Exp 5: Consistent + Label Smoothing (no mixup)

run_experiment(
    "Exp5_LabelSmoothing", train_ds_consistent, test_ds_single,
    label_smoothing=0.1
)

In [None]:
## Exp 6: Consistent + Mixup + 2-Stage

run_experiment(
    "Exp6_2Stage", train_ds_consistent, test_ds_single,
    use_mixup=True, mixup_collate=mixup_collate, two_stage=True
)

In [None]:
## Exp 7: Full Custom + Flip TTA (= videoMAE.ipynb)

test_ds_flip = TestDatasetFlipTTA(PATH_DATA_TEST)

run_experiment(
    "Exp7_FlipTTA_Full", train_ds_consistent, test_ds_flip,
    use_mixup=True, mixup_collate=mixup_collate, two_stage=True, multi_view=True
)

In [None]:
## Exp 8: Summary

print("\n" + "="*60)
print("EXP 8: Full Custom Config (= videoMAE.ipynb)")
print("="*60)
print("Same as Exp 7 but with 40 epochs (30+10) to reach 0.87")
print(f"Current Exp 7 result with {EPOCHS} epochs: {RESULTS[-1]['test_acc']:.4f}")
print("\nTo reproduce 0.87, run with EPOCHS=30 (Phase 1) + 10 (Phase 2)")

In [None]:
## ========================================
## FINAL RESULTS SUMMARY
## ========================================

print("\n" + "="*60)
print("ALL EXPERIMENTS RESULTS")
print("="*60)

results_df = pd.DataFrame(RESULTS)
baseline_acc = results_df['test_acc'].iloc[0]
results_df['delta'] = (results_df['test_acc'] - baseline_acc) * 100
results_df['delta'] = results_df['delta'].apply(lambda x: f"+{x:.2f}%" if x > 0 else f"{x:.2f}%")

print(results_df.to_string(index=False))

print(f"\nBaseline: {baseline_acc:.4f} ({baseline_acc*100:.2f}%)")
print(f"Best: {results_df['test_acc'].max():.4f} ({results_df['test_acc'].max()*100:.2f}%)")
print(f"Improvement: +{(results_df['test_acc'].max() - baseline_acc)*100:.2f}%")

results_df.to_csv('ablation_results.csv', index=False)
print("\nResults saved to ablation_results.csv")