In [None]:
# ============================================
# Large → Seedlings 迁移实验（放在新 cell 里）
# - 载入你在大数据集上训练得到的权重（ckpt）
# - 两种策略：
#     1) full_ft     ：全量微调
#     2) lp_unfreeze ：线性探针(冻结backbone) → 逐步解冻
# - 模型支持：convnext_tiny / resnet18
# - 产物：metrics/train_curve.csv、混淆矩阵 & per-class report、最佳权重
# ============================================

import os, math, random, time, json, warnings
from pathlib import Path
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from sklearn.metrics import f1_score, accuracy_score, confusion_matrix, classification_report
from sklearn.model_selection import StratifiedShuffleSplit

import torchvision
from torchvision import transforms as T, models

# ---------------------------
# 用户需在此处手动设置
# ---------------------------
CFG = {
    "seed": 42,
    "strategy": "full_ft",        # "full_ft" | "lp_unfreeze"
    "model_name": "convnext_tiny",# "convnext_tiny" | "resnet18"
    "ckpt_path": "/kaggle/working/leafsnap_runs/convnext_tiny_leafsnap_phase4/convnext_tiny_leafsnap_phase4_best.pt",
    # ^^^ 上面改成你“大数据集训练阶段”的 best 权重路径（*.pt）

    # 数据路径（Kaggle 官方 Plant Seedlings）
    "data_root": "/kaggle/input/plant-seedlings-classification",
    "val_ratio": 0.1,

    # 训练配置
    "img_size": 224,
    "batch_size": 128,
    "num_workers": 2,
    "epochs_fullft": 30,          # 全量微调的训练轮数
    "epochs_lp": 5,               # 线性探针阶段轮数（冻结backbone，仅训分类头）
    "epochs_unfreeze": 15,        # 逐步解冻阶段轮数
    "lr_head": 1e-3,              # 分类头学习率
    "lr_backbone": 3e-4,          # backbone 学习率（full_ft 或解冻阶段使用）
    "weight_decay": 1e-2,
    "label_smoothing": 0.1,
    "mixed_precision": True,

    "out_root": "/kaggle/working/seedlings_transfer",
    "run_name": None,             # 若为 None 自动根据策略/模型命名
}

# ---------------------------
# 基础工具
# ---------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def set_seed(seed=42):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = True

def ensure_dir(p: Path):
    p.mkdir(parents=True, exist_ok=True)

def build_seedlings_split(root: Path, val_ratio=0.1, seed=42):
    train_dir = root/"train"
    assert train_dir.exists(), f"Plant Seedlings 数据不存在：{train_dir}"
    classes = sorted([d.name for d in train_dir.iterdir() if d.is_dir()])
    c2i = {c:i for i,c in enumerate(classes)}
    items=[]
    exts={".jpg",".jpeg",".png",".bmp",".tif",".tiff"}
    for c in classes:
        for p in (train_dir/c).glob("*.*"):
            if p.suffix.lower() in exts:
                items.append((p, c2i[c]))
    y = np.array([b for _,b in items]); idx = np.arange(len(items))
    sss = StratifiedShuffleSplit(n_splits=1, test_size=val_ratio, random_state=seed)
    tr, va = next(sss.split(idx,y))
    return classes, [items[i] for i in tr], [items[i] for i in va]

class DS(Dataset):
    def __init__(self, items, tfm): self.items, self.tfm = items, tfm
    def __len__(self): return len(self.items)
    def __getitem__(self, i):
        p, y = self.items[i]
        img = Image.open(p).convert("RGB")
        x = self.tfm(img)
        return x, y, str(p)

def get_tfms(size):
    train = T.Compose([
        T.Resize(int(size*1.14)),
        T.RandomResizedCrop(size, scale=(0.85,1.0), ratio=(3/4,4/3)),
        T.RandomHorizontalFlip(0.5),
        T.RandAugment(2, 9),
        T.ToTensor(),
        T.RandomErasing(p=0.1),
        T.Normalize((0.485,0.456,0.406),(0.229,0.224,0.225)),
    ])
    valid = T.Compose([
        T.Resize(int(size*1.14)),
        T.CenterCrop(size),
        T.ToTensor(),
        T.Normalize((0.485,0.456,0.406),(0.229,0.224,0.225)),
    ])
    return train, valid

def accuracy(out, tgt):
    with torch.no_grad():
        return (out.argmax(1)==tgt).float().mean().item()

