In [1]:
import os, argparse, itertools
from pathlib import Path
from collections import Counter, defaultdict
import torch, timm
from torch import nn
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchvision import datasets, transforms
from torch import amp
import numpy as np
from PIL import Image
from PIL import ImageFilter
from torch.utils.data import ConcatDataset


# dùng lại LetterboxSquare, PreprocessUpsample như file gốc

def remap_to_vis_mapping(ds, vis_class_to_idx, vis_classes):
    """
    Remap ImageFolder dataset `ds` sang mapping của visible theo tên folder.
    - ds: ImageFolder (partial)
    - vis_class_to_idx: mapping của visible (vd {'00':0, '01':1, ...})
    - vis_classes: list tên class của visible
    """
    new_samples = []
    for path, _ in ds.samples:
        cls_name = Path(path).parent.name  # tên folder, ví dụ '02'
        if cls_name not in vis_class_to_idx:
            # partial có class mà visible không có -> lỗi thật
            raise ValueError(f"Class '{cls_name}' trong partial KHÔNG tồn tại trong visible. Sửa lại folder trước.")
        new_label = vis_class_to_idx[cls_name]
        new_samples.append((path, new_label))

    ds.samples = new_samples
    ds.targets = [s[1] for s in new_samples]
    ds.class_to_idx = vis_class_to_idx
    ds.classes = vis_classes
    return ds

def finetune_with_partial(
    root="/kaggle/input/datav2/dataset_clsv2",
    ckpt_path="/kaggle/working/cls_ckpt/best.pt",
    outdir="/kaggle/working/cls_ckpt_ft_partial",
    model_name="convnext_tiny",
    weak_classes=(2,),          # class yếu (02 + có thể thêm 1,9,...)
    boost_factor=4.0,
    epochs=20,
    bs=140,
    size=224,
    lr=5e-5,
    alpha0=0.15,
    use_sampler=True,
    workers=4,
):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    outdir = Path(outdir); outdir.mkdir(parents=True, exist_ok=True)

    # 1) loaders dùng cả visible + partial
    (train_loader,
     val_loader_visible,
     val_loader_partial,
     classes,
     weights_ce,
     counts) = get_loaders_with_partial(
        root=root,
        bs=bs,
        size=size,
        workers=workers,
        use_sampler=use_sampler,
        alpha0=alpha0
    )

    n_classes = len(classes)
    print("Classes:", classes)
    print("Train class counts (vis+partial):", {i:int(c.item()) for i,c in enumerate(counts)})
    print("Weak classes:", weak_classes)

    # 2) buff weight cho class yếu
    weights_ft = weights_ce.clone()
    for c in weak_classes:
        if 0 <= c < n_classes:
            print(f"Boosting loss weight for class {c} by x{boost_factor}")
            weights_ft[c] *= float(boost_factor)
    weights_ft = weights_ft / weights_ft.sum() * len(weights_ft)

    # 3) load model từ ckpt visible
    ckpt = torch.load(ckpt_path, map_location=device)
    model = timm.create_model(model_name, pretrained=False, num_classes=n_classes)
    model.load_state_dict(ckpt["model"])
    model.to(device)

    criterion = nn.CrossEntropyLoss(weight=weights_ft.to(device))
    optim = torch.optim.AdamW(model.parameters(), lr=lr)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=epochs)
    scaler = amp.GradScaler("cuda", enabled=torch.cuda.is_available())

    best_vis = 0.0
    best_ep = -1

    print("=== Start finetune on visible + partial ===")
    for ep in range(1, epochs + 1):
        model.train()
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            optim.zero_grad(set_to_none=True)
            with amp.autocast("cuda", enabled=torch.cuda.is_available()):
                logits = model(x)
                loss = criterion(logits, y)
            scaler.scale(loss).backward()
            scaler.step(optim)
            scaler.update()

        # đánh giá trên visible + partial để xem trade-off

        acc_vis, cm_vis, pc_vis = evaluate(model, val_loader_visible, device, n_classes)
        acc_par, cm_par, pc_par = evaluate(model, val_loader_partial, device, n_classes)

        sched.step()

        print(f"[FT] Epoch {ep}/{epochs}  vis_acc={acc_vis:.4f}, par_acc={acc_par:.4f}")
        print("    vis per-class:", ", ".join(f"{i}:{a:.2f}" for i,a in enumerate(pc_vis)))
        print("    par per-class:", ", ".join(f"{i}:{a:.2f}" for i,a in enumerate(pc_par)))

        torch.save({"model": model.state_dict(), "classes": classes},
                   outdir / "ft_partial_last.pt")

        # chọn best theo visible hoặc trung bình (tuỳ anh)
        score = acc_vis  # hoặc (acc_vis + acc_par) / 2
        if score > best_vis:
            best_vis = score
            best_ep = ep
            torch.save({"model": model.state_dict(), "classes": classes},
                       outdir / "ft_partial_best.pt")

    print(f"[FT] Best score (vis) after finetune: {best_vis:.6f} (epoch {best_ep})")
    print(f"Saved ft_partial_best to: {outdir/'ft_partial_best.pt'}")

    return model




