# Indian Medicinal Plant Classifier 



In [5]:
# Environment and version checks
import os, sys, json, random, math, time, gc
from pathlib import Path

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import transforms as T
from torchvision import models as tvm

from PIL import Image
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import classification_report, confusion_matrix, f1_score, accuracy_score

import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
from tqdm.auto import tqdm

print('Python', sys.version)
print('Torch', torch.__version__, '| CUDA available:', torch.cuda.is_available())
print('Torchvision', tvm.__name__)

# Base directories (Kaggle setup)
KAGGLE_INPUT = Path('/kaggle/input')
KAGGLE_WORKING = Path('/kaggle/working')
DEFAULT_LOCAL_ROOT = r"C:\Users\rahul\Downloads\archive\Medicinal plant dataset"

local_env = os.getenv('LOCAL_DATASET_ROOT', DEFAULT_LOCAL_ROOT)
local_path = Path(local_env)

if local_path.exists():
    DATASET_ROOT = local_path
    print("Using local DATASET_ROOT:", DATASET_ROOT)
elif (KAGGLE_INPUT.exists() and (KAGGLE_INPUT / 'indian-medicinal-plant-image-dataset').exists()):
    DATASET_ROOT = KAGGLE_INPUT / 'indian-medicinal-plant-image-dataset' / 'Medicinal plant dataset'
    print("Using Kaggle DATASET_ROOT:", DATASET_ROOT)
else:
    # Fallback: ask interactively (useful when running in notebook locally)
    user_in = input(f"Local dataset not found at {local_path}. Enter dataset path or leave blank to abort: ").strip()
    if user_in:
        DATASET_ROOT = Path(user_in)
    else:
        raise FileNotFoundError(f"Dataset folder not found. Set LOCAL_DATASET_ROOT or provide a valid path.")
    assert DATASET_ROOT.exists(), f"Dataset folder not found at: {DATASET_ROOT}"

# Choose output dir: use Kaggle working when available, otherwise project cwd
if KAGGLE_WORKING.exists() and 'kaggle' in str(KAGGLE_WORKING).lower():
    OUTPUT_DIR = KAGGLE_WORKING / 'impc_outputs'
else:
    OUTPUT_DIR = Path.cwd() / 'impc_outputs'

local_env = os.getenv('LOCAL_DATASET_ROOT', DEFAULT_LOCAL_ROOT)
local_path = Path(local_env)

if local_path.exists():
    DATASET_ROOT = local_path
    print("Using local DATASET_ROOT:", DATASET_ROOT)
elif (KAGGLE_INPUT.exists() and (KAGGLE_INPUT / 'indian-medicinal-plant-image-dataset').exists()):
    DATASET_ROOT = KAGGLE_INPUT / 'indian-medicinal-plant-image-dataset' / 'Medicinal plant dataset'
    print("Using Kaggle DATASET_ROOT:", DATASET_ROOT)
else:
    # Fallback: ask interactively (useful when running in notebook locally)
    user_in = input(f"Local dataset not found at {local_path}. Enter dataset path or leave blank to abort: ").strip()
    if user_in:
        DATASET_ROOT = Path(user_in)
    else:
        raise FileNotFoundError(f"Dataset folder not found. Set LOCAL_DATASET_ROOT or provide a valid path.")
    assert DATASET_ROOT.exists(), f"Dataset folder not found at: {DATASET_ROOT}"

# Choose output dir: use Kaggle working when available, otherwise project cwd
if KAGGLE_WORKING.exists() and 'kaggle' in str(KAGGLE_WORKING).lower():
    OUTPUT_DIR = KAGGLE_WORKING / 'impc_outputs'
else:
    OUTPUT_DIR = Path.cwd() / 'impc_outputs'

OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
print('Data root:', DATASET_ROOT)
print('Output dir:', OUTPUT_DIR)

Python 3.13.5 | packaged by Anaconda, Inc. | (main, Jun 12 2025, 16:37:03) [MSC v.1929 64 bit (AMD64)]
Torch 2.9.0+cpu | CUDA available: False
Torchvision torchvision.models
Using local DATASET_ROOT: C:\Users\rahul\Downloads\archive\Medicinal plant dataset
Using local DATASET_ROOT: C:\Users\rahul\Downloads\archive\Medicinal plant dataset
Data root: C:\Users\rahul\Downloads\archive\Medicinal plant dataset
Output dir: \kaggle\working\impc_outputs


