# Plot onset timings and regression

PNG onset timing, timing precision, and feature selectivity for three-sided shapes (N3P2).

**Dependencies:**

Spike recordings, PNG detection and significance testing:
- Note that recorded spike trains, PNG significance testing are non-deterministic: **results may differ from manuscript**
- Takes ~1 hr (AMD Ryzen 9 5900X, 64GB RAM) to run the entire workflow
```bash
./scripts/run_main_workflow.py experiments/n3p2/train_n3p2_lrate_0_04_181023 31 --layers 4 --configfile config/workflow/config_onsets.yaml --chkpt -1 --subdir onsets --rule significance -v
```

**Plots:**

PNGs that are selective to left- right- and top-convex feature elements:
- Empirical distribution of PNG onset times
- Timing dispersion (standard deviation) versus mean onset time computed across repetitions for each PNG-side pairing
- Mean onset time versus convex-boundary selectivity (F1 score) for each PNG-side pairing

In [None]:
from pathlib import Path
from pprint import pprint

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr
import scipy.stats as spstats
from scipy.stats import linregress
from tqdm.notebook import tqdm

import hsnn.analysis.png.db as polydb
from hsnn import analysis, simulation, utils, viz
from hsnn.analysis.png import stats
from hsnn.utils import handler

pidx = pd.IndexSlice
OUTPUT_DIR = utils.io.BASE_DIR / "out/figures/fig19"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)


# Plotting
prop_cycle = plt.rcParams["axes.prop_cycle"]
colors = prop_cycle.by_key()["color"]

viz.setup_journal_env()

### 1) Select experiment

Load representative trial.

In [None]:
logdir = "n3p2/train_n3p2_lrate_0_04_181023"

expt = handler.ExperimentHandler(logdir)
dataset_name = Path(logdir).parent.name
print(f"Target dataset: {dataset_name}")
# Get relevant, representative Trials
df = expt.get_summary(-1)
closest_trials = expt.index_to_dir[handler.get_closest_samples(df)]
closest_trials.drop([(0, 0), (0, 20), (20, 0)], axis=0)

### 2) Load data

Including Trial, spike records, HFB DB, detected PNGs

In [None]:
# Inspect TrialsDict
model_type = "ALL"
analysis_type = "onsets"
states = ("post",)

print(f"Selected network type: '{model_type}'; results: '{analysis_type}'\n")
offset = 0.0 if analysis_type == "onsets" else 50.0
trials_dict = expt.metadata.get_trials_dict(model_type)

# View relevant trials
trial_names = trials_dict[analysis_type]
print("# Trials to analyse:")
pprint([expt[trial_name] for trial_name in trial_names])

In [None]:
trial_id = "TrainSNN_eb0d4_00031"
state = "post"
subdir = "onsets"
num_reps = 10

# Get representative Trial
trial = expt[trial_id]
print(trial)

# Get relevant spike records
result = handler.load_results(trial, state, subdir=subdir)[state].sel(
    rep=range(num_reps)
)
duration = result.item(0).duration - offset
assert len(result.rep) == 10
print(f"\nduration={duration}; offset={offset}; num_reps={num_reps}")

# Load imageset and labels
cfg = trial.config
imageset, labels = utils.io.get_dataset(
    cfg["training"]["data"], return_annotations=True
)

# Get relevant HFB database
database = handler.load_detections(trial, state, subdir=subdir)[state]
print(f"Loaded HFB database '{database.path}'")

### 3) Restore Network and get refined PNGs
For inspection of ground-truth axonal conduction delays, weights, etc.

In [None]:
sim = simulation.Simulator.from_config(cfg)
if state == "post":
    sim.restore(trial.checkpoints[-1].store_path)
    print(f"Restoring from checkpoint '{trial.checkpoints[-1].store_path}'")
syn_params: pd.DataFrame = sim.network.get_syn_params()
syn_params = syn_params.loc[pidx[slice(None), ("FF", "E2E")], :].sort_index(
    inplace=False
)

# Get detected PNGs (final layer)
polygrps = polydb.get_polygrps(database, syn_params)

### 4) Get inference / detection metrics

Including the following:
- Single-neuron specific information
- PNG performance metrics

In [None]:
# Get single-neuron specific information measures for a target per side (convex)
target = 1

# Firing rates of EXC neuron across last two layers
rates_array = analysis.infer_rates(
    result.sel(layer=[3, 4], nrn_cls="EXC"), duration, offset
)

specific_measures: dict[
    str, dict[str, pd.DataFrame]
] = {}  # layer -> side -> nrn measures
for layer in tqdm([3, 4]):
    specific_measures[layer] = analysis.get_specific_measures_side(
        rates_array.sel(layer=layer), labels, target=target
    )

# Get PNG performance metrics per side
occ_array = stats.get_occurrences_array(
    polygrps, num_reps, len(imageset), index=1, duration=duration, offset=offset
)
metrics_side = stats.get_metrics_side(occ_array, labels, target)  # side -> PNG metrics

### 5) Onsets analysis

