# Feature selectivity N4P2 (noise)

Neuronal response properties before and after training with Gaussian noise applied to shapes from N4P2.

**Dependencies:**

- Inference spike recordings for N4P2: both before and after network training
- Depends on N4P2 workflows (with and without the `--chkpt -1` argument):
```bash
./scripts/run_main_workflow.py experiments/n4p2/train_n4p2_lrate_0_02_181023 15 --rule inference --subdir noise --noise 20 -v
```

**Plots:**

- Firing rate distributions (pre vs. post)
- L4 Gabor traceback analyses

In [None]:
from copy import deepcopy
from tqdm.notebook import tqdm

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr
from pydantic import BaseModel

from hsnn import analysis, simulation, utils, viz
from hsnn.cluster import tasks
from hsnn.utils import handler, io
from hsnn.utils.handler import TrialView
from hsnn import transforms

OUTPUT_DIR = io.BASE_DIR / "out/figures/fig12"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)


def load_inference_results(
    trial: TrialView, chkpt: int | None, **kwargs
) -> xr.DataArray:
    results_path = handler.get_results_path(trial, chkpt, **kwargs)
    if results_path.is_file():
        print(f"Loading '{results_path}'...")
        return utils.io.load_pickle(results_path)
    else:
        raise FileNotFoundError(f"'{results_path}'")


def load_imageset(amplitude: int, cfg: dict) -> tuple[utils.ImageSet, pd.DataFrame]:
    _cfg = deepcopy(cfg)
    if amplitude > 0:
        _cfg["training"]["data"]["transforms"]["gaussiannoise"] = [amplitude]
    return utils.io.get_dataset(_cfg["training"]["data"], return_annotations=True)


class InferenceConfig(BaseModel):
    """Inference kwargs passed to `handler.load_results`."""

    noise: int
    subdir: str | None = None


records: dict[str, xr.DataArray] = {"pre": None, "post": None}
syn_params: dict[str, pd.DataFrame] = {"pre": None, "post": None}
state_chkpt_mapping = {"pre": None, "post": -1}

viz.setup_journal_env()

### 1) Load N4P2 recorded results

**Get representative Trial for `ALL` architecture**

In [None]:
logdir = "n4p2/train_n4p2_lrate_0_02_181023"
trial_index = (20, 20, 3)

expt = handler.ExperimentHandler(logdir)
print(f"Experiment selected: '{expt.name}'")
dataset_name = expt.logdir.parent.stem

trial = expt[trial_index]
cfg = trial.config
print(f"Trial selected: '{trial.name}'")

**Load pre / post-training results at highest noise amplitude**

In [None]:
amplitude = 20
subdir = "noise" if amplitude > 0 else None
offset = 50.0
inference_cfg = InferenceConfig(noise=amplitude, subdir="noise")

# Recordings per noise amplitude
for state, chkpt in state_chkpt_mapping.items():
    records[state] = load_inference_results(trial, chkpt, **dict(inference_cfg))

# Common parameters
duration: float = records["post"].item(0).duration - offset  # Observation period
reps = len(records["post"]["rep"])
input_shape = tuple(cfg["topology"]["poisson"]["EXC"])
layer_shape = tuple(cfg["topology"]["spatial"]["EXC"])

# Get synapse parameters
sim = simulation.Simulator.from_config(cfg)
for state in ("pre", "post"):
    chkpt = state_chkpt_mapping[state]
    store_path = trial.checkpoints[chkpt].store_path if isinstance(chkpt, int) else None
    if store_path is not None:
        sim.restore(store_path)
    syn_params[state] = sim.network.get_syn_params()

### 2) Viz network dynamics

**Firing rate distributions**

In [None]:
# Poisson firing rates
state = "post"

rates = records[state].sel(img=0, rep=0, layer=0, nrn_cls="EXC").item().rates[1]

f, axes = plt.subplots(figsize=(4, 2))
axes.hist(rates[rates > 0], bins=60, density=True)
axes.set_xlabel("Firing rate (Hz)")
axes.set_ylabel(r"$f\;(r)$");

In [None]:
# Firing rates distributions for an image, rep
img = 0
rep = 0
state = "post"

