In [None]:
import os, gc, random, math
import numpy as np
import pandas as pd
from glob import glob
from tqdm import tqdm
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets
from torchvision import transforms
from torchvision.transforms import ColorJitter, RandomPerspective, RandAugment

import timm
from timm.loss import LabelSmoothingCrossEntropy
try:
    from timm.utils import ModelEmaV2
except:
    ModelEmaV2 = None

In [None]:
# Config

TRAIN_DIR = "/kaggle/input/action-dm-dataset/action-dm-dataset_2/train"
TEST_DIR  = "/kaggle/input/action-dm-dataset/action-dm-dataset_2/test/test"
MODEL_DIR_TEACHER = "models_teacher"
MODEL_DIR_STUDENT = "models_student"
PSEUDO_CSV_PATH = "pseudo_labels.csv"
OUT_FILE = "submission.csv"

os.makedirs(MODEL_DIR_TEACHER, exist_ok=True)
os.makedirs(MODEL_DIR_STUDENT, exist_ok=True)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
torch.backends.cudnn.benchmark = True

BATCH_SIZE = 8
IMG_SIZE = 384
BACKBONE = "convnextv2_large.fcmae_ft_in22k_in1k"
N_FOLDS_TEACHER = 3
EPOCHS_TEACHER = 15
N_FOLDS_STUDENT = 3
EPOCHS_STUDENT = 12
CONFIDENCE_THRESHOLD = 0.87
ACCUMULATION_STEPS = 2

print("DEVICE:", DEVICE)
print("BACKBONE:", BACKBONE)

In [None]:
# Transforms

train_tfms = transforms.Compose([
    transforms.RandomResizedCrop(IMG_SIZE, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomApply([ColorJitter(0.2, 0.2, 0.2, 0.1)], p=0.5),
    transforms.RandomApply([RandomPerspective(distortion_scale=0.2, p=0.5)], p=0.3),
    transforms.RandomApply([transforms.GaussianBlur(3)], p=0.3),
    RandAugment(num_ops=2, magnitude=9),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225]),
    transforms.RandomErasing(p=0.25, scale=(0.02, 0.2), ratio=(0.3, 3.3))
])

