In [2]:
import pandas as pd
import plotly.graph_objects as go
import plotly.express as px 
import numpy as np
import os
import math


# Accuracies

In [None]:
csv_path = None # e.g "../exp/.." : the path to the CSV file with SHD_equiparams results 
figure_path = "../figures/"
use_log_x = False
acc_in_pct = True

df = pd.read_csv(csv_path)

df = df.dropna(subset=["model", "num_params", "final_test_acc", "seed"])
df["num_params"] = pd.to_numeric(df["num_params"], errors="coerce")
df = df.dropna(subset=["num_params"])
df["num_params"] = df["num_params"].astype(int)

df = df.sort_index().drop_duplicates(subset=["model", "num_params", "seed"], keep="last")

grp = (
    df.groupby(["model", "num_params"])
      .agg(
          mean_acc=("final_test_acc", "mean"),
          std_acc=("final_test_acc", "std"),
          n=("seed", "nunique"),
      )
      .reset_index()
)

grp["sem_acc"] = grp.apply(
    lambda r: 0.0 if (pd.isna(r["std_acc"]) or r["n"] <= 1) else r["std_acc"] / math.sqrt(r["n"]),
    axis=1
)


if acc_in_pct:
    grp["std_acc_plot"] = grp["std_acc"]
else:
    grp["std_acc_plot"] = 100.0 * grp["std_acc"]

acc_in_pct = grp["mean_acc"].max() > 1.001
if acc_in_pct:
    grp["mean_acc_plot"] = grp["mean_acc"]
    grp["sem_acc_plot"]  = grp["sem_acc"]
    yaxis_title = "Test accuracy (%)"
else:
    grp["mean_acc_plot"] = 100.0 * grp["mean_acc"]
    grp["sem_acc_plot"]  = 100.0 * grp["sem_acc"]
    yaxis_title = "Test accuracy (%)"

grp = grp.sort_values(["model", "num_params"]).reset_index(drop=True)

best = (
    grp.sort_values(["model", "num_params"])
       .groupby("model", as_index=False)
       .tail(1)
       .copy()
)
best = best.sort_values("mean_acc_plot", ascending=False).reset_index(drop=True)

def fmt_params(n):
    if n >= 1_000_000:
        return f"{n/1_000_000:.2f}M"
    if n >= 1_000:
        return f"{int(round(n/1_000))}k"
    return str(n)

best["params_label"] = best["num_params"].apply(fmt_params)

palette = px.colors.qualitative.Plotly
colors = [palette[i % len(palette)] for i in range(len(best))]

y_min = 30
y_max = float(best["mean_acc_plot"].max())
y_top = min(100.0, y_max + 3.0)

fig = go.Figure()

fig.add_trace(go.Bar(
    x=best["model"],
    y=best["mean_acc_plot"],
    marker_color=colors,
    text=best["params_label"],
    textposition="outside",
    textfont=dict(size=12),
    error_y=dict(
        type="data",
        array=best["sem_acc_plot"],
        visible=True,
        color="black",
        thickness=1.5, 
        width=8,        
    ),
    customdata=np.c_[best["num_params"], best["sem_acc_plot"]],
    hovertemplate=(
        "<b>%{x}</b><br>"
        "Params: %{customdata[0]}<br>"
        "Acc: %{y:.2f}%<br>"
        "SEM: %{customdata[1]:.2f}%<extra></extra>"
    ),
    name=""
))

fig.update_xaxes(
    title_text="Model",
    categoryorder="array",
    categoryarray=list(best["model"])[::-1]
)

fig.update_yaxes(
    title_text=yaxis_title if 'yaxis_title' in globals() else "Test accuracy (%)",
    range=[y_min, y_top],
    ticks="outside",
    tick0=10,          
    dtick=10,          
    ticklen=6,
    tickwidth=1.8,
    tickcolor="black",
    ticklabelposition="outside",
    tickformat=".0f", 
    ticksuffix="%",
    showline=True,
    linecolor="black",
    linewidth=2,
    showgrid=True,
    gridcolor="rgba(0,0,0,0.08)",
    gridwidth=1
)

fig.update_yaxes(minor=dict(
    tickmode="linear",
    dtick=5,
    tick0=5,
    ticks="outside",
    ticklen=3,
    showgrid=True,
    gridcolor="rgba(0,0,0,0.05)",
    gridwidth=0.5
))

fig.update_layout(
    template="plotly_white",
    font=dict(family="Times New Roman, serif", size=16),
    showlegend=False,
    bargap=0.12,              
    bargroupgap=0.0,
    height=500,
    width=800,
    margin=dict(l=80, r=20, t=40, b=80)
)

fig.update_traces(marker_line_color="black", marker_line_width=0.6)

os.makedirs(figure_path, exist_ok=True)

bar_svg_path = os.path.join(os.path.dirname(figure_path), "accuracy_best_bar.svg")
fig.write_image(bar_svg_path, format="svg", width=800, height=500, scale=1)
print(f"Saved SVG to: {bar_svg_path}")

