# PNG detections: Retention across pipeline stages

Stage-wise retention of three-neuron PNGs through the HFB detection pipeline.

**Pipeline stages**

1. Unconstrained (`hfb_unconstrained.db`): All detected triplet PNG candidates with layer structure `[L-1, L, L]` and synaptic connections

2. Eligible (`hfb.db`): PNGs **after** applying manuscript eligibility criteria (i.e., delay-matched pathways to within 3 ms, synaptic weights w ≥ 0.5).

3. Significant (`hfb_sgnf.db`): Final retained PNGs **after** statistical significance testing (surrogate-based).

**Dependencies:**

---

A) Significance testing:
- PNG detection and significance testing for N3P2: after network training
- **This workflow is time-consuming to run**
- Already existing detections will be skipped
```bash
./scripts/run_main_workflow.py experiments/n3p2/train_n3p2_lrate_0_04_181023 31 --rule significance --chkpt -1 -v
```

B) Unconstrained PNG detections:

```bash
./scripts/analysis/detection_unconstrained.py ./experiments/n3p2/train_n3p2_lrate_0_04_181023/ 31 --chkpt -1 -v
```

---

**Plots**

- N3P2 and N4P2 Figs
- Panel A: rank-order single neuron information curves
- Panel B: number of selective neurons (exceeding 2/3 threshold)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.ticker import ScalarFormatter
from tqdm.notebook import tqdm

import hsnn.analysis.png.db as polydb
from hsnn import viz, pipeline
from hsnn.analysis.png import PNG
from hsnn.simulation import Simulator
from hsnn.utils import handler, io

pidx = pd.IndexSlice

## Configuration

In [None]:
EXPERIMENT_NAME = "n3p2/train_n3p2_lrate_0_04_181023"
CHECKPOINT_INDEX = -1

# Analysis parameters
LAYERS = [2, 3, 4]  # Layers to analyse (triplet structure [L-1, L, L])

# DB retrieval parameters
NRN_IDS = range(4096)
POSITION = 1

# Output directory
OUTPUT_DIR = io.BASE_DIR / "out/figures/supplementary/fig_S5"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# Setup plotting
viz.setup_journal_env()

## 1. Load Experiment and Identify Detection Trials

In [None]:
expt = handler.ExperimentHandler(EXPERIMENT_NAME)
print(f"Experiment: {expt.logdir.relative_to(io.BASE_DIR)}")

# Get trial view(s)
trials = [expt["TrainSNN_eb0d4_00031"]]
print(f"Selected trial(s): {trials}")

## 2. Define Database Loading Functions

In [None]:
def load_db(
    trial: handler.TrialView, db_name: str, chkpt_idx: int = -1
) -> polydb.PNGDatabase | None:
    """Load a PNG database by name for a given trial.

    Args:
        trial: Trial view object
        db_name: One of 'hfb_unconstrained.db', 'hfb.db', 'hfb_sgnf.db'
        chkpt_idx: Checkpoint index (-1 for last)

    Returns:
        PNGDatabase or None if not found
    """
    checkpoint = trial.checkpoints[chkpt_idx]
    db_path = checkpoint.path / db_name

    if not db_path.exists():
        return None

    return polydb.PNGDatabase(db_path)


def load_syn_params(trial: handler.TrialView, chkpt_idx: int = -1) -> pd.DataFrame:
    """Load synaptic parameters for a trial checkpoint.

    Args:
        trial: Trial view object
        chkpt_idx: Checkpoint index (-1 for last)

    Returns:
        DataFrame with synaptic parameters indexed by [layer, proj, pre, post]
    """
    checkpoint = trial.checkpoints[chkpt_idx]
    sim = Simulator.from_config_restore(trial.config, file_path=checkpoint.store_path)
    return sim.network.get_syn_params(return_delays=True)


def get_png_set(
    db: polydb.PNGDatabase, layer: int, nrn_ids: range, position: int
) -> set[PNG]:
    """Extract unique PNGs for a layer as a set.

    Uses PNG __hash__ which is based on (layers, nrns) tuples only.
    """
    polygrps = db.get_pngs(layer, nrn_ids, position)
    return set(polygrps)


