# 🌱 CAPSTONE-LAZARUS: Plant Disease Detection - Data Exploration & Model Training

## 📊 Project Overview
This notebook provides a comprehensive exploration and training pipeline for **AI-powered plant disease detection**. We're building a robust system to help farmers and agronomists quickly diagnose plant health issues from leaf images.

### Key Features:
- 🔍 **Multi-crop disease detection** (Corn, Potato, Tomato)
- 📈 **Class imbalance handling** with sophisticated techniques
- 🎯 **High recall for critical diseases** (minimize false negatives)
- 📱 **Mobile-ready deployment** with model compression
- 🧠 **Explainable AI** with saliency maps and confidence scoring
- 📡 **Uncertainty estimation** for trustworthy predictions

### Agricultural Impact:
- Early disease detection → **reduced crop losses**
- Precise diagnosis → **optimized pesticide use**
- AI-assisted decisions → **empowered smallholder farmers**

In [None]:
# 📚 Import Essential Libraries
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import cv2
from PIL import Image
import os
from pathlib import Path
from collections import Counter
import random

# Deep Learning
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, applications, optimizers, callbacks
from tensorflow.keras.utils import to_categorical
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
from sklearn.utils.class_weight import compute_class_weight

# Visualization & Analysis
import shap
import lime
from lime import lime_image
from skimage.segmentation import mark_boundaries

print(f"🔥 TensorFlow Version: {tf.__version__}")
print(f"🎮 GPU Available: {tf.config.list_physical_devices('GPU')}")
print(f"🧠 Memory Growth Enabled")

# Configure GPU memory growth
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(f"GPU configuration error: {e}")

In [None]:
# 🎨 Configure Visualization Settings
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

# Custom color palette for diseases
DISEASE_COLORS = {
    'healthy': '#2ECC71',       # Green
    'bacterial': '#E74C3C',     # Red
    'fungal': '#8E44AD',        # Purple
    'viral': '#F39C12',         # Orange
    'pest': '#E67E22',          # Orange-red
    'nutrient': '#3498DB',      # Blue
    'other': '#95A5A6'          # Gray
}

# Plot configuration
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 12
plt.rcParams['axes.titlesize'] = 14
plt.rcParams['axes.labelsize'] = 12

In [None]:
# 📁 Dataset Configuration & Path Setup
PROJECT_ROOT = Path("C:/Users/MadScie254/Documents/GitHub/Portifolio/Capstone-Lazarus")
DATA_DIR = PROJECT_ROOT / "data"
MODELS_DIR = PROJECT_ROOT / "models"
EXPERIMENTS_DIR = PROJECT_ROOT / "experiments"

# Ensure directories exist
MODELS_DIR.mkdir(exist_ok=True)
EXPERIMENTS_DIR.mkdir(exist_ok=True)

print(f"📂 Data Directory: {DATA_DIR}")
print(f"🤖 Models Directory: {MODELS_DIR}")
print(f"🔬 Experiments Directory: {EXPERIMENTS_DIR}")
print(f"✅ Directory structure validated")

In [None]:
# 🔍 Comprehensive Dataset Analysis

def analyze_dataset_structure(data_dir):
    """Comprehensive dataset structure analysis"""
    
    # Get all class directories
    class_dirs = [d for d in data_dir.iterdir() if d.is_dir()]
    class_info = []
    
    total_images = 0
    
    for class_dir in class_dirs:
        # Count images in each class
        image_files = list(class_dir.glob('*.jpg')) + list(class_dir.glob('*.jpeg')) + \
                     list(class_dir.glob('*.png')) + list(class_dir.glob('*.JPG'))
        
        num_images = len(image_files)
        total_images += num_images
        
        # Parse class information
        class_name = class_dir.name
        if '___' in class_name:
            crop, condition = class_name.split('___', 1)
            crop = crop.replace('(', '').replace(')', '').replace('_', ' ').title()
        else:
            crop = class_name
            condition = 'Unknown'
        
        # Sample image for size analysis
        if image_files:
            sample_img = Image.open(image_files[0])
            img_width, img_height = sample_img.size
        else:
            img_width, img_height = 0, 0
        
        class_info.append({
            'class_name': class_name,
            'crop': crop,
            'condition': condition,
            'num_images': num_images,
            'sample_width': img_width,
            'sample_height': img_height,
            'directory': class_dir
        })
    
    df = pd.DataFrame(class_info)
    
    print(f"🌾 Dataset Overview:")
    print(f"   📊 Total Classes: {len(class_dirs)}")
    print(f"   🖼️  Total Images: {total_images:,}")
    print(f"   🌱 Crops: {df['crop'].nunique()} ({', '.join(df['crop'].unique())})")
    print(f"   🦠 Conditions: {df['condition'].nunique()}")
    
    return df

# Analyze the dataset
dataset_df = analyze_dataset_structure(DATA_DIR)
dataset_df.head(10)

In [None]:
# 📊 Class Distribution Visualization