In [6]:
# Configuration (adjust as needed)
from dataclasses import dataclass, asdict

@dataclass
class CFG:
    seed: int = 42
    img_size: int = 256           # training crop size
    train_batch_size: int = 32
    valid_batch_size: int = 64
    num_workers: int = 2
    epochs: int = 10
    base_lr: float = 3e-4
    weight_decay: float = 1e-4
    label_smoothing: float = 0.1
    model_name: str = 'efficientnet_b0'  # ['efficientnet_b0','resnet50','convnext_tiny'] depending on torchvision version
    mixup_alpha: float = 0.0       # set >0.0 to enable MixUp
    cutmix_alpha: float = 0.0      # set >0.0 to enable CutMix
    train_val_split: float = 0.15  # 15% validation
    train_test_split: float = 0.15 # 15% test
    early_stopping_patience: int = 5
    freeze_backbone_epochs: int = 0  # set 1-2 if you want to warm up classifier first
    fp16: bool = True

cfg = CFG()
print(asdict(cfg))

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', DEVICE)

{'seed': 42, 'img_size': 256, 'train_batch_size': 32, 'valid_batch_size': 64, 'num_workers': 2, 'epochs': 10, 'base_lr': 0.0003, 'weight_decay': 0.0001, 'label_smoothing': 0.1, 'model_name': 'efficientnet_b0', 'mixup_alpha': 0.0, 'cutmix_alpha': 0.0, 'train_val_split': 0.15, 'train_test_split': 0.15, 'early_stopping_patience': 5, 'freeze_backbone_epochs': 0, 'fp16': True}
Device: cpu


In [7]:
# Reproducibility and helpers

def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

set_seed(cfg.seed)

IMG_EXTS = {'.jpg','.jpeg','.png','.bmp','.tif','.tiff'}

def list_images(root: Path):
    classes = sorted([d.name for d in root.iterdir() if d.is_dir()])
    samples, labels = [], []
    for idx, cls in enumerate(classes):
        for p in (root/cls).rglob('*'):
            if p.suffix.lower() in IMG_EXTS:
                samples.append(p)
                labels.append(idx)
    return classes, np.array(samples), np.array(labels, dtype=np.int64)

classes, all_paths, all_labels = list_images(DATASET_ROOT)
num_classes = len(classes)
print(f"Found {len(all_paths)} images across {num_classes} classes.")

# Save label mapping
label_map = {i:c for i,c in enumerate(classes)}
with open(OUTPUT_DIR/'labels.json','w') as f:
    json.dump(label_map, f, indent=2)

Found 5945 images across 40 classes.


In [8]:
# Train/Val/Test split (stratified)
sss1 = StratifiedShuffleSplit(n_splits=1, test_size=cfg.train_val_split + cfg.train_test_split, random_state=cfg.seed)
train_idx, temp_idx = next(sss1.split(all_paths, all_labels))

paths_train, labels_train = all_paths[train_idx], all_labels[train_idx]
paths_temp, labels_temp = all_paths[temp_idx], all_labels[temp_idx]

val_ratio_of_temp = cfg.train_val_split / (cfg.train_val_split + cfg.train_test_split)
sss2 = StratifiedShuffleSplit(n_splits=1, test_size=1 - val_ratio_of_temp, random_state=cfg.seed)
val_idx, test_idx = next(sss2.split(paths_temp, labels_temp))

paths_val, labels_val = paths_temp[val_idx], labels_temp[val_idx]
paths_test, labels_test = paths_temp[test_idx], labels_temp[test_idx]

print(f"Split -> train: {len(paths_train)}, val: {len(paths_val)}, test: {len(paths_test)}")

Split -> train: 4161, val: 892, test: 892


In [9]:
# Transforms and Dataset
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD  = (0.229, 0.224, 0.225)

train_tfms = T.Compose([
    T.RandomResizedCrop(cfg.img_size, scale=(0.7, 1.0)),
    T.RandomHorizontalFlip(),
    T.RandomVerticalFlip(p=0.2),
    T.AutoAugment(T.AutoAugmentPolicy.IMAGENET),
    T.ColorJitter(0.2, 0.2, 0.2, 0.1),
    T.ToTensor(),
    T.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])

valid_tfms = T.Compose([
    T.Resize(int(cfg.img_size*1.15)),
    T.CenterCrop(cfg.img_size),
    T.ToTensor(),
    T.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])

