In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import os
import gc
import gzip
import numpy as np
import mne
import json
from sklearn.preprocessing import LabelEncoder
from torch.nn import Module, Linear , LayerNorm, BatchNorm1d, Dropout
import torch.nn.functional as F
import torch
from torch.optim import Adam
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from sklearn.metrics import roc_curve, auc, roc_auc_score
import networkx as nx
from collections import Counter
from snntorch import spikegen
import warnings
import random
import pyarrow
from snntorch import spikegen
import torch
import random
import torch.nn as nn
import snntorch as snn
from snntorch import surrogate
import snntorch
import imblearn
from sklearn.preprocessing import StandardScaler
from imblearn.over_sampling import SMOTE
from typing import Literal, Optional, Tuple
import joblib

mne.set_log_level("ERROR")


In [2]:
def set_global_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.use_deterministic_algorithms(False, warn_only=True)
    # torch.backends.cudnn.deterministic = True
    # torch.backends.cudnn.benchmark = False

In [3]:
# -----------------------------
# SNN that consumes [T, B, F]
# -----------------------------
class SNNModel(nn.Module):
    """Spiking Neural Network that accepts spike trains shaped [T, B, F]."""
    def __init__(self, input_size: int, hidden_size: int, output_size: int, beta: float = 0.5, num_steps: int = 25):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.lif1 = snn.Leaky(beta=beta, spike_grad=surrogate.fast_sigmoid())

        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.lif2 = snn.Leaky(beta=beta, spike_grad=surrogate.fast_sigmoid())

        self.fc3 = nn.Linear(hidden_size, output_size)
        self.lif3 = snn.Leaky(beta=beta, spike_grad=surrogate.fast_sigmoid())

        self.num_steps = int(num_steps)

    def forward(self, x_TBF: torch.Tensor):
        assert x_TBF.dim() == 3, f"Expected [T,B,F], got {tuple(x_TBF.shape)}"
        T, B, _ = x_TBF.shape
        device = x_TBF.device

        mem1 = torch.zeros(B, self.fc1.out_features, device=device)
        mem2 = torch.zeros(B, self.fc2.out_features, device=device)
        mem3 = torch.zeros(B, self.fc3.out_features, device=device)

        spk3_rec, mem3_rec = [], []
        for t in range(T):
            x = x_TBF[t]                 # [B,F]
            cur1 = self.fc1(x)           # [B,H]
            spk1, mem1 = self.lif1(cur1, mem1)

            cur2 = self.fc2(spk1)        # [B,H]
            spk2, mem2 = self.lif2(cur2, mem2)

            cur3 = self.fc3(spk2)        # [B,C]
            spk3, mem3 = self.lif3(cur3, mem3)

            spk3_rec.append(spk3)
            mem3_rec.append(mem3)

        return torch.stack(spk3_rec, 0), torch.stack(mem3_rec, 0)  # [T,B,C]

In [4]:
def encode_to_spikes_eval(
    x: torch.Tensor,          # [B,F] po scalerze
    num_steps: int,
    method: str = "rate",
    gain: float = 0.3
) -> torch.Tensor:
    # MUSI BYĆ TAK SAMO jak w SNNTrainer._encode_to_spikes
    # Zakładam wersję ze stabilnym skalowaniem:
    x_norm = torch.sigmoid(x)      # 0–1, bez per-batch min/max

    if method == "rate":
        spikes = spikegen.rate(x_norm, num_steps=num_steps, gain=gain)
    elif method == "latency":
        spikes = spikegen.latency(x_norm, num_steps=num_steps, normalize=True)
    else:
        raise ValueError(f"Unknown encoding method: {method}")
    return spikes

