In [1]:
# ================================================================
# CELL 1: Setup, Configuration, and Data Loading (v4 - Phase 2)
# ================================================================
# Phase 2: 4-Layer Hinton Comparison
# Includes:
#   - XOR (4-class), MNIST, FashionMNIST, Pendigits, LetterRecog
#   - init_layer_weights with Gabor support
#   - Higher LR (0.1) and larger batch (256) for faster convergence
#   - Reduced patience (20) and epochs (300)
# ================================================================

import os, sys, json, time, copy, random, warnings, csv
from datetime import datetime
from collections import defaultdict
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import torchvision
import torchvision.transforms as transforms

import matplotlib
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.gridspec import GridSpec

warnings.filterwarnings('ignore')

# ---- Mount Google Drive ----
try:
    from google.colab import drive
    drive.mount('/content/drive')
    IN_COLAB = True
    print('\u2713 Google Drive mounted')
except:
    IN_COLAB = False
    print('Not in Colab — using local paths')

# ================================================================
# GLOBAL CONFIGURATION — PHASE 2 OPTIMIZED
# ================================================================
CONFIG = {
    # ----- Paths -----
    'base_path': '/content/drive/My Drive/Research/ModularFF/' if IN_COLAB else './ModularFF/',

    # ----- Which datasets to run -----
    #'datasets_to_run': ['XOR', 'MNIST', 'FashionMNIST', 'Pendigits', 'LetterRecog'],

    #'datasets_to_run': ['MNIST', 'FashionMNIST', 'Pendigits', 'LetterRecog'],
    'datasets_to_run': ['FashionMNIST'],


    # ----- Data Split Ratios -----
    'split_ratios': (0.70, 0.15, 0.15),

    # ----- Reproducibility -----
    'seed': 42,
    'seeds': [42],

    # ----- Device -----
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',

    # ----- Training (PHASE 2 OPTIMIZED) -----
    'lr': 0.01,                    # Adam LR for 4-layer networks
    'epochs': 500,                 # Max epochs (increased from 300)
    'batch_size': 256,             # 2x larger for stability
    'early_stop_patience': 60,     # Total patience before stopping
    'min_epochs': 50,              # No early stopping before this epoch
    'lr_reduce_patience': 20,      # Reduce LR every N epochs of no improvement
    'lr_reduce_factor': 0.5,       # Halve LR on reduction
    'hinton_sgd_lr': 0.03,         # Hinton's original SGD learning rate

    # ----- FF Thresholds -----
    'theta_neuron': 1.0,

    # ----- Hybrid Loss (Phase 2: only 3 values) -----
    'alpha_values': [0.0, 0.5, 1.0],

    # ----- Per-Layer Dropout (disabled for Phase 2) -----
    'layer_dropout': None,
    'layer_dropout_ablation': [None],  # Skip dropout ablation

    # ----- Meta-layers -----
    'meta_layer': 'argmax',
    'meta_layers_to_compare': ['argmax', 'calibrated', 'linear', 'mlp'],
    'use_meta_layer': True,

    # ----- First Layer Initialization -----
    'init_method': 'kaiming',

    # ----- ELM Mode -----
    'elm_mode': False,

    # ----- Pruning -----
    'pruning_enabled': False,
    'prune_beta': 2.0,
    'prune_after_epochs': 10,
    'prune_keep_ratio': 0.5,
}

# ----- Architecture per Dataset -----
# Phase 2: 4-layer Hinton comparison (parameter-matched)
ARCHITECTURES = {
    'XOR': {
        'modularff_archs': [[50, 50], [100, 100]],
        'classic_ff': [100, 100],
        'classic_ff_4L': [100, 100, 100, 100],
        'modularff_4L': [50, 50, 50, 50],
        'bp':         [100, 100],
        'input_dim':  2,
        'img_size':   None,
    },
    'MNIST': {
        'modularff_archs': [[50, 50], [100, 100]],
        'classic_ff': [500, 500],
        'classic_ff_4L': [2000, 2000, 2000, 2000],  # ~13.6M params (Hinton's)
        'modularff_4L': [550, 550, 550, 550],       # ~13.5M params (×10 specialists)
        'bp':         [500, 500],
        'input_dim':  784,
        'img_size':   28,
    },
    'FashionMNIST': {
        'modularff_archs': [[50, 50], [100, 100]],
        'classic_ff': [500, 500],
        'classic_ff_4L': [2000, 2000, 2000, 2000],  # ~13.6M params
        'modularff_4L': [550, 550, 550, 550],       # ~13.5M params (×10 specialists)
        'bp':         [500, 500],
        'input_dim':  784,
        'img_size':   28,
    },
    'Pendigits': {
        'modularff_archs': [[50, 50], [100, 100]],
        'classic_ff': [200, 200],
        'classic_ff_4L': [500, 500, 500, 500],      # ~775K params
        'modularff_4L': [150, 150, 150, 150],       # ~760K params (×10 specialists)
        'bp':         [200, 200],
        'input_dim':  16,
        'img_size':   None,
    },
    'LetterRecog': {
        'modularff_archs': [[50, 50], [100, 100]],
        'classic_ff': [400, 400],
        'classic_ff_4L': [800, 800, 800, 800],      # ~2.0M params
        'modularff_4L': [150, 150, 150, 150],       # ~2.0M params (×26 specialists)
        'bp':         [400, 400],
        'input_dim':  16,
        'img_size':   None,
    },
}

# Derived paths
for key, folder in [('data_path', 'Data'), ('results_path', 'Results'),
                    ('models_path', 'Models'), ('figures_path', 'Figures'),
                    ('logs_path', 'Logs')]:
    CONFIG[key] = os.path.join(CONFIG['base_path'], folder + '/')

# Create directory tree
dirs_to_create = [
    CONFIG['data_path'],
    CONFIG['results_path'],
    CONFIG['models_path'],
    CONFIG['figures_path'],
    os.path.join(CONFIG['figures_path'], 'XOR'),
    os.path.join(CONFIG['figures_path'], 'convergence'),
    CONFIG['logs_path'],
]
for ds in CONFIG['datasets_to_run']:
    dirs_to_create.append(os.path.join(CONFIG['results_path'], ds))
    dirs_to_create.append(os.path.join(CONFIG['data_path'], ds))

for d in dirs_to_create:
    os.makedirs(d, exist_ok=True)

print(f'\u2713 Config ready')
print(f'  Device: {CONFIG["device"]}')
print(f'  Base path: {CONFIG["base_path"]}')
print(f'  PHASE 2 SETTINGS:')
print(f'    LR: {CONFIG["lr"]} (10x higher)')
print(f'    Batch size: {CONFIG["batch_size"]}')
print(f'    Patience: {CONFIG["early_stop_patience"]}')
print(f'    Epochs: {CONFIG["epochs"]}')
print(f'    Alpha values: {CONFIG["alpha_values"]}')


# ================================================================
# UTILITIES
# ================================================================

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

set_seed(CONFIG['seed'])


# ================================================================
# INITIALIZATION METHODS
# ================================================================

def gabor_filter_2d(size, theta, freq, sigma, phase, center):
    cx, cy = center
    y, x = np.ogrid[:size, :size]
    x = x - cx
    y = y - cy
    x_rot = x * np.cos(theta) + y * np.sin(theta)
    y_rot = -x * np.sin(theta) + y * np.cos(theta)
    gaussian = np.exp(-(x_rot**2 + y_rot**2) / (2 * sigma**2))
    sinusoid = np.cos(2 * np.pi * freq * x_rot + phase)
    gabor = gaussian * sinusoid
    gabor = gabor / (np.linalg.norm(gabor) + 1e-8)
    return gabor


def create_gabor_weights(n_neurons, img_size=28):
    weights = []
    for _ in range(n_neurons):
        theta = np.random.uniform(0, np.pi)
        freq = np.random.uniform(0.05, 0.4)
        sigma = np.random.uniform(2, 6)
        phase = np.random.uniform(0, 2 * np.pi)
        margin = int(sigma * 2)
        cx = np.random.randint(margin, img_size - margin)
        cy = np.random.randint(margin, img_size - margin)
        gabor = gabor_filter_2d(img_size, theta, freq, sigma, phase, (cx, cy))
        weights.append(gabor.flatten())
    return np.array(weights, dtype=np.float32)


def init_layer_weights(layer, method='kaiming', img_size=None):
    if method == 'kaiming':
        nn.init.kaiming_uniform_(layer.weight, nonlinearity='relu')
        if layer.bias is not None:
            nn.init.zeros_(layer.bias)
    elif method == 'xavier':
        nn.init.xavier_uniform_(layer.weight)
        if layer.bias is not None:
            nn.init.zeros_(layer.bias)
    elif method == 'gabor':
        if img_size is None:
            print(f"  Warning: Gabor needs img_size. Using Kaiming.")
            nn.init.kaiming_uniform_(layer.weight, nonlinearity='relu')
        else:
            n_neurons = layer.out_features
            input_dim = layer.in_features
            expected_dim = img_size * img_size
            if input_dim != expected_dim:
                print(f"  Warning: Gabor expects {expected_dim}D. Using Kaiming.")
                nn.init.kaiming_uniform_(layer.weight, nonlinearity='relu')
            else:
                gabor_w = create_gabor_weights(n_neurons, img_size)
                with torch.no_grad():
                    layer.weight.copy_(torch.from_numpy(gabor_w))
        if layer.bias is not None:
            nn.init.zeros_(layer.bias)
    else:
        raise ValueError(f"Unknown init method: {method}")


# ================================================================
# DATASET LOADERS
# ================================================================

def load_xor(n_samples=2000, gap=0.05, seed=42, split_ratios=(0.70, 0.15, 0.15)):
    """Load synthetic 4-class XOR dataset."""
    train_r, val_r, test_r = split_ratios

    rng = np.random.RandomState(seed)
    X_all, y_all = [], []
    while sum(len(a) for a in X_all) < n_samples:
        x = rng.uniform(-1, 1, size=(n_samples * 2, 2))
        keep = (np.abs(x[:, 0]) > gap) & (np.abs(x[:, 1]) > gap)
        x = x[keep]
        labels = np.zeros(len(x), dtype=np.int64)
        labels[(x[:, 0] > 0) & (x[:, 1] > 0)] = 0
        labels[(x[:, 0] < 0) & (x[:, 1] > 0)] = 1
        labels[(x[:, 0] < 0) & (x[:, 1] < 0)] = 2
        labels[(x[:, 0] > 0) & (x[:, 1] < 0)] = 3
        X_all.append(x); y_all.append(labels)
    X = np.concatenate(X_all)[:n_samples]
    y = np.concatenate(y_all)[:n_samples]

    # Split: train / (val + test)
    test_val_r = val_r + test_r
    Xtr, Xtmp, ytr, ytmp = train_test_split(
        X, y, test_size=test_val_r, random_state=seed, stratify=y
    )
    # Split: val / test
    val_of_tmp = val_r / test_val_r
    Xv, Xte, yv, yte = train_test_split(
        Xtmp, ytmp, test_size=(1 - val_of_tmp), random_state=seed, stratify=ytmp
    )
    return Xtr, Xv, Xte, ytr, yv, yte, 4, 2


def load_mnist(data_path, split_ratios=(0.70, 0.15, 0.15)):
    """Load MNIST."""
    train_r, val_r, test_r = split_ratios

    tr = torchvision.datasets.MNIST(root=data_path, train=True, download=True, transform=transforms.ToTensor())
    te = torchvision.datasets.MNIST(root=data_path, train=False, download=True, transform=transforms.ToTensor())

    Xf = tr.data.float().view(-1, 784) / 255.0
    yf = tr.targets.numpy()
    Xte = te.data.float().view(-1, 784) / 255.0
    yte = te.targets.numpy()

    val_from_train = val_r / (train_r + val_r)
    Xtr, Xv, ytr, yv = train_test_split(
        Xf.numpy(), yf, test_size=val_from_train, random_state=42, stratify=yf
    )
    return Xtr, Xv, Xte.numpy(), ytr, yv, yte, 10, 784


def load_fashion_mnist(data_path, split_ratios=(0.70, 0.15, 0.15)):
    """Load Fashion-MNIST."""
    train_r, val_r, test_r = split_ratios

    tr = torchvision.datasets.FashionMNIST(root=data_path, train=True, download=True, transform=transforms.ToTensor())
    te = torchvision.datasets.FashionMNIST(root=data_path, train=False, download=True, transform=transforms.ToTensor())

    Xf = tr.data.float().view(-1, 784) / 255.0
    yf = tr.targets.numpy()
    Xte = te.data.float().view(-1, 784) / 255.0
    yte = te.targets.numpy()

    val_from_train = val_r / (train_r + val_r)
    Xtr, Xv, ytr, yv = train_test_split(
        Xf.numpy(), yf, test_size=val_from_train, random_state=42, stratify=yf
    )
    return Xtr, Xv, Xte.numpy(), ytr, yv, yte, 10, 784


def load_pendigits(data_path, split_ratios=(0.70, 0.15, 0.15)):
    """Load UCI Pendigits."""
    train_r, val_r, test_r = split_ratios

    import urllib.request
    urls = {
        'pendigits.tra': 'https://archive.ics.uci.edu/ml/machine-learning-databases/pendigits/pendigits.tra',
        'pendigits.tes': 'https://archive.ics.uci.edu/ml/machine-learning-databases/pendigits/pendigits.tes',
    }
    for fname, url in urls.items():
        fpath = os.path.join(data_path, fname)
        if not os.path.exists(fpath):
            print(f'  Downloading {fname}...')
            urllib.request.urlretrieve(url, fpath)

    train_data = np.loadtxt(os.path.join(data_path, 'pendigits.tra'), delimiter=',')
    test_data = np.loadtxt(os.path.join(data_path, 'pendigits.tes'), delimiter=',')

    Xtr_f, ytr_f = train_data[:, :-1], train_data[:, -1].astype(np.int64)
    Xte, yte = test_data[:, :-1], test_data[:, -1].astype(np.int64)

    sc = StandardScaler()
    Xtr_f = sc.fit_transform(Xtr_f)
    Xte = sc.transform(Xte)

    val_from_train = val_r / (train_r + val_r)
    Xtr, Xv, ytr, yv = train_test_split(
        Xtr_f, ytr_f, test_size=val_from_train, random_state=42, stratify=ytr_f
    )
    return Xtr, Xv, Xte, ytr, yv, yte, 10, 16


def load_letters(data_path, split_ratios=(0.70, 0.15, 0.15)):
    """Load UCI Letter Recognition."""
    train_r, val_r, test_r = split_ratios

    import urllib.request
    url = 'https://archive.ics.uci.edu/ml/machine-learning-databases/letter-recognition/letter-recognition.data'
    fpath = os.path.join(data_path, 'letter-recognition.data')
    if not os.path.exists(fpath):
        print('  Downloading letter-recognition.data...')
        urllib.request.urlretrieve(url, fpath)

    rows = []
    with open(fpath, 'r') as f:
        for line in f:
            parts = line.strip().split(',')
            label = ord(parts[0]) - ord('A')
            feats = [int(x) for x in parts[1:]]
            rows.append(feats + [label])
    data = np.array(rows)
    X, y = data[:, :-1].astype(np.float64), data[:, -1].astype(np.int64)

    sc = StandardScaler()
    X = sc.fit_transform(X)

    Xtr_f, Xte = X[:16000], X[16000:]
    ytr_f, yte = y[:16000], y[16000:]

    val_from_train = val_r / (train_r + val_r)
    Xtr, Xv, ytr, yv = train_test_split(
        Xtr_f, ytr_f, test_size=val_from_train, random_state=42, stratify=ytr_f
    )
    return Xtr, Xv, Xte, ytr, yv, yte, 26, 16


# ================================================================
# LOAD DATASETS
# ================================================================
print('\n' + '='*60)
print(' Loading datasets (Phase 2)')
print('='*60)

DATASETS = {}

all_loaders = {
    'XOR':        (load_xor, {'split_ratios': CONFIG['split_ratios']}),
    'MNIST':      (load_mnist, {'data_path': os.path.join(CONFIG['data_path'], 'MNIST'),
                                'split_ratios': CONFIG['split_ratios']}),
    'FashionMNIST': (load_fashion_mnist, {'data_path': os.path.join(CONFIG['data_path'], 'FashionMNIST'),
                                          'split_ratios': CONFIG['split_ratios']}),
    'Pendigits':  (load_pendigits, {'data_path': os.path.join(CONFIG['data_path'], 'Pendigits'),
                                    'split_ratios': CONFIG['split_ratios']}),
    'LetterRecog': (load_letters, {'data_path': os.path.join(CONFIG['data_path'], 'LetterRecog'),
                                   'split_ratios': CONFIG['split_ratios']}),
}

for name in CONFIG['datasets_to_run']:
    if name not in all_loaders:
        print(f'  WARNING: Unknown dataset {name}, skipping')
        continue
    loader, args = all_loaders[name]
    result = loader(**args)
    Xtr, Xv, Xte, ytr, yv, yte, K, dim = result
    DATASETS[name] = {
        'X_train': Xtr, 'X_val': Xv, 'X_test': Xte,
        'y_train': ytr, 'y_val': yv, 'y_test': yte,
        'num_classes': K, 'input_dim': dim,
    }
    print(f'  \u2713 {name:12s}: K={K:2d}, dim={dim:4d}, '
          f'train={len(Xtr):6d}, val={len(Xv):5d}, test={len(Xte):5d}')

print('='*60)
print(f'\u2713 Loaded {len(DATASETS)} dataset(s): {list(DATASETS.keys())}')
print(f'\u2713 PHASE 2 CONFIG: LR={CONFIG["lr"]}, batch={CONFIG["batch_size"]}, patience={CONFIG["early_stop_patience"]}')

Mounted at /content/drive
✓ Google Drive mounted
✓ Config ready
  Device: cuda
  Base path: /content/drive/My Drive/Research/ModularFF/
  PHASE 2 SETTINGS:
    LR: 0.01 (10x higher)
    Batch size: 256
    Patience: 60
    Epochs: 500
    Alpha values: [0.0, 0.5, 1.0]

 Loading datasets (Phase 2)
  ✓ FashionMNIST: K=10, dim= 784, train= 49411, val=10589, test=10000
✓ Loaded 1 dataset(s): ['FashionMNIST']
✓ PHASE 2 CONFIG: LR=0.01, batch=256, patience=60