def create_class_distribution_plots(df):
    """Create comprehensive class distribution visualizations"""
    
    # 1. Overall class distribution
    fig = make_subplots(
        rows=2, cols=2,
        subplot_titles=('Class Distribution', 'Crop Distribution', 
                       'Condition Distribution', 'Class Imbalance Analysis'),
        specs=[[{"type": "bar"}, {"type": "pie"}],
               [{"type": "bar"}, {"type": "scatter"}]]
    )
    
    # Class distribution bar plot
    sorted_df = df.sort_values('num_images', ascending=False)
    fig.add_trace(
        go.Bar(x=sorted_df['class_name'], y=sorted_df['num_images'],
               name='Images per Class', marker_color='lightblue'),
        row=1, col=1
    )
    
    # Crop distribution pie chart
    crop_counts = df.groupby('crop')['num_images'].sum().reset_index()
    fig.add_trace(
        go.Pie(labels=crop_counts['crop'], values=crop_counts['num_images'],
               name='Crop Distribution'),
        row=1, col=2
    )
    
    # Condition distribution
    condition_counts = df.groupby('condition')['num_images'].sum().reset_index()
    condition_counts = condition_counts.sort_values('num_images', ascending=False)
    fig.add_trace(
        go.Bar(x=condition_counts['condition'], y=condition_counts['num_images'],
               name='Condition Distribution', marker_color='lightgreen'),
        row=2, col=1
    )
    
    # Class imbalance scatter
    df_sorted = df.sort_values('num_images', ascending=False).reset_index()
    fig.add_trace(
        go.Scatter(x=df_sorted.index, y=df_sorted['num_images'],
                  mode='markers+lines', name='Class Sizes',
                  marker=dict(size=8, color='red')),
        row=2, col=2
    )
    
    fig.update_xaxes(title_text="Classes", row=1, col=1)
    fig.update_yaxes(title_text="Number of Images", row=1, col=1)
    fig.update_xaxes(title_text="Conditions", row=2, col=1)
    fig.update_yaxes(title_text="Number of Images", row=2, col=1)
    fig.update_xaxes(title_text="Class Rank", row=2, col=2)
    fig.update_yaxes(title_text="Number of Images", row=2, col=2)
    
    fig.update_layout(height=800, title_text="🌱 Plant Disease Dataset Analysis", showlegend=False)
    fig.show()
    
    return fig

# Create distribution plots
distribution_fig = create_class_distribution_plots(dataset_df)

In [None]:
# 📈 Class Imbalance Analysis & Statistics

def analyze_class_imbalance(df):
    """Detailed class imbalance analysis"""
    
    # Calculate imbalance metrics
    total_images = df['num_images'].sum()
    min_class_size = df['num_images'].min()
    max_class_size = df['num_images'].max()
    mean_class_size = df['num_images'].mean()
    median_class_size = df['num_images'].median()
    
    # Imbalance ratio
    imbalance_ratio = max_class_size / min_class_size
    
    # Coefficient of variation
    cv = df['num_images'].std() / df['num_images'].mean()
    
    # Class distribution statistics
    print("🔍 Class Imbalance Analysis:")
    print(f"   📊 Total Images: {total_images:,}")
    print(f"   📉 Smallest Class: {min_class_size:,} images")
    print(f"   📈 Largest Class: {max_class_size:,} images")
    print(f"   ⚖️  Imbalance Ratio: {imbalance_ratio:.1f}:1")
    print(f"   📊 Mean Class Size: {mean_class_size:.0f}")
    print(f"   📊 Median Class Size: {median_class_size:.0f}")
    print(f"   📊 Coefficient of Variation: {cv:.2f}")
    
    # Risk assessment
    if imbalance_ratio > 100:
        risk_level = "🔴 CRITICAL"
    elif imbalance_ratio > 10:
        risk_level = "🟡 HIGH"
    elif imbalance_ratio > 5:
        risk_level = "🟠 MODERATE"
    else:
        risk_level = "🟢 LOW"
    
    print(f"   ⚠️  Imbalance Risk: {risk_level}")
    
    # Classes needing attention
    underrepresented = df[df['num_images'] < mean_class_size * 0.5]
    if len(underrepresented) > 0:
        print(f"\n🚨 Underrepresented Classes (< {mean_class_size * 0.5:.0f} images):")
        for _, row in underrepresented.iterrows():
            print(f"   - {row['class_name']}: {row['num_images']} images")
    
    return {
        'total_images': total_images,
        'imbalance_ratio': imbalance_ratio,
        'cv': cv,
        'underrepresented_classes': len(underrepresented)
    }

imbalance_stats = analyze_class_imbalance(dataset_df)

In [None]:
# 🖼️ Sample Images Visualization & Quality Analysis

