# 03 — Deep CNN with MaxPooling

**Goal**: Build a multi-layer CNN with pooling and verify quantized inference.

**Model**: `Conv(1→8) → ReLU → MaxPool(2×2) → Conv(8→16) → ReLU → Flatten → Dense(64→10)`

**What you'll learn**:
- Stacking Conv + Pool layers
- How MaxPool2D reduces spatial dimensions (halves H and W)
- Memory estimation for MCU deployment
- Arena size calculation: the largest intermediate activation

**Prerequisites**: `pip install nano-rust-py numpy torch`

In [None]:
import sys
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn

sys.path.insert(0, str(Path.cwd().parent / 'scripts'))
from nano_rust_utils import quantize_to_i8, quantize_weights, calibrate_model
import nano_rust_py

print('✅ All imports OK')

## Step 1: Deep CNN Architecture

Input: 1×8×8 image → shape flows through:

| Layer | Input Shape | Output Shape | Params | Flash |
|-------|-------------|--------------|--------|-------|
| Conv2d(1→8, 3×3, pad=1) | [1, 8, 8] | [8, 8, 8] | 80 | 80B |
| ReLU | [8, 8, 8] | [8, 8, 8] | 0 | 0 |
| MaxPool2d(2, 2) | [8, 8, 8] | [8, 4, 4] | 0 | 0 |
| Conv2d(8→16, 3×3, pad=1) | [8, 4, 4] | [16, 4, 4] | 1,168 | 1.2KB |
| ReLU | [16, 4, 4] | [16, 4, 4] | 0 | 0 |
| Flatten | [16, 4, 4] | [256] | 0 | 0 |
| Linear(256→10) | [256] | [10] | 2,570 | 2.5KB |
| **Total** | | | **3,818** | **~3.8KB** |

Arena needed: `max(8×8×8, 8×4×4, 16×4×4, 256, 10) × 2 = 1024 bytes`

In [None]:
torch.manual_seed(42)
model = nn.Sequential(
    nn.Conv2d(1, 8, 3, stride=1, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, stride=2),
    nn.Conv2d(8, 16, 3, stride=1, padding=1),
    nn.ReLU(),
    nn.Flatten(),
    nn.Linear(16 * 4 * 4, 10),
)
model.eval()
print(model)
print(f'\nParams: {sum(p.numel() for p in model.parameters()):,}')

In [None]:
# Quantize & calibrate
q_weights = quantize_weights(model)

torch.manual_seed(42)
input_tensor = torch.randn(1, 1, 8, 8)
q_input, input_scale = quantize_to_i8(input_tensor.numpy().flatten())

requant = calibrate_model(model, input_tensor, q_weights, input_scale)
print('Calibrated layers:')
for name, params in requant.items():
    if isinstance(params, tuple) and len(params) == 3:
        print(f'  Layer {name}: M={params[0]}, shift={params[1]}')

In [None]:
# PyTorch reference
with torch.no_grad():
    ref = model(input_tensor).numpy().flatten()
q_ref, _ = quantize_to_i8(ref)

# NANO-RUST model (matching the architecture)
nano = nano_rust_py.PySequentialModel([1, 8, 8], arena_size=32768)

# Conv1
m, s, b = requant['0']
nano.add_conv2d_with_requant(
    q_weights['0']['weights'].flatten().tolist(), b,
    1, 8, 3, 3, 1, 1, m, s)
nano.add_relu()

# MaxPool 2×2
nano.add_max_pool2d(2, 2, 0)

# Conv2
m, s, b = requant['3']
nano.add_conv2d_with_requant(
    q_weights['3']['weights'].flatten().tolist(), b,
    8, 16, 3, 3, 1, 1, m, s)
nano.add_relu()

# Flatten + Dense
nano.add_flatten()
m, s, b = requant['6']
nano.add_dense_with_requant(
    q_weights['6']['weights'].flatten().tolist(), b, m, s)

nano_out = np.array(nano.forward(q_input.tolist()), dtype=np.int8)

diff = np.abs(q_ref.astype(np.int32) - nano_out.astype(np.int32))
print(f'PyTorch: {q_ref.tolist()}')
print(f'NANO:    {nano_out.tolist()}')
print(f'Max diff: {np.max(diff)}, Classes match: {np.argmax(q_ref) == np.argmax(nano_out)}')
print('✅ PASSED!' if np.max(diff) <= 20 else '⚠️ Check calibration')