@torch.no_grad()
def evaluate(model, loader, num_classes):
    model.eval()
    all_pred=[]; all_true=[]; loss_sum=acc_sum=n=0
    crit = nn.CrossEntropyLoss()
    for x,y,_ in loader:
        x=x.to(device); y=y.to(device)
        with torch.amp.autocast("cuda", enabled=CFG["mixed_precision"]):
            lo = model(x); ls = crit(lo,y)
        bs=x.size(0)
        loss_sum += ls.item()*bs
        acc_sum  += accuracy(lo,y)*bs
        n += bs
        all_pred.append(lo.argmax(1).cpu().numpy())
        all_true.append(y.cpu().numpy())
    pred = np.concatenate(all_pred); true = np.concatenate(all_true)
    f1 = f1_score(true, pred, average="macro")
    cm = confusion_matrix(true, pred, labels=list(range(num_classes)))
    return loss_sum/n, acc_sum/n, f1, cm, true, pred

def save_cm_and_report(cm, y_true, y_pred, class_names, out_dir: Path, epoch: int):
    ensure_dir(out_dir)
    df_cm = pd.DataFrame(cm, columns=class_names)
    df_cm.insert(0, "true\\pred", class_names)
    df_cm.to_csv(out_dir / f"confusion_matrix_epoch{epoch:03d}.csv", index=False)

    rep = classification_report(y_true, y_pred, labels=list(range(len(class_names))),
                                target_names=class_names, output_dict=True, zero_division=0)
    pd.DataFrame(rep).T.reset_index().rename(columns={"index":"class"}).to_csv(
        out_dir / f"per_class_report_epoch{epoch:03d}.csv", index=False
    )

def save_curve(curves, out_csv: Path):
    pd.DataFrame(curves).to_csv(out_csv, index=False)

# ---------------------------
# 模型与“只加载 backbone”工具
# ---------------------------
def build_model(num_classes, name="convnext_tiny"):
    n = name.lower()
    if n == "convnext_tiny":
        m = models.convnext_tiny(weights=None)
        m.classifier[2] = nn.Linear(m.classifier[2].in_features, num_classes)
        return m
    elif n == "resnet18":
        m = models.resnet18(weights=None)
        m.fc = nn.Linear(m.fc.in_features, num_classes)
        return m
    else:
        raise ValueError(f"Unsupported model_name: {name}")

def load_backbone_from_ckpt(model, ckpt_path, model_name):
    """
    仅加载 backbone（丢弃旧分类头）。ckpt 需是 {'state_dict': ..., ...} 或纯 state_dict。
    """
    ckpt = torch.load(ckpt_path, map_location="cpu")
    sd = ckpt.get("state_dict", ckpt)
    new_sd = model.state_dict()
    drop_keys = []
    if model_name=="convnext_tiny":
        # 丢弃 classifier.*（最后线性层等）
        drop_prefixes = ["classifier."]
    else:  # resnet18
        drop_prefixes = ["fc."]
    for k in list(sd.keys()):
        if any(k.startswith(p) for p in drop_prefixes):
            drop_keys.append(k)
            sd.pop(k)
    # 过滤形状不匹配的键
    matched = {k:v for k,v in sd.items() if (k in new_sd and new_sd[k].shape == v.shape)}
    missing = [k for k in new_sd.keys() if k not in matched]
    print(f"[ckpt] load backbone: matched={len(matched)} | missing(new head etc.)={len(missing)} | dropped_old_head={len(drop_keys)}")
    model.load_state_dict({**new_sd, **matched})
    return model

# ---------------------------
# 冻结/解冻工具（BN 细节）
# ---------------------------
def set_backbone_trainable(model, model_name, trainable: bool):
    if model_name=="convnext_tiny":
        backbone = [model.features]  # convnext features 模块
    else:
        backbone = [nn.Sequential(model.conv1, model.bn1, model.layer1, model.layer2, model.layer3, model.layer4)]
    for m in backbone:
        for p in m.parameters(): p.requires_grad = trainable
        # BN：冻结时 eval，解冻时 train
        for mm in m.modules():
            if isinstance(mm, (nn.BatchNorm2d, nn.LayerNorm)):
                mm.eval() if not trainable else mm.train()

def unfreeze_last_stages(model, model_name, stages=1):
    """
    只对解冻阶段有用：逐步解冻最后若干 stage。
    """
    set_backbone_trainable(model, model_name, trainable=False)  # 先全冻
    if model_name=="convnext_tiny":
        blocks = model.features  # Sequential
        # convnext_tiny stages roughly: [0..6] stem+stages；最后若干层解冻
        to_unfreeze = list(range(len(blocks)-stages, len(blocks)))
        for i in to_unfreeze:
            for p in blocks[i].parameters(): p.requires_grad=True
            for mm in blocks[i].modules():
                if isinstance(mm, (nn.BatchNorm2d, nn.LayerNorm)): mm.train()
    else:
        # resnet18: layer1/2/3/4（解冻最后 n 个）
        layers = [model.layer1, model.layer2, model.layer3, model.layer4]
        for l in layers[-stages:]:
            for p in l.parameters(): p.requires_grad=True
            for mm in l.modules():
                if isinstance(mm, nn.BatchNorm2d): mm.train()

