## **Methodology: Curriculum-Driven Robust Fine-Tuning**

### **Problem**

Some plant disease classes still failed due to:

* visually-similar lesions across species
* long-tail labels with weak supervision
* Grad-CAM showing attention on background instead of lesions

---

### **Step-1: Diagnose**

We computed:

* per-class F1
* confusion matrix
* Grad-CAM lesion overlap

This revealed *where* the model was looking and *why* it failed.

---

### **Step-2: Counterfactual edits**

For misclassified images, we applied edits like:

* sharpening
* center cropping
* contrast/saturation tweaks
* jpeg degradation

For each edit we measured:

> change in probability margin between the true and confused class.

Edits that improved the margin became **per-class recommendations**.

---

### **Step-3: Build curriculum**

From these results we generated:

| File                   | Purpose                          |
| ---------------------- | -------------------------------- |
| `aug_map.json`         | label-specific augmentations     |
| `w_label.json`         | sampling weights for weak labels |
| `alpha_per_class.json` | focal loss emphasis              |

This forces the model to **practice cases it previously failed**.

---

### **Step-4: Fine-Tune**

Two-stage training:

1️⃣ backbone low LR + head higher LR
2️⃣ short correction epoch → stabilizes without overfitting

---

### **Step-5: Validate**

We recomputed metrics and overlap deltas.
Improvements concentrated exactly on:

* overlapping lesion confusions
* low-F1 classes
* CAM shifting toward lesion regions

---

### **Result**

The curriculum delivered **targeted accuracy gains** without harming global performance — and interpretability confirmed behavior improved rather than just numbers.

Setup & metadata

In [None]:
# ===========================
# Section 1 — Setup & metadata
# ===========================
import os, json, math, cv2, timm, torch
import numpy as np, pandas as pd
from pathlib import Path
import torch.nn.functional as F
from torch import nn
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IM_SIZE = 320
IM_MEAN = [0.485, 0.456, 0.406]
IM_STD  = [0.229, 0.224, 0.225]

print("Device:", DEVICE)

# ---- paths ----
BASE_DIR = Path("/kaggle/working/v6_global_eval")
BASE_DIR.mkdir(parents=True, exist_ok=True)

TRAIN_CSV = "/kaggle/input/k/adiithape/cnn-model-v3/outputs/metadata/curated_train_v6.csv"
VAL_CSV   = "/kaggle/input/k/adiithape/cnn-model-v3/outputs/metadata/curated_val_v6.csv"
MAP_JSON  = "/kaggle/input/k/adiithape/cnn-model-v3/outputs/metadata/label2idx_v6.json"

CKPT_PATH = "/kaggle/input/v5-model/effb3_320_curated_no_cotton_best_v5.pt"

TEMP = 0.551  # current calibration temperature

train_df = pd.read_csv(TRAIN_CSV)
val_df   = pd.read_csv(VAL_CSV)

with open(MAP_JSON,"r") as f:
    label2idx = json.load(f)

idx2label = {i:l for l,i in label2idx.items()}
classes = [idx2label[i] for i in range(len(idx2label))]

print("Classes:", len(classes))


Load model + preprocessing

In [None]:
# ===========================
# Section 2 — Model + preprocessing
# ===========================
model = timm.create_model("efficientnet_b3", pretrained=False, num_classes=len(classes))
state = torch.load(CKPT_PATH, map_location="cpu")
if isinstance(state, dict) and "model" in state:
    model.load_state_dict(state["model"], strict=True)
else:
    model.load_state_dict(state, strict=True)

model = model.to(DEVICE).eval()

def load_rgb(path):
    img = cv2.imread(str(path))
    if img is None: raise FileNotFoundError(path)
    return img

def preprocess_bgr_pad(img):
    h,w = img.shape[:2]
    s = IM_SIZE/max(h,w)
    nh,nw = int(h*s), int(w*s)
    img = cv2.resize(img,(nw,nh), cv2.INTER_AREA)
    canvas = np.zeros((IM_SIZE,IM_SIZE,3), dtype=img.dtype)
    y0 = (IM_SIZE-nh)//2; x0 = (IM_SIZE-nw)//2
    canvas[y0:y0+nh,x0:x0+nw] = img
    x = (canvas[:,:,::-1]/255.0 - IM_MEAN)/IM_STD
    return torch.from_numpy(np.transpose(x,(2,0,1))).float()

def preprocess_bgr_center_crop(img):
    h,w = img.shape[:2]
    if h < w: nh,nw = IM_SIZE, int(w*IM_SIZE/h)
    else:     nh,nw = int(h*IM_SIZE/w), IM_SIZE
    img = cv2.resize(img,(nw,nh), cv2.INTER_AREA)
    y0 = (nh-IM_SIZE)//2; x0 = (nw-IM_SIZE)//2
    img = img[y0:y0+IM_SIZE, x0:x0+IM_SIZE]
    x = (img[:,:,::-1]/255.0 - IM_MEAN)/IM_STD
    return torch.from_numpy(np.transpose(x,(2,0,1))).float()


Global inference cache (validation + Google test)

In [None]:
# ===========================
# Section 3 — Global inference cache
# ===========================
@torch.no_grad()
def predict_dual_logits(img, T=TEMP):
    x1 = preprocess_bgr_pad(img).unsqueeze(0).to(DEVICE)
    x2 = preprocess_bgr_center_crop(img).unsqueeze(0).to(DEVICE)
    z = (model(x1)+model(x2))/2
    if T: z = z/float(T)
    return z.squeeze(0).cpu()

def run_infer_cache(df, out_csv):
    rows = []
    for i,r in df.iterrows():
        img = load_rgb(r.filepath)
        z = predict_dual_logits(img)
        p = F.softmax(z,dim=0).numpy()
        top = p.argsort()[::-1][:3]
        rows.append({
            "filepath": r.filepath,
            "gt": r.label,
            "pred": classes[top[0]],
            "prob": float(p[top[0]]),
            "pred2": classes[top[1]],
            "pred3": classes[top[2]]
        })
    pd.DataFrame(rows).to_csv(out_csv, index=False)
    print("Saved:", out_csv)

OUT = BASE_DIR/"preds_v6_val_dualT.csv"
run_infer_cache(val_df[["filepath","label"]], OUT)

CAM lesion-overlap analysis

In [None]:
# ===========================
# Section 4 — CAM lesion alignment
# ===========================
import matplotlib.pyplot as plt
from torch import nn

PRED_CSV = BASE_DIR/"preds_v6_val_dualT.csv"
preds = pd.read_csv(PRED_CSV)
cls2idx = {c:i for i,c in enumerate(classes)}

target_layer = model.conv_head

def gradcam(img, cls):
    acts, grads = [],[]
    def f(m,i,o): acts.append(o.detach())
    def b(m,gi,go): grads.append(go[0].detach())
    h1 = target_layer.register_forward_hook(f)
    h2 = target_layer.register_full_backward_hook(b)

    x = preprocess_bgr_center_crop(img).unsqueeze(0).to(DEVICE)
    s = model(x)[0,cls]; model.zero_grad(); s.backward()

    a, g = acts[-1][0], grads[-1][0]
    w = g.mean((1,2),keepdim=True)
    cam = (w*a).sum(0).cpu().numpy()
    cam = (cam-cam.min())/(cam.max()+1e-6)

    h1.remove(); h2.remove()
    return cv2.resize(cam,(IM_SIZE,IM_SIZE))

