# Resilience of the network to input noise

Network robustness to Gaussian input noise measured using single neuron information analysis.

**Objectives**

- Measure information conveyed by single L4 neurons regarding a left-convex boundary element
- Compare the network performance across increasing levels of noise
- Compare pre- vs. post-trained networks for N4P2 datasets

**Dependencies:**

---

A) Gather inference spike recordings with noise:
- Inference spike recordings for N4P2: both before and after network training
- Depends on N4P2 workflows (with and without the `--chkpt -1` argument):
```bash
./scripts/run_main_workflow.py experiments/n4p2/train_n4p2_lrate_0_02_181023 3 7 15 --rule inference -v

for noise in 10 20; do
    ./scripts/run_main_workflow.py experiments/n4p2/train_n4p2_lrate_0_02_181023 3 7 15 --rule inference --subdir noise --noise $noise -v
done
```

---

B) Compute information measures:

i) For N4P2 (left-convex selectivity):
```bash
./scripts/figures/compute_information.py ./experiments/n4p2/train_n4p2_lrate_0_02_181023 ALL --side left
```
i) For N4P2 (left-convex selectivity) with increasing levels of noise:
```bash
for noise in 10 20; do
    ./scripts/figures/compute_information.py ./experiments/n4p2/train_n4p2_lrate_0_02_181023 ALL --analysis noise --side left --subdir noise --noise $noise
done
```

---

**Figures**

- Figure 1: information plots: increasing noise for each of the two datasets
- Figure 2: number of selective neurons (exceeding a threshold of 2/3 bits)

In [None]:
from __future__ import annotations

from copy import deepcopy
from dataclasses import dataclass
from enum import Enum
from itertools import product
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import pandas as pd
import xarray as xr
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

from hsnn import utils, viz
from hsnn.utils import handler, io

pidx = pd.IndexSlice
RESULTS_DIR = io.BASE_DIR / "out/figures/information"
OUTPUT_DIR = io.BASE_DIR / "out/figures/fig11"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

class DataSet(Enum):
    N3P2 = 1
    N4P2 = 2


viz.setup_journal_env()


def load_measures(
    results_dir: Path, noise: int = 0
) -> dict[str, dict[str, np.ndarray]]:
    """Loads ranked single neuron information measures.

    Args:
        results_dir (Path): Location of pre-computed measures.
        noise (int, optional): Noise amplitude of inference recordings. Defaults to 0.

    Returns:
        dict[str, dict[str, np.ndarray]]: Mapping from network architecture, and then state.
        Each element is an NDArray[float] of ranked measures, shape (num_trials x num_sides, num_nrns)
    """
    fname = io.formatted_name("information_left_convex", "pkl", noise=noise)
    print(f"Loading inference recordings from '{results_dir / fname}'")
    return io.load_pickle(results_dir / fname)


def load_inference_results(
    trial: handler.TrialView, chkpt: int | None, **kwargs
) -> xr.DataArray:
    results_path = handler.get_results_path(trial, chkpt, **kwargs)
    if results_path.is_file():
        print(f"Loading '{results_path}'...")
        return io.load_pickle(results_path)
    else:
        raise FileNotFoundError(f"'{results_path}'")


def plot_grid(
    state_measures_map: dict, axes: plt.Axes, xticks: np.ndarray | None = None
) -> plt.Axes:
    xs = np.arange(1, 4096 + 1) if xticks is None else xticks
    axes.plot(xs, state_measures_map["pre"].mean(axis=0), label="Untrained", ls=":")
    axes.plot(xs, state_measures_map["post"].mean(axis=0), label="Trained", ls="-")
    axes.set_xscale("log")
    axes.set_ylim([0, 1.05])
    axes.grid(True)
    return axes


### Load measures from both experiments (N3P2, N4P2)

In [None]:
class DuplicateResultError(Exception):
    pass


@dataclass
class ResultConfig:
    dataset: DataSet
    noise: int
    arch: str
    state: str

    def __hash__(self):
        return hash((self.dataset, self.noise, self.arch, self.state))

    def __eq__(self, other: ResultConfig):
        return (self.dataset, self.noise, self.arch, self.state) == (
            other.dataset,
            other.noise,
            other.arch,
            other.state,
        )


