In [None]:
from __future__ import annotations
import math
import random
import re
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import amp
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.transforms import InterpolationMode
import torchvision.transforms.functional as TF
from tqdm.auto import tqdm
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import timm
from torch.utils.data import Dataset, DataLoader, random_split, WeightedRandomSampler
import os
from collections import Counter
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {DEVICE}')

In [None]:
class VideoTransform:
    def __init__(self, image_size=224, is_train=True):
        self.image_size = image_size
        self.is_train = is_train
        self.mean = [0.5, 0.5, 0.5]
        self.std = [0.5, 0.5, 0.5]
    
    def __call__(self, frames):
        # frames đầu vào: [T, C, H, W]
        if self.is_train:
            h, w = frames.shape[-2:]
            scale = random.uniform(0.8, 1.0)
            new_h, new_w = int(h * scale), int(w * scale)
            
            frames = TF.resize(frames, [new_h, new_w], interpolation=InterpolationMode.BILINEAR)
            
            i = random.randint(0, max(0, new_h - self.image_size))
            j = random.randint(0, max(0, new_w - self.image_size))
            
            frames = TF.crop(frames, i, j, self.image_size, self.image_size)
            if random.random() < 0.5:
                frames = TF.hflip(frames)
            
            if random.random() < 0.3:
                frames = TF.adjust_brightness(frames, random.uniform(0.9, 1.1))
            if random.random() < 0.3:
                frames = TF.adjust_contrast(frames, random.uniform(0.9, 1.1))
        else:
            frames = TF.resize(frames, [self.image_size, self.image_size], interpolation=InterpolationMode.BILINEAR)
        
        # Chuẩn hóa từng frame trong 1 tensor 4D [T, C, H, W]
        # Chú ý: TF.normalize có thể nhận tensor 4D nếu version torchvision đủ mới, 
        # nhưng an toàn nhất là làm thế này:
        frames = frames / 255.0 if frames.max() > 1.0 else frames
        
        # Biến đổi mean/std thành dạng [1, 3, 1, 1] để broadcast trừ cho [T, 3, H, W]
        m = torch.tensor(self.mean).view(1, 3, 1, 1)
        s = torch.tensor(self.std).view(1, 3, 1, 1)
        frames = (frames - m) / s
        
        return frames

print("Augmentation defined")

