In [46]:
import os
import shutil
import random
from pathlib import Path
from PIL import Image

import torch
from torchvision import models, transforms
import torch.nn.functional as F
from tqdm import tqdm

In [48]:
# -----------------------
# Paths and parameters
# -----------------------
original_base    = Path(r"C:/plantvillage")
train_src        = original_base / "train"
val_src          = original_base / "val"

clean_root       = Path(r"D:/cleaned_plantvillage")  # New cleaned copy root
clean_train_src  = clean_root / "train"
clean_val_src    = clean_root / "val"

fewshot_base     = Path(r"D:/fewshot_dataset")
fewshot_train    = fewshot_base / "train"
fewshot_val      = fewshot_base / "val"

similarity_threshold = 0.7   # keep classes with avg_similarity >= 0.7
k_shot               = 10    # support samples per class
val_shot             = 10    # query samples per class

# -----------------------
# Setup device, model, transforms
# -----------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])


In [50]:
# EfficientNet-B0 backbone without classifier head
backbone = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1)
backbone.classifier = torch.nn.Identity()
backbone = backbone.to(device).eval()

def get_embeddings(image_paths):
    embs = []
    for p in image_paths:
        img = Image.open(p).convert("RGB")
        x = preprocess(img).unsqueeze(0).to(device)
        with torch.no_grad():
            feat = backbone(x)
        feat = F.normalize(feat, p=2, dim=1)
        embs.append(feat[0].cpu())
    return torch.stack(embs)


In [52]:
# -----------------------
# 0) Prepare cleaned copy of the dataset (no deletions on original)
# -----------------------
if clean_root.exists():
    shutil.rmtree(clean_root)
clean_train_src.mkdir(parents=True, exist_ok=True)
clean_val_src.mkdir(parents=True, exist_ok=True)

# Copy original train and val folders into cleaned directory
for cls_dir in sorted(train_src.iterdir()):
    if cls_dir.is_dir():
        shutil.copytree(cls_dir, clean_train_src / cls_dir.name)
for cls_dir in sorted(val_src.iterdir()):
    if cls_dir.is_dir():
        shutil.copytree(cls_dir, clean_val_src / cls_dir.name)


In [54]:
# -----------------------
# 1) Cleaning: delete classes with low intra-class similarity
# -----------------------
kept_classes = []
for cls_dir in tqdm(sorted(clean_train_src.iterdir()), desc="Cleaning classes"):
    if not cls_dir.is_dir():
        continue

    imgs = [p for p in cls_dir.iterdir() if p.suffix.lower() in (".jpg",".jpeg",".png")]
    if len(imgs) < 2:
        continue

    emb = get_embeddings(imgs)
    sim_matrix = emb @ emb.t()
    n = sim_matrix.size(0)
    sims = sim_matrix[~torch.eye(n, dtype=torch.bool)].view(n, n-1)
    avg_sim = sims.mean().item()

    if avg_sim < similarity_threshold:
        # delete from cleaned train and val
        shutil.rmtree(cls_dir)
        val_cls = clean_val_src / cls_dir.name
        if val_cls.exists():
            shutil.rmtree(val_cls)
    else:
        kept_classes.append(cls_dir.name)

# -----------------------
# 1.1) Rebuild cleaned val to contain only original val images of kept classes
# -----------------------
if clean_val_src.exists():
    shutil.rmtree(clean_val_src)
clean_val_src.mkdir(parents=True, exist_ok=True)
for cls in kept_classes:
    src = val_src / cls
    if src.exists():
        shutil.copytree(src, clean_val_src / cls)

# Show remaining classes after cleaning
print("\nRemaining classes after cleaning:")
for name in kept_classes:
    print(f"  {name}")


Cleaning classes: 100%|█████████████████████████████████████████████████████████████| 38/38 [2:20:59<00:00, 222.62s/it]



Remaining classes after cleaning:
  Apple___Cedar_apple_rust
  Blueberry___healthy
  Cherry_(including_sour)___healthy
  Grape___Black_rot
  Grape___Esca_(Black_Measles)
  Grape___healthy
  Grape___Leaf_blight_(Isariopsis_Leaf_Spot)
  Potato___healthy
  Raspberry___healthy
  Soybean___healthy
  Strawberry___healthy
  Tomato___Tomato_mosaic_virus


In [58]:
import os

# Set the dataset directory
dataset_dir = r"D:/cleaned_plantvillage"

