## why Attention Based ?
In standard neural networks, "attention" learns to focus on important data by continuously updating weights during training. However, in Reservoir Computing, the internal connections are completely fixed. Because we cannot learn attention, the EC-Var method acts as a built-in attention mechanism by observing how data naturally flows through the network. It scores every connection by looking at three things: how physically strong the connection is (|W|), how well it passes information from one time step to the next (lagged correlation), and how active the receiving neuron is (standard deviation). By pruning away the low-scoring connections, EC-Var structurally forces the network to "pay attention" only to the most active and useful memory pathways. This achieves the exact goal of an attention mechanism—focusing the network's capacity on what matters most—without breaking the rules of a fixed-weight reservoir.

# Regression Pruning Method Comparison
This notebook compares pruning methods on a regression model using saved weights and states.


## Imports and configuration.

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import time
import random
import copy
import math
from sklearn.metrics import mean_squared_error, mean_absolute_error, mutual_info_score
from sklearn.linear_model import Lasso
from reservoirpy.nodes import Ridge
from reservoirpy import Node
from reservoirpy.observables import rmse, nrmse, rsquare
import torch
from brevitas.nn import QuantIdentity
from scipy.stats import spearmanr
from sklearn.decomposition import PCA


# ── Configuration ──────────────────────────────────────────────────────────────
DATASET_DIR       = "."
NARMA_ORDER       = 20
DATASET_TAG       = f"NARMA{NARMA_ORDER}"
DATASET_NAME      = DATASET_TAG

N_NEURONS         = 800
QUANTIZATION_BITS = 4
SEED              = 2341
WARMUP            = 1000
RIDGE             = 1e-10

TRAIN_LEN         = 5000
N_TIMESTEPS       = 10000

SPARSITY_LEVELS   = [0.0, 0.15, 0.30, 0.45, 0.60, 0.70]

MAX_WEIGHTS_TO_TEST = 500
MI_BINS           = QUANTIZATION_BITS * 2
LASSO_ALPHA       = 1e-5

RETRAIN_READOUT   = True
RUN_MI            = True

seed = SEED
np.random.seed(SEED)

print(f"Dataset: {DATASET_TAG} | {QUANTIZATION_BITS}-bit | N={N_NEURONS}")


## Load model artifacts (states, targets, weights, scales, input signal).

In [None]:
states_dir  = os.path.join(DATASET_DIR, "states")
weights_dir = os.path.join(DATASET_DIR, "weights", f"{QUANTIZATION_BITS}bit")

states_train = np.load(os.path.join(states_dir, f"states_train_{QUANTIZATION_BITS}bit.npy"))
states_test  = np.load(os.path.join(states_dir, f"states_test_{QUANTIZATION_BITS}bit.npy"))
y_train      = np.load(os.path.join(states_dir, "y_train.npy"))
y_test       = np.load(os.path.join(states_dir, "y_test.npy"))

W_res        = np.load(os.path.join(weights_dir, f"quantized_reservoir_weights_Wr_{DATASET_TAG}_{QUANTIZATION_BITS}bit.npy"))
W_in         = np.load(os.path.join(weights_dir, f"quantized_input_weights_Win_{DATASET_TAG}_{QUANTIZATION_BITS}bit.npy"))
W_out        = np.load(os.path.join(weights_dir, f"readout_weights_Wout_{DATASET_TAG}_{QUANTIZATION_BITS}bit.npy"))
readout_bias = np.load(os.path.join(weights_dir, f"readout_bias_{DATASET_TAG}_{QUANTIZATION_BITS}bit.npy"))
int_bias     = np.load(os.path.join(weights_dir, f"quantized_bias_weights_{DATASET_TAG}_{QUANTIZATION_BITS}bit.npy"))

x_scale      = np.load(os.path.join(weights_dir, f"scale_x_{QUANTIZATION_BITS}bit.npy"))
scale_Win    = np.load(os.path.join(weights_dir, f"scale_Win_{QUANTIZATION_BITS}bit.npy"))
scale_Wr     = np.load(os.path.join(weights_dir, f"scale_Wr_{QUANTIZATION_BITS}bit.npy"))

# Input signal
u = np.load(os.path.join(DATASET_DIR, f"u_{DATASET_TAG}.npy"))

# Rebuild input splits for quantized runs
X_train = u[NARMA_ORDER:TRAIN_LEN + NARMA_ORDER]
X_test  = u[TRAIN_LEN + NARMA_ORDER + 1:-1]

# EC/PCA use post-warmup states
if WARMUP > 0:
    states_train_ec = states_train[WARMUP:]
    y_train_ec      = y_train[WARMUP:]
else:
    states_train_ec = states_train
    y_train_ec      = y_train

