<a href="https://colab.research.google.com/github/MadhavMenon10/FA25-Group20/blob/Zoya/plant_disease_fast_training_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Plant Disease Binary Classification - Fast Training Version

## Updates in this version:
- ‚úÖ Uses **HuggingFace ResNet** (better pretrained weights)
- ‚úÖ **Samples from ALL plants** (prevents overfitting)
- ‚úÖ **Faster training** (~30-45 min for full run)
- ‚úÖ **Loss tracking** after each epoch
- ‚úÖ **Saves weights** automatically

## What this does:
- Binary classification: Healthy (0) vs Diseased (1)
- Uses pretrained ResNet from HuggingFace
- Trains on a balanced sample from all 38 plant categories

---
## Step 0: Mount Google Drive
Run this first to access your dataset!

In [6]:
from google.colab import drive
drive.mount('/content/drive')
print("‚úÖ Google Drive mounted!")

Mounted at /content/drive
‚úÖ Google Drive mounted!


---
## Step 1: Install Required Libraries

In [7]:
# Install transformers for HuggingFace models
!pip install -q transformers datasets
print("‚úÖ Libraries installed!")

‚úÖ Libraries installed!


---
## Step 2: Import Libraries

In [8]:
# PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

# HuggingFace
from transformers import AutoImageProcessor, AutoModelForImageClassification

# Image processing
from PIL import Image
import torchvision.transforms as transforms

# Utilities
import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from collections import defaultdict
import random

print("‚úÖ All libraries imported!")
print(f"PyTorch version: {torch.__version__}")

‚úÖ All libraries imported!
PyTorch version: 2.8.0+cu126


---
## Step 3: Configuration

In [4]:
# === PATHS ===
DATA_PATH = '/content/drive/MyDrive/Plantvillage dataset/plantvillage dataset named/segmented'

# === SAMPLING STRATEGY ===
# Instead of using ALL images, we sample from each category
# This keeps all plants but makes training much faster!
MAX_IMAGES_PER_CATEGORY = 200  # Take max 200 images from each plant disease
# With 38 categories, this gives us ~7,600 images total (vs 50,000+)
# Training will be 6-7x faster while covering all plants!

# === TRAINING SETTINGS ===
BATCH_SIZE = 32
NUM_EPOCHS = 10  # This will now finish in ~30-45 mins!
LEARNING_RATE = 2e-5  # Smaller LR for HuggingFace models
IMG_SIZE = 224

# === DEVICE ===
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

print(f"\nüìä Training Configuration:")
print(f"  - Sampling: Max {MAX_IMAGES_PER_CATEGORY} images per category")
print(f"  - Batch size: {BATCH_SIZE}")
print(f"  - Epochs: {NUM_EPOCHS}")
print(f"  - Learning rate: {LEARNING_RATE}")

Using device: cuda

üìä Training Configuration:
  - Sampling: Max 200 images per category
  - Batch size: 32
  - Epochs: 10
  - Learning rate: 2e-05


---
## Step 4: Load Dataset with Stratified Sampling
Load images from ALL plants but sample equally from each category

In [9]:
def load_dataset_with_sampling(data_path, max_per_category=200):
    """
    Load dataset with stratified sampling:
    - Takes up to max_per_category images from EACH plant disease
    - Ensures all plants are represented
    - Creates binary labels (Healthy=0, Diseased=1)

    This prevents overfitting to specific plants while keeping training fast!
    """
    image_paths = []
    labels = []

    categories = [d for d in os.listdir(data_path)
                  if os.path.isdir(os.path.join(data_path, d))]

    print(f"Found {len(categories)} plant categories\n")
    print("Sampling from each category...\n")

    healthy_count = 0
    diseased_count = 0

    for category in sorted(categories):
        category_path = os.path.join(data_path, category)

        # Determine label
        is_healthy = 'healthy' in category.lower()
        label = 0 if is_healthy else 1

        # Get all files (including files with no extension for segmented images)
        all_files = os.listdir(category_path)
        image_files = [f for f in all_files
                      if f.lower().endswith(('.png', '.jpg', '.jpeg')) or '.' not in f]

        # Sample up to max_per_category images
        if len(image_files) > max_per_category:
            sampled_files = random.sample(image_files, max_per_category)
        else:
            sampled_files = image_files

        # Add to dataset
        for img_name in sampled_files:
            img_path = os.path.join(category_path, img_name)
            image_paths.append(img_path)
            labels.append(label)

        if is_healthy:
            healthy_count += len(sampled_files)
        else:
            diseased_count += len(sampled_files)

        print(f"{category:50s} | Total: {len(image_files):4d} | Sampled: {len(sampled_files):3d} | {'‚úì Healthy' if is_healthy else '‚úó Diseased'}")

    print(f"\n{'='*80}")
    print(f"üìä Final Dataset Summary:")
    print(f"  Total images: {len(image_paths)}")
    print(f"  Healthy: {healthy_count} ({100*healthy_count/len(image_paths):.1f}%)")
    print(f"  Diseased: {diseased_count} ({100*diseased_count/len(image_paths):.1f}%)")
    print(f"  Categories: {len(categories)}")
    print(f"{'='*80}")

    return image_paths, labels

