## **1. Purpose of This Notebook**

After obtaining a robust Stage-6 EfficientNet-B3 baseline, this notebook focuses on:

* Selective class curation
* Fast adaptation to a reduced label space
* Targeted robustness improvements for difficult classes
* Final model selection

The goal here is not exploration, but controlled fine-tuning.

## **2. Experiment Overview**

**Fine-tune A**:	Curate dataset by removing noisy / weak classes

**Fine-tune B**:	Targeted robustness fine-tuning with class weighting

Both experiments reuse the same backbone weights and same training utilities.

## **3. Fine-Tune A — Class Curation & Label Space Reduction**
3.1 Motivation

Some classes in the original dataset:

* have inconsistent annotations
* are underrepresented
* are visually ambiguous
* or are not critical for the final objective

Removing them:

* reduces label noise
* simplifies the decision space
* improves per-class performance on retained classes

---

## **Tomato — Special Case Within Class Curation**

During the review phase, tomato classes stood out:

* They were **well-represented** in the dataset
* But also **highly confused among themselves** (blight vs. septoria vs. bacterial spot, etc.)
* CAM analysis showed attention drifting toward **background textures instead of lesions**
* UMAP revealed **overlapping feature clusters** for multiple tomato diseases

Because tomato is both **important** and **internally ambiguous**, it wasn’t a candidate for removal — instead, it required **focused refinement**.

Rather than retraining the whole network, we:

1. **Extracted only tomato classes** into a smaller label space.
2. **Loaded the global EfficientNet-B3 backbone** (frozen implicitly).
3. **Trained only the classifier head** on tomato images.
4. Oversampled the four most problematic diseases:

   * tomato_bacterial_spot
   * tomato_septoria_leaf_spot
   * tomato_early_blight
   * tomato_late_blight

This approach:

* preserved the backbone’s global knowledge
* allowed rapid specialization on subtle tomato lesion variations
* reduced cross-class confusion without harming other crops

We validated the refinement using:

* **Grad-CAM overlap** → attention shifted correctly toward diseased regions
* **UMAP embeddings** → tomato clusters separated more clearly
* **counterfactual edits** → lesion-enhancing transformations increased correct logits

The refined tomato head is then prepared for:

* merging its classifier rows back into the unified model, or
* deployment as a specialized sub-head for tomato-only routing.

3.2 Build Curated Metadata

In [None]:
# Build curated train/val splits after removing unwanted classes

import pandas as pd, json
from pathlib import Path

METADATA_CSV = "outputs/metadata/metadata.csv"

DROP_CLASSES = {
    "diseased_rice","guava_anthracnose","guava_fruit_fly",
    "healthy_corn","healthy_guava","healthy_rice","healthy_sugarcane","healthy_wheat",
    "sugarcane_bacterial_blight","sugarcane_mosaic","sugarcane_red_rot",
    "sugarcane_rust","sugarcane_yellow_leaf_disease",
    "wheat_aphid","wheat_black_rust","wheat_blast","wheat_brown_rust",
    "wheat_common_root_rot","wheat_fusarium_head_blight","wheat_leaf_blight",
    "wheat_mildew","wheat_mite","wheat_septoria","wheat_smut",
    "wheat_stem_fly","wheat_tan_spot","wheat_yellow_rust"
}

meta = pd.read_csv(METADATA_CSV)
curated = meta[~meta["label"].isin(DROP_CLASSES)].copy()

train_df = curated[curated["split"] == "train"]
val_df   = curated[curated["split"] == "val"]

classes = sorted(train_df["label"].unique())
label2idx = {c: i for i, c in enumerate(classes)}

OUT_DIR = Path("outputs/metadata")
OUT_DIR.mkdir(parents=True, exist_ok=True)

train_df.to_csv(OUT_DIR / "curated_train.csv", index=False)
val_df.to_csv(OUT_DIR / "curated_val.csv", index=False)
with open(OUT_DIR / "label2idx.json", "w") as f:
    json.dump(label2idx, f, indent=2)

print(f"Curated classes: {len(classes)}")


## **4. Fine-Tune A — Backbone Reuse + New Head**
4.1 Key Idea

* Reuse backbone weights from the best robust model
* Re-initialize the classifier head for the new class count
* Perform:
    * short head-only warm-up
    * followed by full fine-tuning with MixUp/CutMix

