# VideoMAE 8-Frame: Single Checkpoint Resume

In [None]:
import warnings
warnings.filterwarnings('ignore')
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import logging
logging.getLogger('transformers').setLevel(logging.ERROR)

In [None]:
!pip install -q transformers accelerate

In [None]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from transformers import VideoMAEForVideoClassification, VideoMAEImageProcessor
from transformers import get_cosine_schedule_with_warmup
from pathlib import Path
from PIL import Image
from tqdm.auto import tqdm
import numpy as np
import random
import pandas as pd
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
import gc

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {DEVICE}')

PATH_DATA_TRAIN = '/kaggle/input/action-video/data/data_train'
PATH_DATA_TEST = '/kaggle/input/action-video/data/test'

In [None]:
!gdown "1Xv2CWOqdBj3kt0rkNJKRsodSIEd3-wX_" -O test_labels.csv -q
gt_df = pd.read_csv("test_labels.csv")
TEST_LABELS = dict(zip(gt_df['id'].astype(str), gt_df['class']))
del gt_df; gc.collect()
print(f"Test labels: {len(TEST_LABELS)}")

## 1. Config

In [None]:
MODEL_CKPT = "MCG-NJU/videomae-base-finetuned-kinetics"
NUM_FRAMES = 8
IMG_SIZE = 224
RESIZE_SIZE = 256

EPOCHS_P1 = 30
EPOCHS_P2 = 10
LR_P1 = 5e-5
LR_P2 = 1e-6
LABEL_SMOOTHING = 0.1

BATCH_SIZE = 32
NUM_WORKERS = 4
WEIGHT_DECAY = 0.05
MIXUP_ALPHA = 0.8

print(f"LR_P1={LR_P1}, LR_P2={LR_P2}, BATCH={BATCH_SIZE}, WORKERS={NUM_WORKERS}")

## 2. Dataset

In [None]:
image_processor = VideoMAEImageProcessor.from_pretrained(MODEL_CKPT)
MEAN, STD = image_processor.image_mean, image_processor.image_std

class MixupCollate:
    def __init__(self, num_classes, alpha=0.8):
        self.num_classes, self.alpha = num_classes, alpha
    def __call__(self, batch):
        inputs, targets = torch.utils.data.default_collate(batch)
        lam = np.random.beta(self.alpha, self.alpha)
        idx = torch.randperm(inputs.size(0))
        inputs = lam * inputs + (1 - lam) * inputs[idx]
        oh = F.one_hot(targets, self.num_classes).float()
        return inputs, lam * oh + (1 - lam) * oh[idx]

class VideoDataset(Dataset):
    def __init__(self, root, num_frames=8, is_train=True):
        self.root, self.num_frames, self.is_train = Path(root), num_frames, is_train
        self.classes = sorted([d.name for d in self.root.iterdir() if d.is_dir()])
        self.class_to_idx = {c: i for i, c in enumerate(self.classes)}
        self.samples = [(list(sorted(v.glob('*.jpg'))), self.class_to_idx[c]) 
                        for c in self.classes for v in (self.root/c).iterdir() if v.is_dir()]
    def __len__(self): return len(self.samples)
    def __getitem__(self, idx):
        paths, label = self.samples[idx]
        indices = np.linspace(0, len(paths)-1, self.num_frames, dtype=int)
        frames = [TF.resize(Image.open(paths[i]).convert('RGB'), RESIZE_SIZE) for i in indices]
        if self.is_train:
            i,j,h,w = T.RandomResizedCrop.get_params(frames[0], (0.8,1.0), (0.75,1.33))
            flip = random.random() > 0.5
            frames = [TF.normalize(TF.to_tensor(TF.hflip(TF.resized_crop(f,i,j,h,w,(IMG_SIZE,IMG_SIZE))) if flip else TF.resized_crop(f,i,j,h,w,(IMG_SIZE,IMG_SIZE))), MEAN, STD) for f in frames]
        else:
            frames = [TF.normalize(TF.to_tensor(TF.center_crop(f, IMG_SIZE)), MEAN, STD) for f in frames]
        return torch.stack(frames), label

class TestDataset(Dataset):
    def __init__(self, root, num_frames=8):
        self.root, self.num_frames = Path(root), num_frames
        self.videos = sorted([d for d in self.root.iterdir() if d.is_dir()], key=lambda x: int(x.name))
    def __len__(self): return len(self.videos)
    def __getitem__(self, idx):
        v = self.videos[idx]
        paths = sorted(v.glob('*.jpg'))
        indices = np.linspace(0, len(paths)-1, self.num_frames, dtype=int)
        frames = [TF.normalize(TF.to_tensor(TF.center_crop(TF.resize(Image.open(paths[i]).convert('RGB'), RESIZE_SIZE), IMG_SIZE)), MEAN, STD) for i in indices]
        return torch.stack(frames), int(v.name)

