# Plot HFB rasters

Spike rasters of neuronal activity involved in three PNGs that structurally conform to three-neuron HFB circuits.

**Dependencies:**

Significance testing:
- PNG detection and significance testing for N3P2: after network training
- **This workflow is time-consuming to run**
- Note that recorded spike trains and significance-tested PNGs are non-deterministic: **results may differ from manuscript**
```bash
./scripts/run_main_workflow.py experiments/n3p2/train_n3p2_lrate_0_04_181023 31 --chkpt -1 --rule significance -v
```

**Plots:**

Spike rasters for three PNGs that are selective to left- right- and top-convex feature elements.

In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr

import hsnn.analysis.png.db as polydb
from hsnn import utils, viz
from hsnn.analysis.png import postproc, stats, PNG
from hsnn.utils import handler, io

pidx = pd.IndexSlice
OUTPUT_DIR = io.BASE_DIR / "out/figures/fig15"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)


def get_max_rep(polygrp: PNG, img: int) -> int:
    mask = polygrp.imgs == img
    rep_ids, counts = np.unique(polygrp.reps[mask], return_counts=True)
    return rep_ids[np.argmax(counts)]


# Plotting
prop_cycle = plt.rcParams["axes.prop_cycle"]
colors = prop_cycle.by_key()["color"]

viz.setup_journal_env()

### 1) Select experiment

Load representative trial.

In [None]:
logdir = "n3p2/train_n3p2_lrate_0_04_181023"

expt = handler.ExperimentHandler(logdir)
dataset_name = Path(logdir).parent.name
print(f"Target dataset: {dataset_name}")
# Get relevant, representative Trials
df = expt.get_summary(-1)
closest_trials = expt.index_to_dir[handler.get_closest_samples(df)]
closest_trials.drop([(0, 0), (0, 20), (20, 0)], axis=0)

### 2) Load data

Including Trial, spike records, HFB DB, detected PNGs

In [None]:
trial_id = "TrainSNN_eb0d4_00031"
state = "post"
offset = 50.0
num_reps = 10

# Get representative Trial
trial = expt[trial_id]
print(trial)

# Get relevant spike records
result = handler.load_results(trial, state)[state].sel(rep=range(num_reps))
duration = result.item(0).duration - offset
assert len(result.rep) == 10
print(f"\nduration={duration}; offset={offset}; num_reps={num_reps}")

# Load imageset and labels
imageset, labels = utils.io.get_dataset(
    trial.config["training"]["data"], return_annotations=True
)

# Get relevant HFB database
database = handler.load_detections(trial, state)[state]

# Get detected PNGs
polygrps = polydb.get_polygrps(database)
print(f"Loading HFBs from '{database.path}'")

### 2) Get F1 Scores

To identify which HFBs to plot.

In [None]:
# Get performance metrics per side
occ_array = stats.get_occurrences_array(
    polygrps, num_reps, len(imageset), index=1, duration=duration, offset=offset
)
metrics_side = stats.get_metrics_side(occ_array, labels, target=1)

### 3) PNG raster plots (three examples)

In [None]:
# Select three top-most performant PNGs (left-convex, right-convex, top-convex)
side_polygrp_id_mapping: dict[str, int] = {}
for side in metrics_side.keys():
    side_polygrp_id_mapping[side] = metrics_side[side].sort_values("score", ascending=False).index[0]
side_polygrp_id_mapping

# (Alternatively) select by known indices
# side_indices_mapping = {
#     "left": [(3, 2889), (4, 1990), (4, 1927)],
#     "right": [(3, 2623), (4, 2802), (4, 2737)],
#     "top": [(3, 1129), (4, 1526), (4, 1591)],
# }
# side_polygrp_id_mapping = {
#     side: polydb.find_matching_index(indices, polygrps) for side, indices in side_indices_mapping.items()
# }

In [None]:
# Report F1 score of the left- right- and top-convex selective PNGs
for side, polygrp_id in side_polygrp_id_mapping.items():
    f1_score = metrics_side[side].loc[polygrp_id]["score"]
    print(f"{side.capitalize()}-selective PNG (ID: {polygrp_id}): F1-score = {f1_score:.4f}")