# Set random seed for reproducibility
random.seed(42)
np.random.seed(42)

# Load dataset
image_paths, labels = load_dataset_with_sampling(DATA_PATH, MAX_IMAGES_PER_CATEGORY)

Found 38 plant categories

Sampling from each category...

Apple___Apple_scab                                 | Total:  630 | Sampled: 200 | ‚úó Diseased
Apple___Black_rot                                  | Total:  621 | Sampled: 200 | ‚úó Diseased
Apple___Cedar_apple_rust                           | Total:  276 | Sampled: 200 | ‚úó Diseased
Apple___healthy                                    | Total: 1645 | Sampled: 200 | ‚úì Healthy
Blueberry___healthy                                | Total: 1502 | Sampled: 200 | ‚úì Healthy
Cherry_(including_sour)___Powdery_mildew           | Total: 1052 | Sampled: 200 | ‚úó Diseased
Cherry_(including_sour)___healthy                  | Total:  854 | Sampled: 200 | ‚úì Healthy
Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot | Total:  513 | Sampled: 200 | ‚úó Diseased
Corn_(maize)___Common_rust_                        | Total: 1192 | Sampled: 200 | ‚úó Diseased
Corn_(maize)___Northern_Leaf_Blight                | Total:  985 | Sampled: 200 | ‚úó Dis

---
## Step 5: Train/Validation Split

In [10]:
from sklearn.model_selection import train_test_split

# 80/20 split with stratification
train_paths, val_paths, train_labels, val_labels = train_test_split(
    image_paths,
    labels,
    test_size=0.2,
    random_state=42,
    stratify=labels
)

print(f"Training: {len(train_paths)} images")
print(f"  - Healthy: {train_labels.count(0)}")
print(f"  - Diseased: {train_labels.count(1)}")
print(f"\nValidation: {len(val_paths)} images")
print(f"  - Healthy: {val_labels.count(0)}")
print(f"  - Diseased: {val_labels.count(1)}")

Training: 5401 images
  - Healthy: 1561
  - Diseased: 3840

Validation: 1351 images
  - Healthy: 391
  - Diseased: 960


---
## Step 6: Image Transforms
Note: HuggingFace ResNet handles its own preprocessing,
but we still need basic transforms for data augmentation

In [11]:
# Load HuggingFace image processor for ResNet
processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")

print("‚úÖ HuggingFace ResNet-50 processor loaded!")
print(f"\nExpected image size: {processor.size}")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