In [None]:
train_dataset = VideoDataset(PATH_DATA_TRAIN, NUM_FRAMES, is_train=True)
test_dataset = TestDataset(PATH_DATA_TEST, NUM_FRAMES)
print(f"Train: {len(train_dataset)} | Test: {len(test_dataset)} | Classes: {len(train_dataset.classes)}")

mixup_collate = MixupCollate(len(train_dataset.classes), MIXUP_ALPHA)
train_loader_p1 = DataLoader(train_dataset, BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True, collate_fn=mixup_collate, drop_last=True)
train_loader_p2 = DataLoader(train_dataset, BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True, drop_last=True)
test_loader = DataLoader(test_dataset, BATCH_SIZE, num_workers=NUM_WORKERS, pin_memory=True)

## 3. Model

In [None]:
model = VideoMAEForVideoClassification.from_pretrained(
    MODEL_CKPT, num_labels=len(train_dataset.classes), 
    ignore_mismatched_sizes=True, num_frames=NUM_FRAMES
).to(DEVICE)
print(f"Params: {sum(p.numel() for p in model.parameters()):,}")

## 4. Training Functions

In [None]:
def train_epoch(model, loader, optimizer, scheduler, scaler, use_mixup=True, label_smoothing=0.0):
    model.train()
    total_loss, correct, total = 0.0, 0, 0
    
    for x, y in tqdm(loader, desc="Train", leave=False):
        x, y = x.to(DEVICE, non_blocking=True), y.to(DEVICE, non_blocking=True)
        
        with torch.amp.autocast('cuda'):
            logits = model(x).logits
            if use_mixup:
                loss = -torch.sum(y * F.log_softmax(logits, 1), 1).mean()
                labels = y.argmax(1)
            else:
                loss = F.cross_entropy(logits, y, label_smoothing=label_smoothing)
                labels = y
        
        total_loss += loss.item()
        correct += (logits.argmax(1) == labels).sum().item()
        total += x.size(0)
        
        optimizer.zero_grad(set_to_none=True)
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
        
        del x, y, logits, loss
    
    return total_loss / len(loader), correct / total

@torch.no_grad()
def evaluate(model, loader, classes):
    model.eval()
    all_preds, all_ids = [], []
    
    for x, ids in tqdm(loader, desc="Eval", leave=False):
        x = x.to(DEVICE, non_blocking=True)
        preds = model(x).logits.argmax(1).cpu().numpy()
        all_preds.extend(preds.tolist())
        all_ids.extend(ids.numpy().tolist())
        del x, preds
    
    y_pred = [classes[p] for vid, p in zip(all_ids, all_preds) if str(vid) in TEST_LABELS]
    y_true = [TEST_LABELS[str(vid)] for vid in all_ids if str(vid) in TEST_LABELS]
    return accuracy_score(y_true, y_pred)

def save_checkpoint(epoch, phase, loss, train_acc, test_acc, best_acc):
    """Save single checkpoint (overwrite) + append to CSV"""
    # Save model (overwrite)
    torch.save(model.state_dict(), 'checkpoint.pt')
    
    # Append to CSV
    row = {'epoch': epoch, 'phase': phase, 'loss': loss, 'train_acc': train_acc, 'test_acc': test_acc, 'best_acc': best_acc}
    df = pd.DataFrame([row])
    if os.path.exists('history.csv'):
        df.to_csv('history.csv', mode='a', header=False, index=False)
    else:
        df.to_csv('history.csv', index=False)
    
    print(f"  Loss:{loss:.4f} | Train:{train_acc:.4f} | Test:{test_acc:.4f} | Best:{best_acc:.4f}")

## 5. Check Resume

In [None]:
START_EPOCH = 1
START_PHASE = 'P1'
best_acc = 0

if os.path.exists('history.csv') and os.path.exists('checkpoint.pt'):
    df = pd.read_csv('history.csv')
    last = df.iloc[-1]
    
    model.load_state_dict(torch.load('checkpoint.pt', map_location=DEVICE))
    
    last_epoch = int(last['epoch'])
    START_PHASE = last['phase']
    best_acc = float(last['best_acc'])
    
    if START_PHASE == 'P1':
        START_EPOCH = last_epoch + 1
    else:  # P2
        START_EPOCH = last_epoch - EPOCHS_P1 + 1
    
    print(f"‚úÖ Resume: epoch {last_epoch+1}, phase {START_PHASE}, best={best_acc:.4f}")
