In [None]:
import torch
from torch.utils.data import TensorDataset, DataLoader
from threeSpiral import make_3_spirals, train, accuracy, ShallowMLP, replace_relu_with_ftrelu, FaultTolerantReLU, FTReluConfig, MLP


device = "cpu"
seed = 0
torch.manual_seed(seed)

X, y = make_3_spirals(n_per_class=600, noise=0.25, seed=seed)
idx = torch.randperm(len(X))
split = int(0.8 * len(X))

train_loader = DataLoader(
    TensorDataset(
        torch.from_numpy(X[idx[:split]]),
        torch.from_numpy(y[idx[:split]])
    ),
    batch_size=256, shuffle=True
)
test_loader = DataLoader(
    TensorDataset(
        torch.from_numpy(X[idx[split:]]),
        torch.from_numpy(y[idx[split:]])
    ),
    batch_size=256
)

baseline = MLP(hidden=32)
train(baseline, train_loader, test_loader, device, epochs=1000, lr=1e-3)
baseline_acc = accuracy(baseline, test_loader, device)
print("Baseline accuracy:", baseline_acc)


sigma_vals = [1e-15, 2e-6, 5e-6, 1e-5]
ft_accs = []

x_min, x_max = -26.0, 14.0

for s in sigma_vals:
    model = MLP(hidden=32)
    model.load_state_dict(baseline.state_dict())

    cfg2 = FTReluConfig(
        x_min=x_min,
        x_max=x_max,
        S=2001,
        sigma_phase=s,
        sigma_trig=0.0,
        sigma_score=0.0,
        p_syn=0.0,
    )

    replace_relu_with_ftrelu(model, FaultTolerantReLU(cfg2))
    acc = accuracy(model, test_loader, device)
    ft_accs.append(acc)
    print(f"sigma={s:.1e} → FT acc={acc:.4f}")

epoch    1 | train acc 0.358 | test acc 0.336
epoch  100 | train acc 0.406 | test acc 0.400
epoch  200 | train acc 0.386 | test acc 0.364
epoch  300 | train acc 0.388 | test acc 0.389
epoch  400 | train acc 0.422 | test acc 0.425
epoch  500 | train acc 0.449 | test acc 0.439
epoch  600 | train acc 0.480 | test acc 0.461
epoch  700 | train acc 0.494 | test acc 0.475
epoch  800 | train acc 0.519 | test acc 0.497
epoch  900 | train acc 0.535 | test acc 0.506
epoch 1000 | train acc 0.540 | test acc 0.522
Baseline accuracy: 0.5222222222222223
sigma=1.0e-15 → FT acc=0.2806


KeyboardInterrupt: 