# Implement a Convolutional Spiking Neural Network with Izhikevich neuron model and STDP/R-STDP

#### This implementation references the following:

(1) Eugene M. Izhikevich, "Simple Model of Spiking Neurons," IEEE TNN, 2003

(2) Mozafari et al., "SpykeTorch: Efficient Simulation of CNNs With at Most One Spike per Neuron," Frontiers in Neuroscience, 2019

# 1. Setting Up the Environment

First we have to import all the necessary PyTorch and utility packages.

These libraries will be used to build the convolutional SNN with Izhikevich neurons and STDP/R-STDP learning.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import torchvision.transforms as transforms

import numpy as np
import copy

# 2. Global Configuration Parameters

We have to define the basic parameters for the simulation environment and the Izhikevich neuron model.

Key parameters:
- `TMAX`: Maximum number of time steps for spiking simulation
- Izhikevich model parameters for "Regular Spiking" neurons:
  - `a` : Time scale of recovery
  - `b` : Sensitivity of recovery
  - `c` : Post-spike reset value of membrane potential
  - `d` : Post-spike reset of recovery

These define the dynamics of the spiking neurons based on the Izhikevich 2003 paper.

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# This is the maximum number of time steps for the spiking simulation.
TMAX = 15

# The typical "Regular Spiking (RS)" is usually:
a_ = 0.02
b_ = 0.2
c_ = -65.0
d_ = 8.0

# 3. Converting Images to Spike Waves

The `ToSpikeWave` transform needs to be implemented in order to convert standard image data into temporal spike patterns.

- **Time-to-first-spike encoding**: Higher intensity pixels generate spikes earlier in the simulation
- **Accumulative spike representation**: Once a neuron spikes at time t, it remains "high" (1) for all subsequent time steps
- The output has shape [TMAX, channels, height, width] which represent the entire spike pattern over time

This is essential for processing static images through a temporal spiking network, which allows the STDP learning mechanisms to be applied.

In [3]:
class ToSpikeWave:
    """
    Converts a 2D (or 3D) intensity image into an accumulative spike-wave
    [TMAX, C, H, W].
    """
    def __init__(self, Tmax=15):
        self.Tmax = Tmax

    def __call__(self, img):
        if img.max() > 1.0:
            img = img / 255.0

        spike_times = (1.0 - img) * (self.Tmax - 1)
        spike_times = spike_times.clamp(0, self.Tmax - 1)

        wave = torch.zeros((self.Tmax,) + img.shape, dtype=torch.float)
        int_times = spike_times.mul(self.Tmax - 1).round().long()

        C, H, W = img.shape
        for c in range(C):
            for r in range(H):
                for w_ in range(W):
                    st = int_times[c, r, w_].item()
                    wave[st:, c, r, w_] = 1.0

        return wave

transform = transforms.Compose([
    transforms.ToTensor(),
    ToSpikeWave(Tmax=TMAX)
])

print("Spike transform created.")

Spike transform created.


# 4. Izhikevich Neuron Model Implementation

Now the Izhikevich neuron model will need to be implemented as a PyTorch module. It is a computationally efficient way to simulate biologically plausible neural dynamics.

The key equations implemented are:
- v' = 0.04v² + 5v + 140 - u + I
- u' = a(bv - u)
- If v ≥ 30mV, then v ← c, u ← u + d

Where:
- v is the membrane potential
- u is the recovery variable
- I is the input current

This is a simplified version of the model but it can still reproduce various firing patterns which can be seen in biological neurons (regular spiking, bursting, etc.) by adjusting a few of the parameters.

In [4]:
class IzhikevichLayer(nn.Module):
    """
    Maintains a population of Izhikevich neurons.
    Each forward pass -> single time-step update for all neurons.
    """
    def __init__(self, num_neurons, a=0.02, b=0.2, c=-65.0, d=8.0, init_v=-65.0):
        super().__init__()
        self.a = a
        self.b = b
        self.c = c
        self.d = d
        self.num_neurons = num_neurons

        self.register_buffer('v', torch.full((num_neurons,), init_v))
        self.register_buffer('u', self.b * self.v)

    def forward(self, input_current):
        dv = 0.04*self.v*self.v + 5*self.v + 140 - self.u + input_current
        du = self.a * (self.b*self.v - self.u)

        self.v = self.v + dv
        self.u = self.u + du

        spike_mask = (self.v >= 30.0)
        self.v[spike_mask] = self.c
        self.u[spike_mask] += self.d

        return spike_mask.float(), self.v

    def reset_state(self):
        self.v.fill_(self.c)
        self.u.fill_(self.b * self.c)

print("IzhikevichLayer defined.")


IzhikevichLayer defined.


# 5. Spiking Convolution and Pooling Operations

Now, the core neural processing operations for the SNN needs to be implemented. We need to:

1. **`SpikingConv2D`**: Apply the convolution to the spike wave across time steps
   - This takes input spike wave [T, Fin, H, W] and produces potentials [T, Fout, H', W']
   - It uses randomly initialized weights with mean=0.8, std=0.02

2. **`spiking_fire`**: Convert potential values to spike waves
   - This is when neurons fire when potential exceeds threshold
   - Once a neuron fires, it stays "on" for all subsequent time steps

3. **`spiking_pooling`**: Downsample spike waves using max-pooling
   - This preserves temporal information while reducing spatial dimensions

These form the building blocks of the convolutional SNN architecture.

In [5]:
class SpikingConv2D(nn.Module):
    """
    Convolution on accumulative spike-wave:
    input shape [T, Fin, H, W] => output potentials [T, Fout, H_out, W_out]
    """
    def __init__(self, in_channels, out_channels, kernel_size,
                 weight_mean=0.8, weight_std=0.02):
        super().__init__()
        w = torch.normal(mean=weight_mean, std=weight_std,
                         size=(out_channels, in_channels, kernel_size, kernel_size))
        self.weight = nn.Parameter(w)

    def forward(self, spike_wave):
        pot = F.conv2d(spike_wave, self.weight, bias=None,
                       stride=1, padding=0)
        return pot

def spiking_fire(potentials, threshold):
    """
    Convert potentials [T, F, H, W] to accumulative spike-wave [T, F, H, W].
    Once a neuron crosses threshold at time t, it's 1 for all subsequent t'.
    """
    T, C, H, W = potentials.shape
    spike_wave = torch.zeros_like(potentials)
    first_spike = torch.full((C, H, W), T, dtype=torch.long,
                             device=potentials.device)

    for t in range(T):
        above_thresh = (potentials[t] >= threshold)
        newly_spiked = above_thresh & (first_spike == T)
        first_spike[newly_spiked] = t

    for t in range(T):
        spike_wave[t] = (first_spike <= t).float()
    return spike_wave


