In [3]:
import os, gc, json, random, datetime, logging, csv
import numpy as np
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
from sklearn.metrics import classification_report, accuracy_score, balanced_accuracy_score, f1_score
import timm
from torch.amp import autocast, GradScaler

# ---------------------------- CONFIG ----------------------------
class CFG:
    DATA_PATHS = {
        "train": "/kaggle/input/minida/mini_output1/train",
        "val":   "/kaggle/input/minida/mini_output1/val",
        "test":  "/kaggle/input/minida/mini_output1/test",
    }

    IMG_SIZE = 256
    BATCH = 32
    NUM_WORKERS = 2
    AMP = True
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

    # CE-only fine-tuning (safe/simple)
    EPOCHS = 30
    LR = 1e-4
    WD = 0.05
    LABEL_SMOOTHING = 0.1
    PATIENCE = 7

    SEED = 42

    OUTPUT_DIR = "./swin_controlled_final_strict_eval"
    RUN_TS = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")

    # Controlled architecture: SwinV2-Base 192->256 window12->16
    BASE_SCRATCH = "swinv2_base_window12to16_192to256"
    BASE_IN22K_FT_IN1K = "swinv2_base_window12to16_192to256.ms_in22k_ft_in1k"

    # Optional IN1K baseline (different capacity, but helpful)
    RUN_SMALL_IN1K_BASELINE = True
    SMALL_IN1K = "swinv2_small_window16_256.ms_in1k"

cfg = CFG()
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)

# ---------------------------- LOGGING ----------------------------
logger = logging.getLogger("study")
logger.setLevel(logging.INFO)
fmt = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
if not logger.handlers:
    fh = logging.FileHandler(os.path.join(cfg.OUTPUT_DIR, f"run_{cfg.RUN_TS}.log"))
    fh.setFormatter(fmt); logger.addHandler(fh)
    sh = logging.StreamHandler()
    sh.setFormatter(fmt); logger.addHandler(sh)

def seed_everything(seed=42):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

seed_everything(cfg.SEED)

CLASS_NAMES = sorted(os.listdir(cfg.DATA_PATHS["train"]))
NUM_CLASSES = len(CLASS_NAMES)
logger.info(f"Classes ({NUM_CLASSES}): {CLASS_NAMES}")
logger.info(f"timm={timm.__version__} torch={torch.__version__} device={cfg.DEVICE}")

