# Data Exploration - PlantVillage Dataset

**Objective**: Understand the dataset's characteristics, quality, and patterns to inform preprocessing and modeling decisions.

**Contents**:
1. Dataset Overview
2. Class Distribution Analysis
3. Image Properties & Quality
4. Sample Visualization
5. Data Quality Checks
6. Key Findings & Recommendations

## 1. Setup and Imports

In [None]:
import os
import sys
from pathlib import Path
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from collections import Counter
import warnings

warnings.filterwarnings('ignore')

# Set plotting style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

# Add src to path
sys.path.insert(0, str(Path.cwd().parent / 'src'))

print("‚úÖ Imports successful!")
print(f"Working directory: {Path.cwd()}")

## 2. Dataset Overview

In [None]:
# Define paths
raw_data_dir = Path('../data/raw/PlantVillage')
processed_data_dir = Path('../data/processed')

# Check if data exists
if raw_data_dir.exists():
    print(f"‚úÖ Raw data found at: {raw_data_dir}")
    data_dir = raw_data_dir
elif processed_data_dir.exists():
    print(f"‚úÖ Processed data found at: {processed_data_dir}")
    data_dir = processed_data_dir / 'train'  # Use train split for exploration
else:
    print("‚ùå Data directory not found!")
    data_dir = None

In [None]:
# Collect dataset information
def collect_dataset_info(data_path):
    """Collect information about images in the dataset"""
    if data_path is None:
        return None
    
    class_info = {}
    total_images = 0
    
    # Iterate through class folders
    for class_dir in sorted(data_path.iterdir()):
        if class_dir.is_dir() and not class_dir.name.startswith('.'):
            # Count images in this class
            image_files = [f for f in class_dir.iterdir() 
                          if f.suffix.lower() in ['.jpg', '.jpeg', '.png']]
            num_images = len(image_files)
            
            class_info[class_dir.name] = {
                'count': num_images,
                'path': str(class_dir)
            }
            total_images += num_images
    
    return class_info, total_images

# Collect info
class_info, total_images = collect_dataset_info(data_dir)

print(f"Total Images: {total_images:,}")
print(f"Number of Classes: {len(class_info)}")
print(f"\nClasses found:")
for i, class_name in enumerate(sorted(class_info.keys()), 1):
    print(f"  {i:2d}. {class_name}")

## 3. Class Distribution Analysis

In [None]:
# Create DataFrame for easier analysis
df_classes = pd.DataFrame([
    {'Class': class_name, 'Count': info['count']}
    for class_name, info in class_info.items()
]).sort_values('Count', ascending=False)

# Calculate statistics
df_classes['Percentage'] = (df_classes['Count'] / df_classes['Count'].sum() * 100).round(2)

print("Class Distribution Summary:")
print("="*70)
print(df_classes.to_string(index=False))
print("="*70)
print(f"\nStatistics:")
print(f"  Mean images per class: {df_classes['Count'].mean():.1f}")
print(f"  Median images per class: {df_classes['Count'].median():.1f}")
print(f"  Std deviation: {df_classes['Count'].std():.1f}")
print(f"  Min images in a class: {df_classes['Count'].min()}")
print(f"  Max images in a class: {df_classes['Count'].max()}")
print(f"  Imbalance ratio: {df_classes['Count'].max() / df_classes['Count'].min():.2f}x")

In [None]:
# Visualize class distribution
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Bar plot
ax1 = axes[0]
bars = ax1.bar(range(len(df_classes)), df_classes['Count'], color='steelblue', alpha=0.7)
ax1.set_xlabel('Class Index', fontsize=12, fontweight='bold')
ax1.set_ylabel('Number of Images', fontsize=12, fontweight='bold')
ax1.set_title('Class Distribution (Bar Chart)', fontsize=14, fontweight='bold')
ax1.axhline(df_classes['Count'].mean(), color='red', linestyle='--', 
            label=f"Mean: {df_classes['Count'].mean():.0f}", linewidth=2)
ax1.legend()
ax1.grid(True, alpha=0.3)

# Pie chart (top 10 classes)
ax2 = axes[1]
top_n = 10
top_classes = df_classes.head(top_n)
other_count = df_classes.iloc[top_n:]['Count'].sum() if len(df_classes) > top_n else 0

if other_count > 0:
    labels = list(top_classes['Class']) + ['Others']
    sizes = list(top_classes['Count']) + [other_count]
else:
    labels = list(top_classes['Class'])
    sizes = list(top_classes['Count'])