# Function to count images in a directory (recursively)
def count_images(root_dir, extensions=('.jpg', '.jpeg', '.png', '.bmp', '.gif')):
    total = 0
    for dirpath, dirnames, filenames in os.walk(root_dir):
        total += sum(1 for file in filenames if file.lower().endswith(extensions))
    return total

# Count total images in dataset
total_images = count_images(dataset_dir)

print(f"Total images in '{dataset_dir}': {total_images}")


Total images in 'D:/cleaned_plantvillage': 13135


In [56]:
# -----------------------
# 2) Build few-shot splits from cleaned copy
# -----------------------
# Remove old few-shot directories
for d in (fewshot_train, fewshot_val):
    if d.exists():
        shutil.rmtree(d)
    d.mkdir(parents=True, exist_ok=True)

# Helper to create few-shot dataset
def create_fewshot_dataset(src_path, dest_path):
    for class_folder in sorted(src_path.iterdir()):
        if not class_folder.is_dir():
            continue
        images = [p for p in class_folder.iterdir() if p.suffix.lower() in (".jpg",".jpeg",".png")]
        if len(images) < k_shot + val_shot:
            print(f"Not enough images in {class_folder.name} to sample {k_shot+val_shot}, skipping.")
            continue

        random.shuffle(images)
        support = images[:k_shot]
        query   = images[k_shot:k_shot + val_shot]

        # copy support images to few-shot train
        sup_dest = dest_path / class_folder.name
        sup_dest.mkdir(parents=True, exist_ok=True)
        for img in support:
            shutil.copy2(img, sup_dest / img.name)

# Create few-shot train and val splits
create_fewshot_dataset(clean_train_src, fewshot_train)
create_fewshot_dataset(clean_val_src, fewshot_val)

# -----------------------
# 3) Verification: print counts
# -----------------------
def count_images(folder_path):
    print(f"\n📁 Checking folder: {folder_path}")
    for class_folder in sorted(folder_path.iterdir()):
        if class_folder.is_dir():
            count = len([f for f in class_folder.iterdir() if f.suffix.lower() in (".jpg",".jpeg",".png")])
            print(f"  {class_folder.name}: {count} images")

count_images(fewshot_train)
count_images(fewshot_val)


📁 Checking folder: D:\fewshot_dataset\train
  Apple___Cedar_apple_rust: 10 images
  Blueberry___healthy: 10 images
  Cherry_(including_sour)___healthy: 10 images
  Grape___Black_rot: 10 images
  Grape___Esca_(Black_Measles): 10 images
  Grape___healthy: 10 images
  Grape___Leaf_blight_(Isariopsis_Leaf_Spot): 10 images
  Potato___healthy: 10 images
  Raspberry___healthy: 10 images
  Soybean___healthy: 10 images
  Strawberry___healthy: 10 images
  Tomato___Tomato_mosaic_virus: 10 images

📁 Checking folder: D:\fewshot_dataset\val
  Apple___Cedar_apple_rust: 10 images
  Blueberry___healthy: 10 images
  Cherry_(including_sour)___healthy: 10 images
  Grape___Black_rot: 10 images
  Grape___Esca_(Black_Measles): 10 images
  Grape___healthy: 10 images
  Grape___Leaf_blight_(Isariopsis_Leaf_Spot): 10 images
  Potato___healthy: 10 images
  Raspberry___healthy: 10 images
  Soybean___healthy: 10 images
  Strawberry___healthy: 10 images
  Tomato___Tomato_mosaic_virus: 10 images


In [60]:
from pathlib import Path

fewshot_base = Path(r"D:/fewshot_dataset")

def count_total_images(base_path):
    totals = {}
    for split in ['train', 'val']:
        split_path = base_path / split
        count = sum(len(list(folder.glob("*.jpg"))) + 
                    len(list(folder.glob("*.jpeg"))) + 
                    len(list(folder.glob("*.png")))
                    for folder in split_path.iterdir() if folder.is_dir())
        totals[split] = count
    return totals

totals = count_total_images(fewshot_base)

print(f"📂 Total images in TRAIN: {totals['train']}")
print(f"📂 Total images in VAL: {totals['val']}")
print(f"📦 TOTAL images in FEWSHOT dataset: {totals['train'] + totals['val']}")


📂 Total images in TRAIN: 120
📂 Total images in VAL: 120
📦 TOTAL images in FEWSHOT dataset: 240


In [83]:
import os
import random
import numpy as np
import argparse
from collections import defaultdict
from pathlib import Path
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset
from torchvision import models, transforms
import h5py
from torch.cuda.amp import autocast, GradScaler

