# Plot feature sharing across multiple HFBs

Spike rasters of neuronal activity involved in two PNGs.

**Dependencies:**

Significance testing:
- PNG detection and significance testing for N4P2: after network training
- **This workflow is time-consuming to run**
- Note that recorded spike trains and significance-tested PNGs are non-deterministic: **results may differ from manuscript**
```bash
./scripts/run_main_workflow.py experiments/n4p2/train_n4p2_lrate_0_02_181023 15 --chkpt -1 --rule significance -v
```

**Plots:**

Spike rasters for two PNGs that are selective for a left-convex feature element, and which share the same high-level feature neuron.

In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr
from matplotlib.patches import Rectangle

import hsnn.analysis.png.db as polydb
from hsnn import ops, utils, viz
from hsnn.analysis.png import postproc, refinery, stats, PNG
from hsnn.utils import handler, io

pidx = pd.IndexSlice
OUTPUT_DIR = io.BASE_DIR / "out/figures/fig17"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)


def get_max_rep(polygrp: PNG, img: int) -> int:
    mask = polygrp.imgs == img
    rep_ids, counts = np.unique(polygrp.reps[mask], return_counts=True)
    return rep_ids[np.argmax(counts)]


# Plotting
prop_cycle = plt.rcParams["axes.prop_cycle"]
colors = prop_cycle.by_key()["color"]

viz.setup_journal_env()

### 1) Select experiment

Load representative trial.

In [None]:
logdir = 'n4p2/train_n4p2_lrate_0_02_181023'

expt = handler.ExperimentHandler(logdir)
dataset_name = Path(logdir).parent.name
print(f"Target dataset: {dataset_name}")
# Get relevant, representative Trials
df = expt.get_summary(-1)
closest_trials = expt.index_to_dir[handler.get_closest_samples(df)]
closest_trials.drop([(0, 0), (0, 20), (20, 0)], axis=0)

### 2) Load data

Including Trial, spike records, HFB DB, detected PNGs

In [None]:
trial_id = "TrainSNN_1fdbf_00015"
state = "post"
offset = 50.0
num_reps = 10

# Get representative Trial
trial = expt[trial_id]
print(trial)

# Get relevant spike records
result = handler.load_results(trial, state)[state].sel(rep=range(num_reps))
duration = result.item(0).duration - offset
assert len(result.rep) == 10
print(f"\nduration={duration}; offset={offset}; num_reps={num_reps}")

# Load imageset and labels
imageset, labels = utils.io.get_dataset(
    trial.config["training"]["data"], return_annotations=True
)

# Get relevant HFB database
database = handler.load_detections(trial, state)[state]

# Get detected PNGs
polygrps = polydb.get_polygrps(database)
print(f"Loaded HFBs from '{Path(database.path).relative_to(expt.logdir)}': {len(polygrps)} PNGs detected.")

### 2) Get F1 Scores

To identify which HFBs to plot.

In [None]:
# Get performance metrics per side
occ_array = stats.get_occurrences_array(
    polygrps, num_reps, len(imageset), index=1, duration=duration, offset=offset
)
metrics_side = stats.get_metrics_side(occ_array, labels, target=1)

### 3) PNG raster plot

In [None]:
def get_trains_plotting(
    polygrp: PNG,
    imgs: list[int],
    reps: list[int],
    result: xr.DataArray,
    offset: float,
) -> tuple[dict, dict]:
    """Get spike trains for each (image, rep): entire sequences and PNG-specific ones"""
    spike_trains_dict = {}
    polygrp_trains_dict = {}

    interval = 100
    for key, img, rep in zip(["null", "pos"], imgs, reps):
        spike_trains_dict[key] = postproc.get_spike_trains(
            polygrp, img, rep, result, interval, offset, relative_times=True
        )
        polygrp_trains_dict[key] = postproc.get_polygrp_trains(
            polygrp, img, rep, result, interval, offset, relative_times=True
        )
    # Concatenate the spike trains, null -> pos
    spike_trains = postproc.concat_spike_trains(
        spike_trains_dict["null"], spike_trains_dict["pos"], interval
    )
    polygrp_trains = postproc.concat_spike_trains(
        polygrp_trains_dict["null"], polygrp_trains_dict["pos"], interval
    )
    return spike_trains, polygrp_trains


