# Supplementary: Side-resolved information analysis

Single-neuron information analysis and informative-neuron counts for convex/concave boundaries at each N3P2/N4P2 object side.

- This plots S3 and S4 Figs.
- This is the (FF + LAT + FB) network architecture.

**Dependencies:**

---

A) Gather inference spike recordings:
- Inference spike recordings for N3P2 & N4P2: both before and after network training.
- Artifacts that have already been generated from a previous workflow run will be skipped.
- Depends on N3P2 workflows (with and without the `--chkpt -1` argument):
    - `./scripts/run_main_workflow.py experiments/n3p2/train_n3p2_lrate_0_04_181023 0 1 3 4 5 7 8 9 31 --rule inference -v`
- Depends on N4P2 workflows (with and without the `--chkpt -1` argument):
    - `./scripts/run_main_workflow.py experiments/n4p2/train_n4p2_lrate_0_02_181023 0 1 3 4 5 7 12 15 29 --rule inference -v`

---

B) Compute information measures:

- Also run these with the argument `--target 0` for concave selectivity.

i) For N3P2:
```bash
for side in top left right; do
    ./scripts/figures/compute_information.py ./experiments/n3p2/train_n3p2_lrate_0_04_181023 ALL \
        --side $side  --target 1 --output_dir ./out/figures/supplementary/information -v
done
```
ii) and N4P2:
```bash
for side in top left bottom right; do
    ./scripts/figures/compute_information.py ./experiments/n4p2/train_n4p2_lrate_0_02_181023 ALL \
        --side $side  --target 1 --output_dir ./out/figures/supplementary/information -v
done
```

---

**Plots**

- N3P2 and N4P2 Figs
- Panel A: rank-order single neuron information curves
- Panel B: number of selective neurons (exceeding 2/3 threshold)

In [None]:
from pathlib import Path
from typing import Literal

import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import pandas as pd

from hsnn import viz
from hsnn.utils import io

In [None]:
# === Configuration === #
RESULTS_DIR = io.BASE_DIR / "out/figures/supplementary/information"
OUTPUT_DIR = io.BASE_DIR / "out/figures/supplementary/figs_S4_S5"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# Dataset-specific sides
SIDES_N3P2 = ["top", "left", "right"]
SIDES_N4P2 = ["top", "left", "bottom", "right"]

# Conformations to plot
CONFORMATIONS = ["convex", "concave"]

# Architecture to use (FF+LAT+FB)
ARCH = "ALL"

# Information threshold guide line (optional)
INFO_THRESHOLD = 2 / 3

# Setup journal environment for publication-quality figures
viz.setup_journal_env()

## 1) Data Loading Functions

Load precomputed information measures following `compute_information.py` naming conventions.

In [None]:
def get_expected_filename(side: str, conformation: str, noise: int = 0) -> str:
    """Generate expected filename following compute_information.py conventions.

    The script saves files as: information_{side}_{conformation}_noise_{noise}.pkl
    """
    fname = io.formatted_name(f"information_{side}_{conformation}", "pkl", noise=noise)
    return fname


def load_side_measures(
    dataset_dir: Path,
    side: str,
    conformation: str,
    arch: str = "ALL",
    noise: int = 0,
) -> dict[str, npt.NDArray[np.floating]]:
    """Load precomputed information measures for a specific side and conformation.

    Args:
        dataset_dir: Directory containing precomputed results (e.g., RESULTS_DIR/n3p2)
        side: Boundary side (e.g., 'left', 'top', 'right', 'bottom')
        conformation: Target conformation ('convex' or 'concave')
        arch: Network architecture key (default: 'ALL' for FF+LAT+FB)
        noise: Noise amplitude (default: 0)

    Returns:
        Dict with 'pre' and 'post' arrays of shape (num_trials, num_nrns)

    Raises:
        FileNotFoundError: If the expected file does not exist
        KeyError: If the architecture key is not found in the data
    """
    fname = get_expected_filename(side, conformation, noise)
    fpath = dataset_dir / fname

    if not fpath.exists():
        raise FileNotFoundError(
            f"Expected file not found: {fpath}\n"
            f"Please run: scripts/figures/compute_information.py with --side {side} --target {1 if conformation == 'convex' else 0}"
        )

    data = io.load_pickle(fpath)

    if arch not in data:
        available = list(data.keys())
        raise KeyError(
            f"Architecture '{arch}' not found in {fpath}. Available: {available}"
        )

    return data[arch]


