# Hierarchical feature integration

Hierarchical integration of neuronal responses for feature selective neurons in the last two layers, L3 and L4, trained on N4P2 shapes.

**Dependencies:**

- Inference spike recordings for N4P2 (Trial #15): both before and after network training
- Depends on these workflows:
    - `./scripts/run_main_workflow.py experiments/n4p2/train_n4p2_lrate_0_02_181023 15 --rule inference -v`
    - `./scripts/run_main_workflow.py experiments/n4p2/train_n4p2_lrate_0_02_181023 15 --chkpt -1 --rule inference -v`
- Runtime (including saved file compression): ~6 min per workflow (AMD Ryzen 9 5900X, 64GB RAM).

**Plots:**

A) Responses of an informative L4 neuron to N4P2 object subsets

B) L4 Gabor traceback (negative, positive, averaged)

C) L3 #1 Gabor traceback (negative, positive, averaged)

D) L3 #2 Gabor traceback (negative, positive, averaged)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr
from pydantic import BaseModel

from hsnn import analysis, simulation, utils, viz
from hsnn.cluster import tasks
from hsnn.utils import handler, io
from hsnn.utils.handler import TrialView

pidx = pd.IndexSlice
RESULTS_DIR = io.BASE_DIR / "out/figures/fig6"
RESULTS_DIR.mkdir(parents=True, exist_ok=True)

def load_inference_results(
    trial: TrialView, chkpt: int | None, **kwargs
) -> xr.DataArray:
    results_path = handler.get_results_path(trial, chkpt, **kwargs)
    if results_path.is_file():
        print(f"Loading '{results_path.relative_to(trial.path)}'...")
        return utils.io.load_pickle(results_path)
    else:
        raise FileNotFoundError(f"'{results_path}'")


class InferenceConfig(BaseModel):
    """Inference kwargs passed to `handler.load_results`."""

    amplitude: int
    subdir: str | None = None


records: dict[str, xr.DataArray] = {"pre": None, "post": None}
syn_params: dict[str, pd.DataFrame] = {"pre": None, "post": None}
state_chkpt_mapping = {"pre": None, "post": -1}

viz.setup_journal_env({"legend.fontsize": "x-small"})

### 1) Load N4P2 recorded results

**Get representative Trial per (`E2E`, `FB`) combination**

In [None]:
logdir = "n4p2/train_n4p2_lrate_0_02_181023"

expt = handler.ExperimentHandler(logdir)
dataset_name = expt.logdir.parent.stem

df = expt.get_summary(-1)
closest_trials = expt.index_to_dir[handler.get_closest_samples(df)]
closest_trials.drop((0, 20), axis=0)

In [None]:
trial_index = (20, 20, 3)
# chkpt_idx = -1

trial = expt[trial_index]
print(f"Trial selected: '{trial.name}'")
proj_choices = {
    key: val for key, val in zip(closest_trials.index.names[:-1], trial_index[:-1])
}
print(f"Projections: {proj_choices}")

**Load pre- / post-training results**

In [None]:
inference_cfg = InferenceConfig(amplitude=0, subdir=None)
offset = 0.0 if inference_cfg.subdir == "onsets" else 50.0

# Imageset
cfg = trial.config
if inference_cfg.amplitude > 0:
    cfg["training"]["data"]["transforms"]["gaussiannoise"] = [inference_cfg.amplitude]
imageset, labels = utils.io.get_dataset(
    cfg["training"]["data"], return_annotations=True
)

# Recordings per state
for state, chkpt in state_chkpt_mapping.items():
    records[state] = load_inference_results(trial, chkpt, **dict(inference_cfg))

# Common parameters
duration: float = records["post"].item(0).duration - offset  # Observation period
reps = len(records["post"]["rep"])
input_shape = tuple(cfg["topology"]["poisson"]["EXC"])
layer_shape = tuple(cfg["topology"]["spatial"]["EXC"])

# Get synapse parameters
sim = simulation.Simulator.from_config(cfg)
for state in ("pre", "post"):
    chkpt = state_chkpt_mapping[state]
    store_path = trial.checkpoints[chkpt].store_path if isinstance(chkpt, int) else None
    if store_path is not None:
        sim.restore(store_path)
    syn_params[state] = sim.network.get_syn_params()
projs_plastic = tuple(["FF"] + [proj for proj, val in proj_choices.items() if val > 0])

### 2) Viz network dynamics

**Firing rate distributions**

In [None]:
# Poisson firing rates
state = "post"
rates = records[state].sel(img=0, rep=0, layer=0, nrn_cls="EXC").item().rates[1]

