# üìä Data Exploration for Image Classification

## Learning Objectives
- Load and explore the CIFAR-10 dataset
- Understand image data structure and properties
- Visualize sample images and class distributions
- Analyze dataset statistics and characteristics
- Identify potential challenges and preprocessing needs

## Dataset Overview: CIFAR-10

The CIFAR-10 dataset consists of 60,000 32x32 color images in 10 classes:
- **Classes**: airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck
- **Training set**: 50,000 images
- **Test set**: 10,000 images
- **Image size**: 32√ó32√ó3 (RGB)
- **Balanced**: 6,000 images per class

In [None]:
# Import required libraries
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
from tensorflow import keras
import pandas as pd
from collections import Counter
import warnings
warnings.filterwarnings('ignore')

# Set style for better plots
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

# Set random seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)

print(f"TensorFlow version: {tf.__version__}")
print(f"Keras version: {keras.__version__}")
print(f"NumPy version: {np.__version__}")

## 1. Loading the Dataset

In [None]:
# Load CIFAR-10 dataset
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()

# Define class names
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 
               'dog', 'frog', 'horse', 'ship', 'truck']

print("Dataset loaded successfully!")
print(f"Training set shape: {x_train.shape}")
print(f"Training labels shape: {y_train.shape}")
print(f"Test set shape: {x_test.shape}")
print(f"Test labels shape: {y_test.shape}")
print(f"Number of classes: {len(class_names)}")
print(f"Class names: {class_names}")

## 2. Basic Dataset Statistics

In [None]:
# Calculate basic statistics
total_images = len(x_train) + len(x_test)
image_height, image_width, channels = x_train.shape[1:]

print("üìä Dataset Statistics")
print("=" * 40)
print(f"Total images: {total_images:,}")
print(f"Training images: {len(x_train):,}")
print(f"Test images: {len(x_test):,}")
print(f"Image dimensions: {image_height}√ó{image_width}√ó{channels}")
print(f"Data type: {x_train.dtype}")
print(f"Pixel value range: [{x_train.min()}, {x_train.max()}]")
print(f"Memory usage (training): {x_train.nbytes / (1024**2):.1f} MB")
print(f"Memory usage (test): {x_test.nbytes / (1024**2):.1f} MB")

# Label statistics
print(f"\nüè∑Ô∏è Label Statistics")
print("=" * 40)
print(f"Label data type: {y_train.dtype}")
print(f"Label range: [{y_train.min()}, {y_train.max()}]")
print(f"Unique labels: {np.unique(y_train)}")

## 3. Class Distribution Analysis

In [None]:
# Analyze class distribution
train_counts = Counter(y_train.flatten())
test_counts = Counter(y_test.flatten())

# Create distribution dataframe
distribution_data = []
for i, class_name in enumerate(class_names):
    distribution_data.append({
        'Class': class_name,
        'Class_ID': i,
        'Train_Count': train_counts[i],
        'Test_Count': test_counts[i],
        'Total_Count': train_counts[i] + test_counts[i],
        'Train_Percentage': (train_counts[i] / len(y_train)) * 100,
        'Test_Percentage': (test_counts[i] / len(y_test)) * 100
    })

df_distribution = pd.DataFrame(distribution_data)
print("üìà Class Distribution")
print(df_distribution.to_string(index=False))

# Visualize class distribution
fig, axes = plt.subplots(1, 2, figsize=(15, 6))

# Training set distribution
axes[0].bar(range(len(class_names)), [train_counts[i] for i in range(len(class_names))], 
           color='skyblue', alpha=0.8, edgecolor='black')
axes[0].set_xlabel('Classes')
axes[0].set_ylabel('Number of Images')
axes[0].set_title('Training Set Class Distribution')
axes[0].set_xticks(range(len(class_names)))
axes[0].set_xticklabels(class_names, rotation=45, ha='right')
axes[0].grid(True, alpha=0.3)

# Add count labels on bars
for i, count in enumerate([train_counts[i] for i in range(len(class_names))]):
    axes[0].text(i, count + 50, str(count), ha='center', va='bottom')

# Test set distribution
axes[1].bar(range(len(class_names)), [test_counts[i] for i in range(len(class_names))], 
           color='lightcoral', alpha=0.8, edgecolor='black')
axes[1].set_xlabel('Classes')
axes[1].set_ylabel('Number of Images')
axes[1].set_title('Test Set Class Distribution')
axes[1].set_xticks(range(len(class_names)))
axes[1].set_xticklabels(class_names, rotation=45, ha='right')
axes[1].grid(True, alpha=0.3)