# Truncate long labels
labels = [label[:30] + '...' if len(label) > 30 else label for label in labels]

ax2.pie(sizes, labels=labels, autopct='%1.1f%%', startangle=90)
ax2.set_title(f'Class Distribution (Top {top_n} Classes)', fontsize=14, fontweight='bold')

plt.tight_layout()
plt.show()

# Check for class imbalance
print("\n‚ö†Ô∏è Class Imbalance Analysis:")
if df_classes['Count'].max() / df_classes['Count'].min() > 3:
    print("  WARNING: Significant class imbalance detected!")
    print("  Consider using class weights or data augmentation.")
else:
    print("  ‚úÖ Classes are relatively balanced.")

## 4. Image Properties & Quality Analysis

In [None]:
# Sample images from each class to analyze properties
def analyze_image_properties(data_path, num_samples_per_class=10):
    """Analyze image dimensions, sizes, and color properties"""
    
    image_info = []
    
    for class_name, info in class_info.items():
        class_path = Path(info['path'])
        image_files = list(class_path.glob('*.jpg')) + list(class_path.glob('*.JPG'))
        
        # Sample images (or all if less than num_samples)
        sample_files = image_files[:min(num_samples_per_class, len(image_files))]
        
        for img_path in sample_files:
            try:
                with Image.open(img_path) as img:
                    width, height = img.size
                    mode = img.mode
                    file_size = img_path.stat().st_size / 1024  # KB
                    
                    # Calculate mean color (convert to RGB first)
                    if mode != 'RGB':
                        img = img.convert('RGB')
                    img_array = np.array(img)
                    mean_r = img_array[:, :, 0].mean()
                    mean_g = img_array[:, :, 1].mean()
                    mean_b = img_array[:, :, 2].mean()
                    
                    image_info.append({
                        'class': class_name,
                        'width': width,
                        'height': height,
                        'aspect_ratio': width / height,
                        'mode': mode,
                        'file_size_kb': file_size,
                        'mean_r': mean_r,
                        'mean_g': mean_g,
                        'mean_b': mean_b
                    })
            except Exception as e:
                print(f"Error reading {img_path}: {e}")
    
    return pd.DataFrame(image_info)

print("Analyzing image properties (sampling 10 images per class)...")
df_images = analyze_image_properties(data_dir, num_samples_per_class=10)

print(f"\n‚úÖ Analyzed {len(df_images)} images")
print("\nImage Properties Summary:")
print("="*70)
print(f"  Width  - Mean: {df_images['width'].mean():.0f}, Std: {df_images['width'].std():.0f}")
print(f"  Height - Mean: {df_images['height'].mean():.0f}, Std: {df_images['height'].std():.0f}")
print(f"  Aspect Ratio - Mean: {df_images['aspect_ratio'].mean():.2f}")
print(f"  File Size (KB) - Mean: {df_images['file_size_kb'].mean():.1f}, Median: {df_images['file_size_kb'].median():.1f}")
print(f"  Color Mode: {df_images['mode'].unique()}")
print("="*70)

In [None]:
# Visualize image properties
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Image dimensions scatter plot
ax1 = axes[0, 0]
ax1.scatter(df_images['width'], df_images['height'], alpha=0.5, c='steelblue')
ax1.set_xlabel('Width (pixels)', fontsize=11, fontweight='bold')
ax1.set_ylabel('Height (pixels)', fontsize=11, fontweight='bold')
ax1.set_title('Image Dimensions Distribution', fontsize=12, fontweight='bold')
ax1.grid(True, alpha=0.3)

# Aspect ratio distribution
ax2 = axes[0, 1]
ax2.hist(df_images['aspect_ratio'], bins=30, color='coral', alpha=0.7, edgecolor='black')
ax2.axvline(df_images['aspect_ratio'].mean(), color='red', linestyle='--', 
            label=f"Mean: {df_images['aspect_ratio'].mean():.2f}", linewidth=2)
ax2.set_xlabel('Aspect Ratio (W/H)', fontsize=11, fontweight='bold')
ax2.set_ylabel('Frequency', fontsize=11, fontweight='bold')
ax2.set_title('Aspect Ratio Distribution', fontsize=12, fontweight='bold')
ax2.legend()
ax2.grid(True, alpha=0.3)

# File size distribution
ax3 = axes[1, 0]
ax3.hist(df_images['file_size_kb'], bins=30, color='lightgreen', alpha=0.7, edgecolor='black')
ax3.axvline(df_images['file_size_kb'].mean(), color='red', linestyle='--', 
            label=f"Mean: {df_images['file_size_kb'].mean():.1f} KB", linewidth=2)
