In [1]:
import os, argparse, itertools
from pathlib import Path
from collections import Counter, defaultdict
import torch, timm
from torch import nn
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchvision import datasets, transforms
from torch import amp
import numpy as np
from PIL import Image
from PIL import ImageFilter

class LetterboxSquare(object):
    def __init__(self, size):
        self.size = size

    def __call__(self, img: Image.Image):
        w, h = img.size
        s = min(self.size / w, self.size / h)
        nw, nh = int(round(w * s)), int(round(h * s))

        img = img.resize((nw, nh), Image.BILINEAR)

        pad_l = (self.size - nw) // 2
        pad_t = (self.size - nh) // 2
        pad_r = self.size - nw - pad_l
        pad_b = self.size - nh - pad_t

        return transforms.functional.pad(
            img,
            (pad_l, pad_t, pad_r, pad_b),
            padding_mode="reflect"
        )

class PreprocessUpsample(object):
    """Upsample bằng bicubic rồi apply UnsharpMask (PIL)."""
    def __init__(self, out_size, unsharp_radius=1, unsharp_percent=150, unsharp_threshold=3):
        self.out_size = out_size
        self.radius = unsharp_radius
        self.percent = unsharp_percent
        self.threshold = unsharp_threshold

    def __call__(self, img: Image.Image):
        # img đã được letterboxed trước (nếu bạn dùng LetterboxSquare)
        img = img.resize((self.out_size, self.out_size), resample=Image.BICUBIC)
        # Unsharp mask: tăng nét cạnh
        img = img.filter(ImageFilter.UnsharpMask(radius=self.radius,
                                                 percent=self.percent,
                                                 threshold=self.threshold))
        return img


def get_loaders(root="/kaggle/input/datav2/dataset_clsv2/visible", bs=120, size=224, workers=4, use_sampler=False, alpha0=0.05):
    MEAN, STD = (0.485,0.456,0.406), (0.229,0.224,0.225)

    prep = PreprocessUpsample(size)   # thực hiện upsample + unsharp

    tf_train = transforms.Compose([
        LetterboxSquare(size),          # giữ nội dung, letterbox
        prep,                           # bicubic resize -> unsharp
        # transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(0.2,0.2,0.2,0.1),
        # transforms.GaussianBlur(kernel_size=3),
        transforms.ToTensor(),
        transforms.Normalize(MEAN, STD),
    ])

    tf_val = transforms.Compose([
        LetterboxSquare(size),
        prep,
        transforms.ToTensor(),
        transforms.Normalize(MEAN, STD),
    ])
    train_ds = datasets.ImageFolder(f"{root}/train", tf_train)
    val_ds   = datasets.ImageFolder(f"{root}/val",   tf_val)

    # ----- class counts & weights -----
    counts = torch.zeros(len(train_ds.classes), dtype=torch.float32)
    for _, y in train_ds.samples:
        counts[y] += 1
    inv = 1.0 / (counts + 1e-6)
    weights_ce = inv / inv.sum() * len(inv)

    # deweight class 0 in loss
    weights_ce[0] *= float(alpha0)
    weights_ce = weights_ce / weights_ce.sum() * len(weights_ce)

    # ----- sampler (optional) -----
    if use_sampler:
        # sample weight per-sample ~ 1/freq, nhưng cũng giảm class 0 trong sampler
        per_class_sample_w = inv.clone()
        per_class_sample_w[0] *= float(alpha0)
        per_class_sample_w = per_class_sample_w / per_class_sample_w.sum() * len(per_class_sample_w)
        sample_weights = torch.tensor([per_class_sample_w[y] for _, y in train_ds.samples], dtype=torch.float32)
        sampler = WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True)
        train_loader = DataLoader(train_ds, batch_size=bs, sampler=sampler, num_workers=workers, pin_memory=True)
    else:
        train_loader = DataLoader(train_ds, batch_size=bs, shuffle=True, num_workers=workers, pin_memory=True)

    val_loader   = DataLoader(val_ds,   batch_size=bs, shuffle=False, num_workers=workers, pin_memory=True)
    return train_loader, val_loader, train_ds.classes, weights_ce, counts

@torch.no_grad()
def evaluate(model, loader, device, n_classes):
    model.eval()
    correct = 0
    tot = 0
    # confusion matrix
    cm = np.zeros((n_classes, n_classes), dtype=np.int64)
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        pred = logits.argmax(1)
        correct += (pred == y).sum().item()
        tot += y.numel()
        for t, p in zip(y.view(-1), pred.view(-1)):
            cm[int(t.item()), int(p.item())] += 1
    acc = correct / tot if tot else 0.0
    # per-class acc
    per_class_acc = []
    for c in range(n_classes):
        denom = cm[c].sum()
        per_class_acc.append(cm[c, c] / denom if denom > 0 else 0.0)
    return acc, cm, per_class_acc

def pretty_cm(cm, classes):
    # compact text table (first 6 cols if too wide)
    head = "     " + " ".join([f"{i:>4}" for i in range(len(classes))])
    rows = []
    for i in range(len(classes)):
        rows.append(f"{i:>3}: " + " ".join([f"{cm[i,j]:>4}" for j in range(len(classes))]))
    return head + "\n" + "\n".join(rows)