# Add count labels on bars
for i, count in enumerate([test_counts[i] for i in range(len(class_names))]):
    axes[1].text(i, count + 20, str(count), ha='center', va='bottom')

plt.tight_layout()
plt.show()

# Check if dataset is balanced
train_std = np.std([train_counts[i] for i in range(len(class_names))])
test_std = np.std([test_counts[i] for i in range(len(class_names))])

print(f"\n‚öñÔ∏è Balance Analysis")
print(f"Training set standard deviation: {train_std:.2f}")
print(f"Test set standard deviation: {test_std:.2f}")
print(f"Dataset is {'balanced' if train_std == 0 and test_std == 0 else 'imbalanced'}")

## 4. Sample Image Visualization

In [None]:
# Display sample images from each class
def plot_sample_images(x_data, y_data, class_names, samples_per_class=5):
    """
    Plot sample images from each class
    """
    num_classes = len(class_names)
    fig, axes = plt.subplots(num_classes, samples_per_class, 
                            figsize=(samples_per_class * 2, num_classes * 2))
    
    for class_idx in range(num_classes):
        # Find indices for current class
        class_indices = np.where(y_data.flatten() == class_idx)[0]
        
        # Randomly sample images from this class
        sample_indices = np.random.choice(class_indices, samples_per_class, replace=False)
        
        for sample_idx in range(samples_per_class):
            img_idx = sample_indices[sample_idx]
            image = x_data[img_idx]
            
            axes[class_idx, sample_idx].imshow(image)
            axes[class_idx, sample_idx].axis('off')
            
            # Add class name to first image of each row
            if sample_idx == 0:
                axes[class_idx, sample_idx].set_ylabel(class_names[class_idx], 
                                                      rotation=0, ha='right', va='center')
    
    plt.suptitle('Sample Images from Each Class', fontsize=16, y=0.98)
    plt.tight_layout()
    plt.show()

# Plot sample images
plot_sample_images(x_train, y_train, class_names, samples_per_class=8)

## 5. Pixel Value Analysis

In [None]:
# Analyze pixel value distributions
def analyze_pixel_statistics(x_data, title_prefix=""):
    """
    Analyze and visualize pixel value statistics
    """
    # Calculate statistics for each channel
    stats = {}
    channel_names = ['Red', 'Green', 'Blue']
    
    for i, channel in enumerate(channel_names):
        channel_data = x_data[:, :, :, i].flatten()
        stats[channel] = {
            'mean': np.mean(channel_data),
            'std': np.std(channel_data),
            'min': np.min(channel_data),
            'max': np.max(channel_data),
            'median': np.median(channel_data)
        }
    
    # Overall statistics
    overall_data = x_data.flatten()
    stats['Overall'] = {
        'mean': np.mean(overall_data),
        'std': np.std(overall_data),
        'min': np.min(overall_data),
        'max': np.max(overall_data),
        'median': np.median(overall_data)
    }
    
    return stats

# Calculate statistics for training and test sets
train_stats = analyze_pixel_statistics(x_train, "Training")
test_stats = analyze_pixel_statistics(x_test, "Test")

# Display statistics
print("üé® Pixel Value Statistics")
print("=" * 60)
print(f"{'Channel':<10} {'Mean':<8} {'Std':<8} {'Min':<5} {'Max':<5} {'Median':<8}")
print("-" * 60)

for channel in ['Red', 'Green', 'Blue', 'Overall']:
    stats = train_stats[channel]
    print(f"{channel:<10} {stats['mean']:<8.1f} {stats['std']:<8.1f} "
          f"{stats['min']:<5.0f} {stats['max']:<5.0f} {stats['median']:<8.1f}")

# Visualize pixel value distributions
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Overall pixel distribution
axes[0, 0].hist(x_train.flatten(), bins=50, alpha=0.7, color='skyblue', 
               edgecolor='black', density=True)
axes[0, 0].set_title('Overall Pixel Value Distribution (Training)')
axes[0, 0].set_xlabel('Pixel Value')
axes[0, 0].set_ylabel('Density')
axes[0, 0].grid(True, alpha=0.3)

# Per-channel distributions
colors = ['red', 'green', 'blue']
for i, (color, channel) in enumerate(zip(colors, ['Red', 'Green', 'Blue'])):
    channel_data = x_train[:, :, :, i].flatten()
    axes[0, 1].hist(channel_data, bins=50, alpha=0.5, color=color, 
                   label=channel, density=True)