class ImageDataset(Dataset):
    def __init__(self, paths, labels, transform=None):
        self.paths = list(map(str, paths))
        self.labels = labels
        self.transform = transform
    def __len__(self):
        return len(self.paths)
    def __getitem__(self, idx):
        p = self.paths[idx]
        img = Image.open(p).convert('RGB')
        if self.transform:
            img = self.transform(img)
        target = int(self.labels[idx])
        return img, target

train_ds = ImageDataset(paths_train, labels_train, train_tfms)
val_ds   = ImageDataset(paths_val, labels_val, valid_tfms)
test_ds  = ImageDataset(paths_test, labels_test, valid_tfms)

# Balanced sampling for training if classes are imbalanced
class_counts = np.bincount(labels_train, minlength=num_classes)
class_weights = 1.0 / np.clip(class_counts, 1, None)
weights = class_weights[labels_train]
sampler = WeightedRandomSampler(weights=weights, num_samples=len(weights), replacement=True)

train_loader = DataLoader(train_ds, batch_size=cfg.train_batch_size, sampler=sampler,
                          num_workers=cfg.num_workers, pin_memory=True, persistent_workers=False)
val_loader   = DataLoader(val_ds, batch_size=cfg.valid_batch_size, shuffle=False,
                          num_workers=cfg.num_workers, pin_memory=True)
test_loader  = DataLoader(test_ds, batch_size=cfg.valid_batch_size, shuffle=False,
                          num_workers=cfg.num_workers, pin_memory=True)

len(train_loader), len(val_loader), len(test_loader)

(131, 14, 14)

In [6]:
# Model factory

def build_model(model_name: str, num_classes: int):
    model_name = model_name.lower()

    def safe_load(fn_with_weights, fn_no_weights):
        try:
            return fn_with_weights()
        except Exception as e:
            print(f"[Info] Could not load pretrained weights (likely no internet/cache). Falling back to non-pretrained. Error: {str(e)[:120]}")
            return fn_no_weights()

    if model_name == 'efficientnet_b0':
        def with_w():
            weights = getattr(tvm, 'EfficientNet_B0_Weights', None)
            if weights is not None:
                return tvm.efficientnet_b0(weights=weights.IMAGENET1K_V1)
            return tvm.efficientnet_b0(weights='IMAGENET1K_V1')
        def no_w():
            return tvm.efficientnet_b0(weights=None)
        model = safe_load(with_w, no_w)
        in_features = model.classifier[1].in_features
        model.classifier[1] = nn.Linear(in_features, num_classes)

    elif model_name == 'resnet50':
        def with_w():
            weights = getattr(tvm, 'ResNet50_Weights', None)
            if weights is not None:
                return tvm.resnet50(weights=weights.IMAGENET1K_V2)
            return tvm.resnet50(weights='IMAGENET1K_V2')
        def no_w():
            return tvm.resnet50(weights=None)
        model = safe_load(with_w, no_w)
        in_features = model.fc.in_features
        model.fc = nn.Linear(in_features, num_classes)

    elif model_name == 'convnext_tiny' and hasattr(tvm, 'convnext_tiny'):
        def with_w():
            weights = getattr(tvm, 'ConvNeXt_Tiny_Weights', None)
            return tvm.convnext_tiny(weights=weights.IMAGENET1K_V1 if weights else 'IMAGENET1K_V1')
        def no_w():
            return tvm.convnext_tiny(weights=None)
        model = safe_load(with_w, no_w)
        in_features = model.classifier[2].in_features
        model.classifier[2] = nn.Linear(in_features, num_classes)

    else:
        print('Unknown model, defaulting to resnet50')
        def with_w():
            weights = getattr(tvm, 'ResNet50_Weights', None)
            return tvm.resnet50(weights=weights.IMAGENET1K_V2 if weights else 'IMAGENET1K_V2')
        def no_w():
            return tvm.resnet50(weights=None)
        model = safe_load(with_w, no_w)
        in_features = model.fc.in_features
        model.fc = nn.Linear(in_features, num_classes)

    return model

model = build_model(cfg.model_name, num_classes).to(DEVICE)

criterion = nn.CrossEntropyLoss(label_smoothing=cfg.label_smoothing)
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.base_lr, weight_decay=cfg.weight_decay)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, max_lr=cfg.base_lr, steps_per_epoch=len(train_loader), epochs=cfg.epochs
)

