# Network sensitivity analysis: single-neuron selectivity

Sensitivity of single-neuron selectivity to key network parameters.

This notebook explores the effect of hyperparameter sweeps (learning rate, competition, delays) on single-neuron information.

**Dependencies:**

---

A) Inference recordings:
- Existing inference recordings saved to disk will be skipped
```bash
for exp in competition_121125 delays_121125 lrate_111125; do
    ./scripts/run_main_workflow.py "experiments/n3p2/train_n3p2_sweep_${exp}" {0..1} --rule inference --chkpt -1 -v;
done
```

B) Calculate metrics:

```bash
for exp in competition_121125 delays_121125 lrate_111125; do
    ./scripts/figures/compute_information_sweep.py "experiments/n3p2/train_n3p2_sweep_${exp}" --baseline n3p2/train_n3p2_lrate_0_04_181023 --baseline_model ALL
done
```

---

**Plots:**

- Supplementary S9A Fig
- Supplementary S9B Fig

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from hsnn import viz
from hsnn.utils import io

# Results directory
RESULTS_DIR = io.BASE_DIR / "out/figures/supplementary/robustness/n3p2"

# Setup output directory
OUTPUT_DIR = io.BASE_DIR / "out/figures/supplementary/fig_S9A_B"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
print(f"Results will be saved to: {OUTPUT_DIR}")

viz.setup_journal_env()

In [None]:
# === PLOTTING CONFIGURATION ===

# Hyperparameter display names
HPARAM_DISPLAY = {
    "lrate": r"$\rho$",
    "I2E": r"$\lambda^{I \to E}$ [nS]",
    "delay_max": r"$d_{\max}$ [ms]",
}

# Figure dimensions (max width 7.5" for manuscript)
FIG_WIDTH_MAX = 7.5
FIG_HEIGHT_DEFAULT = 2

# Colour for trained network
POST_COLOR = "C0"

# Information threshold for "informative" neurons
INFO_THRESHOLD = 2 / 3
MAX_INFO = 1.0  # log2(2) for binary classification

# Total neurons in L4
N_NEURONS_L4 = 4096


def get_hparam_display_name(hparam: str) -> str:
    """Get display name for hyperparameter."""
    return HPARAM_DISPLAY.get(hparam, hparam)


def compute_informative_counts(
    measures_arr: np.ndarray, threshold: float = INFO_THRESHOLD
) -> tuple[float, float]:
    """Compute mean and SEM of informative neuron counts across trials."""
    counts = np.sum(measures_arr >= threshold, axis=1)
    return counts.mean(), counts.std() / np.sqrt(len(counts))


def compute_ranked_stats(
    measures_arr: np.ndarray,
) -> tuple[np.ndarray, np.ndarray]:
    """Compute mean and std of ranked information across trials."""
    return measures_arr.mean(axis=0), measures_arr.std(axis=0)

In [None]:
# === LOAD RESULTS ===

# Discover available hyperparameter sweeps
measures_paths = sorted(RESULTS_DIR.glob("*"))
print(f"Found hyperparameter directories: {[p.name for p in measures_paths]}")

# Load all sweep results
sweep_results: dict[str, dict] = {}
for fpath in measures_paths:
    hparam_name = fpath.name
    results_file = fpath / "information_sweep_max_convex.pkl"
    if results_file.exists():
        sweep_results[hparam_name] = io.load_pickle(results_file)
        hparam_values = sorted(sweep_results[hparam_name].keys())
        print(f"  {hparam_name}: {len(hparam_values)} values - {hparam_values}")
    else:
        print(f"  {hparam_name}: results file not found")

n_hparams = len(sweep_results)
print(f"\nLoaded {n_hparams} hyperparameter sweeps")

## Figure 1: Ranked Information Curves