f, axes = plt.subplots(figsize=(4, 2))
axes.hist(rates[rates > 0], bins=60, density=True)
axes.set_xlabel("Firing rate (Hz)")
axes.set_ylabel(r"$f\;(r)$");

In [None]:
# Firing rates distributions for an image, rep
state = "post"
img = 0
rep = 0

viz.hist_rates(
    records[state].sel(img=img, rep=rep, layer=slice(1, None)),
    bins=60,
    xmax=[200, 300],
    figsize=(6, 8),
    yticks=False,
    xlabel="Firing rate (Hz)",
);

In [None]:
# Topological firing activity
state = "post"
img = 0
rep = 0

axes = viz.topographic_rates(
    records[state].sel(img=img, rep=rep, layer=slice(1, None)),
    plot_ticks=False,
    figsize=(4, 8),
);

### 3) Viz weights distribution

In [None]:
state = "post"

axes = viz.hist_weights(
    syn_params[state], projs_plastic, 40, annotations=True, figsize=(8, 6)
)
for ax in axes[:, 0]:
    ax.set_ylabel(ax.get_ylabel(), size="x-large")

### 4) Boundary contour element selectivity

In [None]:
# Infer firing rates for a layer
layer = [3, 4]

rates_array = analysis.infer_rates(
    records["post"].sel(layer=layer, nrn_cls="EXC"), duration, offset
)
rates_array_pre = analysis.infer_rates(
    records["pre"].sel(layer=layer, nrn_cls="EXC"), duration, offset
)
print(f"Firing rates array: layer={layer}; duration={duration}; offset={offset}")

In [None]:
pre_ids = [1869, 2380]
post_id = 2176

pre_ids, post_id

**Subplot A**

Select L4 neuron firing rate responses (left-to-right):
- Concave
- Convex
- All

In [None]:
def plot_selectivity_panel(
    axes,
    rates_array,
    layer,
    nrn_id,
    labels,
    rates_array_pre=None,
    yticks=None,
    ylim=None,
):
    viz.plot_contour_selectivity(
        rates_array.sel(layer=layer),
        nrn_id,
        labels,
        rates_array_pre,
        bar_width=0.7,
        violinplot=False,
        show_xlabels=True,
        rotate_xticklabels=False,
        axes=axes,
    )
    axes.set_title(f"L{layer} neuron #{nrn_id}", fontweight="bold")
    _ylim = ylim if ylim is not None else [0, None]
    axes.set_ylim(_ylim)
    if yticks is not None:
        axes.set_yticks(yticks)

In [None]:
labels_ = labels.copy()
labels_ = labels_[["image_id", "left"]]

width = 2  # * 1.125
height = 4.5

yticks = [0, 50, 100]
ylim = [0, 110]

f, axes = plt.subplots(3, 1, figsize=(width, height), sharex=True)
plot_selectivity_panel(
    axes[0],
    rates_array,
    layer=4,
    nrn_id=post_id,
    labels=labels_,
    rates_array_pre=rates_array_pre,
    yticks=yticks,
    ylim=ylim,
)

plot_selectivity_panel(
    axes[1],
    rates_array,
    layer=3,
    nrn_id=pre_ids[0],
    labels=labels_,
    rates_array_pre=rates_array_pre,
    yticks=yticks,
    ylim=ylim,
)

plot_selectivity_panel(
    axes[2],
    rates_array,
    layer=3,
    nrn_id=pre_ids[1],
    labels=labels_,
    rates_array_pre=rates_array_pre,
    yticks=yticks,
    ylim=ylim,
)


def map_xlabel(xlabel):
    text = " /\n".join(xlabel.get_text().split(": "))
    xlabel.set_text(text)
    # xlabel.set_horizontalalignment('left')
    return xlabel


ax: plt.Axes = axes[-1]
xlabels = [map_xlabel(lab) for lab in ax.get_xticklabels()]
xticklabels = ax.set_xticklabels(xlabels)
axes[0].legend(loc="upper left")
f.tight_layout()
# f.subplots_adjust(hspace=0.4)

viz.save_figure(
    f,
    RESULTS_DIR / "fig_neuron_rates.pdf",
    overwrite=False,
)

### 5) Boundary contour element tracebacks

Demonstrate hierarchical feature integration