def denormalize(video_tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
    """
    Biến đổi tensor từ chuẩn hóa [-1, 1] hoặc [custom] về lại [0, 1] để hiển thị.
    Dữ liệu đầu vào: video_tensor dạng [C, T, H, W]
    """
    # 1. Tạo tensor mean và std với định dạng [C, 1, 1, 1] để khớp với [C, T, H, W]
    device = video_tensor.device
    mean = torch.tensor(mean, device=device).view(3, 1, 1, 1)
    std = torch.tensor(std, device=device).view(3, 1, 1, 1)
    
    # 2. Phép tính ngược: (x * std) + mean
    res = video_tensor * std + mean
    
    # 3. Đảm bảo giá trị nằm trong khoảng [0, 1] để tránh lỗi hiển thị
    res = torch.clamp(res, 0, 1)
    
    return res

In [None]:
class VideoDataset(Dataset):
    def __init__(self, root,num_segments=4,clip_len=16, image_size=224, is_train=True):
        self.root = Path(root)
        self.num_segments = num_segments
        self.clip_len = clip_len
        self.is_train = is_train
        self.transform = VideoTransform(image_size, is_train)
        self.to_tensor = transforms.ToTensor()
        self.classes = sorted([d.name for d in self.root.iterdir() if d.is_dir()])
        self.class_to_idx = {name: idx for idx, name in enumerate(self.classes)}
        self.samples = []
        self.label_idx = []
        for cls in self.classes:
            cls_dir = self.root / cls
            for video_dir in sorted([d for d in cls_dir.iterdir() if d.is_dir()]):
                frame_paths = sorted([p for p in video_dir.iterdir() if p.suffix.lower() in {'.jpg', '.jpeg', '.png'}])
                if frame_paths:
                    self.samples.append((frame_paths, self.class_to_idx[cls]))
                    self.label_idx.append(self.class_to_idx[cls])
                    
    def __len__(self):
        return len(self.samples)

    def _tsn_sample(self, frames):
        total = len(frames)
        seg_size = total / self.num_segments
        clips = []

        for i in range(self.num_segments):
            start = int(i * seg_size)
            end = int((i + 1) * seg_size)
            if self.is_train:
                center = np.random.randint(start, max(start + 1, end))
            else:
                center = (start + end) // 2

            idxs = np.linspace(
                max(0, center - self.clip_len // 2),
                min(total - 1, center + self.clip_len // 2),
                self.clip_len
            ).astype(int)

            clips.append([frames[i] for i in idxs])

        return clips

    def __getitem__(self, idx):
        frame_paths, label = self.samples[idx]

        clips = self._tsn_sample(frame_paths)

        processed_clips = []
        for clip in clips:
            # Đọc tất cả ảnh vào 1 mảng numpy duy nhất [T, H, W, C]
            video_array = np.stack([cv2.cvtColor(cv2.imread(str(p)), cv2.COLOR_BGR2RGB) for p in clip])
            
            # Chuyển sang Tensor [T, H, W, C]
            video_tensor = torch.from_numpy(video_array).float()
            video_tensor = video_tensor.permute(0, 3, 1, 2)
            # Thực hiện Transform trên khối T-C-H-W
            if self.transform:
                video_tensor = self.transform(video_tensor)
            
            # Cuối cùng mới Permute sang chuẩn mô hình 3D: [C, T, H, W]
            # Giả sử video_tensor lúc này là [T, H, W, C]
            video_tensor = video_tensor.permute(1, 0, 2, 3)
            processed_clips.append(video_tensor)

        video = torch.stack(processed_clips, dim=0)
        return video, label


class TestDataset(Dataset):
    def __init__(self, root,num_segments=4,clip_len=16, image_size=224):
        self.root = Path(root)
        self.num_segments = num_segments
        self.clip_len = clip_len
        self.transform = VideoTransform(image_size, is_train=False)
        self.to_tensor = transforms.ToTensor()
        self.video_dirs = sorted([d for d in self.root.iterdir() if d.is_dir()], key=lambda x: int(x.name))
        self.video_ids = [int(d.name) for d in self.video_dirs]
    
    def __len__(self):
        return len(self.video_dirs)
    
    def _tsn_sample(self, frames):
        total = len(frames)
        seg_size = total / self.num_segments
        clips = []

        for i in range(self.num_segments):
            start = int(i * seg_size)
            end = int((i + 1) * seg_size)
            center = (start + end) // 2

            idxs = np.linspace(
                max(0, center - self.clip_len // 2),
                min(total - 1, center + self.clip_len // 2),
                self.clip_len
            ).astype(int)

            clips.append([frames[i] for i in idxs])

        return clips

    def __getitem__(self, idx):
        video_dir = self.video_dirs[idx]
        video_id = self.video_ids[idx]
        frame_paths = sorted([p for p in video_dir.iterdir() if p.suffix.lower() in {'.jpg', '.jpeg', '.png'}])
        total = len(frame_paths)

        clips = self._tsn_sample(frame_paths)
        processed_clips = []
        for clip in clips:
            # Đọc tất cả ảnh vào 1 mảng numpy duy nhất [T, H, W, C]
            video_array = np.stack([cv2.cvtColor(cv2.imread(str(p)), cv2.COLOR_BGR2RGB) for p in clip])
            
            # Chuyển sang Tensor [T, H, W, C]
            video_tensor = torch.from_numpy(video_array).float()
            video_tensor = video_tensor.permute(0, 3, 1, 2)
            # Thực hiện Transform trên khối T-C-H-W
            if self.transform:
                video_tensor = self.transform(video_tensor)
            
            # Cuối cùng mới Permute sang chuẩn mô hình 3D: [C, T, H, W]
            # Giả sử video_tensor lúc này là [T, H, W, C]
            video_tensor = video_tensor.permute(1, 0, 2, 3)
            processed_clips.append(video_tensor)

        video = torch.stack(processed_clips, dim=0)
        return video, video_id
    

print("Dataset classes defined")

def collate_fn(batch: List[Tuple[torch.Tensor, int]]) -> Tuple[torch.Tensor, torch.Tensor]:
    videos = torch.stack([item[0] for item in batch])
    labels = torch.tensor([item[1] for item in batch], dtype=torch.long)
    return videos, labels

def create_balanced_sampler(dataset):
    """Create balanced sampler for imbalanced dataset"""
    if hasattr(dataset, 'dataset'):
        all_labels = [dataset.dataset.label_idx[i] for i in dataset.indices]
    else:
        all_labels = dataset.label_idx

    class_counts = np.bincount(all_labels)
    class_weights = 1.0 / class_counts
    sample_weights = [class_weights[label] for label in all_labels]
    sample_weights = torch.FloatTensor(sample_weights)

    sampler = WeightedRandomSampler(
        weights=sample_weights,
        num_samples=len(sample_weights),
        replacement=True
    )

    print(f"Balanced Sampler: class counts min={class_counts.min()}, max={class_counts.max()}")
    return sampler

In [None]:
class X3DFeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = torch.hub.load(
            "facebookresearch/pytorchvideo",
            "x3d_s",
            pretrained=True
        )

        # bỏ head hoàn toàn
        self.head = self.model.blocks[-1]
        self.blocks = self.model.blocks[:-1]

    def forward(self, x):
        # x: (B, C, T, H, W)
        for block in self.blocks:
            x = block(x)

        # x lúc này là feature map 5D
        # (B, C, T', H', W')
        return x


class SpatioTemporalTokenizer(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, B, N):
        """
        x: (B*N, C, H, W)   ← output thực tế của X3D
        """
        BN, C, H, W = x.shape
        assert BN == B * N, "Batch mismatch"

        x = x.view(B, N, C, H, W)
        x = x.permute(0, 1, 3, 4, 2)   # (B, N, H, W, C)
        x = x.reshape(B, N * H * W, C)

        return x

class SpatioTemporalPositionalEncoding(nn.Module):
    def __init__(self, dim, max_tokens=256):
        super().__init__()
        self.pos_embed = nn.Parameter(
            torch.zeros(1, max_tokens, dim)
        )
        nn.init.trunc_normal_(self.pos_embed, std=0.02)

    def forward(self, x):
        return x + self.pos_embed[:, :x.size(1)]

class SpatioTemporalTransformer(nn.Module):
    def __init__(
        self,
        dim=192,
        depth=2,
        heads=4,
        dropout=0.1
    ):
        super().__init__()

        self.pos_embed = SpatioTemporalPositionalEncoding(dim)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=dim,
            nhead=heads,
            dropout=dropout,
            batch_first=True,
            norm_first=True
        )

        self.encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers=depth
        )

        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        """
        x: (B, tokens, D)
        """
        x = self.pos_embed(x)
        x = self.encoder(x)
        return self.norm(x)

class X3D_SpatioTemporalTransformer(nn.Module):
    def __init__(self, num_classes=51):
        super().__init__()

        self.backbone = X3DFeatureExtractor()
        self.tokenizer = SpatioTemporalTokenizer()

        self.embed_dim = 192

        self.st_transformer = SpatioTemporalTransformer(
            dim=self.embed_dim,
            depth=2,
            heads=4
        )

        self.fc = nn.Linear(self.embed_dim, num_classes)

    def forward(self, x):
        """
        x: (B, N, C, T, H, W)
        """
        B, N, C, T, H, W = x.shape
        x = x.view(B * N, C, T, H, W)

        feat = self.backbone(x)  # (B*N, C, t', h', w')
        feat = feat.mean(dim=2)
        tokens = self.tokenizer(feat, B, N)  # (B, tokens, C)

        tokens = self.st_transformer(tokens)

        feat = tokens.mean(dim=1)  # global pooling

        
        return self.fc(feat)


In [None]:
def train_one_epoch(model, loader, criterion, optimizer, scaler, device, grad_accum_steps=1):
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0
    optimizer.zero_grad()
    progress = tqdm(loader, desc="Train", leave=False)
    for batch_idx, (videos, labels) in enumerate(progress):
        videos = videos.to(device).float()
        labels = labels.to(device)
        #x, y_a, y_b, lam = mixup_video(videos, labels)
        with torch.amp.autocast(device_type='cuda', enabled=(device.type == 'cuda')):
            #logits = model(x)
            logits = model(videos)
            #loss = lam * criterion(logits, y_a) + (1 - lam) * criterion(logits, y_b)
            loss = criterion(logits, labels)
        preds = logits.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
        loss_value = loss.item()
        loss = loss / grad_accum_steps
        scaler.scale(loss).backward()
        should_step = ((batch_idx + 1) % grad_accum_steps == 0) or (batch_idx + 1 == len(loader))
        if should_step:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
        batch_size = videos.size(0)
        total_loss += loss_value * batch_size
        progress.set_postfix(loss=f"{loss_value:.4f}", acc=f"{correct / max(total, 1):.4f}")
    avg_loss = total_loss / max(total, 1)
    avg_acc = correct / max(total, 1)
    return avg_loss, avg_acc

print("Training functions defined")

In [None]:
PATH_DATA_TRAIN = r'/kaggle/input/action-video/data/data_train'
PATH_DATA_TEST = r'/kaggle/input/action-video/data/test'
NUM_FRAMES = 16
FRAME_STRIDE = 2
IMG_SIZE = 224
BATCH_SIZE = 4 
NUM_SEGMENTS = 16
train_dataset = VideoDataset(PATH_DATA_TRAIN,num_segments=4,clip_len=16,is_train=True)
balanced_sampler = create_balanced_sampler(train_dataset)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE,sampler=balanced_sampler, num_workers=2, pin_memory=True)
test_dataset = TestDataset(PATH_DATA_TEST,num_segments=4,clip_len=16)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

print(f'Train clips: {len(train_dataset)} | Test clips: {len(test_dataset)}')
print(f'Train Class count: {len(train_dataset.classes)}')

In [None]:

def count_classes(parent_dir):
    classes = {}
    for folder in os.listdir(parent_dir):
        subdir = os.path.join(parent_dir, folder)
        num_subdirs = sum(
            1 for d in os.listdir(subdir)
            if os.path.isdir(os.path.join(subdir, d))
        )
        classes[folder] = num_subdirs
    return classes
per_class_counts = count_classes(PATH_DATA_TRAIN)
print("Samples per class:", per_class_counts)

# 2️⃣ Tạo weight
num_classes = len(per_class_counts)
total_samples = sum(per_class_counts.values())

weights = [total_samples / (num_classes * per_class_counts[c]) for c in per_class_counts]

# 3️⃣ Chuyển sang tensor
class_weights = torch.tensor(weights, dtype=torch.float).to(DEVICE)
print(class_weights)

In [None]:
sample_idx = 0
sample_frames, sample_label = train_dataset[sample_idx]

# 1. Vì sample_frames là [N, C, T, H, W], ta chọn clip đầu tiên để hiển thị
# Shape sau khi chọn: [C, T, H, W]
first_clip = sample_frames[0] 

# 2. Denormalize clip này
vis_frames = denormalize(first_clip).cpu()

class_name = train_dataset.classes[sample_label]

# 3. Lấy số lượng khung hình từ chiều T (chiều thứ 2)
# vis_frames shape hiện tại là [C, T, H, W]
num_frames_in_clip = vis_frames.shape[1] 
frames_to_show = min(num_frames_in_clip, 12)

cols = 4
rows = math.ceil(frames_to_show / cols)

plt.figure(figsize=(12, 3 * rows))

for i in range(frames_to_show):
    plt.subplot(rows, cols, i + 1)
    
    # Lấy khung hình thứ i dọc theo chiều T: vis_frames[:, i, :, :]
    # Kết quả trả về tensor 3 chiều [C, H, W] -> Sẵn sàng để permute
    frame = vis_frames[:, i, :, :].permute(1, 2, 0).numpy()
    
    plt.imshow(frame)
    plt.axis('off')
    plt.title(f'Frame {i + 1}')

plt.suptitle(f'Sample clip (Segment 1) from class: {class_name}', fontsize=14)
plt.tight_layout()
plt.show()

In [None]:
EPOCHS = 15
GRAD_ACCUM_STEPS = 4

#checkpoint_path = Path('/kaggle/working/x3d_s_vit_best.pt')
#print(f"Loading checkpoint from {checkpoint_path}...")
#checkpoint = torch.load(checkpoint_path, map_location=DEVICE)
#classes = checkpoint['classes']

model = X3D_SpatioTemporalTransformer(num_classes=len(train_dataset.classes)).to(DEVICE)
#model.load_state_dict(checkpoint['model'])
backbone_params = model.backbone.parameters()
transformer_params = model.st_transformer.parameters()
head_params = model.fc.parameters()

criterion = nn.CrossEntropyLoss(weight=class_weights,label_smoothing=0.1)
#criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(
    [
        {"params": backbone_params, "lr": 1e-4},
        {"params": transformer_params, "lr": 5e-4},
        {"params": head_params, "lr": 1e-3},
    ],
    weight_decay=0.05,
)
scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer, step_size=4, gamma=0.5
)
scaler = torch.amp.GradScaler(enabled=torch.cuda.is_available())
#print(f"Optimizer: AdamW | Base LR: {BASE_LR} | Head LR: {HEAD_LR}")
best_acc = 0.0
best_ckpt = Path('./x3d_s_vit_best.pt')
last_ckpt = Path('./x3d_s_vit_last.pt')
history = {'train_loss': [], 'train_acc': []}