val_test_tfms = transforms.Compose([
    transforms.Resize(int(IMG_SIZE * 1.14)),
    transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

tta_tfms = [
    val_test_tfms,
    transforms.Compose([
        transforms.Resize(int(IMG_SIZE * 1.14)),
        transforms.CenterCrop(IMG_SIZE),
        transforms.RandomHorizontalFlip(p=1.0),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ]),
    transforms.Compose([
        transforms.Resize(int(IMG_SIZE * 1.14)),
        transforms.CenterCrop(IMG_SIZE),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ])
]

In [None]:
# Dataset wrapper

class SimpleImageDataset(Dataset):
    def __init__(self, paths, labels, transform=None, source_flags=None):
        self.paths = paths
        self.labels = labels
        self.transform = transform
        self.source_flags = source_flags if source_flags is not None else [0]*len(paths)  # 0 = real, 1 = pseudo

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        p = self.paths[idx]
        y = int(self.labels[idx])
        src = int(self.source_flags[idx])
        img = Image.open(p).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, y, src

In [None]:
# Model factory

def get_model(backbone_name, n_classes, pretrained=True):
    print("Loading model:", backbone_name)
    model = timm.create_model(backbone_name, pretrained=pretrained, num_classes=n_classes)
    # grad checkpointing if available
    try:
        if hasattr(model, "set_grad_checkpointing"):
            model.set_grad_checkpointing(True)
            print(" - grad checkpointing enabled")
    except Exception:
        pass
    return model

In [None]:
# MixUp / CutMix utils

def rand_bbox(size, lam):
    # size: (B, C, H, W)
    H = size[2]
    W = size[3]
    cut_rat = math.sqrt(max(0.0, 1.0 - lam))
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)
    cx = np.random.randint(0, W)
    cy = np.random.randint(0, H)
    x1 = int(np.clip(cx - cut_w // 2, 0, W - 1))
    y1 = int(np.clip(cy - cut_h // 2, 0, H - 1))
    x2 = int(np.clip(cx + cut_w // 2, 0, W - 1))
    y2 = int(np.clip(cy + cut_h // 2, 0, H - 1))
    # ensure valid box
    if x2 <= x1:
        x2 = min(W - 1, x1 + 1)
    if y2 <= y1:
        y2 = min(H - 1, y1 + 1)
    return x1, y1, x2, y2

def cutmix_data(x, y, alpha=1.0):
    if alpha <= 0:
        return x, y, y, 1.0
    lam = np.random.beta(alpha, alpha)
    index = torch.randperm(x.size(0)).to(x.device)
    x1, y1, x2, y2 = rand_bbox(x.size(), lam)
    # apply exchange
    x[:, :, y1:y2, x1:x2] = x[index, :, y1:y2, x1:x2]
    # recompute lam as pixel ratio
    lam = 1.0 - ((x2 - x1) * (y2 - y1) / (x.size(-1) * x.size(-2)))
    return x, y, y[index], lam

def mixup_data(x, y, alpha=1.0):
    if alpha <= 0:
        return x, y, y, 1.0
    lam = np.random.beta(alpha, alpha)
    index = torch.randperm(x.size(0)).to(x.device)
    mixed_x = lam * x + (1. - lam) * x[index, :]
    return mixed_x, y, y[index], lam

def mix_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1. - lam) * criterion(pred, y_b)

In [None]:
# TRAIN FUNCTION 

def train_for_epochs_mix(model, train_loader, val_loader, criterion, optimizer, scheduler, epochs, fold, model_save_dir, ema_model=None):
    best_acc = 0.0
    scaler = torch.cuda.amp.GradScaler()
    accumulation = max(1, ACCUMULATION_STEPS)

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        running_corrects = 0
        total = 0

        pbar = tqdm(train_loader, desc=f"Fold {fold+1} Epoch {epoch+1}/{epochs} [Train]")
        for batch_idx, (images, labels, src_flags) in enumerate(pbar):
            images = images.to(DEVICE); labels = labels.to(DEVICE); src_flags = src_flags.to(DEVICE)

            r = random.random()
            is_mixed = False

            with torch.cuda.amp.autocast():
                if r < 0.25:
                    imgs, y_a, y_b, lam = mixup_data(images, labels, alpha=1.0)
                    outputs = model(imgs)
                    loss = mix_criterion(criterion, outputs, y_a, y_b, lam)
                    is_mixed = True
                elif r < 0.5:
                    imgs, y_a, y_b, lam = cutmix_data(images, labels, alpha=1.0)
                    outputs = model(imgs)
                    loss = mix_criterion(criterion, outputs, y_a, y_b, lam)
                    is_mixed = True
                else:
                    outputs = model(images)
                    loss = criterion(outputs, labels)

                if src_flags.sum() > 0:
                    pseudo_ratio = (src_flags == 1).float().mean().item()
                    loss = loss * (1.0 - 0.5 * pseudo_ratio)

                loss = loss / accumulation

            scaler.scale(loss).backward()

            if (batch_idx + 1) % accumulation == 0 or (batch_idx + 1) == len(train_loader):
                try:
                    scaler.unscale_(optimizer)
                except Exception:
                    pass
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

                # EMA update if provided
                if ema_model is not None:
                    try:
                        ema_model.update(model)
                    except:
                        pass

                try:
                    scheduler.step(epoch + batch_idx / float(len(train_loader)))
                except Exception:
                    try:
                        scheduler.step()
                    except Exception:
                        pass

            with torch.no_grad():
                preds = torch.argmax(outputs.detach(), dim=1)
                if not is_mixed:
                    running_corrects += torch.sum(preds == labels).item()
                running_loss += loss.item() * images.size(0) * accumulation
                total += images.size(0)

            pbar.set_postfix_str(f"loss={running_loss/total:.4f} acc={running_corrects/total:.4f}")

        train_loss = running_loss / total if total > 0 else 0.0
        train_acc = running_corrects / total if total > 0 else 0.0

        # Validation
        model.eval()
        val_loss = 0.0
        val_corrects = 0
        val_total = 0
        with torch.no_grad(), torch.cuda.amp.autocast():
            for images, labels, _ in val_loader:
                images = images.to(DEVICE); labels = labels.to(DEVICE)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item() * images.size(0)
                preds = torch.argmax(outputs, 1)
                val_corrects += torch.sum(preds == labels).item()
                val_total += images.size(0)

        val_loss = val_loss / val_total if val_total > 0 else 0.0
        val_acc = val_corrects / val_total if val_total > 0 else 0.0

        print(f"[Fold {fold+1}][Epoch {epoch+1}/{epochs}] train_loss={train_loss:.4f} train_acc={train_acc:.4f} | val_loss={val_loss:.4f} val_acc={val_acc:.4f}")

        # Save best model
        if val_acc > best_acc:
            best_acc = val_acc
            save_path = os.path.join(model_save_dir, f"best_model_fold_{fold}.pth")
            try:
                if ema_model is not None:
                    ema_state = ema_model.module.state_dict() if hasattr(ema_model, "module") else ema_model.state_dict()
                    torch.save(ema_state, save_path)
                    print("Saved EMA weights to", save_path)
                else:
                    torch.save(model.state_dict(), save_path)
                    print("Saved model weights to", save_path)
            except Exception as e:
                print("EMA save failed:", e, "â€” saving model.state_dict() instead")
                torch.save(model.state_dict(), save_path)

    print(f"Fold {fold+1} finished. Best val_acc={best_acc:.4f}")
    return best_acc

In [None]:
# Helper: K-Fold run (teacher/student reuse)

from sklearn.model_selection import StratifiedKFold

def run_kfold_train(data_paths, data_labels, model_save_dir, n_folds, epochs, is_teacher=True):
    skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=SEED)
    for fold, (train_idx, val_idx) in enumerate(skf.split(data_paths, data_labels)):
        print(f"\n=== Fold {fold+1}/{n_folds} ===")
        train_paths = [data_paths[i] for i in train_idx]
        train_labels = [data_labels[i] for i in train_idx]
        val_paths = [data_paths[i] for i in val_idx]
        val_labels = [data_labels[i] for i in val_idx]

        train_ds = SimpleImageDataset(train_paths, train_labels, transform=train_tfms, source_flags=[0]*len(train_paths))
        val_ds = SimpleImageDataset(val_paths, val_labels, transform=val_test_tfms, source_flags=[0]*len(val_paths))

        train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)
        val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE*2, shuffle=False, num_workers=2, pin_memory=True)

        model = get_model(BACKBONE, n_classes, pretrained=True).to(DEVICE)

        # EMA for teacher only (is_teacher)
        ema_model = None
        if is_teacher and ModelEmaV2 is not None:
            try:
                ema_model = ModelEmaV2(model, decay=0.9999)
                print("ModelEmaV2 enabled for teacher.")
            except:
                ema_model = None

        criterion = LabelSmoothingCrossEntropy(smoothing=0.1)
        # STAGE 1: freeze backbone, train head
        for param in model.parameters():
            param.requires_grad = False
        try:
            classifier = model.get_classifier()
            for p in classifier.parameters():
                p.requires_grad = True
            head_params = list(classifier.parameters())
        except Exception:
            # fallback: unfreeze last 20 params
            all_params = list(model.parameters())
            for p in all_params[-20:]:
                p.requires_grad = True
            head_params = all_params[-20:]
        optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=3e-4, weight_decay=1e-2)
        # use OneCycleLR for head stage for faster convergence
        steps_per_epoch = max(1, len(train_loader))
        scheduler = OneCycleLR(optimizer, max_lr=3e-4, epochs=max(1, epochs//2), steps_per_epoch=steps_per_epoch, pct_start=0.2)

        # STAGE 2: unfreeze top layers (~30%)
        total_params = len(list(model.parameters()))
        for i, p in enumerate(model.parameters()):
            p.requires_grad = (i > total_params * 0.7)
        # parameter-grouping: slightly higher lr for newly-unfrozen params?
        opt_params = [
            {'params': [p for p in model.parameters() if p.requires_grad], 'lr': 5e-5}
        ]
        optimizer = AdamW(opt_params, lr=5e-5, weight_decay=1e-2)
        scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=1, eta_min=1e-6)

        # STAGE 3: full fine-tune
        for p in model.parameters():
            p.requires_grad = True
        optimizer = AdamW(model.parameters(), lr=1e-5, weight_decay=1e-2)
        scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=1, eta_min=1e-6)
        
        train_for_epochs_mix(model, train_loader, val_loader, criterion, optimizer, scheduler, epochs - epochs//2 - epochs//4, fold, model_save_dir, ema_model=ema_model)

        del model, train_loader, val_loader, train_ds, val_ds, ema_model
        gc.collect()
        torch.cuda.empty_cache()

In [None]:
# Pseudo-label generation

def load_teacher_models(model_dir, n_folds):
    models = []
    for f in range(n_folds):
        p = os.path.join(model_dir, f"best_model_fold_{f}.pth")
        if os.path.exists(p):
            m = get_model(BACKBONE, n_classes, pretrained=False).to(DEVICE)
            m.load_state_dict(torch.load(p, map_location=DEVICE))
            m.eval()
            models.append(m)
            print("Loaded teacher fold", f)
    return models

def tta_predict(models_list, img_path):
    img = Image.open(img_path).convert("RGB")
    all_probs = []
    with torch.no_grad():
        for m in models_list:
            m.eval()
            probs = []
            for tfm in tta_tfms:
                x = tfm(img).unsqueeze(0).to(DEVICE)
                out = m(x)
                probs.append(F.softmax(out, dim=1).cpu().numpy())
            probs = np.mean(probs, axis=0)  # avg TTA
            all_probs.append(probs)
    final = np.mean(all_probs, axis=0)[0]
    return final  # shape: (n_classes,)

def generate_pseudo_labels(teacher_models, test_imgs, csv_path, threshold=CONFIDENCE_THRESHOLD):
    records = []
    for p in tqdm(test_imgs, desc="Pseudo labeling"):
        probs = tta_predict(teacher_models, p)
        idx = int(np.argmax(probs))
        conf = float(np.max(probs))
        records.append({"ID":os.path.basename(p), "path":p, "predicted_index":idx, "confidence":conf})
    df = pd.DataFrame(records)
    df.to_csv(csv_path, index=False)
    print("Pseudo saved:", csv_path)
    df_strong = df[df['confidence'] >= threshold].copy()
    print("Pseudo strong count:", len(df_strong))
    return df, df_strong

In [None]:
# Student training

def train_student_with_pseudo(original_paths, original_labels, df_pseudo_strong, student_model_dir, n_folds, epochs):
    pseudo_paths = df_pseudo_strong['path'].tolist()
    pseudo_labels = df_pseudo_strong['predicted_index'].astype(int).tolist()
    combined_paths = original_paths + pseudo_paths
    combined_labels = original_labels + pseudo_labels
    source_flags = [0]*len(original_paths) + [1]*len(pseudo_paths)
    print("Student dataset size:", len(combined_paths), "pseudo:", len(pseudo_paths))
    skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=SEED)
    for fold, (train_idx, val_idx) in enumerate(skf.split(combined_paths, combined_labels)):
        print(f"\n=== Student Fold {fold+1}/{n_folds} ===")
        train_paths = [combined_paths[i] for i in train_idx]
        train_labels = [combined_labels[i] for i in train_idx]
        train_flags  = [source_flags[i] for i in train_idx]

        val_paths = [combined_paths[i] for i in val_idx]
        val_labels = [combined_labels[i] for i in val_idx]
        val_flags  = [source_flags[i] for i in val_idx]

        train_ds = SimpleImageDataset(train_paths, train_labels, transform=train_tfms, source_flags=train_flags)
        val_ds = SimpleImageDataset(val_paths, val_labels, transform=val_test_tfms, source_flags=val_flags)

        train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)
        val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE*2, shuffle=False, num_workers=2, pin_memory=True)

        model = get_model(BACKBONE, n_classes, pretrained=True).to(DEVICE)
        criterion = LabelSmoothingCrossEntropy(smoothing=0.05)
        backbone_params = []
        head_params = []
        for name, p in model.named_parameters():
            if 'head' in name or 'classifier' in name or 'fc' in name:
                head_params.append(p)
            else:
                backbone_params.append(p)
        
        param_groups = [
            {'params': backbone_params, 'lr': 5e-6},
            {'params': head_params, 'lr': 5e-5}
        ] if len(head_params) > 0 else [{'params': model.parameters(), 'lr': 5e-5}]
        
        optimizer = AdamW(param_groups, weight_decay=1e-2)
        scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=1, eta_min=1e-6)

        train_for_epochs_mix(model, train_loader, val_loader, criterion, optimizer, scheduler, epochs, fold, student_model_dir, ema_model=None)

        del model, train_loader, val_loader
        gc.collect()
        torch.cuda.empty_cache()

In [None]:
# MAIN: execute pipeline

# 1) prepare original dataset info
full = datasets.ImageFolder(TRAIN_DIR)
class_names = full.classes
n_classes = len(class_names)
print("Classes:", class_names)

original_paths = [p for p,_ in full.samples]
original_labels = [int(l) for _,l in full.samples]  # safer: use full.targets if needed
test_imgs = sorted(glob(os.path.join(TEST_DIR, "*")))

# 2) Train Teacher (K-Fold)
print("\n*** START TEACHER TRAINING ***")
run_kfold_train(original_paths, original_labels, MODEL_DIR_TEACHER, N_FOLDS_TEACHER, EPOCHS_TEACHER, is_teacher=True)

# 3) Load teacher models & generate pseudo labels
teacher_models = load_teacher_models(MODEL_DIR_TEACHER, N_FOLDS_TEACHER)
df_pseudo_all, df_pseudo_strong = generate_pseudo_labels(teacher_models, test_imgs, PSEUDO_CSV_PATH, threshold=CONFIDENCE_THRESHOLD)

if len(df_pseudo_strong) < max(50, int(0.05 * len(test_imgs))):
    thr = max(0.75, df_pseudo_all['confidence'].quantile(0.75))
    print("Adaptive lower threshold to", thr)
    df_pseudo_strong = df_pseudo_all[df_pseudo_all['confidence'] >= thr].copy()
    print("New strong count:", len(df_pseudo_strong))

# 4) Student training
print("\n*** START STUDENT TRAINING (real + pseudo) ***")
train_student_with_pseudo(original_paths, original_labels, df_pseudo_strong, MODEL_DIR_STUDENT, N_FOLDS_STUDENT, EPOCHS_STUDENT)

# 5) Final inference ensemble (Student models)
print("\n*** FINAL INFERENCE WITH STUDENT MODELS ***")
student_models = []
for f in range(N_FOLDS_STUDENT):
    p = os.path.join(MODEL_DIR_STUDENT, f"best_model_fold_{f}.pth")
    if os.path.exists(p):
        m = get_model(BACKBONE, n_classes, pretrained=False).to(DEVICE)
        m.load_state_dict(torch.load(p, map_location=DEVICE))
        m.eval()
        student_models.append(m)
print("Loaded student models:", len(student_models))

# inference
rows = []
for p in tqdm(test_imgs, desc="Final Inference"):
    probs = tta_predict(student_models, p)
    idx = int(np.argmax(probs))
    rows.append({"ID": os.path.basename(p), "label": class_names[idx]})

pd.DataFrame(rows).to_csv(OUT_FILE, index=False)
print("Submission saved to", OUT_FILE)