# dùng lại đúng 2 class bạn đã có:
class LetterboxSquare(object):
    def __init__(self, size):
        self.size = size

    def __call__(self, img: Image.Image):
        w, h = img.size
        s = min(self.size / w, self.size / h)
        nw, nh = int(round(w * s)), int(round(h * s))

        img = img.resize((nw, nh), Image.BILINEAR)

        pad_l = (self.size - nw) // 2
        pad_t = (self.size - nh) // 2
        pad_r = self.size - nw - pad_l
        pad_b = self.size - nh - pad_t

        return transforms.functional.pad(
            img,
            (pad_l, pad_t, pad_r, pad_b),
            padding_mode="reflect"
        )

class PreprocessUpsample(object):
    def __init__(self, out_size, unsharp_radius=1, unsharp_percent=150, unsharp_threshold=3):
        self.out_size = out_size
        self.radius = unsharp_radius
        self.percent = unsharp_percent
        self.threshold = unsharp_threshold

    def __call__(self, img: Image.Image):
        img = img.resize((self.out_size, self.out_size), resample=Image.BICUBIC)
        img = img.filter(ImageFilter.UnsharpMask(radius=self.radius,
                                                 percent=self.percent,
                                                 threshold=self.threshold))
        return img


def get_loaders_with_partial(root="/kaggle/input/datav2/dataset_clsv2",
                             bs=140, size=224, workers=4,
                             use_sampler=True, alpha0=0.15,
                             partial_factor=0.3):
    MEAN, STD = (0.485,0.456,0.406), (0.229,0.224,0.225)

    prep = PreprocessUpsample(size)

    tf_train = transforms.Compose([
        LetterboxSquare(size),
        prep,
        transforms.ColorJitter(0.2,0.2,0.2,0.1),
        transforms.ToTensor(),
        transforms.Normalize(MEAN, STD),
    ])

    tf_val = transforms.Compose([
        LetterboxSquare(size),
        prep,
        transforms.ToTensor(),
        transforms.Normalize(MEAN, STD),
    ])

    root = Path(root)

    # === visible ===
    vis_train = datasets.ImageFolder(root/"visible"/"train", tf_train)
    vis_val   = datasets.ImageFolder(root/"visible"/"val",   tf_val)

    vis_classes = vis_train.classes
    vis_class_to_idx = vis_train.class_to_idx
    n_classes = len(vis_classes)

    # === partial (raw) ===
    part_train_raw = datasets.ImageFolder(root/"partial"/"train", tf_train)
    part_val_raw   = datasets.ImageFolder(root/"partial"/"val",   tf_val)

    print("VISIBLE classes:", vis_classes)
    print("PARTIAL raw classes:", part_train_raw.classes)

    # remap partial -> dùng chung mapping với visible
    part_train = remap_to_vis_mapping(part_train_raw, vis_class_to_idx, vis_classes)
    part_val   = remap_to_vis_mapping(part_val_raw,   vis_class_to_idx, vis_classes)

    # ===== train dataset gộp: visible + partial =====
    train_ds = ConcatDataset([vis_train, part_train])

    # ===== val loader =====
    val_ds_visible = vis_val
    val_ds_partial = part_val

    # ----- class counts & weights (trên train gộp) -----
    counts = torch.zeros(n_classes, dtype=torch.float32)
    for ds_ in [vis_train, part_train]:
        for _, y in ds_.samples:
            counts[y] += 1

    inv = 1.0 / (counts + 1e-6)
    weights_ce = inv / inv.sum() * len(inv)

    # deweight class 0
    weights_ce[0] *= float(alpha0)
    weights_ce = weights_ce / weights_ce.sum() * len(weights_ce)

    # ----- sampler -----
       # ----- sampler -----
    if use_sampler:
        per_class_sample_w = inv.clone()
        per_class_sample_w[0] *= float(alpha0)
        per_class_sample_w = per_class_sample_w / per_class_sample_w.sum() * len(per_class_sample_w)

        sample_weights = []
        # visible bình thường, partial bị giảm weight theo partial_factor
        for ds_ in [vis_train, part_train]:
            is_partial = (ds_ is part_train)
            for _, y in ds_.samples:
                w = per_class_sample_w[y].item()
                if is_partial:
                    w *= partial_factor   # <== chỗ này làm partial "yếu" hơn
                sample_weights.append(w)

        sample_weights = torch.tensor(sample_weights, dtype=torch.float32)

        sampler = WeightedRandomSampler(
            sample_weights,
            num_samples=len(sample_weights),
            replacement=True
        )
        train_loader = DataLoader(
            train_ds, batch_size=bs, sampler=sampler,
            num_workers=workers, pin_memory=True
        )
    else:
        train_loader = DataLoader(
            train_ds, batch_size=bs, shuffle=True,
            num_workers=workers, pin_memory=True
        )

    val_loader_visible = DataLoader(
        val_ds_visible, batch_size=bs, shuffle=False,
        num_workers=workers, pin_memory=True
    )
    val_loader_partial = DataLoader(
        val_ds_partial, batch_size=bs, shuffle=False,
        num_workers=workers, pin_memory=True
    )

    return (train_loader,
            val_loader_visible,
            val_loader_partial,
            vis_classes,
            weights_ce,
            counts)


