## How to use on Kaggle

1. Open this notebook on Kaggle and add the dataset input: `indian-medicinal-plant-image-dataset`.
2. In the right sidebar, set Accelerator to a GPU (P100/T4/V100). Internet is not required.
3. Click “Save & Run All”.
4. Artifacts will be saved under `/kaggle/working/impc_outputs/`:
   - `best_model.pth` – PyTorch weights
   - `model.torchscript.pt` – TorchScript for production
   - `model.onnx` – ONNX export
   - `labels.json`, `metrics.json`, `train_history.json`

Tip: You can tune `CFG` in the next cell to change epochs, model, image size, etc.

# Indian Medicinal Plant Classifier – Kaggle-ready

This notebook fine-tunes a modern pretrained CNN on the Indian Medicinal Plant Image Dataset hosted on Kaggle, produces a professional analytics dashboard, and exports ready-to-use model artifacts (PyTorch, TorchScript, ONNX) to `/kaggle/working` for submission or reuse.

Highlights
- Robust train/val/test split from foldered classes
- Strong augmentations and mixed-precision (AMP) for fast training on P100
- OneCycleLR scheduler and label smoothing for stable convergence
- Metrics: accuracy, macro F1, confusion matrix, per-class report
- Plotly dashboard with learning curves and per-class performance
- Exports: `best_model.pth`, `labels.json`, `model.torchscript.pt`, `model.onnx`, `metrics.json`

Notes
- Dataset path is auto-detected under `/kaggle/input/indian-medicinal-plant-image-dataset/Medicinal plant dataset`.
- All outputs are written to `/kaggle/working` so they persist when you “Save & Run All”.

In [None]:
# Environment and version checks
import os, sys, json, random, math, time, gc
from pathlib import Path

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import transforms as T
from torchvision import models as tvm

from PIL import Image
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import classification_report, confusion_matrix, f1_score, accuracy_score

import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
from tqdm.auto import tqdm

print('Python', sys.version)
print('Torch', torch.__version__, '| CUDA available:', torch.cuda.is_available())
print('Torchvision', tvm.__name__)

# Base directories (Kaggle setup)
KAGGLE_INPUT = Path('/kaggle/input')
KAGGLE_WORKING = Path('/kaggle/working')
DATASET_ROOT = KAGGLE_INPUT / 'indian-medicinal-plant-image-dataset' / 'Medicinal plant dataset'
assert DATASET_ROOT.exists(), f"Dataset folder not found at: {DATASET_ROOT}. Add the dataset as input in Kaggle."

OUTPUT_DIR = KAGGLE_WORKING / 'impc_outputs'
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
print('Data root:', DATASET_ROOT)
print('Output dir:', OUTPUT_DIR)

In [None]:
# Configuration (adjust as needed)
from dataclasses import dataclass, asdict

@dataclass
class CFG:
    seed: int = 42
    img_size: int = 256           # training crop size
    train_batch_size: int = 32
    valid_batch_size: int = 64
    num_workers: int = 2
    epochs: int = 10
    base_lr: float = 3e-4
    weight_decay: float = 1e-4
    label_smoothing: float = 0.1
    model_name: str = 'efficientnet_b0'  # ['efficientnet_b0','resnet50','convnext_tiny'] depending on torchvision version
    mixup_alpha: float = 0.0       # set >0.0 to enable MixUp
    cutmix_alpha: float = 0.0      # set >0.0 to enable CutMix
    train_val_split: float = 0.15  # 15% validation
    train_test_split: float = 0.15 # 15% test
    early_stopping_patience: int = 5
    freeze_backbone_epochs: int = 0  # set 1-2 if you want to warm up classifier first
    fp16: bool = True

cfg = CFG()
print(asdict(cfg))

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', DEVICE)

In [None]:
# Reproducibility and helpers

def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

set_seed(cfg.seed)

IMG_EXTS = {'.jpg','.jpeg','.png','.bmp','.tif','.tiff'}