def lesion_mask(img):
    lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)[:,:,0]
    lab = cv2.createCLAHE(2.0,(8,8)).apply(lab)
    e = cv2.Canny(lab,60,120)
    e = cv2.dilate(e,np.ones((3,3),np.uint8))
    return cv2.resize((e>0).astype(np.uint8),(IM_SIZE,IM_SIZE))

rows=[]
for _,r in preds.iterrows():
    img = load_rgb(r.filepath)
    cls = cls2idx[r.gt]
    cam = gradcam(img,cls)
    mask = lesion_mask(img)
    ov = float((cam*(mask>0)).sum()/(cam.sum()+1e-6))
    rows.append({**r.to_dict(),"overlap":ov})
df_cam = pd.DataFrame(rows)
df_cam.to_csv(BASE_DIR/"cam_overlap_all_species.csv", index=False)
print("Saved CAM overlap CSV")

Counterfactual edits → deltas

In [None]:
# ===========================
# Section 5 — Counterfactual edits
# ===========================
edits = {
    "center_bias": lambda im: cv2.resize(im[int(0.1*im.shape[0]):int(0.9*im.shape[0]),
                                            int(0.1*im.shape[1]):int(0.9*im.shape[1])], (im.shape[1],im.shape[0])),
    "sharpen": lambda im: cv2.addWeighted(im,1.5,cv2.GaussianBlur(im,(0,0),1.0),-0.5,0),
    "contrast+10%": lambda im: cv2.convertScaleAbs(im,alpha=1.1,beta=0),
    "bright+10": lambda im: cv2.convertScaleAbs(im,alpha=1.0,beta=10),
}

pair = df_cam.copy()
pair["correct"]=pair["gt"]==pair["pred"]

rows=[]
for _,r in pair.sample(300, random_state=42).iterrows():  # cap
    img = load_rgb(r.filepath)
    p0 = F.softmax(predict_dual_logits(img),dim=0).numpy()
    t = cls2idx[r.gt]; pw = cls2idx[r.pred]
    for name,fn in edits.items():
        im1 = fn(img)
        p1 = F.softmax(predict_dual_logits(im1),dim=0).numpy()
        rows.append({
            "true":r.gt,"wrong":r.pred,"edit":name,
            "d_margin": (p1[t]-p1[pw]) - (p0[t]-p0[pw])
        })
df_delta = pd.DataFrame(rows)
df_delta.to_csv(BASE_DIR/"counterfactual_deltas.csv", index=False)
print("Saved deltas")

Build class-specific augmentation map

In [None]:
# ===========================
# Section 6 — Build class augmentation map
# ===========================
winners = (
    df_delta.groupby(["true","edit"])["d_margin"]
    .mean().reset_index()
    .sort_values(["true","d_margin"],ascending=[True,False])
)

aug_map={}
for lbl,g in winners.groupby("true"):
    names=g.head(3)["edit"].tolist()
    cfg=dict(center_crop_p=0.1, blur_p=0.1, sharpen_p=0, color_jitter_p=0.08,
             cj_brightness=0.1,cj_contrast=0.1,cj_saturation=0.1,cj_hue=0.02)
    if "center_bias" in names: cfg["center_crop_p"]=0.35
    if "sharpen" in names: cfg["sharpen_p"]=0.25
    if "contrast" in names: cfg["cj_contrast"]=0.14
    aug_map[lbl]=cfg

(AUG := BASE_DIR/"curriculum/aug_map.json").parent.mkdir(parents=True, exist_ok=True)
json.dump(aug_map, open(AUG,"w"), indent=2)
print("Saved", AUG)

Curriculum label weights & alpha (hard classes get focal loss)

In [None]:
# ===========================
# Section 7 — Curriculum & alpha
# ===========================
cam = pd.read_csv(BASE_DIR/"cam_overlap_all_species.csv")
low = cam.groupby("gt")["overlap"].mean().sort_values().head(30)

w_label={c:1.0 for c in classes}
for i,(lbl,val) in enumerate(low.items()):
    w_label[lbl]=1.0+0.6*(1-i/max(1,len(low)-1))

alpha_per_class={label2idx[l]:1.35 for l in low.head(10).index}

json.dump(w_label, open(BASE_DIR/"curriculum/w_label.json","w"), indent=2)
json.dump(alpha_per_class, open(BASE_DIR/"curriculum/alpha_per_class.json","w"), indent=2)

Unified fine-tuning (final)

In [None]:
# fine_tune_v6_unified.py
import os, json, cv2, math, torch, timm, numpy as np, pandas as pd
from pathlib import Path
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import albumentations as A
from albumentations.pytorch import ToTensorV2

# ---- Config/paths ----
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IM_SIZE = 320
IM_MEAN = [0.485, 0.456, 0.406]; IM_STD = [0.229, 0.224, 0.225]
CKPT_IN = "/kaggle/input/v5-model/effb3_320_curated_no_cotton_best_v5.pt"
CKPT_OUT = "/kaggle/working/v6_global_eval/effb3_v6_unified_ft.pt"
MAP_JSON = "/kaggle/input/k/adiithape/cnn-model-v3/outputs/metadata/label2idx_v6.json"
TRAIN_CSV = "/kaggle/input/k/adiithape/cnn-model-v3/outputs/metadata/curated_train_v6.csv"
VAL_CSV   = "/kaggle/input/k/adiithape/cnn-model-v3/outputs/metadata/curated_val_v6.csv"

CURR_DIR = Path("/kaggle/working/v6_global_eval/curriculum")
AUGMAP_JSON = CURR_DIR/"aug_map.json"
W_LABEL_JSON = CURR_DIR/"w_label.json"
ALPHA_JSON   = CURR_DIR/"alpha_per_class.json"

EPOCHS = 2
BS_TRAIN = 32
BS_VAL = 64
LR_BACKBONE = 1e-5
LR_HEAD = 3e-5
WD = 1e-2
TEMP = 0.551

# ---- Load metadata and curriculum ----
train_df = pd.read_csv(TRAIN_CSV)
val_df   = pd.read_csv(VAL_CSV)
with open(MAP_JSON, "r") as f: label2idx = json.load(f)
idx2label = {int(v):k for k,v in label2idx.items()}
classes = [idx2label[i] for i in range(len(idx2label))]

with open(AUGMAP_JSON,"r") as f: aug_map = json.load(f)
with open(W_LABEL_JSON,"r") as f: w_label = json.load(f)
alpha_per_class = {}
if ALPHA_JSON.exists():
    with open(ALPHA_JSON,"r") as f: alpha_per_class = {int(k):float(v) for k,v in json.load(f).items()}

# ---- Aug helpers ----
def sharpen_unsharp(img):
    blur = cv2.GaussianBlur(img, (0,0), 1.0)
    return cv2.addWeighted(img, 1.5, blur, -0.5, 0)

def sharpen_unsharp_aug(img, **kwargs):
    # Albumentations passes extra params; ignore them
    return sharpen_unsharp(img)

