# Network sensitivity analysis: PNG counts

Sensitivity of PNG counts to key network parameters.

This notebook explores the effect of hyperparameter sweeps (learning rate, competition, delays) on the emergence of three-neuron PNG counts.

**Dependencies:**

---

A) PNG detections and significance testing:
- **This workflow is time-consuming to run**
- Existing results saved to disk will be skipped
```bash
for exp in competition_121125 delays_121125 lrate_111125; do
    ./scripts/run_main_workflow.py "experiments/n3p2/train_n3p2_sweep_${exp}" {0..1} --rule significance --layers 4 --chkpt -1 -v;
done
```

B) Calculate PNG counts:

```bash
for exp in competition_121125 delays_121125 lrate_111125; do
    ./scripts/figures/compute_png_counts_sweep.py "experiments/n3p2/train_n3p2_sweep_${exp}" --baseline n3p2/train_n3p2_lrate_0_04_181023 --baseline_model ALL -v
done
```

---

**Plots:**

- Supplementary S9C Fig

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from hsnn import viz
from hsnn.utils import io

# Results directory
RESULTS_DIR = io.BASE_DIR / "out/figures/supplementary/robustness/n3p2"

# Setup output directory
OUTPUT_DIR = io.BASE_DIR / "out/figures/supplementary/fig_S9C"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
print(f"Results will be saved to: {OUTPUT_DIR}")

viz.setup_journal_env()

In [None]:
# === PLOTTING CONFIGURATION ===

# Hyperparameter display names
HPARAM_DISPLAY = {
    "lrate": r"$\rho$",
    "I2E": r"$\lambda^{I \to E}$ [nS]",
    "delay_max": r"$d_{\max}$ [ms]",
}

# Figure dimensions (max width 7.5" for manuscript)
FIG_WIDTH_MAX = 7.5
FIG_HEIGHT_DEFAULT = 2

# Colour for trained network
POST_COLOR = "C0"

# Total neurons in L4
N_NEURONS_L4 = 4096


def get_hparam_display_name(hparam: str) -> str:
    """Get display name for hyperparameter."""
    return HPARAM_DISPLAY.get(hparam, hparam)


def compute_count_stats(
    counts_arr: np.ndarray, layer_idx: int = 0
) -> tuple[float, float]:
    """Compute mean and SEM of PNG counts across trials for a given layer.

    Args:
        counts_arr: Array of shape (n_trials, n_layers).
        layer_idx: Index of layer to compute stats for.

    Returns:
        Tuple of (mean, sem).
    """
    layer_counts = counts_arr[:, layer_idx]
    return layer_counts.mean(), layer_counts.std() / np.sqrt(len(layer_counts))

In [None]:
# === LOAD RESULTS ===

# Discover available hyperparameter sweeps
hparam_dirs = sorted(RESULTS_DIR.glob("*"))
print(f"Found hyperparameter directories: {[p.name for p in hparam_dirs]}")

# Load all sweep results (PNG counts)
sweep_results: dict[str, dict] = {}
for fpath in hparam_dirs:
    hparam_name = fpath.name
    # Look for PNG counts file (default: layers_4)
    results_file = fpath / "png_counts_sweep_layers_4.pkl"
    if results_file.exists():
        sweep_results[hparam_name] = io.load_pickle(results_file)
        hparam_values = sorted(sweep_results[hparam_name].keys())
        print(f"  {hparam_name}: {len(hparam_values)} values - {hparam_values}")
    else:
        print(f"  {hparam_name}: PNG counts file not found")

n_hparams = len(sweep_results)
print(f"\nLoaded {n_hparams} hyperparameter sweeps")

## Figure: Line Plot with Error Bars

In [None]:
def plot_png_counts_sweep(
    hparam_name: str,
    hparam_results: dict,
    ax: plt.Axes,
    layer_idx: int = 0,
    highlight_idx: int | None = None,
    highlight_color: str = "C1",
) -> None:
    """Plot PNG counts vs hyperparameter value as line plot with error bars.

    Args:
        hparam_name: Name of hyperparameter.
        hparam_results: Dict mapping hparam value -> {'post': np.ndarray}.
        ax: Matplotlib axes.
        layer_idx: Index of layer in counts array.
        highlight_idx: Index of point to highlight (e.g., baseline).
        highlight_color: Color for highlighted point.
    """
    hparam_values = sorted(hparam_results.keys())

    means, sems = [], []
    for hval in hparam_values:
        if "post" in hparam_results[hval]:
            mean, sem = compute_count_stats(hparam_results[hval]["post"], layer_idx)
            means.append(mean)
            sems.append(sem)
        else:
            means.append(np.nan)
            sems.append(np.nan)

    means = np.array(means)
    sems = np.array(sems)

    # Main line plot
    ax.errorbar(
        hparam_values,
        means,
        yerr=sems,
        fmt="o-",
        color=POST_COLOR,
        capsize=5,
        linewidth=1.5,
        markersize=6,
    )

    # Highlight specific point (e.g., baseline)
    if highlight_idx is not None and highlight_idx < len(hparam_values):
        ax.errorbar(
            hparam_values[highlight_idx],
            means[highlight_idx],
            yerr=sems[highlight_idx],
            fmt="o",
            color=highlight_color,
            ecolor=highlight_color,
            capsize=5,
            markersize=6,
            markeredgecolor="k",
            zorder=5,
        )

    ax.set_xticks(hparam_values)
    ax.set_xlabel(get_hparam_display_name(hparam_name))
    ax.set_ylabel("# PNGs")
    ax.grid(True, alpha=1, color="lightgray")