**Gather onsets across PNGs with recorded scores per side**

- Gather recordings across all side-specific measures
- Drop all PNGs without an F1 score for any side
- Get flattened onsets across all retained PNGs

In [None]:
def get_scored_onsets(
    onsets_array: xr.DataArray, metrics: pd.DataFrame, img: int | None = None
) -> xr.DataArray:
    _onsets_array = onsets_array if img is None else onsets_array.sel(img=img)
    mask = np.isin(_onsets_array["png"], metrics.index)
    png_sel = _onsets_array["png"].values[mask]
    return _onsets_array.sel(png=png_sel)


def get_onset_statistics(
    onsets_series: pd.Series, agg: list[str], min_reps: int = 3
) -> pd.Series:
    assert set(onsets_series.index.names) == {"side", "png", "rep"}
    assert isinstance(onsets_series, pd.Series)
    onsets_series_ = onsets_series.groupby(["side", "png"]).filter(
        lambda x: len(x) >= min_reps
    )
    return onsets_series_.groupby(["side", "png"]).agg(agg)


sides = labels.drop("image_id", axis=1).columns.tolist()

In [None]:
onsets_array = stats.get_onsets_array(polygrps, num_reps, len(imageset), null_img=5)

onsets_side = {}
for side in sides:
    onsets_array_sel = get_scored_onsets(
        onsets_array, metrics_side[side], img=utils.get_unique_id(labels, side)
    )
    onsets_series = (
        onsets_array_sel.stack(sample=("png", "rep")).dropna("sample").to_pandas()
    )
    onsets_series.name = "onset"

    onsets_side[side] = pd.merge(
        onsets_series,
        metrics_side[side]["score"],
        how="left",
        left_on="png",
        right_index=True,
    )

In [None]:
bins = np.linspace(0, 1, 21)

onsets_concat: pd.DataFrame = pd.concat(onsets_side, names=["side"])
onsets_df = onsets_concat.reset_index(drop=True)
onsets_df["score_bin"] = pd.cut(onsets_df["score"], bins=bins, include_lowest=False)

**i) Plot histogram and mean vs stdev**

- Show distribution of individual onset times: peak at short times
- Get mean w.r.t. each individual PNG across all reps (reps > occ_thr) and stdev

In [None]:
def plot_av_stdev(
    onset_means: np.ndarray, onset_stdevs: np.ndarray, axes: plt.Axes
) -> plt.Axes:
    axes.scatter(onset_means, onset_stdevs, s=3)
    axes.set_axisbelow(True)
    axes.grid()
    axes.set_xticks(np.arange(0, 250, 50))
    axes.set_xlim([0, 200])
    axes.set_yticks(np.arange(0, 150, 50))
    axes.set_ylim([0, 100])
    axes.set_xlabel("Mean onset [ms]")
    axes.set_ylabel("Dispersion [ms]")
    return axes

In [None]:
# Get mean and stdev of onsets per PNG-side paring, for PNGs with at least 3 reps
onset_png_stats = get_onset_statistics(onsets_concat["onset"], ["mean", "std"])
onset_png_stats.head(10)

In [None]:
# Plot distribution of onsets across PNGs and mean vs stdev of onsets per PNG
num_bins = 60
ymax = 0.04
dy = 0.02

f, axes = plt.subplots(1, 2, figsize=(5.5, 5.5 / 3))

ax: plt.Axes = axes[0]
ax.hist(
    onsets_df["onset"].values, num_bins, density=True, edgecolor="C0", linewidth=0.5
)
ax.set_axisbelow(True)
ax.grid()
ax.set_xlim([0, 200])
ax.set_ylim([0, ymax])
ax.set_yticks(np.arange(0, ymax + dy, dy))
ax.set_ylabel("Frequency")
ax.set_xlabel("Onset [ms]")

ax = plot_av_stdev(onset_png_stats["mean"], onset_png_stats["std"], axes=axes[-1])

f.tight_layout()

filedir = OUTPUT_DIR / f"fig_png_onset_distr_n3p2_{model_type}.pdf"
viz.save_figure(f, filedir, overwrite=False)

In [None]:
# Panel A) stats
counts, bin_edges = np.histogram(onsets_df["onset"], num_bins)

arg_max = counts.argmax()
print(f"Mode: {np.mean(bin_edges[arg_max : arg_max + 2]):.1f} ms")

cpr = counts.cumsum() / counts.cumsum()[-1]
arg_mid = np.abs(0.5 - cpr).argmin()
print(f"Median: {bin_edges[arg_mid]:.1f} ms")

In [None]:
# Panel B) select stats
mask = onset_png_stats["mean"].between(0, 30)
onset_png_stats["mean"][mask].mean(), onset_png_stats["mean"][mask].std()
print(
    f"Mean of mean onsets between 0 and 30 ms: {onset_png_stats['mean'][mask].mean():.1f} ms"
)
print(
    f"Standard deviation of mean onsets between 0 and 30 ms: {onset_png_stats['mean'][mask].std():.1f} ms"
)

**ii) Plot mean onset vs. F1 score**