def visualize_sample_images(df, samples_per_class=3, figsize=(20, 12)):
    """Visualize sample images from each class with quality analysis"""
    
    # Select diverse classes for visualization
    crops = df['crop'].unique()
    selected_classes = []
    
    for crop in crops:
        crop_classes = df[df['crop'] == crop]
        # Get healthy and diseased samples
        healthy = crop_classes[crop_classes['condition'].str.contains('healthy', case=False)]
        diseased = crop_classes[~crop_classes['condition'].str.contains('healthy', case=False)]
        
        if len(healthy) > 0:
            selected_classes.append(healthy.iloc[0])
        if len(diseased) > 0:
            selected_classes.extend(diseased.head(2).to_dict('records'))
    
    # Create subplot grid
    n_classes = len(selected_classes)
    cols = 4
    rows = (n_classes + cols - 1) // cols
    
    fig, axes = plt.subplots(rows, cols, figsize=figsize)
    axes = axes.flatten() if rows > 1 else [axes]
    
    for idx, class_info in enumerate(selected_classes):
        if idx >= len(axes):
            break
            
        # Get sample images from this class
        image_files = list(class_info['directory'].glob('*.jpg')) + \
                     list(class_info['directory'].glob('*.JPG')) + \
                     list(class_info['directory'].glob('*.png'))
        
        if image_files:
            # Load and display a random sample
            sample_img_path = random.choice(image_files)
            img = cv2.imread(str(sample_img_path))
            img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            
            # Basic image quality metrics
            brightness = np.mean(img_rgb)
            contrast = np.std(img_rgb)
            
            axes[idx].imshow(img_rgb)
            axes[idx].set_title(f"{class_info['crop']}\n{class_info['condition']}\n"
                              f"Count: {class_info['num_images']}\n"
                              f"Brightness: {brightness:.1f}, Contrast: {contrast:.1f}", 
                              fontsize=10)
            axes[idx].axis('off')
        else:
            axes[idx].text(0.5, 0.5, 'No Images', ha='center', va='center')
            axes[idx].set_title(f"{class_info['class_name']}")
            axes[idx].axis('off')
    
    # Hide extra subplots
    for idx in range(len(selected_classes), len(axes)):
        axes[idx].axis('off')
    
    plt.tight_layout()
    plt.suptitle('🌱 Sample Images from Plant Disease Dataset', fontsize=16, y=1.02)
    plt.show()
    
    return selected_classes

# Visualize sample images
sample_classes = visualize_sample_images(dataset_df)

In [None]:
# 🔬 Advanced Image Quality Analysis

def analyze_image_quality(df, sample_size=100):
    """Comprehensive image quality analysis across classes"""
    
    quality_metrics = []
    
    for _, class_info in df.iterrows():
        class_dir = class_info['directory']
        image_files = list(class_dir.glob('*.jpg')) + \
                     list(class_dir.glob('*.JPG')) + \
                     list(class_dir.glob('*.png'))
        
        if not image_files:
            continue
        
        # Sample images for analysis
        sample_files = random.sample(image_files, min(sample_size, len(image_files)))
        
        class_metrics = {
            'class_name': class_info['class_name'],
            'crop': class_info['crop'],
            'condition': class_info['condition'],
            'brightness_mean': [],
            'contrast_mean': [],
            'sharpness_mean': [],
            'size_variance': []
        }
        
        for img_file in sample_files:
            try:
                # Load image
                img = cv2.imread(str(img_file))
                if img is None:
                    continue
                    
                img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
                
                # Quality metrics
                brightness = np.mean(img_rgb)
                contrast = np.std(img_rgb)
                sharpness = cv2.Laplacian(img_gray, cv2.CV_64F).var()
                
                class_metrics['brightness_mean'].append(brightness)
                class_metrics['contrast_mean'].append(contrast)
                class_metrics['sharpness_mean'].append(sharpness)
                class_metrics['size_variance'].append(img.shape[0] * img.shape[1])
                
            except Exception as e:
                continue
        
        # Calculate aggregated metrics
        if class_metrics['brightness_mean']:
            quality_metrics.append({
                'class_name': class_info['class_name'],
                'crop': class_info['crop'],
                'condition': class_info['condition'],
                'num_images': class_info['num_images'],
                'avg_brightness': np.mean(class_metrics['brightness_mean']),
                'avg_contrast': np.mean(class_metrics['contrast_mean']),
                'avg_sharpness': np.mean(class_metrics['sharpness_mean']),
                'brightness_std': np.std(class_metrics['brightness_mean']),
                'contrast_std': np.std(class_metrics['contrast_mean']),
                'sharpness_std': np.std(class_metrics['sharpness_mean']),
            })
    
    quality_df = pd.DataFrame(quality_metrics)
    
    # Visualize quality metrics
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    
    # Brightness distribution
    quality_df.groupby('crop')['avg_brightness'].mean().plot(kind='bar', ax=axes[0,0], color='gold')
    axes[0,0].set_title('🌟 Average Brightness by Crop')
    axes[0,0].set_ylabel('Brightness')
    axes[0,0].tick_params(axis='x', rotation=45)
    
    # Contrast distribution
    quality_df.groupby('crop')['avg_contrast'].mean().plot(kind='bar', ax=axes[0,1], color='purple')
    axes[0,1].set_title('📊 Average Contrast by Crop')
    axes[0,1].set_ylabel('Contrast')
    axes[0,1].tick_params(axis='x', rotation=45)
    
    # Sharpness distribution
    quality_df.groupby('crop')['avg_sharpness'].mean().plot(kind='bar', ax=axes[1,0], color='green')
    axes[1,0].set_title('🔍 Average Sharpness by Crop')
    axes[1,0].set_ylabel('Sharpness')
    axes[1,0].tick_params(axis='x', rotation=45)
    
    # Quality consistency (coefficient of variation)
    quality_df['brightness_cv'] = quality_df['brightness_std'] / quality_df['avg_brightness']
    quality_df.groupby('crop')['brightness_cv'].mean().plot(kind='bar', ax=axes[1,1], color='red')
    axes[1,1].set_title('📏 Brightness Consistency by Crop')
    axes[1,1].set_ylabel('Coefficient of Variation')
    axes[1,1].tick_params(axis='x', rotation=45)
    
    plt.tight_layout()
    plt.show()
    
    return quality_df