def load_dataset_measures(
    dataset_name: Literal["n3p2", "n4p2"],
    arch: str = "ALL",
    noise: int = 0,
) -> dict[str, dict[str, dict[str, npt.NDArray[np.floating]]]]:
    """Load all side and conformation measures for a dataset.

    Args:
        dataset_name: Dataset identifier ('n3p2' or 'n4p2')
        arch: Network architecture key
        noise: Noise amplitude

    Returns:
        Nested dict: measures[side][conformation] -> {'pre': array, 'post': array}
    """
    dataset_dir = RESULTS_DIR / dataset_name
    sides = SIDES_N3P2 if dataset_name == "n3p2" else SIDES_N4P2

    if not dataset_dir.exists():
        raise FileNotFoundError(
            f"Dataset directory not found: {dataset_dir}\n"
            f"Please run scripts/figures/compute_information.py for {dataset_name}"
        )

    measures: dict[str, dict[str, dict[str, npt.NDArray]]] = {}
    loaded_files: list[str] = []

    for side in sides:
        measures[side] = {}
        for conformation in CONFORMATIONS:
            data = load_side_measures(dataset_dir, side, conformation, arch, noise)
            measures[side][conformation] = data
            fname = get_expected_filename(side, conformation, noise)
            loaded_files.append(fname)

    # Print loading summary
    print(f"\n=== Loaded {dataset_name.upper()} measures ===")
    print(f"Directory: {dataset_dir}")
    print(f"Architecture: {arch}")
    print(f"Files loaded: {len(loaded_files)}")
    for fname in loaded_files:
        print(f"  - {fname}")

    # Print shape info from first entry
    first_side = sides[0]
    first_conf = CONFORMATIONS[0]
    for state in ["pre", "post"]:
        shape = measures[first_side][first_conf][state].shape
        print(f"Shape ({state}): {shape}")

    return measures

## 2) Plotting Functions

Match styling from `plot_information.ipynb`.

In [None]:
def plot_rank_curve(
    ax: plt.Axes,
    measures: dict[str, npt.NDArray[np.floating]],
    show_legend: bool = False,
) -> None:
    """Plot rank-ordered information curves for pre and post training.

    Args:
        ax: Matplotlib axes to plot on
        measures: Dict with 'pre' and 'post' arrays of shape (num_trials, num_nrns)
        show_legend: Whether to show legend on this subplot
    """
    num_nrns = measures["pre"].shape[-1]
    xticks = np.arange(1, num_nrns + 1)

    # Plot pre (untrained) - dashed line
    pre_mean = measures["pre"].mean(axis=0)
    ax.plot(xticks, pre_mean, ls=":", label="Untrained", color="C0")

    # Plot post (trained) - solid line
    post_mean = measures["post"].mean(axis=0)
    ax.plot(xticks, post_mean, ls="-", label="Trained", color="C1")

    # Styling to match reference notebook
    ax.set_xscale("log")
    ax.set_xlim(1, num_nrns)
    ax.set_ylim(0, 1.05)
    ax.grid(True)

    if show_legend:
        ax.legend(loc="lower left", fontsize="x-small")