axes[0, 1].set_title('Per-Channel Pixel Distributions')
axes[0, 1].set_xlabel('Pixel Value')
axes[0, 1].set_ylabel('Density')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Mean pixel values per class
class_means = []
for class_idx in range(len(class_names)):
    class_mask = (y_train.flatten() == class_idx)
    class_images = x_train[class_mask]
    class_mean = np.mean(class_images)
    class_means.append(class_mean)

axes[1, 0].bar(range(len(class_names)), class_means, color='lightgreen', 
              alpha=0.8, edgecolor='black')
axes[1, 0].set_title('Mean Pixel Value per Class')
axes[1, 0].set_xlabel('Class')
axes[1, 0].set_ylabel('Mean Pixel Value')
axes[1, 0].set_xticks(range(len(class_names)))
axes[1, 0].set_xticklabels(class_names, rotation=45, ha='right')
axes[1, 0].grid(True, alpha=0.3)

# Standard deviation per class
class_stds = []
for class_idx in range(len(class_names)):
    class_mask = (y_train.flatten() == class_idx)
    class_images = x_train[class_mask]
    class_std = np.std(class_images)
    class_stds.append(class_std)

axes[1, 1].bar(range(len(class_names)), class_stds, color='orange', 
              alpha=0.8, edgecolor='black')
axes[1, 1].set_title('Pixel Value Std Dev per Class')
axes[1, 1].set_xlabel('Class')
axes[1, 1].set_ylabel('Standard Deviation')
axes[1, 1].set_xticks(range(len(class_names)))
axes[1, 1].set_xticklabels(class_names, rotation=45, ha='right')
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 6. Image Complexity Analysis

In [None]:
# Analyze image complexity using various metrics
def calculate_image_complexity(images, sample_size=1000):
    """
    Calculate various complexity metrics for images
    """
    # Sample random images for analysis
    indices = np.random.choice(len(images), min(sample_size, len(images)), replace=False)
    sample_images = images[indices]
    
    complexities = []
    
    for img in sample_images:
        # Convert to grayscale for some metrics
        gray = np.dot(img[...,:3], [0.2989, 0.5870, 0.1140])
        
        # Calculate various complexity metrics
        metrics = {
            'variance': np.var(img),
            'gradient_magnitude': np.mean(np.abs(np.gradient(gray))),
            'edge_density': np.sum(np.abs(np.gradient(gray))) / (32 * 32),
            'color_diversity': len(np.unique(img.reshape(-1, 3), axis=0)),
            'brightness': np.mean(img)
        }
        complexities.append(metrics)
    
    return complexities

# Calculate complexity for sample images
print("üîç Analyzing image complexity...")
complexities = calculate_image_complexity(x_train, sample_size=1000)

# Convert to DataFrame for easier analysis
df_complexity = pd.DataFrame(complexities)

print("\nüìä Image Complexity Statistics")
print(df_complexity.describe())

# Visualize complexity distributions
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

metrics = ['variance', 'gradient_magnitude', 'edge_density', 'color_diversity', 'brightness']
titles = ['Pixel Variance', 'Gradient Magnitude', 'Edge Density', 'Color Diversity', 'Brightness']

for i, (metric, title) in enumerate(zip(metrics, titles)):
    axes[i].hist(df_complexity[metric], bins=30, alpha=0.7, color='purple', edgecolor='black')
    axes[i].set_title(f'{title} Distribution')
    axes[i].set_xlabel(title)
    axes[i].set_ylabel('Frequency')
    axes[i].grid(True, alpha=0.3)

# Correlation heatmap
correlation_matrix = df_complexity.corr()
im = axes[5].imshow(correlation_matrix, cmap='coolwarm', aspect='auto', vmin=-1, vmax=1)
axes[5].set_title('Complexity Metrics Correlation')
axes[5].set_xticks(range(len(metrics)))
axes[5].set_yticks(range(len(metrics)))
axes[5].set_xticklabels(titles, rotation=45, ha='right')
axes[5].set_yticklabels(titles)

# Add correlation values to heatmap
for i in range(len(metrics)):
    for j in range(len(metrics)):
        text = axes[5].text(j, i, f'{correlation_matrix.iloc[i, j]:.2f}',
                           ha="center", va="center", color="black" if abs(correlation_matrix.iloc[i, j]) < 0.5 else "white")

plt.colorbar(im, ax=axes[5], shrink=0.8)
plt.tight_layout()
plt.show()

## 7. Class-wise Visual Analysis

In [None]:
# Calculate average image for each class
def calculate_class_averages(x_data, y_data, class_names):
    """
    Calculate average image for each class
    """
    class_averages = []
    
    for class_idx in range(len(class_names)):
        class_mask = (y_data.flatten() == class_idx)
        class_images = x_data[class_mask]
        class_average = np.mean(class_images, axis=0)
        class_averages.append(class_average)
    
    return np.array(class_averages)