print(f"Loaded {DATASET_TAG} artifacts from '{DATASET_DIR}'")
print(f"  W_res: {W_res.shape}  |  states_train: {states_train.shape}  |  u: {u.shape}")


## Baseline performance with saved readout.

In [None]:
readout_baseline = Ridge(ridge=RIDGE)
readout_baseline.output_dim = 1
readout_baseline.input_dim = N_NEURONS
readout_baseline.Wout = W_out
readout_baseline.bias = readout_bias
readout_baseline.state = {"out": np.zeros((1,))}
readout_baseline.initialized = True

# Ensure 2D target shape for metric helpers
if y_test.ndim == 1:
    y_test_eval = y_test.reshape(-1, 1)
else:
    y_test_eval = y_test

baseline_pred = readout_baseline.run(states_test)
if baseline_pred.ndim == 1:
    baseline_pred = baseline_pred.reshape(-1, 1)

baseline_rmse = rmse(y_test_eval, baseline_pred)
baseline_nrmse = nrmse(y_test_eval, baseline_pred)
baseline_r2 = rsquare(y_test_eval, baseline_pred)

print("Baseline metrics")
print(f"RMSE: {baseline_rmse:.6f}")
print(f"NRMSE: {baseline_nrmse:.6f}")
print(f"R2: {baseline_r2:.6f}")


## Quantization helpers and quantized reservoir node.

In [None]:
def setup_quant_params(num_bits):
    return {
        'bits': num_bits,
        'min_val': -(2 ** (num_bits - 1)),
        'max_val': (2 ** (num_bits - 1)) - 1,
        'threshold_scale': 1.0 / (2 ** num_bits),
        'div_scale': 2 ** num_bits
    }

def extract_Qinput(array, num_bits):
    quant_id = QuantIdentity(return_quant_tensor=True, bit_width=num_bits)
    t = torch.tensor(array, dtype=torch.float32)
    qt = quant_id(t)
    return qt.int().detach().numpy(), qt.scale.detach().numpy(), qt.zero_point.detach().numpy()

def compute_integer_thresholds(scale):
    return np.int32(-1 / scale), np.int32(1 / scale)

def piecewise_linear_hard_tanh_integer(x, lo, hi, div_scale):
    x = np.clip(x, lo, hi)
    x = x + hi
    return (x / div_scale).astype(np.int32)

qp = setup_quant_params(QUANTIZATION_BITS)

int_x, _, _ = extract_Qinput(u, QUANTIZATION_BITS)
int_x_train = int_x[NARMA_ORDER:TRAIN_LEN + NARMA_ORDER]
int_x_test = int_x[TRAIN_LEN + NARMA_ORDER + 1:-1]

input_scale = scale_Win * x_scale
reservoir_scale = scale_Wr * qp['threshold_scale']

inp_lo, inp_hi = compute_integer_thresholds(input_scale)
res_lo, res_hi = compute_integer_thresholds(reservoir_scale)

class QuantizedReservoirNode(Node):
    def __init__(self, Wr_matrix, Win_matrix, bias_array, name=None):
        self.name = name
        self.output_dim = N_NEURONS
        self.input_dim = None
        self.initialized = False
        self.Wr = Wr_matrix
        self.Win = Win_matrix
        self.Bias = bias_array.flatten()

    def initialize(self, x, y=None):
        self.input_dim = x.shape[-1]
        self.state = {"out": np.zeros((N_NEURONS,), dtype=np.int64)}
        self.initialized = True

    def _step(self, state, x):
        s = state["out"].astype(np.int64).reshape(1, N_NEURONS)
        recurrent = s @ self.Wr.astype(np.int32)
        inp = x.reshape(1, -1) @ self.Win.astype(np.int32).T
        out_inp = piecewise_linear_hard_tanh_integer(inp, inp_lo, inp_hi, qp['div_scale'])
        out_rec = piecewise_linear_hard_tanh_integer(recurrent, res_lo, res_hi, qp['div_scale'])
        next_out = (out_inp + out_rec + self.Bias.reshape(1, N_NEURONS)).flatten()
        return {"out": next_out}

    def get_param(self, name):
        if name == "Wr":
            return self.Wr
        if name == "Win":
            return self.Win
        if name == "Bias":
            return self.Bias
        raise KeyError(name)

    def set_param(self, name, value):
        if name == "Wr":
            self.Wr = value
            return
        if name == "Win":
            self.Win = value
            return
        if name == "Bias":
            self.Bias = value
            return
        raise KeyError(name)

quant_node = QuantizedReservoirNode(W_res, W_in, int_bias, name="quant_reservoir")


## Shared metrics helpers.

