# Plot HFB Performance metrics

Quantify the performance of HFBs in terms of their ability to classify convex (or concave) boundary elements.

**Methods**

- Using both experiments (datasets) and averaging over three trials of PNG detections.
- Plot F1-scores for each dataset:
    - Depicts the performance for (FF + LAT) network architecture, and (FF + LAT + FB).
    - Each curve is the aggregate across all convex (or concave) boundaries, and then averaged across three select Trials.
    - The error curve corresponds to the standard deviation across these three Trials.
- F1-score is the harmonic mean of the Precision and Recall.

**Dependencies:**

---

A) Significance testing:
- PNG detection and significance testing for N3P2 & N4P2: both before and after network training
---

B) Compute PNG F1 scores:

- Also run these with the argument `--target 0` for concave selectivity.

i) For N3P2:
```bash
for arch in SEMI ALL; do
    ./scripts/figures/compute_png_metrics.py ./experiments/n3p2/train_n3p2_lrate_0_04_181023 $arch -v
done
```
ii) and N4P2:
```bash
for arch in SEMI ALL; do
    ./scripts/figures/compute_png_metrics.py ./experiments/n4p2/train_n4p2_lrate_0_02_181023 $arch -v
done
```

---

**Results**

- Despite intricate nature of a three-neuron HFB circuit, observe substantial number selective ones.
- Compare metrics across datasets.
- Compare metrics across network archs. (FF + LAT) vs (FF + LAT + FB).

**Plots**

- Fig 14B (convex)
- Supplementary Fig S6 (concave).

In [None]:
from enum import Enum

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from hsnn import viz
from hsnn.utils import io

pidx = pd.IndexSlice
OUTPUT_DIR = io.BASE_DIR / "out/figures/fig14"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
RESULTS_DIR = io.BASE_DIR / "out/figures/detection"
ARCH_DESC_MAPPING = {"FF": "FF", "SEMI": "FF+LAT", "ALL": "FF+LAT+FB"}


class DataSet(Enum):
    N3P2 = 1
    N4P2 = 2


def plot_scores(scores: pd.DataFrame, dataset_name: str, ax: plt.Axes):
    scores_avs = {}
    scores_errs = {}
    for col_name in scores.columns.unique(0):
        scores_avs[col_name] = scores[col_name].mean(axis=1).values
        scores_errs[col_name] = scores[col_name].std(axis=1).values

    for idx, col_name in enumerate(scores_avs.keys()):
        scores_av = scores_avs[col_name]
        scores_err = scores_errs[col_name]
        xticks = np.arange(1, len(scores_av) + 1)
        ax.fill_between(
            xticks,
            scores_av - scores_err,
            scores_av + scores_err,
            color=colors[idx],
            alpha=0.3,
        )
        ax.plot(
            xticks,
            scores_av,
            color=colors[idx],
            linestyle="-",
            label=ARCH_DESC_MAPPING[col_name],
        )
        ax.grid(color="lightgray")


logdirs = {DataSet.N3P2: "n3p2", DataSet.N4P2: "n4p2"}

# Plotting
prop_cycle = plt.rcParams["axes.prop_cycle"]
colors = prop_cycle.by_key()["color"]

viz.setup_journal_env()

### 1) Load performance metrics
- Precision and recall determined per PNG.
- This is used to plot the PNG F1 score as a rank-order plot.

In [None]:
target = "convex"  # "convex" or "concave"

scores_datasets = {}
for dataset, logdir in logdirs.items():
    fpath = RESULTS_DIR / logdir / f"png_f1_max_{target}.pkl"
    scores_datasets[dataset] = io.load_pickle(fpath)

### 2) Plot scores for each model architecture

In [None]:
overwrite = False

width = 5.5
height = 2.5

f, axes = plt.subplots(1, 2, figsize=(width, height), sharex=True, sharey=True)

for idx, (ds, scores_ds) in enumerate(scores_datasets.items()):
    png_f1_scores = pd.concat(
        {k: pd.DataFrame(v.T) for k, v in scores_datasets[ds].items()}, axis=1
    )
    ax = axes[idx]
    plot_scores(png_f1_scores, ds.name, ax=ax)
    ax.set_xscale("log")
    ax.set_xlabel("PNG rank #")
    ax.set_xlim([1, 1e4])
    ax.set_title(ds.name.upper(), pad=10, fontweight="bold")
axes[0].legend(loc="lower left", frameon=True, framealpha=1)
axes[0].set_ylabel("F1 score")
axes[0].set_ylim([0, 1.05])
f.tight_layout()

# Save figure
filedir = OUTPUT_DIR / f"fig_png_f1_{target}_datasets.pdf"
if overwrite or not filedir.exists():
    f.savefig(filedir, format="pdf", dpi=300)
    print(f"Saved figure to '{filedir}'")

### 3) Summary Statistics
Compute key statistical summaries including the number of PNGs with high F1 scores (> 0.9).

In [None]:
# Compute summary statistics for each dataset and architecture
f1_threshold = 0.9

summary_data = []
for ds, scores_ds in scores_datasets.items():
    png_f1_scores = pd.concat(
        {k: pd.DataFrame(v.T) for k, v in scores_ds.items()}, axis=1
    )

    for arch in png_f1_scores.columns.unique(0):
        arch_scores = png_f1_scores[arch]
        # Average across trials for each PNG
        mean_scores = arch_scores.mean(axis=1)

        n_pngs = len(mean_scores)
        n_high_f1 = (mean_scores > f1_threshold).sum()
        pct_high_f1 = 100 * n_high_f1 / n_pngs

        summary_data.append({
            "Dataset": ds.name,
            "Architecture": ARCH_DESC_MAPPING[arch],
            "Total PNGs": n_pngs,
            f"PNGs with F1 > {f1_threshold}": n_high_f1,
            f"% PNGs with F1 > {f1_threshold}": f"{pct_high_f1:.1f}%",
            "Mean F1": f"{mean_scores.mean():.3f}",
            "Median F1": f"{mean_scores.median():.3f}",
            "Max F1": f"{mean_scores.max():.3f}",
        })

summary_df = pd.DataFrame(summary_data)
print(f"Summary Statistics (F1 threshold = {f1_threshold})")
print("=" * 80)
display(summary_df)