## Imports & Configuration

In [None]:
import numpy as np
import torch
from brevitas.nn import QuantIdentity
from reservoirpy.nodes import Reservoir, Ridge
from reservoirpy import Node
from reservoirpy.datasets import narma
from reservoirpy.observables import rmse, nrmse, rsquare
import os

# ── Configuration ─────────────────────────────────────
N                     = 800        # Size: Total number of "brain cells" in the reservoir
RIDGE                 = 1e-10      # Penalty: Prevents the output weights from getting too extreme (overfitting)
SR                    = 0.5        # Memory Life: Higher SR (0.9) needed for NARMA-30's 30-step memory
INPUT_SCALING         = 0.7        # Strength: Reduced to avoid saturation with longer memory chains
INPUT_CONNECTIVITY    = 0.9        # Entry: % of neurons that are directly wired to the input source
RC_CONNECTIVITY       = 0.2        # Network: % of neurons that are wired to each other (internal wiring)
LR                    = 1.0        # Speed: How fast the neurons react to new info (1.0 = instant update)
WARMUP                = 1000       # Settling: Steps to ignore at the start to let the "echoes" stabilize
SEED                  = 2341       # Consistency: Ensures you get the same random network every time you run it


# NARMA dataset
N_TIMESTEPS           = 10000      # Data Length: Total seconds/steps of data available
NARMA_ORDER           = 20
TRAIN_LEN             = 5000       # Training: Number of samples used to teach the output layer

DATASET_TAG           = f"NARMA{NARMA_ORDER}"

QUANTIZATION_BITS_LIST = [4, 6, 8]

print(f"Configuration:")
print(f"  Dataset:           {DATASET_TAG}")
print(f"  Reservoir neurons: {N}")
print(f"  NARMA order:       {NARMA_ORDER}")
print(f"  Training length:   {TRAIN_LEN}")
print(f"  Test length:       {N_TIMESTEPS - TRAIN_LEN}")


## Hyperparameter Search (Don't Run Now)

In [None]:
RUN_HYPERPARAM_SEARCH = True  # Set to True to run hyperparameter search