In [None]:
def calculate_all_metrics_multi_dim(y_true, y_pred):
    metrics = {}
    n_outputs = y_true.shape[1]
    dim_metrics = {'mse': [], 'mae': [], 'rmse': [], 'nrmse': [], 'r2': []}

    for dim in range(n_outputs):
        y_true_dim = y_true[:, dim]
        y_pred_dim = y_pred[:, dim]

        dim_mse = mean_squared_error(y_true_dim, y_pred_dim)
        dim_mae = mean_absolute_error(y_true_dim, y_pred_dim)

        dim_rmse = rmse(y_true_dim, y_pred_dim)
        dim_nrmse = nrmse(y_true_dim, y_pred_dim)
        dim_r2 = rsquare(y_true_dim, y_pred_dim)

        dim_metrics['mse'].append(dim_mse)
        dim_metrics['mae'].append(dim_mae)
        dim_metrics['rmse'].append(dim_rmse)
        dim_metrics['nrmse'].append(dim_nrmse)
        dim_metrics['r2'].append(dim_r2)

    for metric_name in dim_metrics:
        metrics[metric_name] = np.mean(dim_metrics[metric_name])

    metrics['dim_metrics'] = dim_metrics
    return metrics


## Quantized baseline metrics for fixed readout.

In [None]:
quantized_node_fixed = QuantizedReservoirNode(W_res, W_in, int_bias, name="quant_reservoir_fixed")
quantized_readout = copy.deepcopy(readout_baseline)

Quantized_States_test = quantized_node_fixed.run(int_x_test.astype(np.float64)) * qp['threshold_scale']
y_pred_quantized = quantized_readout.run(Quantized_States_test)
if y_pred_quantized.ndim == 1:
    y_pred_quantized = y_pred_quantized.reshape(-1, 1)

quantized_metrics = calculate_all_metrics_multi_dim(y_test_eval, y_pred_quantized)



## EC-Var Score Computation

The lagged correlation matrix is computed **inside** `compute_ec_var_scores()` — it is private to the EC method and not exposed as a loose variable.

**EC-Var formula:** `Score[i,j] = |W_res[i,j]| × |Corr(xᵢ(t), xⱼ(t-1))| × std(xᵢ(t))`
- `|W_res[i,j]|` → connection is physically strong
- `|Corr(xᵢ(t), xⱼ(t-1))|` → real information flows through it (lagged = causal/temporal)
- `std(xᵢ(t))` → destination neuron i is expressive and active

In [None]:
W_res_float = W_res.astype(np.float64)   # shared by LASSO and other methods below

def compute_ec_var_scores(states, W_res, N_NEURONS):
    """
    EC-Var scoring function.
    Score[i,j] = |W[i,j]| × |Corr(x_i(t), x_j(t-1))| × std(x_i(t))
    """
    # Lagged rank correlation: how much does neuron j(t-1) predict neuron i(t)?
    lagged_correlation = np.zeros((N_NEURONS, N_NEURONS))
    for i in range(N_NEURONS):
        for j in range(N_NEURONS):
            lagged_correlation[i, j] = spearmanr(states[1:, i], states[:-1, j]).correlation
    lagged_correlation = np.nan_to_num(lagged_correlation, nan=0.0)

    # Destination neuron expressiveness
    dest_std      = states.std(axis=0)
    dest_std_safe = np.where(dest_std == 0, 1.0, dest_std)

    return (
        np.abs(W_res.astype(np.float64))
        * np.abs(lagged_correlation)
        * dest_std_safe.reshape(N_NEURONS, 1)
    )

effective_connectivity = compute_ec_var_scores(states_train_ec, W_res, N_NEURONS)

print(f"EC-Var score statistics (active connections):")
print(f"  Mean:   {effective_connectivity[W_res != 0].mean():.4f}")
print(f"  Median: {np.median(effective_connectivity[W_res != 0]):.4f}")
print(f"  Std:    {effective_connectivity[W_res != 0].std():.4f}")


## Method 1 — EC-Var (Effective Connectivity with Variance)

**Score:** `|W[i,j]| × |SpearmanCorr(xᵢ(t), xⱼ(t-1))| × std(xᵢ(t))`

Three signals combined into one pruning score:
- **Weight magnitude** — the connection is physically strong
- **Lagged rank correlation** — information actually flows through it (causal, temporal)
- **Destination std** — the receiving neuron is active and expressive, not silent

Connections that score low on any of these three dimensions are pruned first.
Readout is retrained after each pruning step for a fair evaluation.

In [None]:
def prune_by_effective_connectivity(W_res, connectivity_scores, sparsity):
    active_mask = (W_res != 0)
    active_indices = np.argwhere(active_mask)
    n_active = active_indices.shape[0]

    n_keep_active = int(n_active * (1 - sparsity))
    n_keep_active = max(0, min(n_keep_active, n_active))

    mask = np.zeros_like(W_res, dtype=bool)
    if n_keep_active == 0 or n_active == 0:
        return W_res * mask, mask

    active_scores = connectivity_scores[active_mask]
    top_idx = np.argsort(active_scores)[-n_keep_active:]
    kept_active_indices = active_indices[top_idx]
    mask[kept_active_indices[:, 0], kept_active_indices[:, 1]] = True

    W_pruned = W_res * mask
    return W_pruned, mask