def main():
    ap = argparse.ArgumentParser("Advanced jersey-group classifier training (12 classes)")
    ap.add_argument("--root", default="/kaggle/input/datav2/dataset_clsv2/visible", help="root containing train/ and val/")
    ap.add_argument("--model", default="convnext_tiny")
    ap.add_argument("--epochs", type=int, default=100)
    ap.add_argument("--bs", type=int, default=120)
    ap.add_argument("--size", type=int, default=224)
    ap.add_argument("--lr", type=float, default=2e-4)
    ap.add_argument("--pretrained", type=int, default=1, help="1 to use timm pretrained weights")
    ap.add_argument("--sampler", type=int, default=1, help="1 = enable WeightedRandomSampler")
    ap.add_argument("--alpha0", type=float, default=0.05, help="downweight factor for class 0 (loss & sampler)")
    ap.add_argument("--patience", type=int, default=50, help="early stopping patience (epochs)")
    ap.add_argument("--outdir", default="/kaggle/working/cls_ckpt")
    args = ap.parse_args([]) if "get_ipython" in globals() else ap.parse_args()

    device = "cuda" if torch.cuda.is_available() else "cpu"
    outdir = Path(args.outdir); outdir.mkdir(parents=True, exist_ok=True)

    # loaders & weights
    train_loader, val_loader, classes, weights_ce, counts = get_loaders(
        root=args.root, bs=args.bs, size=args.size, workers=4,
        use_sampler=bool(args.sampler), alpha0=args.alpha0
    )
    n_classes = len(classes)
    print("Classes:", classes)
    print("Train class counts:", {i:int(c.item()) for i,c in enumerate(counts)})
    print(f"Using sampler: {bool(args.sampler)} | alpha0={args.alpha0}")

    # model
    backbone = timm.create_model(args.model, pretrained=bool(args.pretrained), num_classes=n_classes)
    backbone.to(device)



    criterion = nn.CrossEntropyLoss(weight=weights_ce.to(device))
    optim = torch.optim.AdamW(backbone.parameters(), lr=args.lr)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=args.epochs)

    scaler = amp.GradScaler("cuda", enabled=torch.cuda.is_available())

    # training loop with early stopping
    best = 0.0
    best_ep = -1
    patience_left = args.patience

    for ep in range(1, args.epochs + 1):
        backbone.train()
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            optim.zero_grad(set_to_none=True)
            with amp.autocast("cuda", enabled=torch.cuda.is_available()):
                logits = backbone(x)
                loss = criterion(logits, y)
            scaler.scale(loss).backward()
            scaler.step(optim)
            scaler.update()

        acc, cm, per_class = evaluate(backbone, val_loader, device, n_classes)
        sched.step()
        print(f"Epoch {ep}/{args.epochs}  val_acc={acc:.4f}  | per-class acc: "
      + ', '.join(f'{i}:{a:.2f}' for i, a in enumerate(per_class)))
    
        # (optional) show the rest:
        # print("Per-class acc:", ", ".join(f"{i}:{a:.2f}" for i,a in enumerate(per_class)))
        # print(pretty_cm(cm, classes))

        # save last
        torch.save({"model": backbone.state_dict(), "classes": classes}, outdir / "last.pt")

        # early stopping on best val acc
        if acc > best:
            best = acc
            best_ep = ep
            torch.save({"model": backbone.state_dict(), "classes": classes}, outdir / "best.pt")
            patience_left = args.patience
        else:
            patience_left -= 1

        if patience_left <= 0:
            print(f"Early stopping at epoch {ep} (best={best:.4f} @ {best_ep})")
            break

    print(f"Best val acc: {best:.6f} (epoch {best_ep})")
    print(f"Saved best to: {outdir/'best.pt'} | last to: {outdir/'last.pt'}")

if __name__ == "__main__":
    main()



Classes: ['00', '01', '02', '03', '04', '05', '06', '07', '08', '09', '10', '11']
Train class counts: {0: 18184, 1: 1863, 2: 2062, 3: 2356, 4: 878, 5: 3922, 6: 1583, 7: 3457, 8: 3589, 9: 2728, 10: 1802, 11: 5118}
Using sampler: True | alpha0=0.05


model.safetensors:   0%|          | 0.00/114M [00:00<?, ?B/s]

Epoch 1/100  val_acc=0.4268  | per-class acc: 0:0.00, 1:0.47, 2:0.01, 3:0.99, 4:0.79, 5:0.87, 6:0.94, 7:0.93, 8:1.00, 9:0.87, 10:0.80, 11:0.51
Epoch 2/100  val_acc=0.4498  | per-class acc: 0:0.00, 1:0.86, 2:0.08, 3:0.95, 4:0.92, 5:0.94, 6:0.96, 7:0.39, 8:0.95, 9:0.56, 10:0.74, 11:0.42
Epoch 3/100  val_acc=0.4470  | per-class acc: 0:0.00, 1:0.52, 2:0.06, 3:0.96, 4:0.66, 5:0.97, 6:0.96, 7:0.93, 8:1.00, 9:0.70, 10:0.83, 11:0.71
Epoch 4/100  val_acc=0.4776  | per-class acc: 0:0.00, 1:0.69, 2:0.00, 3:0.86, 4:0.78, 5:0.89, 6:0.96, 7:0.93, 8:0.89, 9:0.95, 10:0.96, 11:0.91
Epoch 5/100  val_acc=0.3578  | per-class acc: 0:0.00, 1:0.43, 2:0.00, 3:0.59, 4:0.93, 5:0.92, 6:0.98, 7:0.85, 8:0.31, 9:0.64, 10:0.70, 11:0.27
Epoch 6/100  val_acc=0.3725  | per-class acc: 0:0.00, 1:0.29, 2:0.00, 3:0.90, 4:0.86, 5:0.95, 6:0.98, 7:0.75, 8:0.95, 9:0.72, 10:1.00, 11:0.16
Epoch 7/100  val_acc=0.4764  | per-class acc: 0:0.10, 1:0.47, 2:0.00, 3:0.92, 4:0.94, 5:0.99, 6:0.93, 7:0.59, 8:0.85, 9:0.64, 10:0.80, 11:0.57