In [None]:
from __future__ import annotations
import math
import random
import re
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import amp
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.transforms import InterpolationMode
import torchvision.transforms.functional as TF
from tqdm.auto import tqdm
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import timm
from torch.utils.data import Dataset, DataLoader, random_split, WeightedRandomSampler
from transformers import VideoMAEForVideoClassification

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {DEVICE}')

In [None]:
class VideoAugmentation:
    """
    Augmentation cho video - CONSISTENT across all frames
    Chỉ dùng cho training, không dùng cho val/test
    """
    def __init__(
        self,
        image_size=224,
        crop_scale=(0.8, 1.0),
        brightness=0.2,
        contrast=0.2,
        saturation=0.2,
        speed_range=(0.9, 1.1),
        mean=(0.485, 0.456, 0.406),
        std=(0.229, 0.224, 0.225),
        is_train=True,
        erase_prob=0.25,          
        erase_scale=(0.02, 0.2), 
        erase_ratio=(0.3, 3.3)    
    ):
        self.image_size = image_size
        self.crop_scale = crop_scale
        self.brightness = brightness
        self.contrast = contrast
        self.saturation = saturation
        self.speed_range = speed_range
        self.mean = mean
        self.std = std
        self.is_train = is_train
        self.erase_prob = erase_prob
        self.erase_scale = erase_scale
        self.erase_ratio = erase_ratio

    def __call__(self, frames):
        """
        Args:
            frames: tensor (T, H, W, C) với giá trị 0-255
        """
        # Chuyển sang Float ngay từ đầu để tính toán chính xác
        frames = frames.float() 

        if self.is_train:
            frames = self._random_resized_crop(frames)
            frames = self._color_jitter(frames)
            frames = self._random_erasing(frames)
            if random.random() < 0.5:
                frames = TF.hflip(frames)
        else:
            # SỬA LỖI Ở ĐÂY: Permute trước khi interpolate
            frames = frames.permute(0, 3, 1, 2) # (T, C, H, W)
            frames = F.interpolate(
                frames,
                size=(self.image_size, self.image_size),
                mode="bilinear",
                align_corners=False
            )
            frames = frames.permute(0, 2, 3, 1) # Trả về (T, H, W, C) để đồng nhất với normalize

        # ---------- NORMALIZE ----------
        # Hàm _normalize của bạn nhận (T, H, W, C) và trả về (T, C, H, W) -> Rất chuẩn
        frames = self._normalize(frames)
        return frames

    def _random_resized_crop(self, frames):
        """Random crop rồi resize về 224x224 - CONSISTENT"""
        T, H, W, C = frames.shape

        # Random scale và position (CÙNG cho tất cả frames)
        scale = random.uniform(self.crop_scale[0], self.crop_scale[1])
        crop_h, crop_w = int(H * scale), int(W * scale)

        top = random.randint(0, H - crop_h)
        left = random.randint(0, W - crop_w)

        # Crop tất cả frames GIỐNG NHAU
        frames = frames[:, top:top+crop_h, left:left+crop_w, :]

        # Resize về 224x224
        # (T, H, W, C) -> (T, C, H, W) for interpolate
        frames = frames.permute(0, 3, 1, 2).float()
        frames = F.interpolate(frames, size=(224, 224), mode='bilinear', align_corners=False)
        # (T, C, H, W) -> (T, H, W, C)
        frames = frames.permute(0, 2, 3, 1)

        return frames

    def _color_jitter(self, frames):
        """Color jitter - CONSISTENT cho tất cả frames"""
        # Random parameters (CÙNG cho tất cả frames)
        brightness_factor = 1.0 + random.uniform(-self.brightness, self.brightness)
        contrast_factor = 1.0 + random.uniform(-self.contrast, self.contrast)
        saturation_factor = 1.0 + random.uniform(-self.saturation, self.saturation)

        frames = frames.float()

        # Brightness
        frames = frames * brightness_factor

        # Contrast
        mean = frames.mean(dim=(1, 2), keepdim=True)
        frames = (frames - mean) * contrast_factor + mean

        # Saturation
        gray = frames.mean(dim=-1, keepdim=True)
        frames = gray + (frames - gray) * saturation_factor

        # Clamp to valid range
        frames = torch.clamp(frames, 0, 255)

        return frames

    def _random_erasing(self, frames):
        """
        Random Erasing CONSISTENT across frames
        frames: (T, H, W, C), float [0,1]
        """
        if random.random() > self.erase_prob:
            return frames
    
        T, H, W, C = frames.shape
        area = H * W
    
        erase_area = random.uniform(*self.erase_scale) * area
        aspect_ratio = random.uniform(*self.erase_ratio)
    
        h = int(round((erase_area * aspect_ratio) ** 0.5))
        w = int(round((erase_area / aspect_ratio) ** 0.5))
    
        if h <= 0 or w <= 0 or h >= H or w >= W:
            return frames
    
        top = random.randint(0, H - h)
        left = random.randint(0, W - w)
    
        # erase value trong [0,1]
        erase_value = torch.randint(0, 256, (1,1,1,C), device=frames.device).float()
    
        frames[:, top:top+h, left:left+w, :] = erase_value
        return frames
    def _normalize(self, frames):
        """
        frames: (T, H, W, C) in [0,255]
        return: (T, C, H, W) normalized
        """
        frames = frames.float() / 255.0
        frames = frames.permute(0, 3, 1, 2)  # (T, C, H, W)
    
        mean = torch.tensor(self.mean, device=frames.device).view(1, 3, 1, 1)
        std = torch.tensor(self.std, device=frames.device).view(1, 3, 1, 1)
    
        return (frames - mean) / std