In [None]:
def subplot_traceback(axes, img, sensitivities_df, norm_max, sel_kwargs=None):
    _sel_kwargs = sel_kwargs or {"s": 5, "linewidths": 0}
    if img is None:
        sensitivities = sensitivities_df.mean(axis=1, skipna=True)
        image = np.full(input_shape[1:], 128.0)
    else:
        sensitivities = sensitivities_df[img].dropna()
        image = imageset[img]

    viz.plot_traceback(
        sensitivities,
        image,
        nrn_id=None,
        layer_shape=layer_shape,
        sensitivities_max=norm_max,
        vmin=0,
        vmax=255,
        color="red",
        sel_kwargs=_sel_kwargs,
        axes=axes,
    )


img_pos = 1
img_neg = 6

**Layer 4**

Left-convex, left-concave, average

In [None]:
layer = 4
nrn_id = post_id

# Runtime: <1 min
sensitivities_L4 = tasks.get_sensitivities(
    nrn_id, layer, duration, offset, records["post"], syn_params["post"], input_shape
)
norm_max_L4 = sensitivities_L4.max(skipna=True).max()

In [None]:
width = 1.2
sel_kwargs = {"s": 5, "linewidths": 0}

f, axes = plt.subplots(1, 3, figsize=(3 * width, width))

# Plot negative case
ax: plt.Axes = axes[0]
subplot_traceback(ax, img_neg, sensitivities_L4, norm_max_L4, sel_kwargs=sel_kwargs)
# ax.set_title('Derp', fontweight='bold')
# Plot negative case
ax: plt.Axes = axes[1]
subplot_traceback(ax, img_pos, sensitivities_L4, norm_max_L4, sel_kwargs=sel_kwargs)
# Plot negative case
ax: plt.Axes = axes[2]
subplot_traceback(ax, None, sensitivities_L4, norm_max_L4, sel_kwargs=sel_kwargs)

f = plt.gcf()
f.tight_layout()

f.savefig(
    RESULTS_DIR / f"fig_traceback_n4p2_L4_{post_id}.pdf",
    dpi=300,
)

**Layer 3**

Left-convex, left-concave, average

In [None]:
layer = 3
print(pre_ids)

# Runtime: ~1.5 min
sensitivities_L3 = {
    pre_id: tasks.get_sensitivities(
        pre_id,
        layer,
        duration,
        offset,
        records["post"],
        syn_params["post"],
        input_shape,
    )
    for pre_id in pre_ids
}
norm_max_L3 = {
    pre_id: sensitivities_L3[pre_id].max(skipna=True).max() for pre_id in pre_ids
}

In [None]:
# pre_id[0]
nrn_id = pre_ids[0]
print(layer, nrn_id)

width = 1.2
sel_kwargs = {"s": 5, "linewidths": 0}

f, axes = plt.subplots(1, 3, figsize=(3 * width, width))

# Plot negative case
ax: plt.Axes = axes[0]
subplot_traceback(
    ax, img_neg, sensitivities_L3[nrn_id], norm_max_L3[nrn_id], sel_kwargs=sel_kwargs
)
# ax.set_title('Derp', fontweight='bold')
# Plot negative case
ax: plt.Axes = axes[1]
subplot_traceback(
    ax, img_pos, sensitivities_L3[nrn_id], norm_max_L3[nrn_id], sel_kwargs=sel_kwargs
)
# Plot negative case
ax: plt.Axes = axes[2]
subplot_traceback(
    ax, None, sensitivities_L3[nrn_id], norm_max_L3[nrn_id], sel_kwargs=sel_kwargs
)

f = plt.gcf()
f.tight_layout()

f.savefig(
    RESULTS_DIR / f"fig_traceback_n4p2_L3_{nrn_id}.pdf",
    dpi=300,
)

In [None]:
# pre_id[0]
nrn_id = pre_ids[1]
print(layer, nrn_id)

width = 1.2
sel_kwargs = {"s": 5, "linewidths": 0}

f, axes = plt.subplots(1, 3, figsize=(3 * width, width))

# Plot negative case
ax: plt.Axes = axes[0]
subplot_traceback(
    ax, img_neg, sensitivities_L3[nrn_id], norm_max_L3[nrn_id], sel_kwargs=sel_kwargs
)
# ax.set_title('Derp', fontweight='bold')
# Plot negative case
ax: plt.Axes = axes[1]
subplot_traceback(
    ax, img_pos, sensitivities_L3[nrn_id], norm_max_L3[nrn_id], sel_kwargs=sel_kwargs
)
# Plot negative case
ax: plt.Axes = axes[2]
subplot_traceback(
    ax, None, sensitivities_L3[nrn_id], norm_max_L3[nrn_id], sel_kwargs=sel_kwargs
)

f = plt.gcf()
f.tight_layout()

f.savefig(
    RESULTS_DIR / f"fig_traceback_n4p2_L3_{nrn_id}.pdf",
    dpi=300,
)