# Oscillation analysis: HFB onsets

Here we determine if the activation times of PNGs are associated with background rhythmic activity of excitatory neurons.

**Dependencies:**

Spike recordings, PNG detection and significance testing:
- Note that recorded spike trains are non-deterministic: **results may differ from manuscript**
- Takes ~1 hr (AMD Ryzen 9 5900X, 64GB RAM) to run the entire workflow
- Same as the dependencies in `plot_fig19.ipynb`
```bash
./scripts/run_main_workflow.py experiments/n3p2/train_n3p2_lrate_0_04_181023 31 --layers 4 --configfile config/workflow/config_onsets.yaml --chkpt -1 --subdir onsets --rule significance -v
```

**Methods**

Two integration methods explored for determining population activity: either $A^l(t)$ (global) or $A^l(t, \vec{p})$ at position $\vec{p}$ (localised).

**Procedure**

- Oscillations are calculated from the spatiotemporal recordings of EXC neurons in layer L3.
- A bin size / averaging window in the interval $\Delta T = [1, 5]$ ms is selected.
- The localised activity $A^l(t, \vec{p}_k)$ is determined w.r.t. a PNG, $k$ as follows:
    1. The activity is determined w.r.t. the same layer as the initial, low-level neuron $i$ in layer $l$.
    2. The weighted average of the number of spikes fired per neuron in $l$ over $\Delta T$ is taken at each time point $t \in [\Delta T, T]$, centered about $\vec{p}_i$.
    3. This is divided by $\Delta T$ to give the (instantaneous) population activity at each point in time $t$.
- This is simplified if the global activity is instead considered: no need to compute separate $A^l(t)$ for each PNG $k$.

**Plots**

- The localised population activity $A^l(t, \vec{p}_k)$ is plotted for a PNG $k$ starting in layer $l=3$: in response to two different shapes.

In [None]:
from pathlib import Path
from pprint import pprint
from typing import Iterable

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

import hsnn.analysis.png.db as polydb
from hsnn.analysis.png import postproc, stats
from hsnn import analysis, ops, simulation, utils, viz
from hsnn.core import SpikeRecord
from hsnn.utils import handler

pidx = pd.IndexSlice
OUTPUT_DIR = utils.io.BASE_DIR / "out/figures/fig20"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)


prop_cycle = plt.rcParams["axes.prop_cycle"]
colors = prop_cycle.by_key()["color"]

viz.setup_journal_env()

### 1) Select experiment

Load representative trial.

In [None]:
logdir = "n3p2/train_n3p2_lrate_0_04_181023"

expt = handler.ExperimentHandler(logdir)
dataset_name = Path(logdir).parent.name
print(f"Target dataset: {dataset_name}")
# Get relevant, representative Trials
df = expt.get_summary(-1)
closest_trials = expt.index_to_dir[handler.get_closest_samples(df)]
closest_trials.drop([(0, 0), (0, 20), (20, 0)], axis=0)

### 2) Load data

Including Trial, spike records, HFB DB, detected PNGs

In [None]:
# Inspect TrialsDict
model_type = "ALL"
analysis_type = "onsets"
states = ("post",)

print(f"Selected network type: '{model_type}'; results: '{analysis_type}'\n")
trials_dict = expt.metadata.get_trials_dict(model_type)

# View relevant trials
trial_names = trials_dict[analysis_type]
print("# Trials to analyse:")
pprint([expt[trial_name] for trial_name in trial_names])

In [None]:
trial_id = "TrainSNN_eb0d4_00031"
state = "post"
subdir = "onsets"
offset = 0.0
num_reps = 10

# Get representative Trial
trial = expt[trial_id]
print(trial)

# Get relevant spike records
result = handler.load_results(trial, state, subdir=subdir)[state].sel(
    rep=range(num_reps)
)
duration: float = result.item(0).duration - offset
assert len(result.rep) == 10
print(f"\nduration={duration}; offset={offset}; num_reps={num_reps}")

# Load imageset and labels[, optionally with injected Gaussian noise]
cfg = trial.config
imageset, labels = utils.io.get_dataset(
    cfg["training"]["data"], return_annotations=True
)

# Get relevant HFB database
database = handler.load_detections(trial, state, subdir=subdir)[state]
print(f"Loaded HFB database '{Path(database.path).relative_to(utils.io.BASE_DIR)}'")

### 3) Restore Network and get refined PNGs
For inspection of ground-truth axonal conduction delays, weights, etc.

In [None]:
sim = simulation.Simulator.from_config(cfg)
if state == "post":
    sim.restore(trial.checkpoints[-1].store_path)
    print(f"Restoring from checkpoint '{trial.checkpoints[-1].store_path.relative_to(utils.io.BASE_DIR)}'")
syn_params: pd.DataFrame = sim.network.get_syn_params()
syn_params = syn_params.loc[pidx[slice(None), ("FF", "E2E")], :].sort_index(
    inplace=False
)

# Get detected PNGs (final layer)
polygrps = polydb.get_polygrps(database, syn_params)

### 4) Get inference / detection metrics

Including the following:
- Single-neuron specific information
- PNG performance metrics

In [None]:
# Get single-neuron specific information measures for a target per side (convex)
target = 1

# Firing rates of EXC neuron across last two layers
rates_array = analysis.infer_rates(
    result.sel(layer=[3, 4], nrn_cls="EXC"), duration, offset
)

specific_measures: dict[str, dict[str, pd.DataFrame]] = {}
for layer in tqdm([3, 4]):
    specific_measures[layer] = analysis.get_specific_measures_side(
        rates_array.sel(layer=layer), labels, target=target
    )