for epoch in range(EPOCHS):
    train_loss, train_acc = train_one_epoch(
        model,
        train_loader,
        criterion,
        optimizer,
        scaler,
        DEVICE,
        grad_accum_steps=GRAD_ACCUM_STEPS,
    )
    scheduler.step()
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    torch.save({'model': model.state_dict(), 'acc': train_acc}, last_ckpt)
    if train_acc > best_acc:
        best_acc = train_acc
        torch.save({'model': model.state_dict(), 'acc': best_acc}, best_ckpt)

    print(
        f'Epoch {epoch + 1}/{EPOCHS} | train_loss={train_loss:.4f} | train_acc={train_acc:.4f}'
    )

trained_model = model
training_history = history


In [None]:
print("\nRunning inference...")
predictions = []
classes = train_dataset.classes
with torch.no_grad():
    for videos, video_ids in tqdm(test_loader, desc="Inference"):
        videos = videos.to(DEVICE)
        logits = model(videos)
        preds = logits.argmax(dim=1)
        for video_id, pred_idx in zip(video_ids.cpu().numpy(), preds.cpu().numpy()):
            pred_class = classes[pred_idx]
            predictions.append((video_id, pred_class))

predictions.sort(key=lambda x: x[0])
print(f"\nTotal predictions: {len(predictions)}")

In [None]:
submission_path = Path('./submission.csv')
with open(submission_path, 'w') as f:
    f.write('id,class\n')
    for video_id, pred_class in predictions:
        f.write(f'{video_id},{pred_class}\n')

print(f"Submission saved to: {submission_path}")