preprocessor_config.json:   0%|          | 0.00/266 [00:00<?, ?B/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


‚úÖ HuggingFace ResNet-50 processor loaded!

Expected image size: {'shortest_edge': 224}


---
## Step 7: Custom Dataset Class

In [12]:
class PlantDiseaseDataset(Dataset):
    """
    Custom dataset that works with HuggingFace image processor
    """
    def __init__(self, image_paths, labels, processor, augment=False):
        self.image_paths = image_paths
        self.labels = labels
        self.processor = processor
        self.augment = augment

        # Data augmentation transforms (optional for training)
        if augment:
            self.aug_transform = transforms.Compose([
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomRotation(10),
            ])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # Load image
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')

        # Apply augmentation if training
        if self.augment:
            image = self.aug_transform(image)

        # Process with HuggingFace processor
        # This handles resizing and normalization automatically!
        inputs = self.processor(images=image, return_tensors="pt")
        pixel_values = inputs['pixel_values'].squeeze(0)

        label = torch.tensor(self.labels[idx], dtype=torch.float32)

        return pixel_values, label

# Create datasets
train_dataset = PlantDiseaseDataset(train_paths, train_labels, processor, augment=True)
val_dataset = PlantDiseaseDataset(val_paths, val_labels, processor, augment=False)

print(f"‚úÖ Datasets created!")
print(f"  Training: {len(train_dataset)} images (with augmentation)")
print(f"  Validation: {len(val_dataset)} images (no augmentation)")

‚úÖ Datasets created!
  Training: 5401 images (with augmentation)
  Validation: 1351 images (no augmentation)


---
## Step 8: Create Data Loaders

In [13]:
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

print(f"‚úÖ Data loaders ready!")
print(f"  Training batches: {len(train_loader)}")
print(f"  Validation batches: {len(val_loader)}")
print(f"\n‚ö° Estimated time per epoch: {len(train_loader) * 0.5 / 60:.1f} minutes")

‚úÖ Data loaders ready!
  Training batches: 169
  Validation batches: 43

‚ö° Estimated time per epoch: 1.4 minutes


---
## Step 9: Load HuggingFace ResNet Model
Using microsoft/resnet-50 - state-of-the-art pretrained model!

In [14]:
# Load pretrained ResNet-50 from HuggingFace
base_model = AutoModelForImageClassification.from_pretrained(
    "microsoft/resnet-50",
    num_labels=1,  # Binary classification (output 1 value)
    ignore_mismatched_sizes=True
)

print("‚úÖ HuggingFace ResNet-50 loaded!")
print("\nModel info:")
print(f"  - Pretrained on ImageNet (1M+ images)")
print(f"  - 50 layers deep")
print(f"  - Modified final layer for binary classification")

# Move to device
base_model = base_model.to(device)
print(f"\n‚úÖ Model moved to {device}")

config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/102M [00:00<?, ?B/s]

Some weights of ResNetForImageClassification were not initialized from the model checkpoint at microsoft/resnet-50 and are newly initialized because the shapes did not match:
- classifier.1.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([1]) in the model instantiated
- classifier.1.weight: found shape torch.Size([1000, 2048]) in the checkpoint and torch.Size([1, 2048]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


‚úÖ HuggingFace ResNet-50 loaded!

Model info:
  - Pretrained on ImageNet (1M+ images)
  - 50 layers deep
  - Modified final layer for binary classification

‚úÖ Model moved to cuda


---
## Step 10: Define Loss Function and Optimizer

In [17]:
# Binary Cross Entropy Loss (for binary classification)
criterion = nn.BCEWithLogitsLoss()

# AdamW optimizer (better than Adam for transformers)
optimizer = optim.AdamW(base_model.parameters(), lr=LEARNING_RATE)

# Learning rate scheduler (reduces LR when loss plateaus)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.5,
    patience=2
    # Removed 'verbose=True' - not supported in this PyTorch version
)

print("‚úÖ Training setup complete!")
print(f"  Loss: Binary Cross Entropy")
print(f"  Optimizer: AdamW (lr={LEARNING_RATE})")
print(f"  Scheduler: ReduceLROnPlateau")

‚úÖ Training setup complete!
  Loss: Binary Cross Entropy
  Optimizer: AdamW (lr=2e-05)
  Scheduler: ReduceLROnPlateau


---
## Step 11: Training Functions

In [18]:
def train_epoch(model, loader, criterion, optimizer, device):
    """
    Train for one epoch
    """
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    pbar = tqdm(loader, desc="Training")

    for images, labels in pbar:
        images = images.to(device)
        labels = labels.unsqueeze(1).to(device)

        # Forward pass
        outputs = model(images).logits
        loss = criterion(outputs, labels)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Metrics
        running_loss += loss.item()
        predictions = (torch.sigmoid(outputs) > 0.5).float()
        correct += (predictions == labels).sum().item()
        total += labels.size(0)

        pbar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'acc': f'{100*correct/total:.2f}%'
        })

    return running_loss / len(loader), 100 * correct / total