In [None]:
######### end of cell 1

In [2]:
# ================================================================
# CELL 2: Core Model Classes (v5)
# ================================================================
# v5 Changes:
#   - FFLayer: goodness uses MEAN (not sum) — scale-invariant
#   - FFLayer: theta_layer = 1.0 (not n_neurons) — fixes sigmoid saturation
#   - FFLayer: per-neuron loss uses MEAN over neurons (not sum) — balanced
#     with layer-level loss so alpha truly interpolates between them
#   - FFLayer: optimizer='adam'|'sgd' parameter
#   - ClassicFF: optimizer='adam'|'sgd' parameter
#   - ClassicFF_Additive: inline goodness also fixed to mean + theta=1.0
#   - Per-neuron theta_neuron=1.0: UNCHANGED
# ================================================================
#   - v6: activation parameter ('relu', 'gelu', 'swish') threaded
#     through FFLayer, ClassicFF, ModularFF, BPBaseline
# ================================================================

class HardLimitSTE(nn.Module):
    """Hard-limit (step) activation with straight-through estimator.
    Forward: output = 1 if x > 0 else 0
    Backward: gradient passes through as if identity (STE)
    """
    def forward(self, x):
        return x + (torch.heaviside(x, torch.tensor(0.5, device=x.device)) - x).detach()


def make_activation(name='relu'):
    """Create activation module by name."""
    name = name.lower()
    if name == 'gelu':
        return nn.GELU()
    elif name in ('swish', 'silu'):
        return nn.SiLU()
    elif name == 'tanh':
        return nn.Tanh()
    elif name in ('hardlimit', 'hardlim', 'step'):
        return HardLimitSTE()
    else:
        return nn.ReLU()


def get_default_theta(activation_name):
    """Return appropriate theta_layer for each activation type."""
    name = activation_name.lower()
    if name == 'tanh':
        return 0.0    # tanh outputs in [-1,1]; goodness=mean(h), natural boundary at 0
    elif name in ('hardlimit', 'hardlim', 'step', 'perceptron'):
        return 0.5    # binary outputs; goodness=mean(h)=fraction active, expect ~50%
    else:
        return 1.0    # ReLU, GELU: goodness=mean(h^2), original theta


class FFLayer(nn.Module):
    """Single Forward-Forward layer with hybrid goodness objective."""

    def __init__(self, in_features, out_features, lr=0.001, theta_neuron=1.0, learnable_theta=False,
                 init_method='kaiming', frozen=False, img_size=None,
                 optimizer='adam', activation='relu'):
        super().__init__()
        self.linear = nn.Linear(in_features, out_features)
        self.act = make_activation(activation)
        self.activation_name = activation
        self.n_neurons = out_features
        self.learnable_theta = learnable_theta
        default_theta = get_default_theta(activation)
        if learnable_theta:
            self.theta_layer = nn.Parameter(torch.tensor(default_theta))
        else:
            self.theta_layer = default_theta
        self.theta_neuron = theta_neuron if activation not in ('tanh', 'hardlimit', 'hardlim', 'step') else default_theta
        self.frozen = frozen
        self.optimizer_type = optimizer

        init_layer_weights(self.linear, method=init_method, img_size=img_size)

        if not frozen:
            if optimizer == 'sgd':
                self.opt = torch.optim.SGD(self.parameters(), lr=lr)
            else:
                self.opt = torch.optim.Adam(self.parameters(), lr=lr)
        else:
            self.opt = None
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, x):
        return self.act(self.linear(x))

    def forward_norm(self, x):
        h = self.forward(x)
        h_n = h / (h.norm(dim=1, keepdim=True) + 1e-8)
        return h, h_n

    def goodness(self, h):
        if self.activation_name in ('tanh', 'hardlimit', 'hardlim', 'step', 'perceptron'):
            return h.mean(dim=1)  # Mean activation: natural for signed/binary outputs
        return (h ** 2).mean(dim=1)  # Mean-squared-goodness: for ReLU/GELU (non-negative outputs)

    def train_step(self, x_pos, x_neg, alpha=0.0, k_pct=0):
        h_pos, h_pos_n = self.forward_norm(x_pos)
        h_neg, h_neg_n = self.forward_norm(x_neg)

        g_pos = self.goodness(h_pos)
        g_neg = self.goodness(h_neg)

        # Store goodness stats for monitoring (no grad impact)
        self._last_g_pos_mean = g_pos.mean().item()
        self._last_g_neg_mean = g_neg.mean().item()
        self._last_g_sep = self._last_g_pos_mean - self._last_g_neg_mean
        self._last_theta = self.theta_layer.item() if isinstance(self.theta_layer, nn.Parameter) else self.theta_layer

        loss_layer = (
            -torch.log(torch.sigmoid(g_pos - self.theta_layer) + 1e-8).mean()
            - torch.log(1 - torch.sigmoid(g_neg - self.theta_layer) + 1e-8).mean()
        )

        loss_local = torch.tensor(0.0, device=x_pos.device)
        if alpha > 0:
            gn_pos = h_pos ** 2
            gn_neg = h_neg ** 2
            pn_pos = torch.sigmoid(gn_pos - self.theta_neuron)
            pn_neg = torch.sigmoid(gn_neg - self.theta_neuron)
            ln_pos = -torch.log(pn_pos + 1e-8)
            ln_neg = -torch.log(1 - pn_neg + 1e-8)

            if 0 < k_pct < 100:
                mask = torch.bernoulli(
                    torch.full((1, self.n_neurons), k_pct / 100.0, device=x_pos.device)
                )
                ln_pos = ln_pos * mask
                ln_neg = ln_neg * mask

            loss_local = ln_pos.mean(1).mean() + ln_neg.mean(1).mean()  # mean over neurons, then mean over batch

        loss = (1 - alpha) * loss_layer + alpha * loss_local

        if not self.frozen and self.opt is not None:
            self.opt.zero_grad()
            loss.backward()
            self.opt.step()

        return loss.item(), h_pos_n.detach(), h_neg_n.detach()

    @torch.no_grad()
    def infer(self, x):
        h, h_n = self.forward_norm(x)
        return self.goodness(h), h_n

    @torch.no_grad()
    def get_activations(self, x):
        return self.forward(x)


# ================================================================
# META-LAYER CLASS
# ================================================================

class PerceptronFFLayer(nn.Module):
    """Gradient-free FF layer using perceptron learning rule.

    Forward: h = step(Wx + b)  (binary 0/1 outputs)
    Goodness: mean(h) = fraction of active neurons
    Learning: No sigmoid loss, no gradients.
      - Positive data with goodness < theta: w += lr * x  (strengthen)
      - Negative data with goodness > theta: w -= lr * x  (weaken)
    """

    def __init__(self, in_features, out_features, lr=0.001, theta_neuron=0.5,
                 init_method='kaiming', frozen=False, img_size=None,
                 optimizer='adam', activation='perceptron'):
        super().__init__()
        self.linear = nn.Linear(in_features, out_features)
        self.n_neurons = out_features
        self.theta_layer = 0.5
        self.theta_neuron = 0.5
        self.frozen = frozen
        self.lr = lr
        self.activation_name = 'perceptron'
        self.optimizer_type = 'perceptron'

        init_layer_weights(self.linear, method=init_method, img_size=img_size)

        if frozen:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, x):
        return torch.heaviside(self.linear(x), torch.tensor(0.5, device=x.device))

    def forward_norm(self, x):
        h = self.forward(x)
        h_n = h / (h.norm(dim=1, keepdim=True) + 1e-8)
        return h, h_n

    def goodness(self, h):
        return h.mean(dim=1)  # Fraction of active neurons

    @torch.no_grad()
    def train_step(self, x_pos, x_neg, alpha=0.0, k_pct=0):
        """Perceptron-style update: no gradients, no loss function.

        alpha controls layer-level vs neuron-level error signals:
          alpha=0: layer-level — update all neurons when mean goodness is wrong
          alpha=1: neuron-level — update each neuron based on its own firing error
          0<alpha<1: blend of both signals

        This is the perceptron analogue of FFLayer's alpha parameter,
        operating entirely without gradients or a loss function.
        """
        h_pos = self.forward(x_pos)   # [B, n_neurons], binary {0, 1}
        h_neg = self.forward(x_neg)
        h_pos_n = h_pos / (h_pos.norm(dim=1, keepdim=True) + 1e-8)
        h_neg_n = h_neg / (h_neg.norm(dim=1, keepdim=True) + 1e-8)

        g_pos = self.goodness(h_pos)   # [B], mean fraction of active neurons
        g_neg = self.goodness(h_neg)
        B = x_pos.size(0)

        # Store goodness stats for monitoring
        self._last_g_pos_mean = g_pos.mean().item()
        self._last_g_neg_mean = g_neg.mean().item()
        self._last_g_sep = self._last_g_pos_mean - self._last_g_neg_mean
        self._last_theta = self.theta_layer.item() if isinstance(self.theta_layer, nn.Parameter) else self.theta_layer

        if not self.frozen:
            # ============================================================
            # LAYER-LEVEL UPDATE (alpha=0 component)
            # If mean goodness is wrong for a sample, nudge ALL neurons
            # ============================================================
            dw_layer = torch.zeros_like(self.linear.weight.data)
            db_layer = torch.zeros_like(self.linear.bias.data)

            pos_err_layer = (g_pos < self.theta_layer).float()  # [B]
            neg_err_layer = (g_neg > self.theta_layer).float()  # [B]

            n_pos_l = pos_err_layer.sum().item()
            n_neg_l = neg_err_layer.sum().item()

            if n_pos_l > 0:
                # Outer product: each neuron gets same input-weighted update
                dw_layer += (pos_err_layer.unsqueeze(1) * x_pos).mean(dim=0).unsqueeze(0).expand_as(self.linear.weight)
                db_layer += pos_err_layer.mean()

            if n_neg_l > 0:
                dw_layer -= (neg_err_layer.unsqueeze(1) * x_neg).mean(dim=0).unsqueeze(0).expand_as(self.linear.weight)
                db_layer -= neg_err_layer.mean()

            # ============================================================
            # NEURON-LEVEL UPDATE (alpha=1 component)
            # Each neuron has its own target:
            #   Positive data: neuron SHOULD fire (target=1)
            #   Negative data: neuron should NOT fire (target=0)
            # Update only the neurons that made wrong individual decisions
            # ============================================================
            dw_neuron = torch.zeros_like(self.linear.weight.data)
            db_neuron = torch.zeros_like(self.linear.bias.data)

            # Positive: neurons that didn't fire but should have
            neuron_err_pos = (1.0 - h_pos)  # [B, n_neurons], 1 where neuron failed
            # Negative: neurons that fired but shouldn't have
            neuron_err_neg = h_neg           # [B, n_neurons], 1 where neuron failed

            # Per-neuron weight update via matrix multiply:
            # dw[j, i] = mean_over_batch(error[b, j] * input[b, i])
            # = (error.T @ input) / B  -> [n_neurons, in_features]
            dw_neuron += (neuron_err_pos.t() @ x_pos) / B   # strengthen missed pos
            dw_neuron -= (neuron_err_neg.t() @ x_neg) / B   # weaken false neg

            # Bias: per-neuron mean error
            db_neuron += neuron_err_pos.mean(dim=0)   # [n_neurons]
            db_neuron -= neuron_err_neg.mean(dim=0)

            # ============================================================
            # BLEND and APPLY
            # ============================================================
            dw = (1.0 - alpha) * dw_layer + alpha * dw_neuron
            db = (1.0 - alpha) * db_layer + alpha * db_neuron

            self.linear.weight.data += self.lr * dw
            self.linear.bias.data += self.lr * db

        # Return pseudo-loss for compatibility
        pseudo_loss = (1.0 - (g_pos.mean() - g_neg.mean())).item()
        return pseudo_loss, h_pos_n.detach(), h_neg_n.detach()

    @torch.no_grad()
    def infer(self, x):
        h, h_n = self.forward_norm(x)
        return self.goodness(h), h_n

    @torch.no_grad()
    def get_activations(self, x):
        return self.forward(x)


class MetaLayer:
    """Meta-layer: takes K goodness values, outputs class prediction."""

    def __init__(self, num_classes, meta_type='argmax', device='cpu', hidden_dim=32):
        self.K = num_classes
        self.meta_type = meta_type
        self.device = device

        self.cal_mu = None
        self.cal_sigma = None
        self.linear = None
        self.mlp = None
        self.temps = None
        self.optimizer = None

        if meta_type == 'linear':
            self.linear = nn.Linear(num_classes, num_classes).to(device)
            self.optimizer = torch.optim.Adam(self.linear.parameters(), lr=0.01)
        elif meta_type == 'mlp':
            self.mlp = nn.Sequential(
                nn.Linear(num_classes, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, num_classes)
            ).to(device)
            self.optimizer = torch.optim.Adam(self.mlp.parameters(), lr=0.01)
        elif meta_type == 'temperature':
            self.temps = nn.Parameter(torch.ones(num_classes, device=device))
            self.optimizer = torch.optim.Adam([self.temps], lr=0.01)

    def calibrate(self, G, y):
        G_np = G.cpu().numpy() if torch.is_tensor(G) else G
        y_np = y.cpu().numpy() if torch.is_tensor(y) else y
        self.cal_mu = np.zeros(self.K)
        self.cal_sigma = np.zeros(self.K)
        for k in range(self.K):
            gk = G_np[y_np == k, k]
            self.cal_mu[k] = gk.mean() if len(gk) > 0 else 0.0
            self.cal_sigma[k] = gk.std() + 1e-8 if len(gk) > 0 else 1.0

    def train(self, G, y, epochs=100):
        if self.meta_type not in ['linear', 'mlp', 'temperature']:
            return
        G_t = G if torch.is_tensor(G) else torch.tensor(G, dtype=torch.float32, device=self.device)
        y_t = y if torch.is_tensor(y) else torch.tensor(y, dtype=torch.long, device=self.device)
        crit = nn.CrossEntropyLoss()
        for _ in range(epochs):
            if self.meta_type == 'linear':
                logits = self.linear(G_t)
            elif self.meta_type == 'mlp':
                logits = self.mlp(G_t)
            elif self.meta_type == 'temperature':
                logits = G_t / (self.temps.abs() + 1e-8)
            loss = crit(logits, y_t)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

    @torch.no_grad()
    def predict(self, G):
        G_t = G if torch.is_tensor(G) else torch.tensor(G, dtype=torch.float32, device=self.device)
        if self.meta_type == 'none':
            return G_t.cpu().numpy()
        elif self.meta_type == 'argmax':
            return G_t.argmax(1).cpu().numpy()
        elif self.meta_type == 'calibrated':
            mu = torch.tensor(self.cal_mu, dtype=torch.float32, device=self.device)
            sig = torch.tensor(self.cal_sigma, dtype=torch.float32, device=self.device)
            return ((G_t - mu) / sig).argmax(1).cpu().numpy()
        elif self.meta_type == 'linear':
            return self.linear(G_t).argmax(1).cpu().numpy()
        elif self.meta_type == 'mlp':
            return self.mlp(G_t).argmax(1).cpu().numpy()
        elif self.meta_type == 'temperature':
            return (G_t / (self.temps.abs() + 1e-8)).argmax(1).cpu().numpy()
        else:
            raise ValueError(f"Unknown meta_type: {self.meta_type}")


# ================================================================
# CLASSIC FF (Hinton's original) — Variant 1: One-Hot Overlay
# ================================================================

class ClassicFF(nn.Module):
    """Standard FF with label overlay (Hinton's original)."""

    def __init__(self, input_dim, hidden_sizes, num_classes, lr=0.001, device='cpu',
                 init_method='kaiming', img_size=None, optimizer='adam', activation='relu', learnable_theta=False):
        super().__init__()
        self.input_dim = input_dim
        self.K = num_classes
        self.device = device
        self.eff_dim = input_dim + num_classes
        self.optimizer_type = optimizer
        self.activation = activation
        self.learnable_theta = learnable_theta

        dims = [self.eff_dim] + hidden_sizes
        self.layers = nn.ModuleList()
        LayerClass = PerceptronFFLayer if activation == 'perceptron' else FFLayer
        for i in range(len(hidden_sizes)):
            layer_kwargs = dict(lr=lr, init_method='kaiming', optimizer=optimizer, activation=activation)
            if LayerClass == FFLayer:
                layer_kwargs['learnable_theta'] = learnable_theta
            self.layers.append(LayerClass(dims[i], dims[i+1], **layer_kwargs))
        self.to(device)

    def _overlay(self, x, labels):
        oh = F.one_hot(labels, self.K).float().to(x.device)
        return torch.cat([x, oh], dim=1)

    def _wrong_labels(self, y):
        wrong = torch.randint(0, self.K - 1, y.shape, device=y.device)
        return (wrong + y + 1) % self.K

    def train_epoch(self, loader):
        self.train()
        total_loss, n = 0.0, 0
        for xb, yb in loader:
            xb, yb = xb.to(self.device), yb.to(self.device)
            x_pos = self._overlay(xb, yb)
            x_neg = self._overlay(xb, self._wrong_labels(yb))
            hp, hn = x_pos, x_neg
            for layer in self.layers:
                lv, hp, hn = layer.train_step(hp, hn, alpha=0.0, k_pct=0)
                total_loss += lv
            n += 1
        return total_loss / max(n, 1)

    @torch.no_grad()
    def predict(self, X, batch_size=512):
        self.eval()
        Xt = torch.tensor(X, dtype=torch.float32, device=self.device)
        preds = []
        for s in range(0, len(Xt), batch_size):
            xb = Xt[s:s+batch_size]
            goodness = []
            for k in range(self.K):
                lk = torch.full((xb.size(0),), k, dtype=torch.long, device=self.device)
                xo = self._overlay(xb, lk)
                tg = torch.zeros(xb.size(0), device=self.device)
                h = xo
                for layer in self.layers:
                    g, h = layer.infer(h)
                    tg += g
                goodness.append(tg)
            preds.append(torch.stack(goodness).argmax(0).cpu().numpy())
        return np.concatenate(preds)

    def evaluate(self, X, y):
        return (self.predict(X) == y).mean() * 100.0


# ================================================================
# CLASSIC FF — Variant 2: Learned Embedding (smaller footprint)
# ================================================================