def create_side_grid_figure(
    measures: dict[str, dict[str, dict[str, npt.NDArray[np.floating]]]],
    sides: list[str],
    dataset_name: str,
) -> plt.Figure:
    """Create a grid figure with rows=sides, cols=conformations.

    Args:
        measures: Nested dict measures[side][conformation] -> {'pre': array, 'post': array}
        sides: List of side names for row ordering
        dataset_name: Dataset name for title

    Returns:
        Matplotlib figure
    """
    n_rows = len(sides)
    n_cols = len(CONFORMATIONS)

    # Figure size: maintain aspect ratio similar to reference
    fig_width = 5.5  # 3.5 * n_cols
    fig_height = 4.0 / 3 * n_rows  # 2.0 * n_rows

    fig, axes = plt.subplots(
        n_rows,
        n_cols,
        figsize=(fig_width, fig_height),
        sharex=True,
        sharey=True,
    )

    # Ensure axes is 2D
    if n_rows == 1:
        axes = axes.reshape(1, -1)

    for row, side in enumerate(sides):
        for col, conformation in enumerate(CONFORMATIONS):
            ax: plt.Axes = axes[row, col]

            # Plot the rank curve
            side_measures = measures[side][conformation]
            # Show legend only on first subplot
            show_legend = row == n_rows - 1 and col == 0
            plot_rank_curve(ax, side_measures, show_legend=show_legend)

            # Column titles (top row only)
            if row == 0:
                ax.set_title(dataset_name.upper() + " - " + conformation.capitalize(), fontweight="bold")

            # Row labels (left column only)
            if col == 0:
                ax.set_ylabel(r"$\mathcal{I}\; (s, \vec{R})$")
                ax.text(
                    -0.4,
                    0.5,
                    rf"{side.capitalize()}",
                    size=10,
                    horizontalalignment="right",
                    verticalalignment="center",
                    transform=ax.transAxes,
                    fontweight="bold",
                )

            # X-axis label (bottom row only)
            if row == n_rows - 1:
                ax.set_xlabel("Neuron rank #")

    fig.tight_layout()
    return fig

## 3) Load Data

Load precomputed measures for both datasets.

In [None]:
# Load N3P2 measures
measures_n3p2 = load_dataset_measures("n3p2", arch=ARCH)

In [None]:
# Load N4P2 measures
measures_n4p2 = load_dataset_measures("n4p2", arch=ARCH)

## 4) Generate Figures

### N3P2: Side-Resolved Rank Curves (3×2 grid)

In [None]:
fig_n3p2 = create_side_grid_figure(measures_n3p2, SIDES_N3P2, "n3p2")
plt.show()

In [None]:
# Save N3P2 figure
output_path_n3p2 = OUTPUT_DIR / "fig_n3p2_side_rank_curves.pdf"
viz.save_figure(fig_n3p2, output_path_n3p2, overwrite=False, dpi=300)

### N4P2: Side-Resolved Rank Curves (4×2 grid)

In [None]:
fig_n4p2 = create_side_grid_figure(measures_n4p2, SIDES_N4P2, "n4p2")
plt.show()

In [None]:
# Save N4P2 figure
output_path_n4p2 = OUTPUT_DIR / "fig_n4p2_side_rank_curves.pdf"
viz.save_figure(fig_n4p2, output_path_n4p2, overwrite=False, dpi=300)

## 5) Informative Neuron Statistics

Compute the number and proportion of neurons exceeding the information threshold (2/3 bits).

In [None]:
def compute_informative_stats(
    measures: dict[str, dict[str, dict[str, npt.NDArray[np.floating]]]],
    sides: list[str],
    dataset_name: str,
    threshold: float = INFO_THRESHOLD,
) -> pd.DataFrame:
    """Compute statistics for neurons exceeding the information threshold.

    Args:
        measures: Nested dict measures[side][conformation] -> {'pre': array, 'post': array}
        sides: List of side names
        dataset_name: Dataset identifier
        threshold: Information threshold in bits (default: 2/3)

    Returns:
        DataFrame with columns: dataset, side, conformation, condition,
                                num_informative_mean, num_informative_sem,
                                num_total, proportion_mean, proportion_sem
    """
    records = []

    for side in sides:
        for conformation in CONFORMATIONS:
            side_measures = measures[side][conformation]

            for condition in ["pre", "post"]:
                # Shape: (num_trials, num_nrns)
                info_array = side_measures[condition]
                num_trials, num_total = info_array.shape

                # Count informative neurons per trial (per row)
                informative_per_trial = (info_array >= threshold).sum(axis=1)

                # Compute mean and SEM for counts
                num_informative_mean = informative_per_trial.mean()
                num_informative_sem = informative_per_trial.std(ddof=1) / np.sqrt(
                    num_trials
                )

                # Compute proportion per trial and its statistics
                proportion_per_trial = informative_per_trial / num_total
                proportion_mean = proportion_per_trial.mean()
                proportion_sem = proportion_per_trial.std(ddof=1) / np.sqrt(num_trials)

                records.append(
                    {
                        "dataset": dataset_name,
                        "side": side,
                        "conformation": conformation,
                        "condition": condition,
                        "num_informative_mean": num_informative_mean,
                        "num_informative_sem": num_informative_sem,
                        "num_total": num_total,
                        "proportion_mean": proportion_mean,
                        "proportion_sem": proportion_sem,
                    }
                )

    return pd.DataFrame(records)