# Calculate class averages
class_averages = calculate_class_averages(x_train, y_train, class_names)

# Visualize class averages
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
axes = axes.flatten()

for i, (avg_img, class_name) in enumerate(zip(class_averages, class_names)):
    # Normalize for display
    display_img = (avg_img - avg_img.min()) / (avg_img.max() - avg_img.min())
    
    axes[i].imshow(display_img)
    axes[i].set_title(f'{class_name}\n(Average)', fontsize=10)
    axes[i].axis('off')

plt.suptitle('Average Images per Class', fontsize=16)
plt.tight_layout()
plt.show()

# Analyze class separability using average images
def calculate_class_distances(class_averages):
    """
    Calculate pairwise distances between class averages
    """
    n_classes = len(class_averages)
    distances = np.zeros((n_classes, n_classes))
    
    for i in range(n_classes):
        for j in range(n_classes):
            # Calculate Euclidean distance
            distances[i, j] = np.linalg.norm(class_averages[i] - class_averages[j])
    
    return distances

# Calculate distances
class_distances = calculate_class_distances(class_averages)

# Visualize class distance matrix
plt.figure(figsize=(10, 8))
im = plt.imshow(class_distances, cmap='viridis')
plt.colorbar(im, label='Euclidean Distance')
plt.title('Pairwise Distances Between Class Averages')
plt.xlabel('Class')
plt.ylabel('Class')
plt.xticks(range(len(class_names)), class_names, rotation=45, ha='right')
plt.yticks(range(len(class_names)), class_names)

# Add distance values to heatmap
for i in range(len(class_names)):
    for j in range(len(class_names)):
        text = plt.text(j, i, f'{class_distances[i, j]:.0f}',
                       ha="center", va="center", 
                       color="white" if class_distances[i, j] > np.mean(class_distances) else "black")

plt.tight_layout()
plt.show()

# Find most and least similar class pairs
# Set diagonal to infinity to ignore self-comparisons
distance_copy = class_distances.copy()
np.fill_diagonal(distance_copy, np.inf)

# Find most similar (smallest distance)
min_idx = np.unravel_index(np.argmin(distance_copy), distance_copy.shape)
most_similar = (class_names[min_idx[0]], class_names[min_idx[1]], distance_copy[min_idx])

# Find least similar (largest distance)
max_idx = np.unravel_index(np.argmax(distance_copy), distance_copy.shape)
least_similar = (class_names[max_idx[0]], class_names[max_idx[1]], distance_copy[max_idx])

print(f"\nüîç Class Similarity Analysis")
print(f"Most similar classes: {most_similar[0]} ‚Üî {most_similar[1]} (distance: {most_similar[2]:.1f})")
print(f"Least similar classes: {least_similar[0]} ‚Üî {least_similar[1]} (distance: {least_similar[2]:.1f})")
print(f"Average inter-class distance: {np.mean(distance_copy[distance_copy != np.inf]):.1f}")

## 8. Data Quality Assessment

In [None]:
# Check for potential data quality issues
def assess_data_quality(x_data, y_data):
    """
    Assess various data quality metrics
    """
    issues = []
    
    # Check for missing values
    if np.isnan(x_data).any():
        nan_count = np.isnan(x_data).sum()
        issues.append(f"Found {nan_count} NaN values in images")
    
    # Check for infinite values
    if np.isinf(x_data).any():
        inf_count = np.isinf(x_data).sum()
        issues.append(f"Found {inf_count} infinite values in images")
    
    # Check for completely black images
    black_images = np.sum(x_data, axis=(1, 2, 3)) == 0
    if black_images.any():
        black_count = black_images.sum()
        issues.append(f"Found {black_count} completely black images")
    
    # Check for completely white images
    white_images = np.all(x_data == 255, axis=(1, 2, 3))
    if white_images.any():
        white_count = white_images.sum()
        issues.append(f"Found {white_count} completely white images")
    
    # Check for duplicate images
    unique_images = np.unique(x_data.reshape(len(x_data), -1), axis=0)
    if len(unique_images) < len(x_data):
        duplicate_count = len(x_data) - len(unique_images)
        issues.append(f"Found {duplicate_count} duplicate images")
    
    # Check label consistency
    if y_data.min() < 0 or y_data.max() >= len(class_names):
        issues.append(f"Labels out of expected range [0, {len(class_names)-1}]")
    
    return issues

# Assess data quality
print("üîç Data Quality Assessment")
print("=" * 40)

train_issues = assess_data_quality(x_train, y_train)
test_issues = assess_data_quality(x_test, y_test)