class ClassicFF_Embed(nn.Module):
    """
    FF with learned label embedding instead of one-hot.
    Embedding dim is fixed (default=2), much smaller than K for large K.
    """

    def __init__(self, input_dim, hidden_sizes, num_classes, lr=0.001, device='cpu',
                 embed_dim=2, activation='relu', learnable_theta=False):
        super().__init__()
        self.input_dim = input_dim
        self.K = num_classes
        self.device = device
        self.embed_dim = embed_dim
        self.activation = activation

        # Learned embedding: K classes -> embed_dim dimensions
        self.label_embedding = nn.Embedding(num_classes, embed_dim)

        self.eff_dim = input_dim + embed_dim

        dims = [self.eff_dim] + hidden_sizes
        self.layers = nn.ModuleList()
        for i in range(len(hidden_sizes)):
            self.layers.append(FFLayer(dims[i], dims[i+1], lr=lr, init_method='kaiming',
                                       activation=activation, learnable_theta=learnable_theta))
        # Optimizer for embedding (layers have their own optimizers)
        self.embed_opt = torch.optim.Adam(self.label_embedding.parameters(), lr=lr)
        self.to(device)

    def _embed_label(self, x, labels):
        """Concatenate input with learned label embedding."""
        emb = self.label_embedding(labels)  # [batch, embed_dim]
        return torch.cat([x, emb], dim=1)

    def _wrong_labels(self, y):
        wrong = torch.randint(0, self.K - 1, y.shape, device=y.device)
        return (wrong + y + 1) % self.K

    def train_epoch(self, loader):
        self.train()
        total_loss, n = 0.0, 0
        for xb, yb in loader:
            xb, yb = xb.to(self.device), yb.to(self.device)

            x_pos = self._embed_label(xb, yb)
            x_neg = self._embed_label(xb, self._wrong_labels(yb))

            # Zero embedding gradients
            self.embed_opt.zero_grad()

            hp, hn = x_pos, x_neg
            batch_loss = 0.0
            for layer in self.layers:
                lv, hp, hn = layer.train_step(hp, hn, alpha=0.0, k_pct=0)
                batch_loss += lv

            # Update embedding based on total loss
            # (layers already updated themselves in train_step)
            total_loss += batch_loss
            n += 1
        return total_loss / max(n, 1)

    @torch.no_grad()
    def predict(self, X, batch_size=512):
        self.eval()
        Xt = torch.tensor(X, dtype=torch.float32, device=self.device)
        preds = []
        for s in range(0, len(Xt), batch_size):
            xb = Xt[s:s+batch_size]
            goodness = []
            for k in range(self.K):
                lk = torch.full((xb.size(0),), k, dtype=torch.long, device=self.device)
                xo = self._embed_label(xb, lk)
                tg = torch.zeros(xb.size(0), device=self.device)
                h = xo
                for layer in self.layers:
                    g, h = layer.infer(h)
                    tg += g
                goodness.append(tg)
            preds.append(torch.stack(goodness).argmax(0).cpu().numpy())
        return np.concatenate(preds)

    def evaluate(self, X, y):
        return (self.predict(X) == y).mean() * 100.0


# ================================================================
# CLASSIC FF — Variant 3: Additive Label at First Hidden Layer
# ================================================================

class ClassicFF_Additive(nn.Module):
    """
    FF with label signal added to first hidden layer (not input).
    Input is preserved entirely; label modulates the hidden representation.
    """

    def __init__(self, input_dim, hidden_sizes, num_classes, lr=0.001, device='cpu',
                 activation='relu', learnable_theta=False):
        super().__init__()
        self.input_dim = input_dim
        self.K = num_classes
        self.device = device
        self.hidden_sizes = hidden_sizes
        self.activation = activation

        # First layer: input -> hidden (NO label yet)
        self.first_layer = FFLayer(input_dim, hidden_sizes[0], lr=lr, init_method='kaiming',
                                   activation=activation, learnable_theta=learnable_theta)
        # Label embedding that matches first hidden size
        self.label_embedding = nn.Embedding(num_classes, hidden_sizes[0])
        self.embed_opt = torch.optim.Adam(self.label_embedding.parameters(), lr=lr)

        # Remaining layers
        self.layers = nn.ModuleList()
        for i in range(1, len(hidden_sizes)):
            self.layers.append(FFLayer(hidden_sizes[i-1], hidden_sizes[i], lr=lr, init_method='kaiming',
                                       activation=activation, learnable_theta=learnable_theta))
        self.to(device)

    def _wrong_labels(self, y):
        wrong = torch.randint(0, self.K - 1, y.shape, device=y.device)
        return (wrong + y + 1) % self.K

    def _forward_with_label(self, x, labels, train=False):
        """Forward pass: first layer on clean input, then add label embedding."""
        # First layer on clean input
        if train:
            h = self.first_layer.forward(x)
        else:
            h = self.first_layer.forward(x)

        # Add label embedding to first hidden representation
        label_emb = self.label_embedding(labels)
        h = h + label_emb

        # Normalize
        h = h / (h.norm(dim=1, keepdim=True) + 1e-8)

        return h

    def train_epoch(self, loader):
        self.train()
        total_loss, n = 0.0, 0

        for xb, yb in loader:
            xb, yb = xb.to(self.device), yb.to(self.device)
            y_wrong = self._wrong_labels(yb)

            self.embed_opt.zero_grad()

            # First layer forward (clean input)
            h_pos_raw, h_pos_n = self.first_layer.forward_norm(xb)
            h_neg_raw, h_neg_n = self.first_layer.forward_norm(xb)

            # Add label embeddings
            label_emb_pos = self.label_embedding(yb)
            label_emb_neg = self.label_embedding(y_wrong)

            h_pos = h_pos_raw + label_emb_pos
            h_neg = h_neg_raw + label_emb_neg

            # Compute goodness for first layer
            g_pos = (h_pos ** 2).mean(dim=1)
            g_neg = (h_neg ** 2).mean(dim=1)

            theta = 1.0
            loss_first = (
                -torch.log(torch.sigmoid(g_pos - theta) + 1e-8).mean()
                - torch.log(1 - torch.sigmoid(g_neg - theta) + 1e-8).mean()
            )

            # Update first layer
            self.first_layer.opt.zero_grad()
            loss_first.backward(retain_graph=True)
            self.first_layer.opt.step()
            self.embed_opt.step()

            total_loss += loss_first.item()

            # Normalize for next layers
            hp = h_pos.detach() / (h_pos.detach().norm(dim=1, keepdim=True) + 1e-8)
            hn = h_neg.detach() / (h_neg.detach().norm(dim=1, keepdim=True) + 1e-8)

            # Remaining layers
            for layer in self.layers:
                lv, hp, hn = layer.train_step(hp, hn, alpha=0.0, k_pct=0)
                total_loss += lv

            n += 1

        return total_loss / max(n, 1)

    @torch.no_grad()
    def predict(self, X, batch_size=512):
        self.eval()
        Xt = torch.tensor(X, dtype=torch.float32, device=self.device)
        preds = []

        for s in range(0, len(Xt), batch_size):
            xb = Xt[s:s+batch_size]
            goodness = []

            for k in range(self.K):
                lk = torch.full((xb.size(0),), k, dtype=torch.long, device=self.device)

                # First layer on clean input
                h = self.first_layer.forward(xb)

                # Add label embedding
                label_emb = self.label_embedding(lk)
                h = h + label_emb

                # Goodness from first layer
                tg = (h ** 2).mean(dim=1)

                # Normalize
                h = h / (h.norm(dim=1, keepdim=True) + 1e-8)

                # Remaining layers
                for layer in self.layers:
                    g, h = layer.infer(h)
                    tg += g

                goodness.append(tg)

            preds.append(torch.stack(goodness).argmax(0).cpu().numpy())

        return np.concatenate(preds)

    def evaluate(self, X, y):
        return (self.predict(X) == y).mean() * 100.0


# ================================================================
# CLASSIC FF — Variant 4: With Local Adaptation (alpha > 0)
# ================================================================

class ClassicFF_LocalAdapt(nn.Module):
    """
    Classic FF with local adaptation (per-neuron goodness).

    This is Hinton's architecture + our alpha-weighted loss.
    Allows fair comparison: does local adaptation alone explain
    ModularFF's advantage, or is the modular architecture key?
    """

    def __init__(self, input_dim, hidden_sizes, num_classes, lr=0.001, device='cpu',
                 init_method='kaiming', img_size=None, alpha=0.3, activation='relu', learnable_theta=False):
        super().__init__()
        self.input_dim = input_dim
        self.K = num_classes
        self.device = device
        self.alpha = alpha
        self.activation = activation
        self.eff_dim = input_dim + num_classes

        dims = [self.eff_dim] + hidden_sizes
        self.layers = nn.ModuleList()
        for i in range(len(hidden_sizes)):
            self.layers.append(FFLayer(dims[i], dims[i+1], lr=lr, init_method='kaiming',
                                       activation=activation, learnable_theta=learnable_theta))
        self.to(device)

    def _overlay(self, x, labels):
        oh = F.one_hot(labels, self.K).float().to(x.device)
        return torch.cat([x, oh], dim=1)

    def _wrong_labels(self, y):
        wrong = torch.randint(0, self.K - 1, y.shape, device=y.device)
        return (wrong + y + 1) % self.K

    def train_epoch(self, loader):
        self.train()
        total_loss, n = 0.0, 0
        for xb, yb in loader:
            xb, yb = xb.to(self.device), yb.to(self.device)
            x_pos = self._overlay(xb, yb)
            x_neg = self._overlay(xb, self._wrong_labels(yb))
            hp, hn = x_pos, x_neg
            for layer in self.layers:
                # Use alpha for local adaptation!
                lv, hp, hn = layer.train_step(hp, hn, alpha=self.alpha, k_pct=0)
                total_loss += lv
            n += 1
        return total_loss / max(n, 1)

    @torch.no_grad()
    def predict(self, X, batch_size=512):
        self.eval()
        Xt = torch.tensor(X, dtype=torch.float32, device=self.device)
        preds = []
        for s in range(0, len(Xt), batch_size):
            xb = Xt[s:s+batch_size]
            goodness = []
            for k in range(self.K):
                lk = torch.full((xb.size(0),), k, dtype=torch.long, device=self.device)
                xo = self._overlay(xb, lk)
                tg = torch.zeros(xb.size(0), device=self.device)
                h = xo
                for layer in self.layers:
                    g, h = layer.infer(h)
                    tg += g
                goodness.append(tg)
            preds.append(torch.stack(goodness).argmax(0).cpu().numpy())
        return np.concatenate(preds)

    def evaluate(self, X, y):
        return (self.predict(X) == y).mean() * 100.0


# ================================================================
# BP BASELINE
# ================================================================

class BPBaseline(nn.Module):
    """Standard MLP with cross-entropy."""

    def __init__(self, input_dim, hidden_sizes, num_classes, lr=0.001, device='cpu',
                 activation='relu', learnable_theta=False):
        super().__init__()
        self.device = device
        self.activation = activation
        layers = []
        dims = [input_dim] + hidden_sizes
        for i in range(len(hidden_sizes)):
            layers += [nn.Linear(dims[i], dims[i+1]), make_activation(activation)]
        layers.append(nn.Linear(hidden_sizes[-1], num_classes))
        self.net = nn.Sequential(*layers)
        self.opt = torch.optim.Adam(self.parameters(), lr=lr)
        self.crit = nn.CrossEntropyLoss()
        self.to(device)

    def forward(self, x):
        return self.net(x)

    def train_epoch(self, loader):
        self.train()
        total_loss, n = 0.0, 0
        for xb, yb in loader:
            xb, yb = xb.to(self.device), yb.to(self.device)
            loss = self.crit(self.forward(xb), yb)
            self.opt.zero_grad(); loss.backward(); self.opt.step()
            total_loss += loss.item(); n += 1
        return total_loss / max(n, 1)

    @torch.no_grad()
    def evaluate(self, X, y):
        self.eval()
        Xt = torch.tensor(X, dtype=torch.float32, device=self.device)
        yt = torch.tensor(y, dtype=torch.long, device=self.device)
        correct = 0
        for s in range(0, len(Xt), 512):
            logits = self.forward(Xt[s:s+512])
            correct += (logits.argmax(1) == yt[s:s+512]).sum().item()
        return correct / len(y) * 100.0


# ================================================================
# ModularFF SPECIALIST
# ================================================================

class ModularFFSpecialist(nn.Module):
    """One specialist for class k."""

    def __init__(self, input_dim, hidden_sizes, class_id, lr=0.001,
                 theta_neuron=1.0, device='cpu',
                 init_method='kaiming', img_size=None,
                 first_layer_frozen=False, prune_beta=1.0, activation='relu', learnable_theta=False):
        super().__init__()
        self.class_id = class_id
        self.device = device
        self.input_dim = input_dim
        self.target_hidden_sizes = hidden_sizes.copy()
        self.lr = lr
        self.theta_neuron = theta_neuron
        self.init_method = init_method
        self.img_size = img_size
        self.first_layer_frozen = first_layer_frozen
        self.prune_beta = prune_beta
        self.activation = activation
        self.learnable_theta = learnable_theta
        self.pruned = False
        self.num_layers = len(hidden_sizes)
        self._build_layers()
        self.to(device)

    def _build_layers(self):
        hidden = self.target_hidden_sizes.copy()
        if self.prune_beta > 1.0 and not self.pruned:
            hidden[0] = int(hidden[0] * self.prune_beta)
        dims = [self.input_dim] + hidden
        self.layers = nn.ModuleList()
        LayerClass = PerceptronFFLayer if self.activation == 'perceptron' else FFLayer
        extra = {'learnable_theta': self.learnable_theta} if LayerClass == FFLayer else {}
        for i in range(len(hidden)):
            if i == 0:
                layer = LayerClass(dims[i], dims[i+1], lr=self.lr, theta_neuron=self.theta_neuron,
                                init_method=self.init_method, frozen=self.first_layer_frozen,
                                img_size=self.img_size, activation=self.activation, **extra)
            else:
                layer = LayerClass(dims[i], dims[i+1], lr=self.lr, theta_neuron=self.theta_neuron,
                                init_method='kaiming', frozen=False, activation=self.activation, **extra)
            self.layers.append(layer)
        self.current_hidden_sizes = hidden

    def train_batch(self, x_pos, x_neg, alpha=0.0, layer_dropout=None):
        if layer_dropout is None:
            layer_dropout = [0] * len(self.layers)
        while len(layer_dropout) < len(self.layers):
            layer_dropout.append(0)
        hp, hn = x_pos, x_neg
        total_loss = 0.0
        for i, layer in enumerate(self.layers):
            lv, hp, hn = layer.train_step(hp, hn, alpha=alpha, k_pct=layer_dropout[i])
            total_loss += lv
        return total_loss / len(self.layers)

    @torch.no_grad()
    def total_goodness(self, x):
        h = x
        tg = torch.zeros(x.size(0), device=self.device)
        for layer in self.layers:
            g, h = layer.infer(h)
            tg += g
        return tg

    @torch.no_grad()
    def compute_pruning_scores(self, X_pos, X_neg):
        first_layer = self.layers[0]
        X_pos_t = torch.tensor(X_pos, dtype=torch.float32, device=self.device)
        X_neg_t = torch.tensor(X_neg, dtype=torch.float32, device=self.device)
        h_pos = first_layer.get_activations(X_pos_t)
        h_neg = first_layer.get_activations(X_neg_t)
        g_pos = h_pos ** 2
        g_neg = h_neg ** 2
        mean_pos = g_pos.mean(dim=0)
        mean_neg = g_neg.mean(dim=0)
        std_pos = g_pos.std(dim=0)
        std_neg = g_neg.std(dim=0)
        std_pooled = torch.sqrt((std_pos**2 + std_neg**2) / 2 + 1e-8)
        separation = torch.abs(mean_pos - mean_neg)
        scores = separation / std_pooled
        avg_activity = (mean_pos + mean_neg) / 2
        scores[avg_activity < 0.01] = 0.0
        return scores

    def prune_first_layer(self, X_pos, X_neg, keep_n=None):
        if self.pruned:
            return
        if keep_n is None:
            keep_n = self.target_hidden_sizes[0]
        scores = self.compute_pruning_scores(X_pos, X_neg)
        current_n = len(scores)
        if keep_n >= current_n:
            self.pruned = True
            return
        _, top_indices = torch.topk(scores, keep_n)
        top_indices = top_indices.sort().values
        old_layer = self.layers[0]
        old_weight = old_layer.linear.weight.data[top_indices, :]
        old_bias = old_layer.linear.bias.data[top_indices] if old_layer.linear.bias is not None else None
        new_layer = FFLayer(self.input_dim, keep_n, lr=self.lr, theta_neuron=self.theta_neuron,
                            init_method='kaiming', frozen=self.first_layer_frozen)
        with torch.no_grad():
            new_layer.linear.weight.copy_(old_weight)
            if old_bias is not None:
                new_layer.linear.bias.copy_(old_bias)
        new_layer.to(self.device)
        if len(self.layers) > 1:
            old_second = self.layers[1]
            old_second_weight = old_second.linear.weight.data[:, top_indices.cpu()]
            new_second = FFLayer(keep_n, old_second.n_neurons, lr=self.lr, theta_neuron=self.theta_neuron,
                                 init_method='kaiming', frozen=False)
            with torch.no_grad():
                new_second.linear.weight.copy_(old_second_weight)
                if old_second.linear.bias is not None:
                    new_second.linear.bias.copy_(old_second.linear.bias.data)
            new_second.to(self.device)
            self.layers[1] = new_second
        self.layers[0] = new_layer
        self.current_hidden_sizes[0] = keep_n
        self.pruned = True
        print(f"    Specialist {self.class_id}: Pruned {current_n - keep_n}/{current_n} neurons")


# ================================================================
# ModularFF ENSEMBLE — With Per-Specialist Evaluation
# ================================================================