if RUN_HYPERPARAM_SEARCH:
    import pandas as pd
    from itertools import product
    from time import time
    import warnings

    # Suppress ill-conditioned matrix warnings for cleaner output
    warnings.filterwarnings('ignore', message='Ill-conditioned matrix')

    print("="*70)
    print("  ADVANCED HYPERPARAMETER SEARCH")
    print("  Fixed: N=500 | Tests: FP32 + 6-bit + 8-bit")
    print("="*70)

    def setup_quant_params(num_bits):
        return {
            'bits': num_bits,
            'min_val': -(2 ** (num_bits - 1)),
            'max_val': (2 ** (num_bits - 1)) - 1,
            'threshold_scale': 1.0 / (2 ** num_bits),
            'div_scale': 2 ** num_bits
        }

    def extract_Qinput(array, num_bits):
        quant_id = QuantIdentity(return_quant_tensor=True, bit_width=num_bits)
        t = torch.tensor(array, dtype=torch.float32)
        qt = quant_id(t)
        return qt.int().detach().numpy(), qt.scale.detach().numpy(), qt.zero_point.detach().numpy()

    def compute_integer_thresholds(scale):
        return np.int32(-1 / scale), np.int32(1 / scale)

    def piecewise_linear_hard_tanh_integer(x, lo, hi, div_scale):
        x = np.clip(x, lo, hi)
        x = x + hi
        return (x / div_scale).astype(np.int32)

    # Define hyperparameter search space
    N_FIXED = 500
    search_space = {
        'SR':              [0.5, 0.7, 0.9],          # Spectral radius
        'RIDGE':           [1e-10, 1e-8, 1e-6],      # Ridge regularization
        'LR':              [0.5, 0.7, 1.0],          # Leak rate
        'INPUT_SCALING':   [0.5, 0.7, 1.0],          # Input scaling
        'RC_CONNECTIVITY': [0.05, 0.1, 0.2]          # Reservoir connectivity (NEW)
    }

    # Generate all combinations
    param_names = list(search_space.keys())
    param_values = list(search_space.values())
    combinations = list(product(*param_values))

    total_runs = len(combinations) * 3  # FP32 + 6bit + 8bit
    print(f"\nTesting {len(combinations)} hyperparameter combinations × 3 models = {total_runs} runs")
    print(f"Fixed: N={N_FIXED}, INPUT_CONNECTIVITY={INPUT_CONNECTIVITY}, WARMUP={WARMUP}, SEED={SEED}\n")
    print("This will take ~10-15 minutes...\n")

    # Prepare dataset (same for all runs)
    rng = np.random.default_rng(seed=SEED)
    u_search = rng.uniform(0, 0.5, size=(N_TIMESTEPS + NARMA_ORDER, 1))
    X_narma_search = narma(n_timesteps=N_TIMESTEPS, order=NARMA_ORDER, u=u_search)
    X_search = X_narma_search[0] if isinstance(X_narma_search, tuple) else X_narma_search
    X_search = X_search[:N_TIMESTEPS]

    X_train_search = u_search[NARMA_ORDER:TRAIN_LEN + NARMA_ORDER]
    y_train_search = X_search[1:TRAIN_LEN + 1]
    X_test_search = u_search[TRAIN_LEN + NARMA_ORDER + 1:-1]
    y_test_search = X_search[TRAIN_LEN + 2:]

    # Helper function to quantize and evaluate
    def evaluate_quantized(esn_fp, num_bits, ridge_val):
        """Quantize FP32 ESN and evaluate."""
        try:
            qp = setup_quant_params(num_bits)

            # Quantize weights and inputs
            int_x, x_scale, _ = extract_Qinput(u_search, num_bits)
            int_Win, scale_Win, _ = extract_Qinput(esn_fp.nodes[0].Win.todense(), num_bits)
            int_Wr, scale_Wr, _ = extract_Qinput(esn_fp.nodes[0].W.todense(), num_bits)

            bias_raw = esn_fp.nodes[0].bias
            bias_array = np.full((1, N_FIXED), bias_raw) if np.isscalar(bias_raw) else np.asarray(bias_raw.todense())
            int_bias, _, _ = extract_Qinput(bias_array, num_bits)

            # Collapse scales to scalars
            if hasattr(x_scale, 'shape') and x_scale.size > 1: x_scale = np.mean(x_scale)
            if hasattr(scale_Win, 'shape') and scale_Win.size > 1: scale_Win = np.mean(scale_Win)
            if hasattr(scale_Wr, 'shape') and scale_Wr.size > 1: scale_Wr = np.mean(scale_Wr)

            input_scale = scale_Win * x_scale
            reservoir_scale = scale_Wr * qp['threshold_scale']
            inp_lo, inp_hi = compute_integer_thresholds(input_scale)
            res_lo, res_hi = compute_integer_thresholds(reservoir_scale)

            # Create quantized node
            class QuantNode(Node):
                def __init__(self, Wr, Win, bias, lo_i, hi_i, lo_r, hi_r, div):
                    self.Wr_int = Wr
                    self.Win_int = Win
                    self.Bias_int = bias.flatten()
                    self.inp_lo = lo_i
                    self.inp_hi = hi_i
                    self.res_lo = lo_r
                    self.res_hi = hi_r
                    self.div_scale = div
                    self.output_dim = N_FIXED
                    self.input_dim = None
                    self.initialized = False

                def initialize(self, x, y=None):
                    self.input_dim = x.shape[-1]
                    self.state = {"out": np.zeros((N_FIXED,), dtype=np.int64)}
                    self.initialized = True

                def _step(self, state, x):
                    s = state["out"].astype(np.int64).reshape(1, N_FIXED)
                    recurrent = s @ self.Wr_int.astype(np.int32)
                    inp = x.reshape(1, -1) @ self.Win_int.astype(np.int32).T
                    out_inp = piecewise_linear_hard_tanh_integer(inp, self.inp_lo, self.inp_hi, self.div_scale)
                    out_rec = piecewise_linear_hard_tanh_integer(recurrent, self.res_lo, self.res_hi, self.div_scale)
                    return {"out": (out_inp + out_rec + self.Bias_int.reshape(1, N_FIXED)).flatten()}

            node = QuantNode(int_Wr, int_Win, int_bias, inp_lo, inp_hi, res_lo, res_hi, qp['div_scale'])

            # Train readout on quantized states
            int_x_train = int_x[NARMA_ORDER:TRAIN_LEN + NARMA_ORDER]
            states_train = node.run(int_x_train.astype(np.float64)) * qp['threshold_scale']

            readout = Ridge(ridge=ridge_val)
            readout.fit(states_train, y_train_search, warmup=WARMUP)

            # Evaluate on test set
            int_x_test = int_x[TRAIN_LEN + NARMA_ORDER + 1:-1]
            states_test = node.run(int_x_test.astype(np.float64)) * qp['threshold_scale']
            y_pred = readout.run(states_test)

            return {
                'rmse': rmse(y_test_search, y_pred),
                'nrmse': nrmse(y_test_search, y_pred),
                'r2': rsquare(y_test_search, y_pred)
            }
        except Exception as e:
            return {'rmse': np.nan, 'nrmse': np.nan, 'r2': np.nan}

    # Run search
    results = []
    start_time = time()
    run_count = 0

    for i, params in enumerate(combinations, 1):
        # Unpack parameters
        sr_val, ridge_val, lr_val, input_scaling_val, rc_conn_val = params

        # ========== FP32 Baseline ==========
        run_count += 1
        try:
            reservoir_fp = Reservoir(
                units=N_FIXED,
                lr=lr_val,
                sr=sr_val,
                input_connectivity=INPUT_CONNECTIVITY,
                rc_connectivity=rc_conn_val,
                input_scaling=input_scaling_val,
                seed=SEED
            )
            readout_fp = Ridge(ridge=ridge_val)
            esn_fp = reservoir_fp >> readout_fp
            esn_fp.fit(X_train_search, y_train_search, warmup=WARMUP)

            y_pred_fp = esn_fp.run(X_test_search)
            rmse_fp = rmse(y_test_search, y_pred_fp)
            nrmse_fp = nrmse(y_test_search, y_pred_fp)
            r2_fp = rsquare(y_test_search, y_pred_fp)
        except:
            rmse_fp, nrmse_fp, r2_fp = np.nan, np.nan, np.nan
            esn_fp = None

        results.append({
            'N': N_FIXED,
            'SR': sr_val,
            'RIDGE': ridge_val,
            'LR': lr_val,
            'INPUT_SCALING': input_scaling_val,
            'RC_CONNECTIVITY': rc_conn_val,
            'MODEL': 'FP32',
            'RMSE': rmse_fp,
            'NRMSE': nrmse_fp,
            'R2': r2_fp
        })

        # ========== 6-bit Quantization ==========
        run_count += 1
        if esn_fp is not None:
            metrics_6 = evaluate_quantized(esn_fp, 6, ridge_val)
        else:
            metrics_6 = {'rmse': np.nan, 'nrmse': np.nan, 'r2': np.nan}

        results.append({
            'N': N_FIXED,
            'SR': sr_val,
            'RIDGE': ridge_val,
            'LR': lr_val,
            'INPUT_SCALING': input_scaling_val,
            'RC_CONNECTIVITY': rc_conn_val,
            'MODEL': '6-bit',
            'RMSE': metrics_6['rmse'],
            'NRMSE': metrics_6['nrmse'],
            'R2': metrics_6['r2']
        })

        # ========== 8-bit Quantization ==========
        run_count += 1
        if esn_fp is not None:
            metrics_8 = evaluate_quantized(esn_fp, 8, ridge_val)
        else:
            metrics_8 = {'rmse': np.nan, 'nrmse': np.nan, 'r2': np.nan}

        results.append({
            'N': N_FIXED,
            'SR': sr_val,
            'RIDGE': ridge_val,
            'LR': lr_val,
            'INPUT_SCALING': input_scaling_val,
            'RC_CONNECTIVITY': rc_conn_val,
            'MODEL': '8-bit',
            'RMSE': metrics_8['rmse'],
            'NRMSE': metrics_8['nrmse'],
            'R2': metrics_8['r2']
        })

        # Progress update
        if i % 10 == 0 or i == len(combinations):
            elapsed = time() - start_time
            eta = (elapsed / run_count) * (total_runs - run_count)
            r2_6_str = f"{metrics_6['r2']:.3f}" if not np.isnan(metrics_6['r2']) else "nan"
            r2_8_str = f"{metrics_8['r2']:.3f}" if not np.isnan(metrics_8['r2']) else "nan"
            print(f"  [{run_count}/{total_runs}] SR={sr_val:.1f}, RIDGE={ridge_val:.0e}, RC_CONN={rc_conn_val:.2f} | FP32 R²={r2_fp:.3f}, 6bit R²={r2_6_str}, 8bit R²={r2_8_str} | ETA: {eta:.0f}s")

    elapsed = time() - start_time
    print(f"\n Search completed in {elapsed:.1f}s ({elapsed/60:.1f} min)")

    # Create DataFrame
    df_results = pd.DataFrame(results)
    df_results.to_csv('hyperparam_search_advanced.csv', index=False)
    print(f" Full results saved to: hyperparam_search_advanced.csv")

    # Display best for each model type
    print("\n" + "="*70)
    print("  BEST CONFIGURATION PER MODEL TYPE")
    print("="*70)

    for model_type in ['FP32', '6-bit', '8-bit']:
        df_model = df_results[df_results['MODEL'] == model_type].copy()
        df_model = df_model.dropna(subset=['R2'])  # Remove failed runs
        df_model = df_model.sort_values('R2', ascending=False)

        if len(df_model) > 0:
            best = df_model.iloc[0]
            print(f"\n{model_type}:")
            print(f"  N={int(best['N'])}, SR={best['SR']}, RIDGE={best['RIDGE']:.0e}, LR={best['LR']}, IN_SCALE={best['INPUT_SCALING']}, RC_CONN={best['RC_CONNECTIVITY']}")
            print(f"  → RMSE={best['RMSE']:.6f}, NRMSE={best['NRMSE']:.6f}, R²={best['R2']:.6f}")

            # Save to file
            with open(f'best_config_{model_type.replace("-", "")}.txt', 'w') as f:
                f.write(f"# Best {model_type} Configuration\n")
                f.write(f"N = {int(best['N'])}\n")
                f.write(f"SR = {best['SR']}\n")
                f.write(f"RIDGE = {best['RIDGE']:.0e}\n")
                f.write(f"LR = {best['LR']}\n")
                f.write(f"INPUT_SCALING = {best['INPUT_SCALING']}\n")
                f.write(f"INPUT_CONNECTIVITY = {INPUT_CONNECTIVITY}\n")
                f.write(f"RC_CONNECTIVITY = {best['RC_CONNECTIVITY']}\n")
                f.write(f"\n# Performance\n")
                f.write(f"RMSE = {best['RMSE']:.6f}\n")
                f.write(f"NRMSE = {best['NRMSE']:.6f}\n")
                f.write(f"R2 = {best['R2']:.6f}\n")

    print("\n" + "="*70)
    print(" Best configs saved to: best_config_FP32.txt, best_config_6bit.txt, best_config_8bit.txt")