def get_loaders(root="/kaggle/input/datav2/dataset_clsv2/visible", bs=120, size=224, workers=4, use_sampler=False, alpha0=0.05):
    MEAN, STD = (0.485,0.456,0.406), (0.229,0.224,0.225)

    prep = PreprocessUpsample(size)   # thực hiện upsample + unsharp

    tf_train = transforms.Compose([
        LetterboxSquare(size),          # giữ nội dung, letterbox
        prep,                           # bicubic resize -> unsharp
        # transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(0.2,0.2,0.2,0.1),
        # transforms.GaussianBlur(kernel_size=3),
        transforms.ToTensor(),
        transforms.Normalize(MEAN, STD),
    ])

    tf_val = transforms.Compose([
        LetterboxSquare(size),
        prep,
        transforms.ToTensor(),
        transforms.Normalize(MEAN, STD),
    ])
    train_ds = datasets.ImageFolder(f"{root}/train", tf_train)
    val_ds   = datasets.ImageFolder(f"{root}/val",   tf_val)

    # ----- class counts & weights -----
    counts = torch.zeros(len(train_ds.classes), dtype=torch.float32)
    for _, y in train_ds.samples:
        counts[y] += 1
    inv = 1.0 / (counts + 1e-6)
    weights_ce = inv / inv.sum() * len(inv)

    # deweight class 0 in loss
    weights_ce[0] *= float(alpha0)
    weights_ce = weights_ce / weights_ce.sum() * len(weights_ce)

    # ----- sampler (optional) -----
    if use_sampler:
        # sample weight per-sample ~ 1/freq, nhưng cũng giảm class 0 trong sampler
        per_class_sample_w = inv.clone()
        per_class_sample_w[0] *= float(alpha0)
        per_class_sample_w = per_class_sample_w / per_class_sample_w.sum() * len(per_class_sample_w)
        sample_weights = torch.tensor([per_class_sample_w[y] for _, y in train_ds.samples], dtype=torch.float32)
        sampler = WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True)
        train_loader = DataLoader(train_ds, batch_size=bs, sampler=sampler, num_workers=workers, pin_memory=True)
    else:
        train_loader = DataLoader(train_ds, batch_size=bs, shuffle=True, num_workers=workers, pin_memory=True)

    val_loader   = DataLoader(val_ds,   batch_size=bs, shuffle=False, num_workers=workers, pin_memory=True)
    return train_loader, val_loader, train_ds.classes, weights_ce, counts

