# 02 — MLP Classification

**Goal**: Build a multi-layer perceptron (Dense→ReLU→Dense) and verify quantization.

**Model**: `Linear(784→128) → ReLU → Linear(128→10)`

This is the simplest real-world architecture: a 2-layer MLP for classifying
784-dimensional inputs (like MNIST flattened) into 10 classes.

**What you'll learn**:
- How weight quantization works (float32 → i8)
- How calibrated vs uncalibrated requantization affects accuracy
- RAM budget: these weights would use only 100KB Flash on an MCU

**Prerequisites**: `pip install nano-rust-py numpy torch`

In [None]:
import sys

import numpy as np
import torch
import torch.nn as nn

from nano_rust_py.utils import quantize_to_i8, quantize_weights, calibrate_model
import nano_rust_py

print('✅ All imports OK')

## Step 1: Define MLP Model

| Layer | Shape | Params | Flash (i8) | RAM |
|-------|-------|--------|------------|-----|
| Linear(784→128) | [784] → [128] | 100,480 | 100KB | 0 |
| ReLU | [128] → [128] | 0 | 0 | 0 |
| Linear(128→10) | [128] → [10] | 1,290 | 1.3KB | 0 |
| **Total** | | **101,770** | **~100KB** | **0** |

Arena buffer needed: `2 × max(128, 10) × 1 = 256 bytes`

In [None]:
torch.manual_seed(42)
model = nn.Sequential(
    nn.Linear(784, 128),
    nn.ReLU(),
    nn.Linear(128, 10),
)
model.eval()

total = sum(p.numel() for p in model.parameters())
print(model)
print(f'\nParameters: {total:,}')
print(f'Float32: {total * 4 / 1024:.1f} KB')
print(f'Int8:    {total / 1024:.1f} KB (4× compression)')

## Step 2: Quantize & Calibrate

The quantization pipeline:
1. `quantize_weights(model)` → extracts all weight matrices, scales each to i8
2. `quantize_to_i8(input)` → scales the input tensor to i8
3. `calibrate_model()` → computes exact requantization params by running the float model

In [None]:
q_weights = quantize_weights(model)

# Random test input (784 features, like a flattened 28×28 MNIST image)
torch.manual_seed(42)
input_tensor = torch.randn(1, 784)
q_input, input_scale = quantize_to_i8(input_tensor.numpy().flatten())
print(f'Input quantized: {q_input.shape}, scale={input_scale:.6f}')

requant = calibrate_model(model, input_tensor, q_weights, input_scale)
for name, (m, s, bc) in requant.items():
    print(f'Layer {name}: requant_m={m}, shift={s}, corrected_bias_len={len(bc)}')

## Step 3: Build NANO-RUST Model & Compare

In [None]:
# PyTorch reference
with torch.no_grad():
    pytorch_out = model(input_tensor).numpy().flatten()
q_pytorch, _ = quantize_to_i8(pytorch_out)

# NANO-RUST model
nano = nano_rust_py.PySequentialModel(input_shape=[784], arena_size=4096)

m0, s0, b0 = requant['0']
nano.add_dense_with_requant(
    q_weights['0']['weights'].flatten().tolist(), b0, m0, s0
)
nano.add_relu()

m2, s2, b2 = requant['2']
nano.add_dense_with_requant(
    q_weights['2']['weights'].flatten().tolist(), b2, m2, s2
)

nano_output = nano.forward(q_input.tolist())
nano_arr = np.array(nano_output, dtype=np.int8)

# Compare
diff = np.abs(q_pytorch.astype(np.int32) - nano_arr.astype(np.int32))
print(f'PyTorch (i8): {q_pytorch.tolist()}')
print(f'NANO-RUST:    {nano_arr.tolist()}')
print(f'Max diff: {int(np.max(diff))}, Mean diff: {float(np.mean(diff)):.2f}')
print(f'Classes match: {np.argmax(q_pytorch) == np.argmax(nano_arr)}')
print('\n✅ PASSED!' if np.max(diff) <= 10 else '\n⚠️ Large diff, check calibration')