# Week 1: Data Exploration and Visualization

**Objective:** To understand the structure, content, and distribution of the waste classification dataset.

This notebook covers:
- Loading the dataset from disk
- Visualizing sample images from each class
- Analyzing the class distribution to check for imbalances
- Understanding image dimensions and characteristics

## 1.1 - Setup

Import necessary libraries and configure paths.

In [None]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image

# Add project root to path
sys.path.append(os.path.dirname(os.getcwd()))
from src.config import *

plt.style.use(PLOT_STYLE)

## 1.2 - Dataset Overview

Let's examine the raw dataset directory to understand its structure.

In [None]:
def get_dataset_stats(data_dir):
    """
    Calculates and returns statistics about the dataset.
    
    Arguments:
    data_dir -- Path, directory of the raw dataset.
    
    Returns:
    stats -- dict, a dictionary containing class names and image counts.
    """
    stats = {}
    for class_name in CLASS_NAMES:
        class_dir = data_dir / class_name
        stats[class_name] = len(list(class_dir.glob('*.jpg')))
    return stats

dataset_stats = get_dataset_stats(RAW_DATA_DIR)
total_images = sum(dataset_stats.values())

print(f"Total number of classes: {len(CLASS_NAMES)}")
print(f"Total number of images: {total_images}")
print("\nImages per class:")
for class_name, count in dataset_stats.items():
    print(f"  {class_name:12s}: {count:5d} images")

## 1.3 - Class Distribution

Visualizing the class distribution helps identify any potential class imbalance, which can affect model training.

In [None]:
plt.figure(figsize=(14, 6))
colors = sns.color_palette('husl', len(CLASS_NAMES))
bars = plt.bar(list(dataset_stats.keys()), list(dataset_stats.values()), color=colors)
plt.title('Class Distribution of Waste Images', fontsize=16, fontweight='bold', pad=20)
plt.xlabel('Waste Category', fontsize=12)
plt.ylabel('Number of Images', fontsize=12)
plt.xticks(rotation=45, ha='right')
plt.grid(axis='y', alpha=0.3)

# Add value labels on bars
for bar in bars:
    height = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2., height,
             f'{int(height)}', ha='center', va='bottom', fontsize=10)

plt.tight_layout()
plt.show()

**Observation:** The dataset shows varying class sizes, with some classes having significantly more samples than others. This is important to note for potential data balancing strategies during training.

## 1.4 - Sample Images Visualization

Let's visualize sample images from each waste category to understand what each class looks like.

In [None]:
def display_sample_images(data_dir, class_names, samples_per_class=3):
    """
    Displays sample images from each class.
    
    Arguments:
    data_dir -- Path, directory of the raw dataset.
    class_names -- list, list of class names.
    samples_per_class -- int, number of samples to display per class.
    """
    fig, axes = plt.subplots(len(class_names), samples_per_class, 
                             figsize=(15, len(class_names) * 2.5))
    
    for i, class_name in enumerate(class_names):
        class_dir = data_dir / class_name
        image_files = list(class_dir.glob('*.jpg'))[:samples_per_class]
        
        for j, img_path in enumerate(image_files):
            img = Image.open(img_path)
            axes[i, j].imshow(img)
            axes[i, j].axis('off')
            if j == 0:
                axes[i, j].set_title(f"{class_name.upper()}", 
                                     fontsize=12, fontweight='bold', loc='left')
    
    plt.suptitle('Sample Images from Each Class', fontsize=16, fontweight='bold', y=0.995)
    plt.tight_layout()
    plt.show()

display_sample_images(RAW_DATA_DIR, CLASS_NAMES, samples_per_class=3)

## 1.5 - Image Dimensions Analysis

Understanding the dimensions of images in the dataset helps determine appropriate preprocessing steps.

In [None]:
def analyze_image_dimensions(data_dir, class_names, sample_size=100):
    """
    Analyzes the dimensions of images in the dataset.
    
    Arguments:
    data_dir -- Path, directory of the raw dataset.
    class_names -- list, list of class names.
    sample_size -- int, number of images to sample per class.
    """
    widths = []
    heights = []
    
    for class_name in class_names:
        class_dir = data_dir / class_name
        image_files = list(class_dir.glob('*.jpg'))[:sample_size]
        
        for img_path in image_files:
            img = Image.open(img_path)
            widths.append(img.width)
            heights.append(img.height)
    
    print(f"Analyzed {len(widths)} images")
    print(f"\nWidth statistics:")
    print(f"  Min: {min(widths)} | Max: {max(widths)} | Mean: {np.mean(widths):.1f}")
    print(f"\nHeight statistics:")
    print(f"  Min: {min(heights)} | Max: {max(heights)} | Mean: {np.mean(heights):.1f}")
    
    # Visualize
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    axes[0].hist(widths, bins=30, edgecolor='black', color='steelblue', alpha=0.7)
    axes[0].set_title('Distribution of Image Widths', fontsize=14, fontweight='bold')
    axes[0].set_xlabel('Width (pixels)')
    axes[0].set_ylabel('Frequency')
    axes[0].grid(alpha=0.3)
    
    axes[1].hist(heights, bins=30, edgecolor='black', color='coral', alpha=0.7)
    axes[1].set_title('Distribution of Image Heights', fontsize=14, fontweight='bold')
    axes[1].set_xlabel('Height (pixels)')
    axes[1].set_ylabel('Frequency')
    axes[1].grid(alpha=0.3)
    
    plt.tight_layout()
    plt.show()

analyze_image_dimensions(RAW_DATA_DIR, CLASS_NAMES, sample_size=100)

## 1.6 - Key Findings

**Summary of Data Exploration:**

1. **Dataset Structure**: The dataset is organized into subdirectories by waste category.
2. **Class Balance**: The dataset shows varying class sizes - some classes like 'clothes' have significantly more samples.
3. **Image Variety**: Each class contains diverse images with varying backgrounds and lighting conditions.
4. **Image Dimensions**: Images have varying dimensions and will need to be resized to a consistent size for model training.

**Next Steps:**
- Preprocess the data (resize, normalize)
- Split the dataset into training, validation, and test sets
- Implement data augmentation strategies
- Build and train a baseline CNN model

## 1.7 - Save Visualizations (Optional)

Save the class distribution plot for reporting purposes.

In [None]:
# Create reports directory if it doesn't exist
REPORTS_DIR.mkdir(parents=True, exist_ok=True)

# Save class distribution plot
plt.figure(figsize=(14, 6))
colors = sns.color_palette('husl', len(CLASS_NAMES))
bars = plt.bar(list(dataset_stats.keys()), list(dataset_stats.values()), color=colors)
plt.title('Class Distribution of Waste Images', fontsize=16, fontweight='bold', pad=20)
plt.xlabel('Waste Category', fontsize=12)
plt.ylabel('Number of Images', fontsize=12)
plt.xticks(rotation=45, ha='right')
plt.grid(axis='y', alpha=0.3)

# Add value labels on bars
for bar in bars:
    height = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2., height,
             f'{int(height)}', ha='center', va='bottom', fontsize=10)

plt.tight_layout()
plt.savefig(REPORTS_DIR / 'class_distribution.png', dpi=300, bbox_inches='tight')
print(f"✅ Saved class distribution plot to: {REPORTS_DIR / 'class_distribution.png'}")
plt.show()