# Incremental Ablation Study - VideoMAE

**Chiến lược**: Tích hợp từng kỹ thuật một cách tuần tự để đo lường cải thiện tích lũy.

| Exp | Name | Techniques | Expected Acc |
|-----|------|------------|--------------|
| 1 | Baseline | VideoMAE (LR=5e-5) | ~83.92% |
| 2 | +ConsistentAug | + Consistent Spatial Augmentation | ~84.31% |
| 3 | +Mixup | + Mixup (α=0.8) | ~82.55%* |
| 4 | +2Stage | + 2-Stage (Mixup→Label Smoothing) | ~84.71% |
| 5 | +TTA | + 6-View TTA at Inference | **~85.10%** |

*Mixup đơn lẻ giảm accuracy ngắn hạn nhưng khi kết hợp với 2-Stage sẽ hiệu quả hơn.

**Epochs**: 10 (Phase 1: 7, Phase 2: 3 nếu 2-Stage)


In [63]:
# =================== CONFIGURATION ===================
# ===== EXPERIMENT SELECTION =====
RUN_EXP = 1                 # Which experiment to run: 1, 2, 3, or 4 (4 includes TTA)
                            # Run separately on different Kaggle sessions

# ===== QUICK TEST MODE =====
QUICK_TEST = False          # True = test pipeline with 5 batches per phase
QUICK_TEST_BATCHES = 5      # Number of batches when QUICK_TEST=True

# Model Config
MODEL_CKPT = 'MCG-NJU/videomae-base-finetuned-kinetics'
NUM_FRAMES = 16
IMG_SIZE = 224
RESIZE_SIZE = 256

# Training Config
EPOCHS_TOTAL = 1 if QUICK_TEST else 20
EPOCHS_P1 = 1 if QUICK_TEST else 15      # Phase 1 epochs (when 2-Stage)
EPOCHS_P2 = 1 if QUICK_TEST else 5      # Phase 2 epochs (when 2-Stage)
BATCH_SIZE = 8
ACCUM_STEPS = 4
LR_P1 = 5e-5
LR_P2 = 1e-6
WEIGHT_DECAY = 0.05
WARMUP_RATIO = 0.1

# Augmentation Config
MIXUP_ALPHA = 0.8
LABEL_SMOOTHING_EPS = 0.1

# Paths (Kaggle)
PATH_DATA_TRAIN = '/kaggle/input/action-video/data/data_train'
PATH_DATA_TEST = '/kaggle/input/action-video/data/test'
TEST_LABELS_URL = '1Xv2CWOqdBj3kt0rkNJKRsodSIEd3-wX_'

EXP_NAMES = {1: 'Baseline', 2: '+ConsistentAug', 3: '+Mixup', 4: '+2Stage+TTA'}
print('='*60)
print(f'RUNNING: Exp {RUN_EXP} - {EXP_NAMES[RUN_EXP]}')
if QUICK_TEST:
    print(f'  [QUICK_TEST] Only {QUICK_TEST_BATCHES} batches, 1 epoch')
print('='*60)

RUNNING: Exp 1 - Baseline


In [64]:
import os
import sys
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import gc
import random
import warnings
warnings.filterwarnings('ignore')

from pathlib import Path
from collections import Counter

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torchvision.transforms.functional as TF

from transformers import VideoMAEForVideoClassification, VideoMAEImageProcessor
from transformers import get_cosine_schedule_with_warmup
from sklearn.metrics import accuracy_score

# Seed
def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

seed_everything(42)
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {DEVICE}')

MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]

Device: cuda


## 1. Dataset Classes

