# MobileNetV2 Leaf Disease Classifier

This notebook is for training/fine-tuning a MobileNetV2 model for crop/leaf disease detection.

- Place your dataset (images + labels) in the `notebooks/` directory.
- Dataset should be in a folder format (e.g., `train/class_name/*.jpg`).
- Outputs a PyTorch `.pth` file and TorchServe `.mar` archive.
    

In [None]:
import torch
import torchvision
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import os
from tqdm import tqdm  # for progress bar

# Paths
data_dir = '../data/PlantVillage'
num_classes = len(os.listdir(os.path.join(data_dir, 'train')))

# Transforms
train_transforms = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

train_dataset = datasets.ImageFolder(os.path.join(data_dir, 'train'), train_transforms)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

# Model
model = models.mobilenet_v2(weights='IMAGENET1K_V1')
model.classifier[1] = torch.nn.Linear(model.last_channel, num_classes)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# Loss and Optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

best_accuracy = 0.0

# Training Loop with progress bar and accuracy
for epoch in range(5):  # reduce to 5 epochs for faster debug
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/5")
    for images, labels in loop:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        loop.set_postfix(loss=loss.item(), acc=100. * correct / total)

    epoch_loss = running_loss / len(train_loader)
    epoch_acc = 100. * correct / total
    print(f"\n✅ Epoch {epoch+1} complete. Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%")

    # Save best model
    if epoch_acc > best_accuracy:
        best_accuracy = epoch_acc
        torch.save(model.state_dict(), 'mobilenetv2_leaf_disease_best.pth')
        print("💾 Saved best model so far!")

# Save final model
torch.save(model.state_dict(), 'mobilenetv2_leaf_disease_final.pth')
print("🎉 Final model saved.")





Epoch 1/5: 100%|██████████| 516/516 [18:20<00:00,  2.13s/it, acc=92.6, loss=0.037] 



✅ Epoch 1 complete. Loss: 0.2932, Accuracy: 92.61%
💾 Saved best model so far!


Epoch 2/5: 100%|██████████| 516/516 [14:34<00:00,  1.70s/it, acc=98.8, loss=0.119]  



✅ Epoch 2 complete. Loss: 0.0461, Accuracy: 98.79%
💾 Saved best model so far!


Epoch 3/5: 100%|██████████| 516/516 [17:17<00:00,  2.01s/it, acc=99.2, loss=0.00533]



✅ Epoch 3 complete. Loss: 0.0276, Accuracy: 99.19%
💾 Saved best model so far!


Epoch 4/5: 100%|██████████| 516/516 [32:30<00:00,  3.78s/it, acc=99.3, loss=0.00584]   



✅ Epoch 4 complete. Loss: 0.0226, Accuracy: 99.34%
💾 Saved best model so far!


Epoch 5/5: 100%|██████████| 516/516 [21:28<00:00,  2.50s/it, acc=99.5, loss=0.0104]   


✅ Epoch 5 complete. Loss: 0.0179, Accuracy: 99.48%
💾 Saved best model so far!
🎉 Final model saved.





In [2]:
%pip install scikit-learn

Collecting scikit-learn
  Using cached scikit_learn-1.7.0-cp311-cp311-win_amd64.whl.metadata (14 kB)
Collecting scipy>=1.8.0 (from scikit-learn)
  Using cached scipy-1.16.0-cp311-cp311-win_amd64.whl.metadata (60 kB)
Collecting joblib>=1.2.0 (from scikit-learn)
  Using cached joblib-1.5.1-py3-none-any.whl.metadata (5.6 kB)
Collecting threadpoolctl>=3.1.0 (from scikit-learn)
  Using cached threadpoolctl-3.6.0-py3-none-any.whl.metadata (13 kB)
Using cached scikit_learn-1.7.0-cp311-cp311-win_amd64.whl (10.7 MB)
Using cached joblib-1.5.1-py3-none-any.whl (307 kB)
Using cached scipy-1.16.0-cp311-cp311-win_amd64.whl (38.6 MB)
Using cached threadpoolctl-3.6.0-py3-none-any.whl (18 kB)
Installing collected packages: threadpoolctl, scipy, joblib, scikit-learn

   ---------- ----------------------------- 1/4 [scipy]
   ---------- ----------------------------- 1/4 [scipy]
   ---------- ----------------------------- 1/4 [scipy]
   ---------- ----------------------------- 1/4 [scipy]
   ---------- --