attention_results = {
    'removal_percentage': [],
    'rmse_values': [],
    'r2_values': []
}

for sparsity in SPARSITY_LEVELS:
    W_res_ec, _ = prune_by_effective_connectivity(W_res, effective_connectivity, sparsity)
    node = QuantizedReservoirNode(W_res_ec, W_in, int_bias, name="quant_reservoir_ec")

    if RETRAIN_READOUT:
        states_train_pruned = node.run(int_x_train.astype(np.float64)) * qp['threshold_scale']
        readout_use = Ridge(ridge=RIDGE)
        readout_use.fit(states_train_pruned, y_train, warmup=WARMUP)
    else:
        readout_use = readout_baseline

    states_test_pruned = node.run(int_x_test.astype(np.float64)) * qp['threshold_scale']
    y_pred = readout_use.run(states_test_pruned)
    if y_pred.ndim == 1:
        y_pred = y_pred.reshape(-1, 1)

    metrics = calculate_all_metrics_multi_dim(y_test_eval, y_pred)
    attention_results['removal_percentage'].append(int(sparsity * 100))
    attention_results['rmse_values'].append(metrics['rmse'])
    attention_results['r2_values'].append(metrics['r2'])


## Method 2 — Random Pruning (Lower-Bound Baseline)

**Score:** None — connections are removed uniformly at random.

Every non-zero weight has equal probability of being pruned, with no knowledge of its importance. Results are averaged over 5 independent trials to reduce variance.

Serves as the **lower-bound baseline**

In [None]:
def random_weight_removal_analysis_quantized(W_quantized, custom_node, X_train, y_train,
                                             X_test, y_test, threshold_scale,
                                             removal_percentages, num_trials=5, random_seed=seed):
    np.random.seed(random_seed)
    random.seed(random_seed)

    non_zero_mask = W_quantized != 0
    non_zero_positions = list(zip(*np.where(non_zero_mask)))
    total_nonzero_weights = len(non_zero_positions)

    random_results = {
        'removal_percentage': [],
        'weights_removed': [],
        'mse_values': [],
        'mae_values': [],
        'rmse_values': [],
        'nrmse_values': [],
        'r2_values': [],
        'mse_std': [],
        'mae_std': [],
        'rmse_std': [],
        'nrmse_std': [],
        'r2_std': []
    }

    for removal_pct in removal_percentages:
        num_weights_to_remove = int((removal_pct / 100) * total_nonzero_weights)

        trial_results = {'mse': [], 'mae': [], 'rmse': [], 'nrmse': [], 'r2': []}

        for trial in range(num_trials):
            W_modified = W_quantized.copy()
            if num_weights_to_remove > 0:
                weights_to_remove = random.sample(non_zero_positions, num_weights_to_remove)
                for row, col in weights_to_remove:
                    W_modified[row, col] = 0

            modified_node = copy.deepcopy(custom_node)
            modified_node.set_param("Wr", W_modified)

            states_train_pruned = modified_node.run(X_train.astype(np.float64)) * threshold_scale
            retrained_readout = Ridge(ridge=RIDGE)
            retrained_readout.fit(states_train_pruned, y_train, warmup=WARMUP)

            quantized_states_test = modified_node.run(X_test.astype(np.float64)) * threshold_scale
            y_pred_modified = retrained_readout.run(quantized_states_test)
            if y_pred_modified.ndim == 1:
                y_pred_modified = y_pred_modified.reshape(-1, 1)
            modified_metrics = calculate_all_metrics_multi_dim(y_test, y_pred_modified)

            for metric in trial_results.keys():
                trial_results[metric].append(modified_metrics[metric])

        random_results['removal_percentage'].append(removal_pct)
        random_results['weights_removed'].append(num_weights_to_remove)

        for metric in trial_results.keys():
            finite_values = [val for val in trial_results[metric] if not (math.isinf(val) or math.isnan(val))]
            if finite_values:
                avg_val = np.mean(finite_values)
                std_val = np.std(finite_values) if len(finite_values) > 1 else 0.0
            else:
                avg_val = float('inf') if metric != 'r2' else -float('inf')
                std_val = 0.0
            random_results[f'{metric}_values'].append(avg_val)
            random_results[f'{metric}_std'].append(std_val)

    return random_results

removal_percentages = [int(s * 100) for s in SPARSITY_LEVELS]