def make_transform_for_label(lbl: str):
    cfg = aug_map.get(lbl, {
        "center_crop_p": 0.10, "blur_p": 0.10, "sharpen_p": 0.00,
        "color_jitter_p": 0.08, "cj_brightness":0.10, "cj_contrast":0.10, "cj_saturation":0.10, "cj_hue":0.02,
    })
    cfg.setdefault("center_crop_p", 0.10)
    cfg.setdefault("blur_p", 0.10)
    cfg.setdefault("sharpen_p", 0.00)
    cfg.setdefault("color_jitter_p", 0.08)
    cfg.setdefault("cj_brightness", 0.10)
    cfg.setdefault("cj_contrast", 0.10)
    cfg.setdefault("cj_saturation", 0.10)
    cfg.setdefault("cj_hue", 0.02)

    tfs = []
    if cfg["center_crop_p"] > 0:
        tfs.append(
            A.RandomResizedCrop(
                size=(IM_SIZE, IM_SIZE),
                scale=(0.85, 1.0),
                ratio=(0.9, 1.1),
                interpolation=cv2.INTER_AREA,
                p=cfg["center_crop_p"],
            )
        )
    else:
        tfs.append(A.Resize(IM_SIZE, IM_SIZE, interpolation=cv2.INTER_AREA))

    tfs.append(A.HorizontalFlip(p=0.5))
    tfs.append(A.ShiftScaleRotate(shift_limit=0.02, scale_limit=0.05, rotate_limit=15,
                                  border_mode=cv2.BORDER_REFLECT_101, p=0.35))

    if cfg["color_jitter_p"] > 0:
        tfs.append(A.ColorJitter(brightness=cfg["cj_brightness"], contrast=cfg["cj_contrast"],
                                 saturation=cfg["cj_saturation"], hue=cfg["cj_hue"],
                                 p=cfg["color_jitter_p"]))

    if cfg["blur_p"] > 0:
        tfs.append(A.GaussianBlur(blur_limit=3, p=cfg["blur_p"]))

    if cfg["sharpen_p"] > 0:
        tfs.append(A.Lambda(image=sharpen_unsharp_aug, p=cfg["sharpen_p"]))

    tfs.append(A.ImageCompression(quality_lower=70, quality_upper=95, p=0.15))
    tfs.append(A.Normalize(mean=IM_MEAN, std=IM_STD))
    tfs.append(ToTensorV2())
    return A.Compose(tfs)


# ---- Dataset ----
class PlantDS(Dataset):
    def __init__(self, df):
        self.df = df.reset_index(drop=True)
    def __len__(self): return len(self.df)
    def __getitem__(self, i):
        r = self.df.iloc[i]
        lbl = r["label"]
        img = cv2.imread(r["filepath"], cv2.IMREAD_COLOR)
        if img is None:
            img = np.zeros((IM_SIZE, IM_SIZE, 3), np.uint8)
        tf = make_transform_for_label(lbl)
        x = tf(image=img)["image"]
        y = label2idx[lbl]
        return x, y

# ---- Sampler ----
def build_weights(df):
    names = df["filepath"].astype(str).str.lower()
    w = np.ones(len(df), dtype=np.float32)
    for i, (fp, lbl) in enumerate(zip(df["filepath"], df["label"])):
        w[i] *= float(w_label.get(lbl, 1.0))
        # slight ambiguity nudge
        if any(k in fp.lower() for k in ["spot","blight","rust","mildew","mosaic","canker"]):
            w[i] *= 1.05
    return w

w = build_weights(train_df)
sampler = WeightedRandomSampler(w, num_samples=len(w), replacement=True)

train_loader = DataLoader(PlantDS(train_df), batch_size=BS_TRAIN, sampler=sampler, num_workers=2, pin_memory=True)
val_loader   = DataLoader(PlantDS(val_df),   batch_size=BS_VAL, shuffle=False, num_workers=2, pin_memory=True)

# ---- Model ----
num_classes = len(classes)
model = timm.create_model("efficientnet_b3", pretrained=False, num_classes=num_classes).to(DEVICE)
state = torch.load(CKPT_IN, map_location="cpu")
if isinstance(state, dict) and "model" in state and isinstance(state["model"], dict):
    model.load_state_dict(state["model"], strict=True)
else:
    model.load_state_dict(state, strict=True)

# param groups
backbone_params, head_params = [], []
for n,p in model.named_parameters():
    if not p.requires_grad: continue
    if "classifier" in n: head_params.append(p)
    else: backbone_params.append(p)
opt = torch.optim.AdamW([
    {"params": backbone_params, "lr": LR_BACKBONE},
    {"params": head_params, "lr": LR_HEAD},
], weight_decay=WD)

# Hybrid loss
class HybridLoss(nn.Module):
    def __init__(self, alpha_per_class, gamma=1.5):
        super().__init__()
        self.alpha = alpha_per_class
        self.gamma = gamma
    def forward(self, logits, target):
        # Use focal if target class has alpha>1, else CE
        alphas = torch.ones(logits.size(0), device=logits.device)
        use_focal = torch.zeros_like(alphas, dtype=torch.bool)
        for i, t in enumerate(target.tolist()):
            if t in self.alpha:
                alphas[i] = self.alpha[t]
                use_focal[i] = True
        ce = F.cross_entropy(logits, target, reduction="none")
        if use_focal.any():
            pt = torch.softmax(logits, dim=1).gather(1, target.view(-1,1)).squeeze(1).clamp(1e-6, 1-1e-6)
            fl = ((1-pt)**self.gamma) * ce * alphas
            # mix: use focal where flagged, CE elsewhere
            mix = torch.where(use_focal, fl, ce)
            return mix.mean()
        else:
            return ce.mean()

criterion = HybridLoss(alpha_per_class=alpha_per_class, gamma=1.5)

# ---- Train/Eval helpers ----
@torch.no_grad()
def eval_top1(loader):
    model.eval()
    correct = 0; total = 0
    for x,y in loader:
        x = x.to(DEVICE); y = y.to(DEVICE)
        logits = model(x)
        pred = logits.argmax(1)
        correct += (pred==y).sum().item()
        total += y.numel()
    return correct/total if total>0 else 0.0

best_acc = eval_top1(val_loader)
print("Start val top1:", best_acc)

# ---- Train loop ----
for epoch in range(EPOCHS):
    model.train()
    for x,y in train_loader:
        x = x.to(DEVICE); y = y.to(DEVICE)
        logits = model(x)
        loss = criterion(logits, y)
        opt.zero_grad(set_to_none=True)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        opt.step()
    acc = eval_top1(val_loader)
    print(f"Epoch {epoch+1}/{EPOCHS} val top1: {acc:.4f}")
    if acc >= best_acc - 1e-4:
        best_acc = acc
        torch.save({"model": model.state_dict(), "label2idx": label2idx}, CKPT_OUT)
        print("Saved:", CKPT_OUT)

print("Best val top1:", best_acc)

In [None]:
# one_shot_eval_v6.py
import os, json, cv2, torch, timm, numpy as np, pandas as pd
from pathlib import Path
from torch import nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

# ---- Config ----
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IM_SIZE = 320
IM_MEAN = [0.485, 0.456, 0.406]
IM_STD  = [0.229, 0.224, 0.225]
TEMP = 0.551  # deployment temperature

