In [None]:
import os, gc, random, math, time
import numpy as np
import pandas as pd
from glob import glob
from tqdm import tqdm
from PIL import Image

import torch
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets

import timm
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from timm.data import Mixup
try:
    from timm.utils import ModelEmaV2
except Exception:
    ModelEmaV2 = None

In [None]:
# CONFIG (tweak if needed)

TRAIN_DIR = "/kaggle/input/action-dm-dataset/action-dm-dataset_2/train"
TEST_DIR  = "/kaggle/input/action-dm-dataset/action-dm-dataset_2/test/test"
MODEL_SAVE_PATH = "efficientnetv2_l_best.pth"
OUT_FILE = "submission_efficientnetv2_l_tta.csv"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SEED = 42
torch.manual_seed(SEED); np.random.seed(SEED); random.seed(SEED)
torch.backends.cudnn.benchmark = True

BATCH_SIZE = 16
NUM_WORKERS = 4
IMG_SIZE = 480
EPOCHS = 25
BACKBONE = "tf_efficientnetv2_l_in21ft1k"
BASE_LR = 1e-4
WARMUP_EPOCHS = 1

USE_EMA = True
USE_MIXUP = True
USE_TTA = True

print("DEVICE:", DEVICE)
print("BACKBONE:", BACKBONE)

In [None]:
# TRANSFORMS & TTA

train_tfms = transforms.Compose([
    transforms.RandomResizedCrop(IMG_SIZE, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.25, 0.25, 0.25, 0.1),
    transforms.RandomRotation(10),
    transforms.RandomApply([transforms.GaussianBlur(3)], p=0.18),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225]),
    transforms.RandomErasing(p=0.3, scale=(0.02,0.12))
])

val_tfms = transforms.Compose([
    transforms.Resize(int(IMG_SIZE * 1.14)),
    transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225])
])

tta_tfms = [
    val_tfms,
    transforms.Compose([
        transforms.Resize(int(IMG_SIZE * 1.14)),
        transforms.CenterCrop(IMG_SIZE),
        transforms.RandomHorizontalFlip(p=1.0),
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
    ]),
    transforms.Compose([
        transforms.Resize(int(IMG_SIZE * 1.14)),
        transforms.RandomRotation(10),
        transforms.CenterCrop(IMG_SIZE),
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
    ]),
    transforms.Compose([
        transforms.RandomResizedCrop(IMG_SIZE, scale=(0.9,1.0)),
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
    ])
]

In [None]:
# DATASET

class SimpleImageDataset(Dataset):
    def __init__(self, paths, labels=None, transform=None):
        self.paths = paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("RGB")
        if self.transform:
            img = self.transform(img)
        if self.labels is not None:
            return img, int(self.labels[idx])
        else:
            return img, os.path.basename(self.paths[idx])

In [None]:
# MODEL helper

def get_model(backbone_name, n_classes, pretrained=True):
    print("=> create model:", backbone_name)
    model = timm.create_model(backbone_name, pretrained=pretrained, num_classes=n_classes)
    try:
        if hasattr(model, "set_grad_checkpointing"):
            model.set_grad_checkpointing(True)
            print(" - grad checkpointing enabled")
    except Exception:
        pass
    return model

In [None]:
# LR: warmup -> cosine schedule (per-batch)

def make_scheduler(optimizer, total_epochs, steps_per_epoch, warmup_epochs):
    total_steps = total_epochs * steps_per_epoch
    warmup_steps = max(1, warmup_epochs * steps_per_epoch)
    def lr_lambda(step):
        if step < warmup_steps:
            return float(step) / float(max(1, warmup_steps))
        progress = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps))
        return 0.5 * (1.0 + math.cos(math.pi * progress))
    return LambdaLR(optimizer, lr_lambda)

In [None]:
# VALIDATE

def validate(model, loader, criterion):
    model.eval()
    loss_sum = 0.0; correct = 0; total = 0
    device = DEVICE
    with torch.no_grad(), torch.cuda.amp.autocast():
        for imgs, labels in loader:
            imgs = imgs.to(device); labels = labels.to(device)
            logits = model(imgs)
            loss = criterion(logits, labels)
            loss_sum += loss.item() * imgs.size(0)
            preds = logits.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += imgs.size(0)
    return (correct/total) if total>0 else 0.0, (loss_sum/total) if total>0 else 0.0

In [None]:
# TRAIN LOOP