@torch.no_grad()
def infer_logits_batched(
    model: torch.nn.Module,
    X_float: torch.Tensor,              # [N,F] (już po scalerze)
    batch_size: int = 128,
    encoding_method: str = "rate",
    num_steps: int | None = None,
    device: str | torch.device = "cpu",
) -> torch.Tensor:
    model.eval().to(device)
    X_float = X_float.to(device)
    N = X_float.shape[0]
    steps = int(num_steps or getattr(model, "num_steps", 25))

    outs = []
    for i in range(0, N, batch_size):
        xb = X_float[i:i+batch_size]                      # [B,F]

        spikes = encode_to_spikes_eval(
            xb,
            num_steps=steps,
            method=encoding_method,
        ).to(device)                                      # [T,B,F]

        spk_rec, mem_rec = model(spikes)                  # [T,B,C]

        if encoding_method == "rate":
            out = spk_rec.sum(dim=0)                      # [B,C]
        elif encoding_method == "latency":
            out = mem_rec.max(dim=0).values               # [B,C]
        else:
            raise ValueError(f"Unknown encoding method: {encoding_method}")

        outs.append(out.detach().cpu())

    return torch.cat(outs, dim=0)                         # [N,C]

In [5]:
import os
import json
import pandas as pd

ALL_FEATURE_SETS = ["katz", "bands_rel", "bands_abs", "f1_slope_calc_by_me", "f1_slope_calc_by_fooof", "plv", "mean_std"]


def _cache_paths(cache_prefix: str):
    return {
        "X":   f"{cache_prefix}X.parquet",
        "y":   f"{cache_prefix}y.parquet",
        "meta": f"{cache_prefix}meta.parquet",
        "chs":  f"{cache_prefix}channels.json",
        "cfg":  f"{cache_prefix}config.json",
    }


def load_extracted_features(
    features, 
    cache_root: str = "./cache/"
):
    if isinstance(features, str):
            feature_list = [features]
    else:
        feature_list = list(features)

    if not feature_list:
        raise ValueError("No features requested.")

    X_all = []
    y_ser_ref = None
    meta_ref = None
    channels_ref = None
    cfg_merged: dict = {}

    for feat_name in feature_list:
        sub_list = [feat_name]

        for fn in sub_list:
            cache_prefix = os.path.join(cache_root, fn, "")
            paths = _cache_paths(cache_prefix)

            missing = [k for k, p in paths.items() if k in ("X", "y", "meta") and not os.path.exists(p)]
            if missing:
                raise FileNotFoundError(
                    f"Missing cached files for feature '{fn}' in '{cache_prefix}': {missing}"
                )

            print(f"[cache] Loading features '{fn}' from '{cache_prefix}_*.parquet'")

            X_df  = pd.read_parquet(paths["X"])
            y_ser = pd.read_parquet(paths["y"])["label"]
            meta  = pd.read_parquet(paths["meta"])

            # channels (optional)
            if os.path.exists(paths["chs"]):
                with open(paths["chs"], "r") as f:
                    chs = json.load(f)
            else:
                chs = None

            cfg_local = {}
            if os.path.exists(paths["cfg"]):
                with open(paths["cfg"], "r") as f:
                    cfg_local = json.load(f)

            if y_ser_ref is None:
                y_ser_ref = y_ser
                meta_ref  = meta
                channels_ref = chs
            else:
                if len(y_ser) != len(y_ser_ref):
                    raise ValueError(
                        f"Feature set '{fn}' has {len(y_ser)} samples, "
                        f"but previous set has {len(y_ser_ref)}."
                    )
            X_df_prefixed = X_df.add_prefix(f"{fn}__")

            X_all.append(X_df_prefixed)

            cfg_merged.update(cfg_local)

    if not X_all:
        raise RuntimeError("No feature matrices were loaded; check 'features' and cache_root.")

    X_concat = pd.concat(X_all, axis=1)

    print(f"[cache] Final concatenated X shape: {X_concat.shape}")
    return X_concat, y_ser_ref, meta_ref, channels_ref, cfg_merged

def _cache_paths(cache_prefix: str):
    return {
        "X": f"{cache_prefix}X.parquet",
        "y": f"{cache_prefix}y.parquet",
        "meta": f"{cache_prefix}meta.parquet",
        "chs": f"{cache_prefix}channels.json",
        "cfg": f"{cache_prefix}config.json",
    }

In [22]:
df_participants = pd.read_csv("participants.tsv", sep='\t')

# Build X/y from *your* notebook data:
FEATURES = "f1_slope_calc_by_fooof" #["plv", "bands_rel", "bands_abs", "mean_std"]

X_df, y_ser, meta, used_channels, _cfg = load_extracted_features(FEATURES)

