# Unified Ablation Study - Video Action Recognition

**Toggle ON/OFF để chạy các ablation experiments khác nhau**

## Preset Configs

| Preset | Toggles | Expected Acc (10 epochs)|
|--------|---------|--------------|
| ViT-Small Baseline | `MODEL="vit_small"`, all OFF | 69.22% |
| ViT-Base | `MODEL="vit_base"`, all OFF | 73.73% |
| VideoMAE Baseline | `MODEL="videomae"`, all OFF | 83.92% |
| VideoMAE 8-Frame | `MODEL="videomae"`, `NUM_FRAMES=8` | 83.00% @ Phase 1: 50 epochs & Phase 2: 10 epochs |
| Phase 3 (Current) | VideoMAE + CONSISTENT + MIXUP + LABEL_SMOOTHING + TWO_PHASE + FLIP_TTA | **85.10%** @ Phase 1: 30 epochs & Phase 2: 10 epochs |
| Data Balance | VideoMAE + DATA_BALANCE + FOCAL_LOSS | 84.00% |
| Layer Decay | VideoMAE + LAYER_DECAY + MIXUP | 84.51% |

In [None]:
# =================== ABLATION STUDY CONFIGURATION ===================
# Toggle ON/OFF để bật/tắt các tính năng

# ===== EXPERIMENT MODE =====
WARMUP = False              # True = test pipeline nhanh (all models, 5 batches mỗi phase)

# ===== MODEL SELECTION (chọn 1) =====
MODEL_TYPE = "videomae"     # "vit_small" | "vit_base" | "videomae"

# ===== FRAME CONFIG =====
NUM_FRAMES = 16             # 8 hoặc 16

# ===== DATA SPLIT =====
TRAIN_VAL_RATIO = 0.9       # 0.9 = 90% train, 10% val. Set 1.0 để dùng full train
USE_TEST_LABELS = True      # True = download test_labels.csv để tính test accuracy
TEST_LABELS_URL = '1Xv2CWOqdBj3kt0rkNJKRsodSIEd3-wX_'  # Google Drive file ID

# ===== DATA AUGMENTATION TOGGLES =====
USE_DATA_BALANCE = False            # Offline augmentation để cân bằng class
USE_CONSISTENT_SPATIAL_AUG = True   # Same crop/flip cho all frames
USE_MIXUP = True                     # Mixup augmentation (α=0.8)
USE_FOCAL_LOSS = False               # Focal Loss cho imbalanced data

# ===== TRAINING TOGGLES =====
USE_LABEL_SMOOTHING = True          # Label smoothing (ε=0.1)
USE_TWO_PHASE = True                # 2-Phase: Mixup → Label Smoothing
USE_LAYER_DECAY = False             # Layer-wise LR decay

# ===== INFERENCE TOGGLES =====
USE_MULTI_SEGMENT = False           # Multi-segment temporal sampling
USE_FLIP_TTA = True                 # 6-view FlipTTA

# ===== TRAINING PARAMS =====
EPOCHS_P1 = 30 if not WARMUP else 1
EPOCHS_P2 = 10 if not WARMUP else 1
WARMUP_BATCHES = 5                  # Batches per phase khi WARMUP=True
BATCH_SIZE = 8
ACCUM_STEPS = 4
LR_P1 = 5e-5
LR_P2 = 1e-6
WEIGHT_DECAY = 0.05
WARMUP_RATIO = 0.1
MIXUP_ALPHA = 0.8
LABEL_SMOOTHING_EPS = 0.1

# ===== PATHS (Kaggle) =====
PATH_DATA_TRAIN = '/kaggle/input/action-video/data/data_train'
PATH_DATA_TEST = '/kaggle/input/action-video/data/test'

print("=" * 60)
print("ABLATION CONFIGURATION")
print("=" * 60)
print(f"WARMUP:            {WARMUP}")
print(f"MODEL_TYPE:        {MODEL_TYPE}")
print(f"NUM_FRAMES:        {NUM_FRAMES}")
print(f"TRAIN_VAL_RATIO:   {TRAIN_VAL_RATIO}")
print(f"USE_TEST_LABELS:   {USE_TEST_LABELS}")
print(f"DATA_BALANCE:      {USE_DATA_BALANCE}")
print(f"CONSISTENT_AUG:    {USE_CONSISTENT_SPATIAL_AUG}")
print(f"MIXUP:             {USE_MIXUP}")
print(f"FOCAL_LOSS:        {USE_FOCAL_LOSS}")
print(f"LABEL_SMOOTHING:   {USE_LABEL_SMOOTHING}")
print(f"TWO_PHASE:         {USE_TWO_PHASE}")
print(f"LAYER_DECAY:       {USE_LAYER_DECAY}")
print(f"MULTI_SEGMENT:     {USE_MULTI_SEGMENT}")
print(f"FLIP_TTA:          {USE_FLIP_TTA}")
print(f"EPOCHS:            P1={EPOCHS_P1}, P2={EPOCHS_P2}")