if not train_issues and not test_issues:
    print("‚úÖ No data quality issues detected!")
else:
    if train_issues:
        print("‚ö†Ô∏è Training set issues:")
        for issue in train_issues:
            print(f"  - {issue}")
    
    if test_issues:
        print("‚ö†Ô∏è Test set issues:")
        for issue in test_issues:
            print(f"  - {issue}")

# Additional quality metrics
print(f"\nüìä Additional Quality Metrics")
print(f"Training set memory usage: {x_train.nbytes / (1024**2):.1f} MB")
print(f"Test set memory usage: {x_test.nbytes / (1024**2):.1f} MB")
print(f"Images per class (train): {len(x_train) // len(class_names)}")
print(f"Images per class (test): {len(x_test) // len(class_names)}")
print(f"Train/test split ratio: {len(x_train) / len(x_test):.1f}:1")

## 9. Preprocessing Recommendations

In [None]:
# Generate preprocessing recommendations based on analysis
def generate_preprocessing_recommendations(x_data, stats):
    """
    Generate preprocessing recommendations based on data analysis
    """
    recommendations = []
    
    # Normalization recommendation
    pixel_range = x_data.max() - x_data.min()
    if pixel_range > 1:
        recommendations.append(
            "üîß Normalize pixel values to [0, 1] range by dividing by 255"
        )
    
    # Standardization recommendation
    overall_mean = stats['Overall']['mean']
    overall_std = stats['Overall']['std']
    recommendations.append(
        f"üìä Consider standardization: mean={overall_mean:.1f}, std={overall_std:.1f}"
    )
    
    # Data augmentation recommendation
    recommendations.append(
        "üîÑ Apply data augmentation: rotation, flipping, zoom, brightness adjustment"
    )
    
    # Resize recommendation for transfer learning
    current_size = x_data.shape[1:3]
    if current_size != (224, 224):
        recommendations.append(
            f"üìè Resize images to 224√ó224 for transfer learning (current: {current_size[0]}√ó{current_size[1]})"
        )
    
    # Memory optimization
    if x_data.dtype != np.float32:
        recommendations.append(
            "üíæ Convert to float32 for memory efficiency and GPU compatibility"
        )
    
    return recommendations

# Generate recommendations
recommendations = generate_preprocessing_recommendations(x_train, train_stats)

print("üí° Preprocessing Recommendations")
print("=" * 50)
for i, rec in enumerate(recommendations, 1):
    print(f"{i}. {rec}")

# Show example of normalization effect
sample_image = x_train[0]
normalized_image = sample_image.astype(np.float32) / 255.0

fig, axes = plt.subplots(1, 2, figsize=(10, 4))

axes[0].imshow(sample_image)
axes[0].set_title(f'Original\nRange: [{sample_image.min()}, {sample_image.max()}]')
axes[0].axis('off')

axes[1].imshow(normalized_image)
axes[1].set_title(f'Normalized\nRange: [{normalized_image.min():.2f}, {normalized_image.max():.2f}]')
axes[1].axis('off')

plt.suptitle('Normalization Example')
plt.tight_layout()
plt.show()

## üìã Summary and Next Steps

### Key Findings:

1. **Dataset Structure**: 
   - 50,000 training images, 10,000 test images
   - 10 balanced classes with 5,000 training images each
   - 32√ó32√ó3 RGB images

2. **Data Quality**: 
   - No missing or invalid values detected
   - Balanced class distribution
   - Pixel values in [0, 255] range

3. **Complexity Analysis**:
   - Varying levels of image complexity across classes
   - Some classes are more visually similar than others
   - Good diversity in color and texture patterns

### Preprocessing Pipeline:

Based on our analysis, the recommended preprocessing steps are:

1. **Normalization**: Scale pixel values to [0, 1]
2. **Data Augmentation**: Apply transformations to increase dataset diversity
3. **Resizing**: Resize to 224√ó224 for transfer learning models
4. **Data Type**: Convert to float32 for efficiency

### Expected Challenges:

1. **Low Resolution**: 32√ó32 images have limited detail
2. **Similar Classes**: Some classes (e.g., cat/dog) may be harder to distinguish
3. **Intra-class Variation**: High variation within classes

### Next Steps:

1. **Implement Preprocessing Pipeline** ‚Üí `02_preprocessing.ipynb`
2. **Build CNN Architecture** ‚Üí `03_model_building.ipynb`
3. **Train and Evaluate Models** ‚Üí `04_training.ipynb`

---

**Ready to move to preprocessing?** The next notebook will implement the preprocessing pipeline based on our findings here! üöÄ