random_results = random_weight_removal_analysis_quantized(
    W_res, quantized_node_fixed, int_x_train, y_train,
    int_x_test, y_test_eval, qp['threshold_scale'],
    removal_percentages=removal_percentages,
    num_trials=5, random_seed=seed
)


## Method 3 — PCA Pruning

**Score:** PCA loading magnitude of the destination neuron on the top principal component.

Fits PCA on the reservoir training states. Neurons with high loadings on the first principal component encode the most variance in the reservoir dynamics and are considered important. All connections *into* a neuron inherit that neuron's importance as their score.

**Unsupervised** — no task signal (y) is used, only the structure of the reservoir states.

In [None]:
pca_components = 1

pca = PCA(n_components=pca_components)
pca.fit(states_train_ec)

neuron_importance = np.sum(np.abs(pca.components_), axis=0)

weight_scores_pca = np.tile(neuron_importance.reshape(1, -1), (W_res.shape[0], 1))


def prune_by_pca_scores(W_res, weight_scores, sparsity):
    active_mask = (W_res != 0)
    active_indices = np.argwhere(active_mask)
    n_active = active_indices.shape[0]

    n_keep_active = int(n_active * (1 - sparsity))
    n_keep_active = max(0, min(n_keep_active, n_active))

    mask = np.zeros_like(W_res, dtype=bool)
    if n_keep_active == 0 or n_active == 0:
        return W_res * mask, mask

    active_scores = weight_scores[active_mask]
    top_idx = np.argsort(active_scores)[-n_keep_active:]
    kept_active_indices = active_indices[top_idx]
    mask[kept_active_indices[:, 0], kept_active_indices[:, 1]] = True

    W_pruned = W_res * mask
    return W_pruned, mask

pca_results = {
    'removal_percentage': [],
    'rmse_values': [],
    'r2_values': []
}

for sparsity in SPARSITY_LEVELS:
    W_res_pca, _ = prune_by_pca_scores(W_res, weight_scores_pca, sparsity)
    node = QuantizedReservoirNode(W_res_pca, W_in, int_bias, name="quant_reservoir_pca")

    if RETRAIN_READOUT:
        states_train_pruned = node.run(int_x_train.astype(np.float64)) * qp['threshold_scale']
        readout_use = Ridge(ridge=RIDGE)
        readout_use.fit(states_train_pruned, y_train, warmup=WARMUP)
    else:
        readout_use = readout_baseline

    states_test_pruned = node.run(int_x_test.astype(np.float64)) * qp['threshold_scale']
    y_pred = readout_use.run(states_test_pruned)
    if y_pred.ndim == 1:
        y_pred = y_pred.reshape(-1, 1)

    metrics = calculate_all_metrics_multi_dim(y_test_eval, y_pred)
    pca_results['removal_percentage'].append(int(sparsity * 100))
    pca_results['rmse_values'].append(metrics['rmse'])
    pca_results['r2_values'].append(metrics['r2'])



## Method 4 — LASSO Pruning

**Score:** `|W[i,j]| × |lasso_coeff[i]|`

Fits an L1-regularized (LASSO) regression from training reservoir states to the target signal. The L1 penalty shrinks unimportant neuron coefficients to exactly zero. Connections into neurons that LASSO zeroed out are pruned first; connections into task-relevant neurons are protected.

**Supervised** — directly uses the target signal (y) to identify important neurons.

In [None]:
# Fit LASSO on post-warmup training states to get neuron importance
lasso_model = Lasso(alpha=LASSO_ALPHA, max_iter=10000, random_state=SEED)
lasso_model.fit(states_train_ec, y_train_ec.ravel())

lasso_coeff = np.abs(lasso_model.coef_)
lasso_coeff_safe = np.where(lasso_coeff == 0, 0.0, lasso_coeff)

n_nonzero = np.sum(lasso_coeff > 0)
print(f"LASSO (alpha={LASSO_ALPHA}): {n_nonzero}/{N_NEURONS} neurons have non-zero coefficients")
print(f"  coeff range: [{lasso_coeff.min():.6f}, {lasso_coeff.max():.6f}]")

# score[i,j] = |W_res[i,j]| × |lasso_coeff[i]|
lasso_scores = np.abs(W_res_float) * lasso_coeff_safe.reshape(N_NEURONS, 1)

lasso_results = {
    'removal_percentage': [],
    'rmse_values': [],
    'r2_values': []
}

removal_percentages = [int(s * 100) for s in SPARSITY_LEVELS]