In [None]:
# Imports
import os
import gc
import random
import shutil
import warnings
warnings.filterwarnings('ignore')
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from pathlib import Path
from collections import Counter

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torchvision.transforms.functional as TF

import timm
from transformers import VideoMAEForVideoClassification, VideoMAEImageProcessor
from transformers import get_cosine_schedule_with_warmup
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

# Seed
def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

seed_everything(42)
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {DEVICE}')

# Constants
IMG_SIZE = 224
RESIZE_SIZE = 256
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]

## 1. Model Factory

In [None]:
class LightweightViTForAction(nn.Module):
    """ViT frame-level model (ImageNet pretrained)."""
    def __init__(self, num_classes=51, model_name='vit_small_patch16_224'):
        super().__init__()
        self.vit = timm.create_model(model_name, pretrained=True, num_classes=0)
        self.embed_dim = self.vit.num_features
        self.head = nn.Linear(self.embed_dim, num_classes)
    
    def forward(self, video):
        B, T, C, H, W = video.shape
        x = video.view(B * T, C, H, W)
        features = self.vit(x)
        features = features.view(B, T, self.embed_dim)
        pooled = features.mean(dim=1)
        return self.head(pooled)


def create_model(model_type, num_classes, num_frames):
    """Model factory based on MODEL_TYPE toggle."""
    if model_type == 'vit_small':
        model = LightweightViTForAction(num_classes, 'vit_small_patch16_224')
        print(f'Created ViT-Small model ({sum(p.numel() for p in model.parameters()):,} params)')
    elif model_type == 'vit_base':
        model = LightweightViTForAction(num_classes, 'vit_base_patch16_224')
        print(f'Created ViT-Base model ({sum(p.numel() for p in model.parameters()):,} params)')
    elif model_type == 'videomae':
        model = VideoMAEForVideoClassification.from_pretrained(
            'MCG-NJU/videomae-base-finetuned-kinetics',
            num_labels=num_classes,
            ignore_mismatched_sizes=True,
            num_frames=num_frames
        )
        print(f'Created VideoMAE model ({sum(p.numel() for p in model.parameters()):,} params)')
    else:
        raise ValueError(f'Unknown model_type: {model_type}')
    return model.to(DEVICE)

print('Model factory defined')

## 2. Dataset Classes

In [None]:
class VideoDataset(Dataset):
    """Training dataset with toggle for consistent spatial aug."""
    def __init__(self, root, num_frames=16, consistent_aug=True):
        self.root = Path(root)
        self.num_frames = num_frames
        self.consistent_aug = consistent_aug
        
        self.classes = sorted([d.name for d in self.root.iterdir() if d.is_dir()])
        self.class_to_idx = {c: i for i, c in enumerate(self.classes)}
        
        self.samples = []
        for cls in self.classes:
            for vid in (self.root / cls).iterdir():
                if vid.is_dir():
                    self.samples.append((vid, self.class_to_idx[cls]))
        print(f'Loaded {len(self.samples)} videos, {len(self.classes)} classes')
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        vid_dir, label = self.samples[idx]
        files = sorted(vid_dir.glob('*.jpg'))
        indices = torch.linspace(0, len(files)-1, self.num_frames).long()
        
        frames = [Image.open(files[i]).convert('RGB') for i in indices]
        frames = [TF.resize(img, RESIZE_SIZE) for img in frames]
        
        if self.consistent_aug:
            # Same crop/flip for all frames
            i, j, h, w = T.RandomResizedCrop.get_params(frames[0], (0.8, 1.0), (0.75, 1.33))
            do_flip = random.random() > 0.5
            processed = []
            for img in frames:
                img = TF.resized_crop(img, i, j, h, w, (IMG_SIZE, IMG_SIZE))
                if do_flip:
                    img = TF.hflip(img)
                img = TF.normalize(TF.to_tensor(img), MEAN, STD)
                processed.append(img)
        else:
            processed = [TF.normalize(TF.to_tensor(TF.center_crop(img, IMG_SIZE)), MEAN, STD) for img in frames]
        
        return torch.stack(processed), label