else:
    print("Hyperparameter search disabled. Set RUN_HYPERPARAM_SEARCH = True to enable.")


## Load & Prepare NARMA Dataset

NARMA (Nonlinear Auto-Regressive Moving Average) is a classic benchmark for reservoir computing.

In [None]:
print(f"Generating {DATASET_TAG} dataset...")

# Generate random input signal
rng = np.random.default_rng(seed=SEED)
u = rng.uniform(0, 0.5, size=(N_TIMESTEPS + NARMA_ORDER, 1))

# Generate NARMA target signal
X_narma = narma(n_timesteps=N_TIMESTEPS, order=NARMA_ORDER, u=u)

# Check if narma returns tuple or array
if isinstance(X_narma, tuple):
    X = X_narma[0]
else:
    X = X_narma

# Ensure X has exactly N_TIMESTEPS elements
X = X[:N_TIMESTEPS]

# Prepare train/test splits
X_train = u[NARMA_ORDER:TRAIN_LEN + NARMA_ORDER]   # Input for training
y_train = X[1:TRAIN_LEN + 1]                       # Target for training (offset by 1 for prediction)

X_test = u[TRAIN_LEN + NARMA_ORDER + 1:-1]         # Input for testing
y_test = X[TRAIN_LEN + 2:]                         # Target for testing