print("🔬 Analyzing image quality across all classes...")
quality_df = analyze_image_quality(dataset_df, sample_size=50)
print(f"✅ Quality analysis completed for {len(quality_df)} classes")

In [None]:
# 🎯 Strategic Class Grouping & Disease Severity Analysis

def create_disease_taxonomy(df):
    """Create a hierarchical disease taxonomy for strategic analysis"""
    
    taxonomy = {
        'healthy': [],
        'fungal_diseases': [],
        'bacterial_diseases': [],
        'viral_diseases': [],
        'pest_damage': [],
        'nutrient_deficiency': [],
        'other_conditions': []
    }
    
    # Classification rules based on condition names
    for _, row in df.iterrows():
        condition = row['condition'].lower()
        class_name = row['class_name']
        
        if 'healthy' in condition:
            taxonomy['healthy'].append(class_name)
        elif any(term in condition for term in ['blight', 'rust', 'spot', 'mold', 'leaf_spot']):
            taxonomy['fungal_diseases'].append(class_name)
        elif 'bacterial' in condition:
            taxonomy['bacterial_diseases'].append(class_name)
        elif any(term in condition for term in ['virus', 'mosaic', 'curl']):
            taxonomy['viral_diseases'].append(class_name)
        elif any(term in condition for term in ['mite', 'spider']):
            taxonomy['pest_damage'].append(class_name)
        else:
            taxonomy['other_conditions'].append(class_name)
    
    # Calculate group statistics
    group_stats = []
    for group, classes in taxonomy.items():
        if classes:
            group_data = df[df['class_name'].isin(classes)]
            group_stats.append({
                'disease_group': group,
                'num_classes': len(classes),
                'total_images': group_data['num_images'].sum(),
                'avg_images_per_class': group_data['num_images'].mean(),
                'min_images': group_data['num_images'].min(),
                'max_images': group_data['num_images'].max(),
                'classes': classes
            })
    
    group_df = pd.DataFrame(group_stats)
    
    # Visualize disease taxonomy
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
    
    # Disease group distribution
    group_df.plot(x='disease_group', y='total_images', kind='bar', ax=ax1, color='skyblue')
    ax1.set_title('📊 Images by Disease Group')
    ax1.set_ylabel('Total Images')
    ax1.tick_params(axis='x', rotation=45)
    
    # Classes per group
    group_df.plot(x='disease_group', y='num_classes', kind='bar', ax=ax2, color='lightcoral')
    ax2.set_title('🏷️ Classes by Disease Group')
    ax2.set_ylabel('Number of Classes')
    ax2.tick_params(axis='x', rotation=45)
    
    plt.tight_layout()
    plt.show()
    
    print("🎯 Disease Taxonomy Analysis:")
    for _, row in group_df.iterrows():
        print(f"\n📋 {row['disease_group'].replace('_', ' ').title()}:")
        print(f"   Classes: {row['num_classes']}")
        print(f"   Total Images: {row['total_images']:,}")
        print(f"   Avg per Class: {row['avg_images_per_class']:.0f}")
        print(f"   Range: {row['min_images']}-{row['max_images']} images")
    
    return taxonomy, group_df

# Create disease taxonomy
disease_taxonomy, disease_groups_df = create_disease_taxonomy(dataset_df)

In [None]:
# 🚀 Advanced Data Loading & Preprocessing Pipeline