class ResultStore:
    def __init__(self):
        self.results: dict[ResultConfig, npt.NDArray[np.float_]] = {}

    def add(self, config: ResultConfig, result: npt.NDArray[np.float_]):
        if config in self.results:
            raise DuplicateResultError(
                f"Result for configuration {config} already exists."
            )
        self.results[config] = result

    def get(self, config: ResultConfig | dict) -> npt.NDArray[np.float_]:
        if isinstance(config, dict):
            _config = ResultConfig(**config)
        return self.results[config]


def push_results(
    datastore: ResultStore, dataset: DataSet, noise: int, records: dict
) -> None:
    for arch, state in product(["ALL"], ["pre", "post"]):
        cfg = ResultConfig(dataset=dataset, noise=noise, arch=arch, state=state)
        datastore.add(cfg, records[arch][state])

In [None]:
logdirs = {
    DataSet.N4P2: "n4p2/train_n4p2_lrate_0_02_181023",
}
amplitudes = [0, 10, 20]

results_ds = ResultStore()
for dataset, logdir in logdirs.items():
    expt = handler.ExperimentHandler(logdir)
    dataset_name = Path(logdir).parent.name
    for noise in amplitudes:
        results = load_measures(RESULTS_DIR / dataset_name, noise=noise)
        push_results(results_ds, dataset, noise, results)


### Figure 1: Information analysis (noise)

**Details:**

- Target `ALL` architecture: the one most relevant to this study
- Plot trend up to maximum noise when close to no neurons are selective before training

**Layout:**

- Cols: N3P2, N4P2
- Rows: noise levels

In [None]:
arch = "ALL"
states = ["pre", "post"]

measures_dict: dict[
    DataSet, dict[int, dict[str, np.ndarray]]
] = {}  # DataSet -> {noise -> {state -> array}}
for dataset in [DataSet.N4P2]:
    if dataset not in measures_dict:
        measures_dict[dataset] = {}
    for noise in amplitudes:
        if noise not in measures_dict[dataset]:
            measures_dict[dataset][noise] = {}
        for state in states:
            cfg = ResultConfig(dataset, noise, arch, state)
            measures_dict[dataset][noise][state] = results_ds.get(cfg)


**Setup image inset plotting**

In [None]:
data_cfg = {
    "name": "n4p2",
    "transforms": {
        "resize": [[128, 128]],
    },
}


def plot_image(
    image: np.ndarray, show_ticks: bool = False, axes: plt.Axes | None = None
) -> plt.Axes:
    axes = viz.imshow_cbar(
        image, attach_cbar=False, cmax=255, axes=axes, rasterized=True
    )
    if not show_ticks:
        axes.set_xticks([])
        axes.set_yticks([])
    return axes


def load_imageset(amplitude: int, cfg: dict) -> tuple[utils.ImageSet, pd.DataFrame]:
    _cfg = deepcopy(cfg)
    if amplitude > 0:
        _cfg["transforms"]["gaussiannoise"] = [amplitude]
    return utils.io.get_dataset(_cfg, return_annotations=True)


# Get target image at given noise levels
img_id = 5
images = {noise: load_imageset(noise, data_cfg)[0][img_id] for noise in amplitudes}

noise = 20
f, axes = plt.subplots(figsize=(2.5, 2.5))
plot_image(images[noise], axes=axes)
f.tight_layout()

**Information plot**

In [None]:
dataset = DataSet.N4P2

overwrite = False
width = 3.5
height = 4

