## Step 0: Install Required Packages

In [None]:
# Install required packages
!pip install -q opencv-python albumentations torch torchvision tqdm pandas numpy scikit-learn matplotlib seaborn pillow

print("✓ All packages installed successfully!")

## Step 1: Mount Google Drive and Setup Paths

In [None]:
from google.colab import drive
import os

# Mount Google Drive
drive.mount('/content/drive')

# Set your dataset path (MODIFY THIS PATH TO YOUR DATASET LOCATION)
DATASET_BASE_PATH = '/content/drive/MyDrive/Leaf Disease Detection'

# Change working directory
os.chdir(DATASET_BASE_PATH)

print(f"Working directory: {os.getcwd()}")
print(f"\nDataset folders:")
if os.path.exists('Dataset'):
    print(os.listdir('Dataset'))
else:
    print("⚠️ Dataset folder not found! Please upload your Dataset folder to Google Drive.")

---
# PART A: DATA PREPROCESSING
---

## A1. Import Libraries for Preprocessing

In [None]:
import cv2
import numpy as np
import pandas as pd
from pathlib import Path
from tqdm import tqdm
import albumentations as A
import os

print("Libraries imported successfully!")

## A2. Configuration for Preprocessing

In [None]:
# Configuration
DATASET_PATH = "Dataset"
OUTPUT_PATH = "Dataset_Resized"
CSV_OUTPUT = "dataset_info.csv"
IMG_SIZE = 224
AUGMENTATIONS_PER_IMAGE = 2

# Class folders
CLASS_FOLDERS = [
    "Bacterial Leaf Spot",
    "Downy Mildew",
    "Healthy Leaf",
    "Mosaic Disease",
    "Powdery_Mildew"
]

print(f"Image resize target: {IMG_SIZE}x{IMG_SIZE}")
print(f"Augmentations per image: {AUGMENTATIONS_PER_IMAGE}")
print(f"Classes: {len(CLASS_FOLDERS)}")

## A3. Define Augmentation Pipeline

In [None]:
def create_augmentation_pipeline():
    """Create augmentation pipeline using albumentations"""
    return A.Compose([
        A.RandomRotate90(p=0.5),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.3),
        A.RandomBrightnessContrast(
            brightness_limit=0.2, 
            contrast_limit=0.2, 
            p=0.5
        ),
        A.ShiftScaleRotate(
            shift_limit=0.05,
            scale_limit=0.1,
            rotate_limit=15,
            p=0.5
        ),
        A.OneOf([
            A.GaussNoise(p=1),
            A.GaussianBlur(p=1),
        ], p=0.3),
    ])

def resize_image(image, size=224):
    """Resize image to specified size"""
    return cv2.resize(image, (size, size), interpolation=cv2.INTER_AREA)

def save_image(image, path):
    """Save image to specified path"""
    os.makedirs(os.path.dirname(path), exist_ok=True)
    cv2.imwrite(path, image)

print("✓ Helper functions defined")

## A4. Process Dataset (Augmentation + Resizing)