4.2 Curated Fine-Tune Training

In [None]:
# Load curated data
train_df = pd.read_csv("outputs/metadata/curated_train.csv")
val_df   = pd.read_csv("outputs/metadata/curated_val.csv")

with open("outputs/metadata/label2idx.json") as f:
    label2idx = json.load(f)

num_classes = len(label2idx)

train_ds = PlantDataset(train_df, get_train_transform_hardened(), label2idx)
val_ds   = PlantDataset(val_df,   get_val_transform(),           label2idx)

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
val_loader   = DataLoader(val_ds, batch_size=64)

# Initialize model
model = create_efficientnet_b3(num_classes).to(DEVICE)

# Load backbone weights only
ckpt = torch.load("outputs/checkpoints/effb3_320_std_domain_mix_v2.pt", map_location="cpu")
state = ckpt["model"]
model_state = model.state_dict()
model_state.update({k: v for k, v in state.items() if not k.startswith("classifier")})
model.load_state_dict(model_state, strict=False)

criterion = LabelSmoothingCE(0.1)
scaler = torch.amp.GradScaler("cuda")


Phase A — Head Warm-Up

In [None]:
set_requires_grad(model, False)
set_requires_grad(model.classifier, True)

optimizer = torch.optim.AdamW(model.classifier.parameters(), lr=3e-4)

for ep in range(2):
    tr = train_one_epoch(model, train_loader, optimizer, scaler, criterion, DEVICE)
    va_loss, va_acc, va_f1, va_top3, _ = validate(model, val_loader, criterion, DEVICE)
    print(f"[Warmup {ep+1}] val loss {va_loss:.4f}")


Phase B — Full Fine-Tune

In [None]:
set_requires_grad(model, True)

optimizer = torch.optim.AdamW(
    get_param_groups(model, head_lr=1.5e-4, backbone_lr=5e-5, weight_decay=1e-2)
)

best_val = float("inf")
best_state = None

for ep in range(8):
    tr = train_one_epoch_mix(
        model, train_loader, optimizer, scaler, DEVICE,
        mixup_alpha=0.3, cutmix_alpha=0.3
    )
    va_loss, va_acc, va_f1, va_top3, va_cm = validate(model, val_loader, criterion, DEVICE)

    if va_loss < best_val:
        best_val = va_loss
        best_state = {"model": model.state_dict(), "label2idx": label2idx}

torch.save(best_state, "outputs/checkpoints/effb3_320_curated.pt")


## **5. Fine-Tune B — Targeted Robustness Training**
5.1 Motivation

Some classes consistently:

* confuse with visually similar diseases
* show higher variance across lighting/backgrounds

Instead of re-balancing globally, you targeted these classes explicitly.

5.2 Strategy

* Use WeightedRandomSampler to up-sample focus classes
* Apply hardened augmentations
* Use milder MixUp/CutMix
* Very short fine-tuning (avoid overfitting)

5.3 Targeted Fine-Tune

In [None]:
FOCUS_CLASSES = {
    "corn_common_rust","bean_angular_leaf_spot","bean_rust",
    "diseased_cucumber","healthy_cucumber","healthy_bean",
    "healthy_groundnut","healthy_pumpkin",
    "pumpkin_bacterial_leaf_spot","pumpkin_downy_mildew",
    "pumpkin_mosaic_disease","pumpkin_powdery_mildew"
}

weights = np.array([
    1.5 if lbl in FOCUS_CLASSES else 1.0
    for lbl in train_df["label"]
], dtype=np.float32)

sampler = WeightedRandomSampler(weights, len(weights), replacement=True)

train_loader = DataLoader(train_ds, batch_size=32, sampler=sampler)
val_loader   = DataLoader(val_ds, batch_size=64)

# Resume from curated checkpoint
model = create_efficientnet_b3(num_classes).to(DEVICE)
ckpt = torch.load("outputs/checkpoints/effb3_320_curated.pt", map_location="cpu")
model.load_state_dict(ckpt["model"], strict=True)


In [None]:
# Short targeted fine-tune
optimizer = torch.optim.AdamW(
    get_param_groups(model, head_lr=1.5e-4, backbone_lr=5e-5, weight_decay=1e-2)
)

best_val = float("inf")

