# 05 — Transfer Learning: Frozen Backbone + Trainable Head

**Goal**: Demonstrate the hybrid memory pattern for MCU deployment.

**Architecture**:
```
FLASH (read-only, 0 bytes RAM):           RAM (trainable):
┌───────────────────────┐                 ┌─────────────────┐
│ FrozenDense(784→128)  │ ──→ ReLU ──→    │ TrainableDense  │
│ (100KB weights)       │                 │ (128→10, 1.3KB) │
└───────────────────────┘                 └─────────────────┘
```

**Why this matters for MCU**:
- **Backbone in Flash**: Pre-trained feature extractor. Loaded at boot, never modified.
  Cost: N bytes Flash, **0 bytes RAM**.
- **Head in RAM**: Small classifier layer. Can be retrained on-device with new data.
  Cost: `in × out + out` bytes RAM.

This lets you deploy a large pre-trained model but adapt it to local conditions
(e.g., a factory sensor that needs to learn the specific machine's vibration pattern).

**Prerequisites**: `pip install nano-rust-py numpy torch`

In [None]:
import sys

import numpy as np
import torch
import torch.nn as nn

from nano_rust_py.utils import quantize_to_i8, quantize_weights, calibrate_model
import nano_rust_py

print('✅ All imports OK')

## Step 1: Memory Budget Analysis

Before building, let's plan the memory:

| Component | Flash | RAM |
|-----------|-------|-----|
| Backbone weights (784×128) | 100,352 B | 0 B |
| Backbone bias (128) | 128 B | 0 B |
| Head weights (128×10) | 0 B | 1,280 B |
| Head bias (10) | 0 B | 10 B |
| Arena (intermediates) | 0 B | ~512 B |
| **Total** | **~100KB** | **~1.8KB** |
| **ESP32 Available** | **4MB** | **520KB** |

We use less than 3% of Flash and less than 1% of RAM. Plenty of room!

In [None]:
# --- Build the PyTorch model (for weight extraction) ---
torch.manual_seed(42)
backbone = nn.Linear(784, 128)
head = nn.Linear(128, 10)
full_model = nn.Sequential(backbone, nn.ReLU(), head)
full_model.eval()

backbone_size = backbone.weight.numel() + backbone.bias.numel()
head_size = head.weight.numel() + head.bias.numel()
print(f'Backbone: {backbone_size:,} params → {backbone_size:,} bytes in Flash')
print(f'Head:     {head_size:,} params → {head_size:,} bytes in RAM')
print(f'Ratio:    {backbone_size/head_size:.0f}:1 (backbone is {backbone_size/head_size:.0f}× larger)')

## Step 2: Quantize Backbone, Build Hybrid Model

In [None]:
# Quantize backbone weights
q_weights = quantize_weights(full_model)

# Calibrate
test_input = torch.randn(1, 784)
q_input, input_scale = quantize_to_i8(test_input.numpy().flatten())
requant = calibrate_model(full_model, test_input, q_weights, input_scale)

# Build NANO model with hybrid architecture
nano = nano_rust_py.PySequentialModel([784], arena_size=4096)

# FROZEN backbone (weights in Flash)
m, s, b = requant['0']
nano.add_dense_with_requant(
    q_weights['0']['weights'].flatten().tolist(), b, m, s
)
nano.add_relu()

# TRAINABLE head (weights in RAM)
# In real deployment: head can be retrained on-device
nano.add_trainable_dense(128, 10)

print('✅ Hybrid model built')
print('   Backbone: FrozenDense (Flash, 0 RAM)')
print('   Head: TrainableDense (RAM, retrainable)')

## Step 3: Run Inference

In [None]:
output = nano.forward(q_input.tolist())
predicted = nano.predict(q_input.tolist())

print(f'Raw output (i8): {output}')
print(f'Predicted class: {predicted}')
print(f'\nNote: Since TrainableDense initializes with small random weights,')
print(f'the output is essentially random. In real deployment, you would')
print(f'retrain the head on-device with local data.')

## Summary: The Hybrid Memory Strategy

```
┌──────────────────────────────────────────────────────┐
│  PRE-DEPLOYMENT (PC/GPU)                             │
│  1. Train backbone on large dataset                  │
│  2. Quantize backbone → i8                           │
│  3. Export to .rs (static arrays in Flash)            │
├──────────────────────────────────────────────────────┤
│  ON-DEVICE (MCU)                                     │
│  4. Load backbone from Flash (free!)                 │
│  5. Initialize small head in RAM                     │
│  6. Collect local samples                            │
│  7. Retrain head (gradient descent in i8)            │
│  8. Run inference: backbone(x) → head(features)      │
└──────────────────────────────────────────────────────┘
```