for sparsity in SPARSITY_LEVELS:
    W_res_lasso, _ = prune_by_effective_connectivity(W_res, lasso_scores, sparsity)
    node = QuantizedReservoirNode(W_res_lasso, W_in, int_bias, name="quant_reservoir_lasso")

    states_train_pruned = node.run(int_x_train.astype(np.float64)) * qp['threshold_scale']
    readout_use = Ridge(ridge=RIDGE)
    readout_use.fit(states_train_pruned, y_train, warmup=WARMUP)

    states_test_pruned = node.run(int_x_test.astype(np.float64)) * qp['threshold_scale']
    y_pred = readout_use.run(states_test_pruned)
    if y_pred.ndim == 1:
        y_pred = y_pred.reshape(-1, 1)

    metrics = calculate_all_metrics_multi_dim(y_test_eval, y_pred)
    lasso_results['removal_percentage'].append(int(sparsity * 100))
    lasso_results['rmse_values'].append(metrics['rmse'])
    lasso_results['r2_values'].append(metrics['r2'])

    print(f"  {int(sparsity*100):>3}% sparsity | R²: {metrics['r2']:.4f} | RMSE: {metrics['rmse']:.6f}")


## Method 5 — Spearman Rank Correlation 

**Score:** `|SpearmanCorr(xᵢ(t), xⱼ(t))|`

Measures how strongly two neurons co-activate at the same timestep using rank (non-linear, monotonic) correlation. Connections between weakly correlated neuron pairs are pruned first.

In [None]:
from scipy.stats import rankdata

# Spearman standalone: Score[i,j] = |SpearmanCorr(x_i(t), x_j(t))|
states_ranked = np.zeros_like(states_train_ec)
for col in range(N_NEURONS):
    states_ranked[:, col] = rankdata(states_train_ec[:, col])

# Contemporaneous Spearman correlation matrix
spearman_contemp = np.corrcoef(states_ranked.T)
spearman_contemp = np.nan_to_num(spearman_contemp, nan=0.0)

# Score is absolute rank correlation at same timestep
spearman_scores = np.abs(spearman_contemp)

print(f"Spearman standalone (contemporaneous) — active connections:")
print(f"  Mean |ρ|: {spearman_scores[W_res != 0].mean():.4f}")
print(f"  Max |ρ|:  {spearman_scores[W_res != 0].max():.4f}")
print(f"  (no lag, no weight magnitude — pure static co-activation)")

spearman_results = {'removal_percentage': [], 'rmse_values': [], 'r2_values': []}

for sparsity in SPARSITY_LEVELS:
    W_res_sp, _ = prune_by_effective_connectivity(W_res, spearman_scores, sparsity)
    node = QuantizedReservoirNode(W_res_sp, W_in, int_bias, name="quant_reservoir_spearman")

    states_train_pruned = node.run(int_x_train.astype(np.float64)) * qp['threshold_scale']
    readout_use = Ridge(ridge=RIDGE)
    readout_use.fit(states_train_pruned, y_train, warmup=WARMUP)

    states_test_pruned = node.run(int_x_test.astype(np.float64)) * qp['threshold_scale']
    y_pred = readout_use.run(states_test_pruned)
    if y_pred.ndim == 1:
        y_pred = y_pred.reshape(-1, 1)

    metrics = calculate_all_metrics_multi_dim(y_test_eval, y_pred)
    spearman_results['removal_percentage'].append(int(sparsity * 100))
    spearman_results['rmse_values'].append(metrics['rmse'])
    spearman_results['r2_values'].append(metrics['r2'])

    print(f"  {int(sparsity*100):>3}% sparsity | R²: {metrics['r2']:.4f} | RMSE: {metrics['rmse']:.6f}")


## Method 6 — Mutual Information (MI) Pruning

**Score:** `MI(xᵢ(t), xⱼ(t))` — joint histogram-based mutual information between neuron pairs.

Measures the shared information between source and destination neuron activations using discretized joint histograms. Connections with low MI (neurons that share little predictable signal) are pruned first.

**Note:** Scoring all connections is expensive, so only `MAX_WEIGHTS_TO_TEST` connections are scored; the rest receive the mean MI score as a neutral default.  
**Unsupervised** — no task signal (y) is used.

In [None]:
def collect_reservoir_activations_quantized(custom_node, input_data, W_matrix, threshold_scale, mi_bins, mi_max_weights):
    quantized_states = custom_node.run(input_data.astype(np.float64)) * threshold_scale
    non_zero_mask = W_matrix != 0.0
    non_zero_indices = list(zip(*np.where(non_zero_mask)))
    if len(non_zero_indices) > mi_max_weights:
        non_zero_indices = non_zero_indices[:mi_max_weights]
    paired_activation_data = {}
    for idx, (i, j) in enumerate(non_zero_indices):
        pre_activations  = quantized_states[:, i]
        post_activations = quantized_states[:, j]
        paired_activation_data[(i, j)] = (pre_activations.tolist(), post_activations.tolist())
    return paired_activation_data, non_zero_indices


