In [None]:
import os, gc, random, math
import numpy as np
import pandas as pd
from glob import glob
from tqdm import tqdm
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets

import timm
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from timm.data import Mixup
try:
    from timm.utils import ModelEmaV2
except:
    ModelEmaV2 = None

In [None]:
# CONFIG

TRAIN_DIR = "/kaggle/input/action-dm-dataset/action-dm-dataset_2/train"
TEST_DIR  = "/kaggle/input/action-dm-dataset/action-dm-dataset_2/test/test"
MODEL_SAVE_PATH = "swin_large_mixcut_ema_best.pth"
OUT_FILE = "submission_swin_large_tta_mixcut.csv"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
torch.backends.cudnn.benchmark = True

BATCH_SIZE = 8
NUM_WORKERS = 4
IMG_SIZE = 384
EPOCHS = 20
BACKBONE = "swinv2_large_window12to24_192to384_22kft1k"
BASE_LR = 5e-5
WARMUP_EPOCHS = 2

USE_EMA = True
USE_MIXUP = True
USE_TTA = True

print("DEVICE:", DEVICE)
print("BACKBONE:", BACKBONE)

In [None]:
# TRANSFORMS

train_tfms = transforms.Compose([
    transforms.RandomResizedCrop(IMG_SIZE, scale=(0.85, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.2, 0.2, 0.2, 0.1),
    transforms.RandomRotation(15),
    transforms.RandomApply([transforms.GaussianBlur(3)], p=0.25),
    transforms.RandomPerspective(distortion_scale=0.2, p=0.3),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225]),
    transforms.RandomErasing(p=0.25, scale=(0.02, 0.15))
])

val_tfms = transforms.Compose([
    transforms.Resize(int(IMG_SIZE * 1.14)),
    transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

tta_tfms = [
    val_tfms,
    transforms.Compose([
        transforms.Resize(int(IMG_SIZE * 1.14)),
        transforms.CenterCrop(IMG_SIZE),
        transforms.RandomHorizontalFlip(p=1.0),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ])
]

In [None]:
# DATASET

class SimpleImageDataset(Dataset):
    def __init__(self, paths, labels=None, transform=None):
        self.paths = paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("RGB")
        if self.transform:
            img = self.transform(img)
        if self.labels is not None:
            return img, int(self.labels[idx])
        else:
            return img, os.path.basename(self.paths[idx])

In [None]:
# MODEL

def get_model(backbone_name, n_classes, pretrained=True):
    print("Loading model:", backbone_name)
    model = timm.create_model(backbone_name, pretrained=pretrained, num_classes=n_classes)
    if hasattr(model, "set_grad_checkpointing"):
        try:
            model.set_grad_checkpointing(True)
            print(" - grad checkpointing ON")
        except Exception:
            pass
    return model

In [None]:
# LR schedule

def make_scheduler(optimizer, total_epochs, steps_per_epoch, base_lr, warmup_epochs):
    total_steps = total_epochs * steps_per_epoch
    warmup_steps = max(1, warmup_epochs * steps_per_epoch)

    def lr_lambda(step):
        if step < warmup_steps:
            return float(step) / float(max(1, warmup_steps))
        # cosine from 1 -> 0
        progress = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps))
        return 0.5 * (1.0 + math.cos(math.pi * progress))
    return LambdaLR(optimizer, lr_lambda)

In [None]:
# Validation