print("\nPreview X_df:")
display(X_df.head())
print("\nPreview y_ser (A->0, C->1):")
display(y_ser.head())
print("\nMeta (subject windows):")
display(meta.head())

[cache] Loading features 'f1_slope_calc_by_fooof' from './cache/f1_slope_calc_by_fooof/_*.parquet'
[cache] Final concatenated X shape: (52552, 19)

Preview X_df:


Unnamed: 0,f1_slope_calc_by_fooof__fooof_f1_slope_C3,f1_slope_calc_by_fooof__fooof_f1_slope_C4,f1_slope_calc_by_fooof__fooof_f1_slope_Cz,f1_slope_calc_by_fooof__fooof_f1_slope_F3,f1_slope_calc_by_fooof__fooof_f1_slope_F4,f1_slope_calc_by_fooof__fooof_f1_slope_F7,f1_slope_calc_by_fooof__fooof_f1_slope_F8,f1_slope_calc_by_fooof__fooof_f1_slope_Fp1,f1_slope_calc_by_fooof__fooof_f1_slope_Fp2,f1_slope_calc_by_fooof__fooof_f1_slope_Fz,f1_slope_calc_by_fooof__fooof_f1_slope_O1,f1_slope_calc_by_fooof__fooof_f1_slope_O2,f1_slope_calc_by_fooof__fooof_f1_slope_P3,f1_slope_calc_by_fooof__fooof_f1_slope_P4,f1_slope_calc_by_fooof__fooof_f1_slope_Pz,f1_slope_calc_by_fooof__fooof_f1_slope_T3,f1_slope_calc_by_fooof__fooof_f1_slope_T4,f1_slope_calc_by_fooof__fooof_f1_slope_T5,f1_slope_calc_by_fooof__fooof_f1_slope_T6
0,2.159743,1.680378,2.206824,1.709732,1.561775,1.748864,2.085033,1.406818,2.111207,1.721701,2.1255,2.083735,2.376242,1.939174,1.996931,1.551694,1.725685,2.056437,1.806958
1,2.074079,2.131712,2.053111,1.639974,1.484552,1.865091,2.075424,1.535537,1.643439,1.879024,2.045097,2.077167,2.097744,2.083318,2.01816,1.818997,1.988955,2.066521,2.147957
2,2.113567,2.265747,2.289503,1.675787,1.777413,2.115215,2.103583,1.611884,1.952839,2.041131,2.030921,2.158749,2.005443,2.240399,2.291582,1.927433,2.018851,1.901319,2.200489
3,2.186963,2.379804,2.280436,1.509626,1.535515,1.883346,1.66994,1.504872,1.948231,1.960013,1.991852,2.038563,2.033714,2.051334,2.192498,2.05529,1.90266,1.941317,1.961843
4,2.277863,2.539529,2.367014,1.712145,1.611813,2.106452,1.933161,1.927252,2.323573,2.080833,2.300203,2.282525,2.352912,2.373809,2.38074,2.283877,1.669275,2.264911,2.123244



Preview y_ser (A->0, C->1):


0    0
1    0
2    0
3    0
4    0
Name: label, dtype: int64


Meta (subject windows):


Unnamed: 0,participant_id,window_idx
0,sub-002,0
1,sub-002,1
2,sub-002,2
3,sub-002,3
4,sub-002,4


In [7]:
from sklearn.model_selection import GroupShuffleSplit

# === Subject-wise 80/20 train-test split ===
# ensures that participants (subjects) are not mixed between train/test
assert "participant_id" in meta.columns, "meta must contain 'participant_id'."

gss = GroupShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
train_idx, test_idx = next(gss.split(X_df, y_ser, groups=meta['participant_id']))

# build the splits
X_train = X_df.iloc[train_idx].reset_index(drop=True)
y_train = y_ser.iloc[train_idx].reset_index(drop=True)
meta_train = meta.iloc[train_idx].reset_index(drop=True)

X_test  = X_df.iloc[test_idx].reset_index(drop=True)
y_test  = y_ser.iloc[test_idx].reset_index(drop=True)
meta_test = meta.iloc[test_idx].reset_index(drop=True)