# Paths (edit these three as needed)
CKPT_PATH = "/kaggle/working/v6_global_eval/effb3_v6_unified_ft.pt"  # new fine-tuned model
MAP_JSON  = "/kaggle/input/k/adiithape/cnn-model-v3/outputs/metadata/label2idx_v6.json"
TEST_ROOT = "/kaggle/input/plant-disease-google-test-images/test_google/test"

# Metadata
VAL_CSV  = "/kaggle/input/k/adiithape/cnn-model-v3/outputs/metadata/curated_val_v6.csv"

# Output dir
OUT = Path("/kaggle/working/v6_global_eval/new_eval"); OUT.mkdir(parents=True, exist_ok=True)

# ---- I/O helpers ----
def list_images(root):
    exts = {".jpg",".jpeg",".png",".bmp",".webp",".tif",".tiff"}
    return [str(p) for p in Path(root).rglob("*") if p.suffix.lower() in exts]

def load_rgb(path):
    img = cv2.imread(str(path))
    if img is None: 
        raise FileNotFoundError(path)
    return img  # BGR

def preprocess_bgr_pad(img_bgr):
    h,w = img_bgr.shape[:2]; s = IM_SIZE/max(h,w)
    nh,nw = int(h*s), int(w*s)
    img = cv2.resize(img_bgr,(nw,nh), interpolation=cv2.INTER_AREA)
    canvas = np.zeros((IM_SIZE,IM_SIZE,3), dtype=img.dtype)
    y0=(IM_SIZE-nh)//2; x0=(IM_SIZE-nw)//2
    canvas[y0:y0+nh, x0:x0+nw] = img
    x = canvas[:,:,::-1].astype(np.float32)/255.0
    x = (x - np.array(IM_MEAN,np.float32))/np.array(IM_STD,np.float32)
    return torch.from_numpy(np.transpose(x,(2,0,1))).float()

def preprocess_bgr_center_crop(img_bgr):
    h,w = img_bgr.shape[:2]
    if h < w: nh,nw = IM_SIZE, int(w*IM_SIZE/h)
    else:     nh,nw = int(h*IM_SIZE/w), IM_SIZE
    img = cv2.resize(img_bgr,(nw,nh), interpolation=cv2.INTER_AREA)
    y0=(nh-IM_SIZE)//2; x0=(nw-IM_SIZE)//2
    img = img[y0:y0+IM_SIZE, x0:x0+IM_SIZE]
    x = img[:,:,::-1].astype(np.float32)/255.0
    x = (x - np.array(IM_MEAN,np.float32))/np.array(IM_STD,np.float32)
    return torch.from_numpy(np.transpose(x,(2,0,1))).float()

# ---- Load label map and model ----
with open(MAP_JSON, "r") as f:
    label2idx = json.load(f)
idx2label = {int(v):k for k,v in label2idx.items()}
classes = [idx2label[i] for i in range(len(idx2label))]

model = timm.create_model("efficientnet_b3", pretrained=False, num_classes=len(classes)).to(DEVICE).eval()
state = torch.load(CKPT_PATH, map_location="cpu")
sd = state["model"] if isinstance(state, dict) and "model" in state and isinstance(state["model"], dict) else state
model.load_state_dict(sd, strict=True)

# ---- Predictors ----
@torch.no_grad()
def predict_dualT_logits(img_bgr, temperature=TEMP):
    x1 = preprocess_bgr_pad(img_bgr).unsqueeze(0).to(DEVICE)
    x2 = preprocess_bgr_center_crop(img_bgr).unsqueeze(0).to(DEVICE)
    z = (model(x1) + model(x2)) / 2.0
    if temperature is not None and temperature > 0:
        z = z / float(temperature)
    return z.squeeze(0).float().cpu()

@torch.no_grad()
def predict_file(path, topk=3):
    img = load_rgb(path)
    z = predict_dualT_logits(img, TEMP)
    p = F.softmax(z, dim=0).numpy()
    top = p.argsort()[::-1][:topk]
    out = [(classes[i], float(p[i])) for i in top]
    return out, p

# ---- Build val/test DataFrames (all species) ----
val_df = pd.read_csv(VAL_CSV)[["filepath","label"]].reset_index(drop=True)
# Test_root enumeration
rows = []
for p in list_images(TEST_ROOT):
    lbl = Path(p).parent.name
    if lbl in label2idx:
        rows.append({"filepath": p, "label": lbl})
test_df = pd.DataFrame(rows).reset_index(drop=True)

# ---- Inference cache ----
def run_infer_cache(df, out_csv):
    rows = []
    for i, r in df.iterrows():
        fp, gt = r["filepath"], r["label"]
        try:
            preds, p = predict_file(fp, topk=3)
        except Exception:
            continue
        top = np.argsort(-p)[:3]
        row = {
            "filepath": fp, "gt": gt,
            "pred": classes[top[0]], "prob": float(p[top[0]]),
            "pred1": classes[top[0]], "prob1": float(p[top[0]]),
            "pred2": classes[top[1]] if len(top)>1 else "", "prob2": float(p[top[1]]) if len(top)>1 else 0.0,
            "pred3": classes[top[2]] if len(top)>2 else "", "prob3": float(p[top[2]]) if len(top)>2 else 0.0,
        }
        rows.append(row)
        if (i+1) % 500 == 0:
            print(f"{i+1}/{len(df)}")
    pd.DataFrame(rows).to_csv(out_csv, index=False)
    print("Saved:", out_csv)

VAL_PRED = OUT/"preds_v6_val_dualT.csv"
TEST_PRED = OUT/"preds_v6_test_dualT.csv"
run_infer_cache(val_df, VAL_PRED)
run_infer_cache(test_df, TEST_PRED)

# ---- Metrics & confusions ----
val = pd.read_csv(VAL_PRED)
labels = sorted(pd.concat([val["gt"], val["pred"]]).unique())
lab2idx_local = {l:i for i,l in enumerate(labels)}
idx2lab_local = {i:l for l,i in lab2idx_local.items()}

def per_class_report(df):
    y_true = df["gt"].map(lab2idx_local).to_numpy()
    y_pred = df["pred"].map(lab2idx_local).to_numpy()
    n = len(labels)
    cm = np.zeros((n,n), dtype=np.int64)
    for t,pred in zip(y_true, y_pred):
        cm[t,pred] += 1
    support = cm.sum(1)
    tp = np.diag(cm); fp = cm.sum(0) - tp; fn = cm.sum(1) - tp
    precision = np.divide(tp, tp+fp, out=np.zeros_like(tp, dtype=float), where=(tp+fp)>0)
    recall    = np.divide(tp, tp+fn, out=np.zeros_like(tp, dtype=float), where=(tp+fn)>0)
    f1 = np.divide(2*precision*recall, precision+recall, out=np.zeros_like(tp, dtype=float), where=(precision+recall)>0)
    rep = pd.DataFrame({
        "label":[idx2lab_local[i] for i in range(n)],
        "support":support, "precision":precision, "recall":recall, "f1":f1
    }).sort_values("f1")
    macro = {
        "macro_precision": float(np.mean(precision)),
        "macro_recall": float(np.mean(recall)),
        "macro_f1": float(np.mean(f1)),
        "overall_acc": float(tp.sum()/cm.sum()) if cm.sum()>0 else 0.0
    }
    return cm, rep, macro