In [None]:
def process_dataset():
    """Main processing function"""
    print("Starting data preprocessing...")
    print(f"Target image size: {IMG_SIZE}x{IMG_SIZE}")
    print(f"Augmentations per image: {AUGMENTATIONS_PER_IMAGE}")
    
    # Create output directory
    os.makedirs(OUTPUT_PATH, exist_ok=True)
    
    # Initialize augmentation
    augment = create_augmentation_pipeline()
    
    # Data storage
    data_records = []
    
    # Process each class
    for class_idx, class_name in enumerate(CLASS_FOLDERS):
        print(f"\nProcessing class: {class_name} (Label: {class_idx})")
        
        class_input_path = os.path.join(DATASET_PATH, class_name)
        class_output_path = os.path.join(OUTPUT_PATH, class_name)
        
        # Check if class folder exists
        if not os.path.exists(class_input_path):
            print(f"Warning: Folder {class_input_path} not found. Skipping...")
            continue
        
        # Get all images in class folder
        image_files = [f for f in os.listdir(class_input_path) 
                      if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
        
        print(f"Found {len(image_files)} images")
        
        # Process each image
        for img_idx, img_file in enumerate(tqdm(image_files, desc=f"Processing {class_name}")):
            original_path = os.path.join(class_input_path, img_file)
            
            # Read original image
            img = cv2.imread(original_path)
            if img is None:
                print(f"Warning: Could not read {original_path}. Skipping...")
                continue
            
            original_height, original_width = img.shape[:2]
            
            # Save original resized image (no augmentation)
            img_resized = resize_image(img, IMG_SIZE)
            base_name = Path(img_file).stem
            ext = Path(img_file).suffix
            
            output_filename = f"{base_name}_orig{ext}"
            output_path = os.path.join(class_output_path, output_filename)
            save_image(img_resized, output_path)
            
            # Record original image data
            data_records.append({
                'filename': output_filename,
                'original_filename': img_file,
                'class_name': class_name,
                'class_label': class_idx,
                'original_width': original_width,
                'original_height': original_height,
                'resized_width': IMG_SIZE,
                'resized_height': IMG_SIZE,
                'augmentation_type': 'original',
                'relative_path': os.path.join(class_name, output_filename),
                'original_image_path': original_path
            })
            
            # Create augmented versions
            for aug_idx in range(AUGMENTATIONS_PER_IMAGE):
                # Apply augmentation
                augmented = augment(image=img)
                img_aug = augmented['image']
                
                # Resize augmented image
                img_aug_resized = resize_image(img_aug, IMG_SIZE)
                
                # Save augmented image
                aug_filename = f"{base_name}_aug{aug_idx+1}{ext}"
                aug_output_path = os.path.join(class_output_path, aug_filename)
                save_image(img_aug_resized, aug_output_path)
                
                # Record augmented image data
                data_records.append({
                    'filename': aug_filename,
                    'original_filename': img_file,
                    'class_name': class_name,
                    'class_label': class_idx,
                    'original_width': original_width,
                    'original_height': original_height,
                    'resized_width': IMG_SIZE,
                    'resized_height': IMG_SIZE,
                    'augmentation_type': f'augmented_{aug_idx+1}',
                    'relative_path': os.path.join(class_name, aug_filename),
                    'original_image_path': original_path
                })
    
    # Create DataFrame and save to CSV
    print(f"\nCreating CSV file: {CSV_OUTPUT}")
    df = pd.DataFrame(data_records)
    df.to_csv(CSV_OUTPUT, index=False)
    
    # Print summary statistics
    print("\n" + "="*60)
    print("PREPROCESSING SUMMARY")
    print("="*60)
    print(f"Total images processed: {len(df)}")
    print(f"CSV file saved: {CSV_OUTPUT}")
    print(f"Resized dataset folder: {OUTPUT_PATH}")
    print("\nClass distribution:")
    print(df['class_name'].value_counts().sort_index())
    print("\nAugmentation distribution:")
    print(df['augmentation_type'].value_counts())
    print("="*60)
    
    return df

# Run preprocessing
df = process_dataset()
print("\n✓ Preprocessing complete!")

---
# PART B: MODEL TRAINING & EVALUATION
---

## B1. Import Training Libraries

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if device.type == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## B2. Load and Split Dataset (70/15/15)

In [None]:
# Load CSV
df = pd.read_csv('dataset_info.csv')
print(f"Total samples: {len(df)}")
print(f"\nClass distribution:")
print(df['class_name'].value_counts())

# Split dataset: 70% train, 15% validation, 15% test
train_df, temp_df = train_test_split(
    df, test_size=0.3, stratify=df['class_label'], random_state=42
)
val_df, test_df = train_test_split(
    temp_df, test_size=0.5, stratify=temp_df['class_label'], random_state=42
)

train_df = train_df.reset_index(drop=True)
val_df = val_df.reset_index(drop=True)
test_df = test_df.reset_index(drop=True)

print(f"\nTrain: {len(train_df)} ({len(train_df)/len(df)*100:.1f}%)")
print(f"Val: {len(val_df)} ({len(val_df)/len(df)*100:.1f}%)")
print(f"Test: {len(test_df)} ({len(test_df)/len(df)*100:.1f}%)")

## B3. Create Dataset and DataLoader

In [None]:
class LeafDiseaseDataset(Dataset):
    def __init__(self, dataframe, root_dir='Dataset_Resized', transform=None):
        self.df = dataframe
        self.root_dir = root_dir
        self.transform = transform
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, self.df.loc[idx, 'relative_path'])
        label = self.df.loc[idx, 'class_label']
        
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

# Transforms
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Datasets
train_dataset = LeafDiseaseDataset(train_df, transform=transform)
val_dataset = LeafDiseaseDataset(val_df, transform=transform)
test_dataset = LeafDiseaseDataset(test_df, transform=transform)

# DataLoaders
BATCH_SIZE = 32
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

print(f"✓ DataLoaders created")

## B4. Define AlexNet Simplified Model