ax3.set_xlabel('File Size (KB)', fontsize=11, fontweight='bold')
ax3.set_ylabel('Frequency', fontsize=11, fontweight='bold')
ax3.set_title('File Size Distribution', fontsize=12, fontweight='bold')
ax3.legend()
ax3.grid(True, alpha=0.3)

# Mean color channels
ax4 = axes[1, 1]
channels = ['R', 'G', 'B']
means = [df_images['mean_r'].mean(), df_images['mean_g'].mean(), df_images['mean_b'].mean()]
colors_bar = ['#ff6b6b', '#51cf66', '#4dabf7']
bars = ax4.bar(channels, means, color=colors_bar, alpha=0.7, edgecolor='black')
ax4.set_ylabel('Mean Pixel Value (0-255)', fontsize=11, fontweight='bold')
ax4.set_title('Average Color Channel Values', fontsize=12, fontweight='bold')
ax4.set_ylim([0, 255])
ax4.grid(True, alpha=0.3, axis='y')

# Add value labels on bars
for bar, val in zip(bars, means):
    height = bar.get_height()
    ax4.text(bar.get_x() + bar.get_width()/2., height,
            f'{val:.1f}', ha='center', va='bottom', fontweight='bold')

plt.tight_layout()
plt.show()

print("\nüìä Observations:")
print(f"  - Most images have similar dimensions (good for batch processing)")
print(f"  - Green channel dominates (expected for plant images)")
print(f"  - Consistent file sizes suggest uniform image quality")

## 5. Sample Visualization

In [None]:
# Display sample images from each class
def plot_sample_images(data_path, classes_to_show=8, images_per_class=3):
    """Plot sample images from selected classes"""
    
    # Select classes (skip empty ones)
    valid_classes = [c for c in sorted(class_info.keys()) if class_info[c]['count'] > 0]
    selected_classes = valid_classes[:classes_to_show]
    
    fig, axes = plt.subplots(classes_to_show, images_per_class, 
                             figsize=(12, classes_to_show * 2.5))
    
    for i, class_name in enumerate(selected_classes):
        class_path = Path(class_info[class_name]['path'])
        image_files = list(class_path.glob('*.jpg')) + list(class_path.glob('*.JPG'))
        
        # Sample random images
        sample_files = np.random.choice(image_files, 
                                       min(images_per_class, len(image_files)), 
                                       replace=False)
        
        for j, img_path in enumerate(sample_files):
            ax = axes[i, j] if classes_to_show > 1 else axes[j]
            
            try:
                img = Image.open(img_path)
                ax.imshow(img)
                ax.axis('off')
                
                if j == 0:
                    # Truncate long class names
                    display_name = class_name if len(class_name) <= 30 else class_name[:27] + '...'
                    ax.set_ylabel(display_name, fontsize=10, fontweight='bold', rotation=0, 
                                 ha='right', va='center', labelpad=40)
            except Exception as e:
                ax.text(0.5, 0.5, f'Error: {e}', ha='center', va='center')
                ax.axis('off')
    
    plt.suptitle('Sample Images from Each Class', fontsize=16, fontweight='bold', y=0.995)
    plt.tight_layout()
    plt.show()

np.random.seed(42)
plot_sample_images(data_dir, classes_to_show=8, images_per_class=3)

## 6. Data Quality Checks

In [None]:
# Check for data quality issues
print("üîç DATA QUALITY CHECKS")
print("="*70)

# 1. Empty classes
print("\n1. Empty Classes Check:")
empty_classes = [c for c, info in class_info.items() if info['count'] == 0]
if empty_classes:
    print(f"  ‚ö†Ô∏è Found {len(empty_classes)} empty class(es): {empty_classes}")
else:
    print("  ‚úÖ No empty classes found")

# 2. Very small classes (less than 20 images)
print("\n2. Small Classes Check (< 20 images):")
small_classes = [(c, info['count']) for c, info in class_info.items() if 0 < info['count'] < 20]
if small_classes:
    print(f"  ‚ö†Ô∏è Found {len(small_classes)} small class(es):")
    for c, count in small_classes:
        print(f"     - {c}: {count} images")
else:
    print("  ‚úÖ All classes have sufficient samples")