def list_images(root: Path):
    classes = sorted([d.name for d in root.iterdir() if d.is_dir()])
    samples, labels = [], []
    for idx, cls in enumerate(classes):
        for p in (root/cls).rglob('*'):
            if p.suffix.lower() in IMG_EXTS:
                samples.append(p)
                labels.append(idx)
    return classes, np.array(samples), np.array(labels, dtype=np.int64)

classes, all_paths, all_labels = list_images(DATASET_ROOT)
num_classes = len(classes)
print(f"Found {len(all_paths)} images across {num_classes} classes.")

# Save label mapping
label_map = {i:c for i,c in enumerate(classes)}
with open(OUTPUT_DIR/'labels.json','w') as f:
    json.dump(label_map, f, indent=2)

In [None]:
# Train/Val/Test split (stratified)
sss1 = StratifiedShuffleSplit(n_splits=1, test_size=cfg.train_val_split + cfg.train_test_split, random_state=cfg.seed)
train_idx, temp_idx = next(sss1.split(all_paths, all_labels))

paths_train, labels_train = all_paths[train_idx], all_labels[train_idx]
paths_temp, labels_temp = all_paths[temp_idx], all_labels[temp_idx]

val_ratio_of_temp = cfg.train_val_split / (cfg.train_val_split + cfg.train_test_split)
sss2 = StratifiedShuffleSplit(n_splits=1, test_size=1 - val_ratio_of_temp, random_state=cfg.seed)
val_idx, test_idx = next(sss2.split(paths_temp, labels_temp))

paths_val, labels_val = paths_temp[val_idx], labels_temp[val_idx]
paths_test, labels_test = paths_temp[test_idx], labels_temp[test_idx]

print(f"Split -> train: {len(paths_train)}, val: {len(paths_val)}, test: {len(paths_test)}")

In [None]:
# Transforms and Dataset
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD  = (0.229, 0.224, 0.225)

train_tfms = T.Compose([
    T.RandomResizedCrop(cfg.img_size, scale=(0.7, 1.0)),
    T.RandomHorizontalFlip(),
    T.RandomVerticalFlip(p=0.2),
    T.AutoAugment(T.AutoAugmentPolicy.IMAGENET),
    T.ColorJitter(0.2, 0.2, 0.2, 0.1),
    T.ToTensor(),
    T.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])

valid_tfms = T.Compose([
    T.Resize(int(cfg.img_size*1.15)),
    T.CenterCrop(cfg.img_size),
    T.ToTensor(),
    T.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])

class ImageDataset(Dataset):
    def __init__(self, paths, labels, transform=None):
        self.paths = list(map(str, paths))
        self.labels = labels
        self.transform = transform
    def __len__(self):
        return len(self.paths)
    def __getitem__(self, idx):
        p = self.paths[idx]
        img = Image.open(p).convert('RGB')
        if self.transform:
            img = self.transform(img)
        target = int(self.labels[idx])
        return img, target

train_ds = ImageDataset(paths_train, labels_train, train_tfms)
val_ds   = ImageDataset(paths_val, labels_val, valid_tfms)
test_ds  = ImageDataset(paths_test, labels_test, valid_tfms)

# Balanced sampling for training if classes are imbalanced
class_counts = np.bincount(labels_train, minlength=num_classes)
class_weights = 1.0 / np.clip(class_counts, 1, None)
weights = class_weights[labels_train]
sampler = WeightedRandomSampler(weights=weights, num_samples=len(weights), replacement=True)

train_loader = DataLoader(train_ds, batch_size=cfg.train_batch_size, sampler=sampler,
                          num_workers=cfg.num_workers, pin_memory=True, persistent_workers=False)
val_loader   = DataLoader(val_ds, batch_size=cfg.valid_batch_size, shuffle=False,
                          num_workers=cfg.num_workers, pin_memory=True)
test_loader  = DataLoader(test_ds, batch_size=cfg.valid_batch_size, shuffle=False,
                          num_workers=cfg.num_workers, pin_memory=True)

len(train_loader), len(val_loader), len(test_loader)

In [None]:
# Model factory