In [None]:
def get_trains_plotting(
    polygrp: PNG,
    imgs: list[int],
    reps: list[int],
    result: xr.DataArray,
    offset: float,
) -> tuple[dict, dict]:
    """Get spike trains for each (image, rep): entire sequences and PNG-specific ones"""
    spike_trains_dict = {}
    polygrp_trains_dict = {}

    interval = 100
    for key, img, rep in zip(["null", "pos"], imgs, reps):
        spike_trains_dict[key] = postproc.get_spike_trains(
            polygrp, img, rep, result, interval, offset, relative_times=True
        )
        polygrp_trains_dict[key] = postproc.get_polygrp_trains(
            polygrp, img, rep, result, interval, offset, relative_times=True
        )
    # Concatenate the spike trains, null -> pos
    spike_trains = postproc.concat_spike_trains(
        spike_trains_dict["null"], spike_trains_dict["pos"], interval
    )
    polygrp_trains = postproc.concat_spike_trains(
        polygrp_trains_dict["null"], polygrp_trains_dict["pos"], interval
    )
    return spike_trains, polygrp_trains


def get_img_rep_ids(
    polygrp, side: str, labels: pd.DataFrame
) -> tuple[list[int], list[int]]:
    """Get list of images, reps, corresponding to null, positive cases."""
    img_null = (labels.drop("image_id", axis=1).sum(axis=1) == 0).idxmax()
    img_pos = (
        (labels.drop("image_id", axis=1).sum(axis=1) == 1) & labels[side] == 1
    ).idxmax()

    assert len(polygrp.reps[(polygrp.imgs == img_null)]) == 0, (
        "Prefer no spiking for null image"
    )
    rep_null = 0
    rep_pos = get_max_rep(polygrp, img_pos)
    return ([img_null, img_pos], [rep_null, rep_pos])


def plot_raster_row(
    spike_trains: dict,
    polygrp_trains: dict,
    polygrp: PNG,
    ax: plt.Axes,
    xwidth: float,
    xmin: float,
    xlabel: str = "",
    interval: float = 100,
    **plot_kwargs,
):
    viz.plot_raster(
        spike_trains,
        xwidth,
        xmin,
        alpha=1,
        axes=ax,
        markerfacecolor="none",
        color="gray",
        xlabel="",
        **plot_kwargs,
    )
    viz.plot_raster(
        polygrp_trains,
        xwidth,
        xmin,
        alpha=1,
        xlabel=xlabel,
        axes=ax,
        markerfacecolor="black",
        **plot_kwargs,
    )
    ymin, ymax = ax.get_ylim()
    ax.vlines([interval], ymin, ymax, colors="k", linestyles="dashed")
    ax.set_ylim(ymin, ymax)

    ax.set_yticks([0, 1, 2])
    ax.set_yticklabels([f"L{x[0]} #{x[1]}" for x in zip(polygrp.layers, polygrp.nrns)])


In [None]:
# Get spike trains for the top- left- and right-selective PNGs
spike_trains_side = {}
polygrp_trains_side = {}
polygrp_side = {}
sides = []
for side, png_id in side_polygrp_id_mapping.items():
    polygrp = polygrps[png_id]
    imgs, reps = get_img_rep_ids(polygrp, side, labels)
    spike_trains_side[side], polygrp_trains_side[side] = get_trains_plotting(
        polygrp, imgs, reps, result, offset
    )
    polygrp_side[side] = polygrp
    sides.append(side)

In [None]:
# Plot the concatenated spike trains: entire sequences and PNG-specific ones
interval = 100
width = 6
height = 4.5

xmin = 0
xwidth = 2 * interval
plot_kwargs = {"markeredgewidth": 2 / 3, "marker": "o", "markersize": 3}

f, axes = plt.subplots(3, 1, figsize=(width, height), sharex=False)
for i, side in enumerate(sides):
    xlabel = "" if i < 2 else "Time [ms]"
    plot_raster_row(
        spike_trains_side[side],
        polygrp_trains_side[side],
        polygrp_side[side],
        axes[i],
        xwidth,
        xmin,
        xlabel=xlabel,
        interval=interval,
        **plot_kwargs,
    )
    axes[i].set_xticks(np.arange(xmin, xwidth + 50, 50))
f.tight_layout()
viz.save_figure(f, OUTPUT_DIR / "fig15_png_rasters.pdf")