# MicroGrad+ MNIST Example

This notebook demonstrates how to train a neural network on MNIST using MicroGrad+.

**Goal**: Train a multi-layer perceptron to achieve >95% accuracy on MNIST.

---

## 1. Setup and Imports

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import sys
import os
import gzip
import socket
from urllib import request
from pathlib import Path

# Robust path resolution - works regardless of working directory
def _find_module_root():
    """Find the module root directory containing micrograd_plus."""
    current = Path.cwd()
    for parent in [current] + list(current.parents):
        if (parent / 'micrograd_plus' / '__init__.py').exists():
            return str(parent)
    return str(Path.cwd().parent)

sys.path.insert(0, _find_module_root())

from micrograd_plus import (
    Tensor, Linear, ReLU, Dropout, Sequential,
    CrossEntropyLoss, Adam
)
from micrograd_plus.utils import set_seed, DataLoader

# Set seed for reproducibility
set_seed(42)

print("MicroGrad+ MNIST Example")
print("=" * 50)

## 2. Load MNIST Dataset

We'll download MNIST from sklearn or keras and preprocess it.

In [None]:
def download_mnist(path=None, timeout=30):
    """
    Download MNIST dataset with fallback URLs and timeout.
    
    Uses multiple mirror URLs to ensure reliability:
    1. PyTorch/OSSCI S3 mirror (most reliable)
    2. Original yann.lecun.com (backup)
    """
    # Resolve path relative to module root if not specified
    if path is None:
        path = os.path.join(_find_module_root(), 'data')
    os.makedirs(path, exist_ok=True)
    
    urls = [
        'https://ossci-datasets.s3.amazonaws.com/mnist/',  # PyTorch mirror (reliable)
        'http://yann.lecun.com/exdb/mnist/',  # Original (sometimes unreliable)
    ]
    
    files = [
        'train-images-idx3-ubyte.gz',
        'train-labels-idx1-ubyte.gz',
        't10k-images-idx3-ubyte.gz',
        't10k-labels-idx1-ubyte.gz'
    ]
    
    for f in files:
        filepath = os.path.join(path, f)
        if not os.path.exists(filepath):
            downloaded = False
            for base_url in urls:
                try:
                    print(f"Downloading {f} from {base_url}...")
                    old_timeout = socket.getdefaulttimeout()
                    socket.setdefaulttimeout(timeout)
                    try:
                        request.urlretrieve(base_url + f, filepath)
                        downloaded = True
                        print(f"  Success!")
                        break
                    finally:
                        socket.setdefaulttimeout(old_timeout)
                except Exception as e:
                    print(f"  Failed: {e}")
                    if os.path.exists(filepath):
                        os.remove(filepath)
                    continue
            
            if not downloaded:
                raise RuntimeError(
                    f"Could not download {f}.\n"
                    f"Please download manually from:\n"
                    f"  https://ossci-datasets.s3.amazonaws.com/mnist/{f}"
                )
    
    print("MNIST data ready!")


def load_mnist_local(path=None):
    if path is None:
        path = os.path.join(_find_module_root(), 'data')
    """Load MNIST from local gzipped files."""
    def load_images(filename):
        with gzip.open(os.path.join(path, filename), 'rb') as f:
            data = np.frombuffer(f.read(), np.uint8, offset=16)
        return data.reshape(-1, 784).astype(np.float32) / 255.0
    
    def load_labels(filename):
        with gzip.open(os.path.join(path, filename), 'rb') as f:
            return np.frombuffer(f.read(), np.uint8, offset=8).astype(np.int32)
    
    X_train = load_images('train-images-idx3-ubyte.gz')
    y_train = load_labels('train-labels-idx1-ubyte.gz')
    X_test = load_images('t10k-images-idx3-ubyte.gz')
    y_test = load_labels('t10k-labels-idx1-ubyte.gz')
    
    return X_train, X_test, y_train, y_test