@torch.no_grad()
def evaluate(model, loader, device, n_classes):
    model.eval()
    correct = 0
    tot = 0
    # confusion matrix
    cm = np.zeros((n_classes, n_classes), dtype=np.int64)
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        pred = logits.argmax(1)
        correct += (pred == y).sum().item()
        tot += y.numel()
        for t, p in zip(y.view(-1), pred.view(-1)):
            cm[int(t.item()), int(p.item())] += 1
    acc = correct / tot if tot else 0.0
    # per-class acc
    per_class_acc = []
    for c in range(n_classes):
        denom = cm[c].sum()
        per_class_acc.append(cm[c, c] / denom if denom > 0 else 0.0)
    return acc, cm, per_class_acc

def pretty_cm(cm, classes):
    # compact text table (first 6 cols if too wide)
    head = "     " + " ".join([f"{i:>4}" for i in range(len(classes))])
    rows = []
    for i in range(len(classes)):
        rows.append(f"{i:>3}: " + " ".join([f"{cm[i,j]:>4}" for j in range(len(classes))]))
    return head + "\n" + "\n".join(rows)

def finetune_weak_classes(
    root="/kaggle/input/datav2/dataset_clsv2/visible",
    ckpt_path="/kaggle/working/cls_ckpt/best.pt",
    outdir="/kaggle/working/cls_ckpt_ft",
    model_name="convnext_tiny",
    weak_classes=(2,),          # class yếu, mặc định buff riêng 02
    boost_factor=4.0,           # nhân weight trong loss cho các class yếu
    epochs=20,
    bs=120,
    size=224,
    lr=5e-5,
    alpha0=0.05,
    use_sampler=True,
    workers=4,
    pretrained=False            # finetune từ ckpt nên không cần pretrained
):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    outdir = Path(outdir); outdir.mkdir(parents=True, exist_ok=True)

    # 1) Tạo lại loaders + weights_ce như ban đầu
    train_loader, val_loader, classes, weights_ce, counts = get_loaders(
        root=root,
        bs=bs,
        size=size,
        workers=workers,
        use_sampler=use_sampler,
        alpha0=alpha0
    )
    n_classes = len(classes)

    print("Classes:", classes)
    print("Train class counts:", {i:int(c.item()) for i,c in enumerate(counts)})
    print("Weak classes:", weak_classes)

    # 2) Tăng weight cho class yếu trong loss
    weights_ft = weights_ce.clone()
    for c in weak_classes:
        if 0 <= c < n_classes:
            print(f"Boosting loss weight for class {c} by x{boost_factor}")
            weights_ft[c] *= float(boost_factor)
        else:
            print(f"[WARN] class index {c} ngoài range, bỏ qua.")

    # chuẩn hóa lại cho đẹp (không bắt buộc nhưng gọn)
    weights_ft = weights_ft / weights_ft.sum() * len(weights_ft)

    print("Original CE weights :", weights_ce.cpu().numpy())
    print("Finetune CE weights :", weights_ft.cpu().numpy())

    # 3) Tạo model + load ckpt tốt nhất
    ckpt = torch.load(ckpt_path, map_location=device)
    backbone = timm.create_model(model_name, pretrained=pretrained, num_classes=n_classes)
    backbone.load_state_dict(ckpt["model"])
    backbone.to(device)

    # 4) Loss, optimizer, scheduler
    criterion = nn.CrossEntropyLoss(weight=weights_ft.to(device))
    optim = torch.optim.AdamW(backbone.parameters(), lr=lr)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=epochs)
    scaler = amp.GradScaler("cuda", enabled=torch.cuda.is_available())

    best = 0.0
    best_ep = -1

    print("=== Start finetune on weak classes ===")
    for ep in range(1, epochs + 1):
        backbone.train()
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            optim.zero_grad(set_to_none=True)
            with amp.autocast("cuda", enabled=torch.cuda.is_available()):
                logits = backbone(x)
                loss = criterion(logits, y)
            scaler.scale(loss).backward()
            scaler.step(optim)
            scaler.update()

        acc, cm, per_class = evaluate(backbone, val_loader, device, n_classes)
        sched.step()

        print(f"[FT] Epoch {ep}/{epochs}  val_acc={acc:.4f}  | per-class acc: "
              + ', '.join(f'{i}:{a:.2f}' for i, a in enumerate(per_class)))
        # Nếu muốn xem confusion matrix:
        # print(pretty_cm(cm, classes))

        # save last
        torch.save({"model": backbone.state_dict(), "classes": classes},
                   outdir / "ft_last.pt")

        # save best theo val_acc (sau finetune)
        if acc > best:
            best = acc
            best_ep = ep
            torch.save({"model": backbone.state_dict(), "classes": classes},
                       outdir / "ft_best.pt")

    print(f"[FT] Best val acc after finetune: {best:.6f} (epoch {best_ep})")
    print(f"Saved ft_best to: {outdir/'ft_best.pt'} | ft_last to: {outdir/'ft_last.pt'}")

    return backbone


