# MNIST Data Exploration
## Hybrid Quantum-Classical Machine Learning Project

This notebook explores the MNIST dataset and verifies our data pipeline.

In [None]:
import sys
sys.path.append('..')

import torch
import numpy as np
import matplotlib.pyplot as plt
from src.data.dataset import MNISTDataModule, get_mnist_dataloaders

# Set style
plt.style.use('seaborn-v0_8-darkgrid')
%matplotlib inline

## 1. Load MNIST Dataset

In [None]:
# Initialize data module
data_module = MNISTDataModule(
    data_dir='../data/raw',
    batch_size=64,
    validation_split=0.1,
    seed=42
)

# Setup datasets
data_module.setup()

# Get data loaders
train_loader = data_module.train_dataloader()
val_loader = data_module.val_dataloader()
test_loader = data_module.test_dataloader()

print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")

## 2. Visualize Sample Images

In [None]:
# Get a batch of images
images, labels = data_module.get_sample_batch()

print(f"Batch shape: {images.shape}")
print(f"Labels shape: {labels.shape}")
print(f"Image value range: [{images.min():.3f}, {images.max():.3f}]")

In [None]:
# Visualize first 16 images
fig, axes = plt.subplots(4, 4, figsize=(10, 10))
axes = axes.ravel()

for idx in range(16):
    # Remove normalization for visualization
    img = images[idx].squeeze().numpy()
    img = img * 0.3081 + 0.1307  # Denormalize
    
    axes[idx].imshow(img, cmap='gray')
    axes[idx].set_title(f'Label: {labels[idx].item()}')
    axes[idx].axis('off')

plt.tight_layout()
plt.show()

## 3. Dataset Statistics

In [None]:
# Count class distribution in training set
train_labels = []
for _, labels in train_loader:
    train_labels.extend(labels.numpy())

train_labels = np.array(train_labels)
unique, counts = np.unique(train_labels, return_counts=True)

print("Class distribution in training set:")
for digit, count in zip(unique, counts):
    print(f"  Digit {digit}: {count} samples ({count/len(train_labels)*100:.2f}%)")

In [None]:
# Visualize class distribution
plt.figure(figsize=(10, 6))
plt.bar(unique, counts, color='skyblue', edgecolor='navy')
plt.xlabel('Digit Class', fontsize=12)
plt.ylabel('Number of Samples', fontsize=12)
plt.title('MNIST Training Set Class Distribution', fontsize=14, fontweight='bold')
plt.xticks(unique)
plt.grid(axis='y', alpha=0.3)
plt.tight_layout()
plt.show()

## 4. Image Statistics

In [None]:
# Analyze pixel value distribution
sample_images, _ = data_module.get_sample_batch()
pixel_values = sample_images.numpy().flatten()

plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.hist(pixel_values, bins=50, color='purple', alpha=0.7, edgecolor='black')
plt.xlabel('Pixel Value (normalized)', fontsize=11)
plt.ylabel('Frequency', fontsize=11)
plt.title('Distribution of Pixel Values', fontsize=12, fontweight='bold')
plt.grid(alpha=0.3)

plt.subplot(1, 2, 2)
mean_per_image = sample_images.mean(dim=[1, 2, 3]).numpy()
plt.hist(mean_per_image, bins=30, color='green', alpha=0.7, edgecolor='black')
plt.xlabel('Mean Pixel Value', fontsize=11)
plt.ylabel('Frequency', fontsize=11)
plt.title('Distribution of Image Means', fontsize=12, fontweight='bold')
plt.grid(alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Pixel value statistics:")
print(f"  Mean: {pixel_values.mean():.4f}")
print(f"  Std: {pixel_values.std():.4f}")
print(f"  Min: {pixel_values.min():.4f}")
print(f"  Max: {pixel_values.max():.4f}")

## 5. Verify Data Pipeline

In [None]:
# Test data pipeline with multiple batches
print("Testing data pipeline...")

for i, (images, labels) in enumerate(train_loader):
    if i >= 3:  # Test first 3 batches
        break
    
    print(f"\nBatch {i+1}:")
    print(f"  Images shape: {images.shape}")
    print(f"  Labels shape: {labels.shape}")
    print(f"  Device: {images.device}")
    print(f"  Dtype: {images.dtype}")
    print(f"  Sample labels: {labels[:10].tolist()}")

print("\n✓ Data pipeline working correctly!")

## 6. Prepare for Model Input

In [None]:
# Check data format for model
images, labels = next(iter(train_loader))

print("Data format for model:")
print(f"  Input shape: {images.shape}  # (batch_size, channels, height, width)")
print(f"  Expected by ResNet18: [N, 1, 28, 28] ✓")
print(f"  Labels shape: {labels.shape}  # (batch_size,)")
print(f"  Number of classes: 10 (digits 0-9)")

# Verify shapes
assert images.shape[1:] == (1, 28, 28), "Image shape incorrect!"
assert labels.shape[0] == images.shape[0], "Batch size mismatch!"

print("\n✓ Data format verified for model input!")

## Summary

✅ **Dataset loaded successfully**
- Training samples: ~54,000
- Validation samples: ~6,000
- Test samples: 10,000

✅ **Class distribution is balanced** (roughly 10% per class)

✅ **Images are properly normalized** (mean ≈ 0, std ≈ 1)

✅ **Data pipeline is working** and ready for model training

✅ **Input format verified** for ResNet18 + Quantum circuit integration

**Next steps:**
1. Implement quantum circuit
2. Build hybrid model architecture
3. Start training!