print(f"\nDataset prepared:")
print(f"  Full signal length: {len(X)}")
print(f"  X_train: {X_train.shape}  (input signal)")
print(f"  y_train: {y_train.shape}  (target signal)")
print(f"  X_test:  {X_test.shape}")
print(f"  y_test:  {y_test.shape}")

assert X_test.shape[0] == y_test.shape[0], f"Test shapes mismatch: X_test {X_test.shape} != y_test {y_test.shape}"
print(f"   Test shapes match!")

print(f"\n  Input range: [{X_train.min():.3f}, {X_train.max():.3f}]")
print(f"  Target range: [{y_train.min():.3f}, {y_train.max():.3f}]")


In [None]:
# Save input signal
np.save(f"u_{DATASET_TAG}.npy", u)
print(f"Input signal saved: u_{DATASET_TAG}.npy  {u.shape}")


In [None]:
# ── Quantization utilities ────────────────────────────
def setup_quant_params(num_bits):
    """Integer-range constants for a given bit-width."""
    return {
        'bits':             num_bits,
        'min_val':          -(2 ** (num_bits - 1)),
        'max_val':          (2 ** (num_bits - 1)) - 1,
        'threshold_scale':  1.0 / (2 ** num_bits),
        'div_scale':        2 ** num_bits
    }

