# Smart Agriculture Advisor for Crops & Pests
## Hybrid Intelligent System: Knowledge Engineering + Deep Learning

**Course:** DSC3113 - Knowledge Engineering  
**Project:** Group 2 - Smart Agriculture Advisor  
**Institution:** Uganda Christian University  
**Semester:** Advent 2025

---

## Table of Contents
1. [Introduction](#introduction)
2. [Data Collection & Loading](#data-collection)
3. [Exploratory Data Analysis (EDA)](#eda)
4. [Knowledge Engineering Setup](#knowledge-engineering)
5. [Deep Learning Setup](#deep-learning)
6. [Hybrid Integration](#hybrid-integration)
7. [Evaluation](#evaluation)
8. [Case Studies](#case-studies)
9. [Results and Discussion](#results)
10. [Conclusion and Recommendations](#conclusion)


<a id="introduction"></a>
## 1. Introduction

### Problem Statement
Farmers worldwide struggle with early identification of crop diseases and pests, leading to significant yield losses. Traditional diagnostic methods rely heavily on expert knowledge, which may not always be accessible to smallholder farmers. This project addresses this challenge by developing a hybrid intelligent system that combines:

- **Knowledge Engineering (KE)**: Expert rules and ontologies for symbolic reasoning about diseases, symptoms, and treatments
- **Deep Learning (DL)**: Convolutional Neural Networks (CNNs) for image-based disease recognition

### Objectives
1. Build an agriculture ontology connecting crops, pests, diseases, symptoms, and treatments
2. Encode at least 20 expert rules for disease diagnosis and treatment recommendations
3. Train CNN models to classify crop diseases from leaf images
4. Integrate KE reasoning with DL predictions for hybrid decision-making
5. Evaluate the system's accuracy and provide actionable recommendations

### Dataset
- **Cassava Leaf Disease Classification Dataset**: https://www.kaggle.com/competitions/cassava-disease
- **Disease Classes**: 
  - Cassava Bacterial Blight (CBB)
  - Cassava Brown Streak Disease (CBSD)
  - Cassava Green Mottle (CGM)
  - Cassava Mosaic Disease (CMD)
  - Healthy

### Relevance
This hybrid system bridges the gap between rule-based expert systems and data-driven AI, providing farmers with accessible, accurate disease diagnosis and treatment recommendations.


### System Architecture

```
┌─────────────────────────────────────────────────────────────┐
│              Smart Agriculture Advisor System                │
├─────────────────────────────────────────────────────────────┤
│                                                              │
│  ┌──────────────┐         ┌──────────────┐                 │
│  │   CNN Model  │────────▶│  Disease     │                 │
│  │  (Image      │         │  Prediction  │                 │
│  │  Classifier) │         │              │                 │
│  └──────────────┘         └──────┬───────┘                 │
│                                   │                         │
│  ┌──────────────┐                │                         │
│  │  Knowledge   │                ▼                         │
│  │  Base        │         ┌──────────────┐                 │
│  │  (Ontology + │────────▶│  Hybrid      │                 │
│  │   Rules)     │         │  Reasoning   │                 │
│  └──────────────┘         │  Engine      │                 │
│                           └──────┬───────┘                 │
│                                   │                         │
│                                   ▼                         │
│                           ┌──────────────┐                 │
│                           │  Treatment   │                 │
│                           │  & Advice    │                 │
│                           │  Output      │                 │
│                           └──────────────┘                 │
└─────────────────────────────────────────────────────────────┘
```


<a id="setup"></a>
## 2. Setup and Dependencies


In [None]:
# Install required packages
!pip install -q opendatasets owlready2 rdflib scikit-learn tensorflow keras numpy pandas matplotlib seaborn pillow opencv-python scipy


In [None]:
# Import libraries
import os
import sys
import warnings
warnings.filterwarnings('ignore')

# Data handling
import numpy as np
import pandas as pd
from pathlib import Path
import json
import pickle

# Deep Learning
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models, callbacks
from tensorflow.keras.applications import MobileNetV2, InceptionV3, EfficientNetB0
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight

# Image processing
from PIL import Image
import cv2
import matplotlib.pyplot as plt
import seaborn as sns

# Knowledge Engineering
try:
    from owlready2 import *
except:
    print("Note: owlready2 installation may require additional setup")
import rdflib
from rdflib import Graph, Namespace, Literal, URIRef
from rdflib.namespace import RDF, RDFS, OWL

# Visualization
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

# Set random seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)

print("✓ All libraries imported successfully")
print(f"TensorFlow version: {tf.__version__}")
print(f"Python version: {sys.version.split()[0]}")


<a id="data-collection"></a>
## 3. Data Collection & Loading


In [None]:
# Dataset configuration
DATASET_NAME = "cassava-leaf-disease-classification"
DATASET_PATH = "./cassava-leaf-disease-classification"  # Adjust path as needed

# Disease class mappings
DISEASE_CLASSES = {
    0: "Cassava_Bacterial_Blight",
    1: "Cassava_Brown_Streak_Disease", 
    2: "Cassava_Green_Mottle",
    3: "Cassava_Mosaic_Disease",
    4: "Healthy"
}

CLASS_NAMES_SHORT = {
    0: "CBB",
    1: "CBSD",
    2: "CGM", 
    3: "CMD",
    4: "Healthy"
}

print("Disease Classes:")
for idx, name in DISEASE_CLASSES.items():
    print(f"  {idx}: {name} ({CLASS_NAMES_SHORT[idx]})")


In [None]:
# Download dataset (uncomment if needed)
# import opendatasets as od
# od.download(f"https://www.kaggle.com/competitions/{DATASET_NAME}")

# Alternative: Load from local path
# Adjust this path to your dataset location
BASE_DATA_PATH = Path("./cassava-leaf-disease-classification/data")

# If dataset is in a different location, update here
# BASE_DATA_PATH = Path("/content/drive/MyDrive/cassava-leaf-disease-classification/data")
# BASE_DATA_PATH = Path("cassava-disease/data")

print(f"Looking for dataset at: {BASE_DATA_PATH.absolute()}")

if BASE_DATA_PATH.exists():
    print("✓ Dataset folder found")
else:
    print("⚠ Dataset folder not found. Please update BASE_DATA_PATH")
    print("   Expected structure:")
    print("   cassava-leaf-disease-classification/")
    print("     └── data/")
    print("         ├── Cassava___bacterial_blight/")
    print("         ├── Cassava___brown_streak_disease/")
    print("         ├── Cassava___green_mottle/")
    print("         ├── Cassava___mosaic_disease/")
    print("         └── Cassava___healthy/")


<a id="eda"></a>
## 4. Exploratory Data Analysis (EDA)

### 4.1 Dataset Structure Analysis


In [None]:
def analyze_dataset_structure(base_path):
    """Analyze the structure and statistics of the dataset"""
    
    if not Path(base_path).exists():
        print(f"⚠ Path {base_path} does not exist")
        return None, None
    
    class_stats = {}
    total_images = 0
    image_extensions = ['.jpg', '.jpeg', '.png', '.JPG', '.JPEG', '.PNG']
    
    # Expected folder names mapping
    folder_mapping = {
        'Cassava___bacterial_blight': 0,
        'Cassava___brown_streak_disease': 1,
        'Cassava___green_mottle': 2,
        'Cassava___mosaic_disease': 3,
        'Cassava___healthy': 4
    }
    
    for folder_name in Path(base_path).iterdir():
        if folder_name.is_dir():
            # Count images in this folder
            image_count = 0
            image_sizes = []
            
            for ext in image_extensions:
                images = list(folder_name.glob(f"*{ext}"))
                image_count += len(images)
                
                # Sample image sizes (first 10)
                for img_path in images[:10]:
                    try:
                        with Image.open(img_path) as img:
                            image_sizes.append(img.size)
                    except:
                        pass
            
            # Map folder to class index
            class_idx = folder_mapping.get(folder_name.name, -1)
            
            class_stats[folder_name.name] = {
                'count': image_count,
                'class_idx': class_idx,
                'sample_sizes': image_sizes[:5] if image_sizes else []
            }
            
            total_images += image_count
    
    return class_stats, total_images

# Analyze dataset
class_stats, total_images = analyze_dataset_structure(BASE_DATA_PATH)

if class_stats:
    print("=" * 60)
    print("DATASET STRUCTURE ANALYSIS")
    print("=" * 60)
    print(f"\nTotal Images: {total_images:,}\n")
    print("-" * 60)
    
    for folder_name, stats in sorted(class_stats.items(), key=lambda x: x[1]['class_idx']):
        class_idx = stats['class_idx']
        count = stats['count']
        percentage = (count / total_images * 100) if total_images > 0 else 0
        
        print(f"Class {class_idx}: {folder_name}")
        print(f"  Images: {count:,} ({percentage:.2f}%)")
        if stats['sample_sizes']:
            print(f"  Sample sizes: {stats['sample_sizes'][0] if stats['sample_sizes'] else 'N/A'}")
        print()
else:
    print("⚠ Could not analyze dataset. Please check the path.")


In [None]:
### 4.2 Visual Distribution Analysis

if class_stats:
    # Create visualization
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))
    
    # Extract data for plotting
    class_labels = []
    counts = []
    class_indices = []
    
    for folder_name, stats in sorted(class_stats.items(), key=lambda x: x[1]['class_idx']):
        class_labels.append(CLASS_NAMES_SHORT[stats['class_idx']])
        counts.append(stats['count'])
        class_indices.append(stats['class_idx'])
    
    # Bar plot
    colors = sns.color_palette("husl", len(class_labels))
    bars = axes[0].bar(class_labels, counts, color=colors, edgecolor='black', linewidth=1.5)
    axes[0].set_title('Class Distribution (Bar Chart)', fontsize=14, fontweight='bold')
    axes[0].set_xlabel('Disease Class', fontsize=12)
    axes[0].set_ylabel('Number of Images', fontsize=12)
    axes[0].grid(axis='y', alpha=0.3)
    
    # Add value labels on bars
    for bar, count in zip(bars, counts):
        height = bar.get_height()
        axes[0].text(bar.get_x() + bar.get_width()/2., height,
                    f'{count:,}',
                    ha='center', va='bottom', fontweight='bold')
    
    # Pie chart
    wedges, texts, autotexts = axes[1].pie(counts, labels=class_labels, autopct='%1.1f%%',
                                           colors=colors, startangle=90, textprops={'fontsize': 11})
    axes[1].set_title('Class Distribution (Pie Chart)', fontsize=14, fontweight='bold')
    
    # Make percentage text bold
    for autotext in autotexts:
        autotext.set_color('white')
        autotext.set_fontweight('bold')
    
    plt.tight_layout()
    plt.show()
    
    # Print statistics
    print("\nClass Distribution Statistics:")
    print("-" * 60)
    for label, count, idx in zip(class_labels, counts, class_indices):
        pct = (count / total_images * 100)
        print(f"{label:8s} (Class {idx}): {count:5,} images ({pct:5.2f}%)")


### 4.3 Sample Image Visualization


In [None]:
def visualize_sample_images(base_path, class_stats, num_samples=5):
    """Visualize sample images from each class"""
    
    if not class_stats:
        print("⚠ No class statistics available")
        return
    
    fig, axes = plt.subplots(len(class_stats), num_samples, figsize=(20, 4*len(class_stats)))
    
    if len(class_stats) == 1:
        axes = axes.reshape(1, -1)
    
    folder_mapping = {
        'Cassava___bacterial_blight': 0,
        'Cassava___brown_streak_disease': 1,
        'Cassava___green_mottle': 2,
        'Cassava___mosaic_disease': 3,
        'Cassava___healthy': 4
    }
    
    row = 0
    for folder_name, stats in sorted(class_stats.items(), key=lambda x: x[1]['class_idx']):
        folder_path = Path(base_path) / folder_name
        class_idx = stats['class_idx']
        class_name = CLASS_NAMES_SHORT[class_idx]
        
        # Get sample images
        image_files = []
        for ext in ['.jpg', '.jpeg', '.png', '.JPG', '.JPEG', '.PNG']:
            image_files.extend(list(folder_path.glob(f"*{ext}")))
        
        # Randomly sample
        if len(image_files) > num_samples:
            image_files = np.random.choice(image_files, num_samples, replace=False)
        else:
            image_files = image_files[:num_samples]
        
        for col, img_path in enumerate(image_files[:num_samples]):
            try:
                img = Image.open(img_path)
                axes[row, col].imshow(img)
                axes[row, col].set_title(f'{class_name}\n{img_path.name[:20]}...', 
                                       fontsize=10, fontweight='bold')
                axes[row, col].axis('off')
            except Exception as e:
                axes[row, col].text(0.5, 0.5, f'Error loading\n{img_path.name}', 
                                  ha='center', va='center')
                axes[row, col].axis('off')
        
        row += 1
    
    plt.suptitle('Sample Images from Each Disease Class', fontsize=16, fontweight='bold', y=0.995)
    plt.tight_layout()
    plt.show()

# Visualize samples
if class_stats:
    visualize_sample_images(BASE_DATA_PATH, class_stats, num_samples=5)


In [None]:
def analyze_image_statistics(base_path, class_stats, sample_per_class=50):
    """Analyze image statistics including dimensions, color channels, etc."""
    
    if not class_stats:
        return None
    
    all_stats = {
        'widths': [],
        'heights': [],
        'aspect_ratios': [],
        'channels': [],
        'file_sizes': [],
        'classes': []
    }
    
    folder_mapping = {
        'Cassava___bacterial_blight': 0,
        'Cassava___brown_streak_disease': 1,
        'Cassava___green_mottle': 2,
        'Cassava___mosaic_disease': 3,
        'Cassava___healthy': 4
    }
    
    for folder_name, stats in class_stats.items():
        folder_path = Path(base_path) / folder_name
        class_idx = stats['class_idx']
        
        # Get images
        image_files = []
        for ext in ['.jpg', '.jpeg', '.png', '.JPG', '.JPEG', '.PNG']:
            image_files.extend(list(folder_path.glob(f"*{ext}")))
        
        # Sample images
        if len(image_files) > sample_per_class:
            image_files = np.random.choice(image_files, sample_per_class, replace=False)
        
        for img_path in image_files:
            try:
                # Get file size
                file_size = img_path.stat().st_size / (1024 * 1024)  # MB
                
                # Open and analyze image
                with Image.open(img_path) as img:
                    width, height = img.size
                    channels = len(img.getbands()) if hasattr(img, 'getbands') else 3
                    aspect_ratio = width / height
                    
                    all_stats['widths'].append(width)
                    all_stats['heights'].append(height)
                    all_stats['aspect_ratios'].append(aspect_ratio)
                    all_stats['channels'].append(channels)
                    all_stats['file_sizes'].append(file_size)
                    all_stats['classes'].append(CLASS_NAMES_SHORT[class_idx])
            except Exception as e:
                continue
    
    return pd.DataFrame(all_stats)

# Analyze image statistics
if class_stats:
    img_stats_df = analyze_image_statistics(BASE_DATA_PATH, class_stats, sample_per_class=100)
    
    if img_stats_df is not None and len(img_stats_df) > 0:
        print("=" * 60)
        print("IMAGE STATISTICS ANALYSIS")
        print("=" * 60)
        print(f"\nTotal Images Analyzed: {len(img_stats_df):,}\n")
        print("-" * 60)
        print("\nSummary Statistics:")
        print(img_stats_df[['widths', 'heights', 'aspect_ratios', 'file_sizes']].describe())
        
        # Visualize statistics
        fig, axes = plt.subplots(2, 2, figsize=(16, 12))
        
        # Width distribution
        axes[0, 0].hist(img_stats_df['widths'], bins=50, color='skyblue', edgecolor='black')
        axes[0, 0].set_title('Image Width Distribution', fontweight='bold')
        axes[0, 0].set_xlabel('Width (pixels)')
        axes[0, 0].set_ylabel('Frequency')
        axes[0, 0].axvline(img_stats_df['widths'].mean(), color='red', linestyle='--', 
                          label=f'Mean: {img_stats_df["widths"].mean():.0f}')
        axes[0, 0].legend()
        
        # Height distribution
        axes[0, 1].hist(img_stats_df['heights'], bins=50, color='lightcoral', edgecolor='black')
        axes[0, 1].set_title('Image Height Distribution', fontweight='bold')
        axes[0, 1].set_xlabel('Height (pixels)')
        axes[0, 1].set_ylabel('Frequency')
        axes[0, 1].axvline(img_stats_df['heights'].mean(), color='red', linestyle='--',
                          label=f'Mean: {img_stats_df["heights"].mean():.0f}')
        axes[0, 1].legend()
        
        # Aspect ratio distribution
        axes[1, 0].hist(img_stats_df['aspect_ratios'], bins=50, color='lightgreen', edgecolor='black')
        axes[1, 0].set_title('Aspect Ratio Distribution', fontweight='bold')
        axes[1, 0].set_xlabel('Aspect Ratio (Width/Height)')
        axes[1, 0].set_ylabel('Frequency')
        axes[1, 0].axvline(img_stats_df['aspect_ratios'].mean(), color='red', linestyle='--',
                          label=f'Mean: {img_stats_df["aspect_ratios"].mean():.2f}')
        axes[1, 0].legend()
        
        # File size distribution
        axes[1, 1].hist(img_stats_df['file_sizes'], bins=50, color='plum', edgecolor='black')
        axes[1, 1].set_title('File Size Distribution', fontweight='bold')
        axes[1, 1].set_xlabel('File Size (MB)')
        axes[1, 1].set_ylabel('Frequency')
        axes[1, 1].axvline(img_stats_df['file_sizes'].mean(), color='red', linestyle='--',
                          label=f'Mean: {img_stats_df["file_sizes"].mean():.3f} MB')
        axes[1, 1].legend()
        
        plt.tight_layout()
        plt.show()
    else:
        print("⚠ Could not analyze image statistics")


In [None]:
def analyze_color_distribution(base_path, class_stats, samples_per_class=20):
    """Analyze color distribution across disease classes"""
    
    if not class_stats:
        return None
    
    color_stats = {name: {'mean_rgb': [], 'std_rgb': []} for name in CLASS_NAMES_SHORT.values()}
    
    folder_mapping = {
        'Cassava___bacterial_blight': 0,
        'Cassava___brown_streak_disease': 1,
        'Cassava___green_mottle': 2,
        'Cassava___mosaic_disease': 3,
        'Cassava___healthy': 4
    }
    
    for folder_name, stats in class_stats.items():
        folder_path = Path(base_path) / folder_name
        class_idx = stats['class_idx']
        class_name = CLASS_NAMES_SHORT[class_idx]
        
        image_files = []
        for ext in ['.jpg', '.jpeg', '.png']:
            image_files.extend(list(folder_path.glob(f"*{ext}")))
        
        if len(image_files) > samples_per_class:
            image_files = np.random.choice(image_files, samples_per_class, replace=False)
        
        for img_path in image_files:
            try:
                img = np.array(Image.open(img_path))
                if len(img.shape) == 3:
                    mean_rgb = img.mean(axis=(0, 1))
                    std_rgb = img.std(axis=(0, 1))
                    color_stats[class_name]['mean_rgb'].append(mean_rgb)
                    color_stats[class_name]['std_rgb'].append(std_rgb)
            except:
                continue
    
    # Create visualization
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))
    
    # Mean RGB values by class
    mean_data = []
    class_names_list = []
    for class_name, stats in color_stats.items():
        if stats['mean_rgb']:
            mean_rgb = np.mean(stats['mean_rgb'], axis=0)
            mean_data.append(mean_rgb)
            class_names_list.append(class_name)
    
    if mean_data:
        mean_data = np.array(mean_data)
        x = np.arange(len(class_names_list))
        width = 0.25
        
        axes[0].bar(x - width, mean_data[:, 0], width, label='Red', color='red', alpha=0.7)
        axes[0].bar(x, mean_data[:, 1], width, label='Green', color='green', alpha=0.7)
        axes[0].bar(x + width, mean_data[:, 2], width, label='Blue', color='blue', alpha=0.7)
        
        axes[0].set_xlabel('Disease Class', fontweight='bold')
        axes[0].set_ylabel('Mean Pixel Intensity', fontweight='bold')
        axes[0].set_title('Average RGB Values by Disease Class', fontweight='bold', fontsize=12)
        axes[0].set_xticks(x)
        axes[0].set_xticklabels(class_names_list)
        axes[0].legend()
        axes[0].grid(axis='y', alpha=0.3)
        
        # Color swatches
        for idx, (class_name, rgb_mean) in enumerate(zip(class_names_list, mean_data)):
            axes[1].add_patch(plt.Rectangle((0, idx), 1, 0.8, 
                                           facecolor=rgb_mean/255.0, edgecolor='black', linewidth=2))
            axes[1].text(1.1, idx + 0.4, class_name, va='center', fontweight='bold', fontsize=11)
        
        axes[1].set_xlim(-0.1, 3)
        axes[1].set_ylim(-0.5, len(class_names_list))
        axes[1].set_title('Average Color by Disease Class', fontweight='bold', fontsize=12)
        axes[1].axis('off')
    
    plt.tight_layout()
    plt.show()

# Analyze color distribution
if class_stats:
    analyze_color_distribution(BASE_DATA_PATH, class_stats, samples_per_class=30)


In [None]:
def check_data_quality(base_path, class_stats):
    """Check for corrupted images, missing files, etc."""
    
    if not class_stats:
        return None
    
    quality_report = {
        'total_checked': 0,
        'corrupted': 0,
        'valid': 0,
        'missing_files': 0,
        'corruption_by_class': {}
    }
    
    folder_mapping = {
        'Cassava___bacterial_blight': 0,
        'Cassava___brown_streak_disease': 1,
        'Cassava___green_mottle': 2,
        'Cassava___mosaic_disease': 3,
        'Cassava___healthy': 4
    }
    
    for folder_name, stats in class_stats.items():
        folder_path = Path(base_path) / folder_name
        class_idx = stats['class_idx']
        class_name = CLASS_NAMES_SHORT[class_idx]
        
        corrupted_count = 0
        valid_count = 0
        
        image_files = []
        for ext in ['.jpg', '.jpeg', '.png']:
            image_files.extend(list(folder_path.glob(f"*{ext}")))
        
        for img_path in image_files:
            quality_report['total_checked'] += 1
            try:
                img = Image.open(img_path)
                img.verify()  # Verify image integrity
                img.close()
                valid_count += 1
                quality_report['valid'] += 1
            except Exception as e:
                corrupted_count += 1
                quality_report['corrupted'] += 1
        
        quality_report['corruption_by_class'][class_name] = {
            'total': len(image_files),
            'valid': valid_count,
            'corrupted': corrupted_count,
            'corruption_rate': (corrupted_count / len(image_files) * 100) if len(image_files) > 0 else 0
        }
    
    return quality_report

# Check data quality
if class_stats:
    quality_report = check_data_quality(BASE_DATA_PATH, class_stats)
    
    if quality_report:
        print("=" * 60)
        print("DATA QUALITY ASSESSMENT")
        print("=" * 60)
        print(f"\nTotal Images Checked: {quality_report['total_checked']:,}")
        print(f"Valid Images: {quality_report['valid']:,} ({quality_report['valid']/quality_report['total_checked']*100:.2f}%)")
        print(f"Corrupted Images: {quality_report['corrupted']:,} ({quality_report['corrupted']/quality_report['total_checked']*100:.2f}%)")
        
        print("\n" + "-" * 60)
        print("Corruption Rate by Class:")
        print("-" * 60)
        for class_name, stats in quality_report['corruption_by_class'].items():
            print(f"{class_name:8s}: {stats['valid']:5,} valid / {stats['total']:5,} total "
                  f"({stats['corruption_rate']:.2f}% corrupted)")
        
        # Visualize quality
        fig, ax = plt.subplots(figsize=(10, 6))
        classes = list(quality_report['corruption_by_class'].keys())
        valid_counts = [quality_report['corruption_by_class'][c]['valid'] for c in classes]
        corrupted_counts = [quality_report['corruption_by_class'][c]['corrupted'] for c in classes]
        
        x = np.arange(len(classes))
        width = 0.35
        
        bars1 = ax.bar(x - width/2, valid_counts, width, label='Valid', color='green', alpha=0.7)
        bars2 = ax.bar(x + width/2, corrupted_counts, width, label='Corrupted', color='red', alpha=0.7)
        
        ax.set_xlabel('Disease Class', fontweight='bold')
        ax.set_ylabel('Number of Images', fontweight='bold')
        ax.set_title('Data Quality by Disease Class', fontweight='bold', fontsize=12)
        ax.set_xticks(x)
        ax.set_xticklabels(classes)
        ax.legend()
        ax.grid(axis='y', alpha=0.3)
        
        plt.tight_layout()
        plt.show()


### 4.7 Class Imbalance Analysis


In [None]:
if class_stats:
    # Calculate class imbalance metrics
    counts = [stats['count'] for stats in sorted(class_stats.values(), key=lambda x: x['class_idx'])]
    labels = [CLASS_NAMES_SHORT[idx] for idx in range(len(counts))]
    
    # Calculate imbalance ratio
    max_count = max(counts)
    min_count = min(counts)
    imbalance_ratio = max_count / min_count
    
    print("=" * 60)
    print("CLASS IMBALANCE ANALYSIS")
    print("=" * 60)
    print(f"\nImbalance Ratio (Max/Min): {imbalance_ratio:.2f}")
    print(f"Most Frequent Class: {labels[counts.index(max_count)]} ({max_count:,} images)")
    print(f"Least Frequent Class: {labels[counts.index(min_count)]} ({min_count:,} images)")
    
    # Calculate class weights (for balancing during training)
    class_weights_array = compute_class_weight('balanced', 
                                               classes=np.arange(len(counts)),
                                               y=np.repeat(np.arange(len(counts)), counts))
    class_weights = {i: weight for i, weight in enumerate(class_weights_array)}
    
    print("\nRecommended Class Weights (for training):")
    print("-" * 60)
    for idx, weight in class_weights.items():
        print(f"Class {idx} ({labels[idx]}): {weight:.4f}")
    
    # Visualize imbalance
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))
    
    # Count comparison
    colors = ['red' if count < max_count * 0.5 else 'orange' if count < max_count * 0.75 else 'green' 
              for count in counts]
    bars = axes[0].bar(labels, counts, color=colors, edgecolor='black', linewidth=1.5)
    axes[0].axhline(max_count, color='green', linestyle='--', alpha=0.5, label='Max count')
    axes[0].axhline(max_count * 0.5, color='orange', linestyle='--', alpha=0.5, label='50% of max')
    axes[0].set_title('Class Distribution (Imbalance Visualization)', fontweight='bold')
    axes[0].set_xlabel('Disease Class')
    axes[0].set_ylabel('Number of Images')
    axes[0].legend()
    axes[0].grid(axis='y', alpha=0.3)
    
    for bar, count in zip(bars, counts):
        height = bar.get_height()
        axes[0].text(bar.get_x() + bar.get_width()/2., height,
                    f'{count:,}',
                    ha='center', va='bottom', fontweight='bold', fontsize=9)
    
    # Class weights visualization
    weight_values = [class_weights[i] for i in range(len(class_weights))]
    axes[1].bar(labels, weight_values, color='steelblue', edgecolor='black', linewidth=1.5)
    axes[1].set_title('Recommended Class Weights', fontweight='bold')
    axes[1].set_xlabel('Disease Class')
    axes[1].set_ylabel('Class Weight')
    axes[1].grid(axis='y', alpha=0.3)
    
    for idx, weight in enumerate(weight_values):
        axes[1].text(idx, weight, f'{weight:.2f}',
                    ha='center', va='bottom', fontweight='bold')
    
    plt.tight_layout()
    plt.show()


<a id="knowledge-engineering"></a>
## 5. Knowledge Engineering Setup

### 5.1 Agriculture Ontology Construction

We will create an ontology that connects:
- **Crops** (Cassava)
- **Diseases** (5 cassava diseases)
- **Symptoms** (visual and physical symptoms)
- **Pests** (related pests)
- **Treatments** (recommended treatments and preventive measures)


In [None]:
# Initialize ontology using RDFLib
from rdflib import Graph, Namespace, Literal, URIRef
from rdflib.namespace import RDF, RDFS, OWL

# Create a new RDF graph
kg = Graph()

# Define namespaces
AGRO = Namespace("http://www.agro-ontology.org/#")
EX = Namespace("http://example.org/agriculture/")

# Bind namespaces
kg.bind("agro", AGRO)
kg.bind("ex", EX)
kg.bind("owl", OWL)
kg.bind("rdfs", RDFS)

print("✓ Knowledge Graph initialized")
print(f"  Base namespace: {AGRO}")
print(f"  Example namespace: {EX}")


In [None]:
# Define classes in the ontology
cassava = AGRO.Cassava
disease = AGRO.Disease
symptom = AGRO.Symptom
pest = AGRO.Pest
treatment = AGRO.Treatment
prevention = AGRO.Prevention

# Add class definitions
kg.add((cassava, RDF.type, OWL.Class))
kg.add((disease, RDF.type, OWL.Class))
kg.add((symptom, RDF.type, OWL.Class))
kg.add((pest, RDF.type, OWL.Class))
kg.add((treatment, RDF.type, OWL.Class))
kg.add((prevention, RDF.type, OWL.Class))

# Define disease instances
diseases = {
    'CBB': AGRO.CassavaBacterialBlight,
    'CBSD': AGRO.CassavaBrownStreakDisease,
    'CGM': AGRO.CassavaGreenMottle,
    'CMD': AGRO.CassavaMosaicDisease,
    'Healthy': AGRO.HealthyCassava
}

for name, uri in diseases.items():
    kg.add((uri, RDF.type, disease))
    kg.add((uri, RDFS.label, Literal(name)))
    kg.add((uri, RDFS.comment, Literal(f"Cassava disease: {name}")))

print("✓ Ontology classes and disease instances created")
print(f"  Total triples: {len(kg)}")


In [None]:
# Define symptoms for each disease
symptoms_db = {
    'CBB': [
        'Water-soaked_lesions',
        'Angular_leaf_spots',
        'Yellowing_leaves',
        'Leaf_wilting',
        'Black_stems'
    ],
    'CBSD': [
        'Brown_streaks_on_stems',
        'Yellowing_between_veins',
        'Chlorotic_mottling',
        'Root_necrosis',
        'Stunted_growth'
    ],
    'CGM': [
        'Green_mottling',
        'Irregular_leaf_patterns',
        'Reduced_photosynthesis',
        'Mild_yellowing',
        'Distorted_leaves'
    ],
    'CMD': [
        'Mosaic_patterns',
        'Leaf_distortion',
        'Reduced_leaf_size',
        'Yellow_green_mottling',
        'Severe_stunting'
    ],
    'Healthy': [
        'Uniform_green_color',
        'No_lesions',
        'Normal_growth',
        'Healthy_roots',
        'Proper_leaf_shape'
    ]
}

# Add symptoms to ontology
for disease_code, symptom_list in symptoms_db.items():
    disease_uri = diseases[disease_code]
    for symptom_name in symptom_list:
        symptom_uri = AGRO[symptom_name.replace('_', '')]
        kg.add((symptom_uri, RDF.type, symptom))
        kg.add((symptom_uri, RDFS.label, Literal(symptom_name.replace('_', ' '))))
        kg.add((disease_uri, AGRO.hasSymptom, symptom_uri))

print("✓ Symptoms added to ontology")
print(f"  Total triples: {len(kg)}")


In [None]:
# Define treatments for each disease
treatments_db = {
    'CBB': [
        ('Copper_based_fungicides', 'Apply copper-based fungicides every 7-10 days'),
        ('Remove_infected_plants', 'Remove and destroy infected plants immediately'),
        ('Crop_rotation', 'Practice crop rotation with non-host crops'),
        ('Resistant_varieties', 'Use disease-resistant cassava varieties'),
        ('Sanitation', 'Maintain field sanitation and remove plant debris')
    ],
    'CBSD': [
        ('Virus_free_planting_material', 'Use certified virus-free planting material'),
        ('Remove_infected_plants', 'Remove and destroy infected plants'),
        ('Vector_control', 'Control whitefly vectors using insecticides'),
        ('Resistant_varieties', 'Plant CBSD-resistant cassava varieties'),
        ('Early_detection', 'Monitor fields regularly for early detection')
    ],
    'CGM': [
        ('Remove_infected_plants', 'Remove and destroy infected plants'),
        ('Vector_control', 'Control insect vectors'),
        ('Sanitation', 'Maintain clean field conditions'),
        ('Resistant_varieties', 'Use resistant cassava varieties'),
        ('Proper_spacing', 'Ensure proper plant spacing for air circulation')
    ],
    'CMD': [
        ('Virus_free_planting_material', 'Use certified virus-free cassava cuttings'),
        ('Remove_infected_plants', 'Remove and destroy infected plants immediately'),
        ('Vector_control', 'Control whitefly populations'),
        ('Resistant_varieties', 'Plant CMD-resistant varieties'),
        ('Early_harvesting', 'Harvest early if infection is detected')
    ],
    'Healthy': [
        ('Maintain_health', 'Continue good agricultural practices'),
        ('Regular_monitoring', 'Monitor for early signs of disease'),
        ('Preventive_measures', 'Apply preventive measures'),
        ('Proper_nutrition', 'Ensure adequate soil nutrition'),
        ('Water_management', 'Maintain proper irrigation')
    ]
}

# Add treatments to ontology
for disease_code, treatment_list in treatments_db.items():
    disease_uri = diseases[disease_code]
    for treatment_name, treatment_desc in treatment_list:
        treatment_uri = AGRO[treatment_name.replace('_', '')]
        kg.add((treatment_uri, RDF.type, treatment))
        kg.add((treatment_uri, RDFS.label, Literal(treatment_name.replace('_', ' '))))
        kg.add((treatment_uri, RDFS.comment, Literal(treatment_desc)))
        kg.add((disease_uri, AGRO.hasTreatment, treatment_uri))

print("✓ Treatments added to ontology")
print(f"  Total triples: {len(kg)}")


In [None]:
# Define pests related to cassava diseases
pests_db = {
    'Whitefly': ['CMD', 'CBSD'],
    'Aphids': ['CMD'],
    'Mealybugs': ['CMD'],
    'Thrips': ['CGM'],
    'Mites': ['CGM']
}

# Add pests to ontology
for pest_name, related_diseases in pests_db.items():
    pest_uri = AGRO[pest_name.replace(' ', '')]
    kg.add((pest_uri, RDF.type, pest))
    kg.add((pest_uri, RDFS.label, Literal(pest_name)))
    for disease_code in related_diseases:
        disease_uri = diseases[disease_code]
        kg.add((disease_uri, AGRO.vectorPest, pest_uri))

print("✓ Pests added to ontology")
print(f"  Total triples: {len(kg)}")
print(f"\nOntology Summary:")
print(f"  - Diseases: {len(diseases)}")
print(f"  - Symptoms: {sum(len(s) for s in symptoms_db.values())}")
print(f"  - Treatments: {sum(len(t) for t in treatments_db.values())}")
print(f"  - Pests: {len(pests_db)}")


### 5.2 Expert Rules Encoding (Minimum 20 Rules)

We will encode expert rules that capture agricultural knowledge about disease diagnosis and treatment recommendations.


In [None]:
# Expert Rules Base
# Rules are structured as: IF conditions THEN conclusion

class ExpertRules:
    """Expert rule base for cassava disease diagnosis and treatment"""
    
    def __init__(self):
        self.rules = []
        self.initialize_rules()
    
    def initialize_rules(self):
        """Initialize at least 20 expert rules"""
        
        # Rule 1-5: CBB (Cassava Bacterial Blight) Rules
        self.rules.append({
            'id': 1,
            'name': 'CBB_Rule_1',
            'conditions': ['Water_soaked_lesions', 'Angular_leaf_spots'],
            'conclusion': 'CBB',
            'confidence': 0.85,
            'description': 'If water-soaked lesions AND angular leaf spots → Cassava Bacterial Blight'
        })
        
        self.rules.append({
            'id': 2,
            'name': 'CBB_Rule_2',
            'conditions': ['Yellowing_leaves', 'Black_stems'],
            'conclusion': 'CBB',
            'confidence': 0.80,
            'description': 'If yellowing leaves AND black stems → Cassava Bacterial Blight'
        })
        
        self.rules.append({
            'id': 3,
            'name': 'CBB_Rule_3',
            'conditions': ['Leaf_wilting', 'Water_soaked_lesions'],
            'conclusion': 'CBB',
            'confidence': 0.75,
            'description': 'If leaf wilting AND water-soaked lesions → Cassava Bacterial Blight'
        })
        
        self.rules.append({
            'id': 4,
            'name': 'CBB_Treatment_Rule',
            'conditions': ['CBB_diagnosed'],
            'conclusion': 'Apply_copper_fungicides',
            'confidence': 0.90,
            'description': 'If CBB diagnosed → Apply copper-based fungicides'
        })
        
        self.rules.append({
            'id': 5,
            'name': 'CBB_Severity_Rule',
            'conditions': ['CBB_diagnosed', 'High_infection_rate'],
            'conclusion': 'Remove_infected_plants',
            'confidence': 0.95,
            'description': 'If CBB AND high infection rate → Remove infected plants immediately'
        })
        
        # Rule 6-10: CBSD (Cassava Brown Streak Disease) Rules
        self.rules.append({
            'id': 6,
            'name': 'CBSD_Rule_1',
            'conditions': ['Brown_streaks_on_stems', 'Yellowing_between_veins'],
            'conclusion': 'CBSD',
            'confidence': 0.88,
            'description': 'If brown streaks on stems AND yellowing between veins → CBSD'
        })
        
        self.rules.append({
            'id': 7,
            'name': 'CBSD_Rule_2',
            'conditions': ['Root_necrosis', 'Chlorotic_mottling'],
            'conclusion': 'CBSD',
            'confidence': 0.85,
            'description': 'If root necrosis AND chlorotic mottling → CBSD'
        })
        
        self.rules.append({
            'id': 8,
            'name': 'CBSD_Rule_3',
            'conditions': ['Stunted_growth', 'Brown_streaks_on_stems'],
            'conclusion': 'CBSD',
            'confidence': 0.82,
            'description': 'If stunted growth AND brown streaks on stems → CBSD'
        })
        
        self.rules.append({
            'id': 9,
            'name': 'CBSD_Treatment_Rule',
            'conditions': ['CBSD_diagnosed'],
            'conclusion': 'Use_virus_free_material',
            'confidence': 0.92,
            'description': 'If CBSD diagnosed → Use virus-free planting material'
        })
        
        self.rules.append({
            'id': 10,
            'name': 'CBSD_Vector_Rule',
            'conditions': ['CBSD_diagnosed', 'Whitefly_present'],
            'conclusion': 'Control_whiteflies',
            'confidence': 0.90,
            'description': 'If CBSD AND whiteflies present → Control whitefly vectors'
        })
        
        # Rule 11-15: CGM (Cassava Green Mottle) Rules
        self.rules.append({
            'id': 11,
            'name': 'CGM_Rule_1',
            'conditions': ['Green_mottling', 'Irregular_leaf_patterns'],
            'conclusion': 'CGM',
            'confidence': 0.80,
            'description': 'If green mottling AND irregular leaf patterns → CGM'
        })
        
        self.rules.append({
            'id': 12,
            'name': 'CGM_Rule_2',
            'conditions': ['Distorted_leaves', 'Mild_yellowing'],
            'conclusion': 'CGM',
            'confidence': 0.75,
            'description': 'If distorted leaves AND mild yellowing → CGM'
        })
        
        self.rules.append({
            'id': 13,
            'name': 'CGM_Treatment_Rule',
            'conditions': ['CGM_diagnosed'],
            'conclusion': 'Remove_infected_plants',
            'confidence': 0.85,
            'description': 'If CGM diagnosed → Remove infected plants'
        })
        
        self.rules.append({
            'id': 14,
            'name': 'CGM_Vector_Rule',
            'conditions': ['CGM_diagnosed', 'Thrips_present'],
            'conclusion': 'Control_thrips',
            'confidence': 0.88,
            'description': 'If CGM AND thrips present → Control thrips'
        })
        
        self.rules.append({
            'id': 15,
            'name': 'CGM_Prevention_Rule',
            'conditions': ['CGM_risk_high'],
            'conclusion': 'Apply_preventive_spacing',
            'confidence': 0.80,
            'description': 'If CGM risk high → Ensure proper plant spacing'
        })
        
        # Rule 16-20: CMD (Cassava Mosaic Disease) Rules
        self.rules.append({
            'id': 16,
            'name': 'CMD_Rule_1',
            'conditions': ['Mosaic_patterns', 'Leaf_distortion'],
            'conclusion': 'CMD',
            'confidence': 0.90,
            'description': 'If mosaic patterns AND leaf distortion → CMD'
        })
        
        self.rules.append({
            'id': 17,
            'name': 'CMD_Rule_2',
            'conditions': ['Yellow_green_mottling', 'Severe_stunting'],
            'conclusion': 'CMD',
            'confidence': 0.87,
            'description': 'If yellow-green mottling AND severe stunting → CMD'
        })
        
        self.rules.append({
            'id': 18,
            'name': 'CMD_Rule_3',
            'conditions': ['Reduced_leaf_size', 'Mosaic_patterns'],
            'conclusion': 'CMD',
            'confidence': 0.83,
            'description': 'If reduced leaf size AND mosaic patterns → CMD'
        })
        
        self.rules.append({
            'id': 19,
            'name': 'CMD_Treatment_Rule',
            'conditions': ['CMD_diagnosed'],
            'conclusion': 'Use_virus_free_cuttings',
            'confidence': 0.93,
            'description': 'If CMD diagnosed → Use virus-free cassava cuttings'
        })
        
        self.rules.append({
            'id': 20,
            'name': 'CMD_Vector_Rule',
            'conditions': ['CMD_diagnosed', 'Whitefly_present'],
            'conclusion': 'Control_whitefly_populations',
            'confidence': 0.91,
            'description': 'If CMD AND whiteflies present → Control whitefly populations'
        })
        
        # Additional rules (21-25) for treatment and prevention
        self.rules.append({
            'id': 21,
            'name': 'General_Prevention_Rule_1',
            'conditions': ['Early_season'],
            'conclusion': 'Use_certified_planting_material',
            'confidence': 0.85,
            'description': 'If early season → Use certified planting material'
        })
        
        self.rules.append({
            'id': 22,
            'name': 'General_Prevention_Rule_2',
            'conditions': ['High_whitefly_population'],
            'conclusion': 'Apply_preventive_insecticides',
            'confidence': 0.88,
            'description': 'If high whitefly population → Apply preventive insecticides'
        })
        
        self.rules.append({
            'id': 23,
            'name': 'Severity_Escalation_Rule',
            'conditions': ['Disease_diagnosed', 'Infection_rate_above_50'],
            'conclusion': 'Immediate_field_management',
            'confidence': 0.95,
            'description': 'If disease diagnosed AND infection rate > 50% → Immediate field management required'
        })
        
        self.rules.append({
            'id': 24,
            'name': 'Healthy_Maintenance_Rule',
            'conditions': ['Uniform_green_color', 'No_lesions', 'Normal_growth'],
            'conclusion': 'Healthy',
            'confidence': 0.95,
            'description': 'If uniform green, no lesions, normal growth → Healthy'
        })
        
        self.rules.append({
            'id': 25,
            'name': 'Integrated_Management_Rule',
            'conditions': ['Multiple_diseases_present'],
            'conclusion': 'Apply_integrated_disease_management',
            'confidence': 0.90,
            'description': 'If multiple diseases present → Apply integrated disease management'
        })
    
    def get_rules(self):
        return self.rules
    
    def find_matching_rules(self, conditions):
        """Find rules that match given conditions"""
        matching_rules = []
        for rule in self.rules:
            # Check if all rule conditions are met
            if all(cond in conditions for cond in rule['conditions']):
                matching_rules.append(rule)
        return matching_rules
    
    def forward_chaining(self, initial_conditions):
        """Forward chaining inference engine"""
        facts = set(initial_conditions)
        conclusions = []
        applied_rules = []
        
        changed = True
        while changed:
            changed = False
            for rule in self.rules:
                if rule['id'] not in applied_rules:
                    # Check if all conditions are satisfied
                    if all(cond in facts for cond in rule['conditions']):
                        # Add conclusion to facts
                        conclusion = rule['conclusion']
                        facts.add(conclusion)
                        conclusions.append({
                            'rule': rule,
                            'conclusion': conclusion,
                            'confidence': rule['confidence']
                        })
                        applied_rules.append(rule['id'])
                        changed = True
        
        return conclusions, facts

# Initialize expert rules
expert_rules = ExpertRules()
rules = expert_rules.get_rules()

print(f"✓ Expert Rules Base initialized")
print(f"  Total rules: {len(rules)}")
print(f"\nRule Summary:")
print("-" * 70)
for rule in rules[:10]:  # Show first 10 rules
    print(f"Rule {rule['id']}: {rule['description']}")
print(f"... and {len(rules) - 10} more rules")


In [None]:
# Save ontology to file
ontology_file = "agriculture_ontology.rdf"
kg.serialize(destination=ontology_file, format='xml')
print(f"✓ Ontology saved to: {ontology_file}")

# Save rules to JSON
rules_file = "expert_rules.json"
with open(rules_file, 'w') as f:
    json.dump(rules, f, indent=2)
print(f"✓ Expert rules saved to: {rules_file}")

# Display summary
print(f"\n{'='*70}")
print("KNOWLEDGE BASE SUMMARY")
print(f"{'='*70}")
print(f"Ontology Triples: {len(kg):,}")
print(f"Expert Rules: {len(rules)}")
print(f"Diseases: {len(diseases)}")
print(f"Total Symptoms: {sum(len(s) for s in symptoms_db.values())}")
print(f"Total Treatments: {sum(len(t) for t in treatments_db.values())}")
print(f"{'='*70}")


### 5.3 Reasoning Engine Implementation


In [None]:
class HybridReasoningEngine:
    """Hybrid reasoning engine combining KE and ML"""
    
    def __init__(self, expert_rules, disease_treatments):
        self.expert_rules = expert_rules
        self.disease_treatments = disease_treatments
    
    def reason_from_symptoms(self, observed_symptoms):
        """Reason about disease from observed symptoms"""
        # Use forward chaining
        conclusions, facts = self.expert_rules.forward_chaining(observed_symptoms)
        
        # Extract disease predictions
        disease_predictions = {}
        for conc in conclusions:
            if conc['conclusion'] in ['CBB', 'CBSD', 'CGM', 'CMD', 'Healthy']:
                disease_predictions[conc['conclusion']] = conc['confidence']
        
        return disease_predictions, conclusions
    
    def get_treatment_recommendations(self, disease):
        """Get treatment recommendations for a diagnosed disease"""
        if disease in self.disease_treatments:
            return self.disease_treatments[disease]
        return []
    
    def hybrid_reasoning(self, ml_prediction, ml_confidence, observed_symptoms=None):
        """Combine ML prediction with KE reasoning"""
        
        # KE reasoning from symptoms (if provided)
        ke_predictions = {}
        if observed_symptoms:
            ke_predictions, ke_conclusions = self.reason_from_symptoms(observed_symptoms)
        
        # Combine ML and KE predictions
        final_prediction = ml_prediction
        final_confidence = ml_confidence
        
        # If KE also predicts the same disease, boost confidence
        if ml_prediction in ke_predictions:
            ke_conf = ke_predictions[ml_prediction]
            # Weighted combination: 70% ML, 30% KE
            final_confidence = 0.7 * ml_confidence + 0.3 * ke_conf
            reasoning_type = "Hybrid (ML + KE agreement)"
        else:
            # If disagreement, prefer ML but consider KE
            if ke_predictions:
                ke_pred = max(ke_predictions, key=ke_predictions.get)
                if ke_predictions[ke_pred] > 0.8:  # Strong KE prediction
                    # Moderate ML confidence if KE disagrees strongly
                    final_confidence = ml_confidence * 0.9
                    reasoning_type = f"ML (KE suggests {ke_pred})"
                else:
                    reasoning_type = "ML (primary)"
            else:
                reasoning_type = "ML (no KE input)"
        
        # Get treatments
        treatments = self.get_treatment_recommendations(ml_prediction)
        
        return {
            'disease': final_prediction,
            'confidence': final_confidence,
            'reasoning_type': reasoning_type,
            'ml_prediction': ml_prediction,
            'ml_confidence': ml_confidence,
            'ke_predictions': ke_predictions,
            'treatments': treatments
        }

# Initialize reasoning engine
reasoning_engine = HybridReasoningEngine(expert_rules, treatments_db)
print("✓ Hybrid Reasoning Engine initialized")


In [None]:
# Prepare data for CNN training
def prepare_dataset(base_path, img_size=(224, 224), test_split=0.2, val_split=0.2):
    """Prepare train/validation/test splits"""
    
    if not Path(base_path).exists():
        print(f"⚠ Path {base_path} does not exist")
        return None, None, None, None, None
    
    all_image_paths = []
    all_labels = []
    
    folder_mapping = {
        'Cassava___bacterial_blight': 0,
        'Cassava___brown_streak_disease': 1,
        'Cassava___green_mottle': 2,
        'Cassava___mosaic_disease': 3,
        'Cassava___healthy': 4
    }
    
    # Collect all images
    for folder_name, class_idx in folder_mapping.items():
        folder_path = Path(base_path) / folder_name
        if folder_path.exists():
            for ext in ['.jpg', '.jpeg', '.png']:
                images = list(folder_path.glob(f"*{ext}"))
                for img_path in images:
                    all_image_paths.append(str(img_path))
                    all_labels.append(class_idx)
    
    all_image_paths = np.array(all_image_paths)
    all_labels = np.array(all_labels)
    
    print(f"Total images collected: {len(all_image_paths):,}")
    
    # Split: train -> (1-test_split), test -> test_split
    train_paths, test_paths, train_labels, test_labels = train_test_split(
        all_image_paths, all_labels, 
        test_size=test_split, 
        random_state=42, 
        stratify=all_labels
    )
    
    # Further split train into train and validation
    train_paths, val_paths, train_labels, val_labels = train_test_split(
        train_paths, train_labels,
        test_size=val_split,
        random_state=42,
        stratify=train_labels
    )
    
    print(f"\nData Split:")
    print(f"  Training:   {len(train_paths):,} images ({len(train_paths)/len(all_image_paths)*100:.1f}%)")
    print(f"  Validation: {len(val_paths):,} images ({len(val_paths)/len(all_image_paths)*100:.1f}%)")
    print(f"  Test:       {len(test_paths):,} images ({len(test_paths)/len(all_image_paths)*100:.1f}%)")
    
    return train_paths, val_paths, test_paths, train_labels, val_labels, test_labels

# Prepare dataset
if class_stats:
    train_paths, val_paths, test_paths, train_labels, val_labels, test_labels = prepare_dataset(
        BASE_DATA_PATH, img_size=(224, 224), test_split=0.2, val_split=0.2
    )
else:
    print("⚠ Cannot prepare dataset - class_stats not available")
    train_paths = val_paths = test_paths = None
    train_labels = val_labels = test_labels = None


In [None]:
# Create TensorFlow datasets with data augmentation
def load_and_preprocess_image(image_path, label, img_size=(224, 224), augment=False):
    """Load and preprocess image"""
    image = tf.io.read_file(image_path)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, img_size)
    image = tf.cast(image, tf.float32) / 255.0
    
    if augment:
        # Data augmentation
        image = tf.image.random_flip_left_right(image)
        image = tf.image.random_flip_up_down(image)
        image = tf.image.random_brightness(image, max_delta=0.2)
        image = tf.image.random_contrast(image, lower=0.8, upper=1.2)
        image = tf.image.random_saturation(image, lower=0.8, upper=1.2)
        image = tf.clip_by_value(image, 0.0, 1.0)
    
    return image, label

def create_tf_dataset(image_paths, labels, img_size=(224, 224), batch_size=32, 
                     augment=False, shuffle=True, buffer_size=1000):
    """Create TensorFlow dataset"""
    dataset = tf.data.Dataset.from_tensor_slices((image_paths, labels))
    
    if shuffle:
        dataset = dataset.shuffle(buffer_size=buffer_size)
    
    dataset = dataset.map(
        lambda x, y: load_and_preprocess_image(x, y, img_size, augment),
        num_parallel_calls=tf.data.AUTOTUNE
    )
    
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)
    
    return dataset

# Create datasets
BATCH_SIZE = 32
IMG_SIZE = (224, 224)

if train_paths is not None:
    train_ds = create_tf_dataset(train_paths, train_labels, IMG_SIZE, BATCH_SIZE, 
                                augment=True, shuffle=True)
    val_ds = create_tf_dataset(val_paths, val_labels, IMG_SIZE, BATCH_SIZE, 
                              augment=False, shuffle=False)
    test_ds = create_tf_dataset(test_paths, test_labels, IMG_SIZE, BATCH_SIZE, 
                               augment=False, shuffle=False)
    
    print("✓ TensorFlow datasets created")
    print(f"  Training batches: {len(train_ds)}")
    print(f"  Validation batches: {len(val_ds)}")
    print(f"  Test batches: {len(test_ds)}")
else:
    train_ds = val_ds = test_ds = None
    print("⚠ Cannot create datasets")


### 6.2 CNN Model Architecture - Baseline Model
