# üì∏ 06 ‚Äî MNIST Digit Classification

**Train CNN on MNIST ‚Üí Quantize to i8 ‚Üí Verify on NANO-RUST engine**

| Property | Value |
|----------|-------|
| **Task** | Handwritten digit recognition (0‚Äì9) |
| **Dataset** | MNIST ‚Äî 60k train / 10k test (28√ó28 grayscale) |
| **Architecture** | Conv2D(1‚Üí8) ‚Üí ReLU ‚Üí Pool ‚Üí Conv2D(8‚Üí16) ‚Üí ReLU ‚Üí Pool ‚Üí Flatten ‚Üí Dense(784‚Üí10) |
| **Expected Accuracy** | ~97% (PyTorch), ~95‚Äì97% (NANO i8) |
| **MCU Memory** | ~13KB Flash + 32KB Arena |

> **Pipeline**: Train (GPU) ‚Üí Quantize (float32‚Üíint8) ‚Üí Calibrate ‚Üí Build NANO model ‚Üí Verify


## Step 0: Install Dependencies

In [None]:
# !pip install nano-rust-py[train] torchvision


## Step 1: Setup & GPU Detection

We use CUDA if available for fast training, then move to CPU for quantization.

In [None]:
import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from nano_rust_py.utils import quantize_to_i8, quantize_weights, calibrate_model
import nano_rust_py

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')
if device.type == 'cuda':
    print(f'  GPU: {torch.cuda.get_device_name(0)}')


## Step 2: Load MNIST Dataset

Each image is 28√ó28 grayscale, normalized with dataset-specific mean/std.

| Param | Value | Why |
|-------|-------|-----|
| Normalize mean | 0.1307 | Centers pixel distribution |
| Normalize std | 0.3081 | Scales to unit variance |
| Batch size | 256 | Good GPU utilization |


In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, pin_memory=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False, pin_memory=True, num_workers=0)

print(f'Train: {len(train_dataset):,} images')
print(f'Test:  {len(test_dataset):,} images')
print(f'Image shape: {train_dataset[0][0].shape}  (C, H, W)')


## Step 3: Define & Train CNN

Architecture designed for MCU deployment ‚Äî small kernels, few channels:

```
Input [1,28,28] ‚Üí Conv(1‚Üí8, 3√ó3) ‚Üí ReLU ‚Üí Pool(2) ‚Üí [8,14,14]
               ‚Üí Conv(8‚Üí16, 3√ó3) ‚Üí ReLU ‚Üí Pool(2) ‚Üí [16,7,7]
               ‚Üí Flatten ‚Üí [784] ‚Üí Dense ‚Üí [10]
```


In [None]:
model = nn.Sequential(
    nn.Conv2d(1, 8, 3, stride=1, padding=1),   # [1,28,28] ‚Üí [8,28,28]
    nn.ReLU(),
    nn.MaxPool2d(2, 2),                          # ‚Üí [8,14,14]
    nn.Conv2d(8, 16, 3, stride=1, padding=1),   # ‚Üí [16,14,14]
    nn.ReLU(),
    nn.MaxPool2d(2, 2),                          # ‚Üí [16,7,7]
    nn.Flatten(),                                # ‚Üí [784]
    nn.Linear(16 * 7 * 7, 10),                  # ‚Üí [10]
).to(device)

total_params = sum(p.numel() for p in model.parameters())
print(f'Parameters: {total_params:,}')
print(f'Float32: {total_params * 4:,} bytes ‚Üí Int8: {total_params:,} bytes (4x smaller!)')


In [None]:
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
EPOCHS = 3

t0 = time.time()
for epoch in range(EPOCHS):
    model.train()
    correct, total = 0, 0
    for data, target in train_loader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        correct += output.argmax(1).eq(target).sum().item()
        total += target.size(0)
    print(f'  Epoch {epoch+1}/{EPOCHS} ‚Äî Acc: {100.*correct/total:.1f}%')

train_time = time.time() - t0
print(f'\nTraining complete in {train_time:.1f}s')


## Step 4: Evaluate PyTorch Baseline

This is our **float32 baseline**. NANO should be within 2-3%.

In [None]:
model.eval()
correct_pt, total_pt = 0, 0
with torch.no_grad():
    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        correct_pt += model(data).argmax(1).eq(target).sum().item()
        total_pt += target.size(0)

pt_accuracy = 100. * correct_pt / total_pt
print(f'‚úÖ PyTorch Test Accuracy: {pt_accuracy:.2f}%')


## Step 5: Quantize & Calibrate