def extract_Qinput(array, num_bits):
    """Quantize a numpy array to num_bits integers via Brevitas."""
    quant_id = QuantIdentity(return_quant_tensor=True, bit_width=num_bits)
    t  = torch.tensor(array, dtype=torch.float32)
    qt = quant_id(t)
    return qt.int().detach().numpy(), qt.scale.detach().numpy(), qt.zero_point.detach().numpy()

def compute_integer_thresholds(scale):
    """Return (lo, hi) hard-tanh bounds in integer domain."""
    return np.int32(-1 / scale), np.int32(1 / scale)

def piecewise_linear_hard_tanh_integer(x, lo, hi, div_scale):
    """Quantized hard-tanh: clip → shift → integer divide."""
    x = np.clip(x, lo, hi)
    x = x + hi                          # shift [lo, hi] → [0, 2*hi]
    return (x / div_scale).astype(np.int32)

## Train FP32 Baseline ESN

In [None]:
print("Training FP32 ESN baseline...")

reservoir_fp32 = Reservoir(
    units=N,
    lr=LR,
    sr=SR,
    input_connectivity=INPUT_CONNECTIVITY,
    rc_connectivity=RC_CONNECTIVITY,
    input_scaling=INPUT_SCALING,
    seed=SEED
)

readout_fp32 = Ridge(ridge=RIDGE)
esn_fp32 = reservoir_fp32 >> readout_fp32
esn_fp32 = esn_fp32.fit(X_train, y_train, warmup=WARMUP)

# Evaluate
y_pred_fp32 = esn_fp32.run(X_test)

fp32_rmse = rmse(y_test, y_pred_fp32)
fp32_nrmse = nrmse(y_test, y_pred_fp32)
fp32_r2 = rsquare(y_test, y_pred_fp32)

print(f"\nFP32 Baseline Results:")
print(f"  RMSE:   {fp32_rmse:.6f}")
print(f"  NRMSE:  {fp32_nrmse:.6f}")
print(f"  R²:     {fp32_r2:.6f}")

## Quantized ESN — Build, Train, Evaluate