def denormalize(frames, mean=(0.485, 0.456, 0.406),std=(0.229, 0.224, 0.225)):
    frames = frames.clone()
    for c in range(frames.shape[1]):
        frames[:, c] = frames[:, c] * std[c] + mean[c]
    return frames.clamp(0, 1)

In [None]:
class VideoOneClipDataset(Dataset):

    def __init__(
        self,
        root,
        num_frames=16,
        image_size=224,
        is_train=False
    ):
        self.root = Path(root)
        self.num_frames = num_frames
        self.is_train = is_train

        # video-level augmentation (bạn đã có VideoAugmentation)
        self.transform = VideoAugmentation(
            image_size=image_size,
            is_train=is_train
        )

        self.to_tensor = transforms.ToTensor()

        # load class
        self.classes = sorted([d.name for d in self.root.iterdir() if d.is_dir()])
        self.class_to_idx = {c: i for i, c in enumerate(self.classes)}

        # collect samples
        self.samples = []
        self.label_idx = []
        for cls in self.classes:
            cls_dir = self.root / cls
            for video_dir in sorted([d for d in cls_dir.iterdir() if d.is_dir()]):
                frame_paths = sorted([
                    p for p in video_dir.iterdir()
                    if p.suffix.lower() in {".jpg", ".jpeg", ".png"}
                ])
                if len(frame_paths) > 0:
                    self.samples.append((frame_paths, self.class_to_idx[cls]))
                    self.label_idx.append(self.class_to_idx[cls])

    def __len__(self):
        return len(self.samples)

    def uniform_clip_sampling(self, frame_paths):
        total = len(frame_paths)
        idxs = []
        
        if total >= self.num_frames:
            # Chia video thành n đoạn bằng nhau
            seg_size = total / self.num_frames
            for i in range(self.num_frames):
                start = int(i * seg_size)
                end = int((i + 1) * seg_size)
                if self.is_train:
                    # Training: Lấy ngẫu nhiên 1 frame trong đoạn
                    idxs.append(random.randint(start, end - 1))
                else:
                    # Validation/Inference: Lấy frame chính giữa đoạn cho ổn định
                    idxs.append((start + end) // 2)
        else:
            # Nếu video ngắn hơn số frame cần thiết: Lấy hết và padding frame cuối
            idxs = list(range(total)) + [total - 1] * (self.num_frames - total)
            
        return [frame_paths[i] for i in idxs]

    def __getitem__(self, idx):
        frame_paths, label = self.samples[idx]
    
        frames = self.uniform_clip_sampling(frame_paths)
    
        imgs = []
        for f in frames:
            img = Image.open(f).convert("RGB")
            imgs.append(torch.from_numpy(np.array(img)))  # CHỈ to_tensor
    
        video = torch.stack(imgs)  # (T, C, H, W)
        video = self.transform(video)  # VideoAugmentation
    
        return video, label

In [None]:
class TestOneClipDataset(Dataset):
    """
    Dataset cho VideoMAE multi-clip inference / validation

    Output:
        clips: (num_clips, T, C, H, W)
        label: int
    """

    def __init__(
        self,
        root,
        num_frames=16,
        image_size=224,
        is_train=False
    ):
        self.root = Path(root)
        self.num_frames = num_frames
        self.is_train = is_train

        # video-level augmentation (bạn đã có VideoAugmentation)
        self.transform = VideoAugmentation(
            image_size=image_size,
            is_train=is_train
        )

        self.to_tensor = transforms.ToTensor()

        self.video_dirs = sorted([d for d in self.root.iterdir() if d.is_dir()], key=lambda x: int(x.name))
        self.video_ids = [int(d.name) for d in self.video_dirs]

    def __len__(self):
        return len(self.video_dirs)

    def uniform_clip_sampling(self, frame_paths):
        total = len(frame_paths)
        idxs = []
        
        if total >= self.num_frames:
            # Chia video thành n đoạn bằng nhau
            seg_size = total / self.num_frames
            for i in range(self.num_frames):
                start = int(i * seg_size)
                end = int((i + 1) * seg_size)
                if self.is_train:
                    # Training: Lấy ngẫu nhiên 1 frame trong đoạn
                    idxs.append(random.randint(start, end - 1))
                else:
                    # Validation/Inference: Lấy frame chính giữa đoạn cho ổn định
                    idxs.append((start + end) // 2)
        else:
            # Nếu video ngắn hơn số frame cần thiết: Lấy hết và padding frame cuối
            idxs = list(range(total)) + [total - 1] * (self.num_frames - total)
            
        return [frame_paths[i] for i in idxs]

    def __getitem__(self, idx):
        video_dir = self.video_dirs[idx]
        video_id = self.video_ids[idx]
        frame_paths = sorted([p for p in video_dir.iterdir() if p.suffix.lower() in {'.jpg', '.jpeg', '.png'}])

        frames = self.uniform_clip_sampling(frame_paths)
    
        imgs = []
        for f in frames:
            img = Image.open(f).convert("RGB")
            imgs.append(torch.from_numpy(np.array(img)))  # CHỈ to_tensor
    
        video = torch.stack(imgs)  # (T, C, H, W)
        video = self.transform(video)  # VideoAugmentation

        return video, video_id

In [None]:
def collate_fn(batch: List[Tuple[torch.Tensor, int]]) -> Tuple[torch.Tensor, torch.Tensor]:
    videos = torch.stack([item[0] for item in batch])
    labels = torch.tensor([item[1] for item in batch], dtype=torch.long)
    return videos, labels

def create_balanced_sampler(dataset):
    """Create balanced sampler for imbalanced dataset"""
    if hasattr(dataset, 'dataset'):
        all_labels = [dataset.dataset.label_idx[i] for i in dataset.indices]
    else:
        all_labels = dataset.label_idx

    class_counts = np.bincount(all_labels)
    class_weights = 1.0 / class_counts
    sample_weights = [class_weights[label] for label in all_labels]
    sample_weights = torch.FloatTensor(sample_weights)

    sampler = WeightedRandomSampler(
        weights=sample_weights,
        num_samples=len(sample_weights),
        replacement=True
    )

    print(f"Balanced Sampler: class counts min={class_counts.min()}, max={class_counts.max()}")
    return sampler

In [None]:
def train_one_epoch(model, loader, criterion, optimizer, scaler, device, grad_accum_steps=1):
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0
    optimizer.zero_grad()
    progress = tqdm(loader, desc="Train", leave=False)
    for batch_idx, (videos, labels) in enumerate(progress):
        videos = videos.to(device).float()
        labels = labels.to(device)
        with torch.amp.autocast(device_type='cuda', enabled=(device.type == 'cuda')):
            outputs = model(videos)
            logits = outputs.logits
            loss = criterion(logits, labels)
        preds = logits.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
        loss_value = loss.item()
        loss = loss / grad_accum_steps
        scaler.scale(loss).backward()
        should_step = ((batch_idx + 1) % grad_accum_steps == 0) or (batch_idx + 1 == len(loader))
        if should_step:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
        batch_size = videos.size(0)
        total_loss += loss_value * batch_size
        progress.set_postfix(loss=f"{loss_value:.4f}", acc=f"{correct / max(total, 1):.4f}")
    avg_loss = total_loss / max(total, 1)
    avg_acc = correct / max(total, 1)
    return avg_loss, avg_acc

In [None]:
PATH_DATA_TRAIN = r'/kaggle/input/action-video/data/data_train'
PATH_DATA_TEST = r'/kaggle/input/action-video/data/test'
NUM_FRAMES = 16
IMG_SIZE = 224
BATCH_SIZE = 8 
train_dataset = VideoOneClipDataset(PATH_DATA_TRAIN,num_frames=16,image_size=224,is_train=True)
balanced_sampler = create_balanced_sampler(train_dataset)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE,sampler=balanced_sampler, num_workers=2, pin_memory=True)
test_dataset = TestOneClipDataset(PATH_DATA_TEST,num_frames=16,image_size=224,is_train=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

print(f'Train clips: {len(train_dataset)} | Test clips: {len(test_dataset)}')
print(f'Train Class count: {len(train_dataset.classes)}')

In [None]:
import os
from collections import Counter
def count_classes(parent_dir):
    classes = {}
    for folder in os.listdir(parent_dir):
        subdir = os.path.join(parent_dir, folder)
        num_subdirs = sum(
            1 for d in os.listdir(subdir)
            if os.path.isdir(os.path.join(subdir, d))
        )
        classes[folder] = num_subdirs
    return classes
per_class_counts = count_classes(PATH_DATA_TRAIN)
print("Samples per class:", per_class_counts)

# 2️⃣ Tạo weight
num_classes = len(per_class_counts)
total_samples = sum(per_class_counts.values())

weights = [total_samples / (num_classes * per_class_counts[c]) for c in per_class_counts]

# 3️⃣ Chuyển sang tensor
class_weights = torch.tensor(weights, dtype=torch.float).to(DEVICE)


In [None]:
EPOCHS = 10
GRAD_ACCUM_STEPS = 4
#checkpoint_path = '/kaggle/working/videomae_best.pt'
#print(f"Loading checkpoint from {checkpoint_path}...")
#checkpoint = torch.load(checkpoint_path, map_location=DEVICE)
#print(checkpoint['acc'])
model = VideoMAEForVideoClassification.from_pretrained(
        'MCG-NJU/videomae-base-finetuned-kinetics',
        hidden_dropout_prob = 0.1,
        attention_probs_dropout_prob =0.1,
         num_labels=51,
        ignore_mismatched_sizes=True
    ).to(DEVICE)
#model.load_state_dict(checkpoint['model'])
criterion = nn.CrossEntropyLoss(weight=class_weights,label_smoothing=0.1)
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=5e-5,
    weight_decay=0.05
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=10
)
scaler = torch.amp.GradScaler(enabled=torch.cuda.is_available())
best_acc = 0.0
best_ckpt = Path('./videomae_best.pt')
last_ckpt = Path('./videomae_last.pt')
history = {'train_loss': [], 'train_acc': []}

for epoch in range(EPOCHS):
    train_loss, train_acc = train_one_epoch(
        model,
        train_loader,
        criterion,
        optimizer,
        scaler,
        DEVICE,
        grad_accum_steps=GRAD_ACCUM_STEPS,
    )
    scheduler.step()
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    torch.save({'model': model.state_dict(), 'acc': train_acc}, last_ckpt)
    if train_acc > best_acc:
        best_acc = train_acc
        torch.save({'model': model.state_dict(), 'acc': best_acc}, best_ckpt)

    print(
        f'Epoch {epoch + 1}/{EPOCHS} | train_loss={train_loss:.4f} | train_acc={train_acc:.4f}'
    )

trained_model = model
training_history = history

In [None]:
predictions = []
classes = train_dataset.classes
with torch.no_grad():
    for videos, video_ids in tqdm(test_loader, desc="Inference"):
        videos = videos.to(DEVICE)
        outputs = model(videos)
        logits = outputs.logits
        preds = logits.argmax(dim=1)
        for video_id, pred_idx in zip(video_ids.cpu().numpy(), preds.cpu().numpy()):
            pred_class = classes[pred_idx]
            predictions.append((video_id, pred_class))

predictions.sort(key=lambda x: x[0])

In [None]:
submission_path = Path('./submission.csv')
with open(submission_path, 'w') as f:
    f.write('id,class\n')
    for video_id, pred_class in predictions:
        f.write(f'{video_id},{pred_class}\n')

print(f"Submission saved to: {submission_path}")