scaler = torch.cuda.amp.GradScaler(enabled=cfg.fp16 and DEVICE.type=='cuda')

print('Model built:', cfg.model_name)

Downloading: "https://download.pytorch.org/models/efficientnet_b0_rwightman-7f5810bc.pth" to C:\Users\rahul/.cache\torch\hub\checkpoints\efficientnet_b0_rwightman-7f5810bc.pth


100%|██████████| 20.5M/20.5M [01:00<00:00, 352kB/s]


Model built: efficientnet_b0


  scaler = torch.cuda.amp.GradScaler(enabled=cfg.fp16 and DEVICE.type=='cuda')


In [None]:
# Training and validation loops

def accuracy(outputs, targets):
    preds = outputs.argmax(1)
    return (preds == targets).float().mean().item()

@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    all_preds, all_targets = [], []
    running_loss = 0.0
    for imgs, targets in loader:
        imgs, targets = imgs.to(DEVICE, non_blocking=True), targets.to(DEVICE, non_blocking=True)
        with torch.cuda.amp.autocast(enabled=cfg.fp16 and DEVICE.type=='cuda'):
            logits = model(imgs)
            loss = criterion(logits, targets)
        running_loss += loss.item() * imgs.size(0)
        all_preds.append(logits.argmax(1).detach().cpu().numpy())
        all_targets.append(targets.detach().cpu().numpy())
    all_preds = np.concatenate(all_preds)
    all_targets = np.concatenate(all_targets)
    val_loss = running_loss / len(loader.dataset)
    acc = accuracy_score(all_targets, all_preds)
    f1 = f1_score(all_targets, all_preds, average='macro')
    return val_loss, acc, f1, all_targets, all_preds


def train_model(model, train_loader, val_loader):
    best_f1, best_state, epochs_no_improve = -1.0, None, 0
    history = {"train_loss":[], "train_acc":[], "val_loss":[], "val_acc":[], "val_f1":[], "lr": []}

    for epoch in range(cfg.epochs):
        model.train()
        if cfg.freeze_backbone_epochs and epoch < cfg.freeze_backbone_epochs:
            for name, p in model.named_parameters():
                if 'classifier' not in name and (not name.endswith('fc.weight') and not name.endswith('fc.bias')):
                    p.requires_grad = False
        else:
            for p in model.parameters():
                p.requires_grad = True

        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{cfg.epochs}", leave=False)
        running_loss, running_acc, n = 0.0, 0.0, 0
        
        for imgs, targets in pbar:
            imgs, targets = imgs.to(DEVICE, non_blocking=True), targets.to(DEVICE, non_blocking=True)
            optimizer.zero_grad(set_to_none=True)
            with torch.cuda.amp.autocast(enabled=cfg.fp16 and DEVICE.type=='cuda'):
                logits = model(imgs)
                loss = criterion(logits, targets)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()

            running_loss += loss.item() * imgs.size(0)
            running_acc += (logits.argmax(1) == targets).float().sum().item()
            n += imgs.size(0)
            pbar.set_postfix({"loss": running_loss/n, "acc": running_acc/n, "lr": scheduler.get_last_lr()[0]})

        train_loss = running_loss / n
        train_acc = running_acc / n
        val_loss, val_acc, val_f1, _, _ = evaluate(model, val_loader)

        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        history['val_f1'].append(val_f1)
        history['lr'].append(scheduler.get_last_lr()[0])

        print(f"Epoch {epoch+1:02d}: train_loss={train_loss:.4f} acc={train_acc:.4f} | val_loss={val_loss:.4f} val_acc={val_acc:.4f} val_f1={val_f1:.4f}")

        # Early stopping & checkpoint
        if val_f1 > best_f1:
            best_f1 = val_f1
            best_state = {k: v.cpu() for k,v in model.state_dict().items()}
            torch.save(best_state, OUTPUT_DIR/'best_model.pth')
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= cfg.early_stopping_patience:
                print('Early stopping.')
                break

    # Load best
    if best_state is not None:
        model.load_state_dict(best_state)
    return model, history

model, history = train_model(model, train_loader, val_loader)
with open(OUTPUT_DIR/'train_history.json','w') as f:
    json.dump(history, f, indent=2)



In [10]:
# Dashboard: learning curves (load saved history; do NOT retrain)
import json
from pathlib import Path
import plotly.graph_objects as go
from plotly.subplots import make_subplots