if __name__ == "__main__":
    # Ví dụ chạy finetune chỉ buff class 02
    # finetune_weak_classes(
    #     root="/kaggle/input/datav2/dataset_clsv2/visible",
    #     ckpt_path="/kaggle/input/bestpt2/best2.pt",
    #     outdir="/kaggle/working/cls_ckpt_ft",
    #     model_name="convnext_tiny",
    #     weak_classes=(2,1),      # nếu muốn buff thêm class 1,9: (2,1,9)
    #     boost_factor=4.0,
    #     epochs=20,
    #     bs=120,
    #     size=224,
    #     lr=5e-5,
    #     alpha0=0.05,
    #     use_sampler=True,
    #     workers=4
    # )
# trong notebook
    finetune_with_partial(
        root="/kaggle/input/datav2/dataset_clsv2",
        ckpt_path="/kaggle/input/bestpt3/ft_best.pt",  # model đã train trên visible
        weak_classes=(2,1),      # thêm class yếu khác nếu cần
        boost_factor=3.0,
        epochs=50,
        lr=1e-4
    )




VISIBLE classes: ['00', '01', '02', '03', '04', '05', '06', '07', '08', '09', '10', '11']
PARTIAL raw classes: ['01', '02', '03', '04', '05', '06', '07', '08', '09', '10', '11']
Classes: ['00', '01', '02', '03', '04', '05', '06', '07', '08', '09', '10', '11']
Train class counts (vis+partial): {0: 18184, 1: 4497, 2: 3373, 3: 4770, 4: 3892, 5: 7387, 6: 4934, 7: 6917, 8: 7809, 9: 4748, 10: 7112, 11: 16586}
Weak classes: (2, 1)
Boosting loss weight for class 2 by x3.0
Boosting loss weight for class 1 by x3.0
=== Start finetune on visible + partial ===
[FT] Epoch 1/50  vis_acc=0.4730, par_acc=0.3515
    vis per-class: 0:0.13, 1:0.36, 2:0.00, 3:0.82, 4:0.89, 5:0.93, 6:0.98, 7:0.97, 8:1.00, 9:0.80, 10:1.00, 11:0.57
    par per-class: 0:0.00, 1:0.07, 2:0.00, 3:0.58, 4:0.73, 5:0.62, 6:0.55, 7:0.71, 8:0.40, 9:0.17, 10:0.67, 11:0.14
[FT] Epoch 2/50  vis_acc=0.5092, par_acc=0.3473
    vis per-class: 0:0.26, 1:0.29, 2:0.00, 3:0.86, 4:0.90, 5:0.94, 6:0.99, 7:0.95, 8:1.00, 9:0.58, 10:1.00, 11:0.47
  