for ep in range(4):
    tr = train_one_epoch_mix(
        model, train_loader, optimizer, scaler, DEVICE,
        mixup_alpha=0.2, cutmix_alpha=0.2
    )
    va_loss, va_acc, va_f1, va_top3, va_cm = validate(model, val_loader, criterion, DEVICE)

    if va_loss < best_val:
        best_val = va_loss
        best_state = {"model": model.state_dict(), "label2idx": label2idx}

torch.save(best_state, "outputs/checkpoints/effb3_320_curated_hardened.pt")


## **Section 4 — Cotton-Specific Investigation & Decision**

This section documents a focused investigation into cotton classes, which were found to negatively impact overall model performance despite multiple corrective attempts.

### 4.1 Motivation: Why Cotton Was Investigated Separately

During validation analysis of the curated and hardened models, cotton classes consistently showed:

* unstable predictions
* reduced generalization
* negative spillover effects on non-cotton crops

Initial inspection suggested that **many cotton images appeared black-and-white**, leading to the hypothesis that **color-based pretrained features were mismatched** for this species.

### 4.2 Hypothesis 1 — Cotton Is Mostly Grayscale → Needs Special Handling

Assumption

If cotton images are primarily grayscale, a cotton-only fine-tune with dedicated preprocessing may help the model adapt.

This motivated:

* a **cotton-only audit**
* removal of the worst cotton class (`cotton_aphids`)
* grayscale / quality filtering
* cotton-focused fine-tuning

### 4.3 Cotton Audit & Cleaning (Evidence-Driven)

What Was Checked

For **train and validation splits**, cotton images were audited for:

* grayscale images
* black borders / letterboxing
* over-exposure
* extreme quality issues

#### Key Finding (Very Important)

> Cotton images were **not consistently grayscale**.
> Instead, they exhibited **mixed color spaces, heavy degradation, and poor acquisition quality**.

This invalidated the original hypothesis.

#### Outcome

* Significant portion of cotton data removed
* Remaining data still **highly inconsistent**
* Visual quality far below other crops

### 4.4 Experiment 1 — Replace Cotton with Cleaned Subset

#### Strategy

* Keep all non-cotton data unchanged
* Replace cotton rows with cleaned cotton images
* Reuse **curated hardened backbone**
* Apply **mild MixUp/CutMix**
* Short fine-tuning with early stopping

#### Intent

> Test whether *partial cotton cleanup* can stop performance degradation.

### 4.5 Experiment 2 — Cotton-Focused Fine-Tune (Final Attempt)

#### Key Design Choices

* Resume from **best curated hardened model**
* Use **hardened augmentations**
* Very **mild MixUp/CutMix**
* Conservative learning rates
* Early stopping
* Weighted sampling to stabilize cotton

## 4.6 Results & Failure Analysis

Despite:

* cotton-only cleaning
* grayscale handling
* class removal
* hardened augmentations
* weighted sampling
* conservative fine-tuning

### Observed Behavior

* Validation loss oscillated
* Confusion increased in **non-cotton classes**
* Generalization degraded
* Cotton features did not transfer reliably

### Root Cause (Key Insight)

> Cotton data quality was **systematically different** from other crops, not just grayscale.
> Many images were:

* poorly exposed
* low resolution
* inconsistently labeled
* captured under non-representative conditions

This caused **negative transfer**.

---

## 4.7 Final Decision — Remove Cotton Entirely

### Decision Rationale

* Cotton classes were **hurting the model**
* Fixing them required **dataset-level intervention**, not modeling tricks
* Retaining cotton would reduce trustworthiness of the system

### Final Action

> **All cotton species were removed from the final training set.**

This decision was:

* evidence-based
* validated through multiple experiments
* aligned with production ML best practices

4.1 Cotton Audit & Cleaning

In [None]:
import cv2
import numpy as np
import pandas as pd
from pathlib import Path
from tqdm import tqdm

TRAIN_ROOT = Path("/kaggle/input/plant-disease-detection-dataset-master-version/MasterDataset/train")
VAL_ROOT   = Path("/kaggle/input/plant-disease-detection-dataset-master-version/MasterDataset/val")

COTTON_KEEP = [
    "cotton_bacterial_blight",
    "cotton_powdery_mildew",
    "cotton_target_spot",
    "healthy_cotton",
]
COTTON_DROP = {"cotton_aphids"}

