# Feature selectivity through self-organisation

Progressive increase in the spatial extent and fraction of active excitatory neurons across successive layers after training on N4P2.

**Plots:**

- Left-side: Topographic activity plots (excitatory, inhibitory)
- Right-side: Firing rates distributions

In [None]:
from typing import Iterable

import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
import numpy as np
import xarray as xr
from pydantic import BaseModel

from hsnn import analysis, utils, viz
from hsnn.utils import handler, io
from hsnn.utils.handler import TrialView

RESULTS_DIR = io.BASE_DIR / "out/figures/fig8"
RESULTS_DIR.mkdir(parents=True, exist_ok=True)

def plot_image(
    img: int, imageset: Iterable, show_ticks: bool = False, axes: plt.Axes | None = None
) -> plt.Axes:
    axes = viz.imshow_cbar(imageset[img], attach_cbar=False, cmax=255, axes=axes)
    if not show_ticks:
        axes.set_xticks([])
        axes.set_yticks([])
    return axes


def load_inference_results(
    trial: TrialView, chkpt: int | None, **kwargs
) -> xr.DataArray:
    results_path = handler.get_results_path(trial, chkpt, **kwargs)
    if results_path.is_file():
        print(f"Loading '{results_path.relative_to(trial.path)}'...")
        return utils.io.load_pickle(results_path)
    else:
        raise FileNotFoundError(f"'{results_path}'")


class InferenceConfig(BaseModel):
    """Inference kwargs passed to `handler.load_results`."""

    amplitude: int
    subdir: str | None = None


records: dict[str, xr.DataArray] = {"pre": None, "post": None}
state_chkpt_mapping = {"pre": None, "post": -1}

viz.setup_journal_env()

### 1) Load N4P2 recorded results

**Get representative Trial per (`E2E`, `FF`) combination**

In [None]:
logdir = "n4p2/train_n4p2_lrate_0_02_181023"

expt = handler.ExperimentHandler(logdir)
dataset_name = expt.logdir.parent.stem

df = expt.get_summary(-1)
closest_trials = expt.index_to_dir[handler.get_closest_samples(df)]
closest_trials.drop((0, 20), axis=0)

In [None]:
trial_index = (20, 20, 3)

trial = expt[trial_index]
print(f"Trial selected: '{trial.name}'")
proj_choices = {
    key: val for key, val in zip(closest_trials.index.names[:-1], trial_index[:-1])
}
print(f"Projections: {proj_choices}")

**Load pre- / post-training results**

In [None]:
inference_cfg = InferenceConfig(amplitude=0, subdir=None)
offset = 0.0 if inference_cfg.subdir == "onsets" else 50.0

# Imageset
cfg = trial.config
if inference_cfg.amplitude > 0:
    cfg["training"]["data"]["transforms"]["gaussiannoise"] = [inference_cfg.amplitude]
imageset, labels = utils.io.get_dataset(
    cfg["training"]["data"], return_annotations=True
)

# Recordings per state
for state, chkpt in state_chkpt_mapping.items():
    records[state] = load_inference_results(trial, chkpt, **dict(inference_cfg))

# Common parameters
duration: float = records["post"].item(0).duration - offset  # Observation period
reps = len(records["post"]["rep"])
input_shape = tuple(cfg["topology"]["poisson"]["EXC"])
layer_shape = tuple(cfg["topology"]["spatial"]["EXC"])

### 2) Viz network dynamics

**Firing rate distributions**

In [None]:
# Poisson firing rates
state = "post"
rates = records[state].sel(img=0, rep=0, layer=0, nrn_cls="EXC").item().rates[1]

f, axes = plt.subplots(figsize=(4, 2))
axes.hist(rates[rates > 0], bins=60, density=True)
axes.set_xlabel("Firing rate (Hz)")
axes.set_ylabel(r"$f\;(r)$");

In [None]:
state = "post"

rates_array = analysis.infer_rates(
    records[state].sel(img=0, nrn_cls="EXC", layer=slice(1, None)),
    duration=duration,
    offset=offset,
)

**Plot reference image**

In [None]:
f, axes = plt.subplots(figsize=(2.25, 2.25))
axes = plot_image(img=0, imageset=imageset)
axes.set_title("Input", fontweight="bold", fontsize=10)
viz.save_figure(f, RESULTS_DIR / "fig_image_inset.pdf", overwrite=False)

**Plot firing rate distrs.**

In [None]:
hist_kwargs = {"facecolor": "#1f77b4", "edgecolor": "#1f77b4", "density": True}
text_kwargs = {
    "fontsize": 8,
    "va": "top",
    "ha": "left",
    "bbox": dict(facecolor="white"),
}
locator = MaxNLocator(nbins=3)
ymax = 0.1 if state == "pre" else 0.06
figsize = (2.5, 5)

f, axes = plt.subplots(4, 1, figsize=figsize, sharex=True, sharey=True)
for i, (l, rates) in enumerate(rates_array.groupby("layer")):
    ax: plt.Axes = axes[i]
    rs = rates.values.ravel()
    ax.hist(rs[rs > 0], bins=np.arange(0, 100 + 10, 5), **hist_kwargs)
    # ax.set_yticks(np.arange(0, 0.05+0.02, 0.025))
    ax.yaxis.set_major_locator(locator=locator)
    ax.set_ylim([0, ymax])
    # Annotation
    x, y = 0.45, 0.9
    frac_active = len(rs[rs > 0]) / len(rs) * 100
    text = f"Active: {frac_active:.1f} %"
    ax.text(x, y, text, transform=ax.transAxes, **text_kwargs)
    ax.set_ylabel("Density")
else:
    ax.set_xlabel("Firing rate [Hz]")
f.tight_layout()

f.savefig(RESULTS_DIR / f"fig_rates_distr_{state}.pdf", dpi=300)

**Plot spatial heatmaps**

In [None]:
rep = 0
vmax = 50 if state == "pre" else 150
figsize = (1.6, 5)

rates_array_ = rates_array.sel(rep=rep)

f, axes = plt.subplots(4, 1, figsize=figsize, sharex=True, sharey=True)
for i, (l, rates) in enumerate(rates_array_.groupby("layer")):
    ax: plt.Axes = axes[i]
    ax.imshow(np.asarray(rates).reshape(layer_shape), cmap="inferno", vmax=vmax)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_ylabel(f"L{l}", rotation=0, labelpad=15, fontweight="bold", fontsize=10)
f.tight_layout()

f.savefig(RESULTS_DIR / f"fig_spatial_rates_{state}.pdf", dpi=300)

## Appendix

In [None]:
# Topological firing activity
state = "post"
img = 0
rep = 0

axes = viz.topographic_rates(
    records[state].sel(img=img, rep=rep, layer=slice(1, None)),
    plot_ticks=False,
    figsize=(4, 8),
    vmax=None,
);

In [None]:
# Firing rates distributions for an image, rep
state = "post"
img = 0
rep = 0

viz.hist_rates(
    records[state].sel(img=img, rep=rep, layer=slice(1, None)),
    bins=np.arange(0, 300 + 10, 5),
    xmax=[100, 300],
    figsize=(6, 8),
    yticks=False,
    xlabel="Firing rate (Hz)",
);