def validate(model, loader, criterion, device):
    """
    Validate the model
    """
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        pbar = tqdm(loader, desc="Validating")

        for images, labels in pbar:
            images = images.to(device)
            labels = labels.unsqueeze(1).to(device)

            outputs = model(images).logits
            loss = criterion(outputs, labels)

            running_loss += loss.item()
            predictions = (torch.sigmoid(outputs) > 0.5).float()
            correct += (predictions == labels).sum().item()
            total += labels.size(0)

            pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'acc': f'{100*correct/total:.2f}%'
            })

    return running_loss / len(loader), 100 * correct / total

print("‚úÖ Training functions defined!")

‚úÖ Training functions defined!


---
## Step 12: Training Loop with Loss Tracking
üöÄ This is where the training happens!

In [1]:
# Track metrics
train_losses = []
val_losses = []
train_accs = []
val_accs = []
learning_rates = []

best_val_acc = 0.0
best_epoch = 0

print("üöÄ Starting training!\n")
print("="*80)

for epoch in range(NUM_EPOCHS):
    print(f"\nüìä Epoch {epoch+1}/{NUM_EPOCHS}")
    print("-"*80)

    # Train
    train_loss, train_acc = train_epoch(base_model, train_loader, criterion, optimizer, device)

    # Validate
    val_loss, val_acc = validate(base_model, val_loader, criterion, device)

    # Update scheduler
    scheduler.step(val_loss)
    current_lr = optimizer.param_groups[0]['lr']

    # Save metrics
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_accs.append(train_acc)
    val_accs.append(val_acc)
    learning_rates.append(current_lr)

    # Print summary
    print(f"\nüìà Epoch {epoch+1} Results:")
    print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
    print(f"  Val Loss:   {val_loss:.4f} | Val Acc:   {val_acc:.2f}%")
    print(f"  Learning Rate: {current_lr:.2e}")

    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_epoch = epoch + 1
        torch.save(base_model.state_dict(), 'best_model_weights.pth')
        print(f"  ‚≠ê New best model saved! (Val Acc: {val_acc:.2f}%)")

    print("="*80)

print(f"\nüéâ Training complete!")
print(f"\nüèÜ Best Results:")
print(f"  Best validation accuracy: {best_val_acc:.2f}%")
print(f"  Achieved at epoch: {best_epoch}")

üöÄ Starting training!



NameError: name 'NUM_EPOCHS' is not defined

---
## Step 13: Save Final Model Weights

In [None]:
# Save final model
torch.save(base_model.state_dict(), 'final_model_weights.pth')

# Save complete model (can be loaded later)
torch.save({
    'epoch': NUM_EPOCHS,
    'model_state_dict': base_model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'best_val_acc': best_val_acc,
    'train_losses': train_losses,
    'val_losses': val_losses,
    'train_accs': train_accs,
    'val_accs': val_accs,
}, 'complete_checkpoint.pth')

print("‚úÖ Model weights saved!")
print("\nSaved files:")
print("  1. best_model_weights.pth - Best performing model")
print("  2. final_model_weights.pth - Final epoch model")
print("  3. complete_checkpoint.pth - Full training state")

# Copy to Google Drive for permanent storage
!cp *.pth /content/drive/MyDrive/
print("\n‚úÖ Weights also copied to your Google Drive!")

---
## Step 14: Visualize Training Results

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(15, 12))

# Plot 1: Loss
axes[0, 0].plot(train_losses, 'b-', marker='o', label='Training Loss')
axes[0, 0].plot(val_losses, 'r-', marker='s', label='Validation Loss')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Training and Validation Loss')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Plot 2: Accuracy
axes[0, 1].plot(train_accs, 'b-', marker='o', label='Training Accuracy')
axes[0, 1].plot(val_accs, 'r-', marker='s', label='Validation Accuracy')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Accuracy (%)')
axes[0, 1].set_title('Training and Validation Accuracy')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Plot 3: Learning Rate
axes[1, 0].plot(learning_rates, 'g-', marker='o')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Learning Rate')
axes[1, 0].set_title('Learning Rate Schedule')
axes[1, 0].set_yscale('log')
axes[1, 0].grid(True, alpha=0.3)