def get_img_rep_ids(
    polygrp, side: str, labels: pd.DataFrame
) -> tuple[list[int], list[int]]:
    """Get list of images, reps, corresponding to null, positive cases."""
    img_null = (labels.drop("image_id", axis=1).sum(axis=1) == 0).idxmax()
    img_pos = (
        (labels.drop("image_id", axis=1).sum(axis=1) == 1) & labels[side] == 1
    ).idxmax()

    assert len(polygrp.reps[(polygrp.imgs == img_null)]) == 0, (
        "Prefer no spiking for null image"
    )
    rep_null = 0
    rep_pos = get_max_rep(polygrp, img_pos)
    return ([img_null, img_pos], [rep_null, rep_pos])


def plot_raster_row(spike_trains: dict, polygrp_trains: dict, polygrp: PNG,
                    ax: plt.Axes, xwidth: float, xmin: float, xlabel: str = '', **plot_kwargs):
    viz.plot_raster(spike_trains, xwidth, xmin, alpha=1, axes=ax,
                    markerfacecolor='none', color='gray', xlabel='', **plot_kwargs)
    viz.plot_raster(polygrp_trains, xwidth, xmin, alpha=1, xlabel=xlabel, axes=ax,
                    markerfacecolor='black', **plot_kwargs)
    ax.set_yticks([0, 1, 2])
    ax.set_yticklabels([f'L{x[0]} #{x[1]}' for x in zip(polygrp.layers, polygrp.nrns)])


### 3) Plot HFB combined

In [None]:
def get_aligned_polygrps(query: PNG, position: int, database, tol: float = 3.0) -> list[PNG]:
    target_times = [float(t) for t in query.lags[position] + query.times]
    polygrps_aligned = polydb.query_aligned_pngs(
        int(query.nrns[position]), int(query.layers[position]),
        position, target_times, database, tol
    )
    return polygrps_aligned

In [None]:
# Get PNGs temporally-aligned with query neuron
refine = refinery.FilterIndex((4, 2176), position=1)

try:
    polygrp_sel = refine(polygrps)[0]
except IndexError:
    raise UserWarning("No PNGs found for the given query - select a different query PNG")

# Get all PNGs with timing(s) that are aligned with a focal neuron in a query PNG
polygrps_aligned = get_aligned_polygrps(polygrp_sel, 1, database)
if not len(polygrps_aligned) > 1:
    raise UserWarning("No aligned PNGs found for the given query - select a different query PNG")

# Select first two temporally-aligned PNGs
polygrps_sel = polygrps_aligned[:2]
polygrps_sel

In [None]:
# Get unique [(img, rep), ...] where there are one or more coincident activations
high1_onsets = polygrps_sel[0].times + polygrps_sel[0].lags[1]
high2_onsets = polygrps_sel[1].times + polygrps_sel[1].lags[1]
coincidence_mask = np.isclose(np.abs(high1_onsets.reshape(-1, 1) - high2_onsets).min(axis=1), 0)
eligible_idxs = np.flatnonzero(coincidence_mask)
if len(eligible_idxs) == 0:
    raise UserWarning("No coincident activations found between the two PNGs - select different PNGs.")

imgs_reps_coincident = sorted(set(zip(polygrps_sel[0].imgs[eligible_idxs], polygrps_sel[0].reps[eligible_idxs])))
imgs_reps_coincident

In [None]:
# Get spike patterns, activation patterns, for select PNGs on (img, rep)
img_id, rep_id = imgs_reps_coincident[0]
# img_id, rep_id = 7, 9
print(f"Selected (img, rep) for plotting: ({img_id}, {rep_id})")