else:
    print("üìù Starting fresh")

## 6. Phase 1: Mixup

In [None]:
if START_PHASE == 'P1' and START_EPOCH <= EPOCHS_P1:
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR_P1, weight_decay=WEIGHT_DECAY)
    total_steps = len(train_loader_p1) * EPOCHS_P1
    scheduler = get_cosine_schedule_with_warmup(optimizer, int(total_steps*0.1), total_steps)
    
    # Fast-forward scheduler
    for _ in range((START_EPOCH - 1) * len(train_loader_p1)):
        scheduler.step()
    
    scaler = torch.amp.GradScaler()
    
    print("="*60)
    print(f"PHASE 1: Mixup | LR={LR_P1} | Epochs {START_EPOCH}-{EPOCHS_P1}")
    print("="*60)
    
    for ep in range(START_EPOCH, EPOCHS_P1 + 1):
        print(f"\nEpoch {ep}/{EPOCHS_P1}")
        loss, train_acc = train_epoch(model, train_loader_p1, optimizer, scheduler, scaler, use_mixup=True)
        test_acc = evaluate(model, test_loader, train_dataset.classes)
        
        if test_acc > best_acc:
            best_acc = test_acc
            torch.save({'model': model.state_dict(), 'classes': train_dataset.classes}, 'best.pt')
            print(f"  >>> New Best!")
        
        save_checkpoint(ep, 'P1', loss, train_acc, test_acc, best_acc)
        gc.collect(); torch.cuda.empty_cache()
    
    print(f"\nP1 Done! Best: {best_acc:.4f}")
    START_EPOCH = 1
    START_PHASE = 'P2'
else:
    print(f"Skipping P1")

## 7. Phase 2: Label Smoothing

In [None]:
if START_PHASE == 'P2':
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR_P2, weight_decay=WEIGHT_DECAY)
    total_steps = len(train_loader_p2) * EPOCHS_P2
    scheduler = get_cosine_schedule_with_warmup(optimizer, 0, total_steps)
    
    p2_start = START_EPOCH if START_EPOCH > 1 else 1
    for _ in range((p2_start - 1) * len(train_loader_p2)):
        scheduler.step()
    
    scaler = torch.amp.GradScaler()
    
    print("\n" + "="*60)
    print(f"PHASE 2: Label Smoothing | LR={LR_P2} | Epochs {p2_start}-{EPOCHS_P2}")
    print("="*60)
    
    for ep in range(p2_start, EPOCHS_P2 + 1):
        global_ep = EPOCHS_P1 + ep
        print(f"\nEpoch {global_ep}/{EPOCHS_P1+EPOCHS_P2}")
        loss, train_acc = train_epoch(model, train_loader_p2, optimizer, scheduler, scaler, use_mixup=False, label_smoothing=LABEL_SMOOTHING)
        test_acc = evaluate(model, test_loader, train_dataset.classes)
        
        if test_acc > best_acc:
            best_acc = test_acc
            torch.save({'model': model.state_dict(), 'classes': train_dataset.classes}, 'best.pt')
            print(f"  >>> New Best!")
        
        save_checkpoint(global_ep, 'P2', loss, train_acc, test_acc, best_acc)
        gc.collect(); torch.cuda.empty_cache()
    
    print(f"\nüèÜ FINAL: {best_acc:.4f}")

## 8. Plot

In [None]:
if os.path.exists('history.csv'):
    df = pd.read_csv('history.csv')
    
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    axes[0].plot(df['epoch'], df['train_acc'], 'b-o', label='Train', ms=3)
    axes[0].plot(df['epoch'], df['test_acc'], 'r-s', label='Test', ms=3)
    axes[0].axvline(x=EPOCHS_P1, color='gray', ls='--', label='P1‚ÜíP2')
    axes[0].set_title('Accuracy'); axes[0].legend(); axes[0].grid(alpha=0.3)
    axes[1].plot(df['epoch'], df['loss'], 'g-^', ms=3)
    axes[1].axvline(x=EPOCHS_P1, color='gray', ls='--')
    axes[1].set_title('Loss'); axes[1].grid(alpha=0.3)
    plt.tight_layout()
    plt.savefig('curves.png', dpi=150)
    plt.show()
    print(df.to_string(index=False))