def get_filtered_png_set(
    db: polydb.PNGDatabase,
    layer: int,
    nrn_ids: range,
    position: int,
    syn_params: pd.DataFrame,
) -> tuple[set[PNG], dict]:
    """Extract unique PNGs for a layer, filtered for valid HFB connectivity.

    Applies filter_valid_hfb_triplets to handle tied-lag ambiguity and
    verify synaptic connectivity.

    Args:
        db: PNG database
        layer: Layer to query
        nrn_ids: Neuron IDs to include
        position: Position index for query
        syn_params: Synaptic parameters for connectivity checking

    Returns:
        Tuple of (png_set, filter_stats)
    """
    polygrps = db.get_pngs(layer, nrn_ids, position)
    valid_pngs, stats = pipeline.filter_valid_hfb_triplets(polygrps, syn_params)
    return set(valid_pngs), stats


def validate_layer_structure(png: PNG, expected_layer: int) -> bool:
    """Validate that PNG has expected [L-1, L, L] structure."""
    expected_structure = [expected_layer - 1, expected_layer, expected_layer]
    return list(png.layers) == expected_structure

## 3. Load Databases for All Trials

In [None]:
DB_NAMES = {
    "unconstrained": "hfb_unconstrained.db",
    "eligible": "hfb.db",
    "significant": "hfb_sgnf.db",
}

# Load databases for all trials
trial_dbs: dict[str, dict[str, polydb.PNGDatabase]] = {}

for trial in tqdm(trials, desc="Loading trial databases"):
    trial_id = trial.name
    dbs = {}

    for stage, db_name in DB_NAMES.items():
        db = load_db(trial, db_name, CHECKPOINT_INDEX)
        if db is not None:
            dbs[stage] = db
            print(f"Trial {trial_id}: Loaded {stage} ({db_name})")
        else:
            print(f"Trial {trial_id}: Missing {stage} ({db_name})")

    if len(dbs) == 3:  # Only include trials with all three DBs
        trial_dbs[trial_id] = dbs

print(f"\nTrials with complete DB sets: {len(trial_dbs)}")

## 4. Extract PNG Sets and Compute Statistics

In [None]:
def compute_layer_stats(
    dbs: dict,
    layer: int,
    nrn_ids: range,
    position: int,
    syn_params: pd.DataFrame,
) -> dict:
    """Compute PNG counts and fractions for a single layer.

    Applies connectivity filtering to unconstrained PNGs to ensure fair
    comparison across pipeline stages.

    Args:
        dbs: Dict mapping stage names to PNGDatabase objects
        layer: Layer to analyze (triplet structure [L-1, L, L])
        nrn_ids: Range of neuron IDs to include
        position: Position index for DB query
        syn_params: Synaptic parameters for connectivity filtering

    Returns:
        Dict with counts, fractions, and PNG sets
    """
    # Extract PNG sets - filter unconstrained for valid connectivity
    U, filter_stats = get_filtered_png_set(
        dbs["unconstrained"], layer, nrn_ids, position, syn_params
    )
    E = get_png_set(dbs["eligible"], layer, nrn_ids, position)
    S = get_png_set(dbs["significant"], layer, nrn_ids, position)

    # Validate layer structure
    for png_set, stage_name in [
        (U, "unconstrained"),
        (E, "eligible"),
        (S, "significant"),
    ]:
        for png in png_set:
            assert validate_layer_structure(png, layer), (
                f"Invalid layer structure in {stage_name}: {png.layers} "
                f"(expected [{layer - 1}, {layer}, {layer}])"
            )

    # Compute counts
    n_unconstrained = len(U)
    n_eligible = len(E)
    n_significant = len(S)

    # Sanity checks
    assert n_eligible <= n_unconstrained, (
        f"Layer {layer}: N_eligible ({n_eligible}) > N_unconstrained ({n_unconstrained})"
    )
    assert n_significant <= n_eligible, (
        f"Layer {layer}: N_significant ({n_significant}) > N_eligible ({n_eligible})"
    )

    # Compute fractions (handle empty denominators)
    frac_eligible = n_eligible / n_unconstrained if n_unconstrained > 0 else np.nan
    frac_significant = n_significant / n_eligible if n_eligible > 0 else np.nan

    return {
        "n_unconstrained": n_unconstrained,
        "n_eligible": n_eligible,
        "n_significant": n_significant,
        "frac_eligible": frac_eligible,
        "frac_significant_given_eligible": frac_significant,
        "filter_stats": filter_stats,  # Stats from connectivity filtering
        "U": U,
        "E": E,
        "S": S,
    }