def spiking_pooling(spike_wave, kernel_size=2, stride=2):
    """
    Basic 2D max-pooling on accumulative spike wave.
    shape => [T, C, H, W] => [T, C, H_out, W_out]
    """
    T, C, H, W = spike_wave.shape
    reshaped = spike_wave.view(T*C, 1, H, W)
    pooled = F.max_pool2d(reshaped, kernel_size, stride)
    Hout = pooled.shape[2]
    Wout = pooled.shape[3]
    pooled = pooled.view(T, C, Hout, Wout)
    return pooled

print("SpikingConv2D, spiking_fire, spiking_pooling (with fix).")

SpikingConv2D, spiking_fire, spiking_pooling (with fix).


# 6. STDP and Reward-Modulated STDP Learning Rules

Now we have to implement the Spike-Timing-Dependent Plasticity (STDP) and its reward-modulated version (R-STDP).

- **STDP**: Synaptic weights are updated based on the relative timing of pre- and post-synaptic spikes
  - Pre-before-post timing strengthens connections (A_plus factor)
  - Post-before-pre timing weakens connections (A_minus factor)
  - Weight updates are bounded between lower bound (lb) and upper bound (ub)

- **R-STDP**: Similar to STDP but modulated by a reward/punishment signal
  - The reward factor scales the weight updates
  - Positive reward reinforces the current pattern
  - Negative reward weakens the current pattern

These learning rules enable unsupervised (STDP) and reinforcement learning (R-STDP) in the spiking neural network.

In [7]:
def stdp_update(
    weight: torch.Tensor,
    pre_spike_time: torch.Tensor,
    post_spike_time: torch.Tensor,
    A_plus: float,
    A_minus: float,
    lb: float,
    ub: float
):
    # Convert post_spike_time to tensor if it's a float
    if isinstance(post_spike_time, (int, float)):
        post_spike_time = torch.tensor(post_spike_time, device=weight.device, dtype=torch.float)

    # Handle the dimensionality correctly
    if post_spike_time.dim() == 0:
        # Create a simple condition based on comparing with pre_spike_time
        cond = pre_spike_time <= post_spike_time
    else:
        # For non-scalar, reshape as needed
        cond = pre_spike_time <= post_spike_time.reshape(-1, 1, 1)

    factor = (weight - lb) * (ub - weight)

    weight_update = torch.zeros_like(weight)
    weight_update[cond] = A_plus * factor[cond]
    weight_update[~cond] = A_minus * factor[~cond]

    weight = weight + weight_update
    weight = torch.clamp(weight, lb, ub)
    return weight

def r_stdp_update(
    weight: torch.Tensor,
    pre_spike_time: torch.Tensor,
    post_spike_time: torch.Tensor,
    A_plus: float,
    A_minus: float,
    lb: float,
    ub: float,
    reward: float
):
    # Convert post_spike_time to tensor if it's a float
    if isinstance(post_spike_time, (int, float)):
        post_spike_time = torch.tensor(post_spike_time, device=weight.device, dtype=torch.float)

    # Handle the dimensionality correctly
    if post_spike_time.dim() == 0:
        # Create a simple condition based on comparing with pre_spike_time
        cond = pre_spike_time <= post_spike_time
    else:
        # For non-scalar, reshape as needed
        cond = pre_spike_time <= post_spike_time.reshape(-1, 1, 1)

    factor = (weight - lb) * (ub - weight)

    weight_update = torch.zeros_like(weight)
    weight_update[cond] = reward * A_plus * factor[cond]
    weight_update[~cond] = reward * A_minus * factor[~cond]

    weight = weight + weight_update
    weight = torch.clamp(weight, lb, ub)
    return weight

print("STDP and R-STDP update functions defined.")

STDP and R-STDP update functions defined.


# 7. Spike Processing and Winner Selection Functions

Now we define a bunch of helper functions for processing spike data and implementing the winner-take-all mechanisms:

1. **`first_spike_time_from_pot`**: Extracts the earliest time a potential crosses threshold
   - Essential for determining when a neuron first spikes

2. **`first_spike_time_from_wave`**: Extracts earliest spike times from a complete spike wave

3. **`get_k_winners`**: Implements a competitive winner-take-all mechanism
   - Selects k neurons with earliest spike times
   - Uses peak potential as a tiebreaker
   - Optional lateral inhibition with radius parameter
   - Winners will be the neurons that get to update their weights via STDP

These functions support the implementation of competitive learning and help process spike data for STDP updates.

In [8]:
def first_spike_time_from_pot(pot_t):
    """
    pot_t shape [T], earliest time pot_t(t) >= 0 => t, else T
    """
    thr = 0.0
    idx = (pot_t >= thr).nonzero(as_tuple=True)[0]
    if len(idx) == 0:
        return pot_t.shape[0]
    return float(idx[0].item())

def first_spike_time_from_wave(spike_wave):
    """
    spike_wave: [T, C, H, W], return earliest spike time => shape [C, H, W].
    """
    T, C, H, W = spike_wave.shape
    out = torch.full((C, H, W), T, dtype=torch.float, device=spike_wave.device)
    for t in range(T):
        mask = (spike_wave[t] >= 0.5)
        out[mask] = torch.minimum(out[mask], torch.full_like(out[mask], float(t)))
    return out