In [None]:
def run_quantized_esn(num_bits, esn_fp32, X_train, y_train, X_test, y_test, u_full):
    """
    Quantize the trained FP32 ESN to num_bits and evaluate.
    Returns metrics, node, readout, and states.
    """
    print(f"\n{'='*50}")
    print(f"  {num_bits}-bit Quantization")
    print(f"{'='*50}")
    qp = setup_quant_params(num_bits)

    # Quantize inputs and weights
    print("  Quantizing inputs and weights...")
    int_x,      x_scale,   _ = extract_Qinput(u_full,                              num_bits)
    int_Win,    scale_Win, _ = extract_Qinput(esn_fp32.nodes[0].Win.todense(),     num_bits)
    int_Wr,     scale_Wr,  _ = extract_Qinput(esn_fp32.nodes[0].W.todense(),       num_bits)

    # Handle bias (can be scalar or array)
    bias_raw = esn_fp32.nodes[0].bias
    bias_array = np.full((1, N), bias_raw) if np.isscalar(bias_raw) else np.asarray(bias_raw.todense())
    int_bias,   _,         _ = extract_Qinput(bias_array,                          num_bits)

    # Collapse scales to scalars
    if hasattr(x_scale,   'shape') and x_scale.size   > 1: x_scale   = np.mean(x_scale)
    if hasattr(scale_Win, 'shape') and scale_Win.size > 1: scale_Win = np.mean(scale_Win)
    if hasattr(scale_Wr,  'shape') and scale_Wr.size  > 1: scale_Wr  = np.mean(scale_Wr)

    input_scale     = scale_Win * x_scale
    reservoir_scale = scale_Wr  * qp['threshold_scale']

    # Integer thresholds
    inp_lo, inp_hi = compute_integer_thresholds(input_scale)
    res_lo, res_hi = compute_integer_thresholds(reservoir_scale)
    print(f"  Input thresholds:     [{inp_lo}, {inp_hi}]")
    print(f"  Reservoir thresholds: [{res_lo}, {res_hi}]")

    # Build quantized reservoir node
    class QuantizedNode(Node):
        def __init__(self, name=None):
            self.name        = name
            self.output_dim  = N
            self.input_dim   = None
            self.initialized = False

        def initialize(self, x, y=None):
            self.input_dim   = x.shape[-1]
            self.output_dim  = N
            self.Wr          = int_Wr
            self.Win         = int_Win
            self.Bias        = int_bias.flatten()
            self.state       = {"out": np.zeros((N,), dtype=np.int64)}
            self.initialized = True

        def _step(self, state, x):
            s         = state["out"].astype(np.int64).reshape(1, N)
            recurrent = s @ self.Wr.astype(np.int32)
            inp       = x.reshape(1, -1) @ self.Win.astype(np.int32).T
            out_inp   = piecewise_linear_hard_tanh_integer(inp,       inp_lo, inp_hi, qp['div_scale'])
            out_rec   = piecewise_linear_hard_tanh_integer(recurrent, res_lo, res_hi, qp['div_scale'])
            next_out  = (out_inp + out_rec + self.Bias.reshape(1, N)).flatten()
            return {"out": next_out}

    quant_node = QuantizedNode(name=f"quant_reservoir_{num_bits}bit")

    # Train readout on quantized states
    print("  Running reservoir on training data...")
    int_x_train = int_x[NARMA_ORDER:TRAIN_LEN + NARMA_ORDER]
    states_train = quant_node.run(int_x_train.astype(np.float64)) * qp['threshold_scale']

    print("  Fitting Ridge readout...")
    quant_readout = Ridge(ridge=RIDGE)
    quant_readout.fit(states_train, y_train, warmup=WARMUP)

    # Evaluate on test set (match the indexing used for X_test in FP32)
    print("  Running reservoir on test data...")
    int_x_test = int_x[TRAIN_LEN + NARMA_ORDER + 1:-1]  # Added +1 to match X_test indexing
    states_test = quant_node.run(int_x_test.astype(np.float64)) * qp['threshold_scale']
    y_pred = quant_readout.run(states_test)

    # Metrics
    test_rmse = rmse(y_test, y_pred)
    test_nrmse = nrmse(y_test, y_pred)
    test_r2 = rsquare(y_test, y_pred)

    print(f"\n  RMSE:   {test_rmse:.6f}")
    print(f"  NRMSE:  {test_nrmse:.6f}")
    print(f"  R²:     {test_r2:.6f}")

    return {
        'metrics': {
            'rmse': test_rmse,
            'nrmse': test_nrmse,
            'r2': test_r2
        },
        'node': quant_node,
        'readout': quant_readout,
        'int_Wr': int_Wr,
        'int_Win': int_Win,
        'int_bias': int_bias,
        'states_train': states_train,
        'states_test': states_test,
        'quant_params': qp
    }

