In [None]:
#Set seed
import random
import numpy as np
import torch

seed = 42

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

# For reproducibility in cuDNN (may slow down performance)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


In [None]:
# ----------------------------    Model Definition  -------------------------------

import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import mne

from spikingjelly.activation_based import encoding,neuron,surrogate,layer
from scipy.io import loadmat
import tqdm
import os

# Import spikingjelly modules
from spikingjelly.activation_based import neuron, surrogate
import torch.nn.functional as F



# ----------------------------
#  2. Reservoir SNN Module using spikingjelly
#   ----------------------------
# class SpikingJellyReservoirSNN(nn.Module):
#     def __init__(self, n_channels, n_reservoir, sfreq, tau=0.02, threshold=1.0):
#         """
#         n_channels: number of EEG channels
#         n_reservoir: number of reservoir neurons
#         sfreq: sampling frequency in Hz
#         tau: membrane time constant (seconds)
#         threshold: spiking threshold
#         """
#         super().__init__()
#         self.n_channels = n_channels
#         self.n_reservoir = n_reservoir
#         self.sfreq = sfreq
#         self.dt = 1.0 / sfreq
#         self.tau = tau
#         self.threshold = threshold

#         # Fully-connected input mapping (EEG channels -> reservoir neurons)
#         self.fc_in = nn.Linear(n_channels, n_reservoir)
#         self.fc_rec = nn.Linear(n_reservoir, n_reservoir)
        
#         # Convert tau to timesteps
#         tau_timesteps = tau * sfreq  
#         # Create spiking neuron using spikingjelly's LIFNode.
#         self.lif = neuron.LIFNode(tau=tau_timesteps, surrogate_function=surrogate.ATan(), detach_reset=True)
#         # Set the threshold manually (if needed; default is 1.0)
#         self.lif.v_threshold = threshold

#     def forward(self, x):
#         """
#         x: Tensor of shape (batch_size, n_channels, time_steps)
#         Returns: spike counts per neuron (batch_size, n_reservoir)
#         """
#         batch_size, _, time_steps = x.shape
#         device = x.device
#         spike_counts = torch.zeros(batch_size, self.n_reservoir, device=device)
#         spikes = torch.zeros(batch_size, self.n_reservoir, device=device)

#         # Reset LIF state at the start of each forward pass
#         self.lif.reset()

#         for t in range(time_steps):
#             input_t = x[:, :, t]  # shape: (batch_size, n_channels)
#             I_in = F.relu(self.fc_in(input_t) + self.fc_rec(spikes))#第二项这里实现递归，将上一轮的脉冲数量传递给递归层进行计算
#             spikes = self.lif(I_in)#这里实际上算是将所有的节点都作为lif神经元进行计算了
#             spike_counts += spikes
#         return spike_counts


import torch
import torch.nn as nn
import torch.nn.functional as F

class SpikingJellyReservoirSNN(nn.Module):
    def __init__(self, input_size, reservoir_size, sparsity=0.1, tau_mem=20.0, dt=1.0, threshold=1.0, reset=0.0):
        super(LIFReservoir, self).__init__()
        self.reservoir_size = reservoir_size
        self.tau_mem = tau_mem
        self.dt = dt
        self.threshold = threshold
        self.reset = reset

        # Time constant decay factor
        self.alpha = float(torch.exp(-dt / tau_mem))  # Membrane decay
        self.beta = 1.0 * (1.0 - self.alpha)           # Input scaling

        # Input-to-reservoir weights
        W_in = torch.randn(reservoir_size, input_size) * 0.1
        self.register_buffer('W_in', W_in)

        # Reservoir-to-reservoir sparse recurrent weights
        W_res = torch.randn(reservoir_size, reservoir_size)
        mask = torch.rand_like(W_res) < sparsity
        W_res[~mask] = 0
        self.register_buffer('W_res', W_res)

        # Linear classifier on top of reservoir states
        self.classifier = nn.Linear(reservoir_size, 1)  # Adjust output size for multi-class tasks

    def forward(self, input_seq, state=None):
        """
        input_seq: shape (batch_size, seq_len, input_size)
        Returns: logits (batch_size, num_classes)
        """
        batch_size, seq_len, _ = input_seq.size()

        if state is None:
            state = {
                'v': torch.zeros(batch_size, self.reservoir_size, device=input_seq.device),
                'output': torch.zeros(batch_size, self.reservoir_size, device=input_seq.device)
            }

        v = state['v']
        outputs = []

        for t in range(seq_len):
            x = input_seq[:, t]

            # Input to reservoir
            I = F.linear(x, self.W_in) + F.linear(state['output'], self.W_res)

            # Update membrane potential using LIF dynamics
            v = self.alpha * v + self.beta * I

            # Spike condition
            z = (v >= self.threshold).float()
            v = v * (1.0 - z) + z * self.reset  # Reset voltage on spike

            outputs.append(z)
            state['v'] = v
            state['output'] = z

        # Use final step's spikes as representation
        final_output = outputs[-1]
        logits = self.classifier(final_output)
        return logits
        