In [None]:
def plot_ranked_information_sweep(
    hparam_name: str,
    hparam_results: dict,
    ax: plt.Axes,
    max_neurons: int = 4096,
    cmap_name: str = "copper",
) -> None:
    """Plot ranked information curves for a hyperparameter sweep."""
    hparam_values = sorted(hparam_results.keys())
    cmap = plt.get_cmap(cmap_name)
    n_vals = len(hparam_values)
    colors = [cmap(i / max(n_vals - 1, 1)) for i in range(n_vals)]

    nrn_ids = np.arange(1, max_neurons + 1)

    for i, hval in enumerate(hparam_values):
        if "post" not in hparam_results[hval]:
            continue
        measures = hparam_results[hval]["post"]
        mean_vals, _ = compute_ranked_stats(measures)
        ax.plot(
            nrn_ids,
            mean_vals[:max_neurons],
            color=colors[i],
            label=f"{hval}",
            linewidth=1.2,
        )

    ax.set_xlim(1, max_neurons)
    ax.set_ylim(0, 1.05)
    ax.set_xscale("log")
    ax.set_xlabel("Neuron rank #")
    ax.set_ylabel(r"$\mathcal{I}\; (s, \vec{R})$")
    if ax.get_subplotspec().colspan.start > 0:
        ax.set_ylabel(None)
        ax.set_yticklabels([])

    ax.grid(True, alpha=0.4)
    ax.legend(
        title=get_hparam_display_name(hparam_name),
        loc="lower left",
        fontsize="x-small",
        title_fontsize="x-small",
    )

In [None]:
# Create figure with subplots for each hyperparameter
if n_hparams > 0:
    nrows, ncols = 1, n_hparams
    figsize = (FIG_WIDTH_MAX, FIG_HEIGHT_DEFAULT * 0.95)

    fig, axes = plt.subplots(nrows, ncols, figsize=figsize)
    axes_flat = np.atleast_1d(axes).flatten()

    for i, (hparam_name, hparam_results) in enumerate(sweep_results.items()):
        plot_ranked_information_sweep(hparam_name, hparam_results, axes_flat[i])

    # Hide unused axes
    for j in range(i + 1, len(axes_flat)):
        axes_flat[j].axis("off")

    # Hide ylabel for subplot columns other than first
    for ax in axes_flat[1:]:
        ax.set_ylabel("")

    fig.tight_layout()

    # Save figure
    output_path = OUTPUT_DIR / "fig_ranked_information_sweep.pdf"
    viz.save_figure(fig, output_path, overwrite=False, dpi=300)
    plt.show()
else:
    print("No hyperparameter sweeps found to plot.")

## Figure 2: Informative Neuron Counts

In [None]:
def plot_informative_counts_sweep(
    hparam_name: str,
    hparam_results: dict,
    ax: plt.Axes,
    n_neurons: int = N_NEURONS_L4,
    plot_errors: bool = False,
    highlight_idx: int | None = None,
    highlight_color: str = "C1",
) -> None:
    """Plot informative neuron counts vs hyperparameter value.

    Args:
        hparam_name: Name of hyperparameter.
        hparam_results: Dict mapping hparam value -> {'post': np.ndarray}.
        ax: Matplotlib axes.
        n_neurons: Total number of neurons for percentage calculation.
        plot_errors: Whether to show error bars.
        highlight_idx: Index of bar to highlight (e.g., baseline). None to disable.
        highlight_color: Color for highlighted bar (default 'C1' = orange).
    """
    hparam_values = sorted(hparam_results.keys())
    x_pos = np.arange(len(hparam_values))

    means, sems = [], []
    for hval in hparam_values:
        if "post" in hparam_results[hval]:
            mean, sem = compute_informative_counts(hparam_results[hval]["post"])
            means.append(mean)
            sems.append(sem)
        else:
            means.append(0)
            sems.append(0)

    # Set up colors - highlight specific bar if requested
    colors = [POST_COLOR] * len(hparam_values)
    if highlight_idx is not None and 0 <= highlight_idx < len(colors):
        colors[highlight_idx] = highlight_color

    bars = ax.bar(
        x_pos,
        means,
        yerr=(sems if plot_errors else None),
        color=colors,
        # edgecolor="black",
        # linewidth=0.5,
        capsize=2,
        alpha=1,
    )

    # Add fraction ± SEM labels above bars
    for bar, mean_val, sem_val in zip(bars, means, sems):
        if mean_val > 0:
            pct_mean = mean_val / n_neurons * 100
            text = f"{pct_mean:.1f}%"
            ax.annotate(
                text,
                xy=(
                    bar.get_x() + bar.get_width() / 2,
                    bar.get_height() + (sem_val if plot_errors else 0) + 2,
                ),
                ha="center",
                va="bottom",
                fontsize="x-small",
            )

    ax.set_xticks(x_pos)
    ax.set_xticklabels([f"{v}" for v in hparam_values])
    ax.set_xlabel(get_hparam_display_name(hparam_name))
    ax.set_ylabel("# informative neurons")
    ax.set_axisbelow(True)
    ax.grid(axis="y")