In [2]:
print("📊 Evaluating model on validation set...")

import torch
import torchvision
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoader
from sklearn.metrics import classification_report, confusion_matrix
import numpy as np
from tqdm import tqdm
import os

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Define model architecture (must match training)
model = models.mobilenet_v2(pretrained=False)
num_classes = 15  # Update this to match your number of classes
model.classifier[1] = torch.nn.Linear(model.last_channel, num_classes)
model = model.to(device)

# Load the best model
model.load_state_dict(torch.load('mobilenetv2_leaf_disease_best.pth', map_location=device))
model.eval()

# Data directory (update if different)
data_dir = '../data/PlantVillage'

# Validation transforms
val_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Load validation dataset
val_dataset = datasets.ImageFolder(os.path.join(data_dir, 'val'), val_transforms)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

# Get class names
class_names = val_dataset.classes
print(f"Found {len(class_names)} classes: {class_names}")

# Initialize predictions and true labels
all_preds = []
all_labels = []

# Disable gradient calculation and run evaluation
with torch.no_grad():
    for images, labels in tqdm(val_loader, desc="Evaluating"):
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, preds = torch.max(outputs, 1)
        
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

# Calculate metrics
print("\n" + "="*80)
print("📈 Classification Report")
print("="*80)
print(classification_report(all_labels, all_preds, target_names=class_names, digits=4))

# Calculate and display confusion matrix
print("\n" + "="*80)
print("📊 Confusion Matrix")
print("="*80)
cm = confusion_matrix(all_labels, all_preds)
print(cm)

# Calculate and display overall metrics
accuracy = np.sum(np.diag(cm)) / np.sum(cm)
print(f"\n✅ Overall Accuracy: {accuracy:.4f}")
precision = np.mean(np.diag(cm) / (np.sum(cm, axis=0) + 1e-10))
recall = np.mean(np.diag(cm) / (np.sum(cm, axis=1) + 1e-10))
f1 = 2 * (precision * recall) / (precision + recall + 1e-10)
print(f"✅ Overall Precision: {precision:.4f}")
print(f"✅ Overall Recall: {recall:.4f}")
print(f"✅ Overall F1-Score: {f1:.4f}")

print("\nEvaluation complete! 🎉")

📊 Evaluating model on validation set...




Using device: cpu
Found 15 classes: ['Pepper__bell___Bacterial_spot', 'Pepper__bell___healthy', 'Potato___Early_blight', 'Potato___Late_blight', 'Potato___healthy', 'Tomato_Bacterial_spot', 'Tomato_Early_blight', 'Tomato_Late_blight', 'Tomato_Leaf_Mold', 'Tomato_Septoria_leaf_spot', 'Tomato_Spider_mites_Two_spotted_spider_mite', 'Tomato__Target_Spot', 'Tomato__Tomato_YellowLeaf__Curl_Virus', 'Tomato__Tomato_mosaic_virus', 'Tomato_healthy']


Evaluating: 100%|██████████| 130/130 [02:03<00:00,  1.05it/s]


📈 Classification Report
                                             precision    recall  f1-score   support

              Pepper__bell___Bacterial_spot     1.0000    1.0000    1.0000       200
                     Pepper__bell___healthy     1.0000    1.0000    1.0000       296
                      Potato___Early_blight     1.0000    0.9950    0.9975       200
                       Potato___Late_blight     1.0000    0.9950    0.9975       200
                           Potato___healthy     1.0000    1.0000    1.0000        31
                      Tomato_Bacterial_spot     0.9953    0.9883    0.9918       426
                        Tomato_Early_blight     0.9800    0.9800    0.9800       200
                         Tomato_Late_blight     0.9845    0.9974    0.9909       382
                           Tomato_Leaf_Mold     1.0000    0.9895    0.9947       191
                  Tomato_Septoria_leaf_spot     0.9916    1.0000    0.9958       355
Tomato_Spider_mites_Two_spotted_spider_




## TorchServe Model Packaging

After training, use the following command to create a TorchServe .mar file:

```bash
torch-model-archiver --model-name leaf-disease --version 1.0 --serialized-file mobilenetv2_leaf_disease.pth --handler image_classifier --extra-files index_to_name.json --export-path model_store
```

- `index_to_name.json` should map class indices to disease names.
    