# ViT-Base Experiment for Video Action Recognition

**Objective**: Compare ViT-Base vs ViT-Small for frame-level video classification.

**Model**: `vit_base_patch16_224` (ImageNet-21k pretrained)

**Expected**: Higher accuracy than ViT-Small (63.92%) due to larger capacity.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.transforms import InterpolationMode
import torchvision.transforms.functional as TF
from pathlib import Path
from PIL import Image
from tqdm.auto import tqdm
import timm
import random
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score, classification_report

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {DEVICE}')

## 1. Configuration

In [None]:
# Data paths
PATH_DATA_TRAIN = r'/kaggle/input/action-video/data/data_train'
PATH_DATA_TEST = r'/kaggle/input/action-video/data/test'

# Model parameters 
NUM_FRAMES = 16
IMG_SIZE = 224
RESIZE_SIZE = 256

# Training parameters
BATCH_SIZE = 8  # Smaller due to larger model
EPOCHS = 10
BASE_LR = 5e-5
HEAD_LR = 5e-4
WEIGHT_DECAY = 0.05
GRAD_ACCUM_STEPS = 4

# Model choice - ViT-Base
PRETRAINED_NAME = 'vit_base_patch16_224'
    
print(f"Train data: {PATH_DATA_TRAIN}")
print(f"Test data: {PATH_DATA_TEST}")
print(f"Model: {PRETRAINED_NAME}")
print(f"Frames per video: {NUM_FRAMES}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Epochs: {EPOCHS}")

## 2. ViT-Base Model with Temporal Pooling

In [None]:
class ViTBaseForAction(nn.Module):
    """ViT-Base for action recognition with temporal mean pooling."""
    
    def __init__(self, num_classes=51, pretrained_name='vit_base_patch16_224'):
        super().__init__()
        
        # Load pretrained ViT-Base
        self.vit = timm.create_model(pretrained_name, pretrained=True, num_classes=0)
        
        # Get embedding dimension (768 for ViT-Base)
        self.embed_dim = self.vit.num_features
        
        # Classification head with dropout
        self.head = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(self.embed_dim, num_classes)
        )
    
    def forward(self, video):
        '''
        Args:
            video: [B, T, C, H, W] - batch of video clips
        Returns:
            logits: [B, num_classes]
        '''
        B, T, C, H, W = video.shape
        
        # Reshape to process all frames
        x = video.view(B * T, C, H, W)
        
        # Extract features with ViT
        features = self.vit(x)  # [B*T, embed_dim]
        
        # Reshape back
        features = features.view(B, T, self.embed_dim)
        
        # Temporal pooling (mean)
        pooled = features.mean(dim=1)  # [B, embed_dim]
        
        # Classification
        logits = self.head(pooled)
        
        return logits

print(f"ViT-Base model defined")
print(f"  Backbone: {PRETRAINED_NAME}")
print(f"  Expected embed_dim: 768")

## 3. Data Augmentation (Consistent Spatial Transform)

In [None]:
class VideoTransform:
    def __init__(self, image_size=224, resize_size=256, is_train=True):
        self.image_size = image_size
        self.resize_size = resize_size
        self.is_train = is_train
        self.mean = [0.485, 0.456, 0.406]  # ImageNet stats
        self.std = [0.229, 0.224, 0.225]
    
    def __call__(self, frames):
        """Apply consistent transform across all frames."""
        # Resize all frames first
        frames = [TF.resize(f, self.resize_size, interpolation=InterpolationMode.BILINEAR) for f in frames]
        
        if self.is_train:
            # Get random crop params (same for all frames)
            i, j, h, w = transforms.RandomResizedCrop.get_params(
                frames[0], scale=(0.8, 1.0), ratio=(0.75, 1.33)
            )
            do_flip = random.random() > 0.5
            
            transformed = []
            for img in frames:
                img = TF.resized_crop(img, i, j, h, w, (self.image_size, self.image_size))
                if do_flip:
                    img = TF.hflip(img)
                img = TF.to_tensor(img)
                img = TF.normalize(img, self.mean, self.std)
                transformed.append(img)
        else:
            # Center crop for validation
            transformed = []
            for img in frames:
                img = TF.center_crop(img, self.image_size)
                img = TF.to_tensor(img)
                img = TF.normalize(img, self.mean, self.std)
                transformed.append(img)
                
        return torch.stack(transformed)

