# Robustness & Sensitivity Analysis for HFB Detection

This notebook explores:
- Sensitivity of detected PNG counts to temporal-span and timing-tolerance parameters.
- Compositional reuse of high-level feature neurons across multiple PNGs.

**Dependencies:**

---

Significance testing:
- PNG detection and significance testing for N3P2: after network training
- **This workflow is time-consuming to run**
- Already existing detections will be skipped
```bash
./scripts/run_main_workflow.py experiments/n3p2/train_n3p2_lrate_0_04_181023 31 --rule significance --chkpt -1 -v
```

---

**Analyses performed**

1. Window Sensitivity:
    - Filter significant HFBs by empirical temporal span and report counts per layer.
2. Timing Tolerance Strictness:
    - Re-apply delay-matching criterion with tighter tolerances (δt = 0, 1, 2, 3 ms) to assess robustness of HFB classification.
3. Composition Metric:
    - Count distinct HFB circuits per high-level neuron to provide evidence that triplets form larger assemblies.

**Plots:**

- Supplementary S7 Fig
- Supplementary S8 Fig

In [None]:
from collections import defaultdict
from typing import Sequence

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import hsnn.analysis.png.db as polydb
from hsnn import viz
from hsnn.analysis.png import PNG, refinery
from hsnn.simulation import Simulator
from hsnn.utils import handler, io

from hsnn.analysis.png.refinery import FilterBySpan

pidx = pd.IndexSlice

# Setup plotting
viz.setup_journal_env()

In [None]:
# === CONFIGURATION ===

# Experiment settings
EXPERIMENT_NAME = "n3p2/train_n3p2_lrate_0_04_181023"
TRIAL_INDEX = 31
CHECKPOINT_INDEX = -1

# Analysis settings
RESTRICT_LAYERS = None  # Set to [4] for quick L4-only analysis, None for all layers
SPAN_THRESHOLDS_MS = [4, 8, 12, 16, 20]  # Window sensitivity thresholds
DT_THRESHOLDS_MS = [0, 1, 2, 3]  # Timing tolerance strictness (baseline is 3 ms)

# HFB constraint parameters (baseline)
W_MIN = 0.5  # Minimum synaptic weight
TOL_BASELINE = 3.0  # Baseline delay tolerance (ms)

# PNG database settings
POSITION = 1  # Second-firing neuron (high-level) index position
NRN_IDS = range(4096)  # All possible neuron IDs

# Setup output directory
OUTPUT_DIR = io.BASE_DIR / "out/figures/supplementary/fig_S7_S8"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
print(f"Results will be saved to: {OUTPUT_DIR}")

## 1. Load Experiment Data

Load the trained network's synaptic parameters and significant HFB detections from the canonical database.

In [None]:
# Load experiment and trial
expt = handler.ExperimentHandler(EXPERIMENT_NAME)
trial = expt[TRIAL_INDEX]

print(f"Experiment: {expt.name}")
print(f"Trial: {trial.name}")

In [None]:
# Load network and synaptic parameters
store = handler.ArtifactStore(trial, ckpt_idx=CHECKPOINT_INDEX)
cfg = trial.config

sim = Simulator.from_config(cfg)
sim.restore(store.checkpoint.store_path)
print(f"Restored network from: {store.checkpoint.store_path.relative_to(expt.logdir.parent)}")

# Get synaptic parameters (weights and delays)
syn_params: pd.DataFrame = sim.network.get_syn_params(return_delays=True)

# Filter to plastic projections (FF, E2E)
projs_plastic = ('FF', 'E2E')
mask = syn_params.index.get_level_values('proj').isin(projs_plastic)
syn_params = syn_params.loc[mask].copy()

print(f"\nSynaptic parameters shape: {syn_params.shape}")
print(f"Projections: {syn_params.index.get_level_values('proj').unique().tolist()}")
print(f"Layers: {syn_params.index.get_level_values('layer').unique().tolist()}")

In [None]:
# Load significant HFB detections
db_sgnf = polydb.PNGDatabase.from_trial(trial, CHECKPOINT_INDEX, sgnf=True)

if not db_sgnf.exists:
    raise FileNotFoundError(f"Significant HFB database not found: {db_sgnf.path}")

print(f"Loaded significant HFB database: {db_sgnf.path}")

In [None]:
# Determine layers to analyse
available_layers = sorted(syn_params.index.get_level_values('layer').unique())
if RESTRICT_LAYERS is not None:
    layers_to_analyse = [layer for layer in RESTRICT_LAYERS if layer in available_layers]
else:
    layers_to_analyse = available_layers