class TestDatasetSingle(Dataset):
    """Test with single center crop."""
    def __init__(self, root, num_frames=16):
        self.root = Path(root)
        self.num_frames = num_frames
        self.samples = [(d, int(d.name)) for d in self.root.iterdir() if d.is_dir()]
        self.samples.sort(key=lambda x: x[1])
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        vid_dir, vid_id = self.samples[idx]
        files = sorted(vid_dir.glob('*.jpg'))
        indices = torch.linspace(0, len(files)-1, self.num_frames).long()
        frames = []
        for i in indices:
            img = Image.open(files[i]).convert('RGB')
            img = TF.resize(img, RESIZE_SIZE)
            img = TF.center_crop(img, IMG_SIZE)
            img = TF.normalize(TF.to_tensor(img), MEAN, STD)
            frames.append(img)
        return torch.stack(frames), vid_id


class TestDatasetFlipTTA(Dataset):
    """6-view TTA: 3 spatial crops × 2 flip states."""
    def __init__(self, root, num_frames=16):
        self.root = Path(root)
        self.num_frames = num_frames
        self.samples = [(d, int(d.name)) for d in self.root.iterdir() if d.is_dir()]
        self.samples.sort(key=lambda x: x[1])
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        vid_dir, vid_id = self.samples[idx]
        files = sorted(vid_dir.glob('*.jpg'))
        indices = torch.linspace(0, len(files)-1, self.num_frames).long()
        
        frames = [Image.open(files[i]).convert('RGB') for i in indices]
        frames = [TF.resize(img, RESIZE_SIZE) for img in frames]
        
        w, h = frames[0].size
        views = []
        
        # 3 spatial crops
        crop_positions = [
            ((h - IMG_SIZE) // 2, (w - IMG_SIZE) // 2),  # center
            (0, 0),  # top-left
            (max(0, h - IMG_SIZE), max(0, w - IMG_SIZE))  # bottom-right
        ]
        
        for top, left in crop_positions:
            view_frames = []
            for img in frames:
                cropped = TF.crop(img, top, left, IMG_SIZE, IMG_SIZE)
                view_frames.append(TF.normalize(TF.to_tensor(cropped), MEAN, STD))
            views.append(torch.stack(view_frames))
            
            # Flipped version
            view_frames_flip = []
            for img in frames:
                cropped = TF.crop(img, top, left, IMG_SIZE, IMG_SIZE)
                cropped = TF.hflip(cropped)
                view_frames_flip.append(TF.normalize(TF.to_tensor(cropped), MEAN, STD))
            views.append(torch.stack(view_frames_flip))
        
        return torch.stack(views), vid_id  # [6, T, C, H, W]

print('Dataset classes defined')

## 3. Augmentation Utilities

In [None]:
class MixupCollate:
    """Collate function with Mixup augmentation."""
    def __init__(self, num_classes, alpha=0.8, prob=1.0):
        self.num_classes = num_classes
        self.alpha = alpha
        self.prob = prob

    def __call__(self, batch):
        inputs, targets = torch.utils.data.default_collate(batch)
        if np.random.rand() > self.prob:
            return inputs, F.one_hot(targets, num_classes=self.num_classes).float()
        batch_size = inputs.size(0)
        index = torch.randperm(batch_size)
        lam = np.random.beta(self.alpha, self.alpha)
        inputs = lam * inputs + (1 - lam) * inputs[index, :]
        targets_one_hot = F.one_hot(targets, num_classes=self.num_classes).float()
        targets = lam * targets_one_hot + (1 - lam) * targets_one_hot[index, :]
        return inputs, targets


def focal_loss(logits, targets, alpha=0.25, gamma=2.0):
    """Focal Loss for imbalanced data."""
    ce_loss = F.cross_entropy(logits, targets, reduction='none')
    pt = torch.exp(-ce_loss)
    focal = alpha * (1 - pt) ** gamma * ce_loss
    return focal.mean()

print('Augmentation utilities defined')

## 4. Training Functions

In [None]:
def train_epoch(model, loader, optimizer, scheduler, scaler, 
                use_mixup=False, use_focal=False, label_smoothing=0.0,
                warmup_mode=False, warmup_batches=5):
    """Train one epoch with toggle support."""
    model.train()
    total_loss, total_correct, total_samples = 0.0, 0, 0
    pbar = tqdm(loader, desc='Training', leave=False)
    optimizer.zero_grad()
    
    for step, (inputs, targets) in enumerate(pbar):
        if warmup_mode and step >= warmup_batches:
            break
            
        inputs = inputs.to(DEVICE)
        
        with torch.amp.autocast('cuda'):
            # Get logits (handle VideoMAE output)
            output = model(inputs)
            logits = output.logits if hasattr(output, 'logits') else output
            
            # Compute loss based on toggles
            if use_mixup:
                targets = targets.to(DEVICE)
                log_probs = F.log_softmax(logits, dim=1)
                loss = -torch.sum(targets * log_probs, dim=1).mean()
                true_labels = targets.argmax(dim=1)
            elif use_focal:
                targets = targets.to(DEVICE)
                loss = focal_loss(logits, targets)
                true_labels = targets
            else:
                targets = targets.to(DEVICE)
                loss = F.cross_entropy(logits, targets, label_smoothing=label_smoothing)
                true_labels = targets
        
        total_correct += (logits.argmax(1) == true_labels).sum().item()
        total_samples += inputs.size(0)
        
        scaler.scale(loss / ACCUM_STEPS).backward()
        
        if (step + 1) % ACCUM_STEPS == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            if scheduler:
                scheduler.step()
        
        total_loss += loss.item()
        pbar.set_postfix({'loss': f'{total_loss/(step+1):.4f}', 'acc': f'{total_correct/max(total_samples,1):.4f}'})
    
    return total_loss / max(step+1, 1), total_correct / max(total_samples, 1)


@torch.no_grad()
def evaluate(model, loader, classes, gt_dict, use_tta=False, warmup_mode=False, warmup_batches=5):
    """Evaluate model on test set."""
    model.eval()
    predictions = []
    
    for step, (videos, video_ids) in enumerate(tqdm(loader, desc='Evaluating', leave=False)):
        if warmup_mode and step >= warmup_batches:
            break
            
        if use_tta:
            # videos: [B, 6, T, C, H, W]
            B, V, T, C, H, W = videos.shape
            videos = videos.view(B * V, T, C, H, W).to(DEVICE)
            output = model(videos)
            logits = output.logits if hasattr(output, 'logits') else output
            logits = logits.view(B, V, -1).mean(dim=1)
            preds = logits.argmax(1).cpu().tolist()
        else:
            videos = videos.to(DEVICE)
            output = model(videos)
            logits = output.logits if hasattr(output, 'logits') else output
            preds = logits.argmax(1).cpu().tolist()
        
        predictions.extend(zip(video_ids.tolist(), preds))
    
    if warmup_mode:
        return 0.0  # Skip accuracy calc in warmup
    
    if gt_dict is None:
        return None  # No test labels available
    
    y_true = [gt_dict[str(vid)] for vid, _ in predictions]
    y_pred = [classes[p] for _, p in predictions]
    return accuracy_score(y_true, y_pred)


@torch.no_grad()
def evaluate_val(model, loader):
    """Evaluate model on validation set (returns accuracy)."""
    model.eval()
    correct, total = 0, 0
    for videos, labels in tqdm(loader, desc='Val Eval', leave=False):
        videos, labels = videos.to(DEVICE), labels.to(DEVICE)
        output = model(videos)
        logits = output.logits if hasattr(output, 'logits') else output
        correct += (logits.argmax(1) == labels).sum().item()
        total += labels.size(0)
    return correct / total if total > 0 else 0.0

print('Training functions defined')

## 5. WARMUP Mode (Test All Pipelines)

In [None]:
def run_warmup_test():
    """Test ALL models to ensure pipeline works."""
    print("\n" + "="*60)
    print("WARMUP MODE: Testing all model pipelines")
    print("="*60)
    
    # Small dataset for testing
    train_ds = VideoDataset(PATH_DATA_TRAIN, NUM_FRAMES, USE_CONSISTENT_SPATIAL_AUG)
    test_ds = TestDatasetSingle(PATH_DATA_TEST, NUM_FRAMES)
    
    # WARMUP uses num_workers=0 to avoid multiprocessing warnings
    train_loader = DataLoader(train_ds, BATCH_SIZE, shuffle=True, num_workers=0, drop_last=True)
    test_loader = DataLoader(test_ds, BATCH_SIZE, shuffle=False, num_workers=0)
    
    # Ground truth
    !gdown "1Xv2CWOqdBj3kt0rkNJKRsodSIEd3-wX_" -O test_labels.csv -q
    gt_df = pd.read_csv('test_labels.csv')
    gt_dict = dict(zip(gt_df['id'].astype(str), gt_df['class']))
    
    WARMUP_MODELS = ['vit_small', 'vit_base', 'videomae']
    
    for model_type in WARMUP_MODELS:
        print(f"\n{'='*50}")
        print(f"WARMUP: Testing {model_type}")
        print(f"{'='*50}")
        
        try:
            model = create_model(model_type, len(full_ds.classes), NUM_FRAMES)
            optimizer = torch.optim.AdamW(model.parameters(), lr=LR_P1, weight_decay=WEIGHT_DECAY)
            scaler = torch.amp.GradScaler()
            
            # Phase 1 - NO mixup in warmup (just test forward/backward pass)
            loss, acc = train_epoch(model, train_loader, optimizer, None, scaler,
                                   use_mixup=False, warmup_mode=True, warmup_batches=WARMUP_BATCHES)
            print(f"✓ Phase 1 OK (loss={loss:.4f}, acc={acc:.4f})")
            
            # Phase 2 - test label smoothing
            loss, acc = train_epoch(model, train_loader, optimizer, None, scaler,
                                   label_smoothing=LABEL_SMOOTHING_EPS, warmup_mode=True, warmup_batches=WARMUP_BATCHES)
            print(f"✓ Phase 2 OK (loss={loss:.4f}, acc={acc:.4f})")
            
            # Eval
            evaluate(model, test_loader, full_ds.classes, gt_dict, warmup_mode=True, warmup_batches=WARMUP_BATCHES)
            print(f"✓ Eval OK")
            
            del model
            torch.cuda.empty_cache()
            gc.collect()
            
        except Exception as e:
            print(f"✗ FAILED: {e}")
            raise
    
    print("\n" + "="*60)
    print("✅ WARMUP COMPLETE - All models pipeline OK!")
    print("="*60)

if WARMUP:
    run_warmup_test()

## 6. Main Training

In [None]:
if not WARMUP:
    # Load full dataset
    full_ds = VideoDataset(PATH_DATA_TRAIN, NUM_FRAMES, USE_CONSISTENT_SPATIAL_AUG)
    
    # Split train/val if ratio < 1.0
    if TRAIN_VAL_RATIO < 1.0:
        train_size = int(len(full_ds) * TRAIN_VAL_RATIO)
        val_size = len(full_ds) - train_size
        train_ds, val_ds = torch.utils.data.random_split(full_ds, [train_size, val_size])
        print(f'Split: {train_size} train, {val_size} val')
    else:
        train_ds = full_ds
        val_ds = None
        print('Using full dataset for training (no validation split)')
    
    if USE_FLIP_TTA:
        test_ds = TestDatasetFlipTTA(PATH_DATA_TEST, NUM_FRAMES)
    else:
        test_ds = TestDatasetSingle(PATH_DATA_TEST, NUM_FRAMES)
    
    # DataLoaders
    if USE_MIXUP:
        mixup_collate = MixupCollate(len(full_ds.classes), MIXUP_ALPHA)
        train_loader = DataLoader(train_ds, BATCH_SIZE, shuffle=True, num_workers=2, 
                                  collate_fn=mixup_collate, drop_last=True)
    else:
        train_loader = DataLoader(train_ds, BATCH_SIZE, shuffle=True, num_workers=2, drop_last=True)
    
    if val_ds:
        val_loader = DataLoader(val_ds, BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True, persistent_workers=True)
    
    test_loader = DataLoader(test_ds, BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True, persistent_workers=True)
    
    # Ground truth for test accuracy (optional)
    gt_dict = None
    if USE_TEST_LABELS:
        !gdown "{TEST_LABELS_URL}" -O test_labels.csv -q
        gt_df = pd.read_csv('test_labels.csv')
        gt_dict = dict(zip(gt_df['id'].astype(str), gt_df['class']))
        print(f'Loaded test labels: {len(gt_dict)} samples')
    else:
        print('Test labels disabled - test accuracy will not be calculated')
    
    # Create model
    model = create_model(MODEL_TYPE, len(full_ds.classes), NUM_FRAMES)
    
    print(f'\nTrain samples: {len(train_ds)}')
    if val_ds:
        print(f'Val samples: {len(val_ds)}')
    print(f'Test samples: {len(test_ds)}')


In [None]:
if not WARMUP:
    # Visualize original class distribution
    print('\\n' + '='*60)
    print('ORIGINAL CLASS DISTRIBUTION')
    print('='*60)
    
    class_counts = Counter([full_ds.samples[i][1] for i in range(len(full_ds.samples))])
    classes_names = full_ds.classes
    counts = [class_counts.get(i, 0) for i in range(len(classes_names))]
    
    # Sort by count
    sorted_pairs = sorted(zip(classes_names, counts), key=lambda x: x[1], reverse=True)
    sorted_names = [p[0] for p in sorted_pairs]
    sorted_counts = [p[1] for p in sorted_pairs]
    
    plt.figure(figsize=(16, 6))
    colors = ['darkgreen' if c >= np.mean(counts) else 'orange' if c >= np.mean(counts)*0.5 else 'darkred' for c in sorted_counts]
    plt.bar(range(len(sorted_names)), sorted_counts, color=colors, alpha=0.7)
    plt.axhline(y=np.mean(counts), color='blue', linestyle='--', linewidth=2, label=f'Mean ({np.mean(counts):.1f})')
    plt.xlabel('Action Category', fontsize=12)
    plt.ylabel('Sample Count', fontsize=12)
    plt.title('Original Training Data Distribution (Before Augmentation)', fontsize=14, fontweight='bold')
    plt.xticks(range(len(sorted_names)), sorted_names, rotation=90, ha='right', fontsize=8)
    plt.legend()
    plt.tight_layout()
    plt.savefig('original_distribution.png', dpi=150)
    plt.show()
    
    print(f'Total samples: {sum(counts)}')
    print(f'Min class: {sorted_names[-1]} ({sorted_counts[-1]} samples)')
    print(f'Max class: {sorted_names[0]} ({sorted_counts[0]} samples)')
    print(f'Imbalance ratio: {sorted_counts[0]/sorted_counts[-1]:.2f}x')

In [None]:
if not WARMUP and USE_DATA_BALANCE:
    # Visualize balanced class distribution after augmentation
    print('\\n' + '='*60)
    print('BALANCED CLASS DISTRIBUTION (After Augmentation)')
    print('='*60)
    
    # Get balanced counts (after augmentation)
    balanced_class_counts = Counter([full_ds.samples[i][1] for i in range(len(full_ds.samples))])
    balanced_counts = [balanced_class_counts.get(i, 0) for i in range(len(classes_names))]
    
    # Sort by original order for comparison
    sorted_balanced = [balanced_class_counts.get(full_ds.class_to_idx[n], 0) for n in sorted_names]
    
    fig, axes = plt.subplots(1, 2, figsize=(16, 5))
    
    # Before
    axes[0].bar(range(len(sorted_names)), sorted_counts, color='coral', alpha=0.7)
    axes[0].axhline(y=np.mean(sorted_counts), color='blue', linestyle='--', linewidth=2)
    axes[0].set_title('Before Balance Augmentation', fontweight='bold')
    axes[0].set_xlabel('Class')
    axes[0].set_ylabel('Count')
    axes[0].set_xticks(range(len(sorted_names)))
    axes[0].set_xticklabels(sorted_names, rotation=90, fontsize=6)
    
    # After
    axes[1].bar(range(len(sorted_names)), sorted_balanced, color='steelblue', alpha=0.7)
    axes[1].axhline(y=np.mean(sorted_balanced), color='blue', linestyle='--', linewidth=2)
    axes[1].set_title('After Balance Augmentation', fontweight='bold')
    axes[1].set_xlabel('Class')
    axes[1].set_ylabel('Count')
    axes[1].set_xticks(range(len(sorted_names)))
    axes[1].set_xticklabels(sorted_names, rotation=90, fontsize=6)
    
    plt.tight_layout()
    plt.savefig('balanced_distribution.png', dpi=150)
    plt.show()
    
    print(f'Balanced total samples: {sum(sorted_balanced)}')
    print(f'New imbalance ratio: {max(sorted_balanced)/max(1, min(sorted_balanced)):.2f}x')

In [None]:
if not WARMUP:
    history = []
    best_acc = 0.0
    scaler = torch.amp.GradScaler()
    
    # PHASE 1
    print('\n' + '='*50)
    print(f'PHASE 1: Epochs={EPOCHS_P1}, LR={LR_P1}')
    if USE_MIXUP:
        print('  Mixup: ON')
    if USE_FOCAL_LOSS:
        print('  Focal Loss: ON')
    print('='*50)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR_P1, weight_decay=WEIGHT_DECAY)
    total_steps = len(train_loader) * EPOCHS_P1 // ACCUM_STEPS
    scheduler = get_cosine_schedule_with_warmup(optimizer, int(total_steps * WARMUP_RATIO), total_steps)
    
    for epoch in range(1, EPOCHS_P1 + 1):
        loss, train_acc = train_epoch(model, train_loader, optimizer, scheduler, scaler,
                                      use_mixup=USE_MIXUP, use_focal=USE_FOCAL_LOSS)
        
        # Evaluate val if available
        val_acc = evaluate_val(model, val_loader) if val_ds else None
        # Evaluate test if labels available
        test_acc = evaluate(model, test_loader, full_ds.classes, gt_dict, use_tta=USE_FLIP_TTA)
        
        history.append({'epoch': epoch, 'phase': 1, 'loss': loss, 'train_acc': train_acc, 'val_acc': val_acc, 'test_acc': test_acc})
        
        # Determine best metric (test_acc preferred, else val_acc)
        current_metric = test_acc if test_acc is not None else val_acc
        status = ''
        if current_metric is not None and current_metric > best_acc:
            best_acc = current_metric
            torch.save(model.state_dict(), 'best_p1.pt')
            status = '>>> BEST'
        elif current_metric is None:
            torch.save(model.state_dict(), 'best_p1.pt')
        
        # Build log message
        msg = f'Ep {epoch}/{EPOCHS_P1}: L={loss:.4f} TrAcc={train_acc:.4f}'
        if val_acc is not None:
            msg += f' ValAcc={val_acc:.4f}'
        if test_acc is not None:
            msg += f' TestAcc={test_acc:.4f}'
        print(f'{msg} {status}')
        
        gc.collect()
        torch.cuda.empty_cache()

