# Information analysis: N3P2 and N4P2

Single-neuron information analysis and informative-neuron counts for convex boundary contour elements.

This plots Fig 9 and supplementary S1 Fig.

**Objectives**

- Measure information conveyed by single L4 neurons regarding convex-boundary contour elements
- Compare performance of different network architectures (combination of FF, LAT and FB)
- Compare pre- vs. post-trained networks for N3P2 and N4P2 datasets
- Count most informative neurons (maximum conveyed per side) across network architectures / datasets

**Dependencies:**

---

A) Gather inference spike recordings:
- Inference spike recordings for N3P2 & N4P2: both before and after network training
- Depends on N3P2 workflows (with and without the `--chkpt -1` argument):
    - `./scripts/run_main_workflow.py experiments/n3p2/train_n3p2_lrate_0_04_181023 0 1 3 4 5 7 8 9 31 --rule inference -v`
- Depends on N4P2 workflows (with and without the `--chkpt -1` argument):
    - `./scripts/run_main_workflow.py experiments/n4p2/train_n4p2_lrate_0_02_181023 0 1 3 4 5 7 12 15 29 --rule inference -v`

---

B) Compute information measures:

- Add `--target 0` to do a parallel analysis regarding concave selectivity (S1 Figure)

i) For N3P2:
```bash
for arch in FF SEMI ALL; do
    ./scripts/figures/compute_information.py ./experiments/n3p2/train_n3p2_lrate_0_04_181023 $arch
done
```
ii) and N4P2:
```bash
for arch in FF SEMI ALL; do
    ./scripts/figures/compute_information.py ./experiments/n4p2/train_n4p2_lrate_0_02_181023 $arch
done
```

---

**Plots**

- Panel A: rank-order single neuron information curves
- Panel B: number of selective neurons (exceeding 2/3 threshold)

In [None]:
from enum import Enum
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import pandas as pd
from matplotlib.ticker import MaxNLocator

from hsnn import viz
from hsnn.utils import handler, io

pidx = pd.IndexSlice
RESULTS_DIR = io.BASE_DIR / "out/figures/information"
OUTPUT_DIR = io.BASE_DIR / "out/figures/fig9"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)


class DataSet(Enum):
    N3P2 = 1
    N4P2 = 2


def load_measures(
    results_dir: Path, noise: int = 0, target: int = 1
) -> dict[str, dict[str, np.ndarray]]:
    """Loads ranked single neuron information measures.

    Args:
        results_dir (Path): Location of pre-computed measures.
        noise (int, optional): Noise amplitude of inference recordings. Defaults to 0.
        target (int, optional): Target boundary conformation. Defaults to 1.

    Returns:
        dict[str, dict[str, np.ndarray]]: Mapping from network architecture, and then state.
        Each element is an NDArray[float] of ranked measures, shape (num_trials x num_sides, num_nrns)
    """
    parts = ["information", "max"]
    parts += ["convex"] if target == 1 else ["concave"]
    fname = io.formatted_name("_".join(parts), "pkl", noise=noise)
    return io.load_pickle(results_dir / fname)


def plot_grid(state_measures_map: dict, axes: plt.Axes) -> plt.Axes:
    xticks = np.arange(1, 4096 + 1)

    axes.plot(xticks, state_measures_map["pre"].mean(axis=0), label="Untrained", ls=":")
    axes.plot(xticks, state_measures_map["post"].mean(axis=0), label="Trained", ls="-")
    # axes = viz.plot_errorband(
    #     state_measures_map["pre"].T, label="Untrained", ls=":", xticks=xticks, axes=axes
    # )
    # axes = viz.plot_errorband(
    #     state_measures_map["post"].T, label="Trained", ls="-", xticks=xticks, axes=axes
    # )
    axes.set_xlim(xticks[0], xticks[-1])
    axes.set_xscale("log")
    axes.set_ylim([0, 1.05])
    axes.grid(True)
    # axes.set_title(f'Single Neuron Information Analysis');
    return axes


viz.setup_journal_env()

### Load measures from both experiments (N3P2, N4P2)

In [None]:
target = 0  # Set target=1 for convex, or target=0 for concave
measures_dict: dict[DataSet, dict[str, dict[str, np.ndarray]]] = {}

logdirs = {
    DataSet.N3P2: "n3p2/train_n3p2_lrate_0_04_181023",
    DataSet.N4P2: "n4p2/train_n4p2_lrate_0_02_181023",
}
for dataset, logdir in logdirs.items():
    expt = handler.ExperimentHandler(logdir)
    dataset_name = Path(logdir).parent.name
    measures_dict[dataset] = load_measures(RESULTS_DIR / dataset_name, target=target)
    print(f"Loaded '{dataset_name}' measures from '{RESULTS_DIR / dataset_name}'")

### Figure 1: Information analysis

**Styles:**

- Aim for easy comparison across architectures per dataset
- Try plotting all trends into one plot with untrained network as FF + LAT + FB
- If this is unclear, use a separate subplot per architecture and highlight pre- vs post-training
- Try with / without error bars

**Layout:**

- Cols: N3P2, N4P2
- Rows: different architectures

In [None]:
archs = ["FF", "SEMI", "ALL"]
arch_desc_mapping = {"FF": "FF", "SEMI": "FF+LAT", "ALL": "FF+LAT+FB"}
width = 5.5
height = 4