In [None]:
# Takes ~2.5 min to run (AMD Ryzen 9 5900X, 64GB RAM)
# Cache path for computed statistics
cache_path = OUTPUT_DIR / ".cache/spade_proportions_stats.pkl"

if cache_path.exists():
    print(f"Loading cached results from: {cache_path}")
    cached = io.load_pickle(cache_path)
    results_df = cached["results_df"]
    pooled_sets = cached["pooled_sets"]
    filter_stats_all = cached.get("filter_stats_all", [])
    print(f"Loaded statistics for {len(results_df)} (trial, layer) combinations")
else:
    # Compute statistics for each trial and layer
    all_results = []
    filter_stats_all = []

    # Also accumulate PNG sets for pooled analysis
    pooled_sets = {"U": set(), "E": set(), "S": set()}

    for trial_id, dbs in tqdm(trial_dbs.items(), desc="Computing statistics"):
        # Load synaptic parameters for this trial
        trial = [t for t in trials if t.name == trial_id][0]
        syn_params = load_syn_params(trial, CHECKPOINT_INDEX)
        print(f"Trial {trial_id}: Loaded syn_params")

        for layer in LAYERS:
            try:
                stats = compute_layer_stats(dbs, layer, NRN_IDS, POSITION, syn_params)

                all_results.append(
                    {
                        "trial_id": trial_id,
                        "layer": layer,
                        "n_unconstrained": stats["n_unconstrained"],
                        "n_eligible": stats["n_eligible"],
                        "n_significant": stats["n_significant"],
                        "frac_eligible": stats["frac_eligible"],
                        "frac_significant_given_eligible": stats[
                            "frac_significant_given_eligible"
                        ],
                    }
                )

                # Store filtering stats
                filter_stats_all.append(
                    {"trial_id": trial_id, "layer": layer, **stats["filter_stats"]}
                )

                # Accumulate for pooled analysis
                pooled_sets["U"].update(stats["U"])
                pooled_sets["E"].update(stats["E"])
                pooled_sets["S"].update(stats["S"])

            except Exception as e:
                print(f"Trial {trial_id}, Layer {layer}: Error - {e}")
                raise

    results_df = pd.DataFrame(all_results)
    print(f"\nComputed statistics for {len(results_df)} (trial, layer) combinations")

    # Save to cache
    io.save_pickle(
        {
            "results_df": results_df,
            "pooled_sets": pooled_sets,
            "filter_stats_all": filter_stats_all,
        },
        cache_path,
        parents=True,
    )
    print(f"Saved results to cache: {cache_path}")

results_df.head(10)

In [None]:
# Display filtering statistics (connectivity validation of unconstrained PNGs)
if filter_stats_all:
    filter_df = pd.DataFrame(filter_stats_all)
    print("Connectivity filtering of unconstrained PNGs:")
    print(filter_df.to_string(index=False))
    print()

    # Aggregate totals
    totals = filter_df[["n_total", "n_valid", "n_tied_lags", "n_tied_lag_duplicates", "n_no_connectivity"]].sum()
    print("Aggregated filtering totals:")
    print(f"  Raw detections:        {totals['n_total']:,}")
    print(f"  Valid (after filter):  {totals['n_valid']:,}")
    print(f"  - Tied H/B lags:       {totals['n_tied_lags']:,}")
    print(f"  - Tied-lag duplicates: {totals['n_tied_lag_duplicates']:,}")
    print(f"  - No connectivity:     {totals['n_no_connectivity']:,}")

## 5. Aggregate Statistics by Layer

In [None]:
# Aggregate by layer: sum counts across trials
layer_agg = (
    results_df.groupby("layer")
    .agg({"n_unconstrained": "sum", "n_eligible": "sum", "n_significant": "sum"})
    .reset_index()
)

# Compute fractions from aggregated counts
layer_agg["frac_eligible"] = layer_agg["n_eligible"] / layer_agg["n_unconstrained"]
layer_agg["frac_significant_given_eligible"] = (
    layer_agg["n_significant"] / layer_agg["n_eligible"]
)