fig.show()

Saved SVG to: ../figures/accuracy_best_bar.svg


# Accuracies as a function of parameters

In [None]:
csv_path = None # e.g "../exp/.." : the path to the CSV file with SHD_equiparams results
figure_path = "../figures/"
use_log_x = False
acc_in_pct = True

def hex_to_rgba(hex_color, alpha=0.18):  
    h = hex_color.lstrip("#")
    r, g, b = int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16)
    return f"rgba({r},{g},{b},{alpha})"


df = pd.read_csv(csv_path)

df = df.dropna(subset=["model", "num_params", "final_test_acc", "seed"])
df["num_params"] = pd.to_numeric(df["num_params"], errors="coerce")
df = df.dropna(subset=["num_params"])
df["num_params"] = df["num_params"].astype(int)

df = df.sort_index().drop_duplicates(subset=["model", "num_params", "seed"], keep="last")

grp = (
    df.groupby(["model", "num_params"])
      .agg(
          mean_acc=("final_test_acc", "mean"),
          std_acc=("final_test_acc", "std"),
          n=("seed", "nunique"),
      )
      .reset_index()
)

grp["sem_acc"] = grp.apply(
    lambda r: 0.0 if (pd.isna(r["std_acc"]) or r["n"] <= 1) else r["std_acc"] / math.sqrt(r["n"]),
    axis=1
)


if acc_in_pct:
    grp["std_acc_plot"] = grp["std_acc"]
else:
    grp["std_acc_plot"] = 100.0 * grp["std_acc"]

acc_in_pct = grp["mean_acc"].max() > 1.001
if acc_in_pct:
    grp["mean_acc_plot"] = grp["mean_acc"]
    grp["sem_acc_plot"]  = grp["sem_acc"]
    yaxis_title = "Test accuracy (%)"
else:
    grp["mean_acc_plot"] = 100.0 * grp["mean_acc"]
    grp["sem_acc_plot"]  = 100.0 * grp["sem_acc"]
    yaxis_title = "Test accuracy (%)"

grp = grp.sort_values(["model", "num_params"]).reset_index(drop=True)

fig = go.Figure()
palette = px.colors.qualitative.Plotly
models = grp["model"].unique()

for i, model in enumerate(models):
    g = grp[grp["model"] == model].sort_values("num_params")
    color = palette[i % len(palette)]
    fillcol = hex_to_rgba(color, 0.18)

    x = g["num_params"].to_numpy()
    y = g["mean_acc_plot"].to_numpy()
    sem = g["sem_acc_plot"].to_numpy()

    y_lower = np.clip(y - sem, 0.0, 100.0)
    y_upper = np.clip(y + sem, 0.0, 100.0)

    fig.add_trace(go.Scatter(
        x=x, y=y_lower, mode="lines",
        line=dict(width=0), hoverinfo="skip",
        showlegend=False, legendgroup=model
    ))

    fig.add_trace(go.Scatter(
        x=x, y=y_upper, mode="lines",
        line=dict(width=0), fill="tonexty", fillcolor=fillcol,
        hoverinfo="skip", showlegend=True, legendgroup=model,
        name=f"{model} ± SEM"
    ))

    fig.add_trace(go.Scatter(
        x=x, y=y, mode="lines+markers",
        name=f"{model} mean", legendgroup=model,
        line=dict(color=color), marker=dict(color=color),
        hovertemplate=(
            "<b>%{fullData.name}</b><br>" +
            "#params: %{x}<br>" +
            "Acc: %{y:.2f}%<br>" +
            "SEM: %{customdata:.2f}%<extra></extra>"
        ),
        customdata=np.round(sem, 4),
    ))


if use_log_x:
    fig.update_xaxes(type="log")
    
os.makedirs(figure_path, exist_ok=True)

svg_path = os.path.join(os.path.dirname(figure_path), "accuracy_vs_params.svg")
fig.write_image(svg_path, format="svg", width=800, height=600, scale=1)
print(f"Saved SVG to: {svg_path}")
fig.show()

Saved SVG to: ../figures/accuracy_vs_params.svg


# Spike penalization

In [None]:
csv_path = None # e.g "../exp/.." : the path to the CSV file with SHD_penalize_spikes results
figure_path = "../figures/"
use_log_x = False

x_limit = 0.09         
models_to_plot = ['SNN', 'SNN_recurrent_delays', 'SNN_feedforward_delays']  

df = pd.read_csv(csv_path)

need_cols = ["model", "seed", "lambda_spike", "final_test_acc", "mean_spikes_test"]
missing = [c for c in need_cols if c not in df.columns]
if missing:
    raise ValueError(f"CSV is missing required columns: {missing}")

df = df.dropna(subset=need_cols)
df = df.sort_index().drop_duplicates(subset=["model", "lambda_spike", "seed"], keep="last")