In [85]:
# Set random seeds for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

###############################################
# Helper Functions: Save Model & Prototypes
###############################################
def save_model_h5(state_dict, filename):
    with h5py.File(filename, 'w') as f:
        grp = f.create_group("model")
        for k, v in state_dict.items():
            grp.create_dataset(k, data=v.cpu().numpy())

def save_prototypes_npy(prototypes, filename):
    np.save(filename, prototypes)

In [87]:
###############################################
# Dataset for Few-Shot
###############################################
class PlantDiseaseDataset(Dataset):
    def __init__(self, root_dir, transform=None, augment_factor=1):
        self.root_dir = Path(root_dir)
        self.transform = transform
        self.augment_factor = augment_factor
        self.classes, self.disease_classes = self._find_classes()
        self.class_to_idx = {c: i for i, c in enumerate(self.classes)}
        self.image_cache = []
        self._load_cache()
        self._build_indices()
        if self.augment_factor > 1:
            self._balance()

    def _find_classes(self):
        plants, diseases = set(), set()
        for d in self.root_dir.iterdir():
            if d.is_dir():
                parts = d.name.split('___', 1)
                if len(parts) == 2:
                    plants.add(parts[0]); diseases.add(parts[1])
        return sorted(plants), sorted(diseases)

    def _load_cache(self):
        for d in self.root_dir.iterdir():
            if not d.is_dir(): continue
            parts = d.name.split('___', 1)
            if len(parts) != 2: continue
            plant, disease = parts
            if disease not in self.disease_classes: continue
            p_lbl = self.class_to_idx[plant]
            d_lbl = self.disease_classes.index(disease)
            for imgf in d.iterdir():
                if imgf.suffix.lower() in ('.jpg', '.png', '.jpeg'):
                    img = Image.open(imgf).convert('RGB')
                    self.image_cache.append((img, p_lbl, d_lbl))
        print(f"Loaded {len(self.image_cache)} images from {len(self.classes)} plants and {len(self.disease_classes)} diseases")

    def _build_indices(self):
        self.plant_idx = defaultdict(list)
        self.disease_idx = defaultdict(list)
        for i, (_, p, d) in enumerate(self.image_cache):
            self.plant_idx[p].append(i)
            self.disease_idx[d].append(i)

    def _balance(self):
        counts = {d: len(v) for d, v in self.disease_idx.items()}
        maxc = max(counts.values())
        balanced = list(self.image_cache)
        for d, c in counts.items():
            if c < maxc:
                choices = [x for x in self.image_cache if x[2] == d]
                needed = (maxc - c) * self.augment_factor
                balanced.extend(random.choices(choices, k=needed))
        self.image_cache = balanced
        self._build_indices()

    def __len__(self): return len(self.image_cache)

    def __getitem__(self, idx):
        img, p, d = self.image_cache[idx]
        if self.transform:
            img = self.transform(img)
        return img, p, d

In [89]:
##############################################
# Prototypical Network Model
##############################################
class EfficientProtoNet(nn.Module):
    def __init__(self, num_disease, emb_dim=384):
        super().__init__()
        backbone = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1)
        orig_cls = backbone.classifier
        in_f = orig_cls[1].in_features if isinstance(orig_cls, nn.Sequential) else None
        backbone.classifier = nn.Identity()
        self.backbone = backbone
        self.embed = nn.Sequential(
            nn.Linear(in_f, 768), nn.ReLU(), nn.Dropout(0.7),
            nn.Linear(768, emb_dim), nn.BatchNorm1d(emb_dim)
        )
        self.head = nn.Sequential(
            nn.Linear(emb_dim, 256), nn.ReLU(), nn.Linear(256, num_disease)
        )
        self.l2norm = True

    def forward(self, x):
        f = self.backbone(x)
        e = self.embed(f)
        if self.l2norm:
            e = F.normalize(e, p=2, dim=1)
        logits = self.head(e)
        return e, logits


In [91]:
##############################################
# Loss, Sampling, Prototype Extraction
##############################################
def enhanced_proto_loss(s_sup, l_sup, s_q, l_q, temp=20.0, ls=0.2):
    classes = torch.unique(l_sup)
    protos = torch.stack([s_sup[l_sup == c].mean(0) for c in classes])
    logits = -torch.cdist(s_q, protos) * temp
    new_lbl = torch.tensor([(classes == l).nonzero().item() for l in l_q], device=s_q.device)
    loss = nn.CrossEntropyLoss(label_smoothing=ls)(logits, new_lbl)
    acc = (logits.argmax(1) == new_lbl).float().mean().item()
    return loss, acc