OUTPUT_DIR = Path.cwd() / 'impc_outputs'
hist_file = OUTPUT_DIR / 'train_history.json'

if not hist_file.exists():
    raise FileNotFoundError(f"{hist_file} missing. Run training once to create it.")

with open(hist_file, 'r') as f:
    hist = json.load(f)

fig = make_subplots(rows=2, cols=2, subplot_titles=('Loss','Accuracy','Val F1','Learning Rate'))
fig.add_trace(go.Scatter(y=hist['train_loss'], name='train_loss'), row=1, col=1)
fig.add_trace(go.Scatter(y=hist['val_loss'], name='val_loss'), row=1, col=1)
fig.add_trace(go.Scatter(y=hist['train_acc'], name='train_acc'), row=1, col=2)
fig.add_trace(go.Scatter(y=hist['val_acc'], name='val_acc'), row=1, col=2)
fig.add_trace(go.Scatter(y=hist['val_f1'], name='val_f1'), row=2, col=1)
fig.add_trace(go.Scatter(y=hist['lr'], name='lr'), row=2, col=2)
fig.update_layout(height=700, width=1000, title_text='Training Dashboard', showlegend=True)
fig.show()

In [3]:
# Quick evaluation (safe on Windows; no plots). Uses a small random subset.

import os, json, random
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T, models as tvm
from PIL import Image, ImageFile
from sklearn.metrics import accuracy_score, f1_score

ImageFile.LOAD_TRUNCATED_IMAGES = True  # tolerate truncated/corrupt files

# -------- Config --------
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
IMAGENET_MEAN, IMAGENET_STD = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
IMG_SIZE = 256
BATCH = 64 if DEVICE.type == 'cuda' else 16
MAX_IMAGES = 128         # total images to evaluate (keeps it fast)
SEED = 42
MODEL_NAME = 'efficientnet_b0'

random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

# -------- Paths --------
DEFAULT_LOCAL_ROOT = Path(r"C:\Users\rahul\Downloads\archive\Medicinal plant dataset")
DATASET_ROOT = Path(os.getenv("IMPC_DATASET_DIR", DEFAULT_LOCAL_ROOT))
OUTPUT_DIR = Path.cwd() / "impc_outputs"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# -------- Classes (prefer saved mapping) --------
labels_json = OUTPUT_DIR / "labels.json"
if labels_json.exists():
    with open(labels_json, "r") as f:
        label_map = json.load(f)
    classes = [label_map[str(i)] for i in range(len(label_map))]
else:
    classes = sorted([d.name for d in DATASET_ROOT.iterdir() if d.is_dir()])
num_classes = len(classes)