In [None]:
if not WARMUP and USE_TWO_PHASE and EPOCHS_P2 > 0:
    # PHASE 2
    print('\n' + '='*50)
    print(f'PHASE 2: Epochs={EPOCHS_P2}, LR={LR_P2}')
    print(f'  Label Smoothing: {LABEL_SMOOTHING_EPS}')
    print('='*50)
    
    model.load_state_dict(torch.load('best_p1.pt'))
    
    # New loader without mixup
    train_loader_p2 = DataLoader(train_ds, BATCH_SIZE, shuffle=True, num_workers=2, drop_last=True)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR_P2, weight_decay=WEIGHT_DECAY)
    total_steps = len(train_loader_p2) * EPOCHS_P2 // ACCUM_STEPS
    scheduler = get_cosine_schedule_with_warmup(optimizer, int(total_steps * WARMUP_RATIO), total_steps)
    scaler = torch.amp.GradScaler()
    
    for epoch in range(1, EPOCHS_P2 + 1):
        loss, train_acc = train_epoch(model, train_loader_p2, optimizer, scheduler, scaler,
                                      label_smoothing=LABEL_SMOOTHING_EPS)
        
        # Evaluate val if available
        val_acc = evaluate_val(model, val_loader) if val_ds else None
        # Evaluate test if labels available
        test_acc = evaluate(model, test_loader, full_ds.classes, gt_dict, use_tta=USE_FLIP_TTA)
        
        history.append({'epoch': EPOCHS_P1 + epoch, 'phase': 2, 'loss': loss, 'train_acc': train_acc, 'val_acc': val_acc, 'test_acc': test_acc})
        
        # Determine best metric (test_acc preferred, else val_acc)
        current_metric = test_acc if test_acc is not None else val_acc
        status = ''
        if current_metric is not None and current_metric > best_acc:
            best_acc = current_metric
            torch.save(model.state_dict(), 'best_final.pt')
            status = '>>> BEST'
        elif current_metric is None:
            torch.save(model.state_dict(), 'best_final.pt')
        
        # Build log message
        msg = f'P2 Ep {epoch}/{EPOCHS_P2}: L={loss:.4f} TrAcc={train_acc:.4f}'
        if val_acc is not None:
            msg += f' ValAcc={val_acc:.4f}'
        if test_acc is not None:
            msg += f' TestAcc={test_acc:.4f}'
        print(f'{msg} {status}')
        
        gc.collect()
        torch.cuda.empty_cache()
    
    # Save history
    df_history = pd.DataFrame(history)
    df_history.to_csv('training_history.csv', index=False)
    if best_acc > 0:
        print(f'\nTraining Complete! Best Acc: {best_acc:.4f}')
    else:
        print(f'\nTraining Complete!')

