# Net training

## Overview
This notebook trains a Spiking Neural Network (SNN) for EEG-based seizure detection. The training pipeline includes:
- Environment setup and dependency installation
- EEG data loading and preprocessing
- Spike encoding
- Network architecture configuration
- Training with validation and early stopping
- Model checkpointing and results saving


The setup is based on this work:
- P. Busia, G. Leone, A. Matticola, L. Raffo and P. Meloni, "Wearable Epilepsy Seizure Detection on FPGA With Spiking Neural Networks," in IEEE Transactions on Biomedical Circuits and Systems, vol. 19, no. 6, pp. 1175-1186, Dec. 2025, doi: 10.1109/TBCAS.2025.3575327.



## 1. Environment Setup
Install required dependencies and clone the STPSNN repository. The code automatically detects whether it's running in Google Colab or a local environment.

In [1]:
# Try to detect if running in Google Colab environment
try:
  import google.colab
  print('colab env')
  !pip install snntorch
  !pip install mne
except:
  print('local env')
  pass

# Import necessary libraries
import sys
import snntorch as snn
import gdown
import pickle
import numpy as np
# import dataloader
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import snntorch as snn
import os
from collections import deque

# Clone the spike-plasticity repository from GitHub
!git clone https://github.com/andem25/STPSNN