print(f"Layers to analyze: {layers_to_analyse}")

In [None]:
# Load all significant HFBs for the layers of interest
polygrps_by_layer: dict[int, list[PNG]] = {}

for layer in layers_to_analyse:
    polygrps = db_sgnf.get_pngs(layer, NRN_IDS, POSITION)
    polygrps_by_layer[layer] = polygrps
    print(f"Layer {layer}: {len(polygrps)} significant HFBs")

total_hfbs = sum(len(p) for p in polygrps_by_layer.values())
print(f"\nTotal significant HFBs: {total_hfbs}")

## 2. Window Sensitivity Analysis

For each significant HFB circuit, compute the empirical temporal span:

$$\text{span}_{\text{ms}} = t_{\text{last}} - t_{\text{first}} = \max(\text{lags}) - \min(\text{lags})$$

Then filter by effective window thresholds and report counts per layer.

This analysis tests whether HFB detection results are robust to the choice of temporal window. If counts remain stable across different span thresholds, it suggests the detected circuits are not artifacts of the specific window choice.

In [None]:
def get_png_spans(pngs: Sequence[PNG]) -> np.ndarray:
    """Extract temporal spans from a sequence of PNGs."""
    return np.array([float(np.max(p.lags) - np.min(p.lags)) for p in pngs])

In [None]:
# Compute span distribution for all HFBs
all_spans = []
for layer, polygrps in polygrps_by_layer.items():
    spans = get_png_spans(polygrps)
    all_spans.extend(spans)
    print(f"Layer {layer}: span range [{spans.min():.1f}, {spans.max():.1f}] ms, "
          f"mean={spans.mean():.2f} ms, median={np.median(spans):.2f} ms")

all_spans = np.array(all_spans)
print(f"\nOverall: span range [{all_spans.min():.1f}, {all_spans.max():.1f}] ms, "
      f"mean={all_spans.mean():.2f} ms, median={np.median(all_spans):.2f} ms")

In [None]:
# Plot span distribution
fig, ax = plt.subplots(figsize=(5.5, 2))

bins = np.arange(0, max(all_spans) + 2, 1)
ax.hist(all_spans, bins=bins, edgecolor='black', alpha=1)
ax.axvline(all_spans.mean(), color='C1', linestyle='--', label=f'Mean ({all_spans.mean():.1f} ms)')

ax.set_xticks(np.arange(0, 22, 2))
ax.set_xlim(0, 20)
ax.set_xlabel('Temporal span [ms]')
ax.set_ylabel('# PNGs')
ax.legend()
ax.set_axisbelow(True)
ax.grid(True)

fig.tight_layout()
viz.save_figure(fig, OUTPUT_DIR / 'span_distribution.pdf', overwrite=False)

In [None]:
# Window sensitivity: count HFBs per layer per span threshold
span_sensitivity_results = []

for span_thresh in SPAN_THRESHOLDS_MS:
    filter_span = FilterBySpan(max_span=span_thresh)

    for layer, polygrps in polygrps_by_layer.items():
        filtered = filter_span(polygrps)
        span_sensitivity_results.append({
            'span_threshold_ms': span_thresh,
            'layer': layer,
            'count': len(filtered),
            'count_original': len(polygrps),
            'retention_pct': 100 * len(filtered) / len(polygrps) if len(polygrps) > 0 else 0
        })

span_df = pd.DataFrame(span_sensitivity_results)
print("Window Sensitivity Results:")
print(span_df.to_string(index=False))

In [None]:
# Pivot table for cleaner display
span_pivot = span_df.pivot(index='layer', columns='span_threshold_ms', values='count')
span_pivot.columns = [f'≤{c} ms' for c in span_pivot.columns]
span_pivot.index = [f'L{i}' for i in span_pivot.index]

# Add totals
span_pivot.loc['Total'] = span_pivot.sum()

print("\nHFB Counts by Span Threshold:")
print(span_pivot)

# Save to CSV
span_pivot.to_csv(OUTPUT_DIR / 'sensitivity_counts_by_span.csv')
print(f"\nSaved: {OUTPUT_DIR / 'sensitivity_counts_by_span.csv'}")

In [None]:
# Plot: HFB counts vs span threshold
fig, ax = plt.subplots(figsize=(5.5, 2.5))

for idx, layer in enumerate(layers_to_analyse):
    layer_data = span_df[span_df['layer'] == layer]
    ax.plot(layer_data['span_threshold_ms'], layer_data['count'],
            marker='o', label=f'L{layer}', linewidth=1.5, markersize=6, alpha=1)

