# Plot PNG counts pre- / post-training

Emergence of PNGs in hierarchical networks.

**Results:**

- Demonstrate emergence of polychronization in the network after training on datasets.
- Plot PNG counts, considering simplest three-neuron HFB circuit
- Show this with respect to the layers and network architectures

**Dependencies:**

---

A) Significance testing:
- PNG detection and significance testing for N3P2 & N4P2: both before and after network training
- **These workflows are time-consuming to run**
- N3P2 workflows (with and without the `--chkpt -1` argument):
```bash
./scripts/run_main_workflow.py experiments/n3p2/train_n3p2_lrate_0_04_181023 1 3 5 7 9 31 --rule significance -v
```
- N4P2 workflows (with and without the `--chkpt -1` argument):
```bash
./scripts/run_main_workflow.py experiments/n4p2/train_n4p2_lrate_0_02_181023 1 3 5 7 15 29 --rule significance -v
```

---

B) Compute PNG counts:

i) For N3P2:
```bash
for arch in SEMI ALL; do
    ./scripts/figures/compute_png_counts.py ./experiments/n3p2/train_n3p2_lrate_0_04_181023 $arch -v
done
```
ii) and N4P2:
```bash
for arch in SEMI ALL; do
    ./scripts/figures/compute_png_counts.py ./experiments/n4p2/train_n4p2_lrate_0_02_181023 $arch -v
done
```

---

**Plots:**

- Shape: (1, 2)
- Columns: Dataset (N3P2, N4P2)
- For each subplot, show:
    - Untrained counts (FF + LAT + FB)
    - Trained counts (FF + LAT) and (FF + LAT + FB)

In [None]:
from enum import Enum

import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import pandas as pd

from hsnn import viz
from hsnn.utils import io

pidx = pd.IndexSlice
RESULTS_DIR = io.BASE_DIR / "out/figures/detection"
OUTPUT_DIR = io.BASE_DIR / "out/figures/fig14"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)


# Annotations
class DataSet(Enum):
    N3P2 = 1
    N4P2 = 2


ARCH_DESC_MAPPING = {"FF": "FF", "SEMI": "FF+LAT", "ALL": "FF+LAT+FB"}

# Plotting
prop_cycle = plt.rcParams["axes.prop_cycle"]
colors = prop_cycle.by_key()["color"]


viz.setup_journal_env()

### 1) Load PNG counts from both experiments (N3P2, N4P2)

In [None]:
logdirs = {DataSet.N3P2: "n3p2", DataSet.N4P2: "n4p2"}

polygrps_counts: dict[DataSet, dict[str, dict[str, npt.NDArray[np.int_]]]] = {}
for dataset, logdir in logdirs.items():
    polygrps_counts[dataset] = io.load_pickle(RESULTS_DIR / logdir / "png_counts.pkl")
    print(f"Loaded '{dataset.name}' PNG counts from '{RESULTS_DIR / logdir}'")

### 2) Plot PNG counts

- Untrained: `FF + LAT + FB`
- Trained: `FF + LAT`, `FF + LAT + FB`

In [None]:
def plot_counts(
    num_polygrps: np.ndarray, layers: np.ndarray, ax: plt.Axes, label=None, **kwargs
):
    return ax.errorbar(
        layers,
        np.mean(num_polygrps, 0),
        np.std(num_polygrps, 0) / np.sqrt(len(num_polygrps)),
        label=label,
        **kwargs,
    )


arch_kwargs = {
    "SEMI": {"ls": "-", "marker": "o"},
    "ALL": {"ls": "-", "marker": "o"},
    "Untrained": {"ls": "--", "marker": "o", "markerfacecolor": "w"},
}
yaxis_kwargs = {
    DataSet.N3P2: {
        "ymax": 8000,  # 3000
        "dy": 2000,  # 500
    },
    DataSet.N4P2: {
        "ymax": 8000,
        "dy": 2000,
    },
}
layers = [1, 2, 3, 4]
width = 5.5
height = 2.5

f, axes = plt.subplots(1, 2, figsize=(width, height), sharey=True)
for i, (dataset, counts_dict) in enumerate(polygrps_counts.items()):
    ax: plt.Axes = axes[i]
    for j, arch in enumerate(["SEMI", "ALL"]):
        plot_counts(
            counts_dict[arch]["post"],
            layers,
            ax=ax,
            label=ARCH_DESC_MAPPING[arch],
            color=colors[j],
            **arch_kwargs[arch],
        )
        if arch == "ALL":
            plot_counts(
                counts_dict[arch]["pre"],
                layers,
                ax=ax,
                label="Untrained",
                color=colors[1],
                **arch_kwargs["Untrained"],
            )
    ymax = yaxis_kwargs[dataset]["ymax"]
    dy = yaxis_kwargs[dataset]["dy"]
    ax.grid(color="lightgray")
    ax.set_xticks(layers)
    ax.set_yticks(np.arange(0, ymax + dy, dy))
    ax.set_title(dataset.name.upper(), pad=10, fontweight="bold")
    ax.set_xlabel("Layer")
    if i == 0:
        ax.set_ylabel("# PNGs")
        # ax.legend(fontsize='x-small', loc=(0.52, 0.1))
        ax.legend(fontsize="x-small", loc="upper left", frameon=True, framealpha=1)
