# Data Exploration for Image Classification

This notebook explores the chosen dataset (CIFAR-10 or Animals10) for image classification.

## Objectives:
- Load and inspect the dataset
- Visualize sample images from each class
- Analyze class distribution
- Understand image dimensions and characteristics
- Plan preprocessing strategy

## 1. Setup and Imports

In [3]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
import os

# Deep learning libraries
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.datasets import cifar10

# Set style for better plots
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

print(f"TensorFlow version: {tf.__version__}")
print(f"GPU Available: {tf.config.list_physical_devices('GPU')}")

TensorFlow version: 2.15.0
GPU Available: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


## 2. Dataset Selection and Loading

Choose between CIFAR-10 or Animals10 dataset.

In [2]:
# Dataset choice - modify this to switch between datasets
DATASET_CHOICE = "animals10"  # Options: "cifar10" or "animals10"

if DATASET_CHOICE == "cifar10":
    # Load CIFAR-10 dataset
    (x_train, y_train), (x_test, y_test) = cifar10.load_data()
    
    # CIFAR-10 class names
    class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 
                   'dog', 'frog', 'horse', 'ship', 'truck']
    
    print(f"CIFAR-10 Dataset Loaded")
    print(f"Training set: {x_train.shape}")
    print(f"Test set: {x_test.shape}")
    
elif DATASET_CHOICE == "animals10":
    import kagglehub
    print("Download from: https://www.kaggle.com/datasets/alessiocorrado99/animals10/data")
    path = kagglehub.dataset_download("alessiocorrado99/animals10")
    print("Path to dataset files:", path) 
    
    
    
    
    # Placeholder - implement Animals10 loading
    # data_dir = "../data/animals10/"  
    # Load using tf.keras.preprocessing.image_dataset_from_directory
    
    class_names = ['dog', 'cat', 'horse', 'spider', 'butterfly', 
                   'chicken', 'sheep', 'cow', 'squirrel', 'elephant']

  from .autonotebook import tqdm as notebook_tqdm


Download from: https://www.kaggle.com/datasets/alessiocorrado99/animals10/data
Downloading from https://www.kaggle.com/api/v1/datasets/download/alessiocorrado99/animals10?dataset_version_number=2...


100%|██████████| 586M/586M [08:11<00:00, 1.25MB/s] 

Extracting files...





Path to dataset files: /Users/smithn5/.cache/kagglehub/datasets/alessiocorrado99/animals10/versions/2


## 3. Dataset Overview and Statistics

In [None]:
if DATASET_CHOICE == "cifar10":
    print("Dataset Statistics:")
    print(f"Number of training samples: {len(x_train)}")
    print(f"Number of test samples: {len(x_test)}")
    print(f"Image shape: {x_train[0].shape}")
    print(f"Number of classes: {len(class_names)}")
    print(f"Pixel value range: [{x_train.min()}, {x_train.max()}]")
    print(f"Data type: {x_train.dtype}")

## 4. Class Distribution Analysis

In [None]:
if DATASET_CHOICE == "cifar10":
    # Analyze class distribution
    train_class_counts = Counter(y_train.flatten())
    test_class_counts = Counter(y_test.flatten())
    
    # Create distribution plots
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    # Training set distribution
    train_counts = [train_class_counts[i] for i in range(len(class_names))]
    ax1.bar(class_names, train_counts)
    ax1.set_title('Training Set Class Distribution')
    ax1.set_xlabel('Classes')
    ax1.set_ylabel('Number of Images')
    ax1.tick_params(axis='x', rotation=45)
    
    # Test set distribution
    test_counts = [test_class_counts[i] for i in range(len(class_names))]
    ax2.bar(class_names, test_counts)
    ax2.set_title('Test Set Class Distribution')
    ax2.set_xlabel('Classes')
    ax2.set_ylabel('Number of Images')
    ax2.tick_params(axis='x', rotation=45)
    
    plt.tight_layout()
    plt.show()
    
    print("\nClass distribution:")
    for i, class_name in enumerate(class_names):
        print(f"{class_name}: Train={train_class_counts[i]}, Test={test_class_counts[i]}")

## 5. Sample Image Visualization

In [None]:
if DATASET_CHOICE == "cifar10":
    # Function to display sample images
    def plot_sample_images(X, y, class_names, samples_per_class=5):
        fig, axes = plt.subplots(len(class_names), samples_per_class, 
                               figsize=(samples_per_class*2, len(class_names)*2))
        
        for class_idx in range(len(class_names)):
            # Find indices for current class
            class_indices = np.where(y.flatten() == class_idx)[0]
            
            for sample_idx in range(samples_per_class):
                if sample_idx < len(class_indices):
                    img_idx = class_indices[sample_idx]
                    img = X[img_idx]
                    
                    axes[class_idx, sample_idx].imshow(img)
                    axes[class_idx, sample_idx].set_title(f'{class_names[class_idx]}')
                    axes[class_idx, sample_idx].axis('off')
                else:
                    axes[class_idx, sample_idx].axis('off')
        
        plt.tight_layout()
        plt.show()
    
    # Display sample images
    print("Sample images from each class:")
    plot_sample_images(x_train, y_train, class_names, samples_per_class=5)

## 6. Pixel Value Distribution Analysis

In [None]:
if DATASET_CHOICE == "cifar10":
    # Analyze pixel value distributions
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
    # RGB channel distributions
    colors = ['red', 'green', 'blue']
    channel_names = ['Red', 'Green', 'Blue']
    
    for i in range(3):
        channel_data = x_train[:, :, :, i].flatten()
        axes[i].hist(channel_data, bins=50, color=colors[i], alpha=0.7)
        axes[i].set_title(f'{channel_names[i]} Channel Distribution')
        axes[i].set_xlabel('Pixel Value')
        axes[i].set_ylabel('Frequency')
    
    plt.tight_layout()
    plt.show()
    
    # Calculate channel statistics
    print("\nChannel Statistics:")
    for i, channel in enumerate(channel_names):
        channel_data = x_train[:, :, :, i]
        print(f"{channel} - Mean: {channel_data.mean():.2f}, Std: {channel_data.std():.2f}")

## 7. Key Insights and Preprocessing Strategy

### Observations:

**Dataset Characteristics:**
- Image dimensions: 32x32x3 (for CIFAR-10)
- Pixel values: 0-255 (uint8)
- Classes: Balanced distribution

**Preprocessing Strategy:**
1. **Normalization**: Scale pixel values to [0,1] range
2. **Data Augmentation**: Apply rotation, flip, zoom, shift
3. **Train/Val Split**: Use 80/20 or 70/15/15 split
4. **One-hot Encoding**: Convert labels for categorical crossentropy

**Next Steps:**
- Implement preprocessing pipeline
- Design custom CNN architecture
- Implement transfer learning approach
- Compare model performances

In [None]:
# Save key variables for next notebooks
print(f"Dataset choice: {DATASET_CHOICE}")
print(f"Number of classes: {len(class_names)}")
print(f"Class names: {class_names}")
if DATASET_CHOICE == "cifar10":
    print(f"Image shape: {x_train.shape[1:]}")
    print(f"Training samples: {len(x_train)}")
    print(f"Test samples: {len(x_test)}")