# -------- Build lightweight sample set --------
IMG_EXTS = {'.jpg','.jpeg','.png','.bmp','.tif','.tiff'}
per_class_cap = max(1, MAX_IMAGES // max(1, num_classes))
paths, labels = [], []

for cls_idx, cls in enumerate(classes):
    taken = 0
    cls_dir = DATASET_ROOT / cls
    for p in cls_dir.rglob('*'):
        if p.suffix.lower() in IMG_EXTS:
            paths.append(p)
            labels.append(cls_idx)
            taken += 1
            if taken >= per_class_cap:
                break

paths = np.array(paths)
labels = np.array(labels, dtype=np.int64)

# Shuffle subset once
perm = np.random.permutation(len(paths))
paths, labels = paths[perm], labels[perm]

# -------- Dataset / Loader --------
valid_tfms = T.Compose([
    T.Resize(int(IMG_SIZE*1.15)),
    T.CenterCrop(IMG_SIZE),
    T.ToTensor(),
    T.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])

class ImageDataset(Dataset):
    def __init__(self, paths, labels, transform=None):
        self.paths = list(map(str, paths))
        self.labels = labels
        self.transform = transform
    def __len__(self): return len(self.paths)
    def __getitem__(self, idx):
        p = self.paths[idx]
        try:
            img = Image.open(p).convert('RGB')
        except Exception:
            # return None to skip corrupt/unreadable files
            return None
        if self.transform:
            img = self.transform(img)
        return img, int(self.labels[idx])

def collate_skip_none(batch):
    batch = [b for b in batch if b is not None]
    if not batch:
        # empty batch -> return empty tensors
        return torch.empty(0), torch.empty(0, dtype=torch.long)
    imgs, targets = zip(*batch)
    return torch.stack(imgs, 0), torch.tensor(targets, dtype=torch.long)

eval_ds = ImageDataset(paths, labels, valid_tfms)
# IMPORTANT: num_workers=0 on Windows/Jupyter to avoid multiprocessing crashes
eval_loader = DataLoader(
    eval_ds, batch_size=BATCH, shuffle=False,
    num_workers=0, pin_memory=(DEVICE.type=='cuda'),
    collate_fn=collate_skip_none, persistent_workers=False
)

# -------- Model (minimal factory) --------
def build_model(model_name: str, num_classes: int):
    model_name = model_name.lower()
    def safe_load(w_fn, nw_fn):
        try: return w_fn()
        except Exception: return nw_fn()
    if model_name == 'efficientnet_b0':
        def w():
            W = getattr(tvm, 'EfficientNet_B0_Weights', None)
            return tvm.efficientnet_b0(weights=W.IMAGENET1K_V1 if W else 'IMAGENET1K_V1')
        def nw(): return tvm.efficientnet_b0(weights=None)
        m = safe_load(w, nw)
        in_f = m.classifier[1].in_features
        m.classifier[1] = nn.Linear(in_f, num_classes)
    elif model_name == 'resnet50':
        def w():
            W = getattr(tvm, 'ResNet50_Weights', None)
            return tvm.resnet50(weights=W.IMAGENET1K_V2 if W else 'IMAGENET1K_V2')
        def nw(): return tvm.resnet50(weights=None)
        m = safe_load(w, nw)
        in_f = m.fc.in_features
        m.fc = nn.Linear(in_f, num_classes)
    else:
        m = tvm.resnet50(weights=None)
        in_f = m.fc.in_features
        m.fc = nn.Linear(in_f, num_classes)
    return m

model = build_model(MODEL_NAME, num_classes).to(DEVICE)

# Load checkpoint if present
ckpt = OUTPUT_DIR / "best_model.pth"
if ckpt.exists():
    state = torch.load(ckpt, map_location=DEVICE)
    model.load_state_dict(state)
else:
    print(f"Warning: {ckpt} not found; using current weights.")

criterion = nn.CrossEntropyLoss(label_smoothing=0.1 if DEVICE.type=='cuda' else 0.0)

# -------- Evaluate (no plots) --------
@torch.no_grad()
def quick_evaluate(model, loader):
    model.eval()
    preds_all, targs_all = [], []
    total_loss, seen = 0.0, 0
    for imgs, targs in loader:
        if imgs.numel() == 0:
            continue  # skip empty batches
        imgs, targs = imgs.to(DEVICE), targs.to(DEVICE)
        with torch.cuda.amp.autocast(enabled=(DEVICE.type=='cuda')):
            logits = model(imgs)
            loss = criterion(logits, targs)
        total_loss += loss.item() * imgs.size(0)
        seen += imgs.size(0)
        preds_all.append(logits.argmax(1).cpu().numpy())
        targs_all.append(targs.cpu().numpy())
    if seen == 0:
        print("No valid images found to evaluate.")
        return 0.0, 0.0, 0.0, 0
    preds_all = np.concatenate(preds_all)
    targs_all = np.concatenate(targs_all)
    loss = total_loss / seen
    acc = accuracy_score(targs_all, preds_all)
    f1 = f1_score(targs_all, preds_all, average='macro')
    return loss, acc, f1, seen

loss, acc, f1, n = quick_evaluate(model, eval_loader)
print(f"Quick eval on {n} images -> loss={loss:.4f}, acc={acc:.4f}, f1={f1:.4f}")

  with torch.cuda.amp.autocast(enabled=(DEVICE.type=='cuda')):


Quick eval on 120 images -> loss=0.1462, acc=0.9833, f1=0.9812


In [8]:
# Export (state dict, TorchScript, ONNX with dependency check)
import torch, json, sys
from pathlib import Path
import torch.nn as nn
from torchvision import models as tvm

OUTPUT_DIR = globals().get('OUTPUT_DIR', Path.cwd() / 'impc_outputs')
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
DEVICE = globals().get('DEVICE', torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
cfg = globals().get('cfg')
if cfg is None:
    class CFG: img_size=256; model_name='efficientnet_b0'; fp16=True
    cfg = CFG()

# Build / get model
def build_model(name, num_classes):
    name = name.lower()
    if name == 'efficientnet_b0':
        m = tvm.efficientnet_b0(weights=None)
        m.classifier[1] = nn.Linear(m.classifier[1].in_features, num_classes)
    elif name == 'resnet50':
        m = tvm.resnet50(weights=None)
        m.fc = nn.Linear(m.fc.in_features, num_classes)
    else:
        m = tvm.resnet50(weights=None)
        m.fc = nn.Linear(m.fc.in_features, num_classes)
    return m

# Classes
labels_json = OUTPUT_DIR / 'labels.json'
if labels_json.exists():
    with open(labels_json,'r') as f: lbl_map = json.load(f)
    classes = [lbl_map[str(i)] for i in range(len(lbl_map))]
else:
    classes = globals().get('classes', [])
num_classes = len(classes) if classes else globals().get('num_classes', 2) or 2

model = globals().get('model') or build_model(getattr(cfg,'model_name','efficientnet_b0'), num_classes).to(DEVICE)

# Load checkpoint if exists
best_pth = OUTPUT_DIR / 'best_model.pth'
if best_pth.exists():
    try:
        state = torch.load(best_pth, map_location=DEVICE)
        model.load_state_dict(state, strict=False)
        print('Loaded checkpoint:', best_pth)
    except Exception as e:
        print('Checkpoint load failed:', e)
else:
    torch.save({k:v.cpu() for k,v in model.state_dict().items()}, best_pth)
    print('Created baseline checkpoint:', best_pth)

model.eval()

# State dict already ensured above

# TorchScript
script_path = OUTPUT_DIR / 'model.torchscript.pt'
try:
    example = torch.randn(1,3,cfg.img_size,cfg.img_size, device=DEVICE)
    traced = torch.jit.trace(model, example)
    traced.save(str(script_path))
    print('Saved TorchScript:', script_path)
    with torch.no_grad():
        ts_out = traced(example)
    print('TorchScript test OK shape:', tuple(ts_out.shape))
except Exception as e:
    print('TorchScript export failed:', e)

# ONNX export with dependency check
onx_path = OUTPUT_DIR / 'model.onnx'
try:
    import onnx  # noqa
    try:
        import onnxscript  # required for new exporter in recent PyTorch
    except ImportError:
        print("onnxscript missing; install with: pip install onnxscript (skipping ONNX export).")
        raise RuntimeError("onnxscript not installed")
    dummy = torch.randn(1,3,cfg.img_size,cfg.img_size, device=DEVICE)
    torch.onnx.export(
        model,
        dummy,
        str(onx_path),
        input_names=['images'],
        output_names=['logits'],
        dynamic_axes={'images': {0: 'batch'}, 'logits': {0: 'batch'}},
        opset_version=17,
        do_constant_folding=True
    )
    print('Saved ONNX:', onx_path)
    # Optional runtime test
    try:
        import onnxruntime as ort
        ort_sess = ort.InferenceSession(str(onx_path), providers=['CPUExecutionProvider'])
        ort_out = ort_sess.run(None, {'images': dummy.cpu().numpy()})[0]
        print('ONNX runtime test OK shape:', tuple(ort_out.shape))
    except Exception as e:
        print('ONNX runtime test skipped/failed:', e)
except Exception as e:
    print('ONNX export failed:', e)

print('Artifacts:', OUTPUT_DIR)
print('Classes count:', num_classes)

Loaded checkpoint: c:\Users\rahul\OneDrive\Desktop\Plant\Indian-Medicinal-Plant-classifier\impc_outputs\best_model.pth
Saved TorchScript: c:\Users\rahul\OneDrive\Desktop\Plant\Indian-Medicinal-Plant-classifier\impc_outputs\model.torchscript.pt
TorchScript test OK shape: (1, 40)
ONNX export failed: No module named 'onnx'
Artifacts: c:\Users\rahul\OneDrive\Desktop\Plant\Indian-Medicinal-Plant-classifier\impc_outputs
Classes count: 40


In [9]:
# Inference demo with fast TTA (self-contained; no labels_test dependency)
import os, json, math, numpy as np
from pathlib import Path
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T
from PIL import Image, ImageFile
import plotly.graph_objects as go
from plotly.subplots import make_subplots

ImageFile.LOAD_TRUNCATED_IMAGES = True

# Fallbacks for globals if not present
IMAGENET_MEAN = globals().get('IMAGENET_MEAN', (0.485, 0.456, 0.406))
IMAGENET_STD  = globals().get('IMAGENET_STD',  (0.229, 0.224, 0.225))
DEVICE = globals().get('DEVICE', torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
cfg = globals().get('cfg', type('C', (), {'img_size':256, 'fp16':True})())
OUTPUT_DIR = globals().get('OUTPUT_DIR', Path.cwd() / 'impc_outputs')
DATASET_ROOT = globals().get('DATASET_ROOT', Path(r"C:\Users\rahul\Downloads\archive\Medicinal plant dataset"))

# Fast TTA: single forward per image
@torch.no_grad()
def tta_predict(img):
    imgs = torch.stack([
        img,                       # original
        torch.flip(img, dims=[2]), # horizontal flip
        torch.flip(img, dims=[1])  # vertical flip
    ], dim=0).to(DEVICE)
    with torch.cuda.amp.autocast(enabled=(getattr(cfg, 'fp16', True) and DEVICE.type=='cuda')):
        logits = model(imgs)
        probs = F.softmax(logits, dim=1).mean(0)
    return probs

# Pick a dataset to display: prefer test_ds -> eval_ds -> build tiny fallback
ds = globals().get('test_ds') or globals().get('eval_ds')
classes = globals().get('classes')

if ds is None:
    # Build a tiny dataset from filesystem
    IMG_EXTS = {'.jpg','.jpeg','.png','.bmp','.tif','.tiff'}
    valid_tfms = T.Compose([
        T.Resize(int(cfg.img_size*1.15)),
        T.CenterCrop(cfg.img_size),
        T.ToTensor(),
        T.Normalize(IMAGENET_MEAN, IMAGENET_STD),
    ])
    # Classes from saved labels if present, else from folders
    labels_json = OUTPUT_DIR / "labels.json"
    if classes is None:
        if labels_json.exists():
            with open(labels_json, "r") as f:
                label_map = json.load(f)
            classes = [label_map[str(i)] for i in range(len(label_map))]
        else:
            classes = sorted([d.name for d in DATASET_ROOT.iterdir() if d.is_dir()])
    per_class_cap = 5
    paths, labels = [], []
    for cls_idx, cls in enumerate(classes):
        taken = 0
        for p in (DATASET_ROOT/cls).rglob('*'):
            if p.suffix.lower() in IMG_EXTS:
                paths.append(str(p)); labels.append(cls_idx)
                taken += 1
                if taken >= per_class_cap: break
    paths = np.array(paths); labels = np.array(labels, dtype=np.int64)

    class TinyDS(Dataset):
        def __init__(self, paths, labels, tfm): self.paths, self.labels, self.tfm = list(paths), labels, tfm
        def __len__(self): return len(self.paths)
        def __getitem__(self, i):
            img = Image.open(self.paths[i]).convert('RGB')
            img = self.tfm(img)
            return img, int(self.labels[i])
    ds = TinyDS(paths, labels, valid_tfms)

# Show a few predictions
n_show = min(12, len(ds))
if n_show == 0:
    raise RuntimeError("No images available for display.")
sample_idx = np.random.choice(len(ds), size=n_show, replace=False)

# Titles from dataset labels (avoid labels_test)
subplot_titles = [f"true:{classes[int(ds[i][1])]}" for i in sample_idx]
rows, cols = math.ceil(n_show/4), 4
fig = make_subplots(rows=rows, cols=cols, subplot_titles=subplot_titles)

for k, i in enumerate(sample_idx):
    img, true_lbl = ds[i]
    prob = tta_predict(img)
    pred_idx = int(prob.argmax().item())
    pred_cls = classes[pred_idx]

    # Denormalize for display
    img_disp = img.clone()
    for c,(m,s) in enumerate(zip(IMAGENET_MEAN, IMAGENET_STD)):
        img_disp[c] = img_disp[c]*s + m
    img_disp = (img_disp.clamp(0,1).permute(1,2,0).cpu().numpy()*255).astype(np.uint8)

    r, c = k//cols + 1, k%cols + 1
    fig.add_trace(go.Image(z=img_disp), row=r, col=c)
    fig.update_xaxes(showticklabels=False).update_yaxes(showticklabels=False)
    fig.layout.annotations[k].text += f" | pred:{pred_cls}"

fig.update_layout(height=300*rows, width=250*cols, title_text='Sample Predictions (Fast TTA)')
fig.show()


`torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.