cm, rep, macro = per_class_report(val)
pd.DataFrame(cm, index=labels, columns=labels).to_csv(OUT/"confusion_matrix.csv", index=False)
rep.to_csv(OUT/"per_class_report.csv", index=False)
with open(OUT/"summary.json","w") as f: json.dump(macro, f, indent=2)

# Top confusions
pairs = []
for i in range(cm.shape[0]):
    row = cm[i].copy(); row[i] = 0
    if row.sum()==0: continue
    top_js = np.argsort(-row)[:min(50, cm.shape[1]-1)]
    for j in top_js:
        if row[j] > 0:
            pairs.append({"true": labels[i], "pred": labels[j], "count": int(row[j])})
pairs.sort(key=lambda d: d["count"], reverse=True)
with open(OUT/"top_confusions.json","w") as f:
    json.dump({"global": pairs[:100]}, f, indent=2)

print("Saved metrics and confusions to:", str(OUT))

# ---- CAM overlap (all species) ----
# Target conv layer: conv_head preferred
target_layer = getattr(model, "conv_head", None)
if target_layer is None:
    # fallback last Conv2d
    tname = None
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d):
            tname = name
    def get_module_by_name(model, name):
        mod = model
        for part in name.split("."):
            mod = getattr(mod, part)
        return mod
    target_layer = get_module_by_name(model, tname)
print("CAM target:", target_layer.__class__.__name__)

def gradcam_on_image(img_bgr, cls_idx, target_layer):
    model.zero_grad(set_to_none=True)
    activations, gradients = [], []
    def fwd_hook(m, i, o): activations.append(o.detach())
    def bwd_hook(m, gi, go): gradients.append(go[0].detach())
    h1 = target_layer.register_forward_hook(fwd_hook)
    h2 = target_layer.register_full_backward_hook(bwd_hook)

    x = preprocess_bgr_center_crop(img_bgr).unsqueeze(0).to(DEVICE)
    out = model(x); score = out[0, cls_idx]; score.backward()
    act = activations[-1][0]; grad = gradients[-1][0]
    w = grad.mean(dim=(1,2), keepdim=True)
    cam = (w * act).sum(0).cpu().numpy()
    cam = np.maximum(cam, 0); cam = cv2.resize(cam, (IM_SIZE, IM_SIZE))
    cam = (cam - cam.min()) / (cam.max() + 1e-6)
    h1.remove(); h2.remove()
    return cam

def lesion_mask(img_bgr):
    imgL = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2LAB)[:,:,0]
    imgL = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8)).apply(imgL)
    edges = cv2.Canny(imgL, 60, 120)
    edges = cv2.dilate(edges, np.ones((3,3),np.uint8), iterations=1)
    edges = cv2.GaussianBlur(edges, (3,3), 0)
    mask = (edges > 0).astype(np.uint8)
    mask = cv2.resize(mask, (IM_SIZE, IM_SIZE), interpolation=cv2.INTER_NEAREST)
    return mask

preds = val.copy()
preds["correct"] = preds["gt"] == preds["pred"]
cls2idx = {c:i for i,c in enumerate(classes)}

PANEL_DIR = OUT/"cam_panels"; PANEL_DIR.mkdir(exist_ok=True)
metrics = []
for lbl in classes:
    if lbl not in preds["gt"].values or lbl not in cls2idx: 
        continue
    sub = preds[preds["gt"]==lbl].sample(min(16, (preds["gt"]==lbl).sum()), random_state=42)
    figs = []
    for _, r in sub.iterrows():
        fp = r["filepath"]
        try:
            img = load_rgb(fp)
        except Exception:
            continue
        cam = gradcam_on_image(img, cls2idx[lbl], target_layer)
        msk = lesion_mask(img)
        overlap = float((cam * (msk>0)).sum() / (cam.sum() + 1e-6))
        metrics.append({"label": lbl, "filepath": fp, "overlap": overlap, "pred": r["pred"], "correct": bool(r["correct"])})
        heat = (cv2.applyColorMap((cam*255).astype(np.uint8), cv2.COLORMAP_JET)[:,:,::-1])/255.0
        rgb  = cv2.cvtColor(cv2.resize(img, (IM_SIZE,IM_SIZE)), cv2.COLOR_BGR2RGB)/255.0
        overlay = 0.4*heat + 0.6*rgb
        figs.append(overlay)
    if figs:
        cols = 4; rows = int(np.ceil(len(figs)/cols))
        plt.figure(figsize=(3*cols, 3*rows))
        for i, im in enumerate(figs, 1):
            plt.subplot(rows, cols, i); plt.imshow(im); plt.axis("off")
        plt.suptitle(f"Grad-CAM — {lbl}"); plt.tight_layout()
        plt.savefig(PANEL_DIR/f"cam_{lbl}.png", dpi=200); plt.close()

pd.DataFrame(metrics).to_csv(OUT/"cam_overlap_all_species.csv", index=False)
print("Saved CAM overlap and panels to:", str(OUT))


Short correction pass

In [None]:
# 3_finetune_one_epoch.py
import os, json, cv2, torch, timm, numpy as np, pandas as pd
from pathlib import Path
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Config
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IM_SIZE = 320
IM_MEAN = [0.485, 0.456, 0.406]; IM_STD = [0.229, 0.224, 0.225]
EPOCHS = 1
BS_TRAIN = 32; BS_VAL = 64
LR_BACKBONE = 1e-5; LR_HEAD = 3e-5; WD = 1e-2

# Paths
TRAIN_CSV = "/kaggle/input/k/adiithape/cnn-model-v3/outputs/metadata/curated_train_v6.csv"
VAL_CSV   = "/kaggle/input/k/adiithape/cnn-model-v3/outputs/metadata/curated_val_v6.csv"
MAP_JSON  = "/kaggle/input/k/adiithape/cnn-model-v3/outputs/metadata/label2idx_v6.json"
CKPT_IN   = "/kaggle/working/v6_global_eval/effb3_v6_unified_ft.pt"  # previous best
CKPT_OUT  = "/kaggle/working/v6_global_eval/effb3_v6_unified_ft_final.pt"
CURR      = Path("/kaggle/working/v6_global_eval/curriculum")

# Load metadata and curriculum
train_df = pd.read_csv(TRAIN_CSV)
val_df   = pd.read_csv(VAL_CSV)
with open(MAP_JSON,"r") as f: label2idx = json.load(f)
idx2label = {int(v):k for k,v in label2idx.items()}
classes = [idx2label[i] for i in range(len(idx2label))]

with open(CURR/"aug_map.json","r") as f: aug_map = json.load(f)
with open(CURR/"w_label.json","r") as f: w_label = json.load(f)
alpha_per_class = {}
APC = CURR/"alpha_per_class.json"
if APC.exists():
    with open(APC,"r") as f: alpha_per_class = {int(k):float(v) for k,v in json.load(f).items()}

def sharpen_unsharp(img):
    blur = cv2.GaussianBlur(img, (0,0), 1.0)
    return cv2.addWeighted(img, 1.5, blur, -0.5, 0)