def load_mnist():
    """
    Load MNIST dataset with multiple fallback methods.
    
    Priority:
    1. Local files (if already downloaded)
    2. Download from reliable mirrors
    3. sklearn's fetch_openml (slower but reliable)
    4. tensorflow/keras (if available)
    
    Returns:
        X_train, X_test: (N, 784) normalized images
        y_train, y_test: (N,) integer labels 0-9
    """
    data_path = os.path.join(_find_module_root(), 'data')
    local_files = ['train-images-idx3-ubyte.gz', 'train-labels-idx1-ubyte.gz',
                   't10k-images-idx3-ubyte.gz', 't10k-labels-idx1-ubyte.gz']
    
    # Method 1: Check for local files
    if all(os.path.exists(os.path.join(data_path, f)) for f in local_files):
        print("Loading MNIST from local files...")
        return load_mnist_local(data_path)
    
    # Method 2: Download from mirrors
    try:
        print("Downloading MNIST from mirrors...")
        download_mnist(data_path)
        return load_mnist_local(data_path)
    except Exception as e:
        print(f"Mirror download failed: {e}")
    
    # Method 3: Try sklearn
    try:
        from sklearn.datasets import fetch_openml
        print("Loading MNIST from sklearn (may be slow on first run)...")
        mnist = fetch_openml('mnist_784', version=1, as_frame=False, parser='auto')
        X, y = mnist.data.astype(np.float32), mnist.target.astype(np.int32)
        X = X / 255.0
        return X[:60000], X[60000:], y[:60000], y[60000:]
    except Exception as e:
        print(f"sklearn failed: {e}")
    
    # Method 4: Try keras/tensorflow
    try:
        from tensorflow.keras.datasets import mnist
        print("Loading MNIST from keras...")
        (X_train, y_train), (X_test, y_test) = mnist.load_data()
        X_train = X_train.reshape(-1, 784).astype(np.float32) / 255.0
        X_test = X_test.reshape(-1, 784).astype(np.float32) / 255.0
        return X_train, X_test, y_train.astype(np.int32), y_test.astype(np.int32)
    except Exception as e:
        print(f"keras failed: {e}")
    
    raise RuntimeError(
        "Could not load MNIST from any source.\n"
        "Please download manually from:\n"
        "  https://ossci-datasets.s3.amazonaws.com/mnist/\n"
        "and place files in the data/ directory."
    )

X_train, X_test, y_train, y_test = load_mnist()

print(f"\nDataset loaded:")
print(f"  Training: {X_train.shape[0]} samples")
print(f"  Test:     {X_test.shape[0]} samples")
print(f"  Image shape: 28x28 = 784 pixels")
print(f"  Classes: 0-9 digits")

## 3. Visualize Some Examples

In [None]:
# Display some training examples
fig, axes = plt.subplots(2, 5, figsize=(12, 5))

for i, ax in enumerate(axes.flat):
    ax.imshow(X_train[i].reshape(28, 28), cmap='gray')
    ax.set_title(f'Label: {y_train[i]}')
    ax.axis('off')

plt.suptitle('Sample MNIST Images', fontsize=14)
plt.tight_layout()
plt.show()

## 4. Create Data Loaders

In [None]:
# Use a subset for faster training (or full dataset for best accuracy)
TRAIN_SIZE = 10000  # Use 10000 for quick demo, 60000 for full training
BATCH_SIZE = 64

# Create data loaders
train_loader = DataLoader(X_train[:TRAIN_SIZE], y_train[:TRAIN_SIZE], 
                          batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(X_test, y_test, 
                         batch_size=BATCH_SIZE, shuffle=False)

print(f"Training batches: {len(train_loader)}")
print(f"Test batches: {len(test_loader)}")

## 5. Define the Model

We'll create a multi-layer perceptron (MLP) with:
- Input: 784 features (28x28 pixels)
- Hidden layers: 256 -> 128 neurons with ReLU
- Output: 10 classes (digits 0-9)

In [None]:
# Build the model
model = Sequential(
    Linear(784, 256),
    ReLU(),
    Dropout(0.2),
    Linear(256, 128),
    ReLU(),
    Dropout(0.2),
    Linear(128, 10)
)

# Count parameters
total_params = sum(p.data.size for p in model.parameters())
print(f"Model Architecture:")
print(f"  Input:  784 (28x28 pixels)")
print(f"  Hidden: 256 -> ReLU -> Dropout(0.2)")
print(f"  Hidden: 128 -> ReLU -> Dropout(0.2)")
print(f"  Output: 10 (digits 0-9)")
print(f"\nTotal Parameters: {total_params:,}")

## 6. Setup Training

In [None]:
# Loss function and optimizer
loss_fn = CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=0.001)

# Training settings
EPOCHS = 10

print(f"Training Configuration:")
print(f"  Epochs: {EPOCHS}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Learning rate: 0.001")
print(f"  Optimizer: Adam")
print(f"  Loss: CrossEntropyLoss")

## 7. Training Loop

In [None]:
def train_epoch(model, train_loader, loss_fn, optimizer):
    """Train for one epoch."""
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for X_batch, y_batch in train_loader:
        X = Tensor(X_batch, requires_grad=True)
        y = Tensor(y_batch)
        
        # Forward pass
        logits = model(X)
        loss = loss_fn(logits, y)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Metrics
        total_loss += loss.item() * len(y_batch)
        predictions = np.argmax(logits.data, axis=1)
        correct += np.sum(predictions == y_batch)
        total += len(y_batch)
    
    return total_loss / total, correct / total


