# 02. Dataset Exploration

This notebook explores the MNIST and CUB-200 datasets used in federated learning experiments.

**Topics covered:**
- MNIST dataset loading and visualization
- CUB-200 dataset loading and visualization
- IID vs Non-IID data partitioning

## Setup

In [None]:
import sys
import os

# Navigate to project root if in notebooks folder
if os.path.basename(os.getcwd()) == 'notebooks':
    os.chdir('..')

PROJECT_ROOT = os.getcwd()
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

print(f"Project root: {PROJECT_ROOT}")

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from src.utils.data_loader import load_mnist, get_client_data, get_class_distribution

plt.style.use('seaborn-v0_8-whitegrid')
%matplotlib inline

---
## 1. MNIST Dataset

In [None]:
# Load MNIST (auto-downloads if not present)
train_data, test_data = load_mnist("./data")

print(f"Training samples: {len(train_data):,}")
print(f"Test samples: {len(test_data):,}")
print(f"Image shape: {train_data[0][0].shape}")
print(f"Number of classes: 10 (digits 0-9)")

In [None]:
# Visualize sample images
fig, axes = plt.subplots(2, 10, figsize=(15, 4))
fig.suptitle("MNIST Sample Images", fontsize=14)

for i in range(20):
    img, label = train_data[i]
    ax = axes[i // 10, i % 10]
    ax.imshow(img.squeeze(), cmap='gray')
    ax.set_title(f"{label}")
    ax.axis('off')

plt.tight_layout()
plt.show()

In [None]:
# Class distribution
labels = np.array(train_data.targets)
unique, counts = np.unique(labels, return_counts=True)

plt.figure(figsize=(10, 4))
plt.bar(unique, counts, color='steelblue')
plt.xlabel('Digit Class')
plt.ylabel('Count')
plt.title('MNIST Training Set Class Distribution')
plt.xticks(unique)
plt.show()

print("\nClass counts:")
for u, c in zip(unique, counts):
    print(f"  Class {u}: {c:,} samples")

---
## 2. IID vs Non-IID Partitioning

Federated Learning distributes data across clients. The data distribution can be:
- **IID**: Each client has similar class distribution (random split)
- **Non-IID**: Clients have different distributions (Dirichlet with α)

In [None]:
NUM_CLIENTS = 10

# IID Partitioning
print("=" * 50)
print("IID PARTITIONING")
print("=" * 50)

iid_distributions = []
for client_id in range(NUM_CLIENTS):
    client_data = get_client_data(train_data, client_id, NUM_CLIENTS, partition="iid")
    dist = get_class_distribution(client_data)
    iid_distributions.append(dist)
    print(f"Client {client_id}: {len(client_data):,} samples")

In [None]:
# Non-IID Partitioning (α = 0.5)
print("=" * 50)
print("NON-IID PARTITIONING (α = 0.5, moderate heterogeneity)")
print("=" * 50)

noniid_moderate = []
for client_id in range(NUM_CLIENTS):
    client_data = get_client_data(train_data, client_id, NUM_CLIENTS, partition="noniid", alpha=0.5)
    dist = get_class_distribution(client_data)
    noniid_moderate.append(dist)
    print(f"Client {client_id}: {len(client_data):,} samples, dominant: {max(dist, key=dist.get)}")

In [None]:
# Non-IID Partitioning (α = 0.1)
print("=" * 50)
print("NON-IID PARTITIONING (α = 0.1, extreme heterogeneity)")
print("=" * 50)

noniid_extreme = []
for client_id in range(NUM_CLIENTS):
    client_data = get_client_data(train_data, client_id, NUM_CLIENTS, partition="noniid", alpha=0.1)
    dist = get_class_distribution(client_data)
    noniid_extreme.append(dist)
    print(f"Client {client_id}: {len(client_data):,} samples, dominant: {max(dist, key=dist.get)}")

In [None]:
# Visualize partitioning differences
def plot_distribution(distributions, title):
    """Plot class distribution heatmap for all clients."""
    matrix = np.zeros((len(distributions), 10))
    for i, dist in enumerate(distributions):
        for cls, count in dist.items():
            matrix[i, cls] = count
    
    # Normalize to percentages
    matrix = matrix / matrix.sum(axis=1, keepdims=True) * 100
    
    plt.figure(figsize=(10, 4))
    sns.heatmap(matrix, annot=True, fmt='.0f', cmap='Blues',
                xticklabels=range(10), yticklabels=[f"Client {i}" for i in range(len(distributions))])
    plt.xlabel('Class')
    plt.ylabel('Client')
    plt.title(f'{title} (% of each client\'s data)')
    plt.show()

plot_distribution(iid_distributions, "IID Distribution")
plot_distribution(noniid_moderate, "Non-IID (α=0.5) Distribution")
plot_distribution(noniid_extreme, "Non-IID (α=0.1) Distribution")

---
## 3. CUB-200 Dataset

CUB-200 is a fine-grained bird classification dataset with 200 species.

**Note:** CUB-200 requires manual download. See `RUNPOD_SETUP_GUIDE.md` for instructions.

In [None]:
from src.utils.cub200_loader import load_cub200, CUB200Dataset

try:
    train_cub, test_cub = load_cub200("./data", download=False)
    print(f"CUB-200 Training samples: {len(train_cub):,}")
    print(f"CUB-200 Test samples: {len(test_cub):,}")
    print(f"Number of classes: 200 bird species")
    CUB_AVAILABLE = True
except RuntimeError as e:
    print(f"CUB-200 not available: {e}")
    print("\nTo download CUB-200, follow the instructions in RUNPOD_SETUP_GUIDE.md")
    CUB_AVAILABLE = False

In [None]:
# Visualize CUB-200 samples (if available)
if CUB_AVAILABLE:
    fig, axes = plt.subplots(2, 5, figsize=(15, 6))
    fig.suptitle("CUB-200 Sample Images", fontsize=14)
    
    for i in range(10):
        idx = i * 100  # Sample from different classes
        img, label = train_cub[idx]
        
        # Denormalize for visualization
        img = img.permute(1, 2, 0).numpy()
        img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
        img = np.clip(img, 0, 1)
        
        ax = axes[i // 5, i % 5]
        ax.imshow(img)
        ax.set_title(f"Class {label}")
        ax.axis('off')
    
    plt.tight_layout()
    plt.show()

---
## Summary

| Dataset | Classes | Train Samples | Test Samples | Use Case |
|---------|---------|---------------|--------------|----------|
| MNIST | 10 | 60,000 | 10,000 | Fast experiments, baseline |
| CUB-200 | 200 | ~5,994 | ~5,794 | Complex, realistic scenario |

Proceed to **03_quick_experiment.ipynb** to run a quick test.