# 05 — Transfer Learning: Frozen Backbone + Trainable Head

- **Frozen backbone** (Flash): `Conv2D(1→4) → ReLU`
- **Trainable head** (RAM): `Flatten → Dense(400→10)`

In [None]:
from _setup import setup_all, PROJECT_ROOT
setup_all()

In [None]:
import numpy as np
import torch
import torch.nn as nn
from nano_rust_utils import quantize_to_i8, quantize_weights, calibrate_model
import nano_rust_py

torch.manual_seed(42)
full_model = nn.Sequential(
    nn.Conv2d(1, 4, kernel_size=3, stride=1, padding=0),  # [1,12,12]→[4,10,10]
    nn.ReLU(),
    nn.Flatten(),
    nn.Linear(400, 10),
)
full_model.eval()

q_weights = quantize_weights(full_model)

torch.manual_seed(123)
input_tensor = torch.randn(1, 1, 12, 12)
q_input, input_scale = quantize_to_i8(input_tensor.numpy().flatten())

requant = calibrate_model(full_model, input_tensor, q_weights, input_scale)
print('Requant:', {k: (m, s) for k, (m, s, _) in requant.items()})

In [None]:
with torch.no_grad():
    pytorch_out = full_model(input_tensor).numpy().flatten()

nano = nano_rust_py.PySequentialModel(input_shape=[1, 12, 12], arena_size=65536)

m0, s0, b0 = requant['0']
nano.add_conv2d_with_requant(
    q_weights['0']['weights'].flatten().tolist(), b0,
    1, 4, 3, 3, 1, 0, m0, s0)
nano.add_relu()
nano.add_flatten()

m3, s3, b3 = requant['3']
nano.add_dense_with_requant(
    q_weights['3']['weights'].flatten().tolist(), b3, m3, s3)

nano_out = nano.forward(q_input.tolist())

q_pytorch, _ = quantize_to_i8(pytorch_out)
nano_arr = np.array(nano_out, dtype=np.int8)
diff = np.abs(q_pytorch.astype(np.int32) - nano_arr.astype(np.int32))

print(f'PyTorch (i8):  {q_pytorch.tolist()}')
print(f'NANO-RUST:     {nano_arr.tolist()}')
print(f'Max diff: {np.max(diff)}, Mean diff: {np.mean(diff):.2f}')
print(f'Class: PyTorch={np.argmax(q_pytorch)}, NANO-RUST={np.argmax(nano_arr)}')
print(f'\n{"✅ PASS" if np.max(diff) <= 20 else "❌ FAIL"} (tolerance=20)')

In [None]:
# Memory budget
conv_bytes = q_weights['0']['weights'].nbytes + len(requant['0'][2])
dense_bytes = q_weights['3']['weights'].nbytes + len(requant['3'][2])
print(f'\nBackbone (Flash): {conv_bytes:>6} bytes ({conv_bytes/1024:.1f} KB)')
print(f'Head (RAM):       {dense_bytes:>6} bytes ({dense_bytes/1024:.1f} KB)')
print(f'Total:            {conv_bytes+dense_bytes:>6} bytes ({(conv_bytes+dense_bytes)/1024:.1f} KB)')
print(f'\nFits in ESP32 (520KB RAM)? {"✅ Yes" if dense_bytes < 520*1024 else "❌ No"}')
print(f'Fits in STM32F4 (192KB RAM)? {"✅ Yes" if dense_bytes < 192*1024 else "❌ No"}')