# Add total line
total_counts = span_df.groupby('span_threshold_ms')['count'].sum()
ax.plot(total_counts.index, total_counts.values,
        marker='s', label='Total', linewidth=1.5, markersize=6,
        color='black', linestyle='--')

ax.set_xlabel('Maximum span threshold [ms]')
ax.set_ylabel('# PNGs')
ax.legend(loc='best', fontsize="small")
ax.grid(True, alpha=1)
ax.set_xticks(SPAN_THRESHOLDS_MS)
ax.set_ylim(None, 10000 * 1.11)

fig.tight_layout()
viz.save_figure(fig, OUTPUT_DIR / 'sensitivity_window.pdf', overwrite=False)

## 3. Timing Tolerance (δt) Strictness Analysis

Re-apply the delay-matching criterion with **stricter** tolerances to assess robustness of HFB classification.

**HFB delay constraint**

For a 3-neuron HFB circuit with neurons (L, H, B) in firing order:
- The composite delay relation must hold: $d_{L \to B} \approx d_{L \to H} + d_{H \to B}$ within tolerance δt
- All synaptic weights must exceed $w_{\min}$

**Limitation**

We only evaluate strictness (tightening δt from the baseline 3 ms). Relaxing δt beyond baseline could admit additional circuits not represented in the canonical merged results, which would require re-running detection before merging.

In [None]:
# Timing tolerance strictness analysis
dt_sensitivity_results = []

for dt_thresh in DT_THRESHOLDS_MS:
    # Create Constrained refinery with specified tolerance
    filter_constrained = refinery.Constrained(syn_params, w_min=W_MIN, tol=dt_thresh)

    for layer, polygrps in polygrps_by_layer.items():
        # Apply stricter constraint
        filtered = filter_constrained(polygrps)

        dt_sensitivity_results.append({
            'dt_tolerance_ms': dt_thresh,
            'layer': layer,
            'count': len(filtered),
            'count_original': len(polygrps),
            'retention_pct': 100 * len(filtered) / len(polygrps) if len(polygrps) > 0 else 0
        })

dt_df = pd.DataFrame(dt_sensitivity_results)
print("Timing Tolerance Strictness Results:")
print(dt_df.to_string(index=False))

In [None]:
# Pivot table for cleaner display
dt_pivot = dt_df.pivot(index='layer', columns='dt_tolerance_ms', values='count')
dt_pivot.columns = [f'δt≤{c} ms' for c in dt_pivot.columns]
dt_pivot.index = [f'L{i}' for i in dt_pivot.index]

# Add totals
dt_pivot.loc['Total'] = dt_pivot.sum()

print("\nHFB Counts by Timing Tolerance:")
print(dt_pivot)

# Save to CSV
dt_pivot.to_csv(OUTPUT_DIR / 'sensitivity_counts_by_dt.csv')
print(f"\nSaved: {OUTPUT_DIR / 'sensitivity_counts_by_dt.csv'}")

In [None]:
# Plot: HFB counts vs timing tolerance
fig, ax = plt.subplots(figsize=(5.5, 2.5))

for layer in layers_to_analyse:
    layer_data = dt_df[dt_df['layer'] == layer]
    ax.plot(layer_data['dt_tolerance_ms'], layer_data['count'],
            marker='o', label=f'L{layer}', linewidth=2, markersize=6)

# Add total line
total_counts = dt_df.groupby('dt_tolerance_ms')['count'].sum()
ax.plot(total_counts.index, total_counts.values,
        marker='s', label='Total', linewidth=1.5, markersize=6,
        color='black', linestyle='--')

ax.set_xlabel(r'Timing tolerance $\delta t$ [ms]')
ax.set_ylabel('# PNGs')
ax.legend(loc='best')
ax.grid(True, alpha=1)
ax.set_xticks(DT_THRESHOLDS_MS)
ax.set_ylim(None, 10000 * 1.11)

fig.tight_layout()
viz.save_figure(fig, OUTPUT_DIR / 'sensitivity_dt.pdf', overwrite=False)

## 4. Composition Analysis

For each high-level neuron (H) in significant HFB circuits, count the number of distinct circuits that share that neuron. This provides evidence that triplet HFBs can form larger assemblies.

### Interpretation

If high-level neurons participate in multiple distinct HFB circuits, it suggests that:
1. The triplet motifs are not isolated structures
2. Higher-order assemblies may emerge through shared neurons
3. The network has developed compositional representations