## Run All Quantization Levels

In [None]:
quantized_results = {}

for bits in QUANTIZATION_BITS_LIST:
    quantized_results[bits] = run_quantized_esn(
        bits, esn_fp32, X_train, y_train, X_test, y_test, u
    )

# Summary table
print("\n" + "="*70)
print(f"{'Model':<10} {'RMSE':<15} {'NRMSE':<15} {'R²':<15}")
print("-"*70)
print(f"{'FP32':<10} {fp32_rmse:<15.6f} {fp32_nrmse:<15.6f} {fp32_r2:<15.6f}")
for bits in QUANTIZATION_BITS_LIST:
    m = quantized_results[bits]['metrics']
    print(f"{f'{bits}-bit':<10} {m['rmse']:<15.6f} {m['nrmse']:<15.6f} {m['r2']:<15.6f}")
print("="*70)

## Save Quantized Weights

In [None]:
for bits in QUANTIZATION_BITS_LIST:
    wdir = os.path.join("weights", f"{bits}bit")
    os.makedirs(wdir, exist_ok=True)

    r = quantized_results[bits]
    np.save(os.path.join(wdir, f"quantized_reservoir_weights_Wr_{DATASET_TAG}_{bits}bit.npy"),  r['int_Wr'])
    np.save(os.path.join(wdir, f"quantized_input_weights_Win_{DATASET_TAG}_{bits}bit.npy"),     r['int_Win'])
    np.save(os.path.join(wdir, f"quantized_bias_weights_{DATASET_TAG}_{bits}bit.npy"),          r['int_bias'])
    np.save(os.path.join(wdir, f"readout_weights_Wout_{DATASET_TAG}_{bits}bit.npy"),            r['readout'].Wout)
    np.save(os.path.join(wdir, f"readout_bias_{DATASET_TAG}_{bits}bit.npy"),                    r['readout'].bias)

    # Save quantization scales (needed for re-running with pruned weights)
    qp_save = setup_quant_params(bits)
    _, x_scale_save,   _ = extract_Qinput(u,                              bits)
    _, scale_Win_save, _ = extract_Qinput(esn_fp32.nodes[0].Win.todense(), bits)
    _, scale_Wr_save,  _ = extract_Qinput(esn_fp32.nodes[0].W.todense(),   bits)

    if hasattr(x_scale_save,   'shape') and x_scale_save.size   > 1: x_scale_save   = np.mean(x_scale_save)
    if hasattr(scale_Win_save, 'shape') and scale_Win_save.size > 1: scale_Win_save = np.mean(scale_Win_save)
    if hasattr(scale_Wr_save,  'shape') and scale_Wr_save.size  > 1: scale_Wr_save  = np.mean(scale_Wr_save)

    np.save(os.path.join(wdir, f"scale_x_{bits}bit.npy"),   x_scale_save)
    np.save(os.path.join(wdir, f"scale_Win_{bits}bit.npy"), scale_Win_save)
    np.save(os.path.join(wdir, f"scale_Wr_{bits}bit.npy"),  scale_Wr_save)

    print(f"  {bits}-bit weights and scales saved → {wdir}/")

print(f"\nAll weights saved for {DATASET_TAG}.")


## Save States for Attention Pruning

**Important:** Save reservoir states for later use in attention-based neuron importance scoring.

In [None]:
states_dir = "states"
os.makedirs(states_dir, exist_ok=True)

for bits in QUANTIZATION_BITS_LIST:
    np.save(os.path.join(states_dir, f"states_train_{bits}bit.npy"), quantized_results[bits]["states_train"])
    np.save(os.path.join(states_dir, f"states_test_{bits}bit.npy"),  quantized_results[bits]["states_test"])

# Targets are the same regardless of quantization level
np.save(os.path.join(states_dir, "y_train.npy"), y_train)
np.save(os.path.join(states_dir, "y_test.npy"),  y_test)

print(f"States saved for {DATASET_TAG}:")
for bits in QUANTIZATION_BITS_LIST:
    print(f"  {bits}-bit: states_train_{bits}bit.npy, states_test_{bits}bit.npy")
print("  y_train.npy, y_test.npy")