def calculate_mutual_information_saliencies(paired_activation_data, non_zero_indices, n_bins):
    weight_saliencies_mi = []
    weight_indices = []
    for idx, (i, j) in enumerate(non_zero_indices):
        pre_activations  = np.array(paired_activation_data[(i, j)][0])
        post_activations = np.array(paired_activation_data[(i, j)][1])
        joint_distribution, _, _ = np.histogram2d(pre_activations, post_activations, bins=n_bins)
        mi = mutual_info_score(None, None, contingency=joint_distribution)
        weight_saliencies_mi.append(mi)
        weight_indices.append((i, j))
    return weight_saliencies_mi, weight_indices


def mi_weight_removal_analysis_quantized(W_quantized, custom_node, X_train, y_train,
                                         X_test, y_test, removal_percentages,
                                         threshold_scale, mi_bins, quantization_bits):
    all_active_indices = list(zip(*np.where(W_quantized != 0)))
    total_active = len(all_active_indices)

    paired_activation_data, scored_indices = collect_reservoir_activations_quantized(
        custom_node, X_train, W_quantized, threshold_scale, mi_bins, MAX_WEIGHTS_TO_TEST
    )
    weight_saliencies_mi, weight_indices_mi = calculate_mutual_information_saliencies(
        paired_activation_data, scored_indices, mi_bins
    )

    mean_mi = float(np.mean(weight_saliencies_mi)) if weight_saliencies_mi else 0.0
    score_dict = {(i, j): s for (i, j), s in zip(weight_indices_mi, weight_saliencies_mi)}
    all_scores = np.array([score_dict.get(ij, mean_mi) for ij in all_active_indices])
    sorted_order = np.argsort(all_scores)

    mi_results = {
        'removal_percentage': [], 'weights_removed': [],
        'mse_values': [], 'mae_values': [], 'rmse_values': [],
        'nrmse_values': [], 'r2_values': []
    }

    for removal_pct in removal_percentages:
        num_weights_to_remove = int((removal_pct / 100) * total_active)
        W_modified = W_quantized.copy()
        if num_weights_to_remove > 0:
            for idx in sorted_order[:num_weights_to_remove]:
                i, j = all_active_indices[idx]
                W_modified[i, j] = 0

        modified_node = copy.deepcopy(custom_node)
        modified_node.set_param("Wr", W_modified)

        states_train_pruned = modified_node.run(X_train.astype(np.float64)) * threshold_scale
        retrained_readout = Ridge(ridge=RIDGE)
        retrained_readout.fit(states_train_pruned, y_train, warmup=WARMUP)

        quantized_states_test = modified_node.run(X_test.astype(np.float64)) * threshold_scale
        y_pred_modified = retrained_readout.run(quantized_states_test)
        if y_pred_modified.ndim == 1:
            y_pred_modified = y_pred_modified.reshape(-1, 1)
        modified_metrics = calculate_all_metrics_multi_dim(y_test, y_pred_modified)

        mi_results['removal_percentage'].append(removal_pct)
        mi_results['weights_removed'].append(num_weights_to_remove)
        mi_results['mse_values'].append(modified_metrics['mse'])
        mi_results['mae_values'].append(modified_metrics['mae'])
        mi_results['rmse_values'].append(modified_metrics['rmse'])
        mi_results['nrmse_values'].append(modified_metrics['nrmse'])
        mi_results['r2_values'].append(modified_metrics['r2'])

    return mi_results

mi_results = None

if RUN_MI:
    node = QuantizedReservoirNode(W_res, W_in, int_bias, name="quant_reservoir_mi")
    removal_percentages = [int(s * 100) for s in SPARSITY_LEVELS]

    mi_results = mi_weight_removal_analysis_quantized(
        W_res, node, int_x_train, y_train,
        int_x_test, y_test_eval, removal_percentages,
        qp['threshold_scale'], MI_BINS, QUANTIZATION_BITS
    )


## **Comparison Plots**

In [None]:
# ============================================================
# PRINT SUMMARY TABLE
# ============================================================
smart_methods = [("EC-Var", attention_results), ("PCA", pca_results), ("Random", random_results)]
if mi_results is not None:
    smart_methods.append(("Mutual Info", mi_results))
smart_methods.append(("LASSO",    lasso_results))
smart_methods.append(("Spearman", spearman_results))

all_pcts = attention_results['removal_percentage']

print(f"\n{'='*100}")
print(f"  RESULTS — {DATASET_NAME}, {QUANTIZATION_BITS}-bit, N={N_NEURONS}")
print(f"  Baseline R²: {baseline_r2:.4f}  |  Baseline RMSE: {baseline_rmse:.6f}")
print(f"{'='*100}")