class ModularFFEnsemble:
    """
    Full ModularFF: K specialists + meta-layers + per-specialist evaluation.

    New in v4:
    - evaluate_specialists(): binary accuracy, sensitivity, specificity per expert
    - Negative examples are resampled each epoch (confirmed)
    """

    def __init__(self, input_dim, hidden_sizes, num_classes,
                 lr=0.001, theta_neuron=1.0, device='cpu',
                 init_method='kaiming', img_size=None,
                 first_layer_frozen=False, prune_beta=1.0,
                 meta_type='argmax', use_meta_layer=True, activation='relu', learnable_theta=False):

        self.K = num_classes
        self.device = device
        self.hidden_sizes = hidden_sizes
        self.prune_beta = prune_beta
        self.use_meta_layer = use_meta_layer
        self.activation = activation

        self.specs = [
            ModularFFSpecialist(input_dim, hidden_sizes, k, lr, theta_neuron, device,
                            init_method, img_size, first_layer_frozen, prune_beta,
                            activation=activation)
            for k in range(num_classes)
        ]

        self.meta_layers = {}
        for mt in ['argmax', 'calibrated', 'linear', 'mlp', 'temperature']:
            self.meta_layers[mt] = MetaLayer(num_classes, mt, device)

        self.default_meta = meta_type

    def total_params(self):
        return sum(sum(p.numel() for p in s.parameters()) for s in self.specs)

    def _get_specialist_data(self, X, y, class_id, balanced=True):
        """
        Get positive and negative examples for a specialist.

        NOTE: This is called EACH EPOCH, so negatives change!
        """
        pos_idx = np.where(y == class_id)[0]
        neg_idx = np.where(y != class_id)[0]
        n_pos = len(pos_idx)

        if n_pos == 0 or len(neg_idx) == 0:
            return None, None

        if balanced:
            # Sample equal number of negatives (resampled each call!)
            neg_sample = np.random.choice(neg_idx, size=min(n_pos, len(neg_idx)), replace=False)
        else:
            neg_sample = neg_idx

        return X[pos_idx], X[neg_sample]

    def train_epoch(self, X, y, alpha=0.0, layer_dropout=None, batch_size=128):
        """
        Train all specialists for one epoch.
        Negatives are RESAMPLED each epoch (different random subset).
        """
        total_loss = 0.0
        for spec in self.specs:
            spec.train()

        for k, spec in enumerate(self.specs):
            # NOTE: _get_specialist_data resamples negatives each call!
            X_pos, X_neg = self._get_specialist_data(X, y, k, balanced=True)
            if X_pos is None:
                continue

            n = min(len(X_pos), len(X_neg))

            # Shuffle order each epoch too
            perm_pos = np.random.permutation(len(X_pos))
            perm_neg = np.random.permutation(len(X_neg))

            spec_loss, nb = 0.0, 0
            for s in range(0, n, batch_size):
                e = min(s + batch_size, n)
                xp = torch.tensor(X_pos[perm_pos[s:e]], dtype=torch.float32, device=self.device)
                xn = torch.tensor(X_neg[perm_neg[s:e]], dtype=torch.float32, device=self.device)
                mb = min(xp.size(0), xn.size(0))
                if mb == 0:
                    continue
                lv = spec.train_batch(xp[:mb], xn[:mb], alpha=alpha, layer_dropout=layer_dropout)
                spec_loss += lv
                nb += 1

            if nb > 0:
                total_loss += spec_loss / nb

        return total_loss / self.K

    def prune_all_specialists(self, X, y):
        print(f"\n  Pruning {self.K} specialists...")
        for k, spec in enumerate(self.specs):
            X_pos, X_neg = self._get_specialist_data(X, y, k)
            if X_pos is not None and X_neg is not None:
                spec.prune_first_layer(X_pos, X_neg)
        print(f"  Done. New params: {self.total_params()}")

    @torch.no_grad()
    def _all_goodness(self, X, batch_size=512):
        for s in self.specs:
            s.eval()
        Xt = torch.tensor(X, dtype=torch.float32, device=self.device)
        chunks = []
        for s in range(0, len(Xt), batch_size):
            xb = Xt[s:s+batch_size]
            g = torch.stack([spec.total_goodness(xb) for spec in self.specs], dim=1)
            chunks.append(g)
        return torch.cat(chunks, dim=0)

    def train_meta_layers(self, X_val, y_val, epochs=100):
        if not self.use_meta_layer:
            return
        G = self._all_goodness(X_val)
        y_t = torch.tensor(y_val, dtype=torch.long, device=self.device)
        self.meta_layers['calibrated'].calibrate(G, y_t)
        for mt in ['linear', 'mlp', 'temperature']:
            self.meta_layers[mt].train(G, y_t, epochs=epochs)

    def predict(self, X, meta=None):
        if meta is None:
            meta = self.default_meta
        G = self._all_goodness(X)
        if not self.use_meta_layer or meta == 'none':
            return G.argmax(1).cpu().numpy()
        return self.meta_layers[meta].predict(G)

    def evaluate(self, X, y, meta=None):
        return (self.predict(X, meta) == y).mean() * 100.0

    def evaluate_all_meta(self, X, y):
        results = {}
        for mt in self.meta_layers.keys():
            results[mt] = self.evaluate(X, y, mt)
        return results

    @torch.no_grad()
    def evaluate_specialists(self, X, y, threshold_method='mean'):
        """
        Evaluate each specialist on its binary classification task.

        For specialist k:
        - Positive examples: samples where y == k
        - Negative examples: balanced sample where y != k

        Returns:
            dict: {k: {'accuracy', 'sensitivity', 'specificity',
                       'mean_g_pos', 'mean_g_neg', 'threshold'}}
        """
        results = {}

        for k, spec in enumerate(self.specs):
            spec.eval()

            # Get balanced positive/negative data
            X_pos, X_neg = self._get_specialist_data(X, y, k, balanced=True)

            if X_pos is None or X_neg is None or len(X_pos) == 0 or len(X_neg) == 0:
                results[k] = {'accuracy': 0, 'sensitivity': 0, 'specificity': 0}
                continue

            # Compute goodness
            X_pos_t = torch.tensor(X_pos, dtype=torch.float32, device=self.device)
            X_neg_t = torch.tensor(X_neg, dtype=torch.float32, device=self.device)

            g_pos = spec.total_goodness(X_pos_t).cpu().numpy()
            g_neg = spec.total_goodness(X_neg_t).cpu().numpy()

            mean_g_pos = g_pos.mean()
            mean_g_neg = g_neg.mean()

            # Determine threshold
            if threshold_method == 'mean':
                threshold = (mean_g_pos + mean_g_neg) / 2
            elif threshold_method == 'theta':
                threshold = sum(layer.theta_layer for layer in spec.layers)
            else:
                threshold = (mean_g_pos + mean_g_neg) / 2

            # Binary predictions: positive if goodness > threshold
            tp = (g_pos > threshold).sum()
            fn = (g_pos <= threshold).sum()
            tn = (g_neg <= threshold).sum()
            fp = (g_neg > threshold).sum()

            n_pos = len(g_pos)
            n_neg = len(g_neg)

            accuracy = (tp + tn) / (n_pos + n_neg) * 100
            sensitivity = tp / n_pos * 100 if n_pos > 0 else 0
            specificity = tn / n_neg * 100 if n_neg > 0 else 0

            results[k] = {
                'accuracy': round(accuracy, 2),
                'sensitivity': round(sensitivity, 2),
                'specificity': round(specificity, 2),
                'mean_g_pos': round(float(mean_g_pos), 2),
                'mean_g_neg': round(float(mean_g_neg), 2),
                'threshold': round(float(threshold), 2),
                'separation': round(float(mean_g_pos - mean_g_neg), 2),
            }

        return results

    def print_specialist_performance(self, X, y, threshold_method='mean'):
        """Pretty-print per-specialist binary performance."""
        results = self.evaluate_specialists(X, y, threshold_method)

        print(f'\n  {"Spec":>4}  {"Acc":>6}  {"Sens":>6}  {"Spec":>6}  '
              f'{"G_pos":>7}  {"G_neg":>7}  {"Sep":>6}')
        print('  ' + '-' * 52)

        for k in range(self.K):
            r = results[k]
            print(f'  {k:>4}  {r["accuracy"]:>5.1f}%  {r["sensitivity"]:>5.1f}%  '
                  f'{r["specificity"]:>5.1f}%  {r["mean_g_pos"]:>7.1f}  '
                  f'{r["mean_g_neg"]:>7.1f}  {r["separation"]:>6.1f}')

        # Summary stats
        avg_acc = np.mean([results[k]['accuracy'] for k in range(self.K)])
        avg_sens = np.mean([results[k]['sensitivity'] for k in range(self.K)])
        avg_spec = np.mean([results[k]['specificity'] for k in range(self.K)])
        avg_sep = np.mean([results[k]['separation'] for k in range(self.K)])

        print('  ' + '-' * 52)
        print(f'  {"Avg":>4}  {avg_acc:>5.1f}%  {avg_sens:>5.1f}%  '
              f'{avg_spec:>5.1f}%  {"":>7}  {"":>7}  {avg_sep:>6.1f}')

        return results


# ================================================================
print('\u2713 All classes defined (v5):')
print('  - FFLayer: mean-goodness + theta=1.0 (scale-invariant)')
print('  - FFLayer: per-neuron loss uses mean (balanced with layer loss)')
print('  - FFLayer: supports optimizer="adam"|"sgd"')
print('  - ClassicFF: supports optimizer="adam"|"sgd"')
print('  - Alpha now truly interpolates: 0.5 = equal blend')

✓ All classes defined (v5):
  - FFLayer: mean-goodness + theta=1.0 (scale-invariant)
  - FFLayer: per-neuron loss uses mean (balanced with layer loss)
  - FFLayer: supports optimizer="adam"|"sgd"
  - ClassicFF: supports optimizer="adam"|"sgd"
  - Alpha now truly interpolates: 0.5 = equal blend


In [None]:
####### end of cell 2

In [3]:
# ================================================================
# CELL 3: Training Engine (v5)
# ================================================================
# v5 Changes:
#   - Added run_2layer_experiment() for Phase 1
#   - run_4layer_experiment() for Phase 2 (Adam + SGD)
#   - Relaxed early stopping: patience=60, min_epochs=50
#   - LR reduction schedule with best-model checkpoint
# ================================================================


class ExperimentLogger:
    """Append-only CSV logger."""

    FIELDS = [
        'timestamp', 'dataset', 'method', 'seed',
        'alpha', 'layer_dropout', 'meta_layer',
        'hidden_sizes', 'num_specialists',
        'best_val_acc', 'test_acc', 'epochs_run',
        'train_time_sec', 'total_params',
        'init_method', 'elm_mode', 'pruning', 'prune_beta',
        'avg_specialist_acc', 'avg_specialist_sep',
    ]

    def __init__(self, log_dir):
        self.path = os.path.join(log_dir, 'experiment_logs.csv')
        if not os.path.exists(self.path):
            with open(self.path, 'w', newline='') as f:
                csv.DictWriter(f, fieldnames=self.FIELDS).writeheader()

    def log(self, d):
        d['timestamp'] = datetime.now().isoformat()
        with open(self.path, 'a', newline='') as f:
            csv.DictWriter(f, fieldnames=self.FIELDS).writerow(
                {k: d.get(k, '') for k in self.FIELDS}
            )


LOGGER = ExperimentLogger(CONFIG['logs_path'])


def make_loader(X, y, batch_size, shuffle=True):
    Xt = torch.tensor(X, dtype=torch.float32)
    yt = torch.tensor(y, dtype=torch.long)
    return DataLoader(TensorDataset(Xt, yt), batch_size=batch_size, shuffle=shuffle)


# ================================================================
# LR REDUCTION HELPER
# ================================================================

def reduce_lr_for_model(model, factor=0.5):
    """Reduce LR for all optimizers in a model by the given factor."""
    if hasattr(model, 'layers'):
        # ClassicFF / ClassicFF_LocalAdapt / ClassicFF_Additive
        for layer in model.layers:
            if hasattr(layer, 'opt') and layer.opt is not None:
                for pg in layer.opt.param_groups:
                    pg['lr'] *= factor
    if hasattr(model, 'first_layer') and hasattr(model.first_layer, 'opt'):
        if model.first_layer.opt is not None:
            for pg in model.first_layer.opt.param_groups:
                pg['lr'] *= factor
    if hasattr(model, 'embed_opt'):
        for pg in model.embed_opt.param_groups:
            pg['lr'] *= factor
    if hasattr(model, 'opt'):
        # BPBaseline
        for pg in model.opt.param_groups:
            pg['lr'] *= factor


def reduce_lr_for_ensemble(ensemble, factor=0.5):
    """Reduce LR for all specialists in a ModularFF ensemble."""
    for spec in ensemble.specs:
        for layer in spec.layers:
            if hasattr(layer, 'opt') and layer.opt is not None:
                for pg in layer.opt.param_groups:
                    pg['lr'] *= factor


def get_current_lr_str(model):
    """Get current LR string for logging."""
    if hasattr(model, 'opt'):
        return f"{model.opt.param_groups[0]['lr']:.1e}"
    if hasattr(model, 'layers') and len(model.layers) > 0:
        layer = model.layers[0]
        if hasattr(layer, 'opt') and layer.opt is not None:
            return f"{layer.opt.param_groups[0]['lr']:.1e}"
    return "?"


def get_ensemble_lr_str(ensemble):
    """Get current LR string for ensemble logging."""
    if ensemble.specs and ensemble.specs[0].layers:
        layer = ensemble.specs[0].layers[0]
        if hasattr(layer, 'opt') and layer.opt is not None:
            return f"{layer.opt.param_groups[0]['lr']:.1e}"
    return "?"


# ================================================================
# TRAINING FUNCTIONS (with relaxed early stopping)
# ================================================================

def train_classic_ff(ds_name, ds, hidden, seed, cfg, optimizer='adam', lr_override=None, activation='relu', learnable_theta=False):
    """Train Classic FF (Hinton's original with one-hot overlay).

    Args:
        optimizer: 'adam' or 'sgd'
        lr_override: If set, use this LR instead of cfg['lr'] (for Hinton SGD config)
    """
    set_seed(seed)
    dev = cfg['device']
    lr = lr_override if lr_override is not None else cfg['lr']
    opt_label = optimizer.upper()
    method_name = f'ClassicFF_{opt_label}'

    model = ClassicFF(ds['input_dim'], hidden, ds['num_classes'], lr, dev,
                      init_method=cfg.get('init_method', 'kaiming'),
                      img_size=ARCHITECTURES[ds_name].get('img_size'),
                      optimizer=optimizer, activation=activation, learnable_theta=learnable_theta)
    n_params = sum(p.numel() for p in model.parameters())
    loader = make_loader(ds['X_train'], ds['y_train'], cfg['batch_size'])

    # Early stopping config
    min_epochs = cfg.get('min_epochs', 50)
    total_patience = cfg['early_stop_patience']
    lr_reduce_patience = cfg.get('lr_reduce_patience', 20)
    lr_reduce_factor = cfg.get('lr_reduce_factor', 0.5)

    hist = {'train_acc': [], 'val_acc': [], 'loss': [], 'wall_time': []}
    best_val, patience_counter = 0.0, 0
    lr_reductions = 0
    best_state = None
    t0 = time.time()

    for ep in range(cfg['epochs']):
        loss = model.train_epoch(loader)
        tr_acc = model.evaluate(ds['X_train'], ds['y_train'])
        va_acc = model.evaluate(ds['X_val'], ds['y_val'])
        hist['loss'].append(loss)
        hist['train_acc'].append(tr_acc)
        hist['val_acc'].append(va_acc)
        hist['wall_time'].append(time.time() - t0)

        if va_acc > best_val:
            best_val = va_acc
            patience_counter = 0
            best_state = copy.deepcopy(model.state_dict())
        else:
            patience_counter += 1

        # LR reduction before stopping
        if patience_counter > 0 and patience_counter % lr_reduce_patience == 0 and patience_counter < total_patience:
            lr_reductions += 1
            reduce_lr_for_model(model, lr_reduce_factor)
            lr_str = get_current_lr_str(model)
            print(f'  [{method_name}] LR reduced (×{lr_reduce_factor}) -> {lr_str} at epoch {ep+1} (reduction #{lr_reductions})')

        if (ep+1) % 10 == 0 or ep == 0:
            print(f'  [{method_name}] ep {ep+1:3d}  loss={loss:.3f}  '
                  f'train={tr_acc:.1f}%  val={va_acc:.1f}%  (p={patience_counter})')

        # Early stop only after min_epochs and full patience exhausted
        if ep >= min_epochs and patience_counter >= total_patience:
            print(f'  [{method_name}] Early stop at epoch {ep+1} (best_val={best_val:.1f}%, {lr_reductions} LR reductions)')
            break

    # Restore best model
    if best_state is not None:
        model.load_state_dict(best_state)
        print(f'  [{method_name}] Restored best checkpoint (val={best_val:.1f}%)')

    elapsed = time.time() - t0
    te_acc = model.evaluate(ds['X_test'], ds['y_test'])
    # Goodness diagnostic
    g_diag = []
    for li, layer in enumerate(model.layers):
        gp = getattr(layer, '_last_g_pos_mean', 0)
        gn = getattr(layer, '_last_g_neg_mean', 0)
        th = getattr(layer, '_last_theta', getattr(layer, 'theta_layer', '?'))
        if hasattr(th, 'item'): th = th.item()
        g_diag.append(f'L{li}:g+={gp:.3f}/g-={gn:.3f}/sep={gp-gn:.3f}/th={th:.3f}')
    print(f'  [{method_name}] DONE  test={te_acc:.2f}%  {elapsed:.0f}s  {n_params} params  (LR={lr}, {optimizer}, act={activation})')
    print(f'    Goodness: {" | ".join(g_diag)}')

    LOGGER.log({
        'dataset': ds_name, 'method': method_name, 'seed': seed,
        'alpha': 0, 'layer_dropout': 'N/A', 'meta_layer': 'N/A',
        'hidden_sizes': str(hidden), 'num_specialists': 1,
        'best_val_acc': round(best_val, 2), 'test_acc': round(te_acc, 2),
        'epochs_run': len(hist['loss']), 'train_time_sec': round(elapsed, 1),
        'total_params': n_params,
        'init_method': cfg.get('init_method', 'kaiming'),
        'elm_mode': False, 'pruning': False, 'prune_beta': 1.0,
    })
    return hist, te_acc, n_params