class ComplexClassifier(nn.Module):
    def __init__(self, n_reservoir, n_classes):
        super().__init__()
        # First hidden layer with 128 neurons
        self.fc1 = nn.Linear(n_reservoir, 64)
        self.bn1 = nn.BatchNorm1d(64)
        self.dropout1 = nn.Dropout(0.5)
        # Second hidden layer with 64 neurons
        self.fc2 = nn.Linear(64, 32)
        self.bn2 = nn.BatchNorm1d(32)
        self.dropout2 = nn.Dropout(0.5)
        #Third hidden layer
        self.fc3 = nn.Linear(32, n_classes)



    def forward(self, x):
        # Pass through first hidden layer
        x = self.fc1(x)
        x = self.bn1(x)
        x = torch.relu(x)
        x = self.dropout1(x)
        # Pass through second hidden layer
        x = self.fc2(x)
        x = self.bn2(x)
        x = torch.relu(x)
        x = self.dropout2(x)
        # Final output layer; no activation if using a loss that applies one (e.g., CrossEntropyLoss)
        x = self.fc3(x)

        return x


# ----------------------------
# 4. Combined Simplified SNN Model using spikingjelly reservoir
# ----------------------------
class SimplifiedSNNModel(nn.Module):
    def __init__(self, n_channels, n_reservoir, n_classes, sfreq, tau=0.02, threshold=1.0):
        super().__init__()
        
        self.encoder = encoding.PoissonEncoder()
        self.cnn = nn.Sequential(
            nn.Conv1d(n_channels, 64, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Conv1d(64, 64, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Conv1d(64, 64, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm1d(64),
            nn.ReLU()
            
        )

        self.reservoir = SpikingJellyReservoirSNN(64, n_reservoir, sfreq, tau, threshold)
        self.classifier = ComplexClassifier(n_reservoir, n_classes)

    def forward(self, x):
        # x: (batch_size, n_channels, time_steps)

        x = self.cnn(x)
        #x = self.encoder(x)  # Encode the input using Poisson encoding

        spike_counts = self.reservoir(x)
        # Use the classifier to map the reservoir's spike counts to rate-coded outputs.
        rates = self.classifier(spike_counts)
        return rates

# ----------------------------
# 5. Custom Dataset for EEG Trials
# ----------------------------
class EEGDataset:
    def __init__(self, trials):
        """
        trials: list of tuples (segment, label)
          segment: numpy array of shape (n_channels, time_steps)
          label: integer label for sleep stage
        """
        self.trials = trials

    def __len__(self):
        return len(self.trials)

    def __getitem__(self, idx):
        segment, label = self.trials[idx]
        segment_tensor = torch.tensor(segment, dtype=torch.float32)
        label_tensor = torch.tensor(label, dtype=torch.long)
        return segment_tensor, label_tensor





In [None]:
# ----------------------------  Dataset class Definition & Dataset loading ------------------------
# 6. Data Loading: EEG (.edf) and Sleep Staging (.mat)
# ----------------------------
# Update the file paths as needed.
edf_file = os.path.join(os.getcwd(), r"01_sleep_psg.edf")
try:
    raw = mne.io.read_raw_edf(edf_file, preload=True)
except Exception as e:
    print("Error reading EDF file:", e)
    raise

eeg_data = raw.get_data()  # shape: (n_channels, total_samples)
sfreq = raw.info['sfreq']

# Load sleep staging information from a MAT file
mat_file = os.path.join(os.getcwd(), r"01_SleepStages.mat")
mat_contents = loadmat(mat_file)
keys = [key for key in mat_contents.keys() if not key.startswith('__')]
if len(keys) == 0:
    raise ValueError("No valid variables found in the MAT file.")
stages_raw = mat_contents[keys[0]]
stages_raw = np.squeeze(stages_raw)
stages = [str(s) for s in stages_raw]
stages_clean = [s[0] if isinstance(s, (np.ndarray, list)) else s for s in stages]
stage_mapping = {"['NREM 1']": 0, "['NREM 2']": 1, "['NREM 3']": 2, "['Wake']": 3, "['REM']": 4}
stages_int = [stage_mapping[s] for s in stages_clean]

# Build trials: divide continuous EEG data into 30-second epochs
epoch_duration = 30  # seconds
samples_per_epoch = int(sfreq * epoch_duration)
total_samples = eeg_data.shape[1]
n_epochs_available = total_samples // samples_per_epoch
n_epochs = min(len(stages_int), n_epochs_available)
print(f"Number of epochs (trials): {n_epochs}")

trials = []
for epoch_idx in range(n_epochs):
    start_idx = epoch_idx * samples_per_epoch
    end_idx = (epoch_idx + 1) * samples_per_epoch
    segment = eeg_data[:, start_idx:end_idx]
    stage = stages_int[epoch_idx]
    trials.append((segment, stage))

# Instantiate the dataset
dataset = EEGDataset(trials)

# ----------------------------
# 6.1 Compute Global Normalization Statistics (per channel)
# ----------------------------
# Stack all segments to compute global mean and std per channel.
all_segments = [segment for segment, _ in dataset]
all_segments = np.stack(all_segments, axis=0)  # shape: (n_trials, n_channels, time_steps)
global_mean = np.mean(all_segments, axis=(0, 2))  # shape: (n_channels,)
global_std = np.std(all_segments, axis=(0, 2))    # shape: (n_channels,)

# ----------------------------
# 7. Split Dataset into Training and Evaluation Sets
# ----------------------------
indices = list(range(len(dataset)))
random.shuffle(indices)
split_ratio = 0.8  # 80% training, 20% evaluation
train_split = int(len(indices) * split_ratio)
train_indices = indices[:train_split]
eval_indices = indices[train_split:]







Extracting EDF parameters from /home/fwx/work/01_sleep_psg.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 5557599  =      0.000 ... 27787.995 secs...
Number of epochs (trials): 926


In [None]:
#Model Initialization
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
n_channels = eeg_data.shape[0]
n_reservoir = 512  # e.g., 512 reservoir neurons
n_classes = len(np.unique(stages_int))

model = SimplifiedSNNModel(n_channels, n_reservoir, n_classes, sfreq, tau=0.02, threshold=1.0).to(device)

In [None]:
# ----------------------------
# 8. Training the Simplified Model with Rate Coding Loss (Manual Batching)
# ----------------------------

criterion = nn.MSELoss()  # Rate coding loss (MSELoss)
optimizer = optim.Adam(model.parameters(), lr=0.001)

n_epochs_train = 50  # number of training epochs
batch_size = 32

model.train()
early_stop = False
for epoch in range(n_epochs_train):
    print(f"\nTraining for epoch: {epoch+1}")
    total_loss = 0.0
    training_correct = 0
    total_train_samples = 0

    num_batches = len(train_indices) // batch_size

    for batch_idx in tqdm.tqdm(range(num_batches), desc="Training Batches"):
        batch_indices = train_indices[batch_idx * batch_size : (batch_idx + 1) * batch_size]
        batch_segments = []
        batch_labels = []
        for idx in batch_indices:
            segment, label = dataset[idx]
            # Normalize each segment per channel
            segment = (segment - global_mean[:, None]) / global_std[:, None]
            batch_segments.append(segment)
            batch_labels.append(label)
        batch_segments = torch.stack(batch_segments)  # shape: [batch_size, n_channels, time_steps]
        batch_labels = torch.tensor(batch_labels, dtype=torch.long)

        batch_segments = batch_segments.to(device, non_blocking=True)
        batch_labels = batch_labels.to(device, non_blocking=True)

        optimizer.zero_grad()
        outputs = model(batch_segments)
        # Create one-hot targets for rate coding
        target = torch.zeros_like(outputs)
        target.scatter_(1, batch_labels.view(-1, 1), 1.0)
        loss = criterion(outputs, target)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * batch_segments.size(0)

        # Compute predictions and accumulate correct predictions
        preds = outputs.argmax(dim=1)
        training_correct += (preds == batch_labels).sum().item()
        total_train_samples += batch_segments.size(0)

    epoch_loss = total_loss / len(train_indices)
    training_accuracy = training_correct / len(train_indices) * 100
    print(f"Epoch {epoch+1}, Loss: {epoch_loss:.4f}, Training Accuracy: {training_accuracy:.2f}%")
    
    # Early stopping if training accuracy is above 85%
    if training_accuracy >= 85:
        print("Training accuracy above 85%. Early stopping!")
        # Print out the parameter list (state dictionary)
        print("Model parameters:")
        for name, param in model.named_parameters():
            print(f"{name}: {param.data}")
        early_stop = True
        break







Training for epoch: 1


  segment = (segment - global_mean[:, None]) / global_std[:, None]
Training Batches: 100%|██████████| 23/23 [00:32<00:00,  1.40s/it]


Epoch 1, Loss: 0.5196, Training Accuracy: 25.00%

Training for epoch: 2


Training Batches: 100%|██████████| 23/23 [00:30<00:00,  1.33s/it]


Epoch 2, Loss: 0.3320, Training Accuracy: 29.05%

Training for epoch: 3


Training Batches: 100%|██████████| 23/23 [00:16<00:00,  1.37it/s]


Epoch 3, Loss: 0.2613, Training Accuracy: 28.24%

Training for epoch: 4


Training Batches: 100%|██████████| 23/23 [00:23<00:00,  1.03s/it]


Epoch 4, Loss: 0.2428, Training Accuracy: 33.24%

Training for epoch: 5


Training Batches: 100%|██████████| 23/23 [00:24<00:00,  1.07s/it]


Epoch 5, Loss: 0.2095, Training Accuracy: 35.41%

Training for epoch: 6


Training Batches: 100%|██████████| 23/23 [00:21<00:00,  1.05it/s]


Epoch 6, Loss: 0.1978, Training Accuracy: 38.51%

Training for epoch: 7


Training Batches: 100%|██████████| 23/23 [00:24<00:00,  1.07s/it]


Epoch 7, Loss: 0.1910, Training Accuracy: 36.35%

Training for epoch: 8


Training Batches: 100%|██████████| 23/23 [00:33<00:00,  1.45s/it]


Epoch 8, Loss: 0.1728, Training Accuracy: 41.76%

Training for epoch: 9


Training Batches: 100%|██████████| 23/23 [00:29<00:00,  1.27s/it]


Epoch 9, Loss: 0.1664, Training Accuracy: 42.97%

Training for epoch: 10


Training Batches: 100%|██████████| 23/23 [00:32<00:00,  1.41s/it]


Epoch 10, Loss: 0.1667, Training Accuracy: 40.41%

Training for epoch: 11


Training Batches: 100%|██████████| 23/23 [00:30<00:00,  1.31s/it]


Epoch 11, Loss: 0.1562, Training Accuracy: 43.51%

Training for epoch: 12


Training Batches: 100%|██████████| 23/23 [00:35<00:00,  1.52s/it]


Epoch 12, Loss: 0.1548, Training Accuracy: 42.16%

Training for epoch: 13


Training Batches: 100%|██████████| 23/23 [00:29<00:00,  1.28s/it]


Epoch 13, Loss: 0.1498, Training Accuracy: 43.11%

Training for epoch: 14


Training Batches: 100%|██████████| 23/23 [00:32<00:00,  1.43s/it]


Epoch 14, Loss: 0.1421, Training Accuracy: 45.27%

Training for epoch: 15


Training Batches: 100%|██████████| 23/23 [00:39<00:00,  1.70s/it]


Epoch 15, Loss: 0.1417, Training Accuracy: 46.08%

Training for epoch: 16


Training Batches: 100%|██████████| 23/23 [00:26<00:00,  1.17s/it]


Epoch 16, Loss: 0.1392, Training Accuracy: 48.11%

Training for epoch: 17


Training Batches: 100%|██████████| 23/23 [00:35<00:00,  1.52s/it]


Epoch 17, Loss: 0.1361, Training Accuracy: 47.57%

Training for epoch: 18


Training Batches: 100%|██████████| 23/23 [00:27<00:00,  1.18s/it]


Epoch 18, Loss: 0.1350, Training Accuracy: 49.59%

Training for epoch: 19


Training Batches: 100%|██████████| 23/23 [00:27<00:00,  1.19s/it]


Epoch 19, Loss: 0.1341, Training Accuracy: 50.68%

Training for epoch: 20


Training Batches: 100%|██████████| 23/23 [00:32<00:00,  1.40s/it]


Epoch 20, Loss: 0.1301, Training Accuracy: 52.70%

Training for epoch: 21


Training Batches: 100%|██████████| 23/23 [00:35<00:00,  1.53s/it]


Epoch 21, Loss: 0.1237, Training Accuracy: 55.41%

Training for epoch: 22


Training Batches: 100%|██████████| 23/23 [00:25<00:00,  1.11s/it]


Epoch 22, Loss: 0.1210, Training Accuracy: 58.78%

Training for epoch: 23


Training Batches: 100%|██████████| 23/23 [00:34<00:00,  1.51s/it]


Epoch 23, Loss: 0.1159, Training Accuracy: 60.00%

Training for epoch: 24


Training Batches: 100%|██████████| 23/23 [00:28<00:00,  1.24s/it]


Epoch 24, Loss: 0.1133, Training Accuracy: 62.03%

Training for epoch: 25


Training Batches: 100%|██████████| 23/23 [00:23<00:00,  1.04s/it]


Epoch 25, Loss: 0.1115, Training Accuracy: 62.84%

Training for epoch: 26


Training Batches: 100%|██████████| 23/23 [00:22<00:00,  1.03it/s]


Epoch 26, Loss: 0.1101, Training Accuracy: 63.78%

Training for epoch: 27


Training Batches: 100%|██████████| 23/23 [00:20<00:00,  1.10it/s]


Epoch 27, Loss: 0.1038, Training Accuracy: 67.70%

Training for epoch: 28


Training Batches: 100%|██████████| 23/23 [00:18<00:00,  1.26it/s]


Epoch 28, Loss: 0.1056, Training Accuracy: 67.30%

Training for epoch: 29


Training Batches: 100%|██████████| 23/23 [00:20<00:00,  1.10it/s]


Epoch 29, Loss: 0.0997, Training Accuracy: 67.84%

Training for epoch: 30


Training Batches: 100%|██████████| 23/23 [00:28<00:00,  1.23s/it]


Epoch 30, Loss: 0.0995, Training Accuracy: 66.89%

Training for epoch: 31


Training Batches: 100%|██████████| 23/23 [00:30<00:00,  1.32s/it]


Epoch 31, Loss: 0.1021, Training Accuracy: 67.84%

Training for epoch: 32


Training Batches: 100%|██████████| 23/23 [00:26<00:00,  1.17s/it]


Epoch 32, Loss: 0.0947, Training Accuracy: 70.00%

Training for epoch: 33


Training Batches: 100%|██████████| 23/23 [00:27<00:00,  1.18s/it]


Epoch 33, Loss: 0.0924, Training Accuracy: 70.54%

Training for epoch: 34


Training Batches: 100%|██████████| 23/23 [00:22<00:00,  1.01it/s]


Epoch 34, Loss: 0.0900, Training Accuracy: 74.05%

Training for epoch: 35


Training Batches: 100%|██████████| 23/23 [00:30<00:00,  1.34s/it]


Epoch 35, Loss: 0.0870, Training Accuracy: 74.32%

Training for epoch: 36


Training Batches: 100%|██████████| 23/23 [00:43<00:00,  1.90s/it]


Epoch 36, Loss: 0.0777, Training Accuracy: 78.11%

Training for epoch: 37


Training Batches: 100%|██████████| 23/23 [00:35<00:00,  1.53s/it]


Epoch 37, Loss: 0.0759, Training Accuracy: 79.46%

Training for epoch: 38


Training Batches: 100%|██████████| 23/23 [00:31<00:00,  1.39s/it]


Epoch 38, Loss: 0.0774, Training Accuracy: 78.24%

Training for epoch: 39


Training Batches: 100%|██████████| 23/23 [00:33<00:00,  1.46s/it]


Epoch 39, Loss: 0.0765, Training Accuracy: 80.95%

Training for epoch: 40


Training Batches: 100%|██████████| 23/23 [00:31<00:00,  1.36s/it]


Epoch 40, Loss: 0.0905, Training Accuracy: 73.24%

Training for epoch: 41


Training Batches: 100%|██████████| 23/23 [00:38<00:00,  1.68s/it]


Epoch 41, Loss: 0.0926, Training Accuracy: 70.14%

Training for epoch: 42


Training Batches: 100%|██████████| 23/23 [00:33<00:00,  1.44s/it]


Epoch 42, Loss: 0.0865, Training Accuracy: 73.92%

Training for epoch: 43


Training Batches: 100%|██████████| 23/23 [00:29<00:00,  1.27s/it]


Epoch 43, Loss: 0.0773, Training Accuracy: 77.30%

Training for epoch: 44


Training Batches: 100%|██████████| 23/23 [00:29<00:00,  1.30s/it]


Epoch 44, Loss: 0.0716, Training Accuracy: 79.32%

Training for epoch: 45


Training Batches: 100%|██████████| 23/23 [00:27<00:00,  1.20s/it]


Epoch 45, Loss: 0.0782, Training Accuracy: 78.38%

Training for epoch: 46


Training Batches: 100%|██████████| 23/23 [00:29<00:00,  1.27s/it]


Epoch 46, Loss: 0.0727, Training Accuracy: 78.38%

Training for epoch: 47


Training Batches: 100%|██████████| 23/23 [00:37<00:00,  1.61s/it]


Epoch 47, Loss: 0.0706, Training Accuracy: 81.89%

Training for epoch: 48


Training Batches: 100%|██████████| 23/23 [00:21<00:00,  1.06it/s]


Epoch 48, Loss: 0.0664, Training Accuracy: 82.16%

Training for epoch: 49


Training Batches: 100%|██████████| 23/23 [00:22<00:00,  1.04it/s]


Epoch 49, Loss: 0.0713, Training Accuracy: 80.81%

Training for epoch: 50


Training Batches: 100%|██████████| 23/23 [00:22<00:00,  1.01it/s]

Epoch 50, Loss: 0.0640, Training Accuracy: 83.11%





In [None]:
# ----------------------------
# Use CrossEntropyLoss instead of MSELoss
# ----------------------------
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
n_epochs_train = 50  # number of training epochs
batch_size = 32
# ----------------------------
# 8. Training the Simplified Model with Cross-Entropy Loss (Manual Batching)
# ----------------------------
model.train()
early_stop = False
for epoch in range(n_epochs_train):
    print(f"\nTraining for epoch: {epoch+1}")
    total_loss = 0.0
    training_correct = 0
    total_train_samples = 0

    num_batches = len(train_indices) // batch_size

    for batch_idx in tqdm.tqdm(range(num_batches), desc="Training Batches"):
        batch_indices = train_indices[batch_idx * batch_size : (batch_idx + 1) * batch_size]
        batch_segments = []
        batch_labels = []
        for idx in batch_indices:
            segment, label = dataset[idx]
            # Normalize using the global mean and std (per channel)
            segment = (segment - global_mean[:, None]) / global_std[:, None]
            batch_segments.append(segment)
            batch_labels.append(label)
        batch_segments = torch.stack(batch_segments)  # shape: [batch_size, n_channels, time_steps]
        batch_labels = torch.tensor(batch_labels, dtype=torch.long)

        batch_segments = batch_segments.to(device, non_blocking=True)
        batch_labels = batch_labels.to(device, non_blocking=True)

        optimizer.zero_grad()
        outputs = model(batch_segments)  # outputs: raw logits of shape [batch_size, n_classes]
        loss = criterion(outputs, batch_labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * batch_segments.size(0)

        # Compute predictions and accumulate correct predictions
        preds = outputs.argmax(dim=1)
        training_correct += (preds == batch_labels).sum().item()
        total_train_samples += batch_segments.size(0)

    epoch_loss = total_loss / len(train_indices)
    training_accuracy = training_correct / len(train_indices) * 100
    print(f"Epoch {epoch+1}, Loss: {epoch_loss:.4f}, Training Accuracy: {training_accuracy:.2f}%")

    # Early stopping condition (optional): stop if training accuracy >= 85%
    if training_accuracy >= 85:
        print("Training accuracy above 85%. Early stopping!")
        print("Model parameters:")
        for name, param in model.named_parameters():
            print(f"{name}: {param.data}")
        early_stop = True
        break


Training for epoch: 1


  segment = (segment - global_mean[:, None]) / global_std[:, None]
Training Batches: 100%|██████████| 23/23 [00:21<00:00,  1.07it/s]


Epoch 1, Loss: 1.0935, Training Accuracy: 76.89%

Training for epoch: 2


Training Batches: 100%|██████████| 23/23 [00:23<00:00,  1.00s/it]


Epoch 2, Loss: 0.8809, Training Accuracy: 81.08%

Training for epoch: 3


Training Batches: 100%|██████████| 23/23 [00:25<00:00,  1.10s/it]


Epoch 3, Loss: 0.7411, Training Accuracy: 83.51%

Training for epoch: 4


Training Batches: 100%|██████████| 23/23 [00:33<00:00,  1.46s/it]


Epoch 4, Loss: 0.7032, Training Accuracy: 81.89%

Training for epoch: 5


Training Batches: 100%|██████████| 23/23 [00:30<00:00,  1.32s/it]


Epoch 5, Loss: 0.6386, Training Accuracy: 83.51%

Training for epoch: 6


Training Batches: 100%|██████████| 23/23 [00:28<00:00,  1.25s/it]


Epoch 6, Loss: 0.6114, Training Accuracy: 82.97%

Training for epoch: 7


Training Batches: 100%|██████████| 23/23 [00:36<00:00,  1.57s/it]


Epoch 7, Loss: 0.7393, Training Accuracy: 75.41%

Training for epoch: 8


Training Batches: 100%|██████████| 23/23 [00:28<00:00,  1.24s/it]


Epoch 8, Loss: 0.7258, Training Accuracy: 77.30%

Training for epoch: 9


Training Batches: 100%|██████████| 23/23 [00:30<00:00,  1.34s/it]


Epoch 9, Loss: 0.6294, Training Accuracy: 79.73%

Training for epoch: 10


Training Batches: 100%|██████████| 23/23 [00:26<00:00,  1.13s/it]


Epoch 10, Loss: 0.7553, Training Accuracy: 75.41%

Training for epoch: 11


Training Batches: 100%|██████████| 23/23 [00:33<00:00,  1.46s/it]


Epoch 11, Loss: 0.6435, Training Accuracy: 78.38%

Training for epoch: 12


Training Batches: 100%|██████████| 23/23 [00:36<00:00,  1.60s/it]


Epoch 12, Loss: 0.5246, Training Accuracy: 84.46%

Training for epoch: 13


Training Batches: 100%|██████████| 23/23 [00:30<00:00,  1.35s/it]


Epoch 13, Loss: 0.4897, Training Accuracy: 86.08%
Training accuracy above 85%. Early stopping!
Model parameters:
cnn.0.weight: tensor([[[ 0.0645,  0.0694, -0.0294,  0.0577, -0.0710],
         [ 0.0186, -0.0334,  0.0479,  0.0703, -0.0517],
         [ 0.0977,  0.0463,  0.0880,  0.0424,  0.0686],
         ...,
         [ 0.0661,  0.0165,  0.0161, -0.0386,  0.0300],
         [-0.0047, -0.0280,  0.0532,  0.0052, -0.0963],
         [-0.0316,  0.0031,  0.0305, -0.0310, -0.0976]],

        [[ 0.0191, -0.0115,  0.0188,  0.0503, -0.0516],
         [ 0.0112,  0.0326,  0.0071, -0.0336,  0.0328],
         [ 0.0694,  0.0579,  0.0163,  0.0095,  0.0376],
         ...,
         [-0.0783, -0.0549,  0.0130, -0.0472, -0.0434],
         [ 0.0384,  0.0271, -0.0465, -0.0756, -0.0415],
         [ 0.0045,  0.1180, -0.0447, -0.0589,  0.0379]],

        [[ 0.0057,  0.0263,  0.0050, -0.0104, -0.1205],
         [ 0.0114,  0.0505, -0.0021, -0.0396,  0.0769],
         [-0.0499, -0.0647, -0.0899, -0.0184,  0.0117],
 

In [None]:
# If early stopping was triggered, evaluate on the test set immediately
if early_stop:
    model.eval()
    correct = 0
    with torch.no_grad():
        num_batches_eval = len(eval_indices) // batch_size
        for batch_idx in tqdm.tqdm(range(num_batches_eval), desc="Evaluation Batches"):
            batch_indices = eval_indices[batch_idx * batch_size : (batch_idx + 1) * batch_size]
            batch_segments = []
            batch_labels = []
            for idx in batch_indices:
                segment, label = dataset[idx]
                # Normalize using the same global statistics
                segment = (segment - global_mean[:, None]) / global_std[:, None]
                batch_segments.append(segment)
                batch_labels.append(label)
            batch_segments = torch.stack(batch_segments)
            batch_labels = torch.tensor(batch_labels, dtype=torch.long)
            batch_segments = batch_segments.to(device, non_blocking=True)
            batch_labels = batch_labels.to(device, non_blocking=True)
            outputs = model(batch_segments)
            preds = outputs.argmax(dim=1)
            correct += (preds == batch_labels).sum().item()

    test_accuracy = correct / (num_batches_eval * batch_size) * 100
    print(f"\nEvaluation Accuracy (Test Set): {test_accuracy:.2f}%")

  segment = (segment - global_mean[:, None]) / global_std[:, None]
Evaluation Batches: 100%|██████████| 5/5 [00:04<00:00,  1.04it/s]


Evaluation Accuracy (Test Set): 67.50%





In [None]:
# ---------------------------- Evaluation without stopping -------
# 9. Evaluate the Simplified Model with Rate Coding (Manual Batching)
# ----------------------------
model.eval()
correct = 0
with torch.no_grad():
    num_batches_eval = len(eval_indices) // batch_size
    for batch_idx in tqdm.tqdm(range(num_batches_eval), desc="Evaluation Batches"):
        batch_indices = eval_indices[batch_idx * batch_size : (batch_idx + 1) * batch_size]
        batch_segments = []
        batch_labels = []
        for idx in batch_indices:
            segment, label = dataset[idx]
            # Use the same global normalization
            segment = (segment - global_mean[:, None]) / global_std[:, None]
            batch_segments.append(segment)
            batch_labels.append(label)
        batch_segments = torch.stack(batch_segments)
        batch_labels = torch.tensor(batch_labels, dtype=torch.long)
        batch_segments = batch_segments.to(device, non_blocking=True)
        batch_labels = batch_labels.to(device, non_blocking=True)
        outputs = model(batch_segments)
        preds = outputs.argmax(dim=1)
        correct += (preds == batch_labels).sum().item()

accuracy = correct / (num_batches_eval * batch_size) * 100
print(f"\nEvaluation Accuracy: {accuracy:.2f}%")


  segment = (segment - global_mean[:, None]) / global_std[:, None]
Evaluation Batches: 100%|██████████| 5/5 [00:05<00:00,  1.16s/it]


Evaluation Accuracy: 67.50%





In [None]:
#Multiple other classifiers
# We'll define a helper function to extract features from the reservoir.
def extract_reservoir_features(model, indices, batch_size=32):
    model.eval()
    features_list = []
    labels_list = []
    with torch.no_grad():
        num_batches = len(indices) // batch_size
        for batch_idx in tqdm.tqdm(range(num_batches), desc="Extracting Features"):
            batch_indices = indices[batch_idx * batch_size : (batch_idx + 1) * batch_size]
            batch_segments = []
            batch_labels = []
            for idx in batch_indices:
                segment, label = dataset[idx]
                segment = (segment - global_mean[:, None]) / global_std[:, None]
                batch_segments.append(segment)
                batch_labels.append(label)
            batch_segments = torch.stack(batch_segments)  # shape: [batch_size, n_channels, time_steps]
            batch_labels = torch.tensor(batch_labels, dtype=torch.long)
            batch_segments = batch_segments.to(device, non_blocking=True)
            # Extract reservoir features directly from the reservoir module.
            features = model.reservoir(batch_segments)
            features_list.append(features.cpu().numpy())
            labels_list.append(batch_labels.cpu().numpy())
    X = np.concatenate(features_list, axis=0)
    y = np.concatenate(labels_list, axis=0)
    return X, y

# Extract features from training and evaluation sets
X_train, y_train = extract_reservoir_features(model, train_indices, batch_size)
X_eval, y_eval = extract_reservoir_features(model, eval_indices, batch_size)

# Train classical classifiers using scikit-learn.
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score

# --- 1. Logistic Regression ---
clf_lr = LogisticRegression(max_iter=1000)
clf_lr.fit(X_train, y_train)
y_pred_lr = clf_lr.predict(X_eval)
accuracy_lr = accuracy_score(y_eval, y_pred_lr)
print(f"\nLogistic Regression Accuracy: {accuracy_lr:.2f}")

# --- 2. Random Forest ---
clf_rf = RandomForestClassifier(n_estimators=100)
clf_rf.fit(X_train, y_train)
y_pred_rf = clf_rf.predict(X_eval)
accuracy_rf = accuracy_score(y_eval, y_pred_rf)
print(f"Random Forest Accuracy: {accuracy_rf:.2f}")

# --- 3. Support Vector Machine ---
clf_svm = SVC(kernel='linear')
clf_svm.fit(X_train, y_train)
y_pred_svm = clf_svm.predict(X_eval)
accuracy_svm = accuracy_score(y_eval, y_pred_svm)
print(f"SVM Accuracy: {accuracy_svm:.2f}")

  segment = (segment - global_mean[:, None]) / global_std[:, None]
Extracting Features:   0%|          | 0/23 [00:01<?, ?it/s]


RuntimeError: mat1 and mat2 shapes cannot be multiplied (32x35 and 64x512)