1. **`quantize_weights()`** ‚Äî float32 ‚Üí int8 (4√ó compression)
2. **`calibrate_model()`** ‚Äî compute `(requant_m, requant_shift)` per layer

> Without calibration: ~85% accuracy. With calibration: **95‚Äì99%**


In [None]:
model_cpu = model.cpu().eval()

# Quantize weights
q_weights = quantize_weights(model_cpu)
for name, info in q_weights.items():
    w = info['weights']
    print(f'  Layer {name} ({info["type"]}): {w.shape}, {w.nbytes} bytes, scale={info["weight_scale"]:.6f}')


In [None]:
# Calibrate requantization parameters
cal_image = test_dataset[0][0].unsqueeze(0)
q_cal, cal_scale = quantize_to_i8(cal_image.numpy().flatten())
requant = calibrate_model(model_cpu, cal_image, q_weights, cal_scale)

for name, params in requant.items():
    if isinstance(params, tuple) and len(params) == 3:
        m, s, _ = params
        print(f'  Layer {name}: requant_m={m}, requant_shift={s}')


## Step 6: Build NANO Model & Compare

Build the same architecture in NANO-RUST with i8 weights + calibrated requantization,
then compare predictions on 1000 test images.


In [None]:
def build_nano_model():
    nano = nano_rust_py.PySequentialModel(input_shape=[1, 28, 28], arena_size=131072)
    # Conv2d(1‚Üí8) + ReLU + Pool
    m, s, bc = requant['0']
    nano.add_conv2d_with_requant(
        q_weights['0']['weights'].flatten().tolist(), bc, 1, 8, 3, 3, 1, 1, m, s)
    nano.add_relu()
    nano.add_max_pool2d(2, 2, 2)
    # Conv2d(8‚Üí16) + ReLU + Pool
    m, s, bc = requant['3']
    nano.add_conv2d_with_requant(
        q_weights['3']['weights'].flatten().tolist(), bc, 8, 16, 3, 3, 1, 1, m, s)
    nano.add_relu()
    nano.add_max_pool2d(2, 2, 2)
    # Flatten + Dense(784‚Üí10)
    nano.add_flatten()
    m, s, bc = requant['7']
    nano.add_dense_with_requant(
        q_weights['7']['weights'].flatten().tolist(), bc, m, s)
    return nano


In [None]:
N_TEST = min(1000, len(test_dataset))
correct_nano, match_count = 0, 0
max_diffs = []

t0 = time.time()
for i in range(N_TEST):
    image, label = test_dataset[i]
    q_image, _ = quantize_to_i8(image.numpy().flatten())
    nano_out = build_nano_model().forward(q_image.tolist())
    nano_cls = int(np.argmax(nano_out))

    with torch.no_grad():
        pt_out = model_cpu(image.unsqueeze(0)).numpy().flatten()
    pt_cls = int(np.argmax(pt_out))

    q_pt, _ = quantize_to_i8(pt_out)
    diff = np.abs(q_pt.astype(np.int32) - np.array(nano_out, dtype=np.int8).astype(np.int32))
    max_diffs.append(int(np.max(diff)))
    if nano_cls == label: correct_nano += 1
    if nano_cls == pt_cls: match_count += 1
    if (i+1) % 250 == 0: print(f'  {i+1}/{N_TEST}...')

infer_time = time.time() - t0
print(f'Done in {infer_time:.1f}s')


## üìä Results

In [None]:
nano_acc = 100. * correct_nano / N_TEST
agreement = 100. * match_count / N_TEST
total_flash = sum(info['weights'].nbytes for info in q_weights.values())

print('=' * 60)
print('       MNIST CLASSIFICATION RESULTS')
print('=' * 60)
print(f'PyTorch Accuracy:      {pt_accuracy:.2f}%')
print(f'NANO-RUST Accuracy:    {nano_acc:.2f}% (n={N_TEST})')
print(f'Classification Match:  {agreement:.1f}%')
print(f'Max i8 Diff (median):  {int(np.median(max_diffs))}')
print(f'Max i8 Diff (95th):    {int(np.percentile(max_diffs, 95))}')
print(f'Flash: {total_flash:,} bytes ({total_flash/1024:.1f}KB) | Arena: 32KB')
print('=' * 60)
print(f'{"‚úÖ PASS" if agreement > 90 else "‚ùå FAIL"}: {agreement:.1f}% agreement')


## üìù Key Takeaways

- CNN on MNIST: ~97% float32 accuracy
- i8 quantization + calibration preserves accuracy within ~2%
- Total model: ~13KB Flash ‚Äî fits on virtually any MCU
- Export to firmware via `export_to_rust()` for ESP32/STM32 deployment