class PlantDiseaseDataProcessor:
    """Advanced data processor for plant disease detection"""
    
    def __init__(self, data_dir, img_size=(224, 224), batch_size=32):
        self.data_dir = Path(data_dir)
        self.img_size = img_size
        self.batch_size = batch_size
        self.class_names = None
        self.label_encoder = LabelEncoder()
        
    def create_class_mapping(self):
        """Create comprehensive class mapping with metadata"""
        class_dirs = [d for d in self.data_dir.iterdir() if d.is_dir()]
        self.class_names = sorted([d.name for d in class_dirs])
        
        # Enhanced class mapping with crop and condition info
        class_mapping = {}
        for i, class_name in enumerate(self.class_names):
            if '___' in class_name:
                crop, condition = class_name.split('___', 1)
                crop = crop.replace('(', '').replace(')', '').replace('_', ' ')
            else:
                crop = 'Unknown'
                condition = class_name
            
            class_mapping[i] = {
                'class_name': class_name,
                'crop': crop,
                'condition': condition,
                'is_healthy': 'healthy' in condition.lower(),
                'severity': self._estimate_severity(condition)
            }
        
        return class_mapping
    
    def _estimate_severity(self, condition):
        """Estimate disease severity from condition name"""
        condition_lower = condition.lower()
        if 'healthy' in condition_lower:
            return 0
        elif any(term in condition_lower for term in ['early', 'minor']):
            return 1
        elif any(term in condition_lower for term in ['late', 'severe', 'blight']):
            return 3
        else:
            return 2  # moderate
    
    def load_and_prepare_data(self, validation_split=0.2, test_split=0.1, stratify=True):
        """Load and prepare data with advanced preprocessing"""
        
        # Load image paths and labels
        image_paths = []
        labels = []
        
        for class_dir in self.data_dir.iterdir():
            if not class_dir.is_dir():
                continue
                
            class_name = class_dir.name
            image_files = list(class_dir.glob('*.jpg')) + \
                         list(class_dir.glob('*.JPG')) + \
                         list(class_dir.glob('*.png'))
            
            for img_path in image_files:
                image_paths.append(str(img_path))
                labels.append(class_name)
        
        # Encode labels
        labels_encoded = self.label_encoder.fit_transform(labels)
        self.class_names = list(self.label_encoder.classes_)
        
        # Stratified splitting
        if stratify:
            X_train, X_temp, y_train, y_temp = train_test_split(
                image_paths, labels_encoded, 
                test_size=validation_split + test_split,
                stratify=labels_encoded, 
                random_state=42
            )
            
            X_val, X_test, y_val, y_test = train_test_split(
                X_temp, y_temp,
                test_size=test_split / (validation_split + test_split),
                stratify=y_temp,
                random_state=42
            )
        else:
            # Simple random splitting
            X_train, X_temp, y_train, y_temp = train_test_split(
                image_paths, labels_encoded, 
                test_size=validation_split + test_split,
                random_state=42
            )
            
            X_val, X_test, y_val, y_test = train_test_split(
                X_temp, y_temp,
                test_size=test_split / (validation_split + test_split),
                random_state=42
            )
        
        print(f"📊 Data Split Summary:")
        print(f"   🏋️ Training: {len(X_train):,} images")
        print(f"   🔍 Validation: {len(X_val):,} images") 
        print(f"   🎯 Test: {len(X_test):,} images")
        print(f"   🏷️ Classes: {len(self.class_names)}")
        
        return {
            'train': (X_train, y_train),
            'val': (X_val, y_val), 
            'test': (X_test, y_test),
            'class_mapping': self.create_class_mapping()
        }
    
    def create_tf_dataset(self, image_paths, labels, training=False):
        """Create TensorFlow dataset with advanced augmentation"""
        
        def load_and_preprocess_image(path, label):
            """Load and preprocess individual image"""
            image = tf.io.read_file(path)
            image = tf.image.decode_jpeg(image, channels=3)
            image = tf.image.resize(image, self.img_size)
            image = tf.cast(image, tf.float32) / 255.0
            
            if training:
                # Advanced augmentation for training
                image = tf.image.random_flip_left_right(image)
                image = tf.image.random_flip_up_down(image)
                image = tf.image.random_brightness(image, max_delta=0.1)
                image = tf.image.random_contrast(image, lower=0.9, upper=1.1)
                image = tf.image.random_saturation(image, lower=0.9, upper=1.1)
                image = tf.image.random_hue(image, max_delta=0.05)
                
                # Random rotation
                image = tf.image.rot90(image, tf.random.uniform(shape=[], maxval=4, dtype=tf.int32))
                
                # Random zoom and crop
                image = tf.image.random_crop(
                    tf.image.resize(image, [int(self.img_size[0] * 1.1), int(self.img_size[1] * 1.1)]),
                    size=[self.img_size[0], self.img_size[1], 3]
                )
            
            return image, label
        
        # Create dataset
        dataset = tf.data.Dataset.from_tensor_slices((image_paths, labels))
        dataset = dataset.map(load_and_preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
        
        if training:
            dataset = dataset.shuffle(buffer_size=1000)
            dataset = dataset.repeat()
        
        dataset = dataset.batch(self.batch_size)
        dataset = dataset.prefetch(tf.data.AUTOTUNE)
        
        return dataset

# Initialize data processor
print("🔄 Initializing Plant Disease Data Processor...")
data_processor = PlantDiseaseDataProcessor(DATA_DIR, img_size=(224, 224), batch_size=32)

# Load and prepare data
print("📥 Loading and preparing dataset...")
data_splits = data_processor.load_and_prepare_data(validation_split=0.2, test_split=0.1)

print("✅ Data preparation completed successfully!")

In [None]:
# 🎛️ Class Weight Calculation for Imbalanced Dataset

def calculate_strategic_class_weights(y_train, labels, strategy='balanced'):
    """Calculate class weights with multiple strategies for imbalanced data"""
    
    # Count samples per class
    unique_labels, counts = np.unique(y_train, return_counts=True)
    
    strategies = {
        'balanced': compute_class_weight('balanced', classes=unique_labels, y=y_train),
        'inverse_freq': len(y_train) / (len(unique_labels) * counts),
        'sqrt_inverse_freq': np.sqrt(len(y_train) / (len(unique_labels) * counts)),
        'log_inverse_freq': np.log(len(y_train) / counts),
    }
    
    class_weights = strategies[strategy]
    class_weight_dict = dict(zip(unique_labels, class_weights))
    
    print(f"📊 Class Weight Strategy: {strategy}")
    print(f"🔢 Weight Range: {class_weights.min():.3f} - {class_weights.max():.3f}")
    
    # Visualize class weights
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
    
    # Class distribution
    ax1.bar(range(len(counts)), counts, color='lightblue', alpha=0.7)
    ax1.set_title('📊 Original Class Distribution')
    ax1.set_xlabel('Class Index')
    ax1.set_ylabel('Sample Count')
    
    # Class weights
    ax2.bar(range(len(class_weights)), class_weights, color='orange', alpha=0.7)
    ax2.set_title(f'⚖️ Class Weights ({strategy})')
    ax2.set_xlabel('Class Index')
    ax2.set_ylabel('Weight')
    
    plt.tight_layout()
    plt.show()
    
    return class_weight_dict

# Calculate class weights
X_train, y_train = data_splits['train']
class_weights = calculate_strategic_class_weights(y_train, data_processor.class_names)

In [None]:
# 🏗️ Advanced Model Architecture Factory

class PlantDiseaseModelFactory:
    """Factory for creating state-of-the-art plant disease detection models"""
    
    def __init__(self, num_classes, input_shape=(224, 224, 3)):
        self.num_classes = num_classes
        self.input_shape = input_shape
        
    def create_efficient_net_model(self, model_size='B0', fine_tune=True):
        """Create EfficientNet-based model - excellent for agricultural imagery"""
        
        # Load pre-trained EfficientNet
        if model_size == 'B0':
            base_model = applications.EfficientNetB0(
                weights='imagenet',
                include_top=False,
                input_shape=self.input_shape
            )
        elif model_size == 'B3':
            base_model = applications.EfficientNetB3(
                weights='imagenet', 
                include_top=False,
                input_shape=self.input_shape
            )
        
        # Add custom classifier head
        model = keras.Sequential([
            base_model,
            layers.GlobalAveragePooling2D(),
            layers.Dropout(0.3),
            layers.Dense(512, activation='relu'),
            layers.BatchNormalization(),
            layers.Dropout(0.5),
            layers.Dense(256, activation='relu'),
            layers.BatchNormalization(), 
            layers.Dropout(0.3),
            layers.Dense(self.num_classes, activation='softmax', name='predictions')
        ])
        
        if fine_tune:
            # Unfreeze the last few layers for fine-tuning
            base_model.trainable = True
            for layer in base_model.layers[:-20]:
                layer.trainable = False
        else:
            base_model.trainable = False
            
        return model
    
    def create_vision_transformer(self, patch_size=16, num_heads=8, num_layers=6):
        """Create Vision Transformer model for plant disease detection"""
        
        # Vision Transformer implementation
        inputs = layers.Input(shape=self.input_shape)
        
        # Patch extraction
        patches = self._extract_patches(inputs, patch_size)
        patch_dims = patch_size * patch_size * 3
        
        # Patch encoding
        encoded_patches = layers.Dense(256)(patches)
        
        # Positional encoding
        num_patches = (self.input_shape[0] // patch_size) ** 2
        positions = tf.range(start=0, limit=num_patches, delta=1)
        position_embedding = layers.Embedding(input_dim=num_patches, output_dim=256)(positions)
        encoded_patches = encoded_patches + position_embedding
        
        # Transformer blocks
        for _ in range(num_layers):
            # Multi-head attention
            attention_output = layers.MultiHeadAttention(
                num_heads=num_heads, key_dim=256
            )(encoded_patches, encoded_patches)
            attention_output = layers.Dropout(0.1)(attention_output)
            attention_output = layers.LayerNormalization()(encoded_patches + attention_output)
            
            # Feed forward
            ffn_output = layers.Dense(512, activation='gelu')(attention_output)
            ffn_output = layers.Dense(256)(ffn_output)
            ffn_output = layers.Dropout(0.1)(ffn_output)
            encoded_patches = layers.LayerNormalization()(attention_output + ffn_output)
        
        # Global average pooling and classification
        representation = layers.GlobalAveragePooling1D()(encoded_patches)
        features = layers.Dense(512, activation='relu')(representation)
        features = layers.Dropout(0.3)(features)
        outputs = layers.Dense(self.num_classes, activation='softmax')(features)
        
        model = keras.Model(inputs=inputs, outputs=outputs)
        return model
    
    def _extract_patches(self, images, patch_size):
        """Extract patches from images"""
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, patch_size, patch_size, 1],
            strides=[1, patch_size, patch_size, 1],
            rates=[1, 1, 1, 1],
            padding="VALID",
        )
        patch_dims = patches.shape[-1]
        patches = tf.reshape(patches, [batch_size, -1, patch_dims])
        return patches
    
    def create_hybrid_cnn_transformer(self):
        """Create hybrid CNN-Transformer model combining both approaches"""
        
        # CNN backbone for feature extraction
        base_cnn = applications.EfficientNetB0(
            weights='imagenet',
            include_top=False,
            input_shape=self.input_shape
        )
        base_cnn.trainable = False
        
        inputs = layers.Input(shape=self.input_shape)
        
        # CNN feature extraction
        cnn_features = base_cnn(inputs, training=False)
        cnn_features = layers.GlobalAveragePooling2D()(cnn_features)
        
        # Transformer branch
        patches = self._extract_patches(inputs, patch_size=32)
        patch_embedding = layers.Dense(256)(patches)
        
        # Single transformer block
        attention_output = layers.MultiHeadAttention(num_heads=8, key_dim=256)(
            patch_embedding, patch_embedding
        )
        attention_output = layers.LayerNormalization()(patch_embedding + attention_output)
        transformer_features = layers.GlobalAveragePooling1D()(attention_output)
        
        # Fusion layer
        combined_features = layers.Concatenate()([cnn_features, transformer_features])
        combined_features = layers.Dense(512, activation='relu')(combined_features)
        combined_features = layers.BatchNormalization()(combined_features)
        combined_features = layers.Dropout(0.4)(combined_features)
        
        # Final classification
        outputs = layers.Dense(self.num_classes, activation='softmax')(combined_features)
        
        model = keras.Model(inputs=inputs, outputs=outputs)
        return model
    
    def create_ensemble_model(self, models_list):
        """Create ensemble model from multiple architectures"""
        
        inputs = layers.Input(shape=self.input_shape)
        outputs = []
        
        for model in models_list:
            model.trainable = False
            output = model(inputs)
            outputs.append(output)
        
        # Average ensemble
        ensemble_output = layers.Average()(outputs)
        
        ensemble_model = keras.Model(inputs=inputs, outputs=ensemble_output)
        return ensemble_model