f, axes = plt.subplots(3, 2, figsize=(width, height))
for row, arch in enumerate(archs):
    for col, dataset in enumerate(DataSet):
        ax: plt.Axes = axes[row, col]
        state_measures_map = measures_dict[dataset][arch]
        plot_grid(state_measures_map, ax)
        # ticks, labels
        if row == 0:
            ax.set_title(dataset.name, size=10, pad=10, fontweight="bold")
        if row < 2:
            ax.set_xticklabels([])
        else:
            ax.set_xlabel("Neuron rank #")
        if col > 0:
            ax.set_yticklabels([])
        else:
            ax.set_ylabel(r"$\mathcal{I}\; (s, \vec{R})$")
            ax.text(
                -0.4,
                0.5,
                rf"{arch_desc_mapping[arch]}",
                size=10,
                horizontalalignment="right",
                verticalalignment="center",
                transform=ax.transAxes,
                fontweight="bold",
            )
        if ax == axes[-1, -2]:
            if target == 1:
                ax.legend(fontsize="x-small", loc="lower left")
            else:
                ax.legend(fontsize="x-small", loc="upper right")
        # ax.set_xlim([None, 1.5e3])
        ax.set_xlim([1, 4096])
f.tight_layout()

conformation = "convex" if target == 1 else "concave"
parts = ["fig", "information", conformation, "datasets"]
fname = "_".join(parts) + ".pdf"
fpath = OUTPUT_DIR / fname
viz.save_figure(f, fpath, overwrite=False)

### Figure 2: number of most selective neurons

**Styles:**

- Bar charts: N3P2, N4P2
- Plot number of neurons with I(s, R) > 0.9 or 2/3 per network architecture
- Untrained architecture can be FF + LAT + FB

In [None]:
def _get_num_selective(
    info_measures: np.ndarray, threshold: float = 2 / 3.0
) -> tuple[float, float]:
    """Compute mean and SEM of informative neuron counts across trials.

    Args:
        info_measures: Array of shape (n_trials, n_neurons).
        threshold: Information threshold for "informative" neurons.

    Returns:
        Tuple of (mean count, SEM).
    """
    if info_measures.ndim == 2:
        counts = np.sum(info_measures >= threshold, axis=1)
        return counts.mean(), counts.std() / np.sqrt(len(counts))
    elif info_measures.ndim == 1:
        count = (info_measures > threshold).sum()
        return float(count), 0.0
    else:
        raise ValueError(info_measures.shape)


def get_num_selective_dataset(
    measures_arch: dict[str, dict[str, np.ndarray]], threshold: float = 2 / 3.0
) -> tuple[list[str], dict[str, list[float]], dict[str, list[float]]]:
    """Get mean counts and SEMs for each architecture and state.

    Returns:
        Tuple of (archs, means_dict, sems_dict).
    """
    means = {"pre": [], "post": []}
    sems = {"pre": [], "post": []}
    archs = []
    for arch, measures_state in measures_arch.items():
        archs.append(arch)
        for state, info_measures in measures_state.items():
            mean, sem = _get_num_selective(info_measures, threshold)
            means[state].append(mean)
            sems[state].append(sem)
    return archs, means, sems


def plot_number_selective_neurons(
    num_selective_nrns: dict[str, npt.ArrayLike],
    sems: dict[str, npt.ArrayLike],
    dataset_name: str,
    axes: plt.Axes | None = None,
) -> plt.Axes:
    axes = viz.base.setup_axes(axes)
    axes.grid(axis="y")
    axes.yaxis.set_major_locator(MaxNLocator(nbins=5, integer=True))

    x = np.arange(len(archs))
    y1 = num_selective_nrns["pre"]
    y2 = num_selective_nrns["post"]
    sem1 = sems["pre"]
    sem2 = sems["post"]
    labels = [arch_desc_mapping[label] for label in archs]

    # Width of a bar
    width = 0.4
    factor = 1 / 2 + 0.05

    axes: plt.Axes
    bars1 = axes.bar(
        x - width * factor, y1, width, yerr=sem1, label="Untrained", zorder=3, capsize=3
    )
    bars2 = axes.bar(
        x + width * factor, y2, width, yerr=sem2, label="Trained", zorder=3, capsize=3
    )

    # Add percentage labels above bars (without error values)
    num_total = 4096
    for bar, val, sem_val in zip(bars1, y1, sem1):
        pct = val / num_total * 100
        if pct > 0:
            axes.annotate(
                f"{pct:.1f}%",
                xy=(bar.get_x() + bar.get_width() / 2, bar.get_height() + sem_val + 2),
                ha="center",
                va="bottom",
                fontsize="x-small",
            )
    for bar, val, sem_val in zip(bars2, y2, sem2):
        pct = val / num_total * 100
        if pct > 0:
            axes.annotate(
                f"{pct:.1f}%",
                xy=(bar.get_x() + bar.get_width() / 2, bar.get_height() + sem_val + 2),
                ha="center",
                va="bottom",
                fontsize="x-small",
            )

    axes.set_title(f"{dataset_name}", fontweight="bold", fontsize=10)
    axes.set_xticks(x)
    axes.set_xticks(x, labels, fontsize=10, rotation=45, horizontalalignment="right")
    return axes

In [None]:
threshold = 2 / 3

f, axes = plt.subplots(1, 2, figsize=(5.5, 3), sharey=True)

for i, (dataset, measures_arch) in enumerate(measures_dict.items()):
    archs, num_selective_nrns, sems = get_num_selective_dataset(
        measures_dict[dataset], threshold=threshold
    )
    ax = plot_number_selective_neurons(
        num_selective_nrns, sems, dataset.name, axes=axes[i]
    )
    if i == 0:
        ax.set_ylabel("# informative neurons")
else:
    ax.legend(loc=("upper right" if target == 1 else "upper left"))
ax.set_ylim(0, (420 if target == 1 else 100))
f.tight_layout()

parts = ["fig", "informative_neurons", conformation, "datasets"]
fname = "_".join(parts) + ".pdf"
fpath = OUTPUT_DIR / fname
viz.save_figure(f, fpath, overwrite=False)