def train_bp(ds_name, ds, hidden, seed, cfg, activation='relu'):
    """Train BP Baseline."""
    set_seed(seed)
    dev = cfg['device']

    model = BPBaseline(ds['input_dim'], hidden, ds['num_classes'], cfg['lr'], dev,
                       activation=activation)
    n_params = sum(p.numel() for p in model.parameters())
    loader = make_loader(ds['X_train'], ds['y_train'], cfg['batch_size'])

    # Early stopping config
    min_epochs = cfg.get('min_epochs', 50)
    total_patience = cfg['early_stop_patience']
    lr_reduce_patience = cfg.get('lr_reduce_patience', 20)
    lr_reduce_factor = cfg.get('lr_reduce_factor', 0.5)

    hist = {'train_acc': [], 'val_acc': [], 'loss': [], 'wall_time': []}
    best_val, patience_counter = 0.0, 0
    lr_reductions = 0
    best_state = None
    t0 = time.time()

    for ep in range(cfg['epochs']):
        loss = model.train_epoch(loader)
        tr_acc = model.evaluate(ds['X_train'], ds['y_train'])
        va_acc = model.evaluate(ds['X_val'], ds['y_val'])
        hist['loss'].append(loss)
        hist['train_acc'].append(tr_acc)
        hist['val_acc'].append(va_acc)
        hist['wall_time'].append(time.time() - t0)

        if va_acc > best_val:
            best_val = va_acc
            patience_counter = 0
            best_state = copy.deepcopy(model.state_dict())
        else:
            patience_counter += 1

        # LR reduction before stopping
        if patience_counter > 0 and patience_counter % lr_reduce_patience == 0 and patience_counter < total_patience:
            lr_reductions += 1
            reduce_lr_for_model(model, lr_reduce_factor)
            lr_str = get_current_lr_str(model)
            print(f'  [BP]       LR reduced (×{lr_reduce_factor}) -> {lr_str} at epoch {ep+1} (reduction #{lr_reductions})')

        if (ep+1) % 10 == 0 or ep == 0:
            print(f'  [BP]       ep {ep+1:3d}  loss={loss:.3f}  '
                  f'train={tr_acc:.1f}%  val={va_acc:.1f}%  (p={patience_counter})')

        # Early stop only after min_epochs and full patience exhausted
        if ep >= min_epochs and patience_counter >= total_patience:
            print(f'  [BP]       Early stop at epoch {ep+1} (best_val={best_val:.1f}%, {lr_reductions} LR reductions)')
            break

    # Restore best model
    if best_state is not None:
        model.load_state_dict(best_state)
        print(f'  [BP]       Restored best checkpoint (val={best_val:.1f}%)')

    elapsed = time.time() - t0
    te_acc = model.evaluate(ds['X_test'], ds['y_test'])
    print(f'  [BP]       DONE  test={te_acc:.2f}%  {elapsed:.0f}s  {n_params} params')

    LOGGER.log({
        'dataset': ds_name, 'method': 'BP', 'seed': seed,
        'alpha': 'N/A', 'layer_dropout': 'N/A', 'meta_layer': 'N/A',
        'hidden_sizes': str(hidden), 'num_specialists': 1,
        'best_val_acc': round(best_val, 2), 'test_acc': round(te_acc, 2),
        'epochs_run': len(hist['loss']), 'train_time_sec': round(elapsed, 1),
        'total_params': n_params,
        'init_method': 'N/A', 'elm_mode': False, 'pruning': False, 'prune_beta': 1.0,
    })
    return hist, te_acc, n_params


def train_modularff(ds_name, ds, spec_hidden, seed, cfg,
                alpha=0.0, layer_dropout=None, meta='argmax',
                init_method='kaiming', elm_mode=False,
                pruning_enabled=False, prune_beta=2.0, prune_after=10,
                use_meta_layer=True, show_specialist_perf=False, activation='relu', learnable_theta=False):
    """Train ModularFF with all features."""
    set_seed(seed)
    dev = cfg['device']
    img_size = ARCHITECTURES[ds_name].get('img_size')

    effective_beta = prune_beta if pruning_enabled else 1.0

    ens = ModularFFEnsemble(
        ds['input_dim'], spec_hidden, ds['num_classes'],
        cfg['lr'], cfg['theta_neuron'], dev,
        init_method=init_method, img_size=img_size,
        first_layer_frozen=elm_mode, prune_beta=effective_beta,
        meta_type=meta, use_meta_layer=use_meta_layer,
        activation=activation,
        learnable_theta=learnable_theta
    )

    initial_params = ens.total_params()

    # Early stopping config
    min_epochs = cfg.get('min_epochs', 50)
    total_patience = cfg['early_stop_patience']
    lr_reduce_patience = cfg.get('lr_reduce_patience', 20)
    lr_reduce_factor = cfg.get('lr_reduce_factor', 0.5)

    hist = {'train_acc': [], 'val_acc': [], 'loss': [], 'phase': [], 'wall_time': []}
    best_val, patience_counter = 0.0, 0
    lr_reductions = 0
    # Checkpoint: save specialist state dicts
    best_spec_states = None
    t0 = time.time()

    total_epochs = cfg['epochs']
    pruned = False

    ld_str = str(layer_dropout) if layer_dropout else 'uniform'

    for ep in range(total_epochs):
        if pruning_enabled and not pruned and ep == prune_after:
            print(f'\n  [ModularFF] === PRUNING at epoch {ep} ===')
            ens.prune_all_specialists(ds['X_train'], ds['y_train'])
            pruned = True
            print(f'  [ModularFF] Params: {initial_params} -> {ens.total_params()}\n')

        phase = 'post-prune' if pruned else ('pre-prune' if pruning_enabled else 'normal')

        loss = ens.train_epoch(
            ds['X_train'], ds['y_train'],
            alpha=alpha, layer_dropout=layer_dropout, batch_size=cfg['batch_size']
        )
        tr_acc = ens.evaluate(ds['X_train'], ds['y_train'], 'argmax')
        va_acc = ens.evaluate(ds['X_val'], ds['y_val'], 'argmax')

        hist['loss'].append(loss)
        hist['train_acc'].append(tr_acc)
        hist['val_acc'].append(va_acc)
        hist['phase'].append(phase)
        hist['wall_time'].append(time.time() - t0)

        if va_acc > best_val:
            best_val = va_acc
            patience_counter = 0
            # Checkpoint all specialists
            best_spec_states = [copy.deepcopy(s.state_dict()) for s in ens.specs]
        else:
            patience_counter += 1

        # LR reduction before stopping
        if patience_counter > 0 and patience_counter % lr_reduce_patience == 0 and patience_counter < total_patience:
            lr_reductions += 1
            reduce_lr_for_ensemble(ens, lr_reduce_factor)
            lr_str = get_ensemble_lr_str(ens)
            print(f'  [ModularFF] LR reduced (×{lr_reduce_factor}) -> {lr_str} at epoch {ep+1} (reduction #{lr_reductions})')

        if (ep+1) % 10 == 0 or ep == 0:
            print(f'  [ModularFF a={alpha} ld={ld_str}] ep {ep+1:3d}  '
                  f'loss={loss:.3f}  train={tr_acc:.1f}%  val={va_acc:.1f}%  (p={patience_counter})')

        # Early stop only after min_epochs and full patience exhausted
        if ep >= min_epochs and patience_counter >= total_patience:
            print(f'  [ModularFF] Early stop at epoch {ep+1} (best_val={best_val:.1f}%, {lr_reductions} LR reductions)')
            break

    # Restore best specialists
    if best_spec_states is not None:
        for s, state in zip(ens.specs, best_spec_states):
            s.load_state_dict(state)
        print(f'  [ModularFF] Restored best checkpoint (val={best_val:.1f}%)')

    train_time = time.time() - t0
    final_params = ens.total_params()

    # Train meta-layers
    if use_meta_layer:
        ens.train_meta_layers(ds['X_val'], ds['y_val'])

    # Evaluate all meta-layers
    meta_results = ens.evaluate_all_meta(ds['X_test'], ds['y_test']) if use_meta_layer else {}
    te_acc = meta_results.get(meta, ens.evaluate(ds['X_test'], ds['y_test'], meta))

    # Per-specialist evaluation
    spec_results = ens.evaluate_specialists(ds['X_test'], ds['y_test'])
    avg_spec_acc = np.mean([spec_results[k]['accuracy'] for k in range(ens.K)])
    avg_spec_sep = np.mean([spec_results[k].get('separation', 0) for k in range(ens.K)])

    # Print summary
    meta_str = ', '.join([f'{m}={meta_results.get(m, 0):.1f}%' for m in ['argmax', 'calibrated', 'linear', 'mlp', 'temperature']])
    print(f'  [ModularFF a={alpha} {init_method} {"frozen" if elm_mode else "trainable"} {"prune" if pruning_enabled else "no-prune"} act={activation}] DONE')
    print(f'    Meta: {meta_str}')
    print(f'    test={te_acc:.2f}%  {train_time:.0f}s  {final_params} params')
    # Goodness diagnostic (specialist 0)
    g_diag = []
    for li, layer in enumerate(ens.specs[0].layers):
        gp = getattr(layer, '_last_g_pos_mean', 0)
        gn = getattr(layer, '_last_g_neg_mean', 0)
        th = getattr(layer, '_last_theta', getattr(layer, 'theta_layer', '?'))
        if hasattr(th, 'item'): th = th.item()
        g_diag.append(f'L{li}:g+={gp:.3f}/g-={gn:.3f}/sep={gp-gn:.3f}/th={th:.3f}')
    print(f'    Goodness (spec0): {" | ".join(g_diag)}')
    print(f'    Avg specialist: acc={avg_spec_acc:.1f}%, separation={avg_spec_sep:.1f}')

    LOGGER.log({
        'dataset': ds_name, 'method': 'ModularFF', 'seed': seed,
        'alpha': alpha, 'layer_dropout': str(layer_dropout), 'meta_layer': meta,
        'hidden_sizes': str(spec_hidden), 'num_specialists': ds['num_classes'],
        'best_val_acc': round(best_val, 2), 'test_acc': round(te_acc, 2),
        'epochs_run': len(hist['loss']), 'train_time_sec': round(train_time, 1),
        'total_params': final_params,
        'init_method': init_method, 'elm_mode': elm_mode,
        'pruning': pruning_enabled, 'prune_beta': prune_beta,
        'avg_specialist_acc': round(avg_spec_acc, 2),
        'avg_specialist_sep': round(avg_spec_sep, 2),
    })

    return hist, te_acc, final_params, ens, meta_results, spec_results


# ================================================================
# PHASE 1: 2-LAYER EXPERIMENTS
# ================================================================

def run_2layer_experiment(ds_name, seed, cfg, show_specialist_perf=False, activation='relu', learnable_theta=False):
    """
    Run all 2-layer experiments for a dataset.

    Runs:
    - ClassicFF (One-Hot) with Adam and SGD
    - ClassicFF_Embed
    - ClassicFF_Additive
    - ClassicFF_LocalAdapt (alpha=0.5)
    - BP Baseline
    - ModularFF for each architecture × alpha combination
    """
    ds = DATASETS[ds_name]
    arch = ARCHITECTURES[ds_name]
    results = {}

    init_method = cfg.get('init_method', 'kaiming')
    use_meta_layer = cfg.get('use_meta_layer', True)
    hinton_lr = cfg.get('hinton_sgd_lr', 0.03)

    classic_hidden = arch.get('classic_ff')
    bp_hidden = arch.get('bp')
    modularff_archs = arch.get('modularff_archs', [[50, 50]])

    print(f'\n{"="*70}')
    print(f'  {ds_name} | seed={seed} | K={ds["num_classes"]} | dim={ds["input_dim"]}')
    print(f'  2-LAYER EXPERIMENTS')
    print(f'  Classic FF: {classic_hidden}')
    print(f'  ModularFF archs: {modularff_archs}')
    print(f'  Adam LR={cfg["lr"]}, SGD LR={hinton_lr}, batch={cfg["batch_size"]}')
    print(f'{"="*70}')

    # 1. Classic FF (One-Hot) — Adam
    if classic_hidden:
        print('\n--- Classic FF (One-Hot, Adam) ---')
        h, a, p = train_classic_ff(ds_name, ds, classic_hidden, seed, cfg, optimizer='adam', activation=activation, learnable_theta=learnable_theta)
        results['ClassicFF_Adam'] = {'history': h, 'test_acc': a, 'params': p}

        # Classic FF (One-Hot) — SGD
        print(f'\n--- Classic FF (One-Hot, SGD, LR={hinton_lr}) ---')
        h, a, p = train_classic_ff(ds_name, ds, classic_hidden, seed, cfg,
                                    optimizer='sgd', lr_override=hinton_lr, activation=activation, learnable_theta=learnable_theta)
        results['ClassicFF_SGD'] = {'history': h, 'test_acc': a, 'params': p}

    # 2. Classic FF Embed
    if classic_hidden:
      if activation not in ('perceptron',):
        print('\n--- Classic FF (Learned Embedding) ---')
        set_seed(seed)
        dev = cfg['device']
        model = ClassicFF_Embed(ds['input_dim'], classic_hidden, ds['num_classes'], cfg['lr'], dev,
                                    activation=activation, learnable_theta=learnable_theta)
        n_params = sum(p.numel() for p in model.parameters())
        loader = make_loader(ds['X_train'], ds['y_train'], cfg['batch_size'])

        hist = {'train_acc': [], 'val_acc': [], 'loss': [], 'wall_time': []}
        best_val, patience_counter = 0.0, 0
        best_state = None
        t0 = time.time()

        min_epochs = cfg.get('min_epochs', 50)
        total_patience = cfg['early_stop_patience']
        lr_reduce_patience = cfg.get('lr_reduce_patience', 20)
        lr_reduce_factor = cfg.get('lr_reduce_factor', 0.5)
        lr_reductions = 0

        for ep in range(cfg['epochs']):
            loss = model.train_epoch(loader)
            tr_acc = model.evaluate(ds['X_train'], ds['y_train'])
            va_acc = model.evaluate(ds['X_val'], ds['y_val'])
            hist['loss'].append(loss); hist['train_acc'].append(tr_acc)
            hist['val_acc'].append(va_acc); hist['wall_time'].append(time.time() - t0)
            if va_acc > best_val:
                best_val = va_acc; patience_counter = 0
                best_state = copy.deepcopy(model.state_dict())
            else:
                patience_counter += 1
            if patience_counter > 0 and patience_counter % lr_reduce_patience == 0 and patience_counter < total_patience:
                lr_reductions += 1
                reduce_lr_for_model(model, lr_reduce_factor)
                print(f'  [ClassicFF-Embed] LR reduced at epoch {ep+1} (#{lr_reductions})')
            if (ep+1) % 10 == 0 or ep == 0:
                print(f'  [ClassicFF-Embed] ep {ep+1:3d}  loss={loss:.3f}  train={tr_acc:.1f}%  val={va_acc:.1f}%  (p={patience_counter})')
            if ep >= min_epochs and patience_counter >= total_patience:
                print(f'  [ClassicFF-Embed] Early stop at epoch {ep+1}')
                break
        if best_state is not None:
            model.load_state_dict(best_state)
        te_acc = model.evaluate(ds['X_test'], ds['y_test'])
        print(f'  [ClassicFF-Embed] DONE  test={te_acc:.2f}%  {time.time()-t0:.0f}s  {n_params} params')
        results['ClassicFF_Embed'] = {'history': hist, 'test_acc': te_acc, 'params': n_params}
        LOGGER.log({'dataset': ds_name, 'method': 'ClassicFF_Embed', 'seed': seed,
                     'alpha': 0, 'hidden_sizes': str(classic_hidden),
                     'best_val_acc': round(best_val, 2), 'test_acc': round(te_acc, 2),
                     'epochs_run': len(hist['loss']), 'total_params': n_params})

    # 3. Classic FF Additive
    if classic_hidden:
      if activation not in ('perceptron',):
        print('\n--- Classic FF (Additive Hidden) ---')
        set_seed(seed)
        dev = cfg['device']
        model = ClassicFF_Additive(ds['input_dim'], classic_hidden, ds['num_classes'], cfg['lr'], dev,
                                      activation=activation, learnable_theta=learnable_theta)
        n_params = sum(p.numel() for p in model.parameters())
        loader = make_loader(ds['X_train'], ds['y_train'], cfg['batch_size'])

        hist = {'train_acc': [], 'val_acc': [], 'loss': [], 'wall_time': []}
        best_val, patience_counter = 0.0, 0
        best_state = None
        t0 = time.time()
        lr_reductions = 0

        for ep in range(cfg['epochs']):
            loss = model.train_epoch(loader)
            tr_acc = model.evaluate(ds['X_train'], ds['y_train'])
            va_acc = model.evaluate(ds['X_val'], ds['y_val'])
            hist['loss'].append(loss); hist['train_acc'].append(tr_acc)
            hist['val_acc'].append(va_acc); hist['wall_time'].append(time.time() - t0)
            if va_acc > best_val:
                best_val = va_acc; patience_counter = 0
                best_state = copy.deepcopy(model.state_dict())
            else:
                patience_counter += 1
            if patience_counter > 0 and patience_counter % lr_reduce_patience == 0 and patience_counter < total_patience:
                lr_reductions += 1
                reduce_lr_for_model(model, lr_reduce_factor)
                print(f'  [ClassicFF-Additive] LR reduced at epoch {ep+1} (#{lr_reductions})')
            if (ep+1) % 10 == 0 or ep == 0:
                print(f'  [ClassicFF-Additive] ep {ep+1:3d}  loss={loss:.3f}  train={tr_acc:.1f}%  val={va_acc:.1f}%  (p={patience_counter})')
            if ep >= min_epochs and patience_counter >= total_patience:
                print(f'  [ClassicFF-Additive] Early stop at epoch {ep+1}')
                break
        if best_state is not None:
            model.load_state_dict(best_state)
        te_acc = model.evaluate(ds['X_test'], ds['y_test'])
        print(f'  [ClassicFF-Additive] DONE  test={te_acc:.2f}%  {time.time()-t0:.0f}s  {n_params} params')
        results['ClassicFF_Additive'] = {'history': hist, 'test_acc': te_acc, 'params': n_params}
        LOGGER.log({'dataset': ds_name, 'method': 'ClassicFF_Additive', 'seed': seed,
                     'alpha': 0, 'hidden_sizes': str(classic_hidden),
                     'best_val_acc': round(best_val, 2), 'test_acc': round(te_acc, 2),
                     'epochs_run': len(hist['loss']), 'total_params': n_params})

    # 4. Classic FF LocalAdapt (alpha=0.5)
    if classic_hidden:
      if activation not in ('perceptron',):
        print('\n--- Classic FF (LocalAdapt alpha=0.5) ---')
        set_seed(seed)
        dev = cfg['device']
        model = ClassicFF_LocalAdapt(ds['input_dim'], classic_hidden, ds['num_classes'],
                                      cfg['lr'], dev, alpha=0.5, activation=activation, learnable_theta=learnable_theta)
        n_params = sum(p.numel() for p in model.parameters())
        loader = make_loader(ds['X_train'], ds['y_train'], cfg['batch_size'])

        hist = {'train_acc': [], 'val_acc': [], 'loss': [], 'wall_time': []}
        best_val, patience_counter = 0.0, 0
        best_state = None
        t0 = time.time()
        lr_reductions = 0

        for ep in range(cfg['epochs']):
            loss = model.train_epoch(loader)
            tr_acc = model.evaluate(ds['X_train'], ds['y_train'])
            va_acc = model.evaluate(ds['X_val'], ds['y_val'])
            hist['loss'].append(loss); hist['train_acc'].append(tr_acc)
            hist['val_acc'].append(va_acc); hist['wall_time'].append(time.time() - t0)
            if va_acc > best_val:
                best_val = va_acc; patience_counter = 0
                best_state = copy.deepcopy(model.state_dict())
            else:
                patience_counter += 1
            if patience_counter > 0 and patience_counter % lr_reduce_patience == 0 and patience_counter < total_patience:
                lr_reductions += 1
                reduce_lr_for_model(model, lr_reduce_factor)
                print(f'  [ClassicFF+LA] LR reduced at epoch {ep+1} (#{lr_reductions})')
            if (ep+1) % 10 == 0 or ep == 0:
                print(f'  [ClassicFF+LA a=0.5] ep {ep+1:3d}  loss={loss:.3f}  train={tr_acc:.1f}%  val={va_acc:.1f}%  (p={patience_counter})')
            if ep >= min_epochs and patience_counter >= total_patience:
                print(f'  [ClassicFF+LA] Early stop at epoch {ep+1}')
                break
        if best_state is not None:
            model.load_state_dict(best_state)
        te_acc = model.evaluate(ds['X_test'], ds['y_test'])
        print(f'  [ClassicFF+LA a=0.5] DONE  test={te_acc:.2f}%  {time.time()-t0:.0f}s  {n_params} params')
        results['ClassicFF_LocalAdapt_a0.5'] = {'history': hist, 'test_acc': te_acc, 'params': n_params}
        LOGGER.log({'dataset': ds_name, 'method': 'ClassicFF_LocalAdapt', 'seed': seed,
                     'alpha': 0.5, 'hidden_sizes': str(classic_hidden),
                     'best_val_acc': round(best_val, 2), 'test_acc': round(te_acc, 2),
                     'epochs_run': len(hist['loss']), 'total_params': n_params})

    # 5. BP Baseline
    if bp_hidden:
        print('\n--- BP Baseline ---')
        if activation != 'perceptron':
            h, a, p = train_bp(ds_name, ds, bp_hidden, seed, cfg, activation=activation)
            results['BP'] = {'history': h, 'test_acc': a, 'params': p}
        else:
            print('  [BP] Skipped (perceptron activation is FF-specific)')

    # 6. ModularFF — all architectures × alpha values
    for spec_hidden in modularff_archs:
        for alpha in cfg.get('alpha_values', [0.0, 0.5, 1.0]):
            arch_str = '_'.join(map(str, spec_hidden))
            print(f'\n--- ModularFF arch={spec_hidden} (alpha={alpha}) ---')
            h, a, p, ens, mr, sr = train_modularff(
                ds_name, ds, spec_hidden, seed, cfg,
                alpha=alpha, layer_dropout=None, meta='argmax',
                init_method=init_method, elm_mode=False, activation=activation,
                pruning_enabled=False, use_meta_layer=use_meta_layer,
                show_specialist_perf=show_specialist_perf,
                learnable_theta=learnable_theta
            )
            results[f'ModularFF_{arch_str}_a{alpha}'] = {
                'history': h, 'test_acc': a, 'params': p,
                'meta_results': mr, 'specialist_results': sr,
                'architecture': spec_hidden
            }

    # Save results
    save_path = os.path.join(CONFIG['results_path'], ds_name, f'results_2layer_seed{seed}.json')

    serializable = {}
    for key, val in results.items():
        entry = {
            'test_acc': val['test_acc'], 'params': val['params'],
            'train_acc': val['history']['train_acc'],
            'val_acc': val['history']['val_acc'],
            'wall_time': val['history'].get('wall_time', []),
        }
        if 'meta_results' in val:
            entry['meta_results'] = val['meta_results']
        if 'specialist_results' in val:
            entry['specialist_results'] = val['specialist_results']
        if 'architecture' in val:
            entry['architecture'] = val['architecture']
        serializable[key] = entry

    with open(save_path, 'w') as f:
        json.dump(serializable, f, indent=2)
    print(f'\n\u2713 2-Layer results saved: {save_path}')

    # Print summary
    print(f'\n  --- {ds_name} 2-Layer Summary ---')
    for key, val in results.items():
        print(f'  {key}: {val["test_acc"]:.1f}% ({val["params"]:,} params)')

    return results