def sharpen_unsharp_aug(img, **kwargs): return sharpen_unsharp(img)

def make_transform_for_label(lbl: str):
    cfg = aug_map.get(lbl, {
        "center_crop_p": 0.10, "blur_p": 0.10, "sharpen_p": 0.00,
        "color_jitter_p": 0.08, "cj_brightness":0.10, "cj_contrast":0.10, "cj_saturation":0.10, "cj_hue":0.02,
    })
    for k, v in {"center_crop_p":0.10,"blur_p":0.10,"sharpen_p":0.00,"color_jitter_p":0.08,
                 "cj_brightness":0.10,"cj_contrast":0.10,"cj_saturation":0.10,"cj_hue":0.02}.items():
        cfg.setdefault(k, v)
    tfs = []
    if cfg["center_crop_p"] > 0:
        tfs.append(A.RandomResizedCrop(size=(IM_SIZE,IM_SIZE), scale=(0.85,1.0), ratio=(0.9,1.1),
                                       interpolation=cv2.INTER_AREA, p=cfg["center_crop_p"]))
    else:
        tfs.append(A.Resize(IM_SIZE, IM_SIZE, interpolation=cv2.INTER_AREA))
    tfs.append(A.HorizontalFlip(p=0.5))
    tfs.append(A.Affine(scale=(0.95,1.05), translate_percent={"x":(-0.02,0.02),"y":(-0.02,0.02)},
                        rotate=(-15,15), mode=cv2.BORDER_REFLECT_101, p=0.35))
    if cfg["color_jitter_p"] > 0:
        tfs.append(A.ColorJitter(brightness=cfg["cj_brightness"], contrast=cfg["cj_contrast"],
                                 saturation=cfg["cj_saturation"], hue=cfg["cj_hue"], p=cfg["color_jitter_p"]))
    if cfg["blur_p"] > 0: tfs.append(A.GaussianBlur(blur_limit=3, p=cfg["blur_p"]))
    if cfg["sharpen_p"] > 0: tfs.append(A.Lambda(image=sharpen_unsharp_aug, p=cfg["sharpen_p"]))
    tfs.append(A.ImageCompression(quality=(70,95), p=0.15))
    tfs.append(A.Normalize(mean=IM_MEAN, std=IM_STD)); tfs.append(ToTensorV2())
    return A.Compose(tfs)

class PlantDS(Dataset):
    def __init__(self, df): self.df = df.reset_index(drop=True)
    def __len__(self): return len(self.df)
    def __getitem__(self, i):
        r = self.df.iloc[i]
        img = cv2.imread(r["filepath"], cv2.IMREAD_COLOR)
        if img is None: img = np.zeros((IM_SIZE,IM_SIZE,3), np.uint8)
        x = make_transform_for_label(r["label"])(image=img)["image"]
        y = label2idx[r["label"]]
        return x, y

def build_weights(df):
    w = np.ones(len(df), dtype=np.float32)
    for i, (_, row) in enumerate(df.iterrows()):
        lbl = row["label"]
        w[i] *= float(w_label.get(lbl, 1.0))
    return w

train_loader = DataLoader(PlantDS(train_df), batch_size=BS_TRAIN,
                          sampler=WeightedRandomSampler(build_weights(train_df), num_samples=len(train_df), replacement=True),
                          num_workers=2, pin_memory=True)
val_loader = DataLoader(PlantDS(val_df), batch_size=BS_VAL, shuffle=False, num_workers=2, pin_memory=True)

# Model
model = timm.create_model("efficientnet_b3", pretrained=False, num_classes=len(classes)).to(DEVICE)
state = torch.load(CKPT_IN, map_location="cpu")
sd = state["model"] if isinstance(state, dict) and "model" in state and isinstance(state["model"], dict) else state
model.load_state_dict(sd, strict=True)

# Optim
backbone_params, head_params = [], []
for n,p in model.named_parameters():
    if not p.requires_grad: continue
    (head_params if "classifier" in n else backbone_params).append(p)
opt = torch.optim.AdamW([
    {"params": backbone_params, "lr": LR_BACKBONE},
    {"params": head_params, "lr": LR_HEAD},
], weight_decay=WD)

class HybridLoss(nn.Module):
    def __init__(self, alpha_per_class, gamma=1.5):
        super().__init__(); self.alpha = alpha_per_class; self.gamma = gamma
    def forward(self, logits, target):
        ce = F.cross_entropy(logits, target, reduction="none")
        alphas = torch.ones_like(target, dtype=torch.float, device=logits.device)
        use = torch.zeros_like(target, dtype=torch.bool, device=logits.device)
        for i, t in enumerate(target.tolist()):
            if t in self.alpha: alphas[i] = self.alpha[t]; use[i] = True
        if use.any():
            pt = torch.softmax(logits, dim=1).gather(1, target.view(-1,1)).squeeze(1).clamp(1e-6, 1-1e-6)
            fl = ((1-pt)**self.gamma) * ce * alphas
            return torch.where(use, fl, ce).mean()
        return ce.mean()

criterion = HybridLoss(alpha_per_class=alpha_per_class, gamma=1.5)

@torch.no_grad()
def eval_top1(loader):
    model.eval(); correct=0; total=0
    for x,y in loader:
        x=x.to(DEVICE); y=y.to(DEVICE)
        pred = model(x).argmax(1)
        correct += (pred==y).sum().item(); total += y.numel()
    return correct/total if total>0 else 0.0

best_acc = eval_top1(val_loader)
print("Start val top1:", best_acc)

# Train one epoch
model.train()
for x,y in train_loader:
    x=x.to(DEVICE); y=y.to(DEVICE)
    logits = model(x); loss = criterion(logits, y)
    opt.zero_grad(set_to_none=True); loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    opt.step()

acc = eval_top1(val_loader)
print("Post-epoch val top1:", acc)
if acc >= best_acc - 1e-4:
    torch.save({"model": model.state_dict(), "label2idx": label2idx}, CKPT_OUT)
    print("Saved:", CKPT_OUT)

Final evaluation

In [None]:
# 4_eval_final_checkpoint.py
# Set CKPT_PATH to the new final model and rerun the one-shot eval you used earlier.
import os, json, cv2, torch, timm, numpy as np, pandas as pd
from pathlib import Path
from torch import nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IM_SIZE = 320; IM_MEAN=[0.485,0.456,0.406]; IM_STD=[0.229,0.224,0.225]; TEMP=0.551

CKPT_PATH = "/kaggle/working/v6_global_eval/effb3_v6_unified_ft_final.pt"
MAP_JSON  = "/kaggle/input/k/adiithape/cnn-model-v3/outputs/metadata/label2idx_v6.json"
VAL_CSV   = "/kaggle/input/k/adiithape/cnn-model-v3/outputs/metadata/curated_val_v6.csv"
TEST_ROOT = "/kaggle/input/plant-disease-google-test-images/test_google/test"
OUT = Path("/kaggle/working/v6_global_eval/final_eval"); OUT.mkdir(parents=True, exist_ok=True)

def list_images(root):
    exts={".jpg",".jpeg",".png",".bmp",".webp",".tif",".tiff"}
    return [str(p) for p in Path(root).rglob("*") if p.suffix.lower() in exts]