print(f"Train windows: {len(X_train)} | Test windows: {len(X_test)}")
print("Train label counts:", y_train.value_counts().to_dict())
print("Test  label counts:", y_test.value_counts().to_dict())
print("Train subjects:", meta_train['participant_id'].nunique(),
      "| Test subjects:", meta_test['participant_id'].nunique())

Train windows: 41713 | Test windows: 11437
Train label counts: {0: 22776, 1: 18937}
Test  label counts: {0: 6305, 1: 5132}
Train subjects: 52 | Test subjects: 13


In [8]:
class SNNTrainer:
    def __init__(self, X_df, y_ser, device=None, random_state=42):
        self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
        self.random_state = random_state
        
        # labels
        y = y_ser
        if y.dtype.kind not in "iu":
            y = y.astype(str).str.upper().map({"A": 0, "C": 1})
        self.y_np = y.to_numpy()

        # features
        self.X_np = X_df.to_numpy(dtype=np.float32)
        assert np.isfinite(self.X_np).all()

        print(f"Using device: {self.device}")

    def _prep_data(self, use_smote=True):
        scaler = StandardScaler()
        X_scaled = scaler.fit_transform(self.X_np).astype(np.float32)
        self.scaler_ = scaler

        if use_smote:
            smote = SMOTE(random_state=self.random_state)
            X_res, y_res = smote.fit_resample(X_scaled, self.y_np)
        else:
            X_res, y_res = X_scaled, self.y_np

        return (
            torch.tensor(X_res, dtype=torch.float32),
            torch.tensor(y_res, dtype=torch.long),
            scaler
        )

    def _encode_to_spikes(self, x, num_steps, method="rate", gain=0.3):
        return encode_to_spikes_eval(x, num_steps, method, gain)

    def train(self, num_epochs=50, batch_size=32, hidden_size=64,
              num_steps=25, beta=0.7, lr=1e-3,
              encoding_method="rate", use_smote=True,
              save_prefix="./eeg_snn"):

        X_t, y_t, scaler = self._prep_data(use_smote)
        dataset = torch.utils.data.TensorDataset(X_t, y_t)
        loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

        self.cfg_ = {"num_steps": num_steps, "encoding_method": encoding_method}
        model = SNNModel(X_t.shape[1], hidden_size, 2, beta=beta, num_steps=num_steps).to(self.device)
        self.model_ = model

        opt = torch.optim.Adam(model.parameters(), lr=lr)
        loss_fn = nn.CrossEntropyLoss()

        print(f"Starting SNN training (encoding={encoding_method})...")

        for epoch in range(1, num_epochs + 1):
            model.train()
            epoch_loss = 0

            for xb, yb in loader:
                xb, yb = xb.to(self.device), yb.to(self.device)
                
                spikes = self._encode_to_spikes(xb, num_steps, encoding_method).to(self.device)
                spk_rec, _ = model(spikes)

                out = spk_rec.sum(0)

                loss = loss_fn(out, yb)
                opt.zero_grad()
                loss.backward()
                opt.step()

                epoch_loss += loss.item()

            if epoch % 10 == 0 or epoch == 1 or epoch == num_epochs:
                acc = self.evaluate(model, X_t, y_t, encoding_method, num_steps)
                avg_loss = epoch_loss / len(loader)
                print(f"Epoch {epoch:03d} | avg_loss={avg_loss:.4f} | acc={acc:.4f}")

        torch.save(model.state_dict(), f"{save_prefix}_model.pt")
        joblib.dump(scaler, f"{save_prefix}_scaler.pkl")
        print(f"Saved model -> {save_prefix}_model.pt ; scaler -> {save_prefix}_scaler.pkl")

        return model

    @torch.no_grad()
    def evaluate(self, model, X_all, y_all, encoding_method, num_steps):
        model.eval()
        spikes = self._encode_to_spikes(X_all.to(self.device), num_steps, encoding_method)
        spk_rec, _ = model(spikes)
        out = spk_rec.sum(0)
        pred = out.argmax(1)
        return (pred == y_all.to(self.device)).float().mean().item()

In [9]:
trainer = SNNTrainer(X_df, y_ser)

set_global_seed(42)
# pick one: "rate" or "latency"
CODING_METHOD = "rate"
BETA = 0.5
LEARNING_RATE = 5e-4
NUM_STEPS = 40
HIDDEN_SIZE = 128
NUM_EPOCHS = 50