# ================================================================
# PHASE 2: 4-LAYER HINTON COMPARISON
# ================================================================

def run_4layer_experiment(ds_name, seed, cfg, show_specialist_perf=False, activation='relu', learnable_theta=False):
    """
    Run ONLY the 4-layer Hinton-style comparison experiments.

    Phase 2 optimized:
    - Alpha values: [0.0, 0.5, 1.0] (skip 0.3)
    - Classic FF runs with BOTH Adam and SGD (Hinton's optimizer)
    """
    ds = DATASETS[ds_name]
    arch = ARCHITECTURES[ds_name]
    results = {}

    init_method = cfg.get('init_method', 'kaiming')
    use_meta_layer = cfg.get('use_meta_layer', True)
    hinton_lr = cfg.get('hinton_sgd_lr', 0.03)

    # Get 4-layer architectures
    classic_ff_4L = arch.get('classic_ff_4L')
    modularff_4L = arch.get('modularff_4L')

    if classic_ff_4L is None or modularff_4L is None:
        print(f'  WARNING: 4-layer architectures not defined for {ds_name}, skipping')
        return results

    print(f'\n{"="*70}')
    print(f'  {ds_name} | seed={seed} | K={ds["num_classes"]} | dim={ds["input_dim"]}')
    print(f'  4-LAYER HINTON COMPARISON (Phase 2)')
    print(f'  Classic FF 4L: {classic_ff_4L}')
    print(f'  ModularFF 4L: {modularff_4L} × {ds["num_classes"]} specialists')
    print(f'  Adam LR={cfg["lr"]}, SGD LR={hinton_lr}, batch={cfg["batch_size"]}, patience={cfg["early_stop_patience"]}, min_ep={cfg.get("min_epochs", 50)}')
    print(f'{"="*70}')

    # 1a. Classic FF 4-Layer with Adam (same optimizer as ModularFF)
    print('\n--- Classic FF 4-Layer (Adam) ---')
    h, a, p = train_classic_ff(ds_name, ds, classic_ff_4L, seed, cfg,
                                optimizer='adam', activation=activation, learnable_theta=learnable_theta)
    results['ClassicFF_4L_Adam'] = {'history': h, 'test_acc': a, 'params': p}

    # 1b. Classic FF 4-Layer with SGD (Hinton's original optimizer)
    print(f'\n--- Classic FF 4-Layer (SGD, LR={hinton_lr}) ---')
    h, a, p = train_classic_ff(ds_name, ds, classic_ff_4L, seed, cfg,
                                optimizer='sgd', lr_override=hinton_lr, activation=activation, learnable_theta=learnable_theta)
    results['ClassicFF_4L_SGD'] = {'history': h, 'test_acc': a, 'params': p}

    # 2. ModularFF 4-Layer (parameter-matched, Adam)
    for alpha in cfg.get('alpha_values', [0.0, 0.3, 0.5, 1.0]):
        print(f'\n--- ModularFF 4L arch={modularff_4L} (alpha={alpha}) ---')
        h, a, p, ens, mr, sr = train_modularff(
            ds_name, ds, modularff_4L, seed, cfg,
            alpha=alpha, layer_dropout=None, meta='argmax',
            init_method=init_method, elm_mode=False, activation=activation,
            pruning_enabled=False, use_meta_layer=use_meta_layer,
            show_specialist_perf=show_specialist_perf,
            learnable_theta=learnable_theta
        )
        arch_str = '_'.join(map(str, modularff_4L))
        results[f'ModularFF_4L_{arch_str}_a{alpha}'] = {
            'history': h, 'test_acc': a, 'params': p,
            'meta_results': mr, 'specialist_results': sr,
            'architecture': modularff_4L
        }

    # Save results
    save_path = os.path.join(CONFIG['results_path'], ds_name, f'results_4layer_seed{seed}.json')

    serializable = {}
    for key, val in results.items():
        entry = {
            'test_acc': val['test_acc'], 'params': val['params'],
            'train_acc': val['history']['train_acc'],
            'val_acc': val['history']['val_acc'],
            'wall_time': val['history'].get('wall_time', []),
        }
        if 'meta_results' in val:
            entry['meta_results'] = val['meta_results']
        if 'specialist_results' in val:
            entry['specialist_results'] = val['specialist_results']
        if 'architecture' in val:
            entry['architecture'] = val['architecture']
        serializable[key] = entry

    with open(save_path, 'w') as f:
        json.dump(serializable, f, indent=2)
    print(f'\n\u2713 4-Layer results saved: {save_path}')

    # Print comparison summary
    print(f'\n  --- {ds_name} Summary ---')
    adam_acc = results.get('ClassicFF_4L_Adam', {}).get('test_acc', 0)
    sgd_acc = results.get('ClassicFF_4L_SGD', {}).get('test_acc', 0)
    best_mod_acc = max(
        (v.get('test_acc', 0) for k, v in results.items() if k.startswith('ModularFF')),
        default=0
    )
    print(f'  ClassicFF 4L (Adam):  {adam_acc:.1f}%')
    print(f'  ClassicFF 4L (SGD):   {sgd_acc:.1f}%')
    print(f'  ModularFF 4L (best):  {best_mod_acc:.1f}%')

    return results


# ================================================================
print('\u2713 Training engine ready (v5):')
print(f'  - Patience: {CONFIG["early_stop_patience"]} (reduce LR every {CONFIG.get("lr_reduce_patience", 20)} epochs)')
print(f'  - Min epochs: {CONFIG.get("min_epochs", 50)}')
print('  - Best-model checkpoint: ON')
print('  - 2-layer experiments: run_2layer_experiment()')
print('  - 4-layer Hinton comparison: run_4layer_experiment()')
print('  - Classic FF: Adam and SGD (Hinton config)')
print(f'  - Logger: {LOGGER.path}')

✓ Training engine ready (v5):
  - Patience: 60 (reduce LR every 20 epochs)
  - Min epochs: 50
  - Best-model checkpoint: ON
  - 2-layer experiments: run_2layer_experiment()
  - 4-layer Hinton comparison: run_4layer_experiment()
  - Classic FF: Adam and SGD (Hinton config)
  - Logger: /content/drive/My Drive/Research/ModularFF/Logs/experiment_logs.csv


In [None]:
######### end of cell 3

In [None]:
# ================================================================
# CELL 4a: Run 2-Layer Experiments (v6 — Activation Sweep)
# ================================================================
# Runs all 2-layer experiments for all datasets × activations:
#   - ClassicFF (Adam + SGD), Embed, Additive, LocalAdapt
#   - BP Baseline
#   - ModularFF: all architectures × alpha values
#   - Activations: gelu, tanh, hardlimit, perceptron
#   - Loop order: DATASET (outer) × ACTIVATION (inner)
#     → completes all activations per dataset before moving on
# ================================================================

#ACTIVATIONS = ['gelu', 'tanh', 'hardlimit', 'perceptron']
ACTIVATIONS = ['perceptron']

LEARNABLE_THETA = True  # Set False for fixed theta (default values per activation)

SHOW_SPECIALIST_PERF = False

ALL_RESULTS = {}  # key: (ds_name, activation) -> {seed -> results}

print('\n' + '#'*70)
print('#  2-LAYER EXPERIMENTS (with Activation Sweep)')
print('#  Datasets:', CONFIG['datasets_to_run'])
print('#  Activations:', ACTIVATIONS)
print('#  Alpha values:', CONFIG.get('alpha_values', [0.0, 0.5, 1.0]))
print('#  Adam LR:', CONFIG['lr'], '| SGD LR:', CONFIG.get('hinton_sgd_lr', 0.03),
      '| Batch:', CONFIG['batch_size'], '| Patience:', CONFIG['early_stop_patience'])
print('#  Loop order: DATASET (outer) × ACTIVATION (inner)')
print('#'*70)

for ds_name in CONFIG['datasets_to_run']:
    if ds_name not in DATASETS:
        print(f'WARNING: {ds_name} not loaded, skipping')
        continue

    print('\n' + '='*70)
    print(f'  DATASET: {ds_name}')
    print('='*70)

    for act in ACTIVATIONS:
        print('\n' + '#'*70)
        print(f'#  DATASET: {ds_name} | ACTIVATION: {act}')
        print('#'*70)

        result_key = (ds_name, act)
        ALL_RESULTS[result_key] = {}

        for seed in CONFIG['seeds']:
            ALL_RESULTS[result_key][seed] = run_2layer_experiment(
                ds_name, seed, CONFIG,
                show_specialist_perf=SHOW_SPECIALIST_PERF,
                activation=act,
                learnable_theta=LEARNABLE_THETA
            )

    # --- Per-dataset summary (all activations) ---
    print('\n' + '='*90)
    print(f' {ds_name} — 2-LAYER CROSS-ACTIVATION SUMMARY')
    print('='*90)
    print(f'{"Activation":<15} {"BP":>8} {"FF best":>8} {"ModularFF":>10} {"Δ vs FF":>10}')
    print('-'*90)

    for act in ACTIVATIONS:
        result_key = (ds_name, act)
        if result_key not in ALL_RESULTS:
            continue
        for seed in CONFIG['seeds']:
            if seed not in ALL_RESULTS[result_key]:
                continue
            res = ALL_RESULTS[result_key][seed]

            bp_acc = res.get('BP', {}).get('test_acc', 0)
            ff_best = max(
                res.get('ClassicFF_Adam', {}).get('test_acc', 0),
                res.get('ClassicFF_SGD', {}).get('test_acc', 0),
                res.get('ClassicFF_Embed', {}).get('test_acc', 0),
                res.get('ClassicFF_Additive', {}).get('test_acc', 0),
                res.get('ClassicFF_LocalAdapt_a0.5', {}).get('test_acc', 0),
            )
            best_mod = max(
                (v.get('test_acc', 0) for k, v in res.items() if k.startswith('ModularFF')),
                default=0
            )
            delta = best_mod - ff_best
            delta_str = f'+{delta:.1f}%' if delta > 0 else f'{delta:.1f}%'
            bp_str = f'{bp_acc:.1f}%' if bp_acc > 0 else '  N/A'
            print(f'{act:<15} {bp_str:>8} {ff_best:>7.1f}% {best_mod:>9.1f}% {delta_str:>10}')

    print('='*90)


# ================================================================
# FINAL SUMMARY (all datasets × all activations)
# ================================================================
print('\n' + '='*70)
print(' 2-LAYER EXPERIMENTS COMPLETE (all activations)')
print('='*70)

for ds_name in CONFIG['datasets_to_run']:
    print(f'\n{"="*70}')
    print(f'  DATASET: {ds_name}')
    print(f'{"="*70}')

    for act in ACTIVATIONS:
        result_key = (ds_name, act)
        if result_key not in ALL_RESULTS:
            continue
        print(f'\n  {act.upper()}:')

        for seed in CONFIG['seeds']:
            if seed not in ALL_RESULTS[result_key]:
                continue
            res = ALL_RESULTS[result_key][seed]

            bp_acc = res.get('BP', {}).get('test_acc', 0)
            ff_adam = res.get('ClassicFF_Adam', {}).get('test_acc', 0)
            ff_sgd = res.get('ClassicFF_SGD', {}).get('test_acc', 0)
            ff_embed = res.get('ClassicFF_Embed', {}).get('test_acc', 0)
            ff_add = res.get('ClassicFF_Additive', {}).get('test_acc', 0)
            ff_la = res.get('ClassicFF_LocalAdapt_a0.5', {}).get('test_acc', 0)
            ff_best = max(ff_adam, ff_sgd, ff_embed, ff_add, ff_la)

            best_mod_acc, best_mod_key = 0, ''
            for key, val in res.items():
                if key.startswith('ModularFF') and 'test_acc' in val:
                    if val['test_acc'] > best_mod_acc:
                        best_mod_acc = val['test_acc']
                        best_mod_key = key

            print(f'    seed={seed}:')
            bp_str = f'{bp_acc:.1f}%' if bp_acc > 0 else 'N/A'
            print(f'      BP:                  {bp_str}')
            print(f'      ClassicFF (Adam):    {ff_adam:.1f}%')
            print(f'      ClassicFF (SGD):     {ff_sgd:.1f}%')
            print(f'      ClassicFF (Embed):   {ff_embed:.1f}%')
            print(f'      ClassicFF (Additive):{ff_add:.1f}%')
            print(f'      ClassicFF (LA a=0.5):{ff_la:.1f}%')
            print(f'      ModularFF (best):    {best_mod_acc:.1f}% [{best_mod_key}]')