def sample_episode(ds, task, n_way, k_shot, q_query):
    idx_dict = ds.plant_idx if task == 'plant' else ds.disease_idx
    classes = random.sample(list(idx_dict.keys()), n_way)
    s_imgs, s_lbls, q_imgs, q_lbls = [], [], [], []
    for i, c in enumerate(classes):
        idxs = idx_dict[c]
        total = k_shot + q_query
        sel = random.sample(idxs, total) if len(idxs) >= total else random.choices(idxs, k=total)
        for j, idx in enumerate(sel):
            img, p, d = ds[idx]
            if j < k_shot:
                s_imgs.append(img)
                s_lbls.append(i)
            else:
                q_imgs.append(img)
                q_lbls.append(i)
    return (torch.stack(s_imgs), torch.tensor(s_lbls)), (torch.stack(q_imgs), torch.tensor(q_lbls))

def extract_prototypes(model, ds, device, n_way=5, k_shot=10, episodes=10):
    prototype_dict = np.zeros((len(ds.disease_classes), model.embed[-1].num_features))
    for ci in range(len(ds.disease_classes)):
        accumulated = []
        for _ in range(episodes):
            indices = ds.disease_idx[ci]
            sel = random.choices(indices, k=k_shot) if len(indices) < k_shot else random.sample(indices, k_shot)
            imgs = torch.stack([ds[i][0] for i in sel]).to(device)
            with torch.no_grad():
                emb, _ = model(imgs)
            accumulated.append(emb.mean(0).cpu().numpy())
        prototype_dict[ci] = np.mean(accumulated, axis=0)
    return prototype_dict


In [93]:
##############################################
# Epoch Functions
##############################################
def train_epoch(model, ds, opt, device, train_eps, n_way, k_shot, q_query, scaler=None):
    model.train()
    tloss, tacc, p_tot, d_tot, p_cnt, d_cnt = 0, 0, 0, 0, 0, 0
    for _ in range(train_eps):
        task = 'disease' if random.random() < 0.7 else 'plant'
        (s_img, s_lbl), (q_img, q_lbl) = sample_episode(ds, task, n_way, k_shot, q_query)
        s_img, s_lbl, q_img, q_lbl = [x.to(device) for x in (s_img, s_lbl, q_img, q_lbl)]
        if scaler:
            with autocast():
                s_e, _ = model(s_img)
                q_e, _ = model(q_img)
                loss, acc = enhanced_proto_loss(s_e, s_lbl, q_e, q_lbl)
            opt.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(opt)
            scaler.update()
        else:
            s_e, _ = model(s_img)
            q_e, _ = model(q_img)
            loss, acc = enhanced_proto_loss(s_e, s_lbl, q_e, q_lbl)
            opt.zero_grad()
            loss.backward()
            opt.step()
        tloss += loss.item()
        tacc += acc
        if task == 'plant':
            p_tot += acc
            p_cnt += 1
        else:
            d_tot += acc
            d_cnt += 1
    combined_acc = tacc / train_eps
    plant_acc = p_tot / p_cnt if p_cnt else 0
    disease_acc = d_tot / d_cnt if d_cnt else 0
    return {'loss': tloss / train_eps, 'plant_acc': plant_acc, 'disease_acc': disease_acc, 'combined_acc': combined_acc}

def evaluate(model, ds, device, val_eps, n_way, k_shot, q_query, scaler=None):
    model.eval()
    plant_acc, disease_acc = 0, 0
    with torch.no_grad():
        for task in ['plant', 'disease']:
            for _ in range(val_eps):
                (s_img, s_lbl), (q_img, q_lbl) = sample_episode(ds, task, n_way, k_shot, q_query)
                s_img, s_lbl, q_img, q_lbl = [x.to(device) for x in (s_img, s_lbl, q_img, q_lbl)]
                if scaler:
                    with autocast():
                        s_e, _ = model(s_img)
                        q_e, _ = model(q_img)
                        acc = enhanced_proto_loss(s_e, s_lbl, q_e, q_lbl)[1]
                else:
                    s_e, _ = model(s_img)
                    q_e, _ = model(q_img)
                    acc = enhanced_proto_loss(s_e, s_lbl, q_e, q_lbl)[1]
                if task == 'plant':
                    plant_acc += acc
                else:
                    disease_acc += acc
    return {'plant_acc': plant_acc / val_eps, 'disease_acc': disease_acc / val_eps, 'avg_acc': (plant_acc + disease_acc) / (2 * val_eps)}

