In [1]:
#Set seed
import random
import numpy as np
import torch

seed = 42

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

# For reproducibility in cuDNN (may slow down performance)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


In [None]:
# ----------------------------    Model Definition  -------------------------------

import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import mne

from spikingjelly.activation_based import encoding,neuron,surrogate,layer
from scipy.io import loadmat
import tqdm
import os

# Import spikingjelly modules
from spikingjelly.activation_based import neuron, surrogate
import torch.nn.functional as F



# ----------------------------
#  2. Reservoir SNN Module using spikingjelly
#   ----------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
from spikingjelly.clock_driven import neuron, surrogate

class SpikingJellyReservoirSNN(nn.Module):
    def __init__(self, n_channels, n_reservoir, sfreq, tau=0.02, threshold=1.0):
        """
        n_channels: number of EEG channels
        n_reservoir: number of reservoir neurons
        sfreq: sampling frequency in Hz
        tau: membrane time constant (seconds)
        threshold: spiking threshold
        """
        super().__init__()
        self.n_channels = n_channels
        self.n_reservoir = n_reservoir
        self.sfreq = sfreq
        self.dt = 1.0 / sfreq
        self.tau = tau
        self.threshold = threshold

        # Fully-connected input and recurrent layers
        self.fc_in = nn.Linear(n_channels, n_reservoir)
        self.fc_rec = nn.Linear(n_reservoir, n_reservoir)

        # Freeze input and recurrent weights (reservoir remains fixed)
        for param in self.fc_in.parameters():
            param.requires_grad = False
        for param in self.fc_rec.parameters():
            param.requires_grad = False

        # Optional: Initialize weights for better reservoir dynamics
        with torch.no_grad():
            # Xavier initialization for input weights
            nn.init.xavier_uniform_(self.fc_in.weight)
            # Orthogonal initialization for recurrent weights
            nn.init.orthogonal_(self.fc_rec.weight)
            # Optionally scale to control spectral radius (e.g., 0.9)
            with torch.no_grad():
                weight = self.fc_rec.weight.data
                abs_eigenvalues = torch.abs(torch.linalg.eigvals(weight))
                spectral_radius = abs_eigenvalues.max()
                self.fc_rec.weight.data = (weight / spectral_radius) * 0.9

        # Convert tau to timesteps
        tau_timesteps = tau * sfreq
        # Create spiking neuron using spikingjelly's LIFNode
        self.lif = neuron.LIFNode(tau=tau_timesteps, surrogate_function=surrogate.ATan(), detach_reset=True)
        self.lif.v_threshold = threshold  # Set the threshold manually

    def forward(self, x):
        """
        x: Tensor of shape (batch_size, n_channels, time_steps)
        Returns: spike counts per neuron (batch_size, n_reservoir)
        """
        batch_size, _, time_steps = x.shape
        device = x.device
        spike_counts = torch.zeros(batch_size, self.n_reservoir, device=device)
        spikes = torch.zeros(batch_size, self.n_reservoir, device=device)

        # Reset LIF state at the start of each forward pass
        self.lif.reset()

        for t in range(time_steps):
            input_t = x[:, :, t]  # shape: (batch_size, n_channels)
            # Compute input current without ReLU
            I_in = self.fc_in(input_t) + self.fc_rec(spikes)
            # Update LIF neuron state
            spikes = self.lif(I_in)
            # Accumulate spike counts
            spike_counts += spikes

        return spike_counts


 
        
class ComplexClassifier(nn.Module):
    def __init__(self, n_reservoir, n_classes):
        super().__init__()
        # First hidden layer with 128 neurons
        self.fc1 = nn.Linear(n_reservoir, 64)
        self.bn1 = nn.BatchNorm1d(64)
        self.dropout1 = nn.Dropout(0.5)
        # Second hidden layer with 64 neurons
        self.fc2 = nn.Linear(64, 32)
        self.bn2 = nn.BatchNorm1d(32)
        self.dropout2 = nn.Dropout(0.5)
        #Third hidden layer
        self.fc3 = nn.Linear(32, n_classes)



    def forward(self, x):
        # Pass through first hidden layer
        x = self.fc1(x)
        x = self.bn1(x)
        x = torch.relu(x)
        x = self.dropout1(x)
        # Pass through second hidden layer
        x = self.fc2(x)
        x = self.bn2(x)
        x = torch.relu(x)
        x = self.dropout2(x)
        # Final output layer; no activation if using a loss that applies one (e.g., CrossEntropyLoss)
        x = self.fc3(x)

        return x