In [None]:
def compute_hfbs_per_high_neuron(polygrps: Sequence[PNG]) -> dict[int, int]:
    """Count distinct HFB circuits per high-level neuron.

    The high-level neuron is the second-firing neuron (position 1 in the triplet).

    Returns:
        Dictionary mapping high-level neuron ID to count of distinct circuits.
    """
    counts = defaultdict(int)
    for png in polygrps:
        # High-level neuron is at position 1 (second-firing)
        high_nrn = png.nrns[1]
        counts[high_nrn] += 1
    return dict(counts)

In [None]:
# Compute composition metrics for each layer
composition_results = []

for layer, polygrps in polygrps_by_layer.items():
    hfbs_per_high = compute_hfbs_per_high_neuron(polygrps)
    counts = list(hfbs_per_high.values())

    if counts:
        composition_results.append({
            'layer': layer,
            'num_high_neurons': len(counts),
            'total_hfbs': sum(counts),
            'mean_hfbs_per_high': np.mean(counts),
            'median_hfbs_per_high': np.median(counts),
            'max_hfbs_per_high': np.max(counts),
            'min_hfbs_per_high': np.min(counts),
            'std_hfbs_per_high': np.std(counts)
        })

        print(f"\nLayer {layer}:")
        print(f"  High-level neurons participating in HFBs: {len(counts)}")
        print(f"  HFBs per high neuron: mean={np.mean(counts):.2f}, "
              f"median={np.median(counts):.1f}, max={np.max(counts)}")

composition_df = pd.DataFrame(composition_results)
composition_df.set_index('layer', inplace=True)
composition_df.index = [f'L{i}' for i in composition_df.index]

In [None]:
# Display composition summary
print("\nComposition Summary:")
print(composition_df.round(2))

# Save to CSV
composition_df.to_csv(OUTPUT_DIR / 'composition_metrics.csv')
print(f"\nSaved: {OUTPUT_DIR / 'composition_metrics.csv'}")

In [None]:
# Plot distribution of HFBs per high-level neuron (combined across layers)
all_hfbs_per_high = []
for layer, polygrps in polygrps_by_layer.items():
    hfbs_per_high = compute_hfbs_per_high_neuron(polygrps)
    all_hfbs_per_high.extend(hfbs_per_high.values())

all_hfbs_per_high = np.array(all_hfbs_per_high)

fig, ax = plt.subplots(figsize=(5, 2.5))

bins = np.arange(0.5, max(all_hfbs_per_high) + 1.5, 1)
ax.hist(all_hfbs_per_high, bins=bins, edgecolor='black', alpha=1)
ax.axvline(all_hfbs_per_high.mean(), color='C1', linestyle='--',
           label=f'Mean ({all_hfbs_per_high.mean():.1f})')

ax.set_xlim(0, None)

ax.set_xlabel('# PNGs per high-level neuron')
ax.set_ylabel('Count')  # 'Count (# high-level neurons)'
ax.legend()
ax.set_axisbelow(True)
ax.grid(True, alpha=1)

fig.tight_layout()
viz.save_figure(fig, OUTPUT_DIR / 'composition_distribution.pdf', overwrite=False)

In [None]:
# Per-layer composition distributions
if len(layers_to_analyse) > 1:
    fig, axes = plt.subplots(2, 2, figsize=(5.5, 3), sharex=True, sharey=True)
    axes = axes.flatten()
    ncols = 2

    for idx, (ax, layer) in enumerate(zip(axes, layers_to_analyse)):
        hfbs_per_high = compute_hfbs_per_high_neuron(polygrps_by_layer[layer])
        counts = list(hfbs_per_high.values())

        if counts:
            mean_count = float(np.mean(counts))
            bins = np.arange(0.5, max(counts) + 1.5, 1)
            ax.hist(counts, bins=bins, edgecolor='black', linewidth=0.5, alpha=1)
            ax.axvline(mean_count, color='C1', linestyle='--', linewidth=1.5,
                       label=f'Mean ({mean_count:.1f})')
            ax.set_title(f'Layer {layer}', fontsize="medium", fontweight="bold")
            if idx >= ncols:
                ax.set_xlabel('# PNGs per high neuron')
            ax.legend()
            ax.set_axisbelow(True)
            ax.grid(True, alpha=1)

    for ax in axes[len(layers_to_analyse):]:
        ax.axis('off')

    axes[0].set_xlim(0, 20)
    axes[0].set_ylim(0, None)
    axes[0].set_ylabel('Count')
    axes[2].set_ylabel('Count')
    fig.tight_layout()
    viz.save_figure(fig, OUTPUT_DIR / 'composition_distribution_by_layer.pdf', overwrite=False)