def build_model(model_name: str, num_classes: int):
    model_name = model_name.lower()

    def safe_load(fn_with_weights, fn_no_weights):
        try:
            return fn_with_weights()
        except Exception as e:
            print(f"[Info] Could not load pretrained weights (likely no internet/cache). Falling back to non-pretrained. Error: {str(e)[:120]}")
            return fn_no_weights()

    if model_name == 'efficientnet_b0':
        def with_w():
            weights = getattr(tvm, 'EfficientNet_B0_Weights', None)
            if weights is not None:
                return tvm.efficientnet_b0(weights=weights.IMAGENET1K_V1)
            return tvm.efficientnet_b0(weights='IMAGENET1K_V1')
        def no_w():
            return tvm.efficientnet_b0(weights=None)
        model = safe_load(with_w, no_w)
        in_features = model.classifier[1].in_features
        model.classifier[1] = nn.Linear(in_features, num_classes)

    elif model_name == 'resnet50':
        def with_w():
            weights = getattr(tvm, 'ResNet50_Weights', None)
            if weights is not None:
                return tvm.resnet50(weights=weights.IMAGENET1K_V2)
            return tvm.resnet50(weights='IMAGENET1K_V2')
        def no_w():
            return tvm.resnet50(weights=None)
        model = safe_load(with_w, no_w)
        in_features = model.fc.in_features
        model.fc = nn.Linear(in_features, num_classes)

    elif model_name == 'convnext_tiny' and hasattr(tvm, 'convnext_tiny'):
        def with_w():
            weights = getattr(tvm, 'ConvNeXt_Tiny_Weights', None)
            return tvm.convnext_tiny(weights=weights.IMAGENET1K_V1 if weights else 'IMAGENET1K_V1')
        def no_w():
            return tvm.convnext_tiny(weights=None)
        model = safe_load(with_w, no_w)
        in_features = model.classifier[2].in_features
        model.classifier[2] = nn.Linear(in_features, num_classes)

    else:
        print('Unknown model, defaulting to resnet50')
        def with_w():
            weights = getattr(tvm, 'ResNet50_Weights', None)
            return tvm.resnet50(weights=weights.IMAGENET1K_V2 if weights else 'IMAGENET1K_V2')
        def no_w():
            return tvm.resnet50(weights=None)
        model = safe_load(with_w, no_w)
        in_features = model.fc.in_features
        model.fc = nn.Linear(in_features, num_classes)

    return model

model = build_model(cfg.model_name, num_classes).to(DEVICE)

criterion = nn.CrossEntropyLoss(label_smoothing=cfg.label_smoothing)
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.base_lr, weight_decay=cfg.weight_decay)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, max_lr=cfg.base_lr, steps_per_epoch=len(train_loader), epochs=cfg.epochs
)

scaler = torch.cuda.amp.GradScaler(enabled=cfg.fp16 and DEVICE.type=='cuda')

print('Model built:', cfg.model_name)

In [None]:
# Training and validation loops

def accuracy(outputs, targets):
    preds = outputs.argmax(1)
    return (preds == targets).float().mean().item()

@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    all_preds, all_targets = [], []
    running_loss = 0.0
    for imgs, targets in loader:
        imgs, targets = imgs.to(DEVICE, non_blocking=True), targets.to(DEVICE, non_blocking=True)
        with torch.cuda.amp.autocast(enabled=cfg.fp16 and DEVICE.type=='cuda'):
            logits = model(imgs)
            loss = criterion(logits, targets)
        running_loss += loss.item() * imgs.size(0)
        all_preds.append(logits.argmax(1).detach().cpu().numpy())
        all_targets.append(targets.detach().cpu().numpy())
    all_preds = np.concatenate(all_preds)
    all_targets = np.concatenate(all_targets)
    val_loss = running_loss / len(loader.dataset)
    acc = accuracy_score(all_targets, all_preds)
    f1 = f1_score(all_targets, all_preds, average='macro')
    return val_loss, acc, f1, all_targets, all_preds