layer_agg

In [None]:
# Compute pooled statistics using set union across all trials/layers
n_unconstrained_pooled = len(pooled_sets["U"])
n_eligible_pooled = len(pooled_sets["E"])
n_significant_pooled = len(pooled_sets["S"])

frac_eligible_pooled = (
    n_eligible_pooled / n_unconstrained_pooled if n_unconstrained_pooled > 0 else np.nan
)
frac_significant_pooled = (
    n_significant_pooled / n_eligible_pooled if n_eligible_pooled > 0 else np.nan
)

print(
    f"Pooled (L2-4): U={n_unconstrained_pooled}, E={n_eligible_pooled}, S={n_significant_pooled}"
)
print(
    f"frac_eligible={frac_eligible_pooled:.4f}, frac_significant={frac_significant_pooled:.4f}"
)

## 6. Build Summary Table

In [None]:
# Build final summary table
summary_rows = []

for _, row in layer_agg.iterrows():
    summary_rows.append(
        {
            "layer": f"L{int(row['layer'])}",
            "N_unconstrained": int(row["n_unconstrained"]),
            "N_eligible": int(row["n_eligible"]),
            "N_significant": int(row["n_significant"]),
            "frac_eligible": row["frac_eligible"],
            "frac_significant_given_eligible": row["frac_significant_given_eligible"],
        }
    )

# Add pooled row
summary_rows.append(
    {
        "layer": "All (2-4)",
        "N_unconstrained": n_unconstrained_pooled,
        "N_eligible": n_eligible_pooled,
        "N_significant": n_significant_pooled,
        "frac_eligible": frac_eligible_pooled,
        "frac_significant_given_eligible": frac_significant_pooled,
    }
)

summary_df = pd.DataFrame(summary_rows)
summary_df = summary_df.set_index("layer")

print("\n" + "=" * 80)
print("PNG RETENTION BY PIPELINE STAGE")
print("=" * 80)
print(f"Experiment: {EXPERIMENT_NAME}")
print(f"Trials analysed: {len(trial_dbs)}")
print(f"Layers: {LAYERS}")
print("=" * 80)
summary_df.round(4)

## 7. Save Summary Table as CSV

In [None]:
csv_path = OUTPUT_DIR / "hfb_retention_fractions_by_stage.csv"
summary_df.to_csv(csv_path)
print(f"Saved summary table: {csv_path}")

## 8. Generate Figure

A single-panel grouped bar chart, showing PNG counts at each pipeline stage (Detected, Eligible, Significant) broken down by layer.

In [None]:
# Prepare data for plotting
layer_labels = list(summary_df.index)
frac_eligible = summary_df["frac_eligible"].values
frac_significant = summary_df["frac_significant_given_eligible"].values

# Get numerators/denominators for annotations
n_u = summary_df["N_unconstrained"].values
n_e = summary_df["N_eligible"].values
n_s = summary_df["N_significant"].values

In [None]:
# Prepare per-layer data (excluding the pooled "All" row)
layer_data = summary_df.iloc[:-1].copy()  # L2, L3, L4 only
pooled_data = summary_df.iloc[-1]  # "All (2-4)" row

# Extract layer indices and counts
layers = [int(l[1]) for l in layer_data.index]  # [2, 3, 4]
layer_labels_plot = [f"L{l}" for l in layers]

# Counts per stage per layer
counts_U = layer_data["N_unconstrained"].astype(int).values
counts_E = layer_data["N_eligible"].astype(int).values
counts_S = layer_data["N_significant"].astype(int).values

# Pooled totals
total_U = int(pooled_data["N_unconstrained"])
total_E = int(pooled_data["N_eligible"])
total_S = int(pooled_data["N_significant"])

# Retention percentages
pct_e_of_u = total_E / total_U * 100
pct_s_of_e = total_S / total_E * 100
pct_s_of_u = total_S / total_U * 100

print(f"Total counts: U={total_U:,}, E={total_E:,}, S={total_S:,}")
print(
    f"Retention: E/U = {pct_e_of_u:.1f}%, S/E = {pct_s_of_e:.1f}%, S/U = {pct_s_of_u:.1f}%"
)

In [None]:
# Create grouped bar chart with legend/stats box on the right
fig, ax = plt.subplots(figsize=(5.5, 2))