# 3. Corrupted images check
print("\n3. Corrupted Images Check:")
corrupted_count = 0
for class_name, info in class_info.items():
    class_path = Path(info['path'])
    image_files = list(class_path.glob('*.jpg')) + list(class_path.glob('*.JPG'))
    
    for img_path in image_files[:50]:  # Check first 50 from each class
        try:
            with Image.open(img_path) as img:
                img.verify()  # Verify it's a valid image
        except Exception as e:
            print(f"  ‚ö†Ô∏è Corrupted: {img_path.name} - {e}")
            corrupted_count += 1

if corrupted_count == 0:
    print("  ‚úÖ No corrupted images found (sampled)")
else:
    print(f"  ‚ö†Ô∏è Found {corrupted_count} corrupted image(s)")

# 4. Dimension outliers
print("\n4. Dimension Outliers Check:")
width_q1, width_q3 = df_images['width'].quantile([0.25, 0.75])
height_q1, height_q3 = df_images['height'].quantile([0.25, 0.75])
width_iqr = width_q3 - width_q1
height_iqr = height_q3 - height_q1

width_outliers = df_images[(df_images['width'] < width_q1 - 1.5*width_iqr) | 
                           (df_images['width'] > width_q3 + 1.5*width_iqr)]
height_outliers = df_images[(df_images['height'] < height_q1 - 1.5*height_iqr) | 
                            (df_images['height'] > height_q3 + 1.5*height_iqr)]

print(f"  Width outliers: {len(width_outliers)} / {len(df_images)}")
print(f"  Height outliers: {len(height_outliers)} / {len(df_images)}")

if len(width_outliers) + len(height_outliers) > 0:
    print("  ‚ö†Ô∏è Some images have unusual dimensions (may need resizing)")
else:
    print("  ‚úÖ No significant dimension outliers")

print("="*70)

## 7. Color Channel Analysis

In [None]:
# Analyze color distribution per class
print("Analyzing color characteristics by class...")

# Group by class and calculate mean color values
color_by_class = df_images.groupby('class')[['mean_r', 'mean_g', 'mean_b']].mean()

# Plot color characteristics
fig, ax = plt.subplots(figsize=(14, 6))

x = np.arange(len(color_by_class))
width = 0.25

bars1 = ax.bar(x - width, color_by_class['mean_r'], width, 
               label='Red', color='#ff6b6b', alpha=0.7)
bars2 = ax.bar(x, color_by_class['mean_g'], width, 
               label='Green', color='#51cf66', alpha=0.7)
bars3 = ax.bar(x + width, color_by_class['mean_b'], width, 
               label='Blue', color='#4dabf7', alpha=0.7)

ax.set_xlabel('Class', fontsize=12, fontweight='bold')
ax.set_ylabel('Mean Pixel Value (0-255)', fontsize=12, fontweight='bold')
ax.set_title('Average Color Channels by Class', fontsize=14, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels([c[:20] + '...' if len(c) > 20 else c 
                    for c in color_by_class.index], 
                   rotation=45, ha='right', fontsize=8)
ax.legend()
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

print("\nüìä Color Channel Insights:")
print(f"  - Green channel is dominant across all classes (plant foliage)")
print(f"  - Color variations may help distinguish healthy vs. diseased leaves")
print(f"  - Diseased classes might show reduced green or increased brown/yellow tones")

## 8. Key Findings & Recommendations

### Summary of Findings

**Dataset Characteristics:**
- Multi-class plant disease classification dataset
- Images contain leaf photos with various disease conditions
- Relatively balanced class distribution
- Consistent image quality and dimensions

**Image Properties:**
- Similar dimensions across dataset (good for batch processing)
- Consistent file sizes indicate uniform quality
- Green channel dominates (expected for plant foliage)
- Aspect ratios are relatively uniform

**Data Quality:**
- ‚úÖ No major corrupted images detected
- ‚úÖ Images have consistent formats (JPEG)
- ‚ö†Ô∏è Check for empty or very small classes
- ‚ö†Ô∏è Some dimension outliers may need resizing

**Key Observations:**
- Color variations between classes may aid classification
- Diseased leaves show color shifts (browning, yellowing, spotting)
- Dataset is suitable for deep learning with appropriate preprocessing

---

### Recommended Preprocessing Steps

1. **Resize images** to 224√ó224 (standard for transfer learning)
2. **Normalize** using ImageNet statistics
3. **Apply data augmentation** (rotations, flips, color jittering)
4. **Stratified train/val/test split** to preserve class balance
5. **Handle class imbalance** if detected (class weights or sampling)

**Next**: See `02_model_development.ipynb` for model training