def get_k_winners(pot, k=5, radius=2):
    """
    pot [T, C, H, W], pick k neurons with earliest spike times.
    If tie => use peak potential as tiebreak (like partial).
    """
    T, C, H, W = pot.shape
    spike_times = torch.full((C, H, W), T, device=pot.device)
    peak_pot = torch.zeros((C, H, W), device=pot.device)

    for t in range(T):
        slice_ = pot[t]
        mask = (slice_ >= 0.0)
        spike_times[mask] = torch.minimum(spike_times[mask], torch.full_like(spike_times[mask], float(t)))
        peak_pot = torch.max(peak_pot, slice_)

    st_flat = spike_times.view(-1)
    pk_flat = peak_pot.view(-1)

    coords = [(i // (H*W), (i % (H*W)) // W, (i % (H*W)) % W) for i in range(C*H*W)]
    sorted_indices = sorted(range(C*H*W),
                            key=lambda i: (st_flat[i].item(), -pk_flat[i].item()))
    winners = []
    used = torch.zeros((C, H, W), dtype=torch.bool)

    for i in sorted_indices:
        if len(winners) >= k:
            break
        f, r, c = coords[i]
        if radius>0:
            rr_min = max(0, r-radius)
            rr_max = min(H, r+radius+1)
            cc_min = max(0, c-radius)
            cc_max = min(W, c+radius+1)
            if used[f, rr_min:rr_max, cc_min:cc_max].any():
                continue
            else:
                used[f, rr_min:rr_max, cc_min:cc_max] = True
        winners.append((f, r, c))
    return winners

def get_better_winners(pot, k=5, radius=2, class_balance=True):
    """
    Enhanced winner selection with better class balance
    pot [T, C, H, W], pick k neurons with earliest spike times.
    If class_balance=True, tries to select winners from different feature maps
    """
    T, C, H, W = pot.shape
    spike_times = torch.full((C, H, W), T, device=pot.device)
    peak_pot = torch.zeros((C, H, W), device=pot.device)

    # Calculate spike times and peak potentials
    for t in range(T):
        slice_ = pot[t]
        mask = (slice_ >= 0.0)
        spike_times[mask] = torch.minimum(spike_times[mask], torch.full_like(spike_times[mask], float(t)))
        peak_pot = torch.max(peak_pot, slice_)

    st_flat = spike_times.view(-1)
    pk_flat = peak_pot.view(-1)

    # Create coordinates list for all potential neurons
    coords = [(i // (H*W), (i % (H*W)) // W, (i % (H*W)) % W) for i in range(C*H*W)]

    # Sort by spike time (primary) and peak potential (secondary)
    sorted_indices = sorted(range(C*H*W),
                           key=lambda i: (st_flat[i].item(), -pk_flat[i].item()))

    winners = []
    used = torch.zeros((C, H, W), dtype=torch.bool, device=pot.device)
    class_count = torch.zeros(C, dtype=torch.int, device=pot.device)

    max_per_class = k // 3 + 1 if class_balance else k  # Limit winners per class if balancing

    for i in sorted_indices:
        if len(winners) >= k:
            break

        f, r, c = coords[i]

        # Skip if we've reached max count for this class and we want class balance
        if class_balance and class_count[f] >= max_per_class:
            continue

        # Check for inhibition zone
        if radius > 0:
            rr_min = max(0, r-radius)
            rr_max = min(H, r+radius+1)
            cc_min = max(0, c-radius)
            cc_max = min(W, c+radius+1)

            if used[f, rr_min:rr_max, cc_min:cc_max].any():
                continue
            else:
                used[f, rr_min:rr_max, cc_min:cc_max] = True

        winners.append((f, r, c))
        class_count[f] += 1

    return winners

print("Helper functions for spikes, winners, times.")


Helper functions for spikes, winners, times.


# 8. Deep Convolutional SNN Architecture

Now we have to implement the complete three-layer Spiking CNN architecture with:

1. **Layer Structure**:
   - Layer 1: Convolutional layer + STDP learning
   - Pooling
   - Layer 2: Convolutional layer + STDP learning
   - Pooling
   - Layer 3: Output layer + R-STDP for classification

2. **Learning Methods**:
   - `forward_inference`: Regular forward pass for inference
   - `forward_learn`: Forward pass with learning for a specific layer
   - `apply_r_stdp`: Apply reward-modulated STDP
   - `apply_r_stdp_direct`: Direct application of R-STDP to target neurons
   - `stdp_update_layer`: Core function for updating weights with STDP

This integrates all the previously defined components into a complete SNN architecture capable of unsupervised and reinforcement learning.

In [9]:
class DeepConvSNN(nn.Module):
    """
    3-layer conv spiking net, approach from Mozafari et al. (2019).
    """
    def __init__(
        self,
        in_channels=1,
        layer1_channels=30, kernel1=5,
        layer2_channels=100, kernel2=3,
        layer3_channels=10, kernel3=5,
        a=0.02, b=0.2, c=-65.0, d=8.0,
        A_plus=0.004, A_minus=-0.003,
        A_plus_r=0.004, A_minus_r=-0.003,
        lb=0.0, ub=1.0,
        reward_val=+1.0, punish_val=-1.0
    ):
        super().__init__()
        self.conv1 = SpikingConv2D(in_channels, layer1_channels, kernel1)
        self.conv2 = SpikingConv2D(layer1_channels, layer2_channels, kernel2)
        self.conv3 = SpikingConv2D(layer2_channels, layer3_channels, kernel3)

        self.A_plus = A_plus
        self.A_minus = A_minus
        self.A_plus_r = A_plus_r
        self.A_minus_r = A_minus_r
        self.lb = lb
        self.ub = ub
        self.reward_val = reward_val
        self.punish_val = punish_val

    def forward_inference(self, spike_wave):
        pot1 = self.conv1(spike_wave)
        spk1 = spiking_fire(pot1, 10.0)
        spk1_pool = spiking_pooling(spk1, 2, 2)

        pot2 = self.conv2(spk1_pool)
        spk2 = spiking_fire(pot2, 5.0)
        spk2_pool = spiking_pooling(spk2, 2, 2)

        pot3 = self.conv3(spk2_pool)
        return pot3

    def forward_learn(self, spike_wave, layer_to_learn):
        """
        Single sample => train a single layer with STDP or final layer with R-STDP.
        """
        pot1 = self.conv1(spike_wave)
        spk1 = spiking_fire(pot1, 5.0)
        spk1_pool = spiking_pooling(spk1, 2, 2)

        if layer_to_learn == 1:
            winners = get_k_winners(pot1, 20, 2)
            self.stdp_update_layer(self.conv1, spike_wave, pot1, winners, r_stdp=False)
            return

        pot2 = self.conv2(spk1_pool)
        spk2 = spiking_fire(pot2, 3.0)
        spk2_pool = spiking_pooling(spk2, 2, 2)

        if layer_to_learn == 2:
            winners = get_k_winners(pot2, 30, 1)
            self.stdp_update_layer(self.conv2, spk1_pool, pot2, winners, r_stdp=False)
            return

        pot3 = self.conv3(spk2_pool)
        return pot3

    def apply_r_stdp(self, spike_wave, label, predicted):
        pot3 = self.forward_inference(spike_wave)
        winners = get_k_winners(pot3, k=1, radius=0)
        rew = self.reward_val if (predicted == label) else self.punish_val
        self.stdp_update_layer(self.conv3, None, pot3, winners, r_stdp=True, reward=rew)

    def apply_r_stdp_direct(self, spike_wave, pot3, winners, reward):
        """Direct training of specific neurons with stronger reward"""
        # Get the potentials from previous layers
        pot1 = self.conv1(spike_wave)
        spk1 = spiking_fire(pot1, 8.0)
        spk1_pool = spiking_pooling(spk1, 2, 2)

        pot2 = self.conv2(spk1_pool)
        spk2 = spiking_fire(pot2, 6.0)
        spk2_pool = spiking_pooling(spk2, 2, 2)

        for (fout, rr, cc) in winners:
            # Get pre-synaptic spike times for the specific kernel window
            # Extract the correct kernel-sized window from spk2_pool based on the kernel size
            kH, kW = self.conv3.weight.shape[2], self.conv3.weight.shape[3]  # Get actual kernel dimensions

            # Make sure the window is properly sized
            if rr + kH <= spk2_pool.shape[2] and cc + kW <= spk2_pool.shape[3]:
                # Extract exactly the patch that would be used in convolution
                pre_patch = spk2_pool[:, :, rr:rr+kH, cc:cc+kW]
                T_pre = first_spike_time_from_wave(pre_patch)

                # Use early spike time for the target neuron
                T_post = 5.0 if reward > 0 else 10.0

                # Apply stronger weight updates
                self.conv3.weight.data[fout] = r_stdp_update(
                    self.conv3.weight.data[fout],
                    T_pre, T_post,
                    self.A_plus_r*2, self.A_minus_r*2,
                    self.lb, self.ub,
                    reward
                )


    def stdp_update_layer(self, conv_layer, input_spike_wave, pot, winners, r_stdp=False, reward=0.0):
        W = conv_layer.weight.data
        kH = W.shape[2]
        kW = W.shape[3]

        for (fout, rr, cc) in winners:
            T_post = first_spike_time_from_pot(pot[:, fout, rr, cc])
            if input_spike_wave is not None:
                pre_patch = input_spike_wave[:, :, rr:rr+kH, cc:cc+kW]
                T_pre = first_spike_time_from_wave(pre_patch)
            else:
                T_pre = torch.zeros((W.shape[1], kH, kW), device=W.device)

            if not r_stdp:
                W[fout] = stdp_update(W[fout], T_pre, T_post,
                                      self.A_plus, self.A_minus, self.lb, self.ub)
            else:
                W[fout] = r_stdp_update(W[fout], T_pre, T_post,
                                        self.A_plus_r, self.A_minus_r, self.lb, self.ub, reward)
        conv_layer.weight.data = W

print("DeepConvSNN created (with fix).")

DeepConvSNN created (with fix).


# 9. Training Pipeline and Utilities

Now we set up the complete training pipeline for the SNN:

1. **Data Loading**:
   - MNIST dataset with our custom spike encoding transform
   - Training and testing data loaders

2. **Model Instantiation**:
   - Creates the DeepConvSNN with specific parameters
   - Configures learning rates, weight bounds, etc.

3. **Training Functions**:
   - `predict_class`: Infers the class from layer 3 potentials
   - `train_layer`: Trains a specific layer with STDP
   - `train_layer3_rstdp`: Trains the classification layer with R-STDP
   - `test_accuracy`: Evaluates model performance
   - `reset_model`: Reinitializes weights for fresh training
   - `train_on_mistakes`: Focused training on misclassified examples

These functions provide a complete pipeline for layerwise training and evaluation of the spiking neural network.

In [10]:
# Datasets
train_dataset = MNIST(root='.', train=True, download=True, transform=transform)
test_dataset = MNIST(root='.', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

# Instantiate net
model = DeepConvSNN(
    in_channels=1,
    layer1_channels=30, kernel1=5,
    layer2_channels=100, kernel2=3,
    layer3_channels=10, kernel3=3,
    a=a_, b=b_, c=c_, d=d_,
    A_plus=0.01, A_minus=-0.008,
    A_plus_r=0.01, A_minus_r=-0.008,
    lb=0.2, ub=0.8,
    reward_val=+1.0, punish_val=-1.0
).to(device)


def predict_class(pot3):
    """
    Enhanced version - calculate spike count and timing together
    """
    T, C, H, W = pot3.shape
    class_scores = torch.zeros(C, device=pot3.device)

    for f in range(C):
        # Maximum activation across spatial dimensions
        channel_pot = pot3[:, f]
        spike_mask = (channel_pot > 0).float()
        spike_count = spike_mask.sum()

        # Timing factor - earlier spikes are better
        min_time = T
        for t in range(T):
            if spike_mask[t].sum() > 0:
                min_time = t
                break

        # Combine factors - more spikes and earlier timing is better
        class_scores[f] = spike_count * (T - min_time)

    return torch.argmax(class_scores).item()


def train_layer(model, loader, layer_idx, epochs=1, max_images=1000):
    """
    Train a specific layer of the model
    max_images: limit the number of training examples (set to None to use all)
    """
    model.train()
    for ep in range(epochs):
        for i, (data_spike_wave, target) in enumerate(loader):
            if i % 10 == 0:
                print(f"Processing image {i}...")

            # Process the image
            data_spike_wave = data_spike_wave.to(device)
            model.forward_learn(data_spike_wave.squeeze(0), layer_to_learn=layer_idx)

            # Limit the number of training examples for faster testing
            if max_images is not None and i >= max_images-1:
                print(f"Reached maximum images ({max_images}), stopping training")
                return

def train_layer3_rstdp(model, loader, epochs=1, max_images=1000):
    model.train()
    for ep in range(epochs):
        correct_during_training = 0
        total_during_training = 0

        for i, (data_spike_wave, target) in enumerate(loader):
            if i % 10 == 0:
                print(f"Processing image {i}...")

            data_spike_wave = data_spike_wave.to(device)
            pot3 = model.forward_inference(data_spike_wave.squeeze(0))
            pred_class_ = predict_class(pot3)
            lbl = target.item()

            # Apply stronger reward/punishment based on correctness
            reward_val = 1.5 if pred_class_ == lbl else -1.5

            # Force activation of the correct class neuron
            winners = [(lbl, 0, 0)]  # Always update the weights for the correct class
            model.apply_r_stdp_direct(data_spike_wave.squeeze(0), pot3, winners, reward_val)

            # Track accuracy during training
            if pred_class_ == lbl:
                correct_during_training += 1
            total_during_training += 1

            if i % 100 == 99:
                print(f"Current training accuracy: {100*correct_during_training/total_during_training:.2f}%")

            if max_images is not None and i >= max_images-1:
                print(f"Reached maximum images ({max_images}), stopping training")
                return

def train_layer3_rstdp_improved(model, loader, epochs=1, max_images=1000):
    """Enhanced R-STDP training for layer 3 with adaptive rewards"""
    model.train()

    correct_during_training = 0
    total_during_training = 0
    running_loss = 0.0

    for ep in range(epochs):
        print(f"Epoch {ep+1}/{epochs}")

        for i, (data_spike_wave, target) in enumerate(loader):
            if i % 10 == 0:
                print(f"Processing image {i}...")

            data_spike_wave = data_spike_wave.to(device)
            pot3 = model.forward_inference(data_spike_wave.squeeze(0))
            pred_class_ = predict_class(pot3)
            lbl = target.item()

            # Use the improved direct R-STDP method with both label and prediction
            model.apply_r_stdp_direct_improved(
                data_spike_wave.squeeze(0), pot3, lbl, pred_class_
            )

            # Track performance
            if pred_class_ == lbl:
                correct_during_training += 1
            total_during_training += 1

            # More frequent progress reporting
            if i % 50 == 49:
                current_acc = 100 * correct_during_training / total_during_training
                print(f"Current training accuracy: {current_acc:.2f}% [{i+1}/{max_images}]")

            # Stop if we've processed enough images
            if max_images is not None and i >= max_images-1:
                print(f"Reached maximum images ({max_images}), stopping training")
                return

        epoch_acc = 100 * correct_during_training / total_during_training
        print(f"Epoch {ep+1} completed. Training accuracy: {epoch_acc:.2f}%")

def test_accuracy(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for i, (data_spike_wave, target) in enumerate(loader):
            data_spike_wave = data_spike_wave.to(device)
            pot3 = model.forward_inference(data_spike_wave.squeeze(0))
            pred_class_ = predict_class(pot3)
            if pred_class_ == target.item():
                correct += 1
            total += 1
    return 100.0 * correct / total

def reset_model(model):
    """Reset model weights for a fresh training run"""
    print("Resetting model weights for a new training attempt...")
    for m in model.modules():
        if isinstance(m, SpikingConv2D):
            # Initialize with new random weights - ensure they're on the same device as the model
            device = m.weight.device
            w = torch.normal(mean=0.8, std=0.02, size=m.weight.shape, device=device)
            m.weight.data = w
    return model

def train_on_mistakes(model, loader, max_mistakes=200):
    print("Training specifically on misclassified examples...")
    mistakes_trained = 0
    model.train()

    for data_spike_wave, target in loader:
        data_spike_wave = data_spike_wave.to(device)
        pot3 = model.forward_inference(data_spike_wave.squeeze(0))
        pred = predict_class(pot3)

        if pred != target.item():
            lbl = target.item()
            winners = [(lbl, 0, 0)]
            reward_val = 2.0

            model.apply_r_stdp_direct(data_spike_wave.squeeze(0), pot3, winners, reward_val)

            mistakes_trained += 1
            if mistakes_trained % 10 == 0:
                print(f"Trained on {mistakes_trained} mistakes...")

            if mistakes_trained >= max_mistakes:
                print(f"Reached maximum mistakes ({max_mistakes}), stopping training")
                break

    return mistakes_trained

print("Data loaded, model created, and training/testing functions defined.")

100%|██████████| 9.91M/9.91M [00:01<00:00, 5.10MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 135kB/s]
100%|██████████| 1.65M/1.65M [00:01<00:00, 1.28MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 6.63MB/s]


Data loaded, model created, and training/testing functions defined.


# 10. Model Training and Evaluation

Now we execute the complete training procedure for the SNN:

1. **Multi-attempt Training**:
   - Tries multiple training runs with different weight initializations
   - Keeps track of the best-performing model

2. **Layer-wise Training Process**:
   - Unsupervised STDP training for layer 1
   - Unsupervised STDP training for layer 2
   - Reward-modulated STDP training for layer 3

3. **Performance Enhancement**:
   - Focused training on misclassified examples
   - Final accuracy evaluation

The training process follows the approach from Mozafari et al. (2019), with additions to improve performance through multiple attempts and targeted training.

In [12]:
# Run multiple training attempts to find the best model
best_accuracy = 0
best_model = None

for attempt in range(3):
    print(f"\n=== Training Attempt {attempt+1} ===")

    # Reset model for new attempts
    if attempt > 0:
        model = reset_model(model)

    # Train layer by layer
    print(f"=== Unsupervised training layer 1 (Attempt {attempt+1}) ===")
    train_layer(model, train_loader, layer_idx=1, epochs=2, max_images=500)
    print(f"Done training layer 1 (Attempt {attempt+1})")

    print(f"=== Unsupervised training layer 2 (Attempt {attempt+1}) ===")
    train_layer(model, train_loader, layer_idx=2, epochs=2, max_images=500)
    print(f"Done training layer 2 (Attempt {attempt+1})")

    print(f"=== R-STDP training layer 3 (Attempt {attempt+1}) ===")
    train_layer3_rstdp(model, train_loader, epochs=2, max_images=500)
    print(f"Done training layer 3 with R-STDP (Attempt {attempt+1})")

    # Test accuracy
    acc = test_accuracy(model, test_loader)
    print(f"Training attempt {attempt+1} accuracy: {acc:.2f}%")

    if acc > best_accuracy:
        best_accuracy = acc
        best_model = copy.deepcopy(model)
        print(f"New best model found! Accuracy: {best_accuracy:.2f}%")

# Use the best model for final results
model = best_model
print(f"\nBest model accuracy: {best_accuracy:.2f}%")

# Apply focused training on mistakes to the best model
mistakes_fixed = train_on_mistakes(model, train_loader, max_mistakes=100)
print(f"Trained on {mistakes_fixed} misclassified examples")

# Final accuracy check
final_acc = test_accuracy(model, test_loader)
print(f"Final accuracy after mistake correction: {final_acc:.2f}%")


=== Training Attempt 1 ===
=== Unsupervised training layer 1 (Attempt 1) ===
Processing image 0...
Processing image 10...
Processing image 20...
Processing image 30...
Processing image 40...
Processing image 50...
Processing image 60...
Processing image 70...
Processing image 80...
Processing image 90...
Processing image 100...
Processing image 110...
Processing image 120...
Processing image 130...
Processing image 140...
Processing image 150...
Processing image 160...
Processing image 170...
Processing image 180...
Processing image 190...
Processing image 200...
Processing image 210...
Processing image 220...
Processing image 230...
Processing image 240...
Processing image 250...
Processing image 260...
Processing image 270...
Processing image 280...
Processing image 290...
Processing image 300...
Processing image 310...
Processing image 320...
Processing image 330...
Processing image 340...
Processing image 350...
Processing image 360...
Processing image 370...
Processing image 380.

# 11. Improved Training Strategy

Previous methods accomplished baseline accuracy, this method attempts to improve the training process and hopefully increase the accuracy. We will perform 3 attempts per layer with limited GPU capabilities

In [11]:
def improved_training_pipeline(model, train_loader, test_loader,
                              layer1_attempts=3, layer1_epochs=3, layer1_samples=1000,
                              layer2_attempts=3, layer2_epochs=3, layer2_samples=1000,
                              layer3_attempts=3, layer3_epochs=3, layer3_samples=1000,
                              mistake_correction_rounds=3, mistakes_per_round=200):
    """
    Enhanced training pipeline with multiple attempts per layer and iterative mistake correction
    """
    print("Starting improved training pipeline with multiple attempts per layer...")

    # Train layer 1 with multiple attempts
    print("=== Training layer 1 ===")
    best_layer1_model = None
    best_layer1_acc = 0

    for attempt in range(layer1_attempts):
        print(f"Layer 1 - Attempt {attempt+1}/{layer1_attempts}")

        # Reset layer 1 weights for each new attempt (except first)
        if attempt > 0:
            device = model.conv1.weight.device  # Get the correct device
            w = torch.normal(mean=0.8, std=0.02, size=model.conv1.weight.shape, device=device)
            model.conv1.weight.data = w

        # Train layer 1
        train_layer(model, train_loader, layer_idx=1, epochs=layer1_epochs, max_images=layer1_samples)

        # Evaluate after layer 1
        acc = test_accuracy(model, test_loader)
        print(f"Layer 1 - Attempt {attempt+1} accuracy: {acc:.2f}%")

        # Keep track of best model
        if acc > best_layer1_acc:
            best_layer1_acc = acc
            best_layer1_model = copy.deepcopy(model)
            print(f"New best layer 1 model! Accuracy: {best_layer1_acc:.2f}%")

    # Use the best layer 1 model
    model = best_layer1_model
    print(f"Best layer 1 model accuracy: {best_layer1_acc:.2f}%")

    # Train layer 2 with multiple attempts
    print("\n=== Training layer 2 ===")
    best_layer2_model = None
    best_layer2_acc = 0

    for attempt in range(layer2_attempts):
        print(f"Layer 2 - Attempt {attempt+1}/{layer2_attempts}")

        # Reset layer 2 weights for each new attempt (except first)
        if attempt > 0:
            device = model.conv2.weight.device  # Get the correct device
            w = torch.normal(mean=0.8, std=0.02, size=model.conv2.weight.shape, device=device)
            model.conv2.weight.data = w

        # Train layer 2
        train_layer(model, train_loader, layer_idx=2, epochs=layer2_epochs, max_images=layer2_samples)

        # Evaluate after layer 2
        acc = test_accuracy(model, test_loader)
        print(f"Layer 2 - Attempt {attempt+1} accuracy: {acc:.2f}%")

        # Keep track of best model
        if acc > best_layer2_acc:
            best_layer2_acc = acc
            best_layer2_model = copy.deepcopy(model)
            print(f"New best layer 2 model! Accuracy: {best_layer2_acc:.2f}%")

    # Use the best layer 2 model
    model = best_layer2_model
    print(f"Best layer 2 model accuracy: {best_layer2_acc:.2f}%")

    # Train layer 3 with multiple attempts
    print("\n=== Training layer 3 with R-STDP ===")
    best_layer3_model = None
    best_layer3_acc = 0

    for attempt in range(layer3_attempts):
        print(f"Layer 3 - Attempt {attempt+1}/{layer3_attempts}")

        # Reset layer 3 weights for each new attempt (except first)
        if attempt > 0:
            device = model.conv3.weight.device  # Get the correct device
            w = torch.normal(mean=0.8, std=0.02, size=model.conv3.weight.shape, device=device)
            model.conv3.weight.data = w

        # Train layer 3
        train_layer3_rstdp(model, train_loader, epochs=layer3_epochs, max_images=layer3_samples)

        # Evaluate after layer 3
        acc = test_accuracy(model, test_loader)
        print(f"Layer 3 - Attempt {attempt+1} accuracy: {acc:.2f}%")

        # Keep track of best model
        if acc > best_layer3_acc:
            best_layer3_acc = acc
            best_layer3_model = copy.deepcopy(model)
            print(f"New best layer 3 model! Accuracy: {best_layer3_acc:.2f}%")

    # Use the best layer 3 model
    model = best_layer3_model
    print(f"Best layer 3 model accuracy: {best_layer3_acc:.2f}%")

    # Iterative mistake correction
    for round_idx in range(mistake_correction_rounds):
        print(f"\n=== Mistake correction round {round_idx+1}/{mistake_correction_rounds} ===")
        mistakes_fixed = train_on_mistakes(model, train_loader, max_mistakes=mistakes_per_round)
        print(f"Trained on {mistakes_fixed} misclassified examples")

        # Evaluate after mistake correction
        acc_after_correction = test_accuracy(model, test_loader)
        print(f"Accuracy after correction round {round_idx+1}: {acc_after_correction:.2f}%")

    # Final evaluation
    final_acc = test_accuracy(model, test_loader)
    print(f"Final model accuracy: {final_acc:.2f}%")

    return model, final_acc

# Run the improved training with multiple attempts per layer
print("\n=== Starting Improved Training Run with Multiple Attempts ===")
model, improved_acc = improved_training_pipeline(
    model,
    train_loader,
    test_loader,
    layer1_attempts=3,
    layer1_epochs=2,
    layer1_samples=1000,
    layer2_attempts=3,
    layer2_epochs=2,
    layer2_samples=1000,
    layer3_attempts=3,
    layer3_epochs=2,
    layer3_samples=1000,
    mistake_correction_rounds=2,
    mistakes_per_round=200
)
print(f"Improved training achieved {improved_acc:.2f}% accuracy")


=== Starting Improved Training Run with Multiple Attempts ===
Starting improved training pipeline with multiple attempts per layer...
=== Training layer 1 ===
Layer 1 - Attempt 1/3
Processing image 0...
Processing image 10...
Processing image 20...
Processing image 30...
Processing image 40...
Processing image 50...
Processing image 60...
Processing image 70...
Processing image 80...
Processing image 90...
Processing image 100...
Processing image 110...
Processing image 120...
Processing image 130...
Processing image 140...
Processing image 150...
Processing image 160...
Processing image 170...
Processing image 180...
Processing image 190...
Processing image 200...
Processing image 210...
Processing image 220...
Processing image 230...
Processing image 240...
Processing image 250...
Processing image 260...
Processing image 270...
Processing image 280...
Processing image 290...
Processing image 300...
Processing image 310...
Processing image 320...
Processing image 330...
Processing im

# 12. A more Flexible DeepConvSNN to allow for > 3 Layers

In [None]:
class FlexibleDeepConvSNN(nn.Module):
    """
    Flexible SNN with variable depth, supporting 3-5 layers
    """
    def __init__(
        self,
        in_channels=1,
        layer_channels=[30, 100, 10],  # Default is 3 layers
        kernel_sizes=[5, 3, 3],
        pool_sizes=[2, 2, 1],          # No pooling after final layer by default
        pool_strides=[2, 2, 1],
        thresholds=[5.0, 3.0, 1.0],    # Firing thresholds for each layer
        a=0.02, b=0.2, c=-65.0, d=8.0,
        A_plus=0.004, A_minus=-0.003,
        A_plus_r=0.004, A_minus_r=-0.003,
        lb=0.0, ub=1.0,
        reward_val=+1.0, punish_val=-1.0
    ):
        super().__init__()

        # Validate inputs
        assert len(layer_channels) >= 3 and len(layer_channels) <= 5, "Model supports 3-5 layers"
        assert len(kernel_sizes) == len(layer_channels), "Must provide kernel size for each layer"
        assert len(pool_sizes) == len(layer_channels), "Must provide pool size for each layer"
        assert len(pool_strides) == len(layer_channels), "Must provide pool stride for each layer"
        assert len(thresholds) == len(layer_channels), "Must provide threshold for each layer"

        self.num_layers = len(layer_channels)
        self.layer_channels = layer_channels
        self.kernel_sizes = kernel_sizes
        self.pool_sizes = pool_sizes
        self.pool_strides = pool_strides
        self.thresholds = thresholds

        # Create convolutional layers
        self.conv_layers = nn.ModuleList()

        # First layer
        self.conv_layers.append(SpikingConv2D(in_channels, layer_channels[0], kernel_sizes[0]))

        # Remaining layers
        for i in range(1, self.num_layers):
            self.conv_layers.append(
                SpikingConv2D(layer_channels[i-1], layer_channels[i], kernel_sizes[i])
            )

        # STDP parameters
        self.A_plus = A_plus
        self.A_minus = A_minus
        self.A_plus_r = A_plus_r
        self.A_minus_r = A_minus_r
        self.lb = lb
        self.ub = ub
        self.reward_val = reward_val
        self.punish_val = punish_val

    def forward_inference(self, spike_wave):
        """
        Forward pass for inference through all layers
        """
        x = spike_wave
        potentials = []

        for i in range(self.num_layers):
            # Apply convolution
            pot = self.conv_layers[i](x)
            potentials.append(pot)

            # Apply firing and pooling (except for the last layer)
            spk = spiking_fire(pot, self.thresholds[i])

            # Apply pooling if needed
            if self.pool_sizes[i] > 1:
                x = spiking_pooling(spk, self.pool_sizes[i], self.pool_strides[i])
            else:
                x = spk

        # Return the final layer potential for classification
        return potentials[-1]

    def forward_learn(self, spike_wave, layer_to_learn):
        """
        Forward pass with learning for a specific layer
        """
        if layer_to_learn < 1 or layer_to_learn > self.num_layers:
            raise ValueError(f"Layer to learn must be between 1 and {self.num_layers}")

        x = spike_wave
        potentials = []

        # Process up to the layer we want to learn
        for i in range(layer_to_learn):
            # Apply convolution
            pot = self.conv_layers[i](x)
            potentials.append(pot)

            # If this is the layer to learn, apply STDP and return
            if i+1 == layer_to_learn:
                winners = get_better_winners(pot, k=20 if i==0 else 30, radius=2 if i==0 else 1)
                self.stdp_update_layer(self.conv_layers[i], x, pot, winners, r_stdp=False)
                return

            # Otherwise continue forward pass
            spk = spiking_fire(pot, self.thresholds[i])

            # Apply pooling if needed
            if self.pool_sizes[i] > 1:
                x = spiking_pooling(spk, self.pool_sizes[i], self.pool_strides[i])
            else:
                x = spk

    def apply_r_stdp(self, spike_wave, label, predicted):
        """
        Apply R-STDP to the final layer based on prediction correctness
        """
        pot_final = self.forward_inference(spike_wave)
        winners = get_k_winners(pot_final, k=1, radius=0)
        rew = self.reward_val if (predicted == label) else self.punish_val

        # Get the input to the final layer
        x = spike_wave
        for i in range(self.num_layers - 1):
            pot = self.conv_layers[i](x)
            spk = spiking_fire(pot, self.thresholds[i])
            if self.pool_sizes[i] > 1:
                x = spiking_pooling(spk, self.pool_sizes[i], self.pool_strides[i])
            else:
                x = spk

        self.stdp_update_layer(self.conv_layers[-1], x, pot_final, winners, r_stdp=True, reward=rew)

    def apply_r_stdp_direct(self, spike_wave, pot_final, winners, reward):
        """
        Direct training of specific neurons with stronger reward
        """
        # Process through all layers except the last one to get the input to the final layer
        x = spike_wave
        for i in range(self.num_layers - 1):
            pot = self.conv_layers[i](x)
            spk = spiking_fire(pot, self.thresholds[i])
            if self.pool_sizes[i] > 1:
                x = spiking_pooling(spk, self.pool_sizes[i], self.pool_strides[i])
            else:
                x = spk

        # Now apply R-STDP to the final layer
        for (fout, rr, cc) in winners:
            # Get pre-synaptic spike times for the specific kernel window
            kH, kW = self.conv_layers[-1].weight.shape[2], self.conv_layers[-1].weight.shape[3]

            # Make sure the window is properly sized
            if rr + kH <= x.shape[2] and cc + kW <= x.shape[3]:
                # Extract exactly the patch that would be used in convolution
                pre_patch = x[:, :, rr:rr+kH, cc:cc+kW]
                T_pre = first_spike_time_from_wave(pre_patch)

                # Use early spike time for the target neuron
                T_post = 5.0 if reward > 0 else 10.0

                # Apply stronger weight updates
                self.conv_layers[-1].weight.data[fout] = r_stdp_update(
                    self.conv_layers[-1].weight.data[fout],
                    T_pre, T_post,
                    self.A_plus_r*2, self.A_minus_r*2,
                    self.lb, self.ub,
                    reward
                )

    def stdp_update_layer(self, conv_layer, input_spike_wave, pot, winners, r_stdp=False, reward=0.0):
        """
        Core function for updating weights with STDP or R-STDP
        """
        W = conv_layer.weight.data
        kH = W.shape[2]
        kW = W.shape[3]

        for (fout, rr, cc) in winners:
            T_post = first_spike_time_from_pot(pot[:, fout, rr, cc])
            if input_spike_wave is not None:
                pre_patch = input_spike_wave[:, :, rr:rr+kH, cc:cc+kW]
                T_pre = first_spike_time_from_wave(pre_patch)
            else:
                T_pre = torch.zeros((W.shape[1], kH, kW), device=W.device)

            if not r_stdp:
                W[fout] = stdp_update(W[fout], T_pre, T_post,
                                     self.A_plus, self.A_minus, self.lb, self.ub)
            else:
                W[fout] = r_stdp_update(W[fout], T_pre, T_post,
                                       self.A_plus_r, self.A_minus_r, self.lb, self.ub, reward)

        conv_layer.weight.data = W

print("Flexible DeepConvSNN class created.")

# Creating a 4 and 5 Layer DeepConvSNN model

In [None]:
# 13. Creating and Training 4-Layer and 5-Layer Models

# 4-Layer Model
model_4layer = FlexibleDeepConvSNN(
    in_channels=1,
    layer_channels=[30, 80, 150, 10],  # 4 layers
    kernel_sizes=[5, 3, 3, 3],
    pool_sizes=[2, 2, 2, 1],
    pool_strides=[2, 2, 2, 1],
    thresholds=[5.0, 4.0, 3.0, 2.0],
    a=a_, b=b_, c=c_, d=d_,
    A_plus=0.01, A_minus=-0.008,
    A_plus_r=0.01, A_minus_r=-0.008,
    lb=0.2, ub=0.8,
    reward_val=+1.0, punish_val=-1.0
).to(device)

# 5-Layer Model
model_5layer = FlexibleDeepConvSNN(
    in_channels=1,
    layer_channels=[30, 60, 100, 150, 10],  # 5 layers
    kernel_sizes=[5, 3, 3, 3, 3],
    pool_sizes=[2, 2, 2, 2, 1],
    pool_strides=[2, 2, 2, 2, 1],
    thresholds=[5.0, 4.5, 4.0, 3.5, 3.0],
    a=a_, b=b_, c=c_, d=d_,
    A_plus=0.01, A_minus=-0.008,
    A_plus_r=0.01, A_minus_r=-0.008,
    lb=0.2, ub=0.8,
    reward_val=+1.0, punish_val=-1.0
).to(device)

print("4-layer and 5-layer models created.")

# Training functions for the flexible models
def train_flexible_model(model, train_loader, test_loader, num_layers):
    """
    Train a flexible model with the specified number of layers
    """
    print(f"\n=== Training {num_layers}-layer model ===")

    # Train each layer sequentially
    for layer_idx in range(1, num_layers):
        print(f"=== Training layer {layer_idx} ===")
        train_layer(model, train_loader, layer_idx=layer_idx, epochs=2, max_images=500)

        # Check intermediate accuracy
        if layer_idx > 1:  # Only check after at least 2 layers are trained
            acc = test_accuracy(model, test_loader)
            print(f"Accuracy after training layer {layer_idx}: {acc:.2f}%")

    # Train the final classification layer with R-STDP
    print(f"=== Training final layer (layer {num_layers}) with R-STDP ===")

    # Adapt the existing train_layer3_rstdp function for our flexible model
    model.train()
    for ep in range(2):  # 2 epochs
        correct_during_training = 0
        total_during_training = 0

        for i, (data_spike_wave, target) in enumerate(train_loader):
            if i % 10 == 0:
                print(f"Processing image {i}...")

            data_spike_wave = data_spike_wave.to(device)
            pot_final = model.forward_inference(data_spike_wave.squeeze(0))
            pred_class_ = predict_class(pot_final)
            lbl = target.item()

            # Apply stronger reward/punishment based on correctness
            reward_val = 1.5 if pred_class_ == lbl else -1.5

            # Force activation of the correct class neuron
            winners = [(lbl, 0, 0)]  # Always update the weights for the correct class
            model.apply_r_stdp_direct(data_spike_wave.squeeze(0), pot_final, winners, reward_val)

            # Track accuracy during training
            if pred_class_ == lbl:
                correct_during_training += 1
            total_during_training += 1

            if i % 100 == 99:
                print(f"Current training accuracy: {100*correct_during_training/total_during_training:.2f}%")

            if i >= 500-1:  # Limit to 500 images
                print(f"Reached maximum images (500), stopping training")
                break

    # Apply mistake correction
    print("=== Applying mistake correction ===")
    mistakes_fixed = train_on_mistakes(model, train_loader, max_mistakes=100)
    print(f"Trained on {mistakes_fixed} misclassified examples")

    # Final accuracy check
    final_acc = test_accuracy(model, test_loader)
    print(f"Final {num_layers}-layer model accuracy: {final_acc:.2f}%")

    return model, final_acc

# Train the 4-layer model
model_4layer, acc_4layer = train_flexible_model(model_4layer, train_loader, test_loader, num_layers=4)

# Train the 5-layer model
model_5layer, acc_5layer = train_flexible_model(model_5layer, train_loader, test_loader, num_layers=5)

# Compare results
print("\n=== Model Comparison ===")
print(f"3-layer model accuracy: {best_accuracy:.2f}%")
print(f"4-layer model accuracy: {acc_4layer:.2f}%")
print(f"5-layer model accuracy: {acc_5layer:.2f}%")

# Adaptive Training for Deeper Models

In [None]:
# 14. Enhanced Prediction Function for Deeper Models

def enhanced_predict_class(pot_final, confidence=False):
    """
    Enhanced prediction function that works better with deeper models
    Returns class prediction and optionally confidence scores
    """
    T, C, H, W = pot_final.shape

    # Calculate multiple metrics for each class
    class_metrics = {
        'spike_count': torch.zeros(C, device=pot_final.device),
        'earliest_spike': torch.full((C,), T, device=pot_final.device),
        'max_potential': torch.zeros(C, device=pot_final.device),
        'spatial_spread': torch.zeros(C, device=pot_final.device)
    }

    # Calculate metrics for each class
    for f in range(C):
        channel_pot = pot_final[:, f]

        # Spike count (total activations above threshold)
        spike_mask = (channel_pot > 0).float()
        class_metrics['spike_count'][f] = spike_mask.sum()

        # Earliest spike time
        for t in range(T):
            if spike_mask[t].sum() > 0:
                class_metrics['earliest_spike'][f] = t
                break

        # Maximum potential reached
        class_metrics['max_potential'][f] = channel_pot.max()

        # Spatial spread (how many spatial locations had spikes)
        spatial_locations = torch.sum(spike_mask.view(T, -1) > 0, dim=1).float().sum()
        class_metrics['spatial_spread'][f] = spatial_locations

    # Combine metrics into a single score
    # Earlier spikes, more spikes, higher potentials, and wider spread are better
    combined_score = (
        class_metrics['spike_count'] *
        (T - class_metrics['earliest_spike']) *
        class_metrics['max_potential'] *
        (1 + 0.1 * class_metrics['spatial_spread'])
    )

    # Get prediction and confidence
    pred_class = torch.argmax(combined_score).item()

    if confidence:
        # Calculate confidence as normalized score
        confidence_score = F.softmax(combined_score, dim=0)
        return pred_class, confidence_score
    else:
        return pred_class

# Test the enhanced prediction function
def test_with_enhanced_prediction(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for i, (data_spike_wave, target) in enumerate(loader):
            data_spike_wave = data_spike_wave.to(device)
            pot_final = model.forward_inference(data_spike_wave.squeeze(0))
            pred_class_ = enhanced_predict_class(pot_final)
            if pred_class_ == target.item():
                correct += 1
            total += 1
    return 100.0 * correct / total

# Compare prediction methods on all models
print("\n=== Comparing Prediction Methods ===")
print("Original prediction method:")
print(f"3-layer model: {test_accuracy(model, test_loader):.2f}%")
print(f"4-layer model: {test_accuracy(model_4layer, test_loader):.2f}%")
print(f"5-layer model: {test_accuracy(model_5layer, test_loader):.2f}%")

print("\nEnhanced prediction method:")
print(f"3-layer model: {test_with_enhanced_prediction(model, test_loader):.2f}%")
print(f"4-layer model: {test_with_enhanced_prediction(model_4layer, test_loader):.2f}%")
print(f"5-layer model: {test_with_enhanced_prediction(model_5layer, test_loader):.2f}%")