# Add a small gap between layer groups
x = np.arange(len(layer_labels_plot)) * 1.2  # L2, L3, L4 only
width = 0.28
bar_offset = width * 1.1  # adds a small gap between adjacent bars within each group

# Create bars with standard matplotlib colors (per-layer only, no Total)
bars_U = ax.bar(
    x - bar_offset,
    counts_U,
    width,
    label="Detected (D)",
    color="C0",
    edgecolor="black",
    linewidth=0.5,
)
bars_E = ax.bar(
    x,
    counts_E,
    width,
    label="Eligible (E)",
    color="C1",
    edgecolor="black",
    linewidth=0.5,
)
bars_S = ax.bar(
    x + bar_offset,
    counts_S,
    width,
    label="Significant (S)",
    color="C2",
    edgecolor="black",
    linewidth=0.5,
)

# Scientific notation formatter for y-axis

ax.yaxis.set_major_formatter(ScalarFormatter(useMathText=True))
ax.ticklabel_format(style="sci", axis="y", scilimits=(0, 0))

# Formatting
ax.set_xlabel("Layer")
ax.set_ylabel("# PNGs")
ax.set_xticks(x)
ax.set_xticklabels(layer_labels_plot)
ax.set_axisbelow(True)
ax.set_ylim(0, 5e4)
ax.grid(axis="y", alpha=1)

# Make room on the right for legend and stats box
fig.subplots_adjust(right=0.68)

# Place legend outside axes, aligned to top-right
legend = ax.legend(
    loc="upper left",
    bbox_to_anchor=(1.02, 1.0),
    bbox_transform=ax.transAxes,
    borderaxespad=0.0,
    fontsize="small",
    frameon=True,
    edgecolor="gray",
)

# Add retention percentage annotation box below legend, outside axes
textstr = (
    f"Overall (n = {total_U:,}):\n"
    f"  E/D = {pct_e_of_u:.1f}%\n"
    f"  S/E = {pct_s_of_e:.1f}%\n"
    f"  S/D = {pct_s_of_u:.1f}%"
)
props = dict(boxstyle="round,pad=0.3", facecolor="white", edgecolor="gray")
ax.text(
    1.03,
    0.55,
    textstr,
    transform=ax.transAxes,
    fontsize=8,
    verticalalignment="top",
    bbox=props,
    # family="monospace",
)

# Save figure
pdf_path_supp = OUTPUT_DIR / "fig_hfb_proportions.pdf"
viz.save_figure(fig, pdf_path_supp, overwrite=False, bbox_inches="tight")

## 9. Summary

In [None]:
print("=" * 70)
print("SUMMARY: PNG Retention Across Pipeline Stages")
print("=" * 70)
print(f"\nExperiment: {EXPERIMENT_NAME}")
print(f"Trials analysed: {len(trial_dbs)}")
print(f"Layers: {LAYERS}")
print()
print("Database stages:")
print("  U = Unconstrained (hfb_unconstrained.db) - all detected triplets")
print("  E = Eligible (hfb.db) - after HFB criteria (w>=0.5, δt=±3ms)")
print("  S = Significant (hfb_sgnf.db) - after surrogate-based testing")
print()
print("-" * 70)
print("Results by Layer:")
print("-" * 70)

for idx in summary_df.index:
    row = summary_df.loc[idx]
    print(f"\n{idx}:")
    print(f"  Unconstrained (U): {int(row['N_unconstrained']):,}")
    print(
        f"  Eligible (E):      {int(row['N_eligible']):,} ({row['frac_eligible']:.1%} of U)"
    )
    print(
        f"  Significant (S):   {int(row['N_significant']):,} ({row['frac_significant_given_eligible']:.1%} of E)"
    )

print("\n" + "=" * 70)
print("Interpretation:")
print("- Eligibility filtering retains PNGs with strengthened synapses (w>=0.5)")
print("  and timing consistent with axonal delays (δt=±3ms)")
print("- Significance testing removes PNGs that could arise by chance")
print("  in shuffled surrogate data")
print("=" * 70)

In [None]:
# Close database connections
for trial_id, dbs in trial_dbs.items():
    for stage, db in dbs.items():
        db.close()
print("Closed all database connections")