def train_model(model, train_loader, val_loader):
    best_f1, best_state, epochs_no_improve = -1.0, None, 0
    history = {"train_loss":[], "train_acc":[], "val_loss":[], "val_acc":[], "val_f1":[], "lr": []}

    for epoch in range(cfg.epochs):
        model.train()
        if cfg.freeze_backbone_epochs and epoch < cfg.freeze_backbone_epochs:
            for name, p in model.named_parameters():
                if 'classifier' not in name and (not name.endswith('fc.weight') and not name.endswith('fc.bias')):
                    p.requires_grad = False
        else:
            for p in model.parameters():
                p.requires_grad = True

        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{cfg.epochs}", leave=False)
        running_loss, running_acc, n = 0.0, 0.0, 0
        
        for imgs, targets in pbar:
            imgs, targets = imgs.to(DEVICE, non_blocking=True), targets.to(DEVICE, non_blocking=True)
            optimizer.zero_grad(set_to_none=True)
            with torch.cuda.amp.autocast(enabled=cfg.fp16 and DEVICE.type=='cuda'):
                logits = model(imgs)
                loss = criterion(logits, targets)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()

            running_loss += loss.item() * imgs.size(0)
            running_acc += (logits.argmax(1) == targets).float().sum().item()
            n += imgs.size(0)
            pbar.set_postfix({"loss": running_loss/n, "acc": running_acc/n, "lr": scheduler.get_last_lr()[0]})

        train_loss = running_loss / n
        train_acc = running_acc / n
        val_loss, val_acc, val_f1, _, _ = evaluate(model, val_loader)

        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        history['val_f1'].append(val_f1)
        history['lr'].append(scheduler.get_last_lr()[0])

        print(f"Epoch {epoch+1:02d}: train_loss={train_loss:.4f} acc={train_acc:.4f} | val_loss={val_loss:.4f} val_acc={val_acc:.4f} val_f1={val_f1:.4f}")

        # Early stopping & checkpoint
        if val_f1 > best_f1:
            best_f1 = val_f1
            best_state = {k: v.cpu() for k,v in model.state_dict().items()}
            torch.save(best_state, OUTPUT_DIR/'best_model.pth')
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= cfg.early_stopping_patience:
                print('Early stopping.')
                break

    # Load best
    if best_state is not None:
        model.load_state_dict(best_state)
    return model, history

model, history = train_model(model, train_loader, val_loader)
with open(OUTPUT_DIR/'train_history.json','w') as f:
    json.dump(history, f, indent=2)

In [None]:
# Dashboard: learning curves
hist = history
fig = make_subplots(rows=2, cols=2, subplot_titles=(
    'Loss','Accuracy','Val F1','Learning Rate'))

fig.add_trace(go.Scatter(y=hist['train_loss'], name='train_loss'), row=1, col=1)
fig.add_trace(go.Scatter(y=hist['val_loss'], name='val_loss'), row=1, col=1)

fig.add_trace(go.Scatter(y=hist['train_acc'], name='train_acc'), row=1, col=2)
fig.add_trace(go.Scatter(y=hist['val_acc'], name='val_acc'), row=1, col=2)

fig.add_trace(go.Scatter(y=hist['val_f1'], name='val_f1'), row=2, col=1)
fig.add_trace(go.Scatter(y=hist['lr'], name='lr'), row=2, col=2)

fig.update_layout(height=700, width=1000, title_text='Training Dashboard', showlegend=True)
fig.show()

In [None]:
# Final evaluation on validation and test sets
val_loss, val_acc, val_f1, val_t, val_p = evaluate(model, val_loader)
print(f"Validation -> loss={val_loss:.4f}, acc={val_acc:.4f}, f1={val_f1:.4f}")

te_loss, te_acc, te_f1, te_t, te_p = evaluate(model, test_loader)
print(f"Test -> loss={te_loss:.4f}, acc={te_acc:.4f}, f1={te_f1:.4f}")

# Detailed classification report (test)
report = classification_report(te_t, te_p, target_names=classes, output_dict=True)
with open(OUTPUT_DIR/'classification_report.json','w') as f:
    json.dump(report, f, indent=2)

# Confusion matrix (test)
cm = confusion_matrix(te_t, te_p, labels=list(range(num_classes)))
cm_fig = px.imshow(cm, text_auto=True, color_continuous_scale='Blues',
                   labels=dict(x='Predicted', y='True', color='Count'),
                   x=classes, y=classes)