model = trainer.train(
    num_epochs=NUM_EPOCHS,
    batch_size=64,
    hidden_size=HIDDEN_SIZE,
    num_steps=NUM_STEPS,
    beta=BETA,
    lr=LEARNING_RATE,
    encoding_method=CODING_METHOD,
    use_smote=True,
    save_prefix="./eeg_snn"
)

scaler = trainer.scaler_
model  = trainer.model_
cfg    = trainer.cfg_

Using device: cpu
Starting SNN training (encoding=rate)...
Epoch 001 | avg_loss=0.6230 | acc=0.7006
Epoch 010 | avg_loss=0.4982 | acc=0.7582


KeyboardInterrupt: 

In [None]:
# === Validation (window- and subject-level) ===
import numpy as np
import pandas as pd
import torch
from sklearn.metrics import (
    accuracy_score, balanced_accuracy_score, f1_score,
    classification_report, confusion_matrix
)
import matplotlib.pyplot as plt

# 1) Scale test features with loaded scaler (if any)
X_test_scaled = scaler.transform(X_test.values).astype("float32")
X_test_t = torch.tensor(X_test_scaled, dtype=torch.float32)

# 2) Inference: logits [N, C]
logits = infer_logits_batched(
    model,
    X_test_t,
    batch_size=128,
    encoding_method=cfg["encoding_method"],
    num_steps=cfg["num_steps"],
    device=trainer.device,
)
y_pred = logits.argmax(dim=1).cpu().numpy()
y_true = np.asarray(y_test)

label_names = ["A", "C"]  # 0 -> A, 1 -> C

# ---------- Model config ----------
print("\n=== Model config ===")
print(f'CODING: {CODING_METHOD} LEARNING RATE: {LEARNING_RATE} BETA: {BETA}')
print(f'NUM_STEPS: {NUM_STEPS} HIDDEN_SIZE: {HIDDEN_SIZE} NUM_EPOCHS: {NUM_EPOCHS}')
print(f'FEATURES: {FEATURES}')
# ---------- Window-level ----------
print("\n=== Window-level ===")
acc_win = accuracy_score(y_true, y_pred)
bal_win = balanced_accuracy_score(y_true, y_pred)
f1m_win  = f1_score(y_true, y_pred, average="macro")
print(f"Accuracy: {acc_win:.4f}")
print(f"Balanced Accuracy: {bal_win:.4f}")
print(f"Macro F1: {f1m_win:.4f}")

cm_win = confusion_matrix(y_true, y_pred, labels=[0, 1])
print("\nConfusion matrix (rows=true A,C; cols=pred A,C):\n", cm_win)

print("\nClassification report:\n",
      classification_report(y_true, y_pred, target_names=label_names, digits=4))

# ---------- Subject-level majority vote with logit tie-break ----------
df_pred = meta_test.copy()
df_pred["y_true"] = y_true
df_pred["y_pred"] = y_pred
logits_np = logits.detach().cpu().numpy()
for c in range(logits_np.shape[1]):
    df_pred[f"logit_{c}"] = logits_np[:, c]

def _mode_strict(s: pd.Series) -> int:
    vc = s.value_counts()
    winners = sorted(vc[vc == vc.max()].index.tolist())
    return int(winners[0])

def _vote_with_logit_tiebreak(g: pd.DataFrame) -> int:
    vc = g["y_pred"].value_counts()
    winners = vc[vc == vc.max()].index.tolist()
    if len(winners) == 1:
        return int(winners[0])
    sums = {cls: g[f"logit_{cls}"].sum() for cls in winners}
    best = max(sums.values())
    winners2 = [cls for cls, s in sums.items() if np.isclose(s, best)]
    return int(min(winners2))

subj = df_pred.groupby("participant_id").apply(
    lambda g: pd.Series({
        "y_true": _mode_strict(g["y_true"]),
        "y_pred": _vote_with_logit_tiebreak(g),
        "n_windows": len(g)
    })
).reset_index()

