In [None]:
# Cherry Leaf Mildew Detection - Exploratory Data Analysis
#
# Objectives:
# * Create average and variability images for each class
# * Generate image montages for visual comparison
# * Analyze differences between healthy and infected leaves
# * Create visualizations for the dashboard
#
# Inputs:
# * Processed image dataset from data_collection.ipynb
#
# Outputs:
# * Visualization findings
# * Statistical analysis results
# * Generated plots for dashboard

import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import itertools
from pathlib import Path

# Set up paths
PROCESSED_DATA_DIR = "../data/processed"

def load_images(class_name, split='train', max_images=100):
    """Load images for a given class and split"""
    image_dir = os.path.join(PROCESSED_DATA_DIR, split, class_name)
    images = []
    for img_path in list(Path(image_dir).glob("*.jpg"))[:max_images]:
        with Image.open(img_path) as img:
            # Convert to RGB and resize if needed
            img = img.convert('RGB')
            img = img.resize((256, 256))  # Standardize size
            images.append(np.array(img))
    return np.array(images)

def create_average_image(images):
    """Create average image from a list of images"""
    return np.mean(images, axis=0).astype(np.uint8)

def create_variability_image(images):
    """Create variability image showing standard deviation across images"""
    return np.std(images, axis=0).astype(np.uint8)

def create_difference_image(healthy_avg, infected_avg):
    """Create difference image between average healthy and infected leaves"""
    difference = healthy_avg.astype(np.float32) - infected_avg.astype(np.float32)
    # Normalize to 0-255 range
    difference = ((difference - difference.min()) * (255.0 / (difference.max() - difference.min()))).astype(np.uint8)
    return difference

def create_montage(images, grid_size=(5, 5)):
    """Create a montage of images"""
    rows, cols = grid_size
    cell_size = images[0].shape[:2]
    montage = np.zeros((cell_size[0] * rows, cell_size[1] * cols, 3), dtype=np.uint8)
    
    for idx, image in enumerate(images[:rows*cols]):
        i, j = idx // cols, idx % cols
        montage[i*cell_size[0]:(i+1)*cell_size[0], 
                j*cell_size[1]:(j+1)*cell_size[1]] = image
    return montage

# Load images
print("Loading images...")
healthy_images = load_images('healthy')
infected_images = load_images('powdery_mildew')

# Create and plot average images
print("\nCreating average images...")
healthy_avg = create_average_image(healthy_images)
infected_avg = create_average_image(infected_images)

plt.figure(figsize=(12, 4))
plt.subplot(131)
plt.imshow(healthy_avg)
plt.title('Average Healthy Leaf')
plt.axis('off')

plt.subplot(132)
plt.imshow(infected_avg)
plt.title('Average Infected Leaf')
plt.axis('off')

plt.subplot(133)
plt.imshow(create_difference_image(healthy_avg, infected_avg))
plt.title('Difference Image')
plt.axis('off')
plt.show()

# Create and plot variability images
print("\nCreating variability images...")
healthy_var = create_variability_image(healthy_images)
infected_var = create_variability_image(infected_images)

plt.figure(figsize=(8, 4))
plt.subplot(121)
plt.imshow(healthy_var)
plt.title('Healthy Leaf Variability')
plt.axis('off')

plt.subplot(122)
plt.imshow(infected_var)
plt.title('Infected Leaf Variability')
plt.axis('off')
plt.show()

# Create and plot image montages
print("\nCreating image montages...")
healthy_montage = create_montage(healthy_images)
infected_montage = create_montage(infected_images)

plt.figure(figsize=(15, 6))
plt.subplot(121)
plt.imshow(healthy_montage)
plt.title('Healthy Leaves Montage')
plt.axis('off')

plt.subplot(122)
plt.imshow(infected_montage)
plt.title('Infected Leaves Montage')
plt.axis('off')
plt.show()

# Statistical Analysis
print("\nPerforming statistical analysis...")
def calculate_image_statistics(images):
    """Calculate basic statistical measures for a set of images"""
    # Convert to float for calculations
    images_float = images.astype(np.float32)
    
    stats = {
        'mean_intensity': np.mean(images_float),
        'std_intensity': np.std(images_float),
        'min_intensity': np.min(images_float),
        'max_intensity': np.max(images_float),
        'mean_per_channel': np.mean(images_float, axis=(0,1,2)),
    }
    return stats

healthy_stats = calculate_image_statistics(healthy_images)
infected_stats = calculate_image_statistics(infected_images)

print("\nImage Statistics:")
print("\nHealthy Leaves:")
for key, value in healthy_stats.items():
    print(f"{key}: {value}")

print("\nInfected Leaves:")
for key, value in infected_stats.items():
    print(f"{key}: {value}")

# Save visualizations for dashboard
print("\nSaving visualizations for dashboard...")
output_dir = "../outputs/visualizations"
os.makedirs(output_dir, exist_ok=True)

plt.imsave(os.path.join(output_dir, 'healthy_average.png'), healthy_avg)
plt.imsave(os.path.join(output_dir, 'infected_average.png'), infected_avg)
plt.imsave(os.path.join(output_dir, 'difference_image.png'), 
           create_difference_image(healthy_avg, infected_avg))
plt.imsave(os.path.join(output_dir, 'healthy_montage.png'), healthy_montage)
plt.imsave(os.path.join(output_dir, 'infected_montage.png'), infected_montage)

print("\nEDA completed! Visualizations saved in outputs/visualizations/")