def validate(model, loader, criterion):
    model.eval()
    val_loss, correct, total = 0.0, 0, 0
    with torch.no_grad(), torch.cuda.amp.autocast():
        for imgs, labels in loader:
            imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            val_loss += loss.item() * imgs.size(0)
            preds = torch.argmax(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return (correct / total) if total>0 else 0.0, (val_loss / total) if total>0 else 0.0

In [None]:
# TRAIN LOOP

def train_model(model, train_loader, val_loader, epochs, save_path, mixup_fn=None, use_ema=USE_EMA):
    soft_loss_fn = SoftTargetCrossEntropy()
    hard_loss_fn = LabelSmoothingCrossEntropy(smoothing=0.1)

    optimizer = AdamW(model.parameters(), lr=BASE_LR, weight_decay=0.01)
    scheduler = make_scheduler(optimizer, total_epochs=epochs, steps_per_epoch=len(train_loader),
                               base_lr=BASE_LR, warmup_epochs=WARMUP_EPOCHS)
    scaler = torch.cuda.amp.GradScaler() if DEVICE.startswith("cuda") else torch.cuda.amp.GradScaler()
    ema_model = None
    if use_ema and ModelEmaV2 is not None:
        try:
            ema_model = ModelEmaV2(model, decay=0.9999)
            print("EMA enabled")
        except Exception as e:
            print("EMA init failed:", e); ema_model = None

    best_val_acc = 0.0
    global_step = 0

    for epoch in range(epochs):
        model.train()
        running_loss, running_corrects, running_total = 0.0, 0, 0
        pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}/{epochs}")

        for step, (imgs, labels) in pbar:
            imgs = imgs.to(DEVICE); labels = labels.to(DEVICE)

            optimizer.zero_grad(set_to_none=True)

            if mixup_fn is not None:
                imgs, labels = mixup_fn(imgs, labels)

            with torch.cuda.amp.autocast():
                outputs = model(imgs)
                if mixup_fn is not None:
                    loss = soft_loss_fn(outputs, labels)
                    try:
                        hard_targets = torch.argmax(labels, dim=1)
                    except Exception:
                        hard_targets = None
                else:
                    loss = hard_loss_fn(outputs, labels)
                    hard_targets = labels

            scaler.scale(loss).backward()
            try:
                scaler.unscale_(optimizer)
            except Exception:
                pass
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()

            scheduler.step()
            global_step += 1

            if ema_model is not None:
                try:
                    ema_model.update(model)
                except Exception:
                    pass

            with torch.no_grad():
                preds = torch.argmax(outputs.detach(), dim=1)
                if hard_targets is not None:
                    running_corrects += (preds == hard_targets).sum().item()
                running_total += imgs.size(0)
                running_loss += loss.item() * imgs.size(0)

            pbar.set_postfix(loss=(running_loss / running_total), acc=(running_corrects / running_total if running_total>0 else 0.0))

        if ema_model is not None and hasattr(ema_model, "module"):
            model_for_eval = ema_model.module
        elif ema_model is not None and hasattr(ema_model, "state_dict"):
            model_for_eval = model
        else:
            model_for_eval = model

        val_acc, val_loss = validate(model_for_eval, val_loader, hard_loss_fn)
        print(f"[Epoch {epoch+1}] val_acc={val_acc:.4f} val_loss={val_loss:.4f}")

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            try:
                if ema_model is not None and hasattr(ema_model, "module"):
                    torch.save(model_for_eval.state_dict(), save_path)
                else:
                    torch.save(model.state_dict(), save_path)
                print(f"Saved best model -> {save_path} (val_acc={best_val_acc:.4f})")
            except Exception as e:
                print("Save failed:", e)

    print("Training finished. Best val acc:", best_val_acc)

In [None]:
# MAIN

full = datasets.ImageFolder(TRAIN_DIR)
n_classes = len(full.classes)
paths = [p for p, _ in full.samples]
labels = [int(l) for _, l in full.samples]

indices = list(range(len(paths)))
random.shuffle(indices)
split = int(0.9 * len(indices))
train_idx = indices[:split]
val_idx = indices[split:]

train_paths = [paths[i] for i in train_idx]
train_labels = [labels[i] for i in train_idx]
val_paths = [paths[i] for i in val_idx]
val_labels = [labels[i] for i in val_idx]

train_ds = SimpleImageDataset(train_paths, train_labels, train_tfms)
val_ds = SimpleImageDataset(val_paths, val_labels, val_tfms)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True, drop_last=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE*2, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

model = get_model(BACKBONE, n_classes, pretrained=True).to(DEVICE)

# Mixup/Cutmix
mixup_fn = None
if USE_MIXUP:
    mixup_fn = Mixup(
        mixup_alpha=0.4, cutmix_alpha=1.0,
        prob=1.0, switch_prob=0.5,
        label_smoothing=0.1, num_classes=n_classes
    )
    print("Mixup/CutMix enabled")

train_model(model, train_loader, val_loader, EPOCHS, MODEL_SAVE_PATH, mixup_fn=mixup_fn)

In [None]:
# INFERENCE + TTA

if os.path.exists(MODEL_SAVE_PATH):
    sd = torch.load(MODEL_SAVE_PATH, map_location=DEVICE)
    try:
        model.load_state_dict(sd)
    except Exception:
        model.load_state_dict(sd, strict=False)
model.eval()

test_imgs = sorted(glob(os.path.join(TEST_DIR, "*")))
rows = []

model_for_infer = model

with torch.no_grad(), torch.cuda.amp.autocast():
    for p in tqdm(test_imgs, desc="Inferencing with TTA"):
        img = Image.open(p).convert("RGB")
        preds = []
        if USE_TTA:
            for tta in tta_tfms:
                x = tta(img).unsqueeze(0).to(DEVICE)
                out = F.softmax(model_for_infer(x), dim=1)
                preds.append(out)
            avg_pred = torch.stack(preds, dim=0).mean(0)
        else:
            x = val_tfms(img).unsqueeze(0).to(DEVICE)
            avg_pred = F.softmax(model_for_infer(x), dim=1)
        idx = torch.argmax(avg_pred, dim=1).item()
        label = full.classes[idx]
        rows.append({"ID": os.path.basename(p), "label": label})

pd.DataFrame(rows).to_csv(OUT_FILE, index=False)
print("âœ… Submission saved to", OUT_FILE)