In [None]:
class AlexNet_simplified(nn.Module):
    def __init__(self, num_classes=5, exit_threshold=0.90):
        super(AlexNet_simplified, self).__init__()
        self.exit_threshold = exit_threshold

        # Conv Layer 1
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=2),
            nn.BatchNorm2d(96),
            nn.ReLU(),
            nn.MaxPool2d(3, 2)
        )

        # Conv Layer 2
        self.conv2 = nn.Sequential(
            nn.Conv2d(96, 256, kernel_size=5, padding=2),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(3, 2)
        )

        self.exit1 = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(256, num_classes)
        )

        # Conv Layer 3
        self.conv3 = nn.Sequential(
            nn.Conv2d(256, 384, kernel_size=3, padding=1),
            nn.BatchNorm2d(384),
            nn.ReLU()
        )

        self.exit2 = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(384, num_classes)
        )

        # Conv Layer 4
        self.conv4 = nn.Sequential(
            nn.Conv2d(384, 384, kernel_size=3, padding=1),
            nn.BatchNorm2d(384),
            nn.ReLU()
        )

        self.exit3 = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(384, num_classes)
        )

        # Conv Layer 5
        self.conv5 = nn.Sequential(
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )

        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
        
        # Final Classifier
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(4096, num_classes)
        )

    def forward(self, x, inference=False):
        x = self.conv1(x)
        x = self.conv2(x)
        out1 = self.exit1(x)

        x = self.conv3(x)
        out2 = self.exit2(x)

        x = self.conv4(x)
        out3 = self.exit3(x)

        x = self.conv5(x)
        x = self.avgpool(x)
        out_final = self.classifier(x)

        if inference:
            conf1 = F.softmax(out1, dim=1).max(1).values
            if conf1.item() >= self.exit_threshold:
                return out1, "Exit1"

            conf2 = F.softmax(out2, dim=1).max(1).values
            if conf2.item() >= self.exit_threshold:
                return out2, "Exit2"

            conf3 = F.softmax(out3, dim=1).max(1).values
            if conf3.item() >= self.exit_threshold:
                return out3, "Exit3"

            return out_final, "Final"

        return out1, out2, out3, out_final

# Initialize model
model = AlexNet_simplified(num_classes=5, exit_threshold=0.90).to(device)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

## B5. Training Setup

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5)

NUM_EPOCHS = 30
EARLY_STOP_PATIENCE = 7

print("Training configuration:")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Learning rate: 0.001")

## B6. Training Functions