EXTS = {".jpg",".jpeg",".png",".bmp",".webp"}

def is_grayscale(img, tol=1):
    b,g,r = cv2.split(img)
    return (
        np.max(np.abs(r.astype(int)-g.astype(int))) <= tol and
        np.max(np.abs(r.astype(int)-b.astype(int))) <= tol
    )

def has_black_bars(img, thr=5, frac=0.95):
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    return (
        (gray[0,:] < thr).mean() > frac or
        (gray[-1,:] < thr).mean() > frac or
        (gray[:,0] < thr).mean() > frac or
        (gray[:,-1] < thr).mean() > frac
    )

def is_overexposed(img, v_thr=245, frac=0.15):
    v = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)[...,2]
    return (v > v_thr).mean() > frac

def scan_split(root):
    rows = []
    for cls in COTTON_KEEP:
        d = root / cls
        if not d.exists(): 
            continue
        for p in d.rglob("*"):
            if p.suffix.lower() not in EXTS: 
                continue
            img = cv2.imread(str(p))
            if img is None: 
                continue
            rows.append({
                "filepath": str(p),
                "label": cls,
                "grayscale": is_grayscale(img),
                "black_bars": has_black_bars(img),
                "overexposed": is_overexposed(img)
            })
    return pd.DataFrame(rows)

train_audit = scan_split(TRAIN_ROOT)
val_audit   = scan_split(VAL_ROOT)

train_audit.to_csv("cotton_audit_train.csv", index=False)
val_audit.to_csv("cotton_audit_val.csv", index=False)

print("Train audit rows:", len(train_audit))
print("Val audit rows:", len(val_audit))


4.2 Clean Cotton Subset

In [None]:
def clean_cotton(df):
    keep = []
    for _, r in tqdm(df.iterrows(), total=len(df)):
        if r["grayscale"] or r["overexposed"]:
            continue
        keep.append(r)
    return pd.DataFrame(keep)

clean_train = clean_cotton(train_audit)
clean_val   = clean_cotton(val_audit)

clean_train[["filepath","label"]].to_csv("cotton_clean_train.csv", index=False)
clean_val[["filepath","label"]].to_csv("cotton_clean_val.csv", index=False)

print("Clean train:", len(clean_train))
print("Clean val:", len(clean_val))


4.3 Replace Cotton Rows in Curated Dataset

In [None]:
META_DIR = Path("outputs/metadata")

train_all = pd.read_csv(META_DIR / "curated_train.csv")
val_all   = pd.read_csv(META_DIR / "curated_val.csv")

cot_tr = pd.read_csv("cotton_clean_train.csv")
cot_va = pd.read_csv("cotton_clean_val.csv")

train_final = pd.concat([
    train_all[~train_all["label"].str.startswith("cotton")],
    cot_tr
], ignore_index=True)

val_final = pd.concat([
    val_all[~val_all["label"].str.startswith("cotton")],
    cot_va
], ignore_index=True)

train_final.to_csv(META_DIR / "curated_train_cotton_fix.csv", index=False)
val_final.to_csv(META_DIR / "curated_val_cotton_fix.csv", index=False)

print("Final train:", len(train_final))
print("Final val:", len(val_final))


4.4 Cotton-Focused Fine-Tune

In [None]:
import torch
from torch.utils.data import DataLoader

CKPT_DIR = Path("outputs/checkpoints")
BASE_CKPT = CKPT_DIR / "effb3_320_curated_hardened.pt"

train_df = pd.read_csv(META_DIR / "curated_train_cotton_fix.csv")
val_df   = pd.read_csv(META_DIR / "curated_val_cotton_fix.csv")

label2idx = torch.load(BASE_CKPT, map_location="cpu")["label2idx"]

train_ds = PlantDataset(train_df, get_train_transform_hardened(), label2idx)
val_ds   = PlantDataset(val_df,   get_val_transform(),           label2idx)

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
val_loader   = DataLoader(val_ds, batch_size=64)

model = create_efficientnet_b3(len(label2idx)).to(DEVICE)
model.load_state_dict(torch.load(BASE_CKPT)["model"], strict=True)

criterion = LabelSmoothingCE(0.1)
scaler = torch.amp.GradScaler("cuda")

# Warmup
set_requires_grad(model, False)
set_requires_grad(model.classifier, True)
opt = torch.optim.AdamW(model.classifier.parameters(), lr=3e-4)