## 7. Visualization & Analysis

In [None]:
if not WARMUP:
    # Training curves
    df = pd.DataFrame(history)
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
    axes[0].plot(df['epoch'], df['test_acc'], 'b-o', markersize=4)
    if USE_TWO_PHASE:
        axes[0].axvline(x=EPOCHS_P1, color='gray', linestyle='--', alpha=0.5)
    axes[0].set_title('Test Accuracy')
    axes[0].set_xlabel('Epoch')
    axes[0].grid(True, alpha=0.3)
    
    axes[1].plot(df['epoch'], df['train_acc'], 'g-s', markersize=4)
    if USE_TWO_PHASE:
        axes[1].axvline(x=EPOCHS_P1, color='gray', linestyle='--', alpha=0.5)
    axes[1].set_title('Train Accuracy')
    axes[1].set_xlabel('Epoch')
    axes[1].grid(True, alpha=0.3)
    
    axes[2].plot(df['epoch'], df['loss'], 'r-^', markersize=4)
    if USE_TWO_PHASE:
        axes[2].axvline(x=EPOCHS_P1, color='gray', linestyle='--', alpha=0.5)
    axes[2].set_title('Loss')
    axes[2].set_xlabel('Epoch')
    axes[2].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('training_curves.png', dpi=150)
    plt.show()