f, axes = plt.subplots(3, 1, figsize=(width, height))
for row, noise in enumerate(amplitudes):
    ax: plt.Axes = axes[row]
    state_measures_map = measures_dict[dataset][noise]
    plot_grid(state_measures_map, ax)
    # Inset image
    inset_axes_ = inset_axes(
        ax,
        width="50%",
        height="50%",
        bbox_to_anchor=(0.33, 0.2, 1, 1),
        bbox_transform=ax.transAxes,
        loc="center",
    )
    plot_image(images[noise], axes=inset_axes_)
    # ticks, labels
    # if row == 0:
    #     ax.set_title(dataset.name, size=10, pad=10, fontweight='bold')
    if row < 2:
        ax.set_xticklabels([])
    else:
        ax.set_xlabel("Neuron rank #")
    ax.set_ylabel(r"$\mathcal{I}\; (s, \vec{R})$")
    ax.text(
        -0.4,
        0.5,
        rf"$\sigma={noise}$",
        size=10,
        horizontalalignment="right",
        verticalalignment="center",
        transform=ax.transAxes,
    )
    if ax == axes[0]:
        ax.legend(fontsize="x-small", loc="lower left")
    # ax.set_xlim([None, 1.5e3])
    ax.set_xlim([1, 4096])
f.tight_layout()
fname = OUTPUT_DIR / "fig_information_noise.pdf"
viz.save_figure(f, fname, overwrite=overwrite, dpi=600)

### Figure 2: number of most selective neurons (noise)

**Styles:**

- Bar chart: N4P2
- Plot number of neurons with I(s, R) > 2/3 per noise level

In [None]:
def _get_num_selective(info_measures: np.ndarray, threshold: float = 2 / 3.0) -> int:
    if info_measures.ndim == 2:
        info_measures_ = info_measures.mean(0)
    elif info_measures.ndim == 1:
        info_measures_ = info_measures
    else:
        raise ValueError(info_measures.shape)
    return (info_measures_ > threshold).sum()


def get_num_selective_dataset(
    measures_noise: dict[int, dict[str, np.ndarray]], threshold: float = 2 / 3.0
) -> tuple[list[int], dict[str, list[int]]]:
    values = {"pre": [], "post": []}
    amplitudes = []
    for amplitude, measures_state in measures_noise.items():
        amplitudes.append(amplitude)
        for state, info_measures in measures_state.items():
            values[state].append(_get_num_selective(info_measures, threshold))
    return amplitudes, values


def plot_number_selective_neurons(
    num_selective_nrns: dict[str, npt.ArrayLike],
    dataset_name: str,
    labels: list,
    axes: plt.Axes | None = None,
) -> plt.Axes:
    axes = viz.base.setup_axes(axes)
    axes.grid(axis="y")

    x = np.arange(len(labels))
    y1 = num_selective_nrns["pre"]
    y2 = num_selective_nrns["post"]

    # Width of a bar
    width = 0.4
    factor = 1 / 2 + 0.05

    axes: plt.Axes
    bars1 = axes.bar(x - width * factor, y1, width, label="Untrained", zorder=3)
    bars2 = axes.bar(x + width * factor, y2, width, label="Trained", zorder=3)

    # Add percentage labels above bars
    num_total = 4096
    for bar, val in zip(bars1, y1):
        pct = val / num_total * 100
        pct = val / num_total * 100
        if pct > 0:
            axes.annotate(
                f"{pct:.1f}%",
                xy=(bar.get_x() + bar.get_width() / 2, bar.get_height()),
                ha="center",
                va="bottom",
                fontsize="x-small",
            )
    for bar, val in zip(bars2, y2):
        pct = val / num_total * 100
        if pct > 0:
            axes.annotate(
                f"{pct:.1f}%",
                xy=(bar.get_x() + bar.get_width() / 2, bar.get_height()),
                ha="center",
                va="bottom",
                fontsize="x-small",
            )

    # axes.set_title(f"{dataset_name}", fontweight='bold', fontsize=10)
    axes.set_xticks(x)
    axes.set_xticks(x, labels)
    axes.set_xlabel(r"Noise amplitude ($\sigma$)")
    return axes

In [None]:
threshold = 2 / 3
overwrite = False

f, axes = plt.subplots(1, figsize=(2.5, 3), sharey=True)

amplitudes, num_selective_nrns = get_num_selective_dataset(
    measures_dict[dataset], threshold=threshold
)
ax = plot_number_selective_neurons(
    num_selective_nrns, dataset.name, amplitudes, axes=axes
)
ax.set_ylabel("# informative neurons")
ax.legend()
f.tight_layout()

fname = OUTPUT_DIR / "fig_informative_neurons_noise.pdf"
viz.save_figure(f, fname, overwrite=overwrite)