f.tight_layout()

# Save figure
filedir = OUTPUT_DIR / "fig_png_counts_datasets.pdf"
viz.save_figure(f, filedir, dpi=300, overwrite=False)

### 3) Report summary statistics

In [None]:
def _relative_error(arr: np.ndarray) -> pd.DataFrame:
    mean = np.mean(arr, axis=0)
    std = np.std(arr, axis=0)  #  / np.sqrt(arr.shape[0])
    rel_err = std / mean
    return pd.DataFrame(
        {"mean": mean, "std": std, "rel_err": rel_err, "rel_err_pct": rel_err * 100},
        index=[int(i) for i in range(1, arr.shape[1] + 1)],
    )


rel_error_tables: dict[DataSet, pd.DataFrame] = {}
for dataset, counts_dict in polygrps_counts.items():
    rows = []
    for arch, states in counts_dict.items():
        for state, arr in states.items():
            rel_df = _relative_error(arr)
            rel_df["Architecture"] = arch
            rel_df["State"] = state
            rel_df["Layer"] = rel_df.index
            rows.append(rel_df.reset_index(drop=True))
    rel_error_tables[dataset] = pd.concat(rows, ignore_index=True)

for dataset, table in rel_error_tables.items():
    rel_error_tables[dataset] = rel_error_tables[dataset].set_index(
        ["Architecture", "State", "Layer"]
    )
    print(dataset)
    display(
        table[
            ["Architecture", "State", "Layer", "mean", "std", "rel_err", "rel_err_pct"]
        ]
        .sort_values(["Architecture", "State", "Layer"])
        .reset_index(drop=True)
    )

In [None]:
def report_rel_err_pct_mean(
    dataset: DataSet,
) -> pd.DataFrame:
    return (
        rel_error_tables[dataset]
        .groupby(["Architecture", "State"])
        .agg({"rel_err_pct": ("mean", "max", "min")})
    )


for dataset in DataSet:
    print(dataset)
    display(report_rel_err_pct_mean(dataset))

In [None]:
df = report_rel_err_pct_mean(DataSet.N4P2)
print(f"Mean relative error: {df['rel_err_pct']['mean'].mean():.1f} %")

**Statistical test**

Determine whether the number of counts for the network with feedback is significantly greater
- Two-Way ANOVA (Analysis of Variance) statistical test

In [None]:
import pandas as pd
import statsmodels.api as sm
from statsmodels.formula.api import ols

In [None]:
# 1. Convert counts_dict to the data format for ANOVA
dataset = DataSet.N3P2
layers = ["L2", "L3", "L4"]
n_seeds, n_layers = 3, 3

data = {
    "PNG_Count": [],
    "Architecture": [],
    "State": [],
    "Layer": [],
}

for arch, states in polygrps_counts[dataset].items():
    for state, arr in states.items():
        # arr shape: (n_seeds, n_layers)
        for seed_idx in range(n_seeds):
            for layer_idx in range(n_layers):
                data["PNG_Count"].append(arr[seed_idx, layer_idx + 1])
                data["Architecture"].append(arch)
                data["State"].append(state)
                data["Layer"].append(layers[layer_idx])

df = pd.DataFrame(data)
df = df[df["State"] == "post"].drop("State", axis=1)
df = df.sort_values(["Architecture", "Layer"], ascending=[False, True]).reset_index(
    drop=True
)
df

In [None]:
# 2. Define the model
# The formula "PNG_Count ~ C(Architecture) * C(Layer)" automatically includes:
# - Main Effect of Architecture
# - Main Effect of Layer
# - The Interaction (Architecture:Layer)
model = ols("PNG_Count ~ C(Architecture) * C(Layer)", data=df).fit()

# 3. Generate the ANOVA table
# typ=2 is standard for this kind of design
anova_table = sm.stats.anova_lm(model, typ=2)

anova_table.to_csv(RESULTS_DIR / f"statistical_test_{dataset.name}.csv")
print(anova_table)

### Appendix

In [None]:
counts_df = pd.concat(
    {
        (arch, state): pd.DataFrame(
            arr, columns=[int(i) for i in range(1, arr.shape[1] + 1)]
        )
        for arch, states in counts_dict.items()
        for state, arr in states.items()
    }
)
counts_df.index.names = ["Architecture", "State", "Trial"]
counts_df

In [None]:
counts_stats = counts_df.groupby(level=["Architecture", "State"]).agg(["mean", "sem"])
counts_stats