In [96]:
##############################################
# Main Execution
##############################################
def main(train_dir, valid_dir, args):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    set_seed(args.seed)
    scaler = GradScaler() if device.type == 'cuda' else None

    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224, scale=(0.65, 1.0)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ColorJitter(brightness=0.3, contrast=0.3),
        transforms.RandomRotation(15),
        transforms.RandomVerticalFlip(p=0.2),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225]),
        transforms.RandomErasing(p=0.2)
    ])
    val_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])
    ])

    train_ds = PlantDiseaseDataset(train_dir, train_transform, augment_factor=2)
    val_ds   = PlantDiseaseDataset(valid_dir, val_transform)
    model    = EfficientProtoNet(len(train_ds.disease_classes)).to(device)
    optimizer= optim.AdamW(model.parameters(), lr=args.lr, weight_decay=5e-4)
    scheduler= ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3)

    best_acc, patience_count, overfit_count = 0.0, 0, 0
    for epoch in range(args.epochs):
        train_metrics = train_epoch(model, train_ds, optimizer, device, args.train_eps, 5, 5, 10, scaler)
        val_metrics   = evaluate(model, val_ds, device, args.val_eps, 5, 5, 15, scaler)
        scheduler.step(val_metrics['avg_acc'])
        gap = train_metrics['combined_acc'] - val_metrics['avg_acc']

        print(f"Epoch {epoch+1}/{args.epochs}:")
        print(f"  Train Loss: {train_metrics['loss']:.4f}")
        print(f"  Train Plant Acc: {train_metrics['plant_acc']:.2%} | Train Disease Acc: {train_metrics['disease_acc']:.2%} | Combined Train Acc: {train_metrics['combined_acc']:.2%}")
        print(f"  Val Plant Acc: {val_metrics['plant_acc']:.2%} | Val Disease Acc: {val_metrics['disease_acc']:.2%} | Combined Val Acc: {val_metrics['avg_acc']:.2%}")

        if val_metrics['avg_acc'] > best_acc:
            best_acc = val_metrics['avg_acc']
            patience_count = 0
            overfit_count = 0
            save_model_h5(model.state_dict(), "best_model.h5")
            prototypes = extract_prototypes(model, val_ds, device)
            save_prototypes_npy(prototypes, "prototypes.npy")
            print("Best model and prototypes saved!")
        else:
            patience_count += 1
            if gap > 0.05:
                overfit_count += 1
                print(f"Warning: Overfitting suspected (gap={gap:.2%})")
            else:
                overfit_count = 0
            if patience_count >= args.patience:
                print("Early stopping triggered due to no improvement")
                break
            if overfit_count >= 3:
                print("Early stopping triggered due to overfitting")
                break

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--lr', type=float, default=5e-5)
    parser.add_argument('--epochs', type=int, default=20)
    parser.add_argument('--patience', type=int, default=5)
    parser.add_argument('--train_eps', type=int, default=20)
    parser.add_argument('--val_eps', type=int, default=10)
    args, _ = parser.parse_known_args()
    main(r"D:/fewshot_dataset/train", r"D:/fewshot_dataset/val", args)

Using device: cpu
Loaded 120 images from 9 plants and 6 diseases
Loaded 120 images from 9 plants and 6 diseases
Epoch 1/20:
  Train Loss: 1.4206
  Train Plant Acc: 40.33% | Train Disease Acc: 49.14% | Combined Train Acc: 46.50%
  Val Plant Acc: 89.07% | Val Disease Acc: 93.47% | Combined Val Acc: 91.27%
Best model and prototypes saved!
Epoch 2/20:
  Train Loss: 1.1427
  Train Plant Acc: 65.00% | Train Disease Acc: 72.33% | Combined Train Acc: 69.40%
  Val Plant Acc: 94.40% | Val Disease Acc: 96.53% | Combined Val Acc: 95.47%
Best model and prototypes saved!
Epoch 3/20:
  Train Loss: 1.0142
  Train Plant Acc: 66.00% | Train Disease Acc: 86.43% | Combined Train Acc: 80.30%
  Val Plant Acc: 93.60% | Val Disease Acc: 94.53% | Combined Val Acc: 94.07%