grp = (
    df.groupby(["model", "lambda_spike"])
      .agg(
          n=("seed", "nunique"),
          mean_acc=("final_test_acc", "mean"),
          std_acc=("final_test_acc", "std"),
          mean_spk=("mean_spikes_test", "mean"),
          std_spk=("mean_spikes_test", "std"),
      )
      .reset_index()
)

grp["sem_acc"] = grp.apply(
    lambda r: 0.0 if (pd.isna(r["std_acc"]) or r["n"] <= 1) else r["std_acc"] / math.sqrt(r["n"]),
    axis=1
)
grp["sem_spk"] = grp.apply(
    lambda r: 0.0 if (pd.isna(r["std_spk"]) or r["n"] <= 1) else r["std_spk"] / math.sqrt(r["n"]),
    axis=1
)

acc_in_pct = grp["mean_acc"].max() > 1.001
if acc_in_pct:
    grp["mean_acc_plot"] = grp["mean_acc"]
    grp["sem_acc_plot"]  = grp["sem_acc"]
    yaxis_title = "Test accuracy (%)"
else:
    grp["mean_acc_plot"] = 100.0 * grp["mean_acc"]
    grp["sem_acc_plot"]  = 100.0 * grp["sem_acc"]
    yaxis_title = "Test accuracy (%)"

grp = grp.sort_values(["model", "mean_spk"]).reset_index(drop=True)

if models_to_plot is not None:
    models_to_plot = set(models_to_plot)
    grp = grp[grp["model"].isin(models_to_plot)].copy()
    
if x_limit is not None:
    grp = grp[grp["mean_spk"] <= float(x_limit)].copy()
    
grp = grp.sort_values(["model", "mean_spk"]).reset_index(drop=True)

def hex_to_rgba(hx, alpha=0.18):
    hx = hx.lstrip("#")
    r, g, b = int(hx[0:2], 16), int(hx[2:4], 16), int(hx[4:6], 16)
    return f"rgba({r},{g},{b},{alpha})"

fig = go.Figure()
palette = px.colors.qualitative.Plotly
models = grp["model"].unique()

for i, model in enumerate(models):
    g = grp[grp["model"] == model].sort_values("mean_spk")
    if g.empty:
        continue  

    color = palette[i % len(palette)]
    band = hex_to_rgba(color, 0.20)

    x = g["mean_spk"].to_numpy()
    y = g["mean_acc_plot"].to_numpy()
    sem_y = g["sem_acc_plot"].to_numpy()

    if x.size == 0:
        continue

    x_band = np.concatenate([x, x[::-1]])
    y_band = np.concatenate([np.clip(y - sem_y, 0, 100), np.clip(y + sem_y, 0, 100)[::-1]])
    fig.add_trace(go.Scatter(
        x=x_band, y=y_band,
        fill="toself", fillcolor=band,
        line=dict(width=0),
        hoverinfo="skip",
        showlegend=False,
        legendgroup=model
    ))

    fig.add_trace(go.Scatter(
        x=x, y=y,
        mode="lines+markers",
        name=model,
        line=dict(color=color, width=2),
        marker=dict(color=color, size=6),
        legendgroup=model,
        customdata=np.stack([g["lambda_spike"], g["n"], g["sem_spk"]], axis=1),
        hovertemplate=(
            "<b>%{fullData.name}</b><br>"
            "λ: %{customdata[0]}<br>"
            "Mean spikes: %{x:.6f} ± %{customdata[2]:.6f}<br>"
            "Accuracy: %{y:.2f}% (mean ± SEM shown)<br>"
            "Seeds: %{customdata[1]}<extra></extra>"
        ),
    ))

fig.update_layout(
    title="Accuracy vs. mean spike rate (per neuron · per timestep)\n(mean ± s.e.m. over seeds)",
    xaxis_title="Mean spikes per neuron per timestep",
    yaxis_title=yaxis_title,
    legend_title="Model",
    hovermode="x unified",
    template="plotly_white",
)

if use_log_x:
    fig.update_xaxes(type="log")
    if x_limit is not None and not grp.empty:
        pos_x = grp.loc[grp["mean_spk"] > 0, "mean_spk"]
        if not pos_x.empty:
            xmin = pos_x.min()
            fig.update_xaxes(range=[np.log10(xmin), np.log10(x_limit)])
else:
    if x_limit is not None and not grp.empty:
        xmin = grp["mean_spk"].min()
        fig.update_xaxes(range=[xmin - 0.002, x_limit])

fig.update_yaxes(
    ticks="outside", ticklen=6, tickwidth=1.5, tickcolor="black",
    showline=True, linecolor="black", linewidth=2
)

os.makedirs(figure_path, exist_ok=True)

svg_path = os.path.join(os.path.dirname(figure_path), "accuracy_vs_spikes.svg")
fig.write_image(svg_path, format="svg", width=800, height=600, scale=1)
print(f"Saved SVG to: {svg_path}")
fig.show()

Saved SVG to: ../figures/accuracy_vs_spikes.svg