train_one_epoch(model, train_loader, opt, scaler, criterion, DEVICE)

# Fine-tune
set_requires_grad(model, True)
opt = torch.optim.AdamW(
    get_param_groups(model, 1.5e-4, 5e-5, 1e-2)
)

best = None
best_loss = float("inf")

for ep in range(4):
    tr = train_one_epoch_mix(
        model, train_loader, opt, scaler, DEVICE,
        mixup_alpha=0.12, cutmix_alpha=0.12
    )
    va_loss, va_acc, va_f1, va_top3, _ = validate(model, val_loader, criterion, DEVICE)

    print(f"[Cotton FT {ep+1}] val loss {va_loss:.4f}")

    if va_loss < best_loss:
        best_loss = va_loss
        best = {"model": model.state_dict(), "label2idx": label2idx}

torch.save(best, CKPT_DIR / "effb3_320_curated_cotton_fix.pt")

# **Section 5: Tomato Head-Only Refinement — Full Experiment**

Setup & Data Filters (Tomato-only splits)

In [None]:
import os, json, torch, timm, numpy as np, pandas as pd
from pathlib import Path
import cv2
import torch.nn.functional as F
from torch import nn

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IM_SIZE = 320
IM_MEAN = [0.485, 0.456, 0.406]
IM_STD  = [0.229, 0.224, 0.225]
TEMP = 0.551

TRAIN_CSV = "/kaggle/working/outputs/metadata/curated_train_v6.csv"
VAL_CSV   = "/kaggle/working/outputs/metadata/curated_val_v6.csv"
MAP_JSON  = "/kaggle/working/outputs/metadata/label2idx_v6.json"
CKPT_PATH = "/kaggle/input/k/adiithape/cnn-model-v3/outputs/checkpoints/effb3_320_curated_hardened_v7.pt"

train_df = pd.read_csv(TRAIN_CSV)
val_df   = pd.read_csv(VAL_CSV)
with open(MAP_JSON, "r") as f: label2idx = json.load(f)

def is_tomato(lbl): return isinstance(lbl, str) and lbl.startswith("tomato_")

train_tom = train_df[train_df["label"].apply(is_tomato)].reset_index(drop=True)
val_tom   = val_df[val_df["label"].apply(is_tomato)].reset_index(drop=True)

print("Tomato train / val:", len(train_tom), len(val_tom))

Augmentations + Dataset + Loaders
import

In [None]:
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

def get_tomato_train_tf():
    return A.Compose([
        A.RandomResizedCrop((IM_SIZE,IM_SIZE), (0.85,1.0), (0.9,1.1), cv2.INTER_AREA, p=1),
        A.HorizontalFlip(p=0.5),
        A.ShiftScaleRotate(0.02,0.05,20, border_mode=cv2.BORDER_REFLECT_101, p=0.6),
        A.ColorJitter(0.12,0.12,0.12,0.05,p=0.6),
        A.CLAHE(2.0,(8,8),p=0.3),
        A.GaussianBlur(3,p=0.08),
        A.ImageCompression(60,95,p=0.25),
        A.Normalize(IM_MEAN,IM_STD),
        ToTensorV2(),
    ])

def get_val_tf():
    return A.Compose([
        A.LongestMaxSize(IM_SIZE),
        A.PadIfNeeded(IM_SIZE,IM_SIZE),
        A.Normalize(IM_MEAN,IM_STD),
        ToTensorV2(),
    ])

class TomatoDS(Dataset):
    def __init__(self, df, l2i, tf):
        self.df=df.reset_index(drop=True); self.l2i=l2i; self.tf=tf
    def __len__(self): return len(self.df)
    def __getitem__(self,i):
        r=self.df.iloc[i]
        img=cv2.imread(r["filepath"]); img=cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
        return self.tf(image=img)["image"], self.l2i[r["label"]]

tom_labels = sorted(pd.concat([train_tom["label"], val_tom["label"]]).unique())
l2i_tom = {l:i for i,l in enumerate(tom_labels)}

train_ds = TomatoDS(train_tom, l2i_tom, get_tomato_train_tf())
val_ds   = TomatoDS(val_tom,   l2i_tom, get_val_tf())