In [65]:
class VideoDataset(Dataset):
    """Training dataset with toggle for consistent spatial aug."""
    def __init__(self, root, num_frames=16, consistent_aug=False):
        self.root = Path(root)
        self.num_frames = num_frames
        self.consistent_aug = consistent_aug
        
        self.classes = sorted([d.name for d in self.root.iterdir() if d.is_dir()])
        self.class_to_idx = {c: i for i, c in enumerate(self.classes)}
        
        self.samples = []
        for cls in self.classes:
            for vid in (self.root / cls).iterdir():
                if vid.is_dir():
                    self.samples.append((vid, self.class_to_idx[cls]))
        print(f'Loaded {len(self.samples)} videos, {len(self.classes)} classes')
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        vid_dir, label = self.samples[idx]
        files = sorted(vid_dir.glob('*.jpg'))
        indices = torch.linspace(0, len(files)-1, self.num_frames).long()
        
        frames = [Image.open(files[i]).convert('RGB') for i in indices]
        frames = [TF.resize(img, RESIZE_SIZE) for img in frames]
        
        if self.consistent_aug:
            # Same crop/flip for all frames
            i, j, h, w = T.RandomResizedCrop.get_params(frames[0], (0.8, 1.0), (0.75, 1.33))
            do_flip = random.random() > 0.5
            processed = []
            for img in frames:
                img = TF.resized_crop(img, i, j, h, w, (IMG_SIZE, IMG_SIZE))
                if do_flip:
                    img = TF.hflip(img)
                img = TF.normalize(TF.to_tensor(img), MEAN, STD)
                processed.append(img)
        else:
            # Simple center crop (no aug)
            processed = [TF.normalize(TF.to_tensor(TF.center_crop(img, IMG_SIZE)), MEAN, STD) 
                         for img in frames]
        
        return torch.stack(processed), label


class TestDatasetSingle(Dataset):
    """Test with single center crop."""
    def __init__(self, root, num_frames=16):
        self.root = Path(root)
        self.num_frames = num_frames
        self.samples = [(d, int(d.name)) for d in self.root.iterdir() if d.is_dir()]
        self.samples.sort(key=lambda x: x[1])
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        vid_dir, vid_id = self.samples[idx]
        files = sorted(vid_dir.glob('*.jpg'))
        indices = torch.linspace(0, len(files)-1, self.num_frames).long()
        frames = []
        for i in indices:
            img = Image.open(files[i]).convert('RGB')
            img = TF.resize(img, RESIZE_SIZE)
            img = TF.center_crop(img, IMG_SIZE)
            img = TF.normalize(TF.to_tensor(img), MEAN, STD)
            frames.append(img)
        return torch.stack(frames), vid_id