print("Augmentation defined (Consistent Spatial Transform)")

## 4. Dataset Classes

In [None]:
class VideoDataset(Dataset):
    def __init__(self, root, num_frames=16, image_size=224, is_train=True):
        self.root = Path(root)
        self.num_frames = num_frames
        self.transform = VideoTransform(image_size, is_train=is_train)
        
        self.classes = sorted([d.name for d in self.root.iterdir() if d.is_dir()])
        self.class_to_idx = {name: idx for idx, name in enumerate(self.classes)}
        
        self.samples = []
        for cls in self.classes:
            cls_dir = self.root / cls
            for video_dir in sorted([d for d in cls_dir.iterdir() if d.is_dir()]):
                frame_paths = sorted([p for p in video_dir.iterdir() if p.suffix.lower() in {'.jpg', '.jpeg', '.png'}])
                if frame_paths:
                    self.samples.append((frame_paths, self.class_to_idx[cls]))
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        frame_paths, label = self.samples[idx]
        total = len(frame_paths)
        
        # Uniform sampling
        indices = torch.linspace(0, total - 1, self.num_frames).long()
        
        frames = []
        for i in indices:
            img = Image.open(frame_paths[i]).convert("RGB")
            frames.append(img)
        
        video = self.transform(frames)
        return video, label


class TestDataset(Dataset):
    def __init__(self, root, num_frames=16, image_size=224):
        self.root = Path(root)
        self.num_frames = num_frames
        self.transform = VideoTransform(image_size, is_train=False)
        self.video_dirs = sorted([d for d in self.root.iterdir() if d.is_dir()], key=lambda x: int(x.name))
        self.video_ids = [int(d.name) for d in self.video_dirs]
    
    def __len__(self):
        return len(self.video_dirs)
    
    def __getitem__(self, idx):
        video_dir = self.video_dirs[idx]
        video_id = self.video_ids[idx]
        frame_paths = sorted([p for p in video_dir.iterdir() if p.suffix.lower() in {'.jpg', '.jpeg', '.png'}])
        
        total = len(frame_paths)
        indices = torch.linspace(0, total - 1, self.num_frames).long()
        
        frames = []
        for i in indices:
            img = Image.open(frame_paths[i]).convert("RGB")
            frames.append(img)
        
        video = self.transform(frames)
        return video, video_id

print("Dataset classes defined")

## 5. Training Functions

In [None]:
def train_one_epoch(model, loader, optimizer, scaler, device, grad_accum_steps=1):
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0
    
    optimizer.zero_grad()
    progress = tqdm(loader, desc="Train", leave=False)
    
    for batch_idx, (videos, labels) in enumerate(progress):
        videos = videos.to(device)
        labels = labels.to(device)
        
        with torch.amp.autocast(device_type='cuda', enabled=(device.type == 'cuda')):
            logits = model(videos)
            loss = F.cross_entropy(logits, labels)
        
        preds = logits.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
        
        loss_value = loss.item()
        loss = loss / grad_accum_steps
        scaler.scale(loss).backward()
        
        should_step = ((batch_idx + 1) % grad_accum_steps == 0) or (batch_idx + 1 == len(loader))
        if should_step:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
        
        batch_size = videos.size(0)
        total_loss += loss_value * batch_size
        progress.set_postfix(loss=f"{loss_value:.4f}", acc=f"{correct / max(total, 1):.4f}")
    
    avg_loss = total_loss / max(total, 1)
    avg_acc = correct / max(total, 1)
    return avg_loss, avg_acc

print("Training functions defined")

## 6. Load Data & Create Model

In [None]:
print("Loading training dataset...")
train_dataset = VideoDataset(PATH_DATA_TRAIN, num_frames=NUM_FRAMES, image_size=IMG_SIZE, is_train=True)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)

print(f"Train samples: {len(train_dataset)}")
print(f"Classes: {len(train_dataset.classes)}")
print(f"Batches per epoch: {len(train_loader)}")

In [None]:
print("Creating ViT-Base model...")
model = ViTBaseForAction(num_classes=len(train_dataset.classes), pretrained_name=PRETRAINED_NAME).to(DEVICE)