w = np.where(train_tom["label"].isin({
    "tomato_bacterial_spot","tomato_septoria_leaf_spot",
    "tomato_early_blight","tomato_late_blight"
}),2.5,1.0)

sampler = WeightedRandomSampler(w, len(w), replacement=True)

train_loader = DataLoader(train_ds,32,sampler=sampler)
val_loader   = DataLoader(val_ds,64,shuffle=False)

print("Tomato labels:", tom_labels)

Train Tomato-only head (freeze backbone)

In [None]:
import torch.nn as nn
from copy import deepcopy

model = timm.create_model("efficientnet_b3", pretrained=False, num_classes=len(l2i_tom)).to(DEVICE)
base = torch.load(CKPT_PATH,map_location="cpu")
state = base["model"] if isinstance(base,dict) and "model" in base else base

sd=model.state_dict(); loaded=0
for k,v in state.items():
    if not k.startswith("classifier") and k in sd and sd[k].shape==v.shape:
        sd[k]=v; loaded+=1
model.load_state_dict(sd, strict=False)
print("Backbone tensors loaded:", loaded)

criterion = nn.CrossEntropyLoss(label_smoothing=0.06)
opt = torch.optim.AdamW([
    {"params":[p for n,p in model.named_parameters() if "classifier" not in n], "lr":5e-5},
    {"params":[p for n,p in model.named_parameters() if "classifier" in n],      "lr":1.5e-4},
])

best = {"loss":1e9,"state":deepcopy(model.state_dict())}

def run_epoch(loader):
    model.train(); total=0
    for x,y in loader:
        x,y=x.to(DEVICE), y.to(DEVICE)
        opt.zero_grad(); loss=criterion(model(x),y)
        loss.backward(); opt.step()
        total+=loss.item()*x.size(0)
    return total/len(loader.dataset)

@torch.no_grad()
def evaluate(loader):
    model.eval(); tot=0; hit=0
    for x,y in loader:
        x,y=x.to(DEVICE),y.to(DEVICE)
        z=model(x); tot+=criterion(z,y).item()*x.size(0)
        hit+=(z.argmax(1)==y).sum().item()
    return tot/len(loader.dataset), hit/len(loader.dataset)

for ep in range(2):
    tr = run_epoch(train_loader)
    vl, acc = evaluate(val_loader)
    print(f"ep {ep+1} | train {tr:.4f} | val {vl:.4f} | acc {acc:.4f}")
    if vl < best["loss"]:
        best = {"loss":vl,"state":deepcopy(model.state_dict())}

model.load_state_dict(best["state"])
OUT = Path("/kaggle/working/tomato_refine"); OUT.mkdir(exist_ok=True)
TOM_CKPT = OUT/"effb3_tomato_refined.pt"
torch.save({"model":model.state_dict(),"label2idx":l2i_tom}, TOM_CKPT)
print("Saved:", TOM_CKPT)

Quick results + CAM sanity check

In [None]:
def top1(csv): return (pd.read_csv(csv)["pred"]==pd.read_csv(csv)["gt"]).mean()
for T in [0.5,0.6,0.7]:
    print(T, top1(OUT/f"tomato_val_refined_T{T}.csv"))

In [None]:
from torch import nn
def get_last_conv(m):
    last=None
    for _,mod in m.named_modules():
        if isinstance(mod, nn.Conv2d): last=mod
    return last

def lesion_mask(img):
    L=cv2.createCLAHE(2.0,(8,8)).apply(cv2.cvtColor(img,cv2.COLOR_BGR2LAB)[:,:,0])
    e=cv2.Canny(L,60,120); e=cv2.dilate(e,np.ones((3,3),np.uint8),1)
    return cv2.resize((e>0).astype(np.uint8),(IM_SIZE,IM_SIZE))

target=get_last_conv(model_t)

def gradcam(img,cls):
    acts=[]; grads=[]
    def f(m,i,o): acts.append(o.detach())
    def b(m,gi,go): grads.append(go[0].detach())
    h1=target.register_forward_hook(f); h2=target.register_full_backward_hook(b)
    x=dual_logits(model_t,img,None).unsqueeze(0).to(DEVICE)
    model_t.zero_grad(); out=model_t(x); out[0,cls].backward()
    h1.remove(); h2.remove()
    a=acts[-1][0]; g=grads[-1][0]; w=g.mean((1,2),True)
    cam=(w*a).sum(0).cpu().numpy(); cam=np.maximum(cam,0)
    cam=cv2.resize(cam,(IM_SIZE,IM_SIZE))
    return (cam-cam.min())/(cam.max()+1e-6)