# Get PNG performance metrics per side
occ_array = stats.get_occurrences_array(
    polygrps, num_reps, len(imageset), index=1, duration=duration, offset=offset
)
metrics_side = stats.get_metrics_side(occ_array, labels, target)

### 5) Inspect a PNG

In [None]:
side = "left"

metrics_side[side].sort_values("score", ascending=False).head(10)

In [None]:
png_id = polydb.find_matching_index([(3, 2889), (4, 1990), (4, 1927)], polygrps)
if metrics_side[side].loc[png_id]["score"] < 0.9:
    raise UserWarning(
        f"Selected PNG with ID {png_id} has a score below 0.9: {metrics_side[side].loc[png_id]['score']:.3f}"
    )

polygrp = polygrps[png_id]
occ_array.sel(png=png_id).plot()
plt.show()

### 6) Population activity

- Calculate global activity at each time point.
- Select an (`img`, `rep`) for a focal PNG with high F1-score

**Explore individual activity**

**Plot activity: positive and negative examples**

With / without a left-convex boundary element

In [None]:
def plot_image(
    img: int,
    imageset: Iterable,
    nrn_id: int | None = None,
    show_ticks: bool = False,
    axes: plt.Axes | None = None,
    **plot_kwargs,
) -> plt.Axes:
    axes = viz.imshow_cbar(imageset[img], attach_cbar=False, cmax=255, axes=axes)
    # Optionally plot nrn location
    if nrn_id is not None:
        _plot_kwargs = dict(
            marker="o", markerfacecolor=colors[1], color="k", markersize=8
        )
        _plot_kwargs.update(**plot_kwargs)
        coord = np.asarray(analysis.get_coords(nrn_id)) * 2
        axes.plot(*coord, **_plot_kwargs)
    if not show_ticks:
        axes.set_xticks([])
        axes.set_yticks([])
    return axes

In [None]:
imgs_reps = [(2, 2), (4, 0)]
nrn_idx = 0
duration = 200.0
radius = 10
nrn_id = polygrp.nrns[nrn_idx]
layer = polygrp.layers[nrn_idx]

population = analysis.Population(duration, bin_size=2)
nrn_ids = sorted(population.get_local_ids(nrn_id, radius))

recordings_dict = {}
activity_dict = {}
polygrp_trains_dict = {}
for img, rep in imgs_reps:
    record: SpikeRecord = result.sel(
        rep=rep, img=img, layer=layer, nrn_cls="EXC"
    ).item()
    recordings_dict[(img, rep)] = ops.select_nrns(record.spike_events, nrn_ids)
    activity_dict[(img, rep)] = population.local_activity(
        record.spike_events, nrn_id, radius=radius
    )
    polygrp_trains_dict[(img, rep)] = postproc.get_polygrp_trains(
        polygrp, img, rep, result, duration
    )

population = analysis.Population(duration, bin_size=2)

In [None]:
ydelims = min(nrn_ids), max(nrn_ids)
xmax = 100
dx = 20
dy = 100
plot_kwargs = {"marker": "o"}

f, axes = plt.subplots(2, 2, figsize=(5.5, 3), sharex=True)

# Row 0: Activity plots
ax: plt.Axes
ymax = 0
for i, (ax, key) in enumerate(zip(axes[0], imgs_reps)):
    img = key[0]
    ax.plot(population.time_points, activity_dict[key], color=colors[0])
    ylim = ax.get_ylim()
    ymax = ylim[-1] if ylim[-1] > ymax else ymax
    ax.set_ylim([0, ymax])
    inset_axes_ = inset_axes(
        ax,
        width="40%",
        height="40%",
        bbox_to_anchor=(0.35, 0.25, 1, 1),
        bbox_transform=ax.transAxes,
        loc="center",
    )
    plot_image(img, imageset, nrn_id=nrn_id, axes=inset_axes_, markersize=4)
    if i > 0:
        ax.set_yticklabels([])
    else:
        ax.set_ylabel("Activity [Hz]")
    ax.vlines(
        polygrp_trains_dict[key][nrn_idx],
        0,
        ymax,
        colors=colors[1],
        linestyles="dashed",
    )

# Row 1: Raster plots
for i, (ax, key) in enumerate(zip(axes[1], imgs_reps)):
    viz.plot_raster(
        recordings_dict[key],
        xmax,
        markerfacecolor="gray",
        color="gray",
        markeredgewidth=0,
        markersize=2,
        axes=ax,
        **plot_kwargs,
    )
    viz.plot_raster(
        {polygrp.nrns[nrn_idx]: polygrp_trains_dict[key][nrn_idx]},
        xmax,
        markerfacecolor=colors[1],
        color="k",
        markeredgewidth=2 / 3,
        axes=ax,
        markersize=4,
        **plot_kwargs,
    )
    ymin, ymax = (ydelims[0] // dy) * dy, (ydelims[1] // dy + 1) * dy
    # ymin, ymax = 2000, 3500
    ax.set_ylim(ymin, ymax)
    if i > 0:
        ax.set_yticklabels([])
        ax.set_ylabel("")
    else:
        ax.set_xlim([0, xmax])
        ax.set_xticks(np.arange(0, xmax + dx, dx))
# f.subplots_adjust(wspace=0.15)

# Save figure
filedir = OUTPUT_DIR / "fig_png_onset_activity_n3p2.pdf"
viz.save_figure(f, filedir, overwrite=False)