def load_rgb(path):
    img=cv2.imread(str(path)); 
    if img is None: raise FileNotFoundError(path)
    return img
def preprocess_bgr_pad(img_bgr):
    h,w=img_bgr.shape[:2]; s=IM_SIZE/max(h,w)
    nh,nw=int(h*s),int(w*s); img=cv2.resize(img_bgr,(nw,nh),interpolation=cv2.INTER_AREA)
    canvas=np.zeros((IM_SIZE,IM_SIZE,3),dtype=img.dtype); y0=(IM_SIZE-nh)//2; x0=(IM_SIZE-nw)//2
    canvas[y0:y0+nh, x0:x0+nw]=img
    x=canvas[:,:,::-1].astype(np.float32)/255.0
    x=(x-np.array(IM_MEAN,np.float32))/np.array(IM_STD,np.float32)
    return torch.from_numpy(np.transpose(x,(2,0,1))).float()
def preprocess_bgr_center_crop(img_bgr):
    h,w=img_bgr.shape[:2]
    if h<w: nh,nw=IM_SIZE,int(w*IM_SIZE/h)
    else:   nh,nw=int(h*IM_SIZE/w),IM_SIZE
    img=cv2.resize(img_bgr,(nw,nh),interpolation=cv2.INTER_AREA)
    y0=(nh-IM_SIZE)//2; x0=(nw-IM_SIZE)//2; img=img[y0:y0+IM_SIZE, x0:x0+IM_SIZE]
    x=img[:,:,::-1].astype(np.float32)/255.0
    x=(x-np.array(IM_MEAN,np.float32))/np.array(IM_STD,np.float32)
    return torch.from_numpy(np.transpose(x,(2,0,1))).float()

with open(MAP_JSON,"r") as f: label2idx=json.load(f)
idx2label={int(v):k for k,v in label2idx.items()}
classes=[idx2label[i] for i in range(len(idx2label))]

model=timm.create_model("efficientnet_b3", pretrained=False, num_classes=len(classes)).to(DEVICE).eval()
state=torch.load(CKPT_PATH, map_location="cpu")
sd=state["model"] if isinstance(state, dict) and "model" in state and isinstance(state["model"], dict) else state
model.load_state_dict(sd, strict=True)

@torch.no_grad()
def predict_dualT_logits(img_bgr, temperature=TEMP):
    x1=preprocess_bgr_pad(img_bgr).unsqueeze(0).to(DEVICE)
    x2=preprocess_bgr_center_crop(img_bgr).unsqueeze(0).to(DEVICE)
    z=(model(x1)+model(x2))/2.0
    if temperature and temperature>0: z=z/float(temperature)
    return z.squeeze(0).float().cpu()

def run_infer_cache(df, out_csv):
    rows=[]
    for i,r in df.iterrows():
        fp, gt=r["filepath"], r["label"]
        try:
            img=load_rgb(fp); z=predict_dualT_logits(img, TEMP); p=F.softmax(z, dim=0).numpy()
        except Exception:
            continue
        top=np.argsort(-p)[:3]
        rows.append({
            "filepath":fp,"gt":gt,
            "pred":classes[top[0]],"prob":float(p[top[0]]),
            "pred1":classes[top[0]],"prob1":float(p[top[0]]),
            "pred2":classes[top[1]] if len(top)>1 else "","prob2":float(p[top[1]]) if len(top)>1 else 0.0,
            "pred3":classes[top[2]] if len(top)>2 else "","prob3":float(p[top[2]]) if len(top)>2 else 0.0,
        })
        if (i+1)%500==0: print(f"{i+1}/{len(df)}")
    pd.DataFrame(rows).to_csv(out_csv, index=False); print("Saved:", out_csv)

val_df=pd.read_csv(VAL_CSV)[["filepath","label"]].reset_index(drop=True)
rows=[]; 
for p in list_images(TEST_ROOT):
    lbl=Path(p).parent.name
    if lbl in label2idx: rows.append({"filepath":p,"label":lbl})
test_df=pd.DataFrame(rows).reset_index(drop=True)

VAL_PRED=OUT/"preds_v6_val_dualT.csv"
TEST_PRED=OUT/"preds_v6_test_dualT.csv"
run_infer_cache(val_df, VAL_PRED)
run_infer_cache(test_df, TEST_PRED)

# Metrics/confusions
val=pd.read_csv(VAL_PRED)
labels=sorted(pd.concat([val["gt"], val["pred"]]).unique())
lab2idx_local={l:i for i,l in enumerate(labels)}; idx2lab_local={i:l for l,i in lab2idx_local.items()}
n=len(labels); cm=np.zeros((n,n), dtype=np.int64)
for t,pred in zip(val["gt"].map(lab2idx_local).to_numpy(), val["pred"].map(lab2idx_local).to_numpy()):
    cm[t,pred]+=1
support=cm.sum(1); tp=np.diag(cm); fp=cm.sum(0)-tp; fn=cm.sum(1)-tp
precision=np.divide(tp, tp+fp, out=np.zeros_like(tp,dtype=float), where=(tp+fp)>0)
recall   =np.divide(tp, tp+fn, out=np.zeros_like(tp,dtype=float), where=(tp+fn)>0)
f1=np.divide(2*precision*recall, precision+recall, out=np.zeros_like(tp,dtype=float), where=(precision+recall)>0)
rep=pd.DataFrame({"label":[idx2lab_local[i] for i in range(n)],"support":support,"precision":precision,"recall":recall,"f1":f1}).sort_values("f1")
macro={"macro_precision":float(np.mean(precision)),"macro_recall":float(np.mean(recall)),"macro_f1":float(np.mean(f1)),"overall_acc":float(tp.sum()/cm.sum()) if cm.sum()>0 else 0.0}
pd.DataFrame(cm, index=labels, columns=labels).to_csv(OUT/"confusion_matrix.csv", index=False)
rep.to_csv(OUT/"per_class_report.csv", index=False)
with open(OUT/"summary.json","w") as f: json.dump(macro, f, indent=2)

pairs=[]
for i in range(cm.shape[0]):
    row=cm[i].copy(); row[i]=0
    if row.sum()==0: continue
    for j in np.argsort(-row)[:min(50, cm.shape[1]-1)]:
        if row[j]>0: pairs.append({"true":labels[i],"pred":labels[j],"count":int(row[j])})
pairs.sort(key=lambda d: d["count"], reverse=True)
pd.DataFrame(pairs[:50]).to_csv(OUT/"top_confusions_new.csv", index=False)
print("Saved metrics/confusions to:", str(OUT))

# CAM overlap (same as earlier)
target_layer=getattr(model,"conv_head",None)
if target_layer is None:
    tname=None
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d): tname=name
    def get_module_by_name(model, name):
        mod=model
        for part in name.split("."): mod=getattr(mod, part)
        return mod
    target_layer=get_module_by_name(model, tname)
print("CAM target:", target_layer.__class__.__name__)