# Plot 4: Loss difference (overfitting check)
loss_diff = np.array(val_losses) - np.array(train_losses)
axes[1, 1].plot(loss_diff, 'm-', marker='o')
axes[1, 1].axhline(y=0, color='k', linestyle='--', alpha=0.3)
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Val Loss - Train Loss')
axes[1, 1].set_title('Overfitting Check (lower is better)')
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('training_results.png', dpi=150, bbox_inches='tight')
plt.show()

print("\nüìä Training visualizations saved as 'training_results.png'")

---
## Step 15: Print Final Summary

In [None]:
print("="*80)
print("üéì TRAINING SUMMARY")
print("="*80)

print(f"\nüìä Dataset:")
print(f"  Total images used: {len(image_paths)}")
print(f"  Plant categories: {len(set([os.path.basename(os.path.dirname(p)) for p in image_paths]))}")
print(f"  Training images: {len(train_paths)}")
print(f"  Validation images: {len(val_paths)}")

print(f"\nü§ñ Model:")
print(f"  Architecture: HuggingFace ResNet-50")
print(f"  Task: Binary Classification (Healthy vs Diseased)")
print(f"  Parameters: ~25M")

print(f"\n‚öôÔ∏è Training Configuration:")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Initial learning rate: {LEARNING_RATE}")
print(f"  Optimizer: AdamW")
print(f"  Loss function: Binary Cross Entropy")

print(f"\nüèÜ Best Results:")
print(f"  Best validation accuracy: {best_val_acc:.2f}%")
print(f"  Achieved at epoch: {best_epoch}")
print(f"  Final train accuracy: {train_accs[-1]:.2f}%")
print(f"  Final validation accuracy: {val_accs[-1]:.2f}%")

print(f"\nüìÅ Saved Files:")
print(f"  ‚úì best_model_weights.pth")
print(f"  ‚úì final_model_weights.pth")
print(f"  ‚úì complete_checkpoint.pth")
print(f"  ‚úì training_results.png")

print(f"\n" + "="*80)
print("‚úÖ All done! Your model is ready to use!")
print("="*80)

---
## BONUS: Test on Sample Images

In [None]:
def predict_image(model, image_path, processor, device):
    """
    Predict on a single image
    """
    model.eval()

    # Load and process image
    image = Image.open(image_path).convert('RGB')
    inputs = processor(images=image, return_tensors="pt")
    pixel_values = inputs['pixel_values'].to(device)

    # Predict
    with torch.no_grad():
        output = model(pixel_values).logits
        probability = torch.sigmoid(output).item()
        prediction = 1 if probability > 0.5 else 0

    return prediction, probability

# Test on 5 random validation images
print("üß™ Testing on sample images:\n")

sample_indices = random.sample(range(len(val_paths)), min(5, len(val_paths)))

for idx in sample_indices:
    img_path = val_paths[idx]
    true_label = val_labels[idx]

    pred, prob = predict_image(base_model, img_path, processor, device)

    folder_name = os.path.basename(os.path.dirname(img_path))
    true_class = "Healthy" if true_label == 0 else "Diseased"
    pred_class = "Healthy" if pred == 0 else "Diseased"

    print(f"üì∏ {folder_name}")
    print(f"   True: {true_class}")
    print(f"   Predicted: {pred_class} ({prob:.1%} confidence)")
    print(f"   {'‚úÖ Correct!' if pred == true_label else '‚ùå Wrong'}\n")

---
## How to Load Your Trained Model Later

```python
from transformers import AutoImageProcessor, AutoModelForImageClassification
import torch

# Load processor
processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")

# Load model architecture
model = AutoModelForImageClassification.from_pretrained(
    "microsoft/resnet-50",
    num_labels=1,
    ignore_mismatched_sizes=True
)

# Load trained weights
model.load_state_dict(torch.load('best_model_weights.pth'))
model.eval()

# Now you can use it for predictions!
```