colab env
Cloning into 'STPSNN'...
remote: Enumerating objects: 22, done.[K
remote: Counting objects: 100% (22/22), done.[K
remote: Compressing objects: 100% (21/21), done.[K
remote: Total 22 (delta 4), reused 0 (delta 0), pack-reused 0 (from 0)[K
Receiving objects: 100% (22/22), 14.94 KiB | 14.94 MiB/s, done.
Resolving deltas: 100% (4/4), done.


## 2. Data Loading
Download the dataset from Google Drive (if not already present) and load the preprocessed training and validation data from pickle files.

In [2]:
# Check if data folder already exists, if not download it from Google Drive
if os.path.exists("data"):
    print("Dataset already present")
else:
    gdown.download_folder('https://drive.google.com/drive/folders/1EARnrSSj1DeHf0OiBmQ6_wcCJjKc8a2m?usp=sharing', output='data', quiet=False, use_cookies=False)


Dataset already present


In [3]:
# Load training and validation data from pickle files
# Open and load training data
with open('data/train_routine/train_data.pkl', 'rb') as f:
    train_data = pickle.load(f)
with open('data/train_routine/valid_data.pkl', 'rb') as f:
    valid_data = pickle.load(f)
with open('data/train_routine/y_train.pkl', 'rb') as f:
    y_train = pickle.load(f)
with open('data/train_routine/y_valid.pkl', 'rb') as f:
    y_valid = pickle.load(f)
with open('data/test_routine/train/training_window.pkl', 'rb') as f:
    train = pickle.load(f)


## 3. Data Preprocessing and Encoding
Convert raw EEG signals into spike trains using thermometer encoding. This process includes:
1. Computing normalization boundaries (max/min values) from training data
2. Encoding each value into 16-level binary representation
3. Creating PyTorch DataLoaders for batch processing

In [None]:
# Calculate maximum and minimum values for encoding normalization
# These values are computed from the first 5 minutes of training data
# and will be used to normalize all data before spike encoding
from STPSNN.encoding_functions import return_max_min
max_train, min_train = return_max_min(np.squeeze(train[0]).transpose(0,2,1), 5)

RETURN MAX MIN
(450, 4, 2048)
massimo [np.float64(285.8119658119658), np.float64(223.2967032967033), np.float64(269.4017094017094), np.float64(164.6886446886447)]
minimo [np.float64(-285.8119658119658), np.float64(-223.2967032967033), np.float64(-269.4017094017094), np.float64(-164.6886446886447)]


In [None]:
# Encode EEG data into spike trains and create data loaders
# The encoding process converts continuous EEG values into 16-level binary representations
# This thermometer encoding is particularly suitable for SNNs as it preserves temporal information
from STPSNN.encoding_functions import encode, EEG_Dataset

# Encode training and validation data using the calculated min/max values
# Each value is converted to a 16-element binary vector (thermometer code)
train_spk   = np.array(encode(np.swapaxes(np.squeeze(train_data), 1, 2), max_train, min_train))
valid_spk   = np.array(encode(np.swapaxes(np.squeeze(valid_data), 1, 2), max_train, min_train))

# Create TensorDatasets for efficient data handling
train_dataset = EEG_Dataset(train_spk, y_train)
valid_dataset = EEG_Dataset(valid_spk, y_valid)

# Create DataLoaders with specified batch size
# shuffle=True for training helps prevent overfitting
# drop_last=True ensures all batches have the same size
batch_size = 32 # Batch size
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, drop_last=True)

massimo [np.float64(285.8119658119658), np.float64(223.2967032967033), np.float64(269.4017094017094), np.float64(164.6886446886447)]
minimo [np.float64(-285.8119658119658), np.float64(-223.2967032967033), np.float64(-269.4017094017094), np.float64(-164.6886446886447)]
massimo [np.float64(285.8119658119658), np.float64(223.2967032967033), np.float64(269.4017094017094), np.float64(164.6886446886447)]
minimo [np.float64(-285.8119658119658), np.float64(-223.2967032967033), np.float64(-269.4017094017094), np.float64(-164.6886446886447)]


In [7]:
# Print shapes of encoded spike data for verification
print(train_spk.shape)
print(valid_spk.shape)


(2160, 4, 2048, 16)
(540, 4, 2048, 16)


## 4. Model Configuration
Define network architecture and training hyperparameters:
- **Network**: 2-layer SNN with Leaky Integrate-and-Fire (LIF) neurons
- **Temporal dynamics**: Number of time steps, decay rate (beta), and firing threshold
- **Training**: Learning rate, epochs, early stopping patience, and learning rate decay schedule

In [None]:
# Define output folder for trained model
FOLDER_OUT = 'trained_folder2'

########### Network Architecture ###########
# Calculate number of inputs based on spike data dimensions
# Input size = channels * encoding_levels (4 channels * 16 levels = 64)
num_inputs = train_spk.shape[1] * train_spk.shape[3]
# print(num_inputs)

num_outputs = 1  # Binary classification: seizure vs non-seizure

# Temporal Dynamics
num_steps = 2048  # Number of time steps for simulation (affects temporal precision)
# beta = 0.95  # Decay rate (commented alternative value)
beta = 0.9  # Membrane potential decay rate (higher = slower decay, more temporal integration)
# threshold = 40.0  # Threshold (commented alternative value)
threshold = 1.0  # Spike threshold (neuron fires when membrane potential exceeds this)
############################################

############ Training routine ##############
learn_treshold = False  # Whether to learn threshold during training (adaptive threshold)
learn_beta = False  # Whether to learn beta during training (adaptive decay)
epoch_start = 0  # Starting epoch number (useful for resuming training)
num_epochs = 500  # Total number of training epochs
lr_steps = 6  # Number of learning rate decay steps (allows lr to be reduced 6 times)
lr_decay = 1/3  # Learning rate decay factor (multiplies lr by this value when plateauing)
patience= 20  # Patience for early stopping (wait 20 epochs before reducing lr or stopping)
#counter = 0
best_loss = None  # Initialize best loss for model saving
lr_var = 0.001  # Initial learning rate (controls step size for weight updates)
############################################

# Check if output folder already exists to avoid overwriting
# This prevents accidental loss of previously trained models
if os.path.exists(FOLDER_OUT):
    raise ValueError(f"Folder {FOLDER_OUT} already exists. Please remove it or choose a different name.")
else:
    os.makedirs(FOLDER_OUT)

# Set device to GPU if available, otherwise CPU
# Training on GPU significantly speeds up computation for neural networks
dtype = torch.float
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

## 5. Network Initialization
Instantiate the SNN model and move it to the appropriate device (GPU if available, otherwise CPU).

In [None]:
# Create the spiking neural network with specified parameters
from STPSNN.net_definition import Net
print(f"num_inputs: {num_inputs}, num_outputs: {num_outputs}, num_steps: {num_steps}, beta: {beta}, learn_treshold: {learn_treshold}, learn_beta: {learn_beta}, threshold: {threshold}")
net = Net(num_inputs=num_inputs,
                  num_outputs=num_outputs,
                  num_steps=num_steps,
                  beta=beta,
                  learn_th=learn_treshold,
                  learn_b=learn_beta,
                  threshold=threshold).to(device)
print("Net created successfully!")


num_inputs: 64, num_outputs: 1, num_steps: 2048, beta: 0.9, learn_treshold: False, learn_beta: False, threshold: 1.0
Net created successfully!


## 6. Loss Function and Optimizer Setup
Configure the spike rate loss function and optimizer:
- **Loss**: Custom SpikeRate loss that matches output firing rates to target rates (0.35 for seizure, 0.03 for non-seizure)
- **Optimizer**: Adam optimizer with configurable learning rate
- **Sanity check**: Test network forward pass and compute initial loss

In [9]:
# Define loss function and optimizer
from STPSNN.loss import SpikeRate

# Create spike rate loss with target firing rates for correct/incorrect predictions
loss = SpikeRate(true_rate = 0.35, false_rate = 0.03)

# Test the network with one batch to compute initial loss
data, targets = next(iter(train_dataloader))
data = data.to(device)
targets = targets.to(device)

spk_rec, mem = net.forward(data)

loss_val = loss(spk_rec, targets)
print(f"The loss from an untrained network is {loss_val.item():.3f}")

# Initialize Adam optimizer
optimizer = torch.optim.Adam(net.parameters(), lr=lr_var, betas=(0.9, 0.999))


Spikerate loss 1 out
The loss from an untrained network is 0.975


## 7. Training Execution
Run the training loop with:
- **Validation monitoring**: Track validation loss and accuracy after each epoch
- **Early stopping**: Stop training if validation loss doesn't improve for `patience` epochs
- **Learning rate scheduling**: Reduce learning rate by `lr_decay` factor after patience period
- **Model checkpointing**: Save best model based on validation loss and last trained model

In [None]:
# Execute the training routine
# This function orchestrates the entire training process including:
# - Forward pass through the network for each batch
# - Loss computation and backpropagation
# - Weight updates via optimizer
# - Validation evaluation after each epoch
# - Model checkpointing (saves best and last models)
# - Early stopping with learning rate decay
from STPSNN.training_routine import training_routine


# Run training and collect history metrics
loss_hist, valid_loss_hist, acc_hist, valid_acc, best_loss = training_routine(
    net, 
    train_dataloader, 
    valid_dataloader, 
    num_epochs, 
    FOLDER_OUT, 
    device, 
    loss, 
    optimizer, 
    lr_decay, 
    patience, 
    epoch_start, 
    f=open('training_log.txt', 'w'),  # Log file for detailed training progress
    lr_steps=lr_steps
)

Epoch: 0, Step: 66/67 Train acc: 0.9818097014925373, Train loss: 0.40791124105453494

Epoch: 0, Step: 15/16 Valid acc: 0.98046875, Valid loss: 0.0602809824049472887680817
Updated best model, saving in: trained_folder2/network.pt


Summary Epoch: 0, Train acc: 0.9818097014925373, Train loss: 0.4079112410545349, Valid acc: 0.98046875, Valid loss: 0.06028098240494728
Epoch: 1, Step: 66/67 Train acc: 0.9822761194029851, Train loss: 0.054573018103837975

Epoch: 1, Step: 15/16 Valid acc: 0.98046875, Valid loss: 0.04616963490843773280948646
Updated best model, saving in: trained_folder2/network.pt


Summary Epoch: 1, Train acc: 0.9822761194029851, Train loss: 0.05457301810383797, Valid acc: 0.98046875, Valid loss: 0.04616963490843773


## 8. Save Training Results
Save the normalization parameters (max/min values) used during encoding. These are essential for preprocessing test data with the same scaling applied to training data.

In [11]:
# Save training results (min/max values used for encoding) to JSON file
import json
results = {
    'maximum': max_train,
    'minimum': min_train
    
}
with open('./%s/training_results.json' % FOLDER_OUT, 'w') as f:
    json.dump(results, f)