def evaluate(model, test_loader, loss_fn):
    """Evaluate on test set."""
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    for X_batch, y_batch in test_loader:
        X = Tensor(X_batch)
        y = Tensor(y_batch)
        
        logits = model(X)
        loss = loss_fn(logits, y)
        
        total_loss += loss.item() * len(y_batch)
        predictions = np.argmax(logits.data, axis=1)
        correct += np.sum(predictions == y_batch)
        total += len(y_batch)
    
    return total_loss / total, correct / total

In [None]:
# Training!
print("\nStarting Training...")
print("=" * 60)

history = {'train_loss': [], 'train_acc': [], 'test_loss': [], 'test_acc': []}

for epoch in range(EPOCHS):
    # Train
    train_loss, train_acc = train_epoch(model, train_loader, loss_fn, optimizer)
    
    # Evaluate
    test_loss, test_acc = evaluate(model, test_loader, loss_fn)
    
    # Record history
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['test_loss'].append(test_loss)
    history['test_acc'].append(test_acc)
    
    print(f"Epoch {epoch+1:2d}/{EPOCHS} | "
          f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc*100:.2f}% | "
          f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc*100:.2f}%")

print("=" * 60)
print(f"Final Test Accuracy: {history['test_acc'][-1]*100:.2f}%")

## 8. Plot Training History

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Loss plot
ax1.plot(history['train_loss'], 'b-', label='Train Loss', linewidth=2)
ax1.plot(history['test_loss'], 'r-', label='Test Loss', linewidth=2)
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training and Test Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Accuracy plot
ax2.plot([a*100 for a in history['train_acc']], 'b-', label='Train Accuracy', linewidth=2)
ax2.plot([a*100 for a in history['test_acc']], 'r-', label='Test Accuracy', linewidth=2)
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy (%)')
ax2.set_title('Training and Test Accuracy')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 9. Visualize Predictions

In [None]:
# Make predictions on test set
model.eval()

# Get predictions for first 20 test images
X_sample = Tensor(X_test[:20])
logits = model(X_sample)
predictions = np.argmax(logits.data, axis=1)

# Visualize
fig, axes = plt.subplots(2, 10, figsize=(15, 4))

for i, ax in enumerate(axes.flat):
    ax.imshow(X_test[i].reshape(28, 28), cmap='gray')
    
    color = 'green' if predictions[i] == y_test[i] else 'red'
    ax.set_title(f'Pred: {predictions[i]}\nTrue: {y_test[i]}', 
                 color=color, fontsize=10)
    ax.axis('off')

plt.suptitle('Model Predictions (Green=Correct, Red=Wrong)', fontsize=14)
plt.tight_layout()
plt.show()

## 10. Confusion Matrix

In [None]:
# Compute confusion matrix
model.eval()
all_predictions = []
all_labels = []

for X_batch, y_batch in test_loader:
    X = Tensor(X_batch)
    logits = model(X)
    predictions = np.argmax(logits.data, axis=1)
    all_predictions.extend(predictions)
    all_labels.extend(y_batch)

all_predictions = np.array(all_predictions)
all_labels = np.array(all_labels)

# Build confusion matrix
confusion = np.zeros((10, 10), dtype=np.int32)
for true, pred in zip(all_labels, all_predictions):
    confusion[true, pred] += 1

# Plot
plt.figure(figsize=(10, 8))
plt.imshow(confusion, cmap='Blues')
plt.colorbar(label='Count')

for i in range(10):
    for j in range(10):
        color = 'white' if confusion[i, j] > confusion.max() / 2 else 'black'
        plt.text(j, i, str(confusion[i, j]), ha='center', va='center', color=color)

plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.xticks(range(10))
plt.yticks(range(10))
plt.tight_layout()
plt.show()

# Per-class accuracy
print("\nPer-class Accuracy:")
for i in range(10):
    class_acc = confusion[i, i] / confusion[i].sum() * 100
    print(f"  Digit {i}: {class_acc:.1f}%")

## 11. Summary

In this example, we:

1. **Loaded MNIST** - 60,000 training + 10,000 test images
2. **Built an MLP** - 784 -> 256 -> 128 -> 10 with dropout
3. **Trained with Adam** - Learning rate 0.001
4. **Achieved >95% accuracy** on the test set

### Performance Note

MicroGrad+ is an educational implementation built in pure Python/NumPy. It is intentionally
**10-100x slower than PyTorch**, which uses optimized C++/CUDA kernels. This is expected!

The goal is understanding, not speed. Once you understand how autograd works internally,
you can use PyTorch (Phase 2) for production workloads with full GPU acceleration.

### Tips for Higher Accuracy:
- Train on full 60,000 samples
- Add more epochs (20-30)
- Try learning rate scheduling
- Add data augmentation (shifts, rotations)
- Use deeper/wider networks

### Memory Management

When training larger models or using the full dataset, run garbage collection:
```python
import gc
gc.collect()
```

## 12. Cleanup

In [None]:
# Cleanup - release memory
from micrograd_plus.utils import cleanup_notebook
cleanup_notebook(globals())