rows=[]
focus = [
 "tomato_bacterial_spot","tomato_septoria_leaf_spot",
 "tomato_early_blight","tomato_late_blight"
]

for lbl in focus:
    sub=val_tom[val_tom["label"]==lbl].sample(min(12,len(val_tom)),random_state=42)
    for _,r in sub.iterrows():
        img=cv2.imread(r["filepath"])
        m=lesion_mask(img); c=gradcam(img,l2i_tom[lbl])
        rows.append({"label":lbl,"filepath":r["filepath"],
                     "overlap":float((c*(m>0)).sum()/(c.sum()+1e-6))})

pd.DataFrame(rows).to_csv(OUT/"tomato_cam_overlap_refined.csv", index=False)

UMAP

In [None]:
# ===== Tomato UMAP — robust, self-contained =====
import numpy as np, pandas as pd, torch, cv2, importlib
import matplotlib.pyplot as plt
from tqdm import tqdm

# 0) Safety — remove any weird sklearn / UMAP monkey-patches
import sklearn.utils.validation as suv
import umap as umap_module
importlib.reload(suv)
importlib.reload(umap_module)
import umap  # clean handle

# 1) Safe penultimate feature extraction (no AMP, temporary hook)
@torch.no_grad()
def penultimate_feature(img_bgr: np.ndarray) -> np.ndarray:
    feats = []

    def hook(module, inp, out):
        feats.append(out.detach().cpu())

    handle = model.global_pool.register_forward_hook(hook)

    x = preprocess_bgr_center_crop(img_bgr).unsqueeze(0).to(DEVICE).float()
    model.eval()(x)

    handle.remove()
    assert len(feats) == 1, "Feature hook failed."

    v = feats[0].squeeze(0).numpy().astype(np.float32)

    # Defensive cleanup
    if not np.all(np.isfinite(v)):
        v = np.nan_to_num(v, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)

    return v

# 2) Balanced selection + clean matrix builder
def extract_features_clean(df: pd.DataFrame, per_class_cap=200, random_state=42):
    parts = []
    for lbl, grp in df.groupby("label"):
        parts.append(grp.sample(min(per_class_cap, len(grp)), random_state=random_state))
    sub = pd.concat(parts, ignoreindex=True)

    feats, labels, files = [], [], []

    for _, r in tqdm(sub.iterrows(), total=len(sub), desc="Extract tomato feats"):
        fp = r["filepath"]
        img = load_rgb(fp)
        v = penultimate_feature(img)
        feats.append(v)
        labels.append(r["label"])
        files.append(fp)

    X = np.vstack(feats).astype(np.float32)
    y = np.array(labels)
    fpaths = np.array(files)

    # Final cleanup
    X[~np.isfinite(X)] = 0.0
    return X, y, fpaths

# 3) Extract tomato validation embeddings
X, y, fpaths = extract_features_clean(val_tom, per_class_cap=200)

# 4) UMAP embedding (cosine)
reducer = umap.UMAP(
    n_neighbors=15,
    min_dist=0.1,
    metric="cosine",
    random_state=42
)
U = reducer.fit_transform(X)

# 5) Plot & save
plot_df = pd.DataFrame({
    "x": U[:,0],
    "y": U[:,1],
    "label": y,
    "filepath": fpaths
})

plt.figure(figsize=(7,6))
for lbl in sorted(pd.unique(plot_df["label"])):
    pts = plot_df[plot_df["label"] == lbl]
    plt.scatter(pts["x"], pts["y"], s=10, label=lbl, alpha=0.75)

plt.legend(markerscale=2, fontsize=8)
plt.title("UMAP — Tomato diseases (val)")
plt.tight_layout()

png_path = OUT_DIR / "tomato_umap_val.png"
plt.savefig(png_path, dpi=200)
plt.close()
print("Saved UMAP:", png_path)

# 6) Save artifacts
np.save(OUT_DIR/"tomato_X.npy", X)
pd.Series(y).to_csv(OUT_DIR/"tomato_y.csv", index=False, header=False)
pd.Series(fpaths).to_csv(OUT_DIR/"tomato_files.csv", index=False, header=False)