# ----------------------------
# 4. Combined Simplified SNN Model using spikingjelly reservoir
# ----------------------------
class SimplifiedSNNModel(nn.Module):
    def __init__(self, n_channels, n_reservoir, n_classes, sfreq, tau=0.02, threshold=1.0):
        super().__init__()
        
        self.cnn = nn.Sequential(
            nn.Conv1d(n_channels, 64, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Conv1d(64, 64, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Conv1d(64, 64, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm1d(64),
            nn.ReLU()
            
        )

        self.reservoir = SpikingJellyReservoirSNN(64, n_reservoir, sfreq, tau, threshold)
        self.classifier = ComplexClassifier(n_reservoir, n_classes)

    def forward(self, x):
        # x: (batch_size, n_channels, time_steps)

        x = self.cnn(x)
        #x = self.encoder(x)  # Encode the input using Poisson encoding

        spike_counts = self.reservoir(x)
        # Use the classifier to map the reservoir's spike counts to rate-coded outputs.
        rates = self.classifier(spike_counts)
        return rates

# ----------------------------
# 5. Custom Dataset for EEG Trials
# ----------------------------
class EEGDataset:
    def __init__(self, trials):
        """
        trials: list of tuples (segment, label)
          segment: numpy array of shape (n_channels, time_steps)
          label: integer label for sleep stage
        """
        self.trials = trials

    def __len__(self):
        return len(self.trials)

    def __getitem__(self, idx):
        segment, label = self.trials[idx]
        segment_tensor = torch.tensor(segment, dtype=torch.float32)
        label_tensor = torch.tensor(label, dtype=torch.long)
        return segment_tensor, label_tensor





In [None]:
# ----------------------------  Dataset class Definition & Dataset loading ------------------------
# 6. Data Loading: EEG (.edf) and Sleep Staging (.mat)
# ----------------------------
# Update the file paths as needed.
edf_file = os.path.join(os.getcwd(), r"01_sleep_psg.edf")
try:
    raw = mne.io.read_raw_edf(edf_file, preload=True)
except Exception as e:
    print("Error reading EDF file:", e)
    raise

eeg_data = raw.get_data()  # shape: (n_channels, total_samples)
sfreq = raw.info['sfreq']

# Load sleep staging information from a MAT file
mat_file = os.path.join(os.getcwd(), r"01_SleepStages.mat")
mat_contents = loadmat(mat_file)
keys = [key for key in mat_contents.keys() if not key.startswith('__')]
if len(keys) == 0:
    raise ValueError("No valid variables found in the MAT file.")
stages_raw = mat_contents[keys[0]]
stages_raw = np.squeeze(stages_raw)
stages = [str(s) for s in stages_raw]
stages_clean = [s[0] if isinstance(s, (np.ndarray, list)) else s for s in stages]
stage_mapping = {"['NREM 1']": 0, "['NREM 2']": 1, "['NREM 3']": 2, "['Wake']": 3, "['REM']": 4}
stages_int = [stage_mapping[s] for s in stages_clean]

# Build trials: divide continuous EEG data into 30-second epochs
epoch_duration = 30  # seconds
samples_per_epoch = int(sfreq * epoch_duration)
total_samples = eeg_data.shape[1]
n_epochs_available = total_samples // samples_per_epoch
n_epochs = min(len(stages_int), n_epochs_available)
print(f"Number of epochs (trials): {n_epochs}")

trials = []
for epoch_idx in range(n_epochs):
    start_idx = epoch_idx * samples_per_epoch
    end_idx = (epoch_idx + 1) * samples_per_epoch
    segment = eeg_data[:, start_idx:end_idx]
    stage = stages_int[epoch_idx]
    trials.append((segment, stage))

# Instantiate the dataset
dataset = EEGDataset(trials)

# ----------------------------
# 6.1 Compute Global Normalization Statistics (per channel)
# ----------------------------
# Stack all segments to compute global mean and std per channel.
all_segments = [segment for segment, _ in dataset]
all_segments = np.stack(all_segments, axis=0)  # shape: (n_trials, n_channels, time_steps)
global_mean = np.mean(all_segments, axis=(0, 2))  # shape: (n_channels,)
global_std = np.std(all_segments, axis=(0, 2))    # shape: (n_channels,)

# ----------------------------
# 7. Split Dataset into Training and Evaluation Sets
# ----------------------------
# Chronological train/val split without shuffling
train_split = int(len(dataset) * 0.8)
train_indices = list(range(train_split))
eval_indices = list(range(train_split, len(dataset)))







Extracting EDF parameters from /home/fwx/work/01_sleep_psg.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 5557599  =      0.000 ... 27787.995 secs...
Number of epochs (trials): 926


In [4]:
#Model Initialization
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
n_channels = eeg_data.shape[0]
n_reservoir = 512  # e.g., 512 reservoir neurons
n_classes = len(np.unique(stages_int))

model = SimplifiedSNNModel(n_channels, n_reservoir, n_classes, sfreq, tau=0.02, threshold=1.0).to(device)

In [None]:
# ----------------------------
# 8. Training the Simplified Model with Rate Coding Loss (Manual Batching)
# ----------------------------

criterion = nn.MSELoss()  # Rate coding loss (MSELoss)
optimizer = optim.Adam(model.parameters(), lr=0.001)

n_epochs_train = 50  # number of training epochs
batch_size = 32

model.train()
early_stop = False
for epoch in range(n_epochs_train):
    print(f"\nTraining for epoch: {epoch+1}")
    total_loss = 0.0
    training_correct = 0
    total_train_samples = 0

    num_batches = len(train_indices) // batch_size

    for batch_idx in tqdm.tqdm(range(num_batches), desc="Training Batches"):
        batch_indices = train_indices[batch_idx * batch_size : (batch_idx + 1) * batch_size]
        batch_segments = []
        batch_labels = []
        for idx in batch_indices:
            segment, label = dataset[idx]
            # Normalize each segment per channel
            segment = (segment - global_mean[:, None]) / global_std[:, None]
            batch_segments.append(segment)
            batch_labels.append(label)
        batch_segments = torch.stack(batch_segments)  # shape: [batch_size, n_channels, time_steps]
        batch_labels = torch.tensor(batch_labels, dtype=torch.long)

        batch_segments = batch_segments.to(device, non_blocking=True)
        batch_labels = batch_labels.to(device, non_blocking=True)

        optimizer.zero_grad()
        outputs = model(batch_segments)
        # Create one-hot targets for rate coding
        target = torch.zeros_like(outputs)
        target.scatter_(1, batch_labels.view(-1, 1), 1.0)
        loss = criterion(outputs, target)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * batch_segments.size(0)

        # Compute predictions and accumulate correct predictions
        preds = outputs.argmax(dim=1)
        training_correct += (preds == batch_labels).sum().item()
        total_train_samples += batch_segments.size(0)

    epoch_loss = total_loss / len(train_indices)
    training_accuracy = training_correct / len(train_indices) * 100
    print(f"Epoch {epoch+1}, Loss: {epoch_loss:.4f}, Training Accuracy: {training_accuracy:.2f}%")
    
    # Early stopping if training accuracy is above 85%
    if training_accuracy >= 85:
        print("Training accuracy above 85%. Early stopping!")
        # Print out the parameter list (state dictionary)
        print("Model parameters:")
        for name, param in model.named_parameters():
            print(f"{name}: {param.data}")
        early_stop = True
        break







Training for epoch: 1


  segment = (segment - global_mean[:, None]) / global_std[:, None]
Training Batches: 100%|██████████| 23/23 [01:16<00:00,  3.32s/it]


Epoch 1, Loss: 0.3948, Training Accuracy: 36.49%

Training for epoch: 2


Training Batches: 100%|██████████| 23/23 [00:45<00:00,  1.97s/it]


Epoch 2, Loss: 0.2889, Training Accuracy: 36.08%

Training for epoch: 3


Training Batches: 100%|██████████| 23/23 [01:16<00:00,  3.34s/it]


Epoch 3, Loss: 0.2344, Training Accuracy: 35.41%

Training for epoch: 4


Training Batches: 100%|██████████| 23/23 [00:54<00:00,  2.36s/it]


Epoch 4, Loss: 0.2049, Training Accuracy: 36.76%

Training for epoch: 5


Training Batches: 100%|██████████| 23/23 [00:59<00:00,  2.58s/it]


Epoch 5, Loss: 0.1971, Training Accuracy: 39.59%

Training for epoch: 6


Training Batches: 100%|██████████| 23/23 [01:21<00:00,  3.56s/it]


Epoch 6, Loss: 0.1806, Training Accuracy: 38.11%

Training for epoch: 7


Training Batches: 100%|██████████| 23/23 [00:26<00:00,  1.15s/it]


Epoch 7, Loss: 0.1723, Training Accuracy: 39.73%

Training for epoch: 8


Training Batches:  83%|████████▎ | 19/23 [00:48<00:11,  2.85s/it]

In [None]:
# ----------------------------
# Use CrossEntropyLoss instead of MSELoss
# ----------------------------
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
n_epochs_train = 50  # number of training epochs
batch_size = 32
# ----------------------------
# 8. Training the Simplified Model with Cross-Entropy Loss (Manual Batching)
# ----------------------------
model.train()
early_stop = False
for epoch in range(n_epochs_train):
    print(f"\nTraining for epoch: {epoch+1}")
    total_loss = 0.0
    training_correct = 0
    total_train_samples = 0

    num_batches = len(train_indices) // batch_size

    for batch_idx in tqdm.tqdm(range(num_batches), desc="Training Batches"):
        batch_indices = train_indices[batch_idx * batch_size : (batch_idx + 1) * batch_size]
        batch_segments = []
        batch_labels = []
        for idx in batch_indices:
            segment, label = dataset[idx]
            # Normalize using the global mean and std (per channel)
            segment = (segment - global_mean[:, None]) / global_std[:, None]
            batch_segments.append(segment)
            batch_labels.append(label)
        batch_segments = torch.stack(batch_segments)  # shape: [batch_size, n_channels, time_steps]
        batch_labels = torch.tensor(batch_labels, dtype=torch.long)

        batch_segments = batch_segments.to(device, non_blocking=True)
        batch_labels = batch_labels.to(device, non_blocking=True)

        optimizer.zero_grad()
        outputs = model(batch_segments)  # outputs: raw logits of shape [batch_size, n_classes]
        loss = criterion(outputs, batch_labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * batch_segments.size(0)

        # Compute predictions and accumulate correct predictions
        preds = outputs.argmax(dim=1)
        training_correct += (preds == batch_labels).sum().item()
        total_train_samples += batch_segments.size(0)

    epoch_loss = total_loss / len(train_indices)
    training_accuracy = training_correct / len(train_indices) * 100
    print(f"Epoch {epoch+1}, Loss: {epoch_loss:.4f}, Training Accuracy: {training_accuracy:.2f}%")

    # Early stopping condition (optional): stop if training accuracy >= 85%
    if training_accuracy >= 85:
        print("Training accuracy above 85%. Early stopping!")
        print("Model parameters:")
        for name, param in model.named_parameters():
            print(f"{name}: {param.data}")
        early_stop = True
        break

In [None]:
# If early stopping was triggered, evaluate on the test set immediately
if early_stop:
    model.eval()
    correct = 0
    with torch.no_grad():
        num_batches_eval = len(eval_indices) // batch_size
        for batch_idx in tqdm.tqdm(range(num_batches_eval), desc="Evaluation Batches"):
            batch_indices = eval_indices[batch_idx * batch_size : (batch_idx + 1) * batch_size]
            batch_segments = []
            batch_labels = []
            for idx in batch_indices:
                segment, label = dataset[idx]
                # Normalize using the same global statistics
                segment = (segment - global_mean[:, None]) / global_std[:, None]
                batch_segments.append(segment)
                batch_labels.append(label)
            batch_segments = torch.stack(batch_segments)
            batch_labels = torch.tensor(batch_labels, dtype=torch.long)
            batch_segments = batch_segments.to(device, non_blocking=True)
            batch_labels = batch_labels.to(device, non_blocking=True)
            outputs = model(batch_segments)
            preds = outputs.argmax(dim=1)
            correct += (preds == batch_labels).sum().item()

    test_accuracy = correct / (num_batches_eval * batch_size) * 100
    print(f"\nEvaluation Accuracy (Test Set): {test_accuracy:.2f}%")

In [None]:
# ---------------------------- Evaluation without stopping -------
# 9. Evaluate the Simplified Model with Rate Coding (Manual Batching)
# ----------------------------
model.eval()
correct = 0
with torch.no_grad():
    num_batches_eval = len(eval_indices) // batch_size
    for batch_idx in tqdm.tqdm(range(num_batches_eval), desc="Evaluation Batches"):
        batch_indices = eval_indices[batch_idx * batch_size : (batch_idx + 1) * batch_size]
        batch_segments = []
        batch_labels = []
        for idx in batch_indices:
            segment, label = dataset[idx]
            # Use the same global normalization
            segment = (segment - global_mean[:, None]) / global_std[:, None]
            batch_segments.append(segment)
            batch_labels.append(label)
        batch_segments = torch.stack(batch_segments)
        batch_labels = torch.tensor(batch_labels, dtype=torch.long)
        batch_segments = batch_segments.to(device, non_blocking=True)
        batch_labels = batch_labels.to(device, non_blocking=True)
        outputs = model(batch_segments)
        preds = outputs.argmax(dim=1)
        correct += (preds == batch_labels).sum().item()

accuracy = correct / (num_batches_eval * batch_size) * 100
print(f"\nEvaluation Accuracy: {accuracy:.2f}%")


In [None]:
#Multiple other classifiers
# We'll define a helper function to extract features from the reservoir.
def extract_reservoir_features(model, indices, batch_size=32):
    model.eval()
    features_list = []
    labels_list = []
    with torch.no_grad():
        num_batches = len(indices) // batch_size
        for batch_idx in tqdm.tqdm(range(num_batches), desc="Extracting Features"):
            batch_indices = indices[batch_idx * batch_size : (batch_idx + 1) * batch_size]
            batch_segments = []
            batch_labels = []
            for idx in batch_indices:
                segment, label = dataset[idx]
                segment = (segment - global_mean[:, None]) / global_std[:, None]
                batch_segments.append(segment)
                batch_labels.append(label)
            batch_segments = torch.stack(batch_segments)  # shape: [batch_size, n_channels, time_steps]
            batch_labels = torch.tensor(batch_labels, dtype=torch.long)
            batch_segments = batch_segments.to(device, non_blocking=True)
            # Extract reservoir features directly from the reservoir module.
            features = model.reservoir(batch_segments)
            features_list.append(features.cpu().numpy())
            labels_list.append(batch_labels.cpu().numpy())
    X = np.concatenate(features_list, axis=0)
    y = np.concatenate(labels_list, axis=0)
    return X, y

# Extract features from training and evaluation sets
X_train, y_train = extract_reservoir_features(model, train_indices, batch_size)
X_eval, y_eval = extract_reservoir_features(model, eval_indices, batch_size)

# Train classical classifiers using scikit-learn.
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score

# --- 1. Logistic Regression ---
clf_lr = LogisticRegression(max_iter=1000)
clf_lr.fit(X_train, y_train)
y_pred_lr = clf_lr.predict(X_eval)
accuracy_lr = accuracy_score(y_eval, y_pred_lr)
print(f"\nLogistic Regression Accuracy: {accuracy_lr:.2f}")

# --- 2. Random Forest ---
clf_rf = RandomForestClassifier(n_estimators=100)
clf_rf.fit(X_train, y_train)
y_pred_rf = clf_rf.predict(X_eval)
accuracy_rf = accuracy_score(y_eval, y_pred_rf)
print(f"Random Forest Accuracy: {accuracy_rf:.2f}")

# --- 3. Support Vector Machine ---
clf_svm = SVC(kernel='linear')
clf_svm.fit(X_train, y_train)
y_pred_svm = clf_svm.predict(X_eval)
accuracy_svm = accuracy_score(y_eval, y_pred_svm)
print(f"SVM Accuracy: {accuracy_svm:.2f}")