spike_patterns = {}
for i, polygrp in enumerate(polygrps_sel):
    spike_patterns[i] = postproc.get_spike_trains(
        polygrp, img_id, rep_id, result, duration, offset, True
    )

polygrp_patterns = {}
for i, polygrp in enumerate(polygrps_sel):
    polygrp_patterns[i] = postproc.get_polygrp_trains(
        polygrp, img_id, rep_id, result, duration, offset, True
    )
polygrp_patterns

In [None]:
# Find coincident activation times of neuron 1 (key=1) between the two PNGs
times_0 = polygrp_patterns[0][1]
times_1 = polygrp_patterns[1][1]

# Match times where neuron 1 fires at the same time in both PNGs
coincident_times = []
for t0 in times_0:
    diffs = np.abs(times_1 - t0)
    idx = np.argmin(diffs)
    if diffs[idx] < 1.0:
        coincident_times.append(t0)

coincident_times = np.array(coincident_times)
print(f"Coincident activation times of neuron 1: {coincident_times}")

In [None]:
# Build bounding boxes for each coincident activation, matched to PNG temporal span
dy = 0.08

bb_coords: list[list[dict]] = []  # bb_coords[i_png][j_activation]
for i, polygrp in enumerate(polygrps_sel[:2]):
    bbs = []
    times_i = polygrp_patterns[i][1]
    for t_coinc in coincident_times:
        # Find the activation index in this PNG closest to the coincident time
        idx = np.argmin(np.abs(times_i - t_coinc))
        # Get spike times for all 3 neurons at this activation index
        nrn_times = [polygrp_patterns[i][n][idx] for n in range(3)]
        t_min = min(nrn_times)
        t_max = max(nrn_times)
        width = t_max - t_min
        bbs.append({'xy': (t_min - 2, 0 - dy), 'width': width + 4, 'height': 2 + 2 * dy})
    bb_coords.append(bbs)

In [None]:
f, axes = plt.subplots(2, 1, figsize=(5.5, 3))
axes = np.atleast_1d(axes)

xmin = 0
xwidth = duration
plot_kwargs = {
    'markeredgewidth': 2/3, 'marker': 'o', 'markersize': 3
}

focal_color = 'C2'
for i, polygrp in enumerate(polygrps_sel[:len(axes)]):
    polygrp_id = polydb.find_idx(polygrp, polygrps)

    ax = axes[i]
    xlabel = '' if i < 1 else 'Time [ms]'
    viz.plot_raster(
        ops.select_nrns(spike_patterns[i], [1]), xwidth, xmin, alpha=1, axes=ax,
        markerfacecolor='none', color=focal_color, xlabel='', **plot_kwargs
    )
    viz.plot_raster(
        ops.select_nrns(polygrp_patterns[i], [1]), xwidth, xmin, alpha=1, axes=ax,
        markerfacecolor=focal_color, color=focal_color, xlabel='', **plot_kwargs
    )
    plot_raster_row(
        ops.select_nrns(spike_patterns[i], [0, 2]),
        ops.select_nrns(polygrp_patterns[i], [0, 2]),
        polygrp, ax, duration, xmin,
        xlabel=xlabel, **plot_kwargs
    )
    yticks = ax.get_yticks()
    ytick_labels = ax.get_yticklabels()
    for ytick, label in zip(yticks, ytick_labels):
        if ytick == 1:
            label.set_color(focal_color)
    # Draw bounding boxes for all coincident activations
    for bb in bb_coords[i]:
        rect = Rectangle(**bb, linewidth=1, edgecolor=colors[1], facecolor='none')
        ax.add_patch(rect)
    ax.set_xticks(np.arange(xmin, xwidth + 50, 50))
    # Attach PNG ID next to the last bounding box
    if bb_coords[i]:
        last_rect = Rectangle(**bb_coords[i][-1])
        ax.text(last_rect.get_bbox().xmax + 4, 1.2, f'PNG ID: #{polygrp_id}', fontsize='x-small', ha='left')

f.tight_layout()
viz.save_figure(f, OUTPUT_DIR / "fig17_png_rasters.pdf")