In [None]:
# Compute statistics for both datasets
stats_n3p2 = compute_informative_stats(measures_n3p2, SIDES_N3P2, "n3p2")
stats_n4p2 = compute_informative_stats(measures_n4p2, SIDES_N4P2, "n4p2")

# Combine into single dataframe
stats_df = pd.concat([stats_n3p2, stats_n4p2], ignore_index=True)

# Display summary
print(f"Information threshold: {INFO_THRESHOLD:.4f} bits (2/3)")
print(f"Total records: {len(stats_df)}")
stats_df

In [None]:
stats_df_pivot = stats_df.pivot_table(
    index=["dataset", "side", "conformation", "condition"],
    values=[
        "num_informative_mean",
        "num_informative_sem",
        "proportion_mean",
        "proportion_sem",
    ],
)

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(5.5, 4.5), sharey=False)
datasets = ["n3p2", "n4p2"]

for row, dataset in enumerate(datasets):
    for col, conf in enumerate(CONFORMATIONS):
        ax = axes[row, col]
        subset = stats_df[
            (stats_df["dataset"] == dataset) & (stats_df["conformation"] == conf)
        ]
        sides = SIDES_N3P2 if dataset == "n3p2" else SIDES_N4P2

        # Pivot for means and sems
        mean_pivot = subset.pivot(
            index="side", columns="condition", values="num_informative_mean"
        ).reindex(sides)
        sem_pivot = subset.pivot(
            index="side", columns="condition", values="num_informative_sem"
        ).reindex(sides)
        prop_pivot = subset.pivot(
            index="side", columns="condition", values="proportion_mean"
        ).reindex(sides)

        # Bar positions
        x = np.arange(len(sides))

        # Width of a bar
        width = 0.35
        factor = 1 / 2 + 0.05

        # Plot bars with error bars
        bars_pre = ax.bar(
            x - width * factor,
            mean_pivot["pre"],
            width,
            yerr=sem_pivot["pre"],
            label="Untrained",
            color="C0",
            capsize=3,
        )
        bars_post = ax.bar(
            x + width * factor,
            mean_pivot["post"],
            width,
            yerr=sem_pivot["post"],
            label="Trained",
            color="C1",
            capsize=3,
        )

        # Add percentage labels above bars (accounting for error bars)
        for bar, prop, sem in zip(bars_pre, prop_pivot["pre"], sem_pivot["pre"]):
            pct = prop * 100
            if pct > 0:
                ax.annotate(
                    f"{pct:.1f}%",
                    xy=(bar.get_x() + bar.get_width() / 2, bar.get_height() + sem + 1),
                    ha="center",
                    va="bottom",
                    fontsize="x-small",
                )
        for bar, prop, sem in zip(bars_post, prop_pivot["post"], sem_pivot["post"]):
            pct = prop * 100
            if pct > 0:
                ax.annotate(
                    f"{pct:.1f}%",
                    xy=(bar.get_x() + bar.get_width() / 2, bar.get_height() + sem + 1),
                    ha="center",
                    va="bottom",
                    fontsize="x-small",
                )

        ax.set_xticks(x)
        ax.set_xticklabels([s.capitalize() for s in sides])
        ax.set_ylim(0, 150)
        ax.set_axisbelow(True)
        ax.grid(axis="y")

        ax.set_title(f"{dataset.upper()} - {conf.capitalize()}", fontweight="bold")
        ax.set_ylabel("# informative neurons" if col == 0 else "")
        ax.set_xlabel("")
        if row == 1 and col == 0:
            ax.legend(loc="upper right", fontsize="small")
        ax.tick_params(axis="x", rotation=0)

fig.tight_layout()

In [None]:
# Save informative neurons figure
output_path_informative = OUTPUT_DIR / "informative_neurons_sides_datasets.pdf"
viz.save_figure(fig, output_path_informative, overwrite=False, dpi=300)

In [None]:
# Save to CSV
csv_path = OUTPUT_DIR / "informative_neuron_stats.csv"
stats_df.to_csv(csv_path, index=False)
print(f"Saved: {csv_path}")