# ================================================================
# GRAND COMPARISON TABLE
# ================================================================
for ds_name in CONFIG['datasets_to_run']:
    K = DATASETS[ds_name]['num_classes']
    print(f'\n{"="*90}')
    print(f' 2-LAYER RESULTS — {ds_name} (K={K}) — ALL ACTIVATIONS')
    print(f'{"="*90}')
    print(f'{"Activation":<15} {"BP":>8} {"FF best":>8} {"ModularFF":>10} {"Δ vs FF":>10}')
    print('-'*90)

    for act in ACTIVATIONS:
        result_key = (ds_name, act)
        if result_key not in ALL_RESULTS:
            continue
        for seed in CONFIG['seeds']:
            if seed not in ALL_RESULTS[result_key]:
                continue
            res = ALL_RESULTS[result_key][seed]

            bp_acc = res.get('BP', {}).get('test_acc', 0)
            ff_best = max(
                res.get('ClassicFF_Adam', {}).get('test_acc', 0),
                res.get('ClassicFF_SGD', {}).get('test_acc', 0),
                res.get('ClassicFF_Embed', {}).get('test_acc', 0),
                res.get('ClassicFF_Additive', {}).get('test_acc', 0),
                res.get('ClassicFF_LocalAdapt_a0.5', {}).get('test_acc', 0),
            )

            best_mod = max(
                (v.get('test_acc', 0) for k, v in res.items() if k.startswith('ModularFF')),
                default=0
            )

            delta = best_mod - ff_best
            delta_str = f'+{delta:.1f}%' if delta > 0 else f'{delta:.1f}%'
            bp_str = f'{bp_acc:.1f}%' if bp_acc > 0 else '  N/A'
            print(f'{act:<15} {bp_str:>8} {ff_best:>7.1f}% {best_mod:>9.1f}% {delta_str:>10}')

    print('='*90)