In [None]:
if not WARMUP:
    # Load best model
    import os
    if os.path.exists('best_final.pt'):
        model.load_state_dict(torch.load('best_final.pt'))
    elif os.path.exists('best_p1.pt'):
        model.load_state_dict(torch.load('best_p1.pt'))
    model.eval()
    
    # Get all predictions
    all_preds, all_true = [], []
    simple_test_ds = TestDatasetSingle(PATH_DATA_TEST, NUM_FRAMES)
    simple_loader = DataLoader(simple_test_ds, BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True, persistent_workers=True)
    
    with torch.no_grad():
        for videos, video_ids in tqdm(simple_loader, desc='Final Eval'):
            videos = videos.to(DEVICE)
            output = model(videos)
            logits = output.logits if hasattr(output, 'logits') else output
            preds = logits.argmax(1).cpu().tolist()
            for vid, pred in zip(video_ids.tolist(), preds):
                all_true.append(gt_dict[str(vid)])
                all_preds.append(full_ds.classes[pred])
    
    overall_acc = accuracy_score(all_true, all_preds)
    print(f'\n{"="*60}')
    print(f'OVERALL TEST ACCURACY: {overall_acc:.4f} ({overall_acc*100:.2f}%)')
    print(f'{"="*60}')

In [None]:
if not WARMUP:
    # Per-class accuracy
    report_dict = classification_report(all_true, all_preds, target_names=full_ds.classes, 
                                        output_dict=True, zero_division=0)
    
    class_accs = [(cls, report_dict[cls]['recall'] * 100) for cls in full_ds.classes]
    class_accs_sorted = sorted(class_accs, key=lambda x: x[1], reverse=True)
    
    cls_names = [c[0] for c in class_accs_sorted]
    cls_accs_vals = [c[1] for c in class_accs_sorted]
    
    plt.figure(figsize=(16, 6))
    colors = ['darkgreen' if acc >= 90 else 'orange' if acc >= 70 else 'darkred' for acc in cls_accs_vals]
    plt.bar(range(len(cls_names)), cls_accs_vals, color=colors, alpha=0.7)
    plt.axhline(y=overall_acc*100, color='blue', linestyle='--', linewidth=2, label=f'Overall ({overall_acc*100:.2f}%)')
    plt.xlabel('Action Category', fontsize=12)
    plt.ylabel('Accuracy (%)', fontsize=12)
    plt.title('Per-Class Test Accuracy', fontsize=14, fontweight='bold')
    plt.xticks(range(len(cls_names)), cls_names, rotation=90, ha='right', fontsize=8)
    plt.legend()
    plt.tight_layout()
    plt.savefig('per_class_accuracy.png', dpi=150)
    plt.show()