In [None]:
# Create figure with 1-row layout
if n_hparams > 0:
    nrows, ncols = 1, n_hparams
    figsize = (FIG_WIDTH_MAX, FIG_HEIGHT_DEFAULT * 1.1)

    fig, axes = plt.subplots(nrows, ncols, figsize=figsize, sharey=True)
    axes_flat = np.atleast_1d(axes).flatten()

    # Highlight index for baseline (middle value, typically index 2 for lrate=0.04)
    highlight_indices = {"lrate": 2, "I2E": 2, "delay_max": 2}

    for i, (hparam_name, hparam_results) in enumerate(sweep_results.items()):
        highlight_idx = highlight_indices.get(hparam_name)
        plot_informative_counts_sweep(
            hparam_name,
            hparam_results,
            axes_flat[i],
            plot_errors=True,
            highlight_idx=highlight_idx,
        )

    # Hide unused axes
    for j in range(i + 1, len(axes_flat)):
        axes_flat[j].axis("off")

    # Hide ylabel for subplot columns other than first
    for ax in axes_flat[1:]:
        ax.set_ylabel("")

    axes_flat[0].set_ylim(0, 540)
    axes_flat[0].set_yticks(np.arange(0, 540, 100))
    fig.tight_layout()

    # Save figure
    output_path = OUTPUT_DIR / "fig_informative_counts_sweep.pdf"
    viz.save_figure(fig, output_path, overwrite=False, dpi=300)
    plt.show()
else:
    print("No hyperparameter sweeps found to plot.")

## Summary Statistics

In [None]:
print("=" * 70)
print("HYPERPARAMETER ROBUSTNESS ANALYSIS SUMMARY")
print("=" * 70)

for hparam_name, hparam_results in sweep_results.items():
    print(f"\n{get_hparam_display_name(hparam_name)}")
    print("-" * 50)

    hparam_values = sorted(hparam_results.keys())

    stats_data = []
    for hval in hparam_values:
        row = {"value": hval}
        if "post" in hparam_results[hval]:
            measures = hparam_results[hval]["post"]
            mean_count, sem_count = compute_informative_counts(measures)
            mean_max = measures.max(axis=1).mean()
            pct_mean = mean_count / N_NEURONS_L4 * 100
            pct_sem = sem_count / N_NEURONS_L4 * 100
            row["count"] = f"{mean_count:.1f} ± {sem_count:.1f}"
            row["fraction"] = f"{pct_mean:.2f} ± {pct_sem:.2f}%"
            row["max_info"] = f"{mean_max:.3f}"
        else:
            row["count"] = "N/A"
            row["fraction"] = "N/A"
            row["max_info"] = "N/A"
        stats_data.append(row)

    stats_df = pd.DataFrame(stats_data)
    print(stats_df.to_string(index=False))

print("\n" + "=" * 70)
print(f"Results saved to: {OUTPUT_DIR}")
print("=" * 70)

In [None]:
# === SAVE SUMMARY TABLES ===

# Create comprehensive summary table
all_stats = []
for hparam_name, hparam_results in sweep_results.items():
    for hval in sorted(hparam_results.keys()):
        if "post" in hparam_results[hval]:
            measures = hparam_results[hval]["post"]
            mean_count, sem_count = compute_informative_counts(measures)
            mean_info, std_info = compute_ranked_stats(measures)

            all_stats.append(
                {
                    "hyperparameter": hparam_name,
                    "value": hval,
                    "informative_count_mean": mean_count,
                    "informative_count_sem": sem_count,
                    "informative_fraction_mean": mean_count / N_NEURONS_L4 * 100,
                    "informative_fraction_sem": sem_count / N_NEURONS_L4 * 100,
                    "max_info_mean": measures.max(axis=1).mean(),
                    "max_info_std": measures.max(axis=1).std(),
                    "mean_info_top10": mean_info[:10].mean(),
                    "n_trials": measures.shape[0],
                }
            )

summary_table = pd.DataFrame(all_stats)
summary_table.to_csv(OUTPUT_DIR / "robustness_summary_statistics.csv", index=False)
print(
    f"Saved summary statistics to: {OUTPUT_DIR / 'robustness_summary_statistics.csv'}"
)

# Display summary table
print("\nInformative Neuron Summary:")
display_cols = [
    "hyperparameter",
    "value",
    "informative_count_mean",
    "informative_fraction_mean",
    "n_trials",
]
print(summary_table[display_cols].round(2).to_string(index=False))