In [None]:
# Create line plot figure
if n_hparams > 0:
    nrows, ncols = 1, n_hparams
    figsize = (FIG_WIDTH_MAX, FIG_HEIGHT_DEFAULT)

    fig, axes = plt.subplots(nrows, ncols, figsize=figsize, sharey=True)
    axes_flat = np.atleast_1d(axes).flatten()

    # Highlight index for baseline (typically index 2 for lrate=0.04)
    highlight_indices = {"lrate": 2, "I2E": 2, "delay_max": 2}

    for i, (hparam_name, hparam_results) in enumerate(sweep_results.items()):
        highlight_idx = highlight_indices.get(hparam_name)
        # highlight_idx = None
        plot_png_counts_sweep(
            hparam_name, hparam_results, axes_flat[i], highlight_idx=highlight_idx
        )

    # Hide unused axes
    for j in range(i + 1, len(axes_flat)):
        axes_flat[j].axis("off")

    # Hide ylabel for subplot columns other than first
    for ax in axes_flat[1:]:
        ax.set_ylabel("")

    axes_flat[0].set_ylim(0, 5000)
    fig.tight_layout()

    # Save figure
    output_path = OUTPUT_DIR / "fig_png_counts_sweep.pdf"
    viz.save_figure(fig, output_path, overwrite=True, dpi=300)
    plt.show()

## Summary Statistics

In [None]:
def build_summary_dataframe(
    sweep_results: dict[str, dict], layer_idx: int = 0
) -> pd.DataFrame:
    """Build summary DataFrame of PNG counts."""
    data = []
    for hparam_name, hparam_results in sweep_results.items():
        for hval in sorted(hparam_results.keys()):
            if "post" in hparam_results[hval]:
                mean, sem = compute_count_stats(
                    hparam_results[hval]["post"], layer_idx
                )
                n_trials = hparam_results[hval]["post"].shape[0]
                data.append(
                    {
                        "hparam": hparam_name,
                        "value": hval,
                        "mean": mean,
                        "sem": sem,
                        "n_trials": n_trials,
                    }
                )
    return pd.DataFrame(data)


# Build summary
summary_df = build_summary_dataframe(sweep_results)
summary_df

In [None]:
print("=" * 70)
print("PNG COUNTS ROBUSTNESS ANALYSIS SUMMARY")
print("=" * 70)

for hparam_name, hparam_results in sweep_results.items():
    print(f"\n{get_hparam_display_name(hparam_name)}")
    print("-" * 50)

    hparam_values = sorted(hparam_results.keys())

    stats_data = []
    for hval in hparam_values:
        row = {"value": hval}
        if "post" in hparam_results[hval]:
            counts = hparam_results[hval]["post"]
            mean, sem = compute_count_stats(counts, layer_idx=0)
            n_trials = counts.shape[0]
            row["count"] = f"{mean:.0f} Â± {sem:.0f}"
            row["n_trials"] = n_trials
        else:
            row["count"] = "N/A"
            row["n_trials"] = 0
        stats_data.append(row)

    stats_df = pd.DataFrame(stats_data)
    print(stats_df.to_string(index=False))

print("\n" + "=" * 70)
print(f"Results saved to: {RESULTS_DIR}")
print("=" * 70)

In [None]:
# === SAVE SUMMARY TABLE ===

if not summary_df.empty:
    summary_df.to_csv(OUTPUT_DIR / "png_counts_summary_statistics.csv", index=False)
    print(f"Saved summary statistics to: {OUTPUT_DIR / 'png_counts_summary_statistics.csv'}")

    # Display formatted summary
    print("\nPNG Counts Summary:")
    print(summary_df.round(2).to_string(index=False))