In [None]:
# Composition analysis across span thresholds
composition_by_span = []

for span_thresh in SPAN_THRESHOLDS_MS:
    filter_span = FilterBySpan(max_span=span_thresh)

    for layer, polygrps in polygrps_by_layer.items():
        filtered = filter_span(polygrps)
        hfbs_per_high = compute_hfbs_per_high_neuron(filtered)
        counts = list(hfbs_per_high.values())

        if counts:
            composition_by_span.append({
                'span_threshold_ms': span_thresh,
                'layer': layer,
                'num_high_neurons': len(counts),
                'mean_hfbs_per_high': np.mean(counts),
                'max_hfbs_per_high': np.max(counts)
            })

composition_span_df = pd.DataFrame(composition_by_span)
print("\nComposition Metrics by Span Threshold:")
print(composition_span_df.round(2).to_string(index=False))

# Save
composition_span_df.to_csv(OUTPUT_DIR / 'composition_by_span.csv', index=False)
print(f"\nSaved: {OUTPUT_DIR / 'composition_by_span.csv'}")

## 5. Summary

### Key Findings

In [None]:
# Generate summary statistics
print("=" * 60)
print("ROBUSTNESS & SENSITIVITY ANALYSIS SUMMARY")
print("=" * 60)

print(f"\nExperiment: {EXPERIMENT_NAME}")
print(f"Trial: {TRIAL_INDEX}")
print(f"Layers analysed: {layers_to_analyse}")
print(f"Total significant HFBs (baseline): {total_hfbs}")

print("\n" + "-" * 60)
print("1. WINDOW SENSITIVITY")
print("-" * 60)
print(f"Span thresholds tested: {SPAN_THRESHOLDS_MS} ms")
print(f"\nRetention at smallest window ({SPAN_THRESHOLDS_MS[0]} ms):")
min_span = SPAN_THRESHOLDS_MS[0]
min_span_total = span_df[span_df['span_threshold_ms'] == min_span]['count'].sum()
print(f"  {min_span_total}/{total_hfbs} ({100*min_span_total/total_hfbs:.1f}%)")

print("\n" + "-" * 60)
print("2. TIMING TOLERANCE STRICTNESS")
print("-" * 60)
print(f"Tolerance thresholds tested: {DT_THRESHOLDS_MS} ms (baseline = {TOL_BASELINE} ms)")
print(f"\nRetention at strictest tolerance ({DT_THRESHOLDS_MS[0]} ms):")
min_dt = DT_THRESHOLDS_MS[0]
min_dt_total = dt_df[dt_df['dt_tolerance_ms'] == min_dt]['count'].sum()
print(f"  {min_dt_total}/{total_hfbs} ({100*min_dt_total/total_hfbs:.1f}%)")

print("\n" + "-" * 60)
print("3. COMPOSITION ANALYSIS")
print("-" * 60)
print(f"High-level neurons participating in HFBs: {len(all_hfbs_per_high)}")
print(f"Mean HFBs per high neuron: {all_hfbs_per_high.mean():.2f}")
print(f"Max HFBs per high neuron: {all_hfbs_per_high.max()}")
multi_hfb_neurons = np.sum(all_hfbs_per_high > 1)
print(f"Neurons with >1 HFB: {multi_hfb_neurons} ({100*multi_hfb_neurons/len(all_hfbs_per_high):.1f}%)")

print("\n" + "=" * 60)
print(f"Results saved to: {OUTPUT_DIR}")
print("=" * 60)

In [None]:
# Save summary to text file
summary_path = OUTPUT_DIR / 'analysis_summary.txt'
with open(summary_path, 'w') as f:
    f.write("ROBUSTNESS & SENSITIVITY ANALYSIS SUMMARY\n")
    f.write("=" * 50 + "\n\n")
    f.write(f"Experiment: {EXPERIMENT_NAME}\n")
    f.write(f"Trial: {TRIAL_INDEX}\n")
    f.write(f"Layers analysed: {layers_to_analyse}\n")
    f.write(f"Total significant HFBs (baseline): {total_hfbs}\n\n")

    f.write("WINDOW SENSITIVITY\n")
    f.write("-" * 30 + "\n")
    f.write(span_pivot.to_string() + "\n\n")

    f.write("TIMING TOLERANCE STRICTNESS\n")
    f.write("-" * 30 + "\n")
    f.write(dt_pivot.to_string() + "\n\n")

    f.write("COMPOSITION METRICS\n")
    f.write("-" * 30 + "\n")
    f.write(composition_df.to_string() + "\n")

print(f"Summary saved to: {summary_path}")