Epoch 4/20:
  Train Loss: 0.9302
  Train Plant Acc: 77.67% | Train Disease Acc: 91.14% | Combined Train Acc: 87.10%
  Val Plant Acc: 91.87% | Val Disease Acc: 95.47% | Combined Val Acc: 93.67%
Epoch 5/20:
  Train Loss: 0.8720
  Train Plant Acc

In [3]:
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from PIL import Image
from pathlib import Path
import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from tqdm import tqdm
import h5py

# 1) Dataset (same as before)
class PlantDiseaseDataset(Dataset):
    def __init__(self, root_dir, transform):
        self.transform = transform
        self.samples, diseases = [], set()
        for split in ("train","val"):
            for cls in (Path(root_dir)/split).iterdir():
                if cls.is_dir() and "___" in cls.name:
                    disease = cls.name.split("___",1)[1]
                    diseases.add(disease)
                    for imgf in cls.iterdir():
                        if imgf.suffix.lower() in (".jpg",".jpeg",".png"):
                            self.samples.append((imgf, disease))
        self.disease_classes = sorted(diseases)
        self.d2i = {d:i for i,d in enumerate(self.disease_classes)}
        self.samples = [(p, self.d2i[d]) for p,d in self.samples]

    def __len__(self): return len(self.samples)
    def __getitem__(self, idx):
        p, lbl = self.samples[idx]
        img = Image.open(p).convert("RGB")
        return self.transform(img), lbl

# 2) Model (encoder only)
class EfficientProtoNet(nn.Module):
    def __init__(self, num_disease, emb_dim=384):
        super().__init__()
        backbone = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1)
        in_f = backbone.classifier[1].in_features
        backbone.classifier = nn.Identity()
        self.backbone = backbone
        self.embed = nn.Sequential(
            nn.Linear(in_f, 768), nn.ReLU(), nn.Dropout(0.7),
            nn.Linear(768, emb_dim), nn.BatchNorm1d(emb_dim)
        )
    def forward(self, x):
        f = self.backbone(x)
        e = self.embed(f)
        return F.normalize(e, p=2.0, dim=1)

# 3) Load prototypes & weights
device = torch.device("cpu")
prototypes = torch.tensor(np.load("prototypes.npy"), dtype=torch.float32).to(device)


model = EfficientProtoNet(num_disease=prototypes.shape[0]).to(device).eval()
with h5py.File("best_model.h5",'r') as f:
    sd = {k: torch.tensor(v[()]) for k,v in f['model'].items()}
# load encoder & embed weights
backbone_sd = {k.replace("backbone.",""):v for k,v in sd.items() if k.startswith("backbone.")}
embed_sd    = {k.replace("embed.",""):v for k,v in sd.items() if k.startswith("embed.")}
model.backbone.load_state_dict(backbone_sd, strict=False)
model.embed.load_state_dict(embed_sd, strict=False)

# 4) Transforms & DataLoader (num_workers=0!)
transform = transforms.Compose([
    transforms.Resize(256), transforms.CenterCrop(224),
    transforms.ToTensor(), transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])
ds = PlantDiseaseDataset(r"D:/cleaned_plantvillage", transform)
loader = DataLoader(ds, batch_size=128, shuffle=False, num_workers=0)  # <-- must be 0 on Windows

# 5) Inference with tqdm
all_preds, all_labels = [], []
start = time.time()
for imgs, labels in tqdm(loader, desc="Inference"):
    emb = model(imgs.to(device))             # [B, emb_dim]
    dists = torch.cdist(emb, prototypes)     # [B, num_classes]
    preds = dists.argmin(dim=1).cpu().tolist()
    all_preds.extend(preds)
    all_labels.extend(labels)
end = time.time()

# 6) Metrics
acc = accuracy_score(all_labels, all_preds)
prec, rec, f1, _ = precision_recall_fscore_support(
    all_labels, all_preds, average='macro'
)

print(f"\nTotal time: {end-start:.1f}s")
print("=== Full-dataset metrics ===")
print(f"Accuracy:  {acc*100:.2f}%")
print(f"Precision: {prec*100:.2f}%")
print(f"Recall:    {rec*100:.2f}%")
print(f"F1 Score:  {f1*100:.2f}%")


Inference: 100%|███████████████████████████████████████████████████████████████████| 103/103 [1:11:54<00:00, 41.89s/it]



Total time: 4314.5s
=== Full-dataset metrics ===
Accuracy:  97.77%
Precision: 95.46%
Recall:    96.40%
F1 Score:  95.83%