# Initialize model factory
model_factory = PlantDiseaseModelFactory(
    num_classes=len(data_processor.class_names), 
    input_shape=(224, 224, 3)
)

print(f"🏭 Model Factory initialized for {len(data_processor.class_names)} classes")
print(f"🎯 Input shape: {model_factory.input_shape}")

In [None]:
# 🚀 Create and Compare Multiple Model Architectures

def create_and_compare_models():
    """Create multiple model architectures for comparison"""
    
    models = {}
    
    print("🔨 Building model architectures...")
    
    # 1. EfficientNet B0 (Fast, efficient)
    print("   📱 Creating EfficientNet-B0 (Mobile-friendly)...")
    models['efficientnet_b0'] = model_factory.create_efficient_net_model('B0', fine_tune=True)
    
    # 2. EfficientNet B3 (Higher accuracy)
    print("   🔥 Creating EfficientNet-B3 (High performance)...")
    models['efficientnet_b3'] = model_factory.create_efficient_net_model('B3', fine_tune=True)
    
    # 3. Hybrid CNN-Transformer
    print("   🤖 Creating Hybrid CNN-Transformer...")
    models['hybrid_cnn_transformer'] = model_factory.create_hybrid_cnn_transformer()
    
    # Model summary comparison
    print("\n📊 Model Architecture Comparison:")
    for name, model in models.items():
        total_params = model.count_params()
        trainable_params = sum([keras.backend.count_params(w) for w in model.trainable_weights])
        
        print(f"\n🏗️ {name.replace('_', ' ').title()}:")
        print(f"   📏 Total Parameters: {total_params:,}")
        print(f"   🎯 Trainable Parameters: {trainable_params:,}")
        print(f"   💾 Estimated Model Size: {total_params * 4 / (1024**2):.1f} MB")
    
    return models