def gradcam_on_image(img_bgr, cls_idx, target_layer):
    model.zero_grad(set_to_none=True); activations=[]; gradients=[]
    def fwd_hook(m,i,o): activations.append(o.detach())
    def bwd_hook(m,gi,go): gradients.append(go[0].detach())
    h1=target_layer.register_forward_hook(fwd_hook)
    h2=target_layer.register_full_backward_hook(bwd_hook)
    x=preprocess_bgr_center_crop(img_bgr).unsqueeze(0).to(DEVICE)
    out=model(x); score=out[0, cls_idx]; score.backward()
    act=activations[-1][0]; grad=gradients[-1][0]; w=grad.mean(dim=(1,2), keepdim=True)
    cam=(w*act).sum(0).cpu().numpy(); cam=np.maximum(cam,0); cam=cv2.resize(cam,(IM_SIZE,IM_SIZE))
    cam=(cam-cam.min())/(cam.max()+1e-6); h1.remove(); h2.remove(); return cam

def lesion_mask(img_bgr):
    imgL=cv2.cvtColor(img_bgr, cv2.COLOR_BGR2LAB)[:,:,0]
    imgL=cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8)).apply(imgL)
    edges=cv2.Canny(imgL,60,120); edges=cv2.dilate(edges, np.ones((3,3),np.uint8),1); edges=cv2.GaussianBlur(edges,(3,3),0)
    mask=(edges>0).astype(np.uint8); mask=cv2.resize(mask,(IM_SIZE,IM_SIZE),interpolation=cv2.INTER_NEAREST); return mask

preds=val.copy(); preds["correct"]=preds["gt"]==preds["pred"]; cls2idx={c:i for i,c in enumerate(classes)}
PANEL=OUT/"cam_panels"; PANEL.mkdir(exist_ok=True)
rows=[]
for lbl in classes:
    if lbl not in preds["gt"].values or lbl not in cls2idx: continue
    sub=preds[preds["gt"]==lbl].sample(min(16, (preds["gt"]==lbl).sum()), random_state=42)
    figs=[]
    for _, r in sub.iterrows():
        fp=r["filepath"]
        try: img=load_rgb(fp)
        except Exception: continue
        cam=gradcam_on_image(img, cls2idx[lbl], target_layer); msk=lesion_mask(img)
        overlap=float((cam*(msk>0)).sum()/(cam.sum()+1e-6))
        rows.append({"label":lbl,"filepath":fp,"overlap":overlap,"pred":r["pred"],"correct":bool(r["correct"])})
        heat=(cv2.applyColorMap((cam*255).astype(np.uint8), cv2.COLORMAP_JET)[:,:,::-1])/255.0
        rgb=cv2.cvtColor(cv2.resize(img,(IM_SIZE,IM_SIZE)), cv2.COLOR_BGR2RGB)/255.0
        figs.append(0.4*heat+0.6*rgb)
    if figs:
        cols=4; rows_n=int(np.ceil(len(figs)/cols)); plt.figure(figsize=(3*cols,3*rows_n))
        for i, im in enumerate(figs,1):
            plt.subplot(rows_n, cols, i); plt.imshow(im); plt.axis("off")
        plt.suptitle(f"Grad-CAM — {lbl}"); plt.tight_layout(); plt.savefig(PANEL/f"cam_{lbl}.png", dpi=200); plt.close()

pd.DataFrame(rows).to_csv(OUT/"cam_overlap_all_species.csv", index=False)
print("Saved CAM overlap/panels to:", str(OUT))

Compare vs previous

In [None]:
# compare_final_vs_prev.py
import json, pandas as pd, numpy as np
from pathlib import Path

# Previous iteration artifacts (from earlier run)
PREV = Path("/kaggle/working/v6_global_eval/new_eval")
prev_sum = json.load(open(PREV/"summary.json"))
prev_rep = pd.read_csv(PREV/"per_class_report.csv")
prev_conf = pd.read_csv(PREV/"top_confusions_new.csv") if (PREV/"top_confusions_new.csv").exists() else None
prev_cam = pd.read_csv(PREV/"cam_overlap_all_species.csv" if (PREV/"cam_overlap_all_species.csv").exists() else PREV/"cam_overlap_all_species-1.csv")

# Current final artifacts
CURR = Path("/kaggle/working/v6_global_eval/final_eval")
curr_sum = json.load(open(CURR/"summary.json"))
curr_rep = pd.read_csv(CURR/"per_class_report.csv")
curr_conf = pd.read_csv(CURR/"top_confusions_new.csv")
curr_cam = pd.read_csv(CURR/"cam_overlap_all_species.csv")

print("Prev macro:", prev_sum)
print("Curr macro:", curr_sum)

# Per-class F1 deltas
rep_join = prev_rep.merge(curr_rep, on="label", suffixes=("_prev","_curr"))
rep_join["d_f1"] = rep_join["f1_curr"] - rep_join["f1_prev"]
print("\nBottom 10 classes (prev) with F1 delta:")
prev_bottom = prev_rep.sort_values("f1").head(10)["label"].tolist()
print(rep_join[rep_join["label"].isin(prev_bottom)][["label","f1_prev","f1_curr","d_f1"]])

# Confusion deltas for overlapping pairs (if prev available)
if prev_conf is not None:
    # Normalize to tuple keys
    pc = prev_conf.copy(); cc = curr_conf.copy()
    # Some files might be JSON-rows vs CSV; ensure columns exist
    for df in (pc, cc):
        if "true" not in df.columns or "pred" not in df.columns or "count" not in df.columns:
            raise RuntimeError("Confusion CSV must have true, pred, count columns")
    key = ["true","pred"]
    conf_merge = pc.merge(cc, on=key, how="outer", suffixes=("_prev","_curr")).fillna(0)
    conf_merge["count_prev"] = conf_merge["count_prev"].astype(int)
    conf_merge["count_curr"] = conf_merge["count_curr"].astype(int)
    conf_merge["d_count"] = conf_merge["count_curr"] - conf_merge["count_prev"]
    print("\nTop pairs by absolute change in confusions:")
    print(conf_merge.sort_values("d_count").head(12)[key+["count_prev","count_curr","d_count"]])

# CAM overlap deltas (per class)
prev_cam_stat = prev_cam.groupby("label")["overlap"].mean().reset_index().rename(columns={"overlap":"mean_prev"})
curr_cam_stat = curr_cam.groupby("label")["overlap"].mean().reset_index().rename(columns={"overlap":"mean_curr"})
cam_join = prev_cam_stat.merge(curr_cam_stat, on="label", how="inner")
cam_join["d_overlap"] = cam_join["mean_curr"] - cam_join["mean_prev"]
print("\nClasses with largest CAM overlap gains:")
print(cam_join.sort_values("d_overlap", ascending=False).head(10))
print("\nClasses with largest CAM overlap drops:")
print(cam_join.sort_values("d_overlap").head(10))

# Save reports
OUT = CURR/"comparison_final"
OUT.mkdir(parents=True, exist_ok=True)
rep_join.sort_values("d_f1").to_csv(OUT/"per_class_f1_delta.csv", index=False)
if prev_conf is not None:
    conf_merge.sort_values("d_count").to_csv(OUT/"confusion_delta.csv", index=False)
cam_join.sort_values("d_overlap").to_csv(OUT/"cam_overlap_delta.csv", index=False)
print("\nSaved comparison files to:", OUT)