class TestDatasetTTA(Dataset):
    """6-view TTA: 3 spatial crops × 2 flip states."""
    def __init__(self, root, num_frames=16):
        self.root = Path(root)
        self.num_frames = num_frames
        self.samples = [(d, int(d.name)) for d in self.root.iterdir() if d.is_dir()]
        self.samples.sort(key=lambda x: x[1])
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        vid_dir, vid_id = self.samples[idx]
        files = sorted(vid_dir.glob('*.jpg'))
        indices = torch.linspace(0, len(files)-1, self.num_frames).long()
        
        frames = [Image.open(files[i]).convert('RGB') for i in indices]
        frames = [TF.resize(img, RESIZE_SIZE) for img in frames]
        
        w, h = frames[0].size
        views = []
        
        # 3 spatial crops
        crop_positions = [
            ((h - IMG_SIZE) // 2, (w - IMG_SIZE) // 2),  # center
            (0, 0),  # top-left
            (max(0, h - IMG_SIZE), max(0, w - IMG_SIZE))  # bottom-right
        ]
        
        for top, left in crop_positions:
            view_frames = []
            for img in frames:
                cropped = TF.crop(img, top, left, IMG_SIZE, IMG_SIZE)
                view_frames.append(TF.normalize(TF.to_tensor(cropped), MEAN, STD))
            views.append(torch.stack(view_frames))
            
            # Flipped version
            view_frames_flip = []
            for img in frames:
                cropped = TF.crop(img, top, left, IMG_SIZE, IMG_SIZE)
                cropped = TF.hflip(cropped)
                view_frames_flip.append(TF.normalize(TF.to_tensor(cropped), MEAN, STD))
            views.append(torch.stack(view_frames_flip))
        
        return torch.stack(views), vid_id  # [6, T, C, H, W]

print('Dataset classes defined.')

Dataset classes defined.


## 2. Training Utilities

In [66]:
class MixupCollate:
    """Collate function with Mixup augmentation."""
    def __init__(self, num_classes, alpha=0.8):
        self.num_classes = num_classes
        self.alpha = alpha

    def __call__(self, batch):
        inputs, targets = torch.utils.data.default_collate(batch)
        batch_size = inputs.size(0)
        index = torch.randperm(batch_size)
        lam = np.random.beta(self.alpha, self.alpha)
        inputs = lam * inputs + (1 - lam) * inputs[index, :]
        targets_one_hot = F.one_hot(targets, num_classes=self.num_classes).float()
        targets = lam * targets_one_hot + (1 - lam) * targets_one_hot[index, :]
        return inputs, targets


def train_epoch(model, loader, optimizer, scheduler, scaler, 
                use_mixup=False, label_smoothing=0.0, max_batches=None):
    """Train one epoch. If max_batches is set, stop early for quick testing."""
    model.train()
    total_loss, total_correct, total_samples = 0.0, 0, 0
    pbar = tqdm(loader, desc='Training', leave=False)
    optimizer.zero_grad()
    
    for step, (inputs, targets) in enumerate(pbar):
        if max_batches is not None and step >= max_batches:
            break
        
        inputs = inputs.to(DEVICE)
        
        with torch.amp.autocast('cuda'):
            output = model(inputs)
            logits = output.logits if hasattr(output, 'logits') else output
            
            if use_mixup:
                targets = targets.to(DEVICE)
                log_probs = F.log_softmax(logits, dim=1)
                loss = -torch.sum(targets * log_probs, dim=1).mean()
                true_labels = targets.argmax(dim=1)
            else:
                targets = targets.to(DEVICE)
                loss = F.cross_entropy(logits, targets, label_smoothing=label_smoothing)
                true_labels = targets
        
        total_correct += (logits.argmax(1) == true_labels).sum().item()
        total_samples += inputs.size(0)
        
        scaler.scale(loss / ACCUM_STEPS).backward()
        
        if (step + 1) % ACCUM_STEPS == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            if scheduler:
                scheduler.step()
        
        total_loss += loss.item()
        pbar.set_postfix({'loss': f'{total_loss/(step+1):.4f}', 
                          'acc': f'{total_correct/max(total_samples,1):.4f}'})
    
    return total_loss / max(step+1, 1), total_correct / max(total_samples, 1)


@torch.no_grad()
def evaluate(model, loader, classes, gt_dict, use_tta=False, max_batches=None):
    """Evaluate model on test set. If max_batches is set, stop early for quick testing."""
    model.eval()
    predictions = []
    
    for step, (videos, video_ids) in enumerate(tqdm(loader, desc='Evaluating', leave=False)):
        if max_batches is not None and step >= max_batches:
            break
        
        if use_tta:
            B, V, T, C, H, W = videos.shape
            videos = videos.view(B * V, T, C, H, W).to(DEVICE)
            output = model(videos)
            logits = output.logits if hasattr(output, 'logits') else output
            logits = logits.view(B, V, -1).mean(dim=1)
            preds = logits.argmax(1).cpu().tolist()
        else:
            videos = videos.to(DEVICE)
            output = model(videos)
            logits = output.logits if hasattr(output, 'logits') else output
            preds = logits.argmax(1).cpu().tolist()
        
        predictions.extend(zip(video_ids.tolist(), preds))
    
    if max_batches is not None:
        return 0.0  # Skip accuracy calc in quick test mode
    
    y_true = [gt_dict[str(vid)] for vid, _ in predictions]
    y_pred = [classes[p] for _, p in predictions]
    return accuracy_score(y_true, y_pred)

print('Training utilities defined.')

Training utilities defined.


## 3. Load Data

In [67]:
# Download test labels
!gdown "{TEST_LABELS_URL}" -O test_labels.csv -q
gt_df = pd.read_csv('test_labels.csv')
gt_dict = dict(zip(gt_df['id'].astype(str), gt_df['class']))
print(f'Loaded test labels: {len(gt_dict)} samples')

# Results storage
all_results = []

Loaded test labels: 510 samples


---
## Exp 1: VideoMAE Baseline
- **Techniques**: None (just fine-tune with LR=5e-5)
- **Expected**: ~83.92%

In [68]:
if RUN_EXP == 1:
    print('='*60)
    print('EXP 1: BASELINE (No augmentation)')
    print('='*60)

    # Dataset without consistent aug
    train_ds = VideoDataset(PATH_DATA_TRAIN, NUM_FRAMES, consistent_aug=False)
    test_ds = TestDatasetSingle(PATH_DATA_TEST, NUM_FRAMES)

    train_loader = DataLoader(train_ds, BATCH_SIZE, shuffle=True, num_workers=2, drop_last=True, persistent_workers=True)
    test_loader = DataLoader(test_ds, BATCH_SIZE, shuffle=False, num_workers=2, persistent_workers=True)

    # Model
    model = VideoMAEForVideoClassification.from_pretrained(
        MODEL_CKPT,
        num_labels=len(train_ds.classes),
        ignore_mismatched_sizes=True,
        num_frames=NUM_FRAMES
    ).to(DEVICE)

    optimizer = torch.optim.AdamW(model.parameters(), lr=LR_P1, weight_decay=WEIGHT_DECAY)
    total_steps = len(train_loader) * EPOCHS_TOTAL // ACCUM_STEPS
    scheduler = get_cosine_schedule_with_warmup(optimizer, int(total_steps * WARMUP_RATIO), total_steps)
    scaler = torch.amp.GradScaler()

    best_acc = 0.0
    max_batches = QUICK_TEST_BATCHES if QUICK_TEST else None
    for epoch in range(1, EPOCHS_TOTAL + 1):
        loss, train_acc = train_epoch(model, train_loader, optimizer, scheduler, scaler, max_batches=max_batches)
        test_acc = evaluate(model, test_loader, train_ds.classes, gt_dict, max_batches=max_batches)
        
        status = ''
        if test_acc > best_acc:
            best_acc = test_acc
            torch.save(model.state_dict(), 'exp1_best.pt')
            status = '>>> BEST'
        
        print(f'Ep {epoch}/{EPOCHS_TOTAL}: Loss={loss:.4f} TrainAcc={train_acc:.4f} TestAcc={test_acc:.4f} {status}')

    print(f'\nExp1 Best: {best_acc:.4f}')
    all_results.append({'exp': 1, 'name': 'Baseline', 'test_acc': best_acc})
    del model, optimizer, scheduler
    torch.cuda.empty_cache()
    gc.collect()
else:
    print('Skipping Exp1 (RUN_EXP != 1)')

EXP 1: BASELINE (No augmentation)
Loaded 6254 videos, 51 classes


Some weights of VideoMAEForVideoClassification were not initialized from the model checkpoint at MCG-NJU/videomae-base-finetuned-kinetics and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([400]) in the checkpoint and torch.Size([51]) in the model instantiated
- classifier.weight: found shape torch.Size([400, 768]) in the checkpoint and torch.Size([51, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Training:   0%|          | 0/781 [00:00<?, ?it/s]

KeyboardInterrupt: 

---
## Exp 2: + Consistent Spatial Augmentation
- **New**: Same crop/flip for all 16 frames
- **Expected**: +0.39% (~84.31%)

In [None]:
if RUN_EXP == 2:
    print('='*60)
    print('EXP 2: + CONSISTENT SPATIAL AUGMENTATION')
    print('='*60)

    train_ds = VideoDataset(PATH_DATA_TRAIN, NUM_FRAMES, consistent_aug=True)
    test_ds = TestDatasetSingle(PATH_DATA_TEST, NUM_FRAMES)
    train_loader = DataLoader(train_ds, BATCH_SIZE, shuffle=True, num_workers=2, drop_last=True, persistent_workers=True)
    test_loader = DataLoader(test_ds, BATCH_SIZE, shuffle=False, num_workers=2, persistent_workers=True)

    model = VideoMAEForVideoClassification.from_pretrained(
        MODEL_CKPT, num_labels=len(train_ds.classes),
        ignore_mismatched_sizes=True, num_frames=NUM_FRAMES
    ).to(DEVICE)

    optimizer = torch.optim.AdamW(model.parameters(), lr=LR_P1, weight_decay=WEIGHT_DECAY)
    total_steps = len(train_loader) * EPOCHS_TOTAL // ACCUM_STEPS
    scheduler = get_cosine_schedule_with_warmup(optimizer, int(total_steps * WARMUP_RATIO), total_steps)
    scaler = torch.amp.GradScaler()

    best_acc = 0.0
    max_batches = QUICK_TEST_BATCHES if QUICK_TEST else None
    for epoch in range(1, EPOCHS_TOTAL + 1):
        loss, train_acc = train_epoch(model, train_loader, optimizer, scheduler, scaler, max_batches=max_batches)
        test_acc = evaluate(model, test_loader, train_ds.classes, gt_dict, max_batches=max_batches)
        status = '>>> BEST' if test_acc > best_acc else ''
        if test_acc > best_acc:
            best_acc = test_acc
            torch.save(model.state_dict(), 'exp2_best.pt')
        print(f'Ep {epoch}/{EPOCHS_TOTAL}: Loss={loss:.4f} TrainAcc={train_acc:.4f} TestAcc={test_acc:.4f} {status}')

    print(f'Exp2 Best: {best_acc:.4f}')
    all_results.append({'exp': 2, 'name': '+ConsistentAug', 'test_acc': best_acc})
    del model, optimizer, scheduler
    torch.cuda.empty_cache(); gc.collect()
else:
    print('Skipping Exp2')

---
## Exp 3: + Mixup
- **New**: Mixup (α=0.8) on top of Consistent Aug
- **Note**: Mixup alone may decrease accuracy but reduces overfitting gap
- **Expected**: ~82.55% (trade-off for better generalization)

In [None]:
if RUN_EXP == 3:
    print('='*60)
    print('EXP 3: + MIXUP')
    print('='*60)

    train_ds = VideoDataset(PATH_DATA_TRAIN, NUM_FRAMES, consistent_aug=True)
    test_ds = TestDatasetSingle(PATH_DATA_TEST, NUM_FRAMES)
    mixup_collate = MixupCollate(len(train_ds.classes), MIXUP_ALPHA)
    train_loader = DataLoader(train_ds, BATCH_SIZE, shuffle=True, num_workers=2, collate_fn=mixup_collate, drop_last=True, persistent_workers=True)
    test_loader = DataLoader(test_ds, BATCH_SIZE, shuffle=False, num_workers=2, persistent_workers=True)

    model = VideoMAEForVideoClassification.from_pretrained(
        MODEL_CKPT, num_labels=len(train_ds.classes),
        ignore_mismatched_sizes=True, num_frames=NUM_FRAMES
    ).to(DEVICE)

    optimizer = torch.optim.AdamW(model.parameters(), lr=LR_P1, weight_decay=WEIGHT_DECAY)
    total_steps = len(train_loader) * EPOCHS_TOTAL // ACCUM_STEPS
    scheduler = get_cosine_schedule_with_warmup(optimizer, int(total_steps * WARMUP_RATIO), total_steps)
    scaler = torch.amp.GradScaler()

    best_acc = 0.0
    max_batches = QUICK_TEST_BATCHES if QUICK_TEST else None
    for epoch in range(1, EPOCHS_TOTAL + 1):
        loss, train_acc = train_epoch(model, train_loader, optimizer, scheduler, scaler, use_mixup=True, max_batches=max_batches)
        test_acc = evaluate(model, test_loader, train_ds.classes, gt_dict, max_batches=max_batches)
        status = '>>> BEST' if test_acc > best_acc else ''
        if test_acc > best_acc:
            best_acc = test_acc
            torch.save(model.state_dict(), 'exp3_best.pt')
        print(f'Ep {epoch}/{EPOCHS_TOTAL}: Loss={loss:.4f} TrainAcc={train_acc:.4f} TestAcc={test_acc:.4f} {status}')

    print(f'Exp3 Best: {best_acc:.4f}')
    all_results.append({'exp': 3, 'name': '+Mixup', 'test_acc': best_acc})
    del model, optimizer, scheduler
    torch.cuda.empty_cache(); gc.collect()
else:
    print('Skipping Exp3')

---
## Exp 4+5: 2-Stage Training + TTA
- **Phase 1**: Mixup + LR=5e-5
- **Phase 2**: Label Smoothing + LR=1e-6
- **TTA**: 3 spatial crops × 2 flip states


In [None]:
if RUN_EXP == 4:
    print('='*60)
    print('EXP 4+5: 2-STAGE TRAINING + TTA')
    print('='*60)
    
    # Setup
    train_ds = VideoDataset(PATH_DATA_TRAIN, NUM_FRAMES, consistent_aug=True)
    test_ds = TestDatasetSingle(PATH_DATA_TEST, NUM_FRAMES)
    mixup_collate = MixupCollate(len(train_ds.classes), MIXUP_ALPHA)
    train_loader_p1 = DataLoader(train_ds, BATCH_SIZE, shuffle=True, num_workers=2, collate_fn=mixup_collate, drop_last=True, persistent_workers=True)
    train_loader_p2 = DataLoader(train_ds, BATCH_SIZE, shuffle=True, num_workers=2, drop_last=True, persistent_workers=True)
    test_loader = DataLoader(test_ds, BATCH_SIZE, shuffle=False, num_workers=2, persistent_workers=True)
    
    model = VideoMAEForVideoClassification.from_pretrained(
        MODEL_CKPT, num_labels=len(train_ds.classes),
        ignore_mismatched_sizes=True, num_frames=NUM_FRAMES
    ).to(DEVICE)
    
    best_acc = 0.0
    max_batches = QUICK_TEST_BATCHES if QUICK_TEST else None
    
    # Phase 1: Mixup
    print(f'\\n--- Phase 1: Mixup ({EPOCHS_P1} epochs) ---')
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR_P1, weight_decay=WEIGHT_DECAY)
    total_steps = len(train_loader_p1) * EPOCHS_P1 // ACCUM_STEPS
    scheduler = get_cosine_schedule_with_warmup(optimizer, int(total_steps * WARMUP_RATIO), total_steps)
    scaler = torch.amp.GradScaler()
    
    for epoch in range(1, EPOCHS_P1 + 1):
        loss, train_acc = train_epoch(model, train_loader_p1, optimizer, scheduler, scaler, use_mixup=True, max_batches=max_batches)
        test_acc = evaluate(model, test_loader, train_ds.classes, gt_dict, max_batches=max_batches)
        if test_acc > best_acc:
            best_acc = test_acc
            torch.save(model.state_dict(), 'exp4_best.pt')
        print(f'P1 Ep {epoch}/{EPOCHS_P1}: Loss={loss:.4f} Acc={test_acc:.4f}')
    
    if not os.path.exists('exp4_best.pt'):
        torch.save(model.state_dict(), 'exp4_best.pt')
    
    # Phase 2: Label Smoothing
    print(f'\\n--- Phase 2: Label Smoothing ({EPOCHS_P2} epochs) ---')
    model.load_state_dict(torch.load('exp4_best.pt'))
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR_P2, weight_decay=WEIGHT_DECAY)
    total_steps = len(train_loader_p2) * EPOCHS_P2 // ACCUM_STEPS
    scheduler = get_cosine_schedule_with_warmup(optimizer, int(total_steps * WARMUP_RATIO), total_steps)
    scaler = torch.amp.GradScaler()
    
    for epoch in range(1, EPOCHS_P2 + 1):
        loss, train_acc = train_epoch(model, train_loader_p2, optimizer, scheduler, scaler, label_smoothing=LABEL_SMOOTHING_EPS, max_batches=max_batches)
        test_acc = evaluate(model, test_loader, train_ds.classes, gt_dict, max_batches=max_batches)
        if test_acc > best_acc:
            best_acc = test_acc
            torch.save(model.state_dict(), 'exp4_best.pt')
        print(f'P2 Ep {epoch}/{EPOCHS_P2}: Loss={loss:.4f} Acc={test_acc:.4f}')
    
    print(f'\\nExp4 (2-Stage) Best: {best_acc:.4f}')
    all_results.append({'exp': 4, 'name': '+2Stage', 'test_acc': best_acc})
    
    # Exp5: TTA
    print('\\n' + '='*60)
    print('EXP 5: + 6-VIEW TTA')
    print('='*60)
    model.load_state_dict(torch.load('exp4_best.pt'))
    test_ds_tta = TestDatasetTTA(PATH_DATA_TEST, NUM_FRAMES)
    test_loader_tta = DataLoader(test_ds_tta, BATCH_SIZE, shuffle=False, num_workers=2, persistent_workers=True)
    test_acc_tta = evaluate(model, test_loader_tta, train_ds.classes, gt_dict, use_tta=True, max_batches=max_batches)
    print(f'\\nExp5 (TTA) Accuracy: {test_acc_tta:.4f}')
    all_results.append({'exp': 5, 'name': '+TTA', 'test_acc': test_acc_tta})
    
    del model, optimizer, scheduler
    torch.cuda.empty_cache()
    gc.collect()
else:
    print('Skipping Exp4+5 (RUN_EXP != 4)')