[1;30;43mStreaming output truncated to the last 5000 lines.[0m

--- ModularFF arch=[100, 100] (alpha=0.0) ---
  [ModularFF a=0.0 ld=uniform] ep   1  loss=0.968  train=86.4%  val=86.2%  (p=0)
  [ModularFF a=0.0 ld=uniform] ep  10  loss=0.719  train=91.2%  val=90.5%  (p=0)
  [ModularFF a=0.0 ld=uniform] ep  20  loss=0.708  train=91.9%  val=91.2%  (p=0)
  [ModularFF a=0.0 ld=uniform] ep  30  loss=0.702  train=92.2%  val=91.2%  (p=0)
  [ModularFF a=0.0 ld=uniform] ep  40  loss=0.701  train=92.6%  val=91.3%  (p=3)
  [ModularFF a=0.0 ld=uniform] ep  50  loss=0.696  train=92.8%  val=91.6%  (p=2)
  [ModularFF a=0.0 ld=uniform] ep  60  loss=0.694  train=92.9%  val=91.6%  (p=2)
  [ModularFF a=0.0 ld=uniform] ep  70  loss=0.693  train=93.1%  val=91.9%  (p=0)
  [ModularFF a=0.0 ld=uniform] ep  80  loss=0.691  train=93.2%  val=91.6%  (p=10)
  [ModularFF] LR reduced (×0.5) -> 5.0e-03 at epoch 90 (reduction #1)
  [ModularFF a=0.0 ld=uniform] ep  90  loss=0.688  train=93.3%  val=91.8%  (p=20)
  [Mod

In [4]:
# ================================================================
# CELL 4: Run Experiments (v6 — 4-Layer + Activation Sweep)
# ================================================================
# Phase 2: 4-Layer Hinton Comparison with Activation Sweep
# - Datasets: FashionMNIST, Pendigits, LetterRecog (4-layer configs)
# - Alpha: [0.0, 0.5, 1.0]
# - Classic FF: both Adam and SGD (Hinton config)
# - Activations: gelu, tanh, hardlimit, perceptron
# - Loop order: DATASET (outer) × ACTIVATION (inner)
# ================================================================

#ACTIVATIONS = ['gelu', 'tanh', 'hardlimit', 'perceptron']
ACTIVATIONS = ['perceptron']
LEARNABLE_THETA = True  # Set False for fixed theta (default values per activation)

# ================================================================
# EXPERIMENT FLAGS
# ================================================================

RUN_4LAYER = True              # Run 4-layer Hinton comparison
SHOW_SPECIALIST_PERF = False   # Print per-specialist table (verbose)

# ================================================================
# RUN EXPERIMENTS
# ================================================================

ALL_RESULTS = {}  # key: (ds_name, activation) -> {seed -> results}

print('\n' + '#'*70)
print('#  PHASE 2: 4-LAYER HINTON COMPARISON (with Activation Sweep)')
print('#  Datasets:', CONFIG['datasets_to_run'])
print('#  Activations:', ACTIVATIONS)
print('#  Alpha values: [0.0, 0.5, 1.0]')
print('#  Adam LR:', CONFIG['lr'], '| SGD LR:', CONFIG.get('hinton_sgd_lr', 0.03),
      '| Batch:', CONFIG['batch_size'], '| Patience:', CONFIG['early_stop_patience'])
print('#  Loop order: DATASET (outer) × ACTIVATION (inner)')
print('#'*70)

for ds_name in CONFIG['datasets_to_run']:
    if ds_name not in DATASETS:
        print(f'WARNING: {ds_name} not loaded, skipping')
        continue

    print('\n' + '='*70)
    print(f'  DATASET: {ds_name}')
    print('='*70)

    for act in ACTIVATIONS:
        print('\n' + '#'*70)
        print(f'#  DATASET: {ds_name} | ACTIVATION: {act}')
        print('#'*70)

        result_key = (ds_name, act)
        ALL_RESULTS[result_key] = {}

        for seed in CONFIG['seeds']:
            ALL_RESULTS[result_key][seed] = run_4layer_experiment(
                ds_name, seed, CONFIG,
                show_specialist_perf=SHOW_SPECIALIST_PERF,
                activation=act,
                learnable_theta=LEARNABLE_THETA
            )

    # --- Per-dataset summary (all activations) ---
    print('\n' + '='*90)
    print(f' {ds_name} — 4-LAYER CROSS-ACTIVATION SUMMARY')
    print('='*90)
    print(f'{"Activation":<15} {"FF(Adam)":>10} {"FF(SGD)":>10} {"ModularFF":>10} {"Δ vs best FF":>14}')
    print('-'*90)

    for act in ACTIVATIONS:
        result_key = (ds_name, act)
        if result_key not in ALL_RESULTS:
            continue
        for seed in CONFIG['seeds']:
            if seed not in ALL_RESULTS[result_key]:
                continue
            res = ALL_RESULTS[result_key][seed]

            ff_adam = res.get('ClassicFF_4L_Adam', {}).get('test_acc', 0)
            ff_sgd = res.get('ClassicFF_4L_SGD', {}).get('test_acc', 0)
            ff_best = max(ff_adam, ff_sgd)
            best_acc = max(
                (v.get('test_acc', 0) for k, v in res.items() if k.startswith('ModularFF_4L')),
                default=0
            )
            delta = best_acc - ff_best
            delta_str = f'+{delta:.1f}%' if delta > 0 else f'{delta:.1f}%'
            print(f'{act:<15} {ff_adam:>9.1f}% {ff_sgd:>9.1f}% {best_acc:>9.1f}% {delta_str:>14}')

    print('='*90)


# ================================================================
# FINAL SUMMARY
# ================================================================
print('\n' + '='*70)
print(' PHASE 2: 4-LAYER HINTON COMPARISON COMPLETE (all activations)')
print('='*70)

for ds_name in CONFIG['datasets_to_run']:
    print(f'\n{"="*70}')
    print(f'  DATASET: {ds_name}')
    print(f'{"="*70}')

    for act in ACTIVATIONS:
        result_key = (ds_name, act)
        if result_key not in ALL_RESULTS:
            continue
        print(f'\n  {act.upper()}:')

        for seed in CONFIG['seeds']:
            if seed not in ALL_RESULTS[result_key]:
                continue
            res = ALL_RESULTS[result_key][seed]

            ff_adam = res.get('ClassicFF_4L_Adam', {}).get('test_acc', 0)
            ff_adam_params = res.get('ClassicFF_4L_Adam', {}).get('params', 0)
            ff_sgd = res.get('ClassicFF_4L_SGD', {}).get('test_acc', 0)
            ff_sgd_params = res.get('ClassicFF_4L_SGD', {}).get('params', 0)
            ff_best = max(ff_adam, ff_sgd)

            best_acc, best_key, best_params = 0, '', 0
            for key, val in res.items():
                if key.startswith('ModularFF_4L') and 'test_acc' in val:
                    if val['test_acc'] > best_acc:
                        best_acc = val['test_acc']
                        best_key = key
                        best_params = val.get('params', 0)

            print(f'    seed={seed}:')
            print(f'      ClassicFF 4L (Adam): {ff_adam:.1f}% ({ff_adam_params:,} params)')
            print(f'      ClassicFF 4L (SGD):  {ff_sgd:.1f}% ({ff_sgd_params:,} params)')
            print(f'      ModularFF 4L (best): {best_acc:.1f}% ({best_params:,} params) [{best_key}]')
            if best_acc > ff_best:
                print(f'      --> ModularFF wins by +{best_acc - ff_best:.1f}% (vs best ClassicFF)')
            elif ff_best > best_acc:
                winner = 'Adam' if ff_adam >= ff_sgd else 'SGD'
                print(f'      --> ClassicFF ({winner}) wins by +{ff_best - best_acc:.1f}%')
            else:
                print(f'      --> TIE')

print(f'\n\u2713 Results: {CONFIG["results_path"]}')
print(f'\u2713 Logs: {LOGGER.path}')

# ================================================================
# GRAND COMPARISON TABLE
# ================================================================
for ds_name in CONFIG['datasets_to_run']:
    K = DATASETS[ds_name]['num_classes']
    print(f'\n{"="*90}')
    print(f' 4-LAYER RESULTS — {ds_name} (K={K}) — ALL ACTIVATIONS')
    print(f'{"="*90}')
    print(f'{"Activation":<15} {"FF(Adam)":>10} {"FF(SGD)":>10} {"ModularFF":>10} {"Δ vs best FF":>14}')
    print('-'*90)

    for act in ACTIVATIONS:
        result_key = (ds_name, act)
        if result_key not in ALL_RESULTS:
            continue
        for seed in CONFIG['seeds']:
            if seed not in ALL_RESULTS[result_key]:
                continue
            res = ALL_RESULTS[result_key][seed]

            ff_adam = res.get('ClassicFF_4L_Adam', {}).get('test_acc', 0)
            ff_sgd = res.get('ClassicFF_4L_SGD', {}).get('test_acc', 0)
            ff_best = max(ff_adam, ff_sgd)

            best_acc = max(
                (v.get('test_acc', 0) for k, v in res.items() if k.startswith('ModularFF_4L')),
                default=0
            )

            delta = best_acc - ff_best
            delta_str = f'+{delta:.1f}%' if delta > 0 else f'{delta:.1f}%'

            print(f'{act:<15} {ff_adam:>9.1f}% {ff_sgd:>9.1f}% {best_acc:>9.1f}% {delta_str:>14}')

    print('='*90)



######################################################################
#  PHASE 2: 4-LAYER HINTON COMPARISON (with Activation Sweep)
#  Datasets: ['FashionMNIST']
#  Activations: ['perceptron']
#  Alpha values: [0.0, 0.5, 1.0]
#  Adam LR: 0.01 | SGD LR: 0.03 | Batch: 256 | Patience: 60
#  Loop order: DATASET (outer) × ACTIVATION (inner)
######################################################################

  DATASET: FashionMNIST

######################################################################
#  DATASET: FashionMNIST | ACTIVATION: perceptron
######################################################################

  FashionMNIST | seed=42 | K=10 | dim=784
  4-LAYER HINTON COMPARISON (Phase 2)
  Classic FF 4L: [2000, 2000, 2000, 2000]
  ModularFF 4L: [550, 550, 550, 550] × 10 specialists
  Adam LR=0.01, SGD LR=0.03, batch=256, patience=60, min_ep=50

--- Classic FF 4-Layer (Adam) ---
  [ClassicFF_ADAM] ep   1  loss=4.000  train=10.2%  val=10.4%  (p=0)
  [ClassicFF_ADAM] ep  10  

In [None]:
########## end of cell 4

In [None]:
# ================================================================
# CELL 5: Analysis — Activation Sweep Results (v5)
# ================================================================
# Handles activation dimension + 2-layer / 4-layer split
# File naming: results_2layer_seed42.json, results_4layer_seed42.json
# Each JSON contains keys like "ClassicFF_Adam", "ModularFF_4L_50_50_50_50_a0.0"
# with an "activation" field inside each entry
# ================================================================

import pandas as pd
import re

# ================================================================
# LOAD ALL RESULTS
# ================================================================

def load_all_results():
    """Load all JSON result files, keyed by (dataset, depth, activation, seed)."""
    all_res = {}
    for ds in CONFIG['datasets_to_run']:
        result_dir = os.path.join(CONFIG['results_path'], ds)
        if not os.path.exists(result_dir):
            continue
        for fname in os.listdir(result_dir):
            if not fname.endswith('.json'):
                continue
            fpath = os.path.join(result_dir, fname)

            # Parse filename: results_2layer_seed42.json or results_4layer_seed42.json
            # Also handle old format: results_seed42.json
            m = re.match(r'results_(\d+layer)_seed(\d+)\.json', fname)
            if m:
                depth = m.group(1)  # "2layer" or "4layer"
                seed = int(m.group(2))
            else:
                m2 = re.match(r'results_seed(\d+)\.json', fname)
                if m2:
                    depth = "2layer"
                    seed = int(m2.group(1))
                else:
                    continue

            with open(fpath, 'r') as f:
                data = json.load(f)

            # Group entries by activation
            for method_key, val in data.items():
                if not isinstance(val, dict):
                    continue
                act = val.get('activation', 'relu')  # default to relu for old results
                key = (ds, depth, act, seed)
                if key not in all_res:
                    all_res[key] = {}
                all_res[key][method_key] = val

    return all_res


print('Loading results...')
ALL_RES = load_all_results()
print(f'✓ Loaded {len(ALL_RES)} (dataset, depth, activation, seed) combinations\n')

# Discover what's available
available = {}
for (ds, depth, act, seed) in ALL_RES.keys():
    if ds not in available:
        available[ds] = {'depths': set(), 'activations': set(), 'seeds': set()}
    available[ds]['depths'].add(depth)
    available[ds]['activations'].add(act)
    available[ds]['seeds'].add(seed)

for ds, info in sorted(available.items()):
    print(f'  {ds}: depths={sorted(info["depths"])}, '
          f'activations={sorted(info["activations"])}, '
          f'seeds={sorted(info["seeds"])}')


# ================================================================
# HELPER FUNCTIONS
# ================================================================

def get_acc(all_res, ds, depth, act, seed, method_key):
    """Get test accuracy for a specific (ds, depth, activation, seed, method)."""
    key = (ds, depth, act, seed)
    if key not in all_res:
        return None
    entry = all_res[key]
    # Exact match
    if method_key in entry:
        return entry[method_key].get('test_acc')
    # Partial match
    for k, v in entry.items():
        if method_key in k:
            return v.get('test_acc')
    return None


def get_meta_accs(all_res, ds, depth, act, seed, method_key):
    """Get all meta-layer accuracies for a specific entry."""
    key = (ds, depth, act, seed)
    if key not in all_res:
        return {}
    entry = all_res[key]
    for k, v in entry.items():
        if method_key in k or k == method_key:
            return v.get('meta_results', {})
    return {}


def get_best_acc(all_res, ds, depth, act, seed, method_prefix):
    """Get best test accuracy (including meta-layer) for a method prefix."""
    key = (ds, depth, act, seed)
    if key not in all_res:
        return None
    best = 0
    for k, v in all_res[key].items():
        if not k.startswith(method_prefix):
            continue
        acc = v.get('test_acc', 0)
        best = max(best, acc)
        for m_acc in v.get('meta_results', {}).values():
            best = max(best, m_acc)
    return best if best > 0 else None


def get_goodness_info(all_res, ds, depth, act, seed, method_key):
    """Get goodness diagnostics from an entry."""
    key = (ds, depth, act, seed)
    if key not in all_res:
        return None
    for k, v in all_res[key].items():
        if method_key in k or k == method_key:
            return v.get('goodness_diagnostics')
    return None


def fmt(val):
    if val is None:
        return '---'
    return f'{val:.1f}%'


def fmt_delta(val, ref):
    if val is None or ref is None:
        return '---'
    d = val - ref
    sign = '+' if d > 0 else ''
    return f'{sign}{d:.1f}%'


# ================================================================
# TABLE A: ACTIVATION COMPARISON — 2-Layer
# For each activation: best ClassicFF, best ModularFF, BP, delta
# ================================================================
for depth, depth_label in [('2layer', '2-LAYER'), ('4layer', '4-LAYER')]:
    print(f'\n{"="*90}')
    print(f' TABLE: {depth_label} ACTIVATION COMPARISON')
    print(f'{"="*90}')

    # Get all activations available for this depth
    all_acts = set()
    for (ds, d, act, seed) in ALL_RES.keys():
        if d == depth:
            all_acts.add(act)
    all_acts = sorted(all_acts)

    if not all_acts:
        print('  No results found for this depth.')
        continue

    for ds in CONFIG['datasets_to_run']:
        if ds not in available:
            continue
        seeds = sorted(available[ds]['seeds'])
        seed = seeds[0]  # Use first seed for display

        ds_label = {'XOR': 'XOR (K=4)', 'MNIST': 'MNIST', 'FashionMNIST': 'FashionMNIST',
                    'Pendigits': 'Pendigits', 'LetterRecog': 'LetterRecog'}.get(ds, ds)

        print(f'\n  {ds_label}:')
        header = f'  {"Activation":12s}  {"BP":>8s}  {"FF(Adam)":>10s}  {"FF(SGD)":>10s}  {"ModularFF":>10s}  {"Δ vs FF":>10s}'
        print(header)
        print('  ' + '-' * (len(header) - 2))

        for act in all_acts:
            # BP
            bp = get_acc(ALL_RES, ds, depth, act, seed, 'BP')

            # Classic FF
            ff_adam = get_acc(ALL_RES, ds, depth, act, seed, 'ClassicFF_Adam')
            ff_sgd = get_acc(ALL_RES, ds, depth, act, seed, 'ClassicFF_SGD')

            # Best ClassicFF (including variants for 2-layer)
            ff_best = 0
            for k_prefix in ['ClassicFF']:
                key = (ds, depth, act, seed)
                if key in ALL_RES:
                    for k, v in ALL_RES[key].items():
                        if k.startswith('ClassicFF'):
                            acc = v.get('test_acc', 0)
                            ff_best = max(ff_best, acc)
            ff_best = ff_best if ff_best > 0 else None

            # Best ModularFF (including all meta-layers)
            mod_best = get_best_acc(ALL_RES, ds, depth, act, seed, 'ModularFF')

            delta = fmt_delta(mod_best, ff_best) if mod_best and ff_best else '---'

            row = f'  {act:12s}  {fmt(bp):>8s}  {fmt(ff_adam):>10s}  {fmt(ff_sgd):>10s}  {fmt(mod_best):>10s}  {delta:>10s}'
            print(row)


# ================================================================
# TABLE B: GOODNESS DIAGNOSTICS — Learnable Theta Values
# ================================================================
print(f'\n{"="*90}')
print(f' TABLE: LEARNABLE THETA VALUES (4-Layer, per activation)')
print(f'{"="*90}')

for ds in CONFIG['datasets_to_run']:
    if ds not in available:
        continue
    seeds = sorted(available[ds]['seeds'])
    seed = seeds[0]

    all_acts = sorted(a for (d, dp, a, s) in ALL_RES.keys() if d == ds and dp == '4layer')
    if not all_acts:
        continue

    ds_label = {'XOR': 'XOR', 'MNIST': 'MNIST', 'FashionMNIST': 'FashionMNIST',
                'Pendigits': 'Pendigits', 'LetterRecog': 'LetterRecog'}.get(ds, ds)
    print(f'\n  {ds_label} — ClassicFF 4L (Adam):')

    for act in all_acts:
        key = (ds, '4layer', act, seed)
        if key not in ALL_RES:
            continue
        for k, v in ALL_RES[key].items():
            if 'ClassicFF_Adam' in k:
                gd = v.get('goodness_diagnostics', {})
                if gd:
                    layers = gd if isinstance(gd, list) else []
                    theta_str = ' | '.join(
                        f'L{i}: th={l.get("theta", "?"):.3f}, sep={l.get("separation", "?"):.3f}'
                        for i, l in enumerate(layers)
                    ) if layers else str(gd)
                    print(f'    {act:12s}: {theta_str}')
                break


# ================================================================
# TABLE C: DETAILED METHOD COMPARISON (per activation)
# ================================================================
for depth, depth_label in [('2layer', '2-LAYER'), ('4layer', '4-LAYER')]:
    print(f'\n{"="*90}')
    print(f' TABLE: {depth_label} DETAILED RESULTS BY ACTIVATION')
    print(f'{"="*90}')

    for ds in CONFIG['datasets_to_run']:
        if ds not in available:
            continue
        seeds = sorted(available[ds]['seeds'])
        seed = seeds[0]

        all_acts = sorted(a for (d, dp, a, s) in ALL_RES.keys() if d == ds and dp == depth)
        if not all_acts:
            continue

        for act in all_acts:
            key = (ds, depth, act, seed)
            if key not in ALL_RES:
                continue

            print(f'\n  {ds} | {depth} | {act}:')
            entries = ALL_RES[key]
            for method_key in sorted(entries.keys()):
                val = entries[method_key]
                acc = val.get('test_acc', 0)
                params = val.get('total_params', '?')
                mr = val.get('meta_results', {})
                meta_str = ''
                if mr:
                    best_meta = max(mr, key=mr.get)
                    meta_str = f'  (best meta: {best_meta}={mr[best_meta]:.1f}%)'
                print(f'    {method_key:45s}  {acc:6.1f}%  [{params} params]{meta_str}')


# ================================================================
# SUMMARY CSV
# ================================================================
print(f'\n{"-"*40}')

summary_rows = []
for (ds, depth, act, seed), entries in ALL_RES.items():
    for method_key, val in entries.items():
        row = {
            'dataset': ds, 'depth': depth, 'activation': act, 'seed': seed,
            'method': method_key, 'test_acc': val.get('test_acc'),
            'total_params': val.get('total_params'),
        }
        for m, v in val.get('meta_results', {}).items():
            row[f'meta_{m}'] = v
        summary_rows.append(row)

if summary_rows:
    df = pd.DataFrame(summary_rows)
    csv_path = os.path.join(CONFIG['results_path'], 'activation_sweep_summary.csv')
    df.to_csv(csv_path, index=False)
    print(f'✓ Summary saved: {csv_path}')
    print(f'  {len(df)} entries across {df["activation"].nunique()} activations, '
          f'{df["depth"].nunique()} depths, {df["dataset"].nunique()} datasets')

print('\n✓ Analysis complete.')


In [None]:
######### end of cell 5

In [None]:
# ================================================================
# CELL 6: Visualization (v4)
# ================================================================

DS_LABELS = {'XOR': '2D XOR (K=4)', 'MNIST': 'MNIST (K=10)',
             'Pendigits': 'Pendigits (K=10)', 'LetterRecog': 'Letters (K=26)'}


def load_results_for_viz():
    all_res = {}
    for ds in CONFIG['datasets_to_run']:
        all_res[ds] = {}
        result_dir = os.path.join(CONFIG['results_path'], ds)
        if not os.path.exists(result_dir):
            continue
        for fname in os.listdir(result_dir):
            if fname.endswith('.json'):
                seed = int(fname.replace('results_seed', '').replace('.json', '').split('_')[0])
                fpath = os.path.join(result_dir, fname)
                with open(fpath, 'r') as f:
                    all_res[ds][seed] = json.load(f)
    return all_res


# ================================================================
# FIGURE 1: XOR Analysis
# ================================================================

def fig_xor_analysis():
    if 'XOR' not in DATASETS:
        print('  XOR not loaded, skipping')
        return

    ds = DATASETS['XOR']
    arch = ARCHITECTURES['XOR']
    set_seed(42)
    dev = CONFIG['device']

    ens = ModularFFEnsemble(ds['input_dim'], arch['modularff'], ds['num_classes'],
                        CONFIG['lr'], CONFIG['theta_neuron'], dev,
                        use_meta_layer=False)
    for _ in range(60):
        ens.train_epoch(ds['X_train'], ds['y_train'], alpha=0.3)

    cff = ClassicFF(ds['input_dim'], arch['classic_ff'], ds['num_classes'], CONFIG['lr'], dev)
    loader = make_loader(ds['X_train'], ds['y_train'], CONFIG['batch_size'])
    for _ in range(60):
        cff.train_epoch(loader)

    res = 100
    xr = np.linspace(-1.2, 1.2, res)
    xx, yy = np.meshgrid(xr, xr)
    grid = np.column_stack([xx.ravel(), yy.ravel()]).astype(np.float32)

    G = ens._all_goodness(grid).cpu().numpy()
    modularff_pred = ens.predict(grid, 'argmax')
    ff_pred = cff.predict(grid)
    modularff_acc = ens.evaluate(ds['X_test'], ds['y_test'], 'argmax')
    ff_acc = cff.evaluate(ds['X_test'], ds['y_test'])

    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    cmap = plt.cm.Set1

    for k in range(4):
        ax = axes[k // 2, k % 2]
        im = ax.contourf(xx, yy, G[:, k].reshape(res, res), levels=20, cmap='hot')
        ax.axhline(0, color='white', ls='--', alpha=0.5)
        ax.axvline(0, color='white', ls='--', alpha=0.5)
        ax.set_title(f'Specialist {k}')
        plt.colorbar(im, ax=ax, shrink=0.8)

    ax = axes[0, 2]
    ax.contourf(xx, yy, modularff_pred.reshape(res, res), levels=4, cmap=cmap, alpha=0.4)
    for k in range(4):
        m = ds['y_test'] == k
        ax.scatter(ds['X_test'][m, 0], ds['X_test'][m, 1], c=[cmap(k)], s=10)
    ax.set_title(f'ModularFF ({modularff_acc:.1f}%)')

    ax = axes[1, 2]
    ax.contourf(xx, yy, ff_pred.reshape(res, res), levels=4, cmap=cmap, alpha=0.4)
    for k in range(4):
        m = ds['y_test'] == k
        ax.scatter(ds['X_test'][m, 0], ds['X_test'][m, 1], c=[cmap(k)], s=10)
    ax.set_title(f'Classic FF ({ff_acc:.1f}%)')

    plt.suptitle('XOR Analysis', fontsize=13, fontweight='bold')
    plt.tight_layout()
    path = os.path.join(CONFIG['figures_path'], 'XOR', 'xor_analysis.png')
    plt.savefig(path, dpi=150, bbox_inches='tight')
    print(f'\u2713 Saved: {path}')
    plt.show()


# ================================================================
# FIGURE 2: Convergence
# ================================================================

def fig_convergence():
    all_res = load_results_for_viz()
    ds_list = [ds for ds in CONFIG['datasets_to_run'] if ds in all_res and 42 in all_res[ds]]

    if not ds_list:
        print('  No results found')
        return

    n = len(ds_list)
    fig, axes = plt.subplots(1, n, figsize=(6*n, 5))
    if n == 1:
        axes = [axes]

    for idx, ds in enumerate(ds_list):
        ax = axes[idx]
        res = all_res[ds][42]

        for key, label, color, ls in [
            ('BP', 'BP', 'black', '-'),
            ('ClassicFF', 'Classic FF', 'blue', '--'),
            ('ModularFF_a0.0', 'ModularFF α=0', 'green', ':'),
            ('ModularFF_a0.3', 'ModularFF α=0.3', 'red', '-'),
            ('ModularFF_a1.0', 'ModularFF α=1.0', 'purple', '-.'),
        ]:
            if key in res and 'val_acc' in res[key]:
                v = res[key]['val_acc']
                ax.plot(range(1, len(v)+1), v, label=label, color=color, ls=ls, lw=1.5)

        ax.set_xlabel('Epoch')
        ax.set_ylabel('Val Accuracy (%)')
        ax.set_title(DS_LABELS.get(ds, ds))
        ax.legend(fontsize=8)
        ax.grid(True, alpha=0.2)

    plt.suptitle('Convergence', fontsize=13, fontweight='bold')
    plt.tight_layout()
    path = os.path.join(CONFIG['figures_path'], 'convergence', 'convergence.png')
    plt.savefig(path, dpi=150, bbox_inches='tight')
    print(f'\u2713 Saved: {path}')
    plt.show()


# ================================================================
# FIGURE 3: Alpha Effect
# ================================================================

def fig_alpha_effect():
    all_res = load_results_for_viz()
    ds_list = [ds for ds in CONFIG['datasets_to_run'] if ds in all_res]

    if not ds_list:
        print('  No results found')
        return

    n = len(ds_list)
    fig, axes = plt.subplots(1, n, figsize=(5*n, 4))
    if n == 1:
        axes = [axes]

    alphas = CONFIG['alpha_values']
    x = np.arange(len(alphas))

    for idx, ds in enumerate(ds_list):
        ax = axes[idx]
        means = []
        for a in alphas:
            vals = [all_res[ds][s].get(f'ModularFF_a{a}', {}).get('test_acc', 0)
                    for s in CONFIG['seeds'] if s in all_res[ds]]
            means.append(np.mean(vals) if vals else 0)

        ax.bar(x, means, color=plt.cm.viridis(np.linspace(0.2, 0.9, len(alphas))))
        ax.set_xticks(x)
        ax.set_xticklabels([str(a) for a in alphas])
        ax.set_xlabel('α')
        ax.set_ylabel('Test Acc (%)')
        ax.set_title(DS_LABELS.get(ds, ds))

        ff_vals = [all_res[ds][s].get('ClassicFF', {}).get('test_acc', 0)
                   for s in CONFIG['seeds'] if s in all_res[ds]]
        if ff_vals:
            ax.axhline(np.mean(ff_vals), color='blue', ls='--', label='Classic FF')
            ax.legend(fontsize=8)

    plt.suptitle('Effect of α', fontsize=13, fontweight='bold')
    plt.tight_layout()
    path = os.path.join(CONFIG['figures_path'], 'alpha_effect.png')
    plt.savefig(path, dpi=150, bbox_inches='tight')
    print(f'\u2713 Saved: {path}')
    plt.show()


# ================================================================
# FIGURE 4: Meta-Layer Comparison
# ================================================================

def fig_meta_comparison():
    all_res = load_results_for_viz()
    ds_list = [ds for ds in CONFIG['datasets_to_run'] if ds in all_res]

    if not ds_list:
        print('  No results found')
        return

    n = len(ds_list)
    fig, axes = plt.subplots(1, n, figsize=(5*n, 4))
    if n == 1:
        axes = [axes]

    meta_types = ['argmax', 'calibrated', 'linear', 'mlp', 'temperature']
    x = np.arange(len(meta_types))
    colors = {'argmax': 'gray', 'calibrated': 'lightblue',
              'linear': 'green', 'mlp': 'orange', 'temperature': 'purple'}

    for idx, ds in enumerate(ds_list):
        ax = axes[idx]

        best_a, best_acc = 0.3, 0
        for a in CONFIG['alpha_values']:
            key = f'ModularFF_a{a}'
            vals = [all_res[ds][s].get(key, {}).get('test_acc', 0)
                    for s in CONFIG['seeds'] if s in all_res[ds]]
            if vals and np.mean(vals) > best_acc:
                best_acc = np.mean(vals)
                best_a = a

        key = f'ModularFF_a{best_a}'
        means = []
        for mt in meta_types:
            vals = []
            for s in CONFIG['seeds']:
                if s in all_res[ds] and key in all_res[ds][s]:
                    mr = all_res[ds][s][key].get('meta_results', {})
                    if mt in mr:
                        vals.append(mr[mt])
            means.append(np.mean(vals) if vals else 0)

        ax.bar(x, means, color=[colors[m] for m in meta_types])
        ax.set_xticks(x)
        ax.set_xticklabels(meta_types, rotation=45, ha='right')
        ax.set_ylabel('Test Acc (%)')
        ax.set_title(f'{DS_LABELS.get(ds, ds)}\n(α={best_a})')

    plt.suptitle('Meta-Layer Comparison', fontsize=13, fontweight='bold')
    plt.tight_layout()
    path = os.path.join(CONFIG['figures_path'], 'meta_comparison.png')
    plt.savefig(path, dpi=150, bbox_inches='tight')
    print(f'\u2713 Saved: {path}')
    plt.show()


# ================================================================
# FIGURE 5: Per-Specialist Performance
# ================================================================

def fig_specialist_performance():
    all_res = load_results_for_viz()
    ds_list = [ds for ds in CONFIG['datasets_to_run'] if ds in all_res]

    if not ds_list:
        print('  No results found')
        return

    for ds in ds_list:
        # Find best alpha
        best_a = 0.3
        best_acc = 0
        for a in CONFIG['alpha_values']:
            key = f'ModularFF_a{a}'
            for s in CONFIG['seeds']:
                if s in all_res[ds] and key in all_res[ds][s]:
                    acc = all_res[ds][s][key].get('test_acc', 0)
                    if acc > best_acc:
                        best_acc = acc
                        best_a = a

        key = f'ModularFF_a{best_a}'
        seed = CONFIG['seeds'][0]

        if seed not in all_res[ds] or key not in all_res[ds][seed]:
            continue

        sr = all_res[ds][seed][key].get('specialist_results', {})
        if not sr:
            continue

        K = len(sr)
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))

        specialists = sorted(sr.keys(), key=lambda x: int(x))
        acc = [sr[k]['accuracy'] for k in specialists]
        sens = [sr[k]['sensitivity'] for k in specialists]
        spec = [sr[k]['specificity'] for k in specialists]
        sep = [sr[k]['separation'] for k in specialists]

        x = np.arange(K)

        # Accuracy
        ax = axes[0]
        ax.bar(x, acc, color='steelblue')
        ax.axhline(np.mean(acc), color='red', ls='--', label=f'Mean={np.mean(acc):.1f}%')
        ax.set_xticks(x)
        ax.set_xticklabels(specialists)
        ax.set_xlabel('Specialist')
        ax.set_ylabel('Binary Accuracy (%)')
        ax.set_title('Per-Specialist Accuracy')
        ax.legend()

        # Sensitivity vs Specificity
        ax = axes[1]
        width = 0.35
        ax.bar(x - width/2, sens, width, label='Sensitivity', color='green')
        ax.bar(x + width/2, spec, width, label='Specificity', color='orange')
        ax.set_xticks(x)
        ax.set_xticklabels(specialists)
        ax.set_xlabel('Specialist')
        ax.set_ylabel('%')
        ax.set_title('Sensitivity vs Specificity')
        ax.legend()

        # Separation
        ax = axes[2]
        colors = ['green' if s > 0 else 'red' for s in sep]
        ax.bar(x, sep, color=colors)
        ax.axhline(0, color='black', lw=0.5)
        ax.axhline(np.mean(sep), color='red', ls='--', label=f'Mean={np.mean(sep):.1f}')
        ax.set_xticks(x)
        ax.set_xticklabels(specialists)
        ax.set_xlabel('Specialist')
        ax.set_ylabel('G_pos - G_neg')
        ax.set_title('Goodness Separation')
        ax.legend()

        plt.suptitle(f'{DS_LABELS.get(ds, ds)} — Per-Specialist Performance (α={best_a})',
                     fontsize=13, fontweight='bold')
        plt.tight_layout()
        path = os.path.join(CONFIG['figures_path'], f'{ds}_specialist_perf.png')
        plt.savefig(path, dpi=150, bbox_inches='tight')
        print(f'\u2713 Saved: {path}')
        plt.show()


# ================================================================
# RUN ALL FIGURES
# ================================================================
print('Generating figures...\n')

print('--- Figure 1: XOR ---')
try:
    fig_xor_analysis()
except Exception as e:
    print(f'  Error: {e}')

print('\n--- Figure 2: Convergence ---')
try:
    fig_convergence()
except Exception as e:
    print(f'  Error: {e}')

print('\n--- Figure 3: Alpha Effect ---')
try:
    fig_alpha_effect()
except Exception as e:
    print(f'  Error: {e}')

print('\n--- Figure 4: Meta-Layer ---')
try:
    fig_meta_comparison()
except Exception as e:
    print(f'  Error: {e}')

print('\n--- Figure 5: Specialist Performance ---')
try:
    fig_specialist_performance()
except Exception as e:
    print(f'  Error: {e}')

print(f'\n\u2713 Figures saved to: {CONFIG["figures_path"]}')

In [None]:
######## end of cell 6 and codebase