# ---------------------------
# 优化器 & 训练循环
# ---------------------------
def make_optimizer(model, model_name, lr_backbone, lr_head, wd=1e-2, full=False):
    if model_name=="convnext_tiny":
        head_params = list(model.classifier.parameters())
        bb_params   = list(model.features.parameters())
    else:
        head_params = list(model.fc.parameters())
        bb_params   = [p for n,p in model.named_parameters() if not n.startswith("fc.")]
    params=[]
    if full:
        # discriminative LR：backbone 用较小 lr
        params=[{"params": bb_params, "lr": lr_backbone},
                {"params": head_params, "lr": lr_head}]
    else:
        params=[{"params": head_params, "lr": lr_head}]
    return torch.optim.AdamW(params, weight_decay=wd)

def train_one_phase(model, train_loader, valid_loader, epochs, opt, crit, curves, out_met_dir, class_names, start_epoch=1):
    scaler = torch.amp.GradScaler("cuda", enabled=CFG["mixed_precision"])
    best_f1=-1; best_ep=-1; best_path = out_met_dir.parent / f"{run_name}_best.pt"
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=max(1,epochs))
    for e in range(start_epoch, start_epoch+epochs):
        model.train(); t0=time.time()
        loss_sum=acc_sum=n=0
        for x,y,_ in train_loader:
            x=x.to(device); y=y.to(device)
            opt.zero_grad(set_to_none=True)
            with torch.amp.autocast("cuda", enabled=CFG["mixed_precision"]):
                lo = model(x)
                ls = crit(lo,y)
            scaler.scale(ls).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(opt); scaler.update()
            loss_sum += ls.item()*x.size(0)
            acc_sum  += accuracy(lo,y)*x.size(0)
            n += x.size(0)
        sched.step()

        tr_loss, tr_acc = loss_sum/n, acc_sum/n
        va_loss, va_acc, va_f1, cm, y_true, y_pred = evaluate(model, valid_loader, num_classes=len(class_names))
        print(f"[{run_name}] Epoch {e:02d}/{start_epoch+epochs-1:02d} | "
              f"train_loss={tr_loss:.4f} acc={tr_acc:.4f} | val_loss={va_loss:.4f} acc={va_acc:.4f} f1={va_f1:.4f} | time={int(time.time()-t0)}s")

        curves["epoch"].append(e)
        curves["train_loss"].append(tr_loss)
        curves["train_acc"].append(tr_acc)
        curves["val_acc"].append(va_acc)
        curves["val_f1"].append(va_f1)
        save_curve(curves, out_met_dir/"train_curve.csv")
        save_cm_and_report(cm, y_true, y_pred, class_names, out_met_dir, e)

        if va_f1 > best_f1:
            best_f1, best_ep = va_f1, e
            torch.save({"state_dict": model.state_dict(), "classes": class_names}, best_path)
    return best_f1, best_ep, best_path

# ---------------------------
# 主流程
# ---------------------------
set_seed(CFG["seed"])
root = Path(CFG["data_root"])
classes, tr_items, va_items = build_seedlings_split(root, CFG["val_ratio"], CFG["seed"])
num_classes = len(classes)
train_tfm, valid_tfm = get_tfms(CFG["img_size"])
train_ds, valid_ds = DS(tr_items, train_tfm), DS(va_items, valid_tfm)
train_loader = DataLoader(train_ds, batch_size=CFG["batch_size"], shuffle=True,
                          num_workers=CFG["num_workers"], pin_memory=True, drop_last=True)
valid_loader = DataLoader(valid_ds, batch_size=CFG["batch_size"], shuffle=False,
                          num_workers=CFG["num_workers"], pin_memory=True)

# 命名与输出目录
if CFG["run_name"] is None:
    run_name = f"seedlings_{CFG['model_name']}_{CFG['strategy']}"
else:
    run_name = CFG["run_name"]
out_dir = Path(CFG["out_root"])/run_name
met_dir = out_dir/"metrics"
ensure_dir(met_dir)

# 模型：新建 Seedlings 头，然后只加载 backbone 参数
model = build_model(num_classes, CFG["model_name"]).to(device)
model = load_backbone_from_ckpt(model, CFG["ckpt_path"], CFG["model_name"])

# 写资源信息
with open(met_dir/"resource.json", "w") as f:
    json.dump({
        "strategy": CFG["strategy"],
        "model": CFG["model_name"],
        "img_size": CFG["img_size"],
        "batch_size": CFG["batch_size"],
        "params": sum(p.numel() for p in model.parameters())
    }, f)