col_w = 15
header = f"  {'Sparsity':>8} | " + " | ".join(f"{n:^{col_w}}" for n, _ in smart_methods)
print(header)
print("  " + "-"*8 + "-+-" + ("-+-".join(["-"*col_w] * len(smart_methods))))

for i, pct in enumerate(all_pcts):
    r2_row = [res['r2_values'][i] if i < len(res['r2_values']) else float('-inf') for _, res in smart_methods]
    best_r2 = max(r2_row)
    cells = []
    for j, (name, _) in enumerate(smart_methods):
        star = " ★" if (r2_row[j] == best_r2 and pct != 0) else "  "
        cells.append(f"R²={r2_row[j]:>6.4f}{star}")
    print(f"  {pct:>7}%  | " + " | ".join(cells))

print(f"\n  (★ = winner at that sparsity level)")
print(f"{'='*100}\n")

win_counts = {name: 0 for name, _ in smart_methods}
n_pruned = 0
for i, pct in enumerate(all_pcts):
    if pct == 0:
        continue
    n_pruned += 1
    best_val, best_name = None, None
    for name, res in smart_methods:
        val = res['r2_values'][i] if i < len(res['r2_values']) else float('-inf')
        if best_val is None or val > best_val:
            best_val, best_name = val, name
    if best_name:
        win_counts[best_name] += 1

print("  Win counts (best R² at each pruned sparsity level):")
for name, cnt in win_counts.items():
    bar = '█' * cnt + '░' * (n_pruned - cnt)
    print(f"    {name:>14}: {cnt}/{n_pruned}  {bar}")
print()

# ============================================================
# FIGURE: 1×2  —  R²  and  RMSE
# ============================================================
style_map = {
    "EC-Var":      ("tab:blue",   "o-",  3.0),
    "PCA":         ("tab:orange", "s-.", 2.5),
    "Mutual Info": ("tab:green",  "^:",  2.0),
    "Random":      ("tab:red",    "x--", 1.5),
    "LASSO":       ("tab:purple", "d-.", 2.0),
    "Spearman":    ("tab:cyan",   "v:",  2.0),
}

MAX_PCT = int(max(SPARSITY_LEVELS) * 100)

def filter_to_max(res, max_pct):
    x, r2, rmse_vals = [], [], []
    for i, pct in enumerate(res['removal_percentage']):
        if pct <= max_pct:
            x.append(pct)
            r2.append(res['r2_values'][i])
            rmse_vals.append(res['rmse_values'][i])
    return x, r2, rmse_vals

fig, axes = plt.subplots(1, 2, figsize=(16, 6))
fig.suptitle(
    f"Pruning Method Comparison  —  {DATASET_NAME}, {QUANTIZATION_BITS}-bit, N={N_NEURONS}"
    f"   |   Baseline R² = {baseline_r2:.4f}",
    fontsize=14, fontweight='bold'
)

# R²
ax = axes[0]
for name, res in smart_methods:
    c, sty, lw = style_map.get(name, ("gray", "-", 1.5))
    x, r2, _ = filter_to_max(res, MAX_PCT)
    r2_clipped = [max(v, 0.0) for v in r2]
    ax.plot(x, r2_clipped, sty, label=name, color=c, linewidth=lw, markersize=9)
ax.axhline(baseline_r2, color='black', linestyle=':', linewidth=1.5,
           label=f'Baseline ({baseline_r2:.3f})')
ax.set_xlim(-2, MAX_PCT + 2)
ax.set_ylim(0, 1.05)
ax.set_xlabel('Sparsity (%)', fontsize=13)
ax.set_ylabel('R²', fontsize=13)
ax.set_title('R²  vs  Sparsity', fontsize=13, fontweight='bold')
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)

# RMSE
ax = axes[1]
for name, res in smart_methods:
    c, sty, lw = style_map.get(name, ("gray", "-", 1.5))
    x, _, rmse_vals = filter_to_max(res, MAX_PCT)
    ax.plot(x, rmse_vals, sty, label=name, color=c, linewidth=lw, markersize=9)
ax.axhline(baseline_rmse, color='black', linestyle=':', linewidth=1.5,
           label=f'Baseline ({baseline_rmse:.4f})')
ax.set_xlim(-2, MAX_PCT + 2)
ax.set_xlabel('Sparsity (%)', fontsize=13)
ax.set_ylabel('RMSE', fontsize=13)
ax.set_title('RMSE  vs  Sparsity', fontsize=13, fontweight='bold')
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)

plt.tight_layout()

# Save to plots/ folder
plots_dir = "plots"
os.makedirs(plots_dir, exist_ok=True)
plot_filename = os.path.join(plots_dir, f"comparison_{DATASET_TAG}_{QUANTIZATION_BITS}bit.png")
plt.savefig(plot_filename, dpi=200, bbox_inches='tight')
plt.show()
print(f"Figure saved → {plot_filename}")
