# Feature selectivity through self-organisation

Development of boundary contour element selectivity in a final layer neuron trained on N3P2 shapes.

**Dependencies:**

- Inference spike recordings for N3P2 (Trial #31): both before and after network training
- Depends on these workflows:
    - `./scripts/run_main_workflow.py experiments/n3p2/train_n3p2_lrate_0_04_181023 31 --rule inference -v`
    - `./scripts/run_main_workflow.py experiments/n3p2/train_n3p2_lrate_0_04_181023 31 --chkpt -1 --rule inference -v`
- Runtime: ~1 min per workflow (AMD Ryzen 9 5900X, 64GB RAM).

**Plots:**

- Responses of informative L4 neuron to N3P2 object subsets
- Gabor traceback (positive example, negative, averaged)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr
from pydantic import BaseModel

from hsnn import analysis, simulation, utils, viz
from hsnn.cluster import tasks
from hsnn.utils import handler, io
from hsnn.utils.handler import TrialView

pidx = pd.IndexSlice
RESULTS_DIR = io.BASE_DIR / "out/figures/fig5"
RESULTS_DIR.mkdir(parents=True, exist_ok=True)


def load_inference_results(
    trial: TrialView, chkpt: int | None, **kwargs
) -> xr.DataArray:
    results_path = handler.get_results_path(trial, chkpt, **kwargs)
    if results_path.is_file():
        print(f"Loading '{results_path.relative_to(trial.path)}'...")
        return utils.io.load_pickle(results_path)
    else:
        raise FileNotFoundError(f"'{results_path}'")


class InferenceConfig(BaseModel):
    """Inference kwargs passed to `handler.load_results`."""

    amplitude: int
    subdir: str | None = None


records: dict[str, xr.DataArray] = {"pre": None, "post": None}
syn_params: dict[str, pd.DataFrame] = {"pre": None, "post": None}
state_chkpt_mapping = {"pre": None, "post": -1}

viz.setup_journal_env({"legend.fontsize": "x-small"})

### 1) Load N3P2 recorded results

**Get representative Trial per (`E2E`, `FB`) combination**

In [None]:
logdir = "n3p2/train_n3p2_lrate_0_04_181023"

expt = handler.ExperimentHandler(logdir)
dataset_name = expt.logdir.parent.stem

df = expt.get_summary(-1)
closest_trials = expt.index_to_dir[handler.get_closest_samples(df)]
closest_trials.drop((0, 20), axis=0)

In [None]:
trial_index = (20, 20, 7)
# chkpt_idx = -1

trial = expt[trial_index]
print(f"Trial selected: '{trial.name}'")
proj_choices = {
    key: val for key, val in zip(closest_trials.index.names[:-1], trial_index[:-1])
}
print(f"Projections: {proj_choices}")

**Load pre- / post-training results**

In [None]:
inference_cfg = InferenceConfig(amplitude=0, subdir=None)
offset = 0.0 if inference_cfg.subdir == "onsets" else 50.0

# Imageset
cfg = trial.config
if inference_cfg.amplitude > 0:
    cfg["training"]["data"]["transforms"]["gaussiannoise"] = [inference_cfg.amplitude]
imageset, labels = utils.io.get_dataset(
    cfg["training"]["data"], return_annotations=True
)

# Recordings per state
for state, chkpt in state_chkpt_mapping.items():
    records[state] = load_inference_results(trial, chkpt, **dict(inference_cfg))

# Common parameters
duration: float = records["post"].item(0).duration - offset  # Observation period
reps = len(records["post"]["rep"])
input_shape = tuple(cfg["topology"]["poisson"]["EXC"])
layer_shape = tuple(cfg["topology"]["spatial"]["EXC"])

# Get synapse parameters
sim = simulation.Simulator.from_config(cfg)
for state in ("pre", "post"):
    chkpt = state_chkpt_mapping[state]
    store_path = trial.checkpoints[chkpt].store_path if isinstance(chkpt, int) else None
    if store_path is not None:
        sim.restore(store_path)
    syn_params[state] = sim.network.get_syn_params()
projs_plastic = tuple(["FF"] + [proj for proj, val in proj_choices.items() if val > 0])

### 2) Viz network dynamics

**Firing rate distributions**

In [None]:
# Poisson firing rates
state = "post"
rates = records[state].sel(img=0, rep=0, layer=0, nrn_cls="EXC").item().rates[1]

f, axes = plt.subplots(figsize=(4, 2))
axes.hist(rates[rates > 0], bins=60, density=True)
axes.set_xlabel("Firing rate (Hz)")
axes.set_ylabel(r"$f\;(r)$");

In [None]:
# Firing rates distributions for an image, rep
state = "post"
img = 0
rep = 0

viz.hist_rates(
    records[state].sel(img=img, rep=rep, layer=slice(1, None)),
    bins=60,
    xmax=[200, 300],
    figsize=(6, 8),
    yticks=False,
    xlabel="Firing rate (Hz)",
);

In [None]:
# Topological firing activity
state = "post"
img = 0
rep = 0

axes = viz.topographic_rates(
    records[state].sel(img=img, rep=rep, layer=slice(1, None)),
    plot_ticks=False,
    figsize=(4, 8),
);

### 3) Viz weights distribution

In [None]:
state = "post"

axes = viz.hist_weights(
    syn_params[state], projs_plastic, 40, annotations=True, figsize=(8, 6)
)
for ax in axes[:, 0]:
    ax.set_ylabel(ax.get_ylabel(), size="x-large")

### 4) Boundary contour element selectivity

In [None]:
# Infer firing rates for a layer
layer = 4

rates_array = analysis.infer_rates(
    records["post"].sel(layer=layer, nrn_cls="EXC"), duration, offset
)
rates_array_pre = analysis.infer_rates(
    records["pre"].sel(layer=layer, nrn_cls="EXC"), duration, offset
)
print(f"Firing rates array: layer={layer}; duration={duration}; offset={offset}")

**Subplot A (i)**

Target boundary contour conformation: left-side convexity

In [None]:
side = "left"
target = 1  # 0: concave; 1: convex

width = 3.5
axes = viz.plot_shape_columns(
    imageset,
    labels,
    side,
    target,
    cmap="gray",
    vmin=0,
    vmax=255,
    linewidth=2,
    figsize=(width, 3 / 4 * width),
)
ax: plt.Axes
for ax in axes[-1, :]:
    xlabel = ax.get_xlabel()
    xlabel = " /\n".join(xlabel.split(": "))
    ax.set_xlabel(xlabel, fontsize=8, ha="left")
    ax.xaxis.set_label_coords(0, -0.1)
f = plt.gcf()
f.subplots_adjust(wspace=0.2, hspace=0.05)

viz.save_figure(f, RESULTS_DIR / "fig_target_shapes.pdf")

**Subplot A (ii)**

Select L4 neuron firing rate responses

In [None]:
nrn_id = 3012

# width = 3.5 * 1.25  # * 1.125
width = 4.21
axes = viz.plot_contour_selectivity(
    rates_array,
    nrn_id,
    labels,
    rates_array_pre,
    bar_width=0.7,
    violinplot=False,
    figsize=(width, 1 / 3 * width),
)
axes.set_title("")
axes.set_yticks([0, 20, 40, 60])
axes.set_xticklabels([])
f = plt.gcf()
f.tight_layout()
axes.set_ylim([0, 60])
axes.legend(loc="upper left")

viz.save_figure(f, RESULTS_DIR / "fig_subset_rates.pdf", overwrite=True)

### 5) Boundary contour element tracebacks

In [None]:
layer = 4
nrn_id = 3012
state = "post"

sensitivities_df = tasks.get_sensitivities(
    nrn_id, layer, duration, offset, records[state], syn_params[state], input_shape
)
norm_max = sensitivities_df.max(skipna=True).max()

**Subplot B**

Select L4 neuron traceback (left-side convexity)

In [None]:
img = 0

sensitivities = sensitivities_df[img].dropna()
image = imageset[img]

axes = viz.plot_traceback(
    sensitivities,
    image,
    nrn_id=None,
    layer_shape=layer_shape,
    figsize=(2, 2),
    sensitivities_max=norm_max,
    vmin=0,
    vmax=255,
    color="red",
    sel_kwargs={"s": 10, "linewidths": 0},
)
f = plt.gcf()
f.tight_layout()

viz.save_figure(f, RESULTS_DIR / "fig_traceback_n3p2_pos.pdf")

**Subplot C**

Select L4 neuron traceback (left-side concavity)

In [None]:
img = 4

sensitivities = sensitivities_df[img].dropna()
image = imageset[img]

axes = viz.plot_traceback(
    sensitivities,
    image,
    nrn_id=None,
    layer_shape=layer_shape,
    figsize=(2, 2),
    sensitivities_max=norm_max,
    vmin=0,
    vmax=255,
    color="red",
    sel_kwargs={"s": 10, "linewidths": 0},
)
f = plt.gcf()
f.tight_layout()

viz.save_figure(f, RESULTS_DIR / "fig_traceback_n3p2_neg.pdf")

**Subplot D**

Select L4 neuron traceback (averaged)

In [None]:
img = None

sensitivities = sensitivities_df.mean(axis=1, skipna=True)
image = np.full(input_shape[1:], 128.0)

axes = viz.plot_traceback(
    sensitivities,
    image,
    nrn_id=None,
    layer_shape=layer_shape,
    figsize=(2, 2),
    sensitivities_max=norm_max,
    vmin=0,
    vmax=255,
    color="red",
    sel_kwargs={"s": 10, "linewidths": 0},
)
f = plt.gcf()
f.tight_layout()

viz.save_figure(f, RESULTS_DIR / "fig_traceback_n3p2_all.pdf")