# 04 — Activation Functions Comparison (Scale-Aware)

Same MLP with ReLU / Sigmoid / Tanh — uses **scaled** LUT for Sigmoid/Tanh.

Key: Sigmoid/Tanh LUTs assume fixed input scale (x/16, x/32). After requantization,
actual scale is S_y. Fix: rescale `index = (i8_val * mult) >> shift` before LUT.

In [None]:
from _setup import setup_all, PROJECT_ROOT
setup_all()

In [None]:
import numpy as np
import torch
import torch.nn as nn
from nano_rust_utils import quantize_to_i8, quantize_weights, calibrate_model
import nano_rust_py

torch.manual_seed(42)
base = nn.Sequential(nn.Linear(16, 32), nn.ReLU(), nn.Linear(32, 10))
base.eval()
w0, b0 = base[0].weight.data.clone(), base[0].bias.data.clone()
w2, b2 = base[2].weight.data.clone(), base[2].bias.data.clone()

torch.manual_seed(123)
input_tensor = torch.randn(1, 16)
q_input, input_scale = quantize_to_i8(input_tensor.numpy().flatten())

activations = {'ReLU': nn.ReLU(), 'Sigmoid': nn.Sigmoid(), 'Tanh': nn.Tanh()}
results = {}

for act_name, act_module in activations.items():
    model = nn.Sequential(nn.Linear(16, 32), act_module, nn.Linear(32, 10))
    model[0].weight.data = w0.clone()
    model[0].bias.data = b0.clone()
    model[2].weight.data = w2.clone()
    model[2].bias.data = b2.clone()
    model.eval()

    q_weights = quantize_weights(model)
    requant = calibrate_model(model, input_tensor, q_weights, input_scale)

    with torch.no_grad():
        pytorch_out = model(input_tensor).numpy().flatten()

    nano = nano_rust_py.PySequentialModel(input_shape=[16], arena_size=8192)
    
    # Layer 0: Dense(16->32)
    m0r, s0r, b0r = requant['0']
    nano.add_dense_with_requant(
        q_weights['0']['weights'].flatten().tolist(), b0r, m0r, s0r)
    
    # Layer 1: Activation (scale-aware for Sigmoid/Tanh)
    if act_name == 'ReLU':
        nano.add_relu()
    elif act_name == 'Sigmoid':
        _, sm, ss = requant['1']  # ('sigmoid', mult, shift)
        nano.add_sigmoid_scaled(sm, ss)
    elif act_name == 'Tanh':
        _, sm, ss = requant['1']  # ('tanh', mult, shift)
        nano.add_tanh_scaled(sm, ss)
    
    # Layer 2: Dense(32->10)
    m2r, s2r, b2r = requant['2']
    nano.add_dense_with_requant(
        q_weights['2']['weights'].flatten().tolist(), b2r, m2r, s2r)
    nano_out = nano.forward(q_input.tolist())

    q_pytorch, _ = quantize_to_i8(pytorch_out)
    nano_arr = np.array(nano_out, dtype=np.int8)
    diff = np.abs(q_pytorch.astype(np.int32) - nano_arr.astype(np.int32))
    results[act_name] = {
        'max_diff': int(np.max(diff)), 'mean_diff': float(np.mean(diff)),
        'pytorch_class': int(np.argmax(q_pytorch)), 'nano_class': int(np.argmax(nano_arr)),
    }

print('=' * 70)
print(f'{"Activation":<12} {"Max Diff":<12} {"Mean Diff":<12} {"Class Match":<12} {"Result"}')
print('-' * 70)
for name, r in results.items():
    tol = 20
    match = 'Yes' if r['pytorch_class'] == r['nano_class'] else 'No'
    passed = '✅ PASS' if r['max_diff'] <= tol else '❌ FAIL'
    print(f'{name:<12} {r["max_diff"]:<12} {r["mean_diff"]:<12.2f} {match:<12} {passed} (tol={tol})')
print('=' * 70)