# Create model architectures
models = create_and_compare_models()

In [None]:
# 📈 Advanced Training Configuration & Callbacks

def create_training_callbacks(model_name, monitor='val_accuracy'):
    """Create comprehensive training callbacks for agricultural AI"""
    
    # Paths for this specific model
    model_dir = MODELS_DIR / model_name
    model_dir.mkdir(exist_ok=True)
    
    callbacks_list = [
        # Model checkpointing - save best model
        keras.callbacks.ModelCheckpoint(
            filepath=str(model_dir / 'best_model.h5'),
            monitor=monitor,
            save_best_only=True,
            save_weights_only=False,
            mode='max',
            verbose=1
        ),
        
        # Early stopping - prevent overfitting
        keras.callbacks.EarlyStopping(
            monitor=monitor,
            patience=15,
            restore_best_weights=True,
            verbose=1
        ),
        
        # Learning rate reduction
        keras.callbacks.ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.5,
            patience=8,
            min_lr=1e-7,
            verbose=1
        ),
        
        # CSV logging
        keras.callbacks.CSVLogger(
            filename=str(model_dir / 'training_log.csv'),
            append=True
        ),
        
        # TensorBoard logging
        keras.callbacks.TensorBoard(
            log_dir=str(EXPERIMENTS_DIR / 'tensorboard' / model_name),
            histogram_freq=1,
            write_graph=True,
            write_images=True,
            update_freq='epoch'
        )
    ]
    
    return callbacks_list