curves = {"epoch": [], "train_loss": [], "train_acc": [], "val_acc": [], "val_f1": []}
crit = nn.CrossEntropyLoss(label_smoothing=CFG["label_smoothing"])

if CFG["strategy"] == "full_ft":
    # 全量微调：backbone 可训练；分组LR（backbone用较小lr）
    set_backbone_trainable(model, CFG["model_name"], trainable=True)
    opt = make_optimizer(model, CFG["model_name"], CFG["lr_backbone"], CFG["lr_head"],
                         wd=CFG["weight_decay"], full=True)
    best_f1, best_ep, best_path = train_one_phase(
        model, train_loader, valid_loader, CFG["epochs_fullft"], opt, crit, curves, met_dir, classes
    )

elif CFG["strategy"] == "lp_unfreeze":
    # STEP1 线性探针：冻结 backbone，只训 head
    set_backbone_trainable(model, CFG["model_name"], trainable=False)
    opt = make_optimizer(model, CFG["model_name"], CFG["lr_backbone"], CFG["lr_head"],
                         wd=CFG["weight_decay"], full=False)
    best_f1, best_ep, best_path = train_one_phase(
        model, train_loader, valid_loader, CFG["epochs_lp"], opt, crit, curves, met_dir, classes, start_epoch=1
    )

    # STEP2 逐步解冻：解冻最后 n 个 stage（ConvNeXt 默认 1，ResNet 默认 1~2）
    unfreeze_last_stages(model, CFG["model_name"], stages=1 if CFG["model_name"]=="convnext_tiny" else 2)
    opt = make_optimizer(model, CFG["model_name"], CFG["lr_backbone"], CFG["lr_head"],
                         wd=CFG["weight_decay"], full=True)
    b2_f1, b2_ep, best_path = train_one_phase(
        model, train_loader, valid_loader, CFG["epochs_unfreeze"], opt, crit, curves, met_dir, classes,
        start_epoch=(CFG["epochs_lp"]+1)
    )
    if b2_f1 > best_f1:
        best_f1, best_ep = b2_f1, b2_ep

else:
    raise ValueError("CFG['strategy'] 必须是 'full_ft' 或 'lp_unfreeze'")

print(f"\nDone. Best macro-F1={best_f1:.4f} @ epoch {best_ep}.")
print("Outputs ->", out_dir)


### Experiment 3 – In-Domain vs. Generic Transfer Learning

#### This experiment investigates whether pretraining on a large, in-domain dataset (LeafSnap) offers better initialization for a smaller target task (Plant Seedlings) than generic ImageNet pretraining. It compares full fine-tuning and gradual unfreezing strategies to determine which transfer learning approach maximizes performance

### 1. Configuration
#### This cell contains all the hyperparameters and settings for the experiment in a single dictionary CFG. 

In [None]:

CFG = {
    # --- General ---
    "seed": 42,
    "strategy": "lp_unfreeze",       # Options: "full_ft" | "lp_unfreeze"
    "model_name": "convnext_tiny",    # Options: "convnext_tiny" | "resnet18"
    
    # --- Paths ---
    # Path to the best weights from your pre-training on the large dataset
    "ckpt_path": "/kaggle/working/leafsnap_runs/convnext_tiny_leafsnap_phase4/convnext_tiny_leafsnap_phase4_best.pt",
    # Path to the Plant Seedlings dataset
    "data_root": "/kaggle/input/plant-seedlings-classification",
    # Root directory for saving outputs (weights, logs, metrics)
    "out_root": "/kaggle/working/seedlings_transfer_convnext_tiny",
    # Name for this specific run. If None, it's auto-generated.
    "run_name": None,                 

    # --- Data & Splitting ---
    "val_ratio": 0.1,                 # Ratio of data to be used for validation

    # --- Training Parameters ---
    "img_size": 224,
    "batch_size": 128,
    "num_workers": 2,
    "epochs_fullft": 30,              # Number of epochs for the "full fine-tuning" strategy
    "epochs_lp": 5,                   # Number of epochs for the "linear probing" phase
    "epochs_unfreeze": 15,            # Number of epochs for the "gradual unfreezing" phase
    "lr_head": 1e-3,                  # Learning rate for the classification head
    "lr_backbone": 3e-4,              # Learning rate for the backbone (used in full_ft or unfreeze phase)
    "weight_decay": 1e-2,
    "label_smoothing": 0.1,
    "mixed_precision": True,          # Enable/disable Automatic Mixed Precision (AMP)
}

### 2. Imports
#### This cell imports all the necessary libraries for the project, including PyTorch, Torchvision, NumPy, Pandas, and Scikit-learn.

