# Neural Network Classification Tutorial

This notebook demonstrates neural network classification on MNIST dataset.

In [None]:
import sys
sys.path.append('../..')

import torch
import matplotlib.pyplot as plt

from src.classification import NeuralNetworkClassifier, NNClassifierTrainer
from src.utils import load_mnist, get_device, set_seed, plot_training_curves, plot_confusion_matrix

set_seed(42)
device = get_device()
print(f"Using device: {device}")

## Load MNIST Dataset

In [None]:
# Load MNIST data
train_loader = load_mnist(batch_size=128, train=True, download=True)
test_loader = load_mnist(batch_size=128, train=False, download=True)

print(f"Training batches: {len(train_loader)}")
print(f"Test batches: {len(test_loader)}")

## Visualize Sample Data

In [None]:
# Get a batch of training data
dataiter = iter(train_loader)
images, labels = next(dataiter)

# Plot first 16 images
fig, axes = plt.subplots(2, 8, figsize=(15, 4))
for i, ax in enumerate(axes.flat):
    if i < len(images):
        ax.imshow(images[i].squeeze(), cmap='gray')
        ax.set_title(f'Label: {labels[i].item()}')
        ax.axis('off')
plt.tight_layout()
plt.show()

## Create and Train Neural Network

In [None]:
# Create model
model = NeuralNetworkClassifier(
    input_dim=784,  # 28x28 images flattened
    output_dim=10,  # 10 classes (digits 0-9)
    hidden_dims=[256, 128, 64],
    dropout_rate=0.2
)

print(model)

In [None]:
# Train model
trainer = NNClassifierTrainer(model, device=device)
history = trainer.train(
    train_loader,
    n_epochs=10,
    learning_rate=1e-3,
    val_loader=test_loader,
    verbose=True
)

## Plot Training History

In [None]:
# Plot loss curves
plot_training_curves(history['train_losses'], history['val_losses'], title='Loss Curves')

# Plot accuracy curves
plt.figure(figsize=(10, 5))
epochs = range(1, len(history['train_accuracies']) + 1)
plt.plot(epochs, history['train_accuracies'], 'b-', label='Training Accuracy', linewidth=2)
plt.plot(epochs, history['val_accuracies'], 'r-', label='Validation Accuracy', linewidth=2)
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Accuracy Curves')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

## Evaluate on Test Set

In [None]:
# Evaluate
test_loss, test_acc = trainer.evaluate(test_loader)
print(f"Test Loss: {test_loss:.6f}")
print(f"Test Accuracy: {test_acc:.2f}%")

## Confusion Matrix

In [None]:
# Get predictions
predictions, true_labels = trainer.predict(test_loader)

# Plot confusion matrix
class_names = [str(i) for i in range(10)]
plot_confusion_matrix(true_labels, predictions, class_names=class_names, figsize=(10, 8))

## Test on Sample Images

In [None]:
# Get a batch of test data
dataiter = iter(test_loader)
images, labels = next(dataiter)

# Make predictions
images_flat = images.view(images.size(0), -1).to(device)
model.eval()
with torch.no_grad():
    predictions = model.predict(images_flat)

# Plot results
fig, axes = plt.subplots(2, 8, figsize=(15, 4))
for i, ax in enumerate(axes.flat):
    if i < len(images):
        ax.imshow(images[i].squeeze(), cmap='gray')
        true_label = labels[i].item()
        pred_label = predictions[i].item()
        color = 'green' if true_label == pred_label else 'red'
        ax.set_title(f'True: {true_label}, Pred: {pred_label}', color=color)
        ax.axis('off')
plt.tight_layout()
plt.show()