viz.hist_rates(
    records[state].sel(img=img, rep=rep, layer=slice(1, None)),
    bins=60,
    xmax=[200, 300],
    figsize=(6, 8),
    yticks=False,
    xlabel="Firing rate (Hz)",
);

In [None]:
# Topological firing activity
img = 0
rep = 0
state = "post"

axes = viz.topographic_rates(
    records[state].sel(img=img, rep=rep, layer=slice(1, None)),
    plot_ticks=False,
    figsize=(4, 8),
);

### 3) Compute sensitivities for a select neuron

In [None]:
# Common parameters
layer = 4

imageset, labels = load_imageset(amplitude, trial.config)

f, axes = plt.subplots(figsize=(2, 2))
axes.imshow(imageset[1], cmap="gray", vmin=0, vmax=255);

**Get or compute sensitivities**

Runtime: ~2 min

In [None]:
nrn_id = 2176

cache_dir = OUTPUT_DIR / ".cache"
cache_dir.mkdir(parents=True, exist_ok=True)
filepath = (
    cache_dir / f"sensitivities_{dataset_name}_noise_{amplitude}_L4_{nrn_id}.pkl"
)
sensitivities_dict: dict[str, pd.DataFrame] = {}  # state -> pd.DataFrame
if not filepath.exists():
    for state in tqdm(["pre", "post"]):
        sensitivities_dict[state] = tasks.get_sensitivities(
            nrn_id,
            layer,
            duration,
            offset,
            records[state],
            syn_params[state],
            input_shape,
        )
    io.save_pickle(sensitivities_dict, filepath)
else:
    sensitivities_dict: dict[str, pd.DataFrame] = io.load_pickle(filepath)
    print(f"Loaded cached sensitivities from '{filepath}'")

norm_maxes: dict[str, float] = {}  # amplitude -> max sensitivity
for state, sensitivities_df in sensitivities_dict.items():
    norm_maxes[state] = sensitivities_df.max(skipna=True).max()

### 4) Figure: Boundary contour element selectivity

In [None]:
layer = [4]

rate_arrays: dict[str, xr.DataArray] = {}
for state in ("pre", "post"):
    rate_arrays[state] = analysis.infer_rates(
        records[state].sel(layer=layer, nrn_cls="EXC"), duration, offset
    )
print(f"Firing rates array: layer={layer}; duration={duration}; offset={offset}")

In [None]:
def plot_selectivity_panel(
    axes,
    rates_array,
    layer,
    nrn_id,
    labels,
    yticks=None,
    ylim=None,
    title: str | None = None,
):
    viz.plot_contour_selectivity(
        rates_array.sel(layer=layer),
        nrn_id,
        labels,
        violinplot=False,
        show_xlabels=True,
        rotate_xticklabels=False,
        axes=axes,
    )
    axes.set_title(title, fontweight="bold")
    _ylim = ylim if ylim is not None else [0, None]
    axes.set_ylim(_ylim)
    if yticks is not None:
        axes.set_yticks(yticks)
    return axes

In [None]:
labels_ = labels.copy()
labels_ = labels_[["image_id", "left"]]

width = 2  # * 1.125
height = 3.5

yticks = range(0, 50, 20)
ylim = [0, 45]

f, axes = plt.subplots(2, 1, figsize=(width, height), sharex=True, sharey=True)
ax = plot_selectivity_panel(
    axes[0],
    rate_arrays["pre"],
    layer=4,
    nrn_id=nrn_id,
    labels=labels_,
    yticks=yticks,
    ylim=ylim,
    title="Untrained",
)
ax = plot_selectivity_panel(
    axes[1],
    rate_arrays["post"],
    layer=4,
    nrn_id=nrn_id,
    labels=labels_,
    yticks=yticks,
    ylim=ylim,
    title="Trained",
)


def map_xlabel(xlabel):
    text = " /\n".join(xlabel.get_text().split(": "))
    xlabel.set_text(text)
    # xlabel.set_horizontalalignment('left')
    return xlabel


ax: plt.Axes = axes[-1]
xlabels = [map_xlabel(lab) for lab in ax.get_xticklabels()]
xticklabels = ax.set_xticklabels(xlabels)
f.tight_layout()
f.subplots_adjust(hspace=0.4)