In [None]:
# =========================================================================================
# Imports and Environment Setup
# =========================================================================================
import os, math, random, time, json, warnings
from pathlib import Path

import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from sklearn.metrics import f1_score, accuracy_score, confusion_matrix, classification_report
from sklearn.model_selection import StratifiedShuffleSplit

import torchvision
from torchvision import transforms as T, models

# Suppress warnings for a cleaner output
warnings.filterwarnings("ignore")

# Set the device to CUDA if available, otherwise CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

### 3. Core Utilities
#### This section contains essential helper functions for reproducibility, file handling, and data splitting.

In [None]:
# =========================================================================================
# Core Utility Functions
# =========================================================================================

def set_seed(seed=42):
    """Sets the random seed for reproducibility across all relevant libraries."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # Ensures that CUDA selects the same algorithm each time, which can be a minor source of non-determinism
    torch.backends.cudnn.benchmark = True

def ensure_dir(p: Path):
    """Creates a directory if it does not already exist."""
    p.mkdir(parents=True, exist_ok=True)

def build_seedlings_split(root: Path, val_ratio=0.1, seed=42):
    """
    Scans the Plant Seedlings dataset directory and creates a stratified train/validation split.
    Returns: A tuple containing (class_names, train_items, validation_items).
    """
    train_dir = root / "train"
    assert train_dir.exists(), f"Plant Seedlings data not found at: {train_dir}"
    
    # Get class names from subdirectories
    classes = sorted([d.name for d in train_dir.iterdir() if d.is_dir()])
    class_to_idx = {c: i for i, c in enumerate(classes)}
    
    # Collect all image paths and their corresponding labels
    items = []
    valid_exts = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"}
    for c in classes:
        for p in (train_dir / c).glob("*.*"):
            if p.suffix.lower() in valid_exts:
                items.append((p, class_to_idx[c]))
    
    # Perform stratified split
    labels = np.array([label for _, label in items])
    indices = np.arange(len(items))
    sss = StratifiedShuffleSplit(n_splits=1, test_size=val_ratio, random_state=seed)
    train_idx, val_idx = next(sss.split(indices, labels))
    
    train_items = [items[i] for i in train_idx]
    val_items = [items[i] for i in val_idx]
    
    return classes, train_items, val_items

### 4. Data Handling
#### Here, we define the PyTorch Dataset class to load images and the transforms for data augmentation and normalization.

In [None]:
# =========================================================================================
# Data Preparation: Dataset Class and Image Transforms
# =========================================================================================

class PlantDataset(Dataset):
    """Custom PyTorch Dataset for loading plant seedling images."""
    def __init__(self, items, transform):
        self.items = items
        self.transform = transform
        
    def __len__(self):
        return len(self.items)
    
    def __getitem__(self, i):
        path, label = self.items[i]
        image = Image.open(path).convert("RGB")
        image_tensor = self.transform(image)
        return image_tensor, label, str(path)

def get_transforms(img_size):
    """
    Defines the image augmentation and normalization pipelines for training and validation.
    """
    # Augmentations for the training set
    train_transform = T.Compose([
        T.Resize(int(img_size * 1.14)),
        T.RandomResizedCrop(img_size, scale=(0.85, 1.0), ratio=(3./4., 4./3.)),
        T.RandomHorizontalFlip(p=0.5),
        T.RandAugment(num_ops=2, magnitude=9),
        T.ToTensor(),
        T.RandomErasing(p=0.1),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    
    # Simpler transforms for the validation set (no augmentation)
    valid_transform = T.Compose([
        T.Resize(int(img_size * 1.14)),
        T.CenterCrop(img_size),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    
    return train_transform, valid_transform

### 5. Model Architecture

In [1]:
# =========================================================================================
# Model Definition and Pre-trained Backbone Loader
# =========================================================================================

def build_model(num_classes, model_name="convnext_tiny"):
    """
    Builds a model with a new classification head adapted for the number of target classes.
    """
    name_lower = model_name.lower()
    if name_lower == "convnext_tiny":
        model = models.convnext_tiny(weights=None)
        # Replace the final layer
        in_features = model.classifier[2].in_features
        model.classifier[2] = nn.Linear(in_features, num_classes)
    elif name_lower == "resnet18":
        model = models.resnet18(weights=None)
        # Replace the final fully connected layer
        in_features = model.fc.in_features
        model.fc = nn.Linear(in_features, num_classes)
    else:
        raise ValueError(f"Unsupported model_name: {model_name}")
    return model

def load_backbone_from_ckpt(model, ckpt_path, model_name):
    """
    Loads only the backbone weights from a checkpoint, ignoring the classification head.
    The checkpoint can be a full dictionary {'state_dict': ...} or just the state_dict itself.
    """
    # Load the checkpoint
    ckpt = torch.load(ckpt_path, map_location="cpu")
    state_dict = ckpt.get("state_dict", ckpt)
    
    # Get the state_dict of the new model
    new_state_dict = model.state_dict()
    
    # Define prefixes for the classification head layers to be dropped
    if model_name.lower() == "convnext_tiny":
        drop_prefixes = ("classifier.",)
    else:  # Assumes resnet18
        drop_prefixes = ("fc.",)
        
    # Remove the old classification head weights from the loaded state_dict
    dropped_keys = []
    for k in list(state_dict.keys()):
        if k.startswith(drop_prefixes):
            dropped_keys.append(state_dict.pop(k))
    
    # Filter for weights that match in name and shape
    matched_state_dict = {k: v for k, v in state_dict.items() 
                          if k in new_state_dict and new_state_dict[k].shape == v.shape}
    
    missing_keys = [k for k in new_state_dict.keys() if k not in matched_state_dict]
    
    print(f"[CKPT INFO] Loading backbone weights:")
    print(f"  - Matched layers: {len(matched_state_dict)}")
    print(f"  - Missing (new head, etc.): {len(missing_keys)}")
    print(f"  - Dropped (old head): {len(dropped_keys)}")
    
    # Load the matched weights into the new model
    model.load_state_dict({**new_state_dict, **matched_state_dict})
    return model

### 6. Fine-Tuning & Layer Freezing Utilities

In [None]:
# =========================================================================================
# Model Freezing/Unfreezing Utilities
# =========================================================================================

def set_backbone_trainable(model, model_name, trainable: bool):
    """
    Sets the `requires_grad` property for all parameters in the model's backbone.
    Also handles setting BatchNorm/LayerNorm layers to train or eval mode accordingly.
    """
    if model_name.lower() == "convnext_tiny":
        backbone_modules = [model.features]
    else: # Assumes resnet18
        backbone_modules = [
            nn.Sequential(model.conv1, model.bn1, model.relu, model.maxpool, 
                          model.layer1, model.layer2, model.layer3, model.layer4)
        ]

    for module in backbone_modules:
        # Set parameter trainability
        for param in module.parameters():
            param.requires_grad = trainable
        # Set norm layers to the correct mode
        for norm_layer in module.modules():
            if isinstance(norm_layer, (nn.BatchNorm2d, nn.LayerNorm)):
                if trainable:
                    norm_layer.train()
                else:
                    norm_layer.eval()

def unfreeze_last_stages(model, model_name, stages_to_unfreeze=1):
    """
    Unfreezes the last N stages of the backbone for fine-tuning.
    This is used in the gradual unfreezing phase.
    """
    # First, freeze the entire backbone
    set_backbone_trainable(model, model_name, trainable=False)
    
    print(f"Unfreezing the last {stages_to_unfreeze} stage(s) of the {model_name} backbone...")
    
    # Identify and unfreeze the target layers
    if model_name.lower() == "convnext_tiny":
        # The 'features' module is a Sequential list of blocks.
        # Stages are roughly groups of these blocks. We unfreeze the last few blocks.
        all_blocks = model.features
        target_blocks = list(all_blocks)[-stages_to_unfreeze:]
    else: # Assumes resnet18
        layers = [model.layer1, model.layer2, model.layer3, model.layer4]
        target_blocks = layers[-stages_to_unfreeze:]

    # Unfreeze parameters and set norm layers to train mode for the target blocks
    for block in target_blocks:
        for param in block.parameters():
            param.requires_grad = True
        for module in block.modules():
            if isinstance(module, (nn.BatchNorm2d, nn.LayerNorm)):
                module.train()

### 7. Training & Evaluation

In [None]:
# =========================================================================================
# Optimizer, Training Loop, and Evaluation
# =========================================================================================

def make_optimizer(model, model_name, lr_backbone, lr_head, wd=1e-2, full_ft=False):
    """
    Creates an AdamW optimizer with differential learning rates for the head and backbone.
    """
    if model_name.lower() == "convnext_tiny":
        head_params = list(model.classifier.parameters())
        backbone_params = list(model.features.parameters())
    else: # Assumes resnet18
        head_params = list(model.fc.parameters())
        backbone_params = [p for n, p in model.named_parameters() if not n.startswith("fc.")]
    
    param_groups = []
    # If full fine-tuning, create two parameter groups with different learning rates
    if full_ft:
        param_groups.append({"params": backbone_params, "lr": lr_backbone})
        param_groups.append({"params": head_params, "lr": lr_head})
    # Otherwise, only train the head
    else:
        param_groups.append({"params": head_params, "lr": lr_head})
        
    return torch.optim.AdamW(param_groups, weight_decay=wd)

@torch.no_grad()
def evaluate(model, loader, num_classes):
    """Evaluates the model on a given dataloader and computes metrics."""
    model.eval()
    all_preds, all_trues = [], []
    loss_sum = acc_sum = n = 0
    criterion = nn.CrossEntropyLoss()
    
    for x, y, _ in loader:
        x, y = x.to(device), y.to(device)
        
        with torch.amp.autocast("cuda", enabled=CFG["mixed_precision"]):
            logits = model(x)
            loss = criterion(logits, y)
        
        bs = x.size(0)
        loss_sum += loss.item() * bs
        acc_sum += (logits.argmax(1) == y).float().sum().item()
        n += bs
        
        all_preds.append(logits.argmax(1).cpu().numpy())
        all_trues.append(y.cpu().numpy())
        
    preds = np.concatenate(all_preds)
    trues = np.concatenate(all_trues)
    
    f1 = f1_score(trues, preds, average="macro")
    cm = confusion_matrix(trues, preds, labels=list(range(num_classes)))
    
    return loss_sum / n, acc_sum / n, f1, cm, trues, preds

def save_metrics(cm, y_true, y_pred, class_names, out_dir: Path, epoch: int):
    """Saves the confusion matrix and a per-class classification report."""
    ensure_dir(out_dir)
    
    # Save confusion matrix
    df_cm = pd.DataFrame(cm, columns=class_names, index=class_names)
    df_cm.to_csv(out_dir / f"confusion_matrix_epoch{epoch:03d}.csv")

    # Save classification report
    report = classification_report(y_true, y_pred, labels=list(range(len(class_names))),
                                   target_names=class_names, output_dict=True, zero_division=0)
    pd.DataFrame(report).T.reset_index().rename(columns={"index": "class"}).to_csv(
        out_dir / f"per_class_report_epoch{epoch:03d}.csv", index=False
    )
    
def save_curve(curves, out_csv: Path):
    """Saves the training history (loss, acc, f1) to a CSV file."""
    pd.DataFrame(curves).to_csv(out_csv, index=False)


def train_one_phase(model, train_loader, valid_loader, epochs, optimizer, criterion, 
                    curves, out_metrics_dir, class_names, start_epoch=1):
    """Main training loop for a single phase (e.g., linear probing or fine-tuning)."""
    scaler = torch.amp.GradScaler(enabled=CFG["mixed_precision"])
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max(1, epochs))
    
    best_f1 = -1
    best_epoch = -1
    best_path = out_metrics_dir.parent / f"{CFG['run_name']}_best.pt"
    
    num_total_epochs = start_epoch + epochs - 1
    
    for e in range(start_epoch, start_epoch + epochs):
        model.train()
        t0 = time.time()
        loss_sum = acc_sum = n = 0
        
        for x, y, _ in train_loader:
            x, y = x.to(device), y.to(device)
            
            optimizer.zero_grad(set_to_none=True)
            
            with torch.amp.autocast("cuda", enabled=CFG["mixed_precision"]):
                logits = model(x)
                loss = criterion(logits, y)
            
            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # Gradient clipping
            scaler.step(optimizer)
            scaler.update()
            
            bs = x.size(0)
            loss_sum += loss.item() * bs
            acc_sum += (logits.argmax(1) == y).float().sum().item()
            n += bs
            
        scheduler.step()

        # End of epoch evaluation
        train_loss, train_acc = loss_sum / n, acc_sum / n
        val_loss, val_acc, val_f1, cm, y_true, y_pred = evaluate(model, valid_loader, len(class_names))
        
        print(f"[{CFG['run_name']}] Epoch {e:02d}/{num_total_epochs:02d} | "
              f"Train Loss={train_loss:.4f}, Acc={train_acc:.4f} | "
              f"Val Loss={val_loss:.4f}, Acc={val_acc:.4f}, F1={val_f1:.4f} | "
              f"Time={int(time.time()-t0)}s")
        
        # Log metrics
        curves["epoch"].append(e)
        curves["train_loss"].append(train_loss)
        curves["train_acc"].append(train_acc)
        curves["val_loss"].append(val_loss)
        curves["val_acc"].append(val_acc)
        curves["val_f1"].append(val_f1)
        save_curve(curves, out_metrics_dir / "train_curve.csv")
        save_metrics(cm, y_true, y_pred, class_names, out_metrics_dir, e)

        # Save best model based on validation F1-score
        if val_f1 > best_f1:
            best_f1, best_epoch = val_f1, e
            torch.save({"state_dict": model.state_dict(), "classes": class_names}, best_path)
            print(f"  -> New best model saved with F1-score: {best_f1:.4f}")
            
    return best_f1, best_epoch, best_path

In [None]:
# =========================================================================================
# Main Execution Block
# =========================================================================================

# 1. Set seed for reproducibility
set_seed(CFG["seed"])

# 2. Prepare data splits and dataloaders
root = Path(CFG["data_root"])
classes, train_items, val_items = build_seedlings_split(root, CFG["val_ratio"], CFG["seed"])
num_classes = len(classes)
print(f"Found {num_classes} classes. Splitting into {len(train_items)} train and {len(val_items)} validation samples.")

train_tfm, valid_tfm = get_transforms(CFG["img_size"])
train_ds = PlantDataset(train_items, train_tfm)
valid_ds = PlantDataset(val_items, valid_tfm)

train_loader = DataLoader(train_ds, batch_size=CFG["batch_size"], shuffle=True,
                          num_workers=CFG["num_workers"], pin_memory=True, drop_last=True)
valid_loader = DataLoader(valid_ds, batch_size=CFG["batch_size"], shuffle=False,
                          num_workers=CFG["num_workers"], pin_memory=True)

# 3. Setup output directories and run name
if CFG["run_name"] is None:
    CFG["run_name"] = f"seedlings_{CFG['model_name']}_{CFG['strategy']}"
out_dir = Path(CFG["out_root"]) / CFG["run_name"]
metrics_dir = out_dir / "metrics"
ensure_dir(metrics_dir)
print(f"Outputs will be saved to: {out_dir}")

# 4. Build model and load pre-trained backbone
model = build_model(num_classes, CFG["model_name"]).to(device)
model = load_backbone_from_ckpt(model, CFG["ckpt_path"], CFG["model_name"])

# 5. Save run configuration
with open(metrics_dir / "config.json", "w") as f:
    json.dump(CFG, f, indent=4)

# 6. Initialize for training
curves = {"epoch": [], "train_loss": [], "train_acc": [], "val_loss": [], "val_acc": [], "val_f1": []}
criterion = nn.CrossEntropyLoss(label_smoothing=CFG["label_smoothing"])
best_f1_overall, best_epoch_overall = -1, -1

# 7. Execute training based on the selected strategy
# ---------------------------------------------------
if CFG["strategy"] == "full_ft":
    print("\n--- Starting Training: Full Fine-Tuning Strategy ---")
    set_backbone_trainable(model, CFG["model_name"], trainable=True)
    optimizer = make_optimizer(model, CFG["model_name"], CFG["lr_backbone"], CFG["lr_head"],
                               wd=CFG["weight_decay"], full_ft=True)
    best_f1_overall, best_epoch_overall, _ = train_one_phase(
        model, train_loader, valid_loader, CFG["epochs_fullft"], optimizer, criterion, 
        curves, metrics_dir, classes
    )

elif CFG["strategy"] == "lp_unfreeze":
    # Phase 1: Linear Probing (train only the head)
    print("\n--- Starting Training Phase 1: Linear Probing ---")
    set_backbone_trainable(model, CFG["model_name"], trainable=False)
    optimizer = make_optimizer(model, CFG["model_name"], CFG["lr_backbone"], CFG["lr_head"],
                               wd=CFG["weight_decay"], full_ft=False)
    best_f1_overall, best_epoch_overall, _ = train_one_phase(
        model, train_loader, valid_loader, CFG["epochs_lp"], optimizer, criterion, 
        curves, metrics_dir, classes, start_epoch=1
    )

    # Phase 2: Gradual Unfreezing (fine-tune deeper layers)
    print("\n--- Starting Training Phase 2: Gradual Unfreezing ---")
    stages = 1 if CFG["model_name"] == "convnext_tiny" else 2
    unfreeze_last_stages(model, CFG["model_name"], stages_to_unfreeze=stages)
    optimizer = make_optimizer(model, CFG["model_name"], CFG["lr_backbone"], CFG["lr_head"],
                               wd=CFG["weight_decay"], full_ft=True)
    f1_p2, ep_p2, _ = train_one_phase(
        model, train_loader, valid_loader, CFG["epochs_unfreeze"], optimizer, criterion, 
        curves, metrics_dir, classes, start_epoch=(CFG["epochs_lp"] + 1)
    )
    if f1_p2 > best_f1_overall:
        best_f1_overall, best_epoch_overall = f1_p2, ep_p2

else:
    raise ValueError("CFG['strategy'] must be either 'full_ft' or 'lp_unfreeze'")

print("\n" + "="*50)
print("Training Finished!")
print(f"Best overall macro-F1 score: {best_f1_overall:.4f} achieved at epoch {best_epoch_overall}.")
print(f"All outputs, metrics, and best model are saved in: {out_dir}")
print("="*50)