def compile_model_for_agriculture(model, model_name):
    """Compile model with agricultural-specific considerations"""
    
    # For critical agricultural applications, we prioritize:
    # 1. High recall for disease detection (avoid false negatives)
    # 2. Calibrated confidence scores
    # 3. Robust optimization
    
    # Custom metrics for agricultural applications
    def f1_score(y_true, y_pred):
        """F1 score metric"""
        y_pred = tf.round(y_pred)
        tp = tf.reduce_sum(tf.cast(y_true * y_pred, tf.float32), axis=0)
        fp = tf.reduce_sum(tf.cast((1 - y_true) * y_pred, tf.float32), axis=0)
        fn = tf.reduce_sum(tf.cast(y_true * (1 - y_pred), tf.float32), axis=0)
        
        precision = tp / (tp + fp + tf.keras.backend.epsilon())
        recall = tp / (tp + fn + tf.keras.backend.epsilon())
        f1 = 2 * precision * recall / (precision + recall + tf.keras.backend.epsilon())
        
        return tf.reduce_mean(f1)
    
    def recall_score(y_true, y_pred):
        """Recall score - critical for disease detection"""
        y_pred = tf.round(y_pred)
        tp = tf.reduce_sum(tf.cast(y_true * y_pred, tf.float32), axis=0)
        fn = tf.reduce_sum(tf.cast(y_true * (1 - y_pred), tf.float32), axis=0)
        return tf.reduce_mean(tp / (tp + fn + tf.keras.backend.epsilon()))
    
    def precision_score(y_true, y_pred):
        """Precision score"""
        y_pred = tf.round(y_pred)
        tp = tf.reduce_sum(tf.cast(y_true * y_pred, tf.float32), axis=0)
        fp = tf.reduce_sum(tf.cast((1 - y_true) * y_pred, tf.float32), axis=0)
        return tf.reduce_mean(tp / (tp + fp + tf.keras.backend.epsilon()))
    
    # Focal loss for handling class imbalance
    def focal_loss(alpha=0.25, gamma=2.0):
        """Focal loss to handle class imbalance"""
        def focal_loss_fixed(y_true, y_pred):
            y_pred = tf.clip_by_value(y_pred, tf.keras.backend.epsilon(), 1 - tf.keras.backend.epsilon())
            p_t = tf.where(tf.equal(y_true, 1), y_pred, 1 - y_pred)
            alpha_factor = tf.ones_like(y_true) * alpha
            alpha_t = tf.where(tf.equal(y_true, 1), alpha_factor, 1 - alpha_factor)
            cross_entropy = -tf.math.log(p_t)
            weight = alpha_t * tf.pow((1 - p_t), gamma)
            loss = weight * cross_entropy
            return tf.reduce_mean(tf.reduce_sum(loss, axis=1))
        return focal_loss_fixed
    
    # Choose optimizer based on model complexity
    if 'efficientnet' in model_name.lower():
        optimizer = keras.optimizers.AdamW(learning_rate=0.001, weight_decay=0.0001)
    else:
        optimizer = keras.optimizers.Adam(learning_rate=0.001)
    
    # Compile model
    model.compile(
        optimizer=optimizer,
        loss=focal_loss(alpha=0.25, gamma=2.0),
        metrics=[
            'accuracy',
            keras.metrics.TopKCategoricalAccuracy(k=3, name='top3_accuracy'),
            f1_score,
            recall_score,
            precision_score
        ]
    )
    
    print(f"✅ Model '{model_name}' compiled successfully")
    print(f"   🎯 Loss: Focal Loss (α=0.25, γ=2.0)")
    print(f"   📊 Metrics: Accuracy, Top-3, F1, Recall, Precision")
    print(f"   ⚡ Optimizer: {optimizer.__class__.__name__}")
    
    return model

# Compile all models
print("⚙️ Compiling models for agricultural applications...")
for name, model in models.items():
    models[name] = compile_model_for_agriculture(model, name)

print("✅ All models compiled successfully!")

In [None]:
# 🏋️‍♂️ Create TensorFlow Datasets for Training

print("🔄 Creating TensorFlow datasets...")

# Create datasets
train_dataset = data_processor.create_tf_dataset(
    data_splits['train'][0], 
    data_splits['train'][1], 
    training=True
)

val_dataset = data_processor.create_tf_dataset(
    data_splits['val'][0], 
    data_splits['val'][1], 
    training=False
)

test_dataset = data_processor.create_tf_dataset(
    data_splits['test'][0], 
    data_splits['test'][1], 
    training=False
)

# Calculate steps per epoch
steps_per_epoch = len(data_splits['train'][0]) // data_processor.batch_size
validation_steps = len(data_splits['val'][0]) // data_processor.batch_size

print(f"📊 Dataset Statistics:")
print(f"   🏋️ Steps per Epoch: {steps_per_epoch}")
print(f"   🔍 Validation Steps: {validation_steps}")
print(f"   💿 Batch Size: {data_processor.batch_size}")

# Visualize a batch of training data
def visualize_training_batch(dataset, class_names, num_images=12):
    """Visualize a batch of training data with augmentations"""
    
    # Get a batch
    for images, labels in dataset.take(1):
        fig, axes = plt.subplots(3, 4, figsize=(16, 12))
        axes = axes.flatten()
        
        for i in range(min(num_images, len(images))):
            img = images[i].numpy()
            label_idx = tf.argmax(labels[i]).numpy()
            label_name = class_names[label_idx]
            
            axes[i].imshow(img)
            axes[i].set_title(f"{label_name}", fontsize=10)
            axes[i].axis('off')
        
        # Hide extra subplots
        for i in range(num_images, len(axes)):
            axes[i].axis('off')
        
        plt.suptitle('🌱 Training Data with Augmentations', fontsize=16)
        plt.tight_layout()
        plt.show()
        break

# Visualize training data
print("👀 Visualizing training data with augmentations...")
visualize_training_batch(train_dataset, data_processor.class_names)

# 🧭 Notebook Relocated

This notebook has been superseded by a new dedicated EDA notebook and a separate training/evaluation notebook to better align with the project goals:

- 01_eda_plant_disease.ipynb – Comprehensive EDA with interactive Plotly visuals
- 02_training_evaluation.ipynb – Model training, evaluation, and calibration

You can safely continue using the new notebooks in the `notebooks/` folder.