- Demonstrate significant trend: earlier onset time with higher F1 score
- Do linear regression
- Ensure data points are aligned with the above scatter plot

In [None]:
# Get onset vs score per PNG-side pairing (for PNGs with at least 3 reps)
mask = onsets_concat.index.droplevel("rep").isin(onset_png_stats.index)
png_shape_onset_score = (
    onsets_concat.loc[mask]
    .groupby(["side", "png"])
    .agg(onset=("onset", "mean"), score=("score", "mean"))
)
png_shape_df = png_shape_onset_score.reset_index(drop=True)

In [None]:
# Regress against individual data points
xs = png_shape_df["score"].values
ys = png_shape_df["onset"].values

slope, intercept, r_value, p_value, std_err = linregress(xs, ys)
xs_est = np.array([0, 1])
ys_est = slope * xs_est + intercept

print(
    f"Gradient: {slope:.1f}; intercept: {intercept:.1f}; r^2: {r_value**2:.3f}; "
    f"p_value: {p_value:0.0e}; std_err: {std_err:.1f}; num pts: {len(xs)}"
)

In [None]:
width = 5.5

axes: plt.Axes
f, axes = plt.subplots(figsize=(width, width * 1 / 2))
axes.scatter(
    png_shape_df["score"].values,
    png_shape_df["onset"].values,
    s=8,
    alpha=0.6,
    color="C0",
)

# Regression line and 95% CI band
xs_line = np.linspace(0, 1.0, 200)
ys_line = slope * xs_line + intercept
n = len(xs)
xbar = xs.mean()
ssxx = np.sum((xs - xbar) ** 2)
residuals = ys - (slope * xs + intercept)
syx = np.sqrt(np.sum(residuals**2) / (n - 2))
t_crit = spstats.t.ppf(0.975, df=n - 2)
se_mean = syx * np.sqrt(1 / n + (xs_line - xbar) ** 2 / ssxx)
ci_lower = ys_line - t_crit * se_mean
ci_upper = ys_line + t_crit * se_mean

axes.plot(xs_line, ys_line, "C1-", linewidth=1.5)
axes.fill_between(xs_line, ci_lower, ci_upper, color="C1", alpha=0.2, linewidth=0)

# Regression statistics
beta1_per_0_1 = slope * 0.1
ci_lower_0_1 = (slope - (spstats.t.ppf(1 - 0.025, df=len(xs) - 2) * std_err)) * 0.1
ci_upper_0_1 = (slope + (spstats.t.ppf(1 - 0.025, df=len(xs) - 2) * std_err)) * 0.1
p_str = (
    f"p < 10$^{{{int(np.floor(np.log10(p_value)))}}}$"
    if p_value < 0.001
    else f"p = {p_value:.3f}"
)
axes.text(
    0.98,
    0.95,
    "Linear fit: $\\beta_1$ = "
    f"{beta1_per_0_1:.1f} ms per 0.1 F1 "
    f"(95% CI {ci_lower_0_1:.1f}, {ci_upper_0_1:.1f})",
    transform=axes.transAxes,
    ha="right",
    va="top",
    fontsize="small",
    bbox=dict(facecolor="white", alpha=1),
)

axes.set_axisbelow(True)
axes.grid()
axes.set_xlim([0, 1.0])
axes.set_ylim([0, 200])
axes.set_xticks(np.arange(0, 1.2, 0.2))
axes.set_yticks(np.arange(0, 220, 40))
axes.set_ylabel("Mean onset [ms]")
axes.set_xlabel("F1 score")
f.tight_layout()

# Save figure
filedir = OUTPUT_DIR / f"fig_png_onset_f1_n3p2_{model_type}.pdf"
viz.save_figure(f, filedir, overwrite=False)

In [None]:
# Calculate 95% Confidence Interval
n = len(xs)
ci_interval = 0.95
t_crit = spstats.t.ppf(
    1 - (1 - ci_interval) / 2, df=n - 2
)  # Two-tailed t-score for 95%
ci_margin = t_crit * std_err

ci_lower = slope - ci_margin
ci_upper = slope + ci_margin

print(f"Report: Slope = {slope:.2f} Â± {std_err:.2f} (SE)")
print(f"Report: Slope = {slope:.2f} [95% CI: {ci_lower:.2f}, {ci_upper:.2f}]")

In [None]:
# Print regression metadata
regression_meta = {
    "model_type": model_type,
    "analysis_type": analysis_type,
    "n_samples": int(len(xs)),
    "slope": float(slope),
    "intercept": float(intercept),
    "r_value": float(r_value),
    "r_squared": float(r_value**2),
    "p_value": float(p_value),
    "std_err": float(std_err),
    "ci_interval": float(ci_interval),
    "ci_lower": float(ci_lower),
    "ci_upper": float(ci_upper),
}

meta_path = OUTPUT_DIR / f"regression_onset_f1_n3p2_{model_type}.json"
pd.Series(regression_meta).to_json(meta_path, indent=4)
print(f"Saved regression metadata -> {meta_path}")