f.savefig(OUTPUT_DIR / 'fig_neuron_rates.pdf', dpi=300)

**Firing rate statistics**

In [None]:
pos_ids = np.flatnonzero(labels["left"] == 1)
neg_ids = np.flatnonzero(labels["left"] == 0)
pos_ids, neg_ids

In [None]:
state = "post"

rates_neg = rate_arrays[state].sel(img=neg_ids, nrn=nrn_id, layer=4).values.ravel()
rates_pos = rate_arrays[state].sel(img=pos_ids, nrn=nrn_id, layer=4).values.ravel()
rates = np.vstack([rates_neg, rates_pos])
print(f"Negative rates: {rates_neg.mean():.1f} ({rates_neg.std():.1f}) Hz")
print(f"Positive rates: {rates_pos.mean():.1f} ({rates_pos.std():.1f}) Hz")

plt.violinplot(rates.T);

### 5) Figure: traceback

Select L4 neuron traceback

Details:
- Left: Negative image
- Middle: Positive image
- Right: Averaged response

In [None]:
# Prepare Gaussian distorted background
tsf = transforms.GaussianNoise(amplitude)
image_bg = np.full(input_shape[1:], 128.0)


def subplot_traceback(axes, img, sensitivities_df, norm_max, sel_kwargs=None):
    _sel_kwargs = sel_kwargs or {"s": 5, "linewidths": 0}
    if img is None:
        sensitivities = sensitivities_df.mean(axis=1, skipna=True)
        image = tsf.transform(image_bg)
    else:
        sensitivities = sensitivities_df[img].dropna()
        image = imageset[img]

    viz.plot_traceback(
        sensitivities,
        image,
        nrn_id=None,
        layer_shape=layer_shape,
        sensitivities_max=norm_max,
        vmin=0,
        vmax=255,
        color="red",
        sel_kwargs=_sel_kwargs,
        axes=axes,
    )


img_pos = 1
img_neg = 6


**Subplot A**

State: `Pre`

In [None]:
seed = 42
state = "pre"

np.random.seed(seed)
sensitivities = sensitivities_dict[state]
norm_max = norm_maxes[state]

width = 1.2
# width = 3.2
sel_kwargs = {"s": 5, "linewidths": 0}
# sel_kwargs = {'s': 30, 'linewidths': 0}

f, axes = plt.subplots(1, 3, figsize=(3 * width, width))

# Plot negative case
ax: plt.Axes = axes[0]
subplot_traceback(ax, img_neg, sensitivities, norm_maxes[state], sel_kwargs=sel_kwargs)
# ax.set_title('Derp', fontweight='bold')
# Plot negative case
ax: plt.Axes = axes[1]
subplot_traceback(ax, img_pos, sensitivities, norm_maxes[state], sel_kwargs=sel_kwargs)
# Plot negative case
ax: plt.Axes = axes[2]
subplot_traceback(ax, None, sensitivities, norm_maxes[state], sel_kwargs=sel_kwargs)

f = plt.gcf()
f.tight_layout()

f.savefig(OUTPUT_DIR / f'fig_traceback_untrained_L4_{nrn_id}.pdf', dpi=300)

**Subplot B**

State: `Post`

In [None]:
seed = 43
state = "post"

np.random.seed(seed)
sensitivities = sensitivities_dict[state]
norm_max = norm_maxes[state]

f, axes = plt.subplots(1, 3, figsize=(3 * width, width))

# Plot negative case
ax: plt.Axes = axes[0]
subplot_traceback(ax, img_neg, sensitivities, norm_maxes[state], sel_kwargs=sel_kwargs)
# ax.set_title('Derp', fontweight='bold')
# Plot negative case
ax: plt.Axes = axes[1]
subplot_traceback(ax, img_pos, sensitivities, norm_maxes[state], sel_kwargs=sel_kwargs)
# Plot negative case
ax: plt.Axes = axes[2]
subplot_traceback(ax, None, sensitivities, norm_maxes[state], sel_kwargs=sel_kwargs)

f = plt.gcf()
f.tight_layout()

f.savefig(OUTPUT_DIR / f'fig_traceback_trained_L4_{nrn_id}.pdf', dpi=300)