num_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total parameters: {num_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Model size: {num_params * 4 / 1024 / 1024:.2f} MB")

In [None]:
# Optimizer with different LR for backbone and head
backbone_params = []
head_params = []

for name, param in model.named_parameters():
    if not param.requires_grad:
        continue
    if 'head' in name:
        head_params.append(param)
    else:
        backbone_params.append(param)

optimizer = torch.optim.AdamW([
    {"params": backbone_params, "lr": BASE_LR},
    {"params": head_params, "lr": HEAD_LR},
], weight_decay=WEIGHT_DECAY)

scaler = torch.amp.GradScaler(enabled=torch.cuda.is_available())
print(f"Optimizer: AdamW | Base LR: {BASE_LR} | Head LR: {HEAD_LR}")

## 7. Training Loop

In [None]:
best_acc = 0.0
checkpoint_path = Path('./vit_base_best.pt')

print("="*50)
print("TRAINING VIT-BASE")
print("="*50)

for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch + 1}/{EPOCHS}")
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, scaler, DEVICE, GRAD_ACCUM_STEPS)
    print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
    
    if train_acc > best_acc:
        best_acc = train_acc
        torch.save({
            'model': model.state_dict(), 
            'classes': train_dataset.classes, 
            'acc': best_acc,
            'epoch': epoch + 1
        }, checkpoint_path)
        print(f"  >>> Best model saved (acc: {best_acc:.4f})")

print("\n" + "="*50)
print(f"Training completed! Best train accuracy: {best_acc:.4f}")
print(f"Model saved to: {checkpoint_path}")

## 8. Inference on Test Set

In [None]:
print("INFERENCE ON TEST SET")

# Load best checkpoint
print(f"Loading checkpoint from {checkpoint_path}...")
checkpoint = torch.load(checkpoint_path, map_location=DEVICE)
classes = checkpoint['classes']

model = ViTBaseForAction(num_classes=len(classes), pretrained_name=PRETRAINED_NAME).to(DEVICE)
model.load_state_dict(checkpoint['model'])
model.eval()
print(f"Model loaded (trained acc: {checkpoint['acc']:.4f})")

# Load test dataset
print("\nLoading test dataset...")
test_dataset = TestDataset(PATH_DATA_TEST, num_frames=NUM_FRAMES, image_size=IMG_SIZE)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
print(f"Test samples: {len(test_dataset)}")

In [None]:
print("\nRunning inference...")
predictions = []

with torch.no_grad():
    for videos, video_ids in tqdm(test_loader, desc="Inference"):
        videos = videos.to(DEVICE)
        logits = model(videos)
        preds = logits.argmax(dim=1)
        
        for video_id, pred_idx in zip(video_ids.cpu().numpy(), preds.cpu().numpy()):
            pred_class = classes[pred_idx]
            predictions.append((video_id, pred_class))

predictions.sort(key=lambda x: x[0])
print(f"\nTotal predictions: {len(predictions)}")

## 9. Evaluate with Ground Truth Labels

In [None]:
!gdown "1Xv2CWOqdBj3kt0rkNJKRsodSIEd3-wX_" -O test_labels.csv -q
print("Downloaded test_labels.csv")

In [None]:
# Load ground truth
gt_df = pd.read_csv("test_labels.csv")
test_labels = dict(zip(gt_df['id'].astype(str), gt_df['class']))

# Match predictions with ground truth
y_pred = []
y_true = []
for video_id, pred_class in predictions:
    video_id_str = str(video_id)
    if video_id_str in test_labels:
        y_pred.append(pred_class)
        y_true.append(test_labels[video_id_str])

# Calculate accuracy
accuracy = accuracy_score(y_true, y_pred)

print("=" * 50)
print("VIT-BASE TEST SET EVALUATION")
print("=" * 50)
print(f"Total: {len(y_true)} | Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
print()
print("Comparison:")
print(f"  ViT-Small (baseline): 63.92%")
print(f"  ViT-Base (this):      {accuracy*100:.2f}%")
print(f"  Improvement:          {(accuracy - 0.6392)*100:+.2f}%")
print()
print(classification_report(y_true, y_pred, zero_division=0))

## 10. Save Submission

In [None]:
# Create submission file
submission = pd.DataFrame(predictions, columns=['id', 'class'])
submission.to_csv('submission_vit_base.csv', index=False)
print(f"Saved submission_vit_base.csv ({len(submission)} rows)")
print(submission.head())