# ---------------------------- TRANSFORMS ----------------------------
tf_train = transforms.Compose([
    transforms.RandomResizedCrop(cfg.IMG_SIZE, scale=(0.7, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ColorJitter(0.2, 0.2, 0.2, 0.05),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

# STRICT eval for ALL models: Resize directly to (H,W), no CenterCrop
tf_eval_strict = transforms.Compose([
    transforms.Resize((cfg.IMG_SIZE, cfg.IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

# ---------------------------- DATA ----------------------------
train_ds = datasets.ImageFolder(cfg.DATA_PATHS["train"], transform=tf_train)
val_ds   = datasets.ImageFolder(cfg.DATA_PATHS["val"], transform=tf_eval_strict)
test_ds  = datasets.ImageFolder(cfg.DATA_PATHS["test"], transform=tf_eval_strict)

train_loader = DataLoader(
    train_ds, batch_size=cfg.BATCH, shuffle=True,
    num_workers=cfg.NUM_WORKERS, pin_memory=True
)
val_loader = DataLoader(
    val_ds, batch_size=cfg.BATCH, shuffle=False,
    num_workers=cfg.NUM_WORKERS, pin_memory=True
)
test_loader = DataLoader(
    test_ds, batch_size=cfg.BATCH, shuffle=False,
    num_workers=cfg.NUM_WORKERS, pin_memory=True
)

logger.info(f"Dataset sizes | train={len(train_ds)} val={len(val_ds)} test={len(test_ds)}")
logger.info("Eval protocol: STRICT resize to (IMG_SIZE, IMG_SIZE) with NO CenterCrop (applied to ALL models).")

# ---------------------------- HELPERS ----------------------------
def count_params(model):
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total, trainable

def estimate_state_dict_size_mb(model):
    total_params, _ = count_params(model)
    return (total_params * 4) / (1024**2)  # fp32 bytes

@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    preds_all, tgts_all = [], []
    for x, y in loader:
        x = x.to(cfg.DEVICE)
        logits = model(x)
        preds = logits.argmax(1).cpu().numpy()
        preds_all.append(preds)
        tgts_all.append(y.numpy())
    preds_all = np.concatenate(preds_all)
    tgts_all = np.concatenate(tgts_all)

    return {
        "acc": float(accuracy_score(tgts_all, preds_all)),
        "bal_acc": float(balanced_accuracy_score(tgts_all, preds_all)),
        "f1_macro": float(f1_score(tgts_all, preds_all, average="macro")),
        "report": classification_report(tgts_all, preds_all, target_names=CLASS_NAMES, digits=4),
    }

def finetune_ce(model, run_dir):
    opt = torch.optim.AdamW(model.parameters(), lr=cfg.LR, weight_decay=cfg.WD)
    crit = nn.CrossEntropyLoss(label_smoothing=cfg.LABEL_SMOOTHING)
    scaler = GradScaler(device="cuda", enabled=(cfg.AMP and cfg.DEVICE == "cuda"))

    best_f1 = -1.0
    best_state = None
    patience = 0

    for ep in range(cfg.EPOCHS):
        model.train()
        losses = []

        for x, y in tqdm(train_loader, desc=f"Epoch {ep+1}/{cfg.EPOCHS}", leave=False):
            x, y = x.to(cfg.DEVICE), y.to(cfg.DEVICE)
            opt.zero_grad(set_to_none=True)
            with autocast(device_type="cuda" if cfg.DEVICE == "cuda" else "cpu", enabled=scaler.is_enabled()):
                logits = model(x)
                loss = crit(logits, y)

            scaler.scale(loss).backward()
            scaler.step(opt)
            scaler.update()
            losses.append(loss.item())

        valm = evaluate(model, val_loader)
        logger.info(f"[VAL] ep={ep+1} loss={np.mean(losses):.4f} val_f1={valm['f1_macro']:.4f} val_acc={valm['acc']:.4f}")

        if valm["f1_macro"] > best_f1 + 1e-6:
            best_f1 = valm["f1_macro"]
            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
            torch.save(best_state, os.path.join(run_dir, "best_ft.pth"))
            patience = 0
        else:
            patience += 1
            if patience >= cfg.PATIENCE:
                logger.info(f"Early stopping at epoch {ep+1}")
                break

    if best_state is not None:
        model.load_state_dict(best_state)

    testm = evaluate(model, test_loader)
    logger.info(f"[TEST] acc={testm['acc']:.4f} bal_acc={testm['bal_acc']:.4f} f1={testm['f1_macro']:.4f}")
    print("\nTEST REPORT:\n", testm["report"])
    return testm

# ---------------------------- FIXED build_model ----------------------------
def build_model(model_id, pretrained, run_dir):
    model = timm.create_model(
        model_id,
        pretrained=pretrained,
        num_classes=NUM_CLASSES,
        img_size=cfg.IMG_SIZE,
    ).to(cfg.DEVICE)

    pcfg = getattr(model, "pretrained_cfg", {}) or {}

    # Only report pretrained_cfg when pretrained=True
    tag = pcfg.get("tag", None) if pretrained else None
    dataset = pcfg.get("dataset", None) if pretrained else None
    url = pcfg.get("url", None) if pretrained else None
    pcfg_num_classes = pcfg.get("num_classes", None) if pretrained else None
    init_desc = "pretrained weights" if pretrained else "random initialization"

    total_params, trainable_params = count_params(model)
    est_mb = estimate_state_dict_size_mb(model)

    meta = {
        "model_id": model_id,
        "pretrained": pretrained,
        "init_desc": init_desc,
        "img_size": cfg.IMG_SIZE,
        "num_classes": NUM_CLASSES,
        "timm_version": timm.__version__,
        "torch_version": torch.__version__,
        "device": cfg.DEVICE,
        "eval_transform": "STRICT Resize((IMG_SIZE, IMG_SIZE)) no CenterCrop",
        "pretrained_cfg": {
            "dataset": dataset,
            "tag": tag,
            "url": url,
            "num_classes": pcfg_num_classes,
        },
        "params_total": int(total_params),
        "params_trainable": int(trainable_params),
        "estimated_fp32_state_dict_mb": float(est_mb),
        "timestamp": cfg.RUN_TS
    }

    with open(os.path.join(run_dir, "model_metadata.json"), "w") as f:
        json.dump(meta, f, indent=2)

    logger.info(f"Model: {model_id} | pretrained={pretrained} | init={init_desc}")
    logger.info(f"pretrained_cfg.tag={tag} dataset={dataset}")
    logger.info(f"Params: {total_params/1e6:.2f}M | est fp32 weights ~ {est_mb:.1f} MB")

    return model, meta

# ---------------------------- EXPERIMENT LIST ----------------------------
EXPS = [
    {"name": "BASE_scratch", "model_id": cfg.BASE_SCRATCH, "pretrained": False},
    {"name": "BASE_in22k_ft_in1k", "model_id": cfg.BASE_IN22K_FT_IN1K, "pretrained": True},
]

if cfg.RUN_SMALL_IN1K_BASELINE:
    EXPS.append({"name": "SMALL_in1k_baseline", "model_id": cfg.SMALL_IN1K, "pretrained": True})

results_csv = os.path.join(cfg.OUTPUT_DIR, f"results_{cfg.RUN_TS}.csv")
rows = []

logger.info("Experiments:")
for e in EXPS:
    logger.info(f" - {e['name']} | {e['model_id']} | pretrained={e['pretrained']}")

# ---------------------------- RUN ----------------------------
for exp in EXPS:
    run_dir = os.path.join(cfg.OUTPUT_DIR, exp["name"])
    os.makedirs(run_dir, exist_ok=True)

    logger.info("="*100)
    logger.info(f"Running: {exp['name']}")

    model, meta = build_model(exp["model_id"], exp["pretrained"], run_dir)
    testm = finetune_ce(model, run_dir)

    row = {
        "exp_name": exp["name"],
        "model_id": exp["model_id"],
        "pretrained": exp["pretrained"],
        "init_desc": meta.get("init_desc"),
        "eval_transform": meta.get("eval_transform"),
        "pretrained_cfg.dataset": meta["pretrained_cfg"]["dataset"],
        "pretrained_cfg.tag": meta["pretrained_cfg"]["tag"],
        "params_total_M": round(meta["params_total"] / 1e6, 3),
        "estimated_fp32_state_dict_mb": round(meta["estimated_fp32_state_dict_mb"], 2),
        "acc": testm["acc"],
        "bal_acc": testm["bal_acc"],
        "f1_macro": testm["f1_macro"],
        "img_size": cfg.IMG_SIZE,
        "epochs_max": cfg.EPOCHS,
        "label_smoothing": cfg.LABEL_SMOOTHING,
        "lr": cfg.LR,
        "wd": cfg.WD,
        "timm_version": timm.__version__,
        "torch_version": torch.__version__,
    }
    rows.append(row)

    write_header = not os.path.exists(results_csv)
    with open(results_csv, "a", newline="") as f:
        w = csv.DictWriter(f, fieldnames=list(row.keys()))
        if write_header:
            w.writeheader()
        w.writerow(row)

    del model
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

logger.info(f"Done! Results saved: {results_csv}")
print("\nRESULTS CSV:", results_csv)

print("\n=== SUMMARY ===")
for r in rows:
    print(f"{r['exp_name']:20s} | f1={r['f1_macro']:.4f} acc={r['acc']:.4f} "
          f"| params={r['params_total_M']:.2f}M | init={r['init_desc']} | tag={r['pretrained_cfg.tag']}")

2025-12-04 04:02:07,921 - INFO - Classes (3): ['Alternaria', 'Healthy Leaf', 'straw_mite']
2025-12-04 04:02:07,922 - INFO - timm=1.0.19 torch=2.6.0+cu124 device=cuda
2025-12-04 04:02:08,368 - INFO - Dataset sizes | train=473 val=99 test=99
2025-12-04 04:02:08,369 - INFO - Eval protocol: STRICT resize to (IMG_SIZE, IMG_SIZE) with NO CenterCrop (applied to ALL models).
2025-12-04 04:02:08,372 - INFO - Experiments:
2025-12-04 04:02:08,372 - INFO -  - BASE_scratch | swinv2_base_window12to16_192to256 | pretrained=False
2025-12-04 04:02:08,373 - INFO -  - BASE_in22k_ft_in1k | swinv2_base_window12to16_192to256.ms_in22k_ft_in1k | pretrained=True
2025-12-04 04:02:08,374 - INFO -  - SMALL_in1k_baseline | swinv2_small_window16_256.ms_in1k | pretrained=True
2025-12-04 04:02:08,376 - INFO - Running: BASE_scratch
2025-12-04 04:02:09,842 - INFO - Model: swinv2_base_window12to16_192to256 | pretrained=False | init=random initialization
2025-12-04 04:02:09,843 - INFO - pretrained_cfg.tag=None dataset=No


TEST REPORT:
               precision    recall  f1-score   support

  Alternaria     0.9706    0.8919    0.9296        37
Healthy Leaf     0.9688    1.0000    0.9841        31
  straw_mite     0.9091    0.9677    0.9375        31

    accuracy                         0.9495        99
   macro avg     0.9495    0.9532    0.9504        99
weighted avg     0.9508    0.9495    0.9491        99



2025-12-04 04:22:02,694 - INFO - Running: BASE_in22k_ft_in1k
2025-12-04 04:22:04,588 - INFO - Model: swinv2_base_window12to16_192to256.ms_in22k_ft_in1k | pretrained=True | init=pretrained weights
2025-12-04 04:22:04,589 - INFO - pretrained_cfg.tag=ms_in22k_ft_in1k dataset=None
2025-12-04 04:22:04,589 - INFO - Params: 86.90M | est fp32 weights ~ 331.5 MB
2025-12-04 04:22:43,638 - INFO - [VAL] ep=1 loss=1.0336 val_f1=0.9215 val_acc=0.9192
2025-12-04 04:23:30,265 - INFO - [VAL] ep=2 loss=0.4072 val_f1=0.9703 val_acc=0.9697
2025-12-04 04:24:15,889 - INFO - [VAL] ep=3 loss=0.3417 val_f1=0.9901 val_acc=0.9899
2025-12-04 04:24:57,685 - INFO - [VAL] ep=4 loss=0.3280 val_f1=0.9802 val_acc=0.9798
2025-12-04 04:25:37,206 - INFO - [VAL] ep=5 loss=0.3332 val_f1=0.9802 val_acc=0.9798
2025-12-04 04:26:15,384 - INFO - [VAL] ep=6 loss=0.3238 val_f1=0.9705 val_acc=0.9697
2025-12-04 04:26:55,254 - INFO - [VAL] ep=7 loss=0.3439 val_f1=0.9602 val_acc=0.9596
2025-12-04 04:27:35,442 - INFO - [VAL] ep=8 loss=


TEST REPORT:
               precision    recall  f1-score   support

  Alternaria     1.0000    0.9459    0.9722        37
Healthy Leaf     0.9394    1.0000    0.9688        31
  straw_mite     1.0000    1.0000    1.0000        31

    accuracy                         0.9798        99
   macro avg     0.9798    0.9820    0.9803        99
weighted avg     0.9810    0.9798    0.9798        99



2025-12-04 04:29:03,408 - INFO - Running: SMALL_in1k_baseline
2025-12-04 04:29:04,709 - INFO - Model: swinv2_small_window16_256.ms_in1k | pretrained=True | init=pretrained weights
2025-12-04 04:29:04,709 - INFO - pretrained_cfg.tag=ms_in1k dataset=None
2025-12-04 04:29:04,711 - INFO - Params: 48.96M | est fp32 weights ~ 186.8 MB
2025-12-04 04:29:42,968 - INFO - [VAL] ep=1 loss=0.6146 val_f1=0.9608 val_acc=0.9596
2025-12-04 04:30:21,804 - INFO - [VAL] ep=2 loss=0.3475 val_f1=0.9901 val_acc=0.9899
2025-12-04 04:31:01,193 - INFO - [VAL] ep=3 loss=0.3243 val_f1=0.9901 val_acc=0.9899
2025-12-04 04:31:39,422 - INFO - [VAL] ep=4 loss=0.3025 val_f1=0.9705 val_acc=0.9697
2025-12-04 04:32:16,619 - INFO - [VAL] ep=5 loss=0.3052 val_f1=0.9802 val_acc=0.9798
2025-12-04 04:32:54,278 - INFO - [VAL] ep=6 loss=0.3071 val_f1=0.9702 val_acc=0.9697
2025-12-04 04:33:33,658 - INFO - [VAL] ep=7 loss=0.3010 val_f1=0.9803 val_acc=0.9798
2025-12-04 04:34:11,784 - INFO - [VAL] ep=8 loss=0.2936 val_f1=0.9801 val_


TEST REPORT:
               precision    recall  f1-score   support

  Alternaria     1.0000    1.0000    1.0000        37
Healthy Leaf     1.0000    1.0000    1.0000        31
  straw_mite     1.0000    1.0000    1.0000        31

    accuracy                         1.0000        99
   macro avg     1.0000    1.0000    1.0000        99
weighted avg     1.0000    1.0000    1.0000        99



2025-12-04 04:34:59,308 - INFO - Done! Results saved: ./swin_controlled_final_strict_eval/results_20251204_040207.csv



RESULTS CSV: ./swin_controlled_final_strict_eval/results_20251204_040207.csv

=== SUMMARY ===
BASE_scratch         | f1=0.9504 acc=0.9495 | params=86.90M | init=random initialization | tag=None
BASE_in22k_ft_in1k   | f1=0.9803 acc=0.9798 | params=86.90M | init=pretrained weights | tag=ms_in22k_ft_in1k
SMALL_in1k_baseline  | f1=1.0000 acc=1.0000 | params=48.96M | init=pretrained weights | tag=ms_in1k