In [None]:
if not WARMUP:
    # Prediction distribution vs Ground truth
    gt_counts = Counter(all_true)
    pred_counts = Counter(all_preds)
    
    classes_sorted = sorted(full_ds.classes)
    gt_vals = [gt_counts.get(c, 0) for c in classes_sorted]
    pred_vals = [pred_counts.get(c, 0) for c in classes_sorted]
    
    x = np.arange(len(classes_sorted))
    width = 0.35
    
    fig, ax = plt.subplots(figsize=(16, 6))
    ax.bar(x - width/2, gt_vals, width, label='Ground Truth', alpha=0.7, color='steelblue')
    ax.bar(x + width/2, pred_vals, width, label='Predictions', alpha=0.7, color='coral')
    
    ax.set_xlabel('Action Category')
    ax.set_ylabel('Count')
    ax.set_title('Prediction Distribution vs Ground Truth (detect model bias)', fontweight='bold')
    ax.set_xticks(x)
    ax.set_xticklabels(classes_sorted, rotation=90, fontsize=7)
    ax.legend()
    plt.tight_layout()
    plt.savefig('prediction_distribution.png', dpi=150)
    plt.show()

In [None]:
if not WARMUP:
    # Top confusion pairs
    cm = confusion_matrix(all_true, all_preds, labels=full_ds.classes)
    
    confusions = []
    for i in range(len(full_ds.classes)):
        for j in range(len(full_ds.classes)):
            if i != j and cm[i, j] > 0:
                confusions.append((full_ds.classes[i], full_ds.classes[j], cm[i, j]))
    
    confusions.sort(key=lambda x: x[2], reverse=True)
    
    print('\n' + '='*70)
    print('TOP-15 CONFUSION PAIRS (True → Predicted)')
    print('='*70)
    for i, (true_cls, pred_cls, count) in enumerate(confusions[:15]):
        print(f'{i+1:2d}. {true_cls:20s} → {pred_cls:20s} ({int(count):2d} errors)')

In [None]:
if not WARMUP:
    print('\n' + '='*70)
    print('FINAL SUMMARY')
    print('='*70)
    print(f'Model: {MODEL_TYPE}')
    print(f'Best Test Accuracy: {best_acc:.4f} ({best_acc*100:.2f}%)')
    print(f'\nConfiguration:')
    print(f'  Consistent Spatial Aug: {USE_CONSISTENT_SPATIAL_AUG}')
    print(f'  Mixup: {USE_MIXUP}')
    print(f'  Label Smoothing: {USE_LABEL_SMOOTHING}')
    print(f'  Two Phase: {USE_TWO_PHASE}')
    print(f'  Flip TTA: {USE_FLIP_TTA}')