cm_fig.update_layout(title='Confusion Matrix – Test')
cm_fig.show()

# Per-class F1
per_class_f1 = [report[c]['f1-score'] for c in classes]
bar_fig = px.bar(x=classes, y=per_class_f1, labels={'x':'Class','y':'F1-score'}, title='Per-Class F1 (Test)')
bar_fig.update_xaxes(tickangle=45)
bar_fig.show()

# Save metrics
metrics = {
    'val': {'loss': val_loss, 'acc': val_acc, 'f1': val_f1},
    'test': {'loss': te_loss, 'acc': te_acc, 'f1': te_f1}
}
with open(OUTPUT_DIR/'metrics.json','w') as f:
    json.dump(metrics, f, indent=2)

print('Metrics saved to', OUTPUT_DIR)

In [None]:
# Export: PyTorch state dict, TorchScript, ONNX
model.eval()

# Save state dict (already saved best during training, but ensure copy)
best_pth = OUTPUT_DIR/'best_model.pth'
if not best_pth.exists():
    torch.save({k: v.cpu() for k,v in model.state_dict().items()}, best_pth)
print('Saved:', best_pth)

# TorchScript
example = torch.randn(1, 3, cfg.img_size, cfg.img_size).to(DEVICE)
traced = torch.jit.trace(model, example)
script_path = OUTPUT_DIR/'model.torchscript.pt'
traced.save(str(script_path))
print('Saved:', script_path)

# ONNX (dynamic axes for batch)
onx_path = OUTPUT_DIR/'model.onnx'
dummy = torch.randn(1, 3, cfg.img_size, cfg.img_size, device=DEVICE)
torch.onnx.export(
    model, dummy, str(onx_path), input_names=['images'], output_names=['logits'],
    dynamic_axes={'images': {0: 'batch'}, 'logits': {0: 'batch'}}, opset_version=12
)
print('Saved:', onx_path)

print('Label map at:', OUTPUT_DIR/'labels.json')
print('Artifacts ready in /kaggle/working/impc_outputs')

In [None]:
# Inference demo with TTA
from torchvision.transforms.functional import resize, center_crop

@torch.no_grad()
def predict_batch(imgs):
    model.eval()
    with torch.cuda.amp.autocast(enabled=cfg.fp16 and DEVICE.type=='cuda'):
        logits = model(imgs)
        probs = F.softmax(logits, dim=1)
    return probs

@torch.no_grad()
def tta_predict(img):
    # img: tensor CxHxW normalized
    augmentations = [lambda x:x,
                     T.RandomHorizontalFlip(p=1.0),
                     T.RandomVerticalFlip(p=1.0)]
    probs = []
    for aug in augmentations:
        aug_img = aug(img)
        probs.append(predict_batch(aug_img.unsqueeze(0).to(DEVICE)))
    return torch.stack(probs).mean(0).squeeze(0)

# Show a few predictions from test set
n_show = min(12, len(test_ds))
sample_idx = np.random.choice(len(test_ds), size=n_show, replace=False)

rows, cols = math.ceil(n_show/4), 4
fig = make_subplots(rows=rows, cols=cols, subplot_titles=[f"true:{classes[int(labels_test[i])] }" for i in sample_idx])

for k, i in enumerate(sample_idx):
    img, true_lbl = test_ds[i]
    prob = tta_predict(img)
    pred_idx = int(prob.argmax().item())
    pred_cls = classes[pred_idx]

    # Denormalize for display
    img_disp = img.clone()
    for c,(m,s) in enumerate(zip(IMAGENET_MEAN, IMAGENET_STD)):
        img_disp[c] = img_disp[c]*s + m
    img_disp = (img_disp.clamp(0,1).permute(1,2,0).cpu().numpy()*255).astype(np.uint8)

    r, c = k//cols + 1, k%cols + 1
    fig.add_trace(go.Image(z=img_disp), row=r, col=c)
    fig.update_xaxes(showticklabels=False).update_yaxes(showticklabels=False)
    fig.layout.annotations[k].text += f" | pred:{pred_cls}"

fig.update_layout(height=300*rows, width=250*cols, title_text='Sample Test Predictions (TTA)')
fig.show()