def train_model(model, train_loader, val_loader, epochs, save_path, mixup_fn=None, use_ema=USE_EMA):
    device = DEVICE
    soft_loss = SoftTargetCrossEntropy()
    hard_loss = LabelSmoothingCrossEntropy(smoothing=0.1)

    optimizer = AdamW(model.parameters(), lr=BASE_LR, weight_decay=0.01)
    scheduler = make_scheduler(optimizer, total_epochs=epochs, steps_per_epoch=len(train_loader), warmup_epochs=WARMUP_EPOCHS)
    scaler = torch.cuda.amp.GradScaler() if device.startswith("cuda") else None

    ema_model = None
    if use_ema and ModelEmaV2 is not None:
        try:
            ema_model = ModelEmaV2(model, decay=0.9999)
            print("=> EMA enabled")
        except Exception as e:
            print("=> EMA init failed:", e); ema_model = None

    best_val_acc = 0.0
    global_step = 0

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0; running_correct = 0; running_total = 0
        pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}/{epochs}")

        for step, (imgs, labels) in pbar:
            imgs = imgs.to(device); labels = labels.to(device)
            optimizer.zero_grad(set_to_none=True)

            if mixup_fn is not None:
                imgs, labels_mixed = mixup_fn(imgs, labels)
                targets = labels_mixed
                use_soft = True
            else:
                targets = labels
                use_soft = False

            with torch.cuda.amp.autocast():
                logits = model(imgs)
                if use_soft:
                    loss = soft_loss(logits, targets)
                    try:
                        hard_targets = targets.argmax(dim=1)
                    except Exception:
                        hard_targets = None
                else:
                    loss = hard_loss(logits, targets)
                    hard_targets = targets

            if scaler is not None:
                scaler.scale(loss).backward()
                try:
                    scaler.unscale_(optimizer)
                except Exception:
                    pass
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()

            scheduler.step()
            global_step += 1

            if ema_model is not None:
                try:
                    ema_model.update(model)
                except Exception:
                    pass

            with torch.no_grad():
                preds = logits.argmax(dim=1)
                if hard_targets is not None:
                    running_correct += (preds == hard_targets).sum().item()
                running_total += imgs.size(0)
                running_loss += loss.item() * imgs.size(0)

            pbar.set_postfix(loss=(running_loss/running_total if running_total>0 else 0.0),
                             acc=(running_correct/running_total if running_total>0 else 0.0),
                             lr=optimizer.param_groups[0]['lr'])

        if ema_model is not None and hasattr(ema_model, "module"):
            model_for_eval = ema_model.module
        else:
            model_for_eval = model

        val_acc, val_loss = validate(model_for_eval, val_loader, hard_loss)
        print(f"[Epoch {epoch+1}] val_acc={val_acc:.4f} val_loss={val_loss:.4f}")

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            try:
                if ema_model is not None and hasattr(ema_model, "module"):
                    torch.save(model_for_eval.state_dict(), save_path)
                elif ema_model is not None and hasattr(ema_model, "state_dict"):
                    torch.save(ema_model.state_dict(), save_path)
                else:
                    torch.save(model.state_dict(), save_path)
                print(f"=> Saved best model to {save_path} (val_acc={best_val_acc:.4f})")
            except Exception as e:
                print("Save failed:", e)

    print("Training complete. Best val acc:", best_val_acc)

In [None]:
# MAIN: prepare data, model, mixup

full = datasets.ImageFolder(TRAIN_DIR)
n_classes = len(full.classes)
paths = [p for p,_ in full.samples]
labels = [int(l) for _,l in full.samples]

indices = list(range(len(paths)))
random.shuffle(indices)
split = int(0.9 * len(indices))
train_idx = indices[:split]; val_idx = indices[split:]

train_paths = [paths[i] for i in train_idx]; train_labels = [labels[i] for i in train_idx]
val_paths   = [paths[i] for i in val_idx];   val_labels   = [labels[i] for i in val_idx]

train_ds = SimpleImageDataset(train_paths, train_labels, train_tfms)
val_ds   = SimpleImageDataset(val_paths,   val_labels,   val_tfms)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True, drop_last=True)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE*2, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

model = get_model(BACKBONE, n_classes, pretrained=True).to(DEVICE)

mixup_fn = None
if USE_MIXUP:
    mixup_fn = Mixup(
        mixup_alpha=0.5,
        cutmix_alpha=1.0,
        prob=0.5,
        switch_prob=0.5,
        mode='batch',
        label_smoothing=0.05,
        num_classes=n_classes
    )
    print("Mixup/CutMix enabled (alpha=0.5)")

train_model(model, train_loader, val_loader, EPOCHS, MODEL_SAVE_PATH, mixup_fn=mixup_fn, use_ema=USE_EMA)

In [None]:
# INFERENCE + TTA

if os.path.exists(MODEL_SAVE_PATH):
    sd = torch.load(MODEL_SAVE_PATH, map_location=DEVICE)
    try:
        model.load_state_dict(sd)
    except Exception:
        try:
            model.load_state_dict(sd, strict=False)
        except Exception:
            print("Warning: can't fully load checkpoint with strict=False; continuing with partial load.")

model.eval()
test_imgs = sorted(glob(os.path.join(TEST_DIR, "*")))
rows = []
with torch.no_grad(), torch.cuda.amp.autocast():
    for p in tqdm(test_imgs, desc="Inference with TTA"):
        img = Image.open(p).convert("RGB")
        if USE_TTA:
            preds = []
            for t in tta_tfms:
                x = t(img).unsqueeze(0).to(DEVICE)
                logits = model(x)
                preds.append(logits)
            avg_logits = torch.stack(preds, dim=0).mean(0)
            probs = F.softmax(avg_logits, dim=1)
        else:
            x = val_tfms(img).unsqueeze(0).to(DEVICE)
            probs = F.softmax(model(x), dim=1)
        idx = torch.argmax(probs, dim=1).item()
        rows.append({"ID": os.path.basename(p), "label": full.classes[idx]})

pd.DataFrame(rows).to_csv(OUT_FILE, index=False)
print("Saved submission:", OUT_FILE)