In [None]:
def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for images, labels in tqdm(dataloader, desc="Training"):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        
        out1, out2, out3, out_final = model(images, inference=False)
        
        loss1 = criterion(out1, labels)
        loss2 = criterion(out2, labels)
        loss3 = criterion(out3, labels)
        loss_final = criterion(out_final, labels)
        
        loss = 0.1 * loss1 + 0.2 * loss2 + 0.3 * loss3 + 0.4 * loss_final
        
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = torch.max(out_final.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
    return running_loss / len(dataloader), 100 * correct / total

def validate_epoch(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc="Validating"):
            images, labels = images.to(device), labels.to(device)
            
            out1, out2, out3, out_final = model(images, inference=False)
            loss = criterion(out_final, labels)
            running_loss += loss.item()
            
            _, predicted = torch.max(out_final.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    return running_loss / len(dataloader), 100 * correct / total

print("✓ Training functions defined")

## B7. Train the Model

In [None]:
train_losses = []
train_accs = []
val_losses = []
val_accs = []

best_val_loss = float('inf')
early_stop_counter = 0

print("Starting training...\n")

for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")
    print("-" * 50)
    
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    train_losses.append(train_loss)
    train_accs.append(train_acc)
    
    val_loss, val_acc = validate_epoch(model, val_loader, criterion, device)
    val_losses.append(val_loss)
    val_accs.append(val_acc)
    
    scheduler.step(val_loss)
    
    print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
    print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        early_stop_counter = 0
        torch.save(model.state_dict(), 'best_model.pth')
        print("✓ Model saved!")
    else:
        early_stop_counter += 1
        print(f"Early stopping: {early_stop_counter}/{EARLY_STOP_PATIENCE}")
        
    if early_stop_counter >= EARLY_STOP_PATIENCE:
        print(f"\nEarly stopping at epoch {epoch+1}")
        break

print("\n✓ Training completed!")

## B8. Plot Training History

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

axes[0].plot(train_losses, label='Train Loss', marker='o')
axes[0].plot(val_losses, label='Val Loss', marker='s')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training and Validation Loss')
axes[0].legend()
axes[0].grid(True)

axes[1].plot(train_accs, label='Train Acc', marker='o')
axes[1].plot(val_accs, label='Val Acc', marker='s')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy (%)')
axes[1].set_title('Training and Validation Accuracy')
axes[1].legend()
axes[1].grid(True)

plt.tight_layout()
plt.savefig('training_history.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"Best validation accuracy: {max(val_accs):.2f}%")

## B9. Test on Original Images

In [None]:
# Load best model
model.load_state_dict(torch.load('best_model.pth'))
model.eval()

def predict_original_image(model, image_path, device):
    img = cv2.imread(image_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img_resized = cv2.resize(img, (224, 224))
    
    transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    img_tensor = transform(img_resized).unsqueeze(0).to(device)
    
    with torch.no_grad():
        out1, out2, out3, out_final = model(img_tensor, inference=False)
        probabilities = F.softmax(out_final, dim=1)
        confidence, predicted = torch.max(probabilities, 1)
        
    return predicted.item(), confidence.item(), probabilities.cpu().numpy()[0], img

# Test on original images
test_original_images = test_df[test_df['augmentation_type'] == 'original'].copy()

class_names = [
    "Bacterial Leaf Spot",
    "Downy Mildew",
    "Healthy Leaf",
    "Mosaic Disease",
    "Powdery Mildew"
]

predictions = []
true_labels = []
confidences = []

for idx, row in tqdm(test_original_images.iterrows(), total=len(test_original_images), desc="Testing"):
    pred_label, confidence, probs, _ = predict_original_image(model, row['original_image_path'], device)
    predictions.append(pred_label)
    true_labels.append(row['class_label'])
    confidences.append(confidence)

test_accuracy = accuracy_score(true_labels, predictions) * 100
print(f"\nTest Accuracy: {test_accuracy:.2f}%")
print(f"Average Confidence: {np.mean(confidences)*100:.2f}%")

## B10. Confusion Matrix and Classification Report

In [None]:
print("Classification Report:")
print("="*60)
print(classification_report(true_labels, predictions, target_names=class_names))

cm = confusion_matrix(true_labels, predictions)

plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=class_names, yticklabels=class_names)
plt.title('Confusion Matrix - Test on Original Images', fontsize=14, fontweight='bold')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.savefig('confusion_matrix.png', dpi=300, bbox_inches='tight')
plt.show()

## B11. Visualize Sample Predictions

In [None]:
sample_indices = np.random.choice(len(test_original_images), size=min(9, len(test_original_images)), replace=False)

fig, axes = plt.subplots(3, 3, figsize=(15, 15))
axes = axes.ravel()

for i, idx in enumerate(sample_indices):
    row = test_original_images.iloc[idx]
    pred_label, confidence, probs, img = predict_original_image(model, row['original_image_path'], device)
    
    axes[i].imshow(img)
    axes[i].axis('off')
    
    color = 'green' if pred_label == row['class_label'] else 'red'
    title = f"True: {class_names[row['class_label']]}\n"
    title += f"Pred: {class_names[pred_label]}\n"
    title += f"Conf: {confidence*100:.1f}%"
    axes[i].set_title(title, fontsize=10, color=color, fontweight='bold')

plt.suptitle('Sample Predictions on Original Test Images', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.savefig('sample_predictions.png', dpi=300, bbox_inches='tight')
plt.show()

## B12. Final Summary

In [None]:
print("\n" + "="*60)
print("FINAL MODEL SUMMARY")
print("="*60)
print(f"Model: AlexNet Simplified with Early Exits")
print(f"Classes: {len(class_names)}")
print(f"Input Size: 224x224x3")
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"\nDataset Sizes:")
print(f"  Train: {len(train_df)}")
print(f"  Validation: {len(val_df)}")
print(f"  Test: {len(test_df)}")
print(f"\nPerformance:")
print(f"  Best Val Accuracy: {max(val_accs):.2f}%")
print(f"  Test Accuracy: {test_accuracy:.2f}%")
print(f"  Avg Confidence: {np.mean(confidences)*100:.2f}%")
print("\nClass-wise Accuracy:")
for i, class_name in enumerate(class_names):
    class_mask = np.array(true_labels) == i
    if class_mask.sum() > 0:
        class_acc = accuracy_score(
            np.array(true_labels)[class_mask], 
            np.array(predictions)[class_mask]
        ) * 100
        print(f"  {class_name}: {class_acc:.2f}%")
print("="*60)
print("\n✓ All tasks completed successfully!")