print("\n=== Subject-level ===")
acc_subj = accuracy_score(subj["y_true"], subj["y_pred"])
bal_subj = balanced_accuracy_score(subj["y_true"], subj["y_pred"])
f1m_subj = f1_score(subj["y_true"], subj["y_pred"], average="macro")
print(f"Subject Accuracy: {acc_subj:.4f}")
print(f"Subject Balanced Accuracy: {bal_subj:.4f}")
print(f"Subject Macro F1: {f1m_subj:.4f}")

cm_subj = confusion_matrix(subj["y_true"], subj["y_pred"], labels=[0, 1])
print("\nSubject-level confusion matrix (rows=true A,C; cols=pred A,C):\n", cm_subj)

# Pretty table
cm_df = pd.DataFrame(cm_subj, index=["True A", "True C"], columns=["Pred A", "Pred C"])
display(cm_df)

# ---------- Optional: Matplotlib confusion matrices ----------
def plot_cm(cm, title="Confusion Matrix", ticklabels=("A","C")):
    plt.figure(figsize=(5,4))
    plt.imshow(cm, interpolation="nearest")
    plt.title(title)
    plt.colorbar()
    ticks = np.arange(len(ticklabels))
    plt.xticks(ticks, ticklabels)
    plt.yticks(ticks, ticklabels)
    thresh = cm.max() / 2.0 if cm.max() > 0 else 0.5
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            plt.text(j, i, format(cm[i, j], "d"),
                     ha="center", va="center",
                     color="white" if cm[i, j] > thresh else "black")
    plt.ylabel("True label")
    plt.xlabel("Predicted label")
    plt.tight_layout()
    plt.show()

plot_cm(cm_win,  title="Window-level Confusion Matrix",  ticklabels=label_names)
plot_cm(cm_subj, title="Subject-level Confusion Matrix", ticklabels=label_names)

In [None]:
def export_snn_for_spinnaker(
    model,
    scaler=None,
    cfg=None,
    beta=None,
    out_dir: str = "./spinnaker_export"
):
    os.makedirs(out_dir, exist_ok=True)

    model_cpu = model.to("cpu").eval()

    fc1_W = model_cpu.fc1.weight.detach().numpy()  # [H1, F]
    fc1_b = model_cpu.fc1.bias.detach().numpy()    # [H1]

    fc2_W = model_cpu.fc2.weight.detach().numpy()  # [H2, H1]
    fc2_b = model_cpu.fc2.bias.detach().numpy()    # [H2]

    fc3_W = model_cpu.fc3.weight.detach().numpy()  # [C, H2]
    fc3_b = model_cpu.fc3.bias.detach().numpy()    # [C]

    np.save(os.path.join(out_dir, "fc1_W.npy"), fc1_W)
    np.save(os.path.join(out_dir, "fc1_b.npy"), fc1_b)
    np.save(os.path.join(out_dir, "fc2_W.npy"), fc2_W)
    np.save(os.path.join(out_dir, "fc2_b.npy"), fc2_b)
    np.save(os.path.join(out_dir, "fc3_W.npy"), fc3_W)
    np.save(os.path.join(out_dir, "fc3_b.npy"), fc3_b)

    meta = {
        "input_size": int(fc1_W.shape[1]),
        "hidden1_size": int(fc1_W.shape[0]),
        "hidden2_size": int(fc2_W.shape[0]),
        "output_size": int(fc3_W.shape[0]),
    }

    if cfg is not None:
        meta["num_steps"] = int(cfg.get("num_steps", -1))
        meta["encoding_method"] = str(cfg.get("encoding_method", "rate"))

    if beta is not None:
        meta["beta"] = float(beta)

    meta_path = os.path.join(out_dir, "meta.json")
    with open(meta_path, "w") as f:
        json.dump(meta, f, indent=2)

    if scaler is not None:
        scaler_path = os.path.join(out_dir, "scaler.pkl")
        joblib.dump(scaler, scaler_path)

    print("Exported SNN for SpiNNaker to:", out_dir)
    print("  - fc1_W.npy, fc1_b.npy")
    print("  - fc2_W.npy, fc2_b.npy")
    print("  - fc3_W.npy, fc3_b.npy")
    print("  - meta.json")
    if scaler is not None:
        print("  - scaler.pkl")

export_snn_for_spinnaker(
    model,
    scaler=scaler,
    cfg=cfg,
    beta=BETA,
    out_dir="./spinnaker_export"
)