# MicroGrad+ CIFAR-10 Example

This notebook demonstrates training on CIFAR-10 using MicroGrad+.

**Note**: CIFAR-10 is more challenging than MNIST because:
- Images are 32x32 color (3 channels) = 3072 features
- 10 diverse object classes (not just digits)
- Requires larger networks for good performance

With an MLP, we can expect ~50-55% accuracy (CNNs can reach ~90%+).

---

## 1. Setup and Imports

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import sys
from pathlib import Path

# Robust path resolution - works regardless of working directory
def _find_module_root():
    """Find the module root directory containing micrograd_plus."""
    current = Path.cwd()
    for parent in [current] + list(current.parents):
        if (parent / 'micrograd_plus' / '__init__.py').exists():
            return str(parent)
    return str(Path.cwd().parent)

sys.path.insert(0, _find_module_root())

from micrograd_plus import (
    Tensor, Linear, ReLU, Dropout, Sequential,
    CrossEntropyLoss, Adam
)
from micrograd_plus.utils import set_seed, DataLoader

set_seed(42)

print("MicroGrad+ CIFAR-10 Example")
print("=" * 50)

## 2. Load CIFAR-10 Dataset

CIFAR-10 contains 60,000 32x32 color images in 10 classes:
- airplane, automobile, bird, cat, deer
- dog, frog, horse, ship, truck

In [None]:
CIFAR10_CLASSES = [
    'airplane', 'automobile', 'bird', 'cat', 'deer',
    'dog', 'frog', 'horse', 'ship', 'truck'
]

def load_cifar10():
    """
    Load CIFAR-10 dataset.
    
    Requires TensorFlow/Keras to be installed for real data.
    Falls back to synthetic data with clear warning if unavailable.
    
    Returns:
        X_train, X_test: (N, 3072) normalized images (flattened)
        y_train, y_test: (N,) integer labels 0-9
    """
    try:
        from tensorflow.keras.datasets import cifar10
        print("Loading CIFAR-10 from keras...")
        (X_train, y_train), (X_test, y_test) = cifar10.load_data()
        
        # Flatten and normalize
        X_train = X_train.reshape(-1, 3072).astype(np.float32) / 255.0
        X_test = X_test.reshape(-1, 3072).astype(np.float32) / 255.0
        y_train = y_train.flatten().astype(np.int32)
        y_test = y_test.flatten().astype(np.int32)
        
        print("CIFAR-10 loaded successfully!")
        return X_train, X_test, y_train, y_test
        
    except ImportError:
        print("=" * 60)
        print("WARNING: TensorFlow not installed!")
        print("=" * 60)
        print("")
        print("To use real CIFAR-10 data, install TensorFlow:")
        print("  pip install tensorflow")
        print("")
        print("Falling back to SYNTHETIC random data.")
        print("Training results will NOT be meaningful!")
        print("=" * 60)
        print("")
        
        # Generate synthetic data that looks somewhat like CIFAR-10
        np.random.seed(42)
        X_train = np.random.randn(5000, 3072).astype(np.float32) * 0.3 + 0.5
        X_train = np.clip(X_train, 0, 1)
        y_train = np.random.randint(0, 10, 5000).astype(np.int32)
        X_test = np.random.randn(1000, 3072).astype(np.float32) * 0.3 + 0.5
        X_test = np.clip(X_test, 0, 1)
        y_test = np.random.randint(0, 10, 1000).astype(np.int32)
        
        return X_train, X_test, y_train, y_test

X_train, X_test, y_train, y_test = load_cifar10()

print(f"\nDataset loaded:")
print(f"  Training: {X_train.shape[0]} samples")
print(f"  Test:     {X_test.shape[0]} samples")
print(f"  Image shape: 32x32x3 = 3072 features")
print(f"  Classes: {', '.join(CIFAR10_CLASSES)}")

## 3. Visualize Some Examples

In [None]:
# Display some training examples
fig, axes = plt.subplots(2, 5, figsize=(12, 5))

for i, ax in enumerate(axes.flat):
    # Reshape to 32x32x3
    img = X_train[i].reshape(32, 32, 3)
    ax.imshow(img)
    ax.set_title(f'{CIFAR10_CLASSES[y_train[i]]}')
    ax.axis('off')

plt.suptitle('Sample CIFAR-10 Images', fontsize=14)
plt.tight_layout()
plt.show()

## 4. Preprocessing

Apply normalization per channel (standard practice for CIFAR-10).

In [None]:
def normalize_cifar10(X_train, X_test):
    """
    Normalize CIFAR-10 using per-channel mean and std.
    """
    # Reshape to (N, 3, 1024) for channel-wise normalization
    X_train_reshaped = X_train.reshape(-1, 3, 1024)
    X_test_reshaped = X_test.reshape(-1, 3, 1024)
    
    # Compute mean and std per channel from training set
    mean = X_train_reshaped.mean(axis=(0, 2), keepdims=True)
    std = X_train_reshaped.std(axis=(0, 2), keepdims=True)
    
    # Normalize
    X_train_norm = (X_train_reshaped - mean) / (std + 1e-8)
    X_test_norm = (X_test_reshaped - mean) / (std + 1e-8)
    
    # Flatten back
    return X_train_norm.reshape(-1, 3072), X_test_norm.reshape(-1, 3072)

X_train_norm, X_test_norm = normalize_cifar10(X_train, X_test)

print(f"Normalized data:")
print(f"  Train mean: {X_train_norm.mean():.4f}")
print(f"  Train std:  {X_train_norm.std():.4f}")

## 5. Create Data Loaders

In [None]:
# Use a subset for faster training
TRAIN_SIZE = 10000  # Use 10000 for quick demo, 50000 for full
BATCH_SIZE = 64

train_loader = DataLoader(X_train_norm[:TRAIN_SIZE], y_train[:TRAIN_SIZE],
                          batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(X_test_norm, y_test,
                         batch_size=BATCH_SIZE, shuffle=False)

print(f"Training batches: {len(train_loader)}")
print(f"Test batches: {len(test_loader)}")

## 6. Define the Model

We need a larger network for CIFAR-10 since it's more complex than MNIST.

In [None]:
# Build a deeper MLP for CIFAR-10
model = Sequential(
    # First hidden layer
    Linear(3072, 1024),
    ReLU(),
    Dropout(0.3),
    
    # Second hidden layer
    Linear(1024, 512),
    ReLU(),
    Dropout(0.3),
    
    # Third hidden layer
    Linear(512, 256),
    ReLU(),
    Dropout(0.2),
    
    # Output layer
    Linear(256, 10)
)

total_params = sum(p.data.size for p in model.parameters())
print(f"Model Architecture:")
print(f"  Input:  3072 (32x32x3 pixels)")
print(f"  Hidden: 1024 -> ReLU -> Dropout(0.3)")
print(f"  Hidden: 512 -> ReLU -> Dropout(0.3)")
print(f"  Hidden: 256 -> ReLU -> Dropout(0.2)")
print(f"  Output: 10 classes")
print(f"\nTotal Parameters: {total_params:,}")

## 7. Training

In [None]:
# Setup
loss_fn = CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=0.001)
EPOCHS = 20

def train_epoch(model, train_loader, loss_fn, optimizer):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for X_batch, y_batch in train_loader:
        X = Tensor(X_batch, requires_grad=True)
        y = Tensor(y_batch)
        
        logits = model(X)
        loss = loss_fn(logits, y)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item() * len(y_batch)
        predictions = np.argmax(logits.data, axis=1)
        correct += np.sum(predictions == y_batch)
        total += len(y_batch)
    
    return total_loss / total, correct / total

def evaluate(model, test_loader, loss_fn):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    for X_batch, y_batch in test_loader:
        X = Tensor(X_batch)
        y = Tensor(y_batch)
        
        logits = model(X)
        loss = loss_fn(logits, y)
        
        total_loss += loss.item() * len(y_batch)
        predictions = np.argmax(logits.data, axis=1)
        correct += np.sum(predictions == y_batch)
        total += len(y_batch)
    
    return total_loss / total, correct / total

In [None]:
# Train!
print("\nStarting Training...")
print("=" * 70)

history = {'train_loss': [], 'train_acc': [], 'test_loss': [], 'test_acc': []}

for epoch in range(EPOCHS):
    train_loss, train_acc = train_epoch(model, train_loader, loss_fn, optimizer)
    test_loss, test_acc = evaluate(model, test_loader, loss_fn)
    
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['test_loss'].append(test_loss)
    history['test_acc'].append(test_acc)
    
    print(f"Epoch {epoch+1:2d}/{EPOCHS} | "
          f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc*100:.2f}% | "
          f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc*100:.2f}%")

print("=" * 70)
print(f"Final Test Accuracy: {history['test_acc'][-1]*100:.2f}%")

## 8. Plot Training History

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Loss
ax1.plot(history['train_loss'], 'b-', label='Train', linewidth=2)
ax1.plot(history['test_loss'], 'r-', label='Test', linewidth=2)
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training and Test Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Accuracy
ax2.plot([a*100 for a in history['train_acc']], 'b-', label='Train', linewidth=2)
ax2.plot([a*100 for a in history['test_acc']], 'r-', label='Test', linewidth=2)
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy (%)')
ax2.set_title('Training and Test Accuracy')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 9. Visualize Predictions

In [None]:
model.eval()

# Get predictions for first 20 test images
X_sample = Tensor(X_test_norm[:20])
logits = model(X_sample)
predictions = np.argmax(logits.data, axis=1)

# Visualize
fig, axes = plt.subplots(2, 10, figsize=(16, 4))

for i, ax in enumerate(axes.flat):
    img = X_test[i].reshape(32, 32, 3)  # Use original (not normalized) for display
    ax.imshow(img)
    
    pred_class = CIFAR10_CLASSES[predictions[i]]
    true_class = CIFAR10_CLASSES[y_test[i]]
    
    color = 'green' if predictions[i] == y_test[i] else 'red'
    ax.set_title(f'P:{pred_class[:4]}\nT:{true_class[:4]}', 
                 color=color, fontsize=8)
    ax.axis('off')

plt.suptitle('Predictions (Green=Correct, Red=Wrong)', fontsize=14)
plt.tight_layout()
plt.show()

## 10. Per-Class Accuracy

In [None]:
# Compute per-class accuracy
model.eval()
all_predictions = []
all_labels = []

for X_batch, y_batch in test_loader:
    X = Tensor(X_batch)
    logits = model(X)
    predictions = np.argmax(logits.data, axis=1)
    all_predictions.extend(predictions)
    all_labels.extend(y_batch)

all_predictions = np.array(all_predictions)
all_labels = np.array(all_labels)

# Per-class accuracy
print("\nPer-Class Accuracy:")
print("-" * 30)

class_accuracies = []
for i, class_name in enumerate(CIFAR10_CLASSES):
    mask = all_labels == i
    if mask.sum() > 0:
        acc = (all_predictions[mask] == i).mean() * 100
        class_accuracies.append(acc)
        print(f"  {class_name:12s}: {acc:.1f}%")

print("-" * 30)
print(f"  {'Average':12s}: {np.mean(class_accuracies):.1f}%")

In [None]:
# Visualize per-class accuracy
plt.figure(figsize=(12, 5))
bars = plt.bar(CIFAR10_CLASSES, class_accuracies, color='steelblue')

# Color bars based on accuracy
for bar, acc in zip(bars, class_accuracies):
    if acc >= 60:
        bar.set_color('green')
    elif acc >= 40:
        bar.set_color('orange')
    else:
        bar.set_color('red')

plt.axhline(y=np.mean(class_accuracies), color='black', linestyle='--', label='Average')
plt.xlabel('Class')
plt.ylabel('Accuracy (%)')
plt.title('Per-Class Accuracy on CIFAR-10')
plt.xticks(rotation=45)
plt.legend()
plt.tight_layout()
plt.show()

## 11. Summary

### What We Learned:

1. **CIFAR-10 is harder than MNIST** - Color images with complex objects

2. **MLPs have limitations** - Can't capture spatial patterns like CNNs

3. **Normalization matters** - Per-channel normalization helps training

4. **Some classes are harder** - Often confusions like cat/dog, car/truck

### Expected Results:
- **Random baseline**: 10% (guessing)
- **Simple MLP**: 40-50%
- **Deep MLP with dropout**: 50-55%
- **CNN (not implemented here)**: 80-95%

### Why CNNs Would Be Better:
- Capture local spatial patterns (edges, textures)
- Parameter sharing (same filter applied everywhere)
- Translation invariance
- Hierarchical feature learning

This example shows the limits of MLPs on image classification tasks,
motivating the need for more advanced architectures like CNNs.