In [None]:
# Import core libraries
import tensorflow as tf  # Main deep learning framework
from tensorflow import keras  # High-level neural network API
import numpy as np  # Numerical computing library
import matplotlib.pyplot as plt  # Plotting library for visualization

In [None]:
# Cell 2: Define LIF Layer with Surrogate Gradient (FIXED)
@tf.custom_gradient
def spike_function(v, threshold):
    """
    Spike function with surrogate gradient.
    Forward pass: Hard threshold (Heaviside step function)
    Backward pass: Smooth gradient (fast sigmoid surrogate)
    """
    # Forward pass: Generate actual spikes (1 if v > threshold, else 0)
    spikes = tf.cast(tf.greater(v, threshold), tf.float32)

    def grad(dy):
        """
        Backward pass: Use fast sigmoid surrogate gradient.
        Derivative of fast sigmoid: 1 / (1 + |alpha * (v - threshold)|)^2
        """
        alpha = 10.0  # Steepness parameter for surrogate gradient
        # Compute distance from threshold
        v_shifted = v - threshold
        # Fast sigmoid surrogate gradient
        surrogate_grad = 1.0 / (1.0 + tf.abs(alpha * v_shifted)) ** 2
        # Scale by incoming gradient
        dv = dy * surrogate_grad
        # Return gradients for v and threshold (threshold is constant, so None)
        return dv, None

    return spikes, grad  # Return spikes and gradient function


class LIFLayer(keras.layers.Layer):
    """
    Leaky Integrate-and-Fire neuron layer with surrogate gradient training.
    Uses smooth gradients during backprop for trainability.
    """

    def __init__(self, units, output_units=None, tau=20.0, threshold=1.0, use_dfa=True, is_output_layer=False, **kwargs):
        """
        Initialize the LIF layer with specified parameters.

        Args:
            units: Number of neurons in this layer
            tau: Membrane time constant (controls leak rate)
            threshold: Voltage threshold for spike generation

            output units: number of output neurons
            use_dfa = whether to use dfa or normal backprop
            is_output_layer = whether this is the output layer
        """
        super(LIFLayer, self).__init__(**kwargs)
        self.units = units  # Store number of neurons
        self.output_units = output_units  # Store number of output neurons
        self.tau = tau  # Store time constant for membrane potential decay
        self.threshold = threshold  # Store spike threshold value
        self.use_dfa = use_dfa  # Store whether to use dfa or normal backprop
        self.is_output_layer = is_output_layer  # Store whether this is the output layer

    def build(self, input_shape):
        """
        Build layer weights - called automatically when layer is first used.
        Creates the weight matrix connecting inputs to neurons.
        """
        # Create weight matrix with Xavier initialization for stable gradients
        self.w = self.add_weight(
            shape=(input_shape[-1], self.units),  # Shape: [input_dim, num_neurons]
            initializer='glorot_uniform',  # Xavier uniform initialization
            trainable=True,  # Weights will be updated during training
            name='weights'  # Name for debugging
        )
        # Create bias vector, one bias per neuron
        self.b = self.add_weight(
            shape=(self.units,),  # Shape: [num_neurons]
            initializer='zeros',  # Initialize biases to zero
            trainable=True,  # Biases will be updated during training
            name='bias'  # Name for debugging
        )

    def compute_output_shape(self, input_shape):
        """
        Compute the output shape of the layer.
        Output has same batch and time dimensions, but neuron dimension changes.
        """
        # Return shape: [batch, time_steps, units]
        return (input_shape[0], input_shape[1], self.units)

    def call(self, inputs, training=None):
        """
        Forward pass through the LIF layer with surrogate gradient.
        Simulates spiking neuron dynamics over time.

        Args:
            inputs: Input tensor of shape [batch, time_steps, input_dim]

        Returns:
            Spike tensor of shape [batch, time_steps, units]
        """
        batch_size = tf.shape(inputs)[0]  # Get batch size from input
        time_steps = tf.shape(inputs)[1]  # Get number of time steps (dynamic)

        # Initialize membrane potential to zero for all neurons
        v = tf.zeros((batch_size, self.units))  # Shape: [batch, units]

        # Use TensorArray to store spikes (works with symbolic tensors)
        spikes_array = tf.TensorArray(
            dtype=tf.float32,  # Data type for spikes
            size=time_steps,  # Size equals number of time steps
            dynamic_size=False  # Fixed size array
        )

        # Define loop body function for time step iteration
        def time_step(t, v, spikes_array):
            """
            Process one time step: compute current, update voltage, generate spikes.
            Uses surrogate gradient for spike generation.
            """
            # Extract input at current time step
            x_t = inputs[:, t, :]  # Shape: [batch, input_dim]

            # Compute input current: weighted sum + bias
            i_in = tf.matmul(x_t, self.w) + self.b  # Shape: [batch, units]

            # Update membrane potential with leak and input
            # dv/dt = -v/tau + i_in approximated as: v = v*(1-1/tau) + i_in
            v = v * (1.0 - 1.0/self.tau) + i_in  # Leaky integration

            # Generate spikes using surrogate gradient function
            # Forward: hard threshold, Backward: smooth gradient
            spike = spike_function(v, self.threshold)

            # Reset membrane potential for neurons that spiked
            v = v * (1.0 - spike)  # Multiply by (1-spike) to reset spiking neurons to 0

            # Write spikes to array at current time index
            spikes_array = spikes_array.write(t, spike)

            # Return updated values for next iteration
            return t + 1, v, spikes_array

        # Run loop over all time steps using tf.while_loop (TensorFlow's loop)
        _, _, spikes_array = tf.while_loop(
            cond=lambda t, *_: t < time_steps,  # Continue while t < time_steps
            body=time_step,  # Function to execute each iteration
            loop_vars=[0, v, spikes_array]  # Initial values: [t=0, v, spikes_array]
        )

        # Convert TensorArray to regular tensor and transpose
        output = spikes_array.stack()  # Shape: [time_steps, batch, units]
        output = tf.transpose(output, [1, 0, 2])  # Transpose to [batch, time_steps, units]

        return output  # Return spike trains

In [None]:
# Cell 3: Build 2-Layer SNN Model
def create_snn_model(input_dim, hidden_units, output_units, time_steps):
    """
    Create a simple 2-layer dense SNN.

    Args:
        input_dim: Dimension of input features
        hidden_units: Number of neurons in hidden layer
        output_units: Number of output neurons (classes)
        time_steps: Number of time steps for simulation

    Returns:
        Keras model representing the SNN
    """
    # Define input layer: [batch, time_steps, input_dim]
    inputs = keras.Input(shape=(time_steps, input_dim), name='input')

    # First LIF layer (hidden layer) - processes temporal input
    hidden = LIFLayer(
        units=hidden_units,  # Number of hidden neurons
        tau=20.0,  # Membrane time constant
        threshold=1.0,  # Spike threshold
        name='hidden_layer'  # Layer name
    )(inputs)  # Apply to input

    # Second LIF layer (output layer) - produces classification spikes
    outputs = LIFLayer(
        units=output_units,  # Number of output classes
        tau=20.0,  # Membrane time constant
        threshold=1.0,  # Spike threshold
        name='output_layer'  # Layer name
    )(hidden)  # Apply to hidden layer output

    # Create the model by connecting inputs to outputs
    model = keras.Model(inputs=inputs, outputs=outputs, name='2layer_snn')

    return model  # Return the complete model

In [None]:
# Cell 4: Create Rate Coding Function
def rate_encode(data, time_steps, max_rate=1.0):
    """
    Convert static data to spike trains using rate coding.
    Higher input values produce more spikes over time.

    Args:
        data: Input data of shape [batch, features]
        time_steps: Number of time steps to encode
        max_rate: Maximum firing rate

    Returns:
        Spike trains of shape [batch, time_steps, features]
    """
    # Normalize input data to [0, 1] range
    data = np.clip(data, 0, 1)  # Ensure values are between 0 and 1
    #clip just sets the max as 1 and the min as 0

    # Generate random numbers for probabilistic spike generation
    rand = np.random.rand(data.shape[0], time_steps, data.shape[1])

    # Create spikes where random value < input value * max_rate
    # Higher input values = higher probability of spikes
    spikes = (rand < data[:, np.newaxis, :] * max_rate).astype(np.float32)

    return spikes  # Return encoded spike trains

In [None]:
# Cell 5: Load and Prepare MNIST Dataset
# Load MNIST handwritten digit dataset
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# Flatten images from 28x28 to 784-dimensional vectors
x_train = x_train.reshape(-1, 784).astype('float32')  # Shape: [60000, 784]
x_test = x_test.reshape(-1, 784).astype('float32')  # Shape: [10000, 784]

# Normalize pixel values to [0, 1] range
x_train = x_train / 255.0  # Divide by max pixel value
x_test = x_test / 255.0  # Divide by max pixel value

#manual normalization here

# Convert labels to one-hot encoding
y_train = keras.utils.to_categorical(y_train, 10)  # Shape: [60000, 10]
y_test = keras.utils.to_categorical(y_test, 10)  # Shape: [10000, 10]

# Use smaller subset for faster training
num_samples = 1000  # Number of training samples
x_train = x_train[:num_samples]  # Take first 1000 samples
y_train = y_train[:num_samples]  # Take corresponding labels

print(f"Training data shape: {x_train.shape}")  # Print shape for verification
print(f"Training labels shape: {y_train.shape}")  # Print shape for verification

Training data shape: (1000, 784)
Training labels shape: (1000, 10)


In [None]:
# Cell 6: Encode Data as Spike Trains
time_steps = 20  # Number of time steps for simulation, for exmaple seconds

# Convert training data to spike trains using rate coding
x_train_spikes = rate_encode(x_train, time_steps, max_rate=0.8)
print(f"Spike train shape: {x_train_spikes.shape}")  # Should be [1000, 20, 784]

#rate coding is basically getting input from the average spikes over a period of time

# Convert test data to spike trains
x_test_spikes = rate_encode(x_test, time_steps, max_rate=0.8)
print(f"Test spike train shape: {x_test_spikes.shape}")  # Should be [10000, 20, 784]

Spike train shape: (1000, 20, 784)
Test spike train shape: (10000, 20, 784)


In [None]:
# Cell 7: Create and Compile Model with Surrogate Gradient Training
# Create SNN with specified architecture
model = create_snn_model(
    input_dim=784,  # MNIST flattened image size
    hidden_units=128,  # Number of hidden layer neurons
    output_units=10,  # Number of output classes (digits 0-9)
    time_steps=time_steps  # Number of simulation time steps
)

# Print model architecture summary
model.summary()  # Shows layer types, output shapes, and parameters

# Define custom loss function for spike-based classification
def spike_categorical_crossentropy(y_true, y_pred):
    """
    Loss function that compares spike counts to target labels.
    Sum spikes over time and compute crossentropy.
    """
    # Sum spikes across time dimension to get firing rates
    spike_counts = tf.reduce_sum(y_pred, axis=1)  # Shape: [batch, output_units]

    # Apply softmax to convert spike counts to probabilities
    probs = tf.nn.softmax(spike_counts)  # Normalize to probability distribution

    # Compute categorical crossentropy loss
    loss = keras.losses.categorical_crossentropy(y_true, probs)

    return loss  # Return scalar loss value

# Define custom accuracy metric for spike-based classification
def spike_categorical_accuracy(y_true, y_pred):
    """
    Accuracy metric that sums spikes over time before comparing to labels.
    """
    # Sum spikes across time dimension to get total spike counts
    spike_counts = tf.reduce_sum(y_pred, axis=1)  # Shape: [batch, output_units]

    # Get predicted class (neuron with most spikes)
    y_pred_class = tf.argmax(spike_counts, axis=-1)  # Shape: [batch]

    # Get true class from one-hot labels
    y_true_class = tf.argmax(y_true, axis=-1)  # Shape: [batch]

    # Compare predictions to true labels
    matches = tf.equal(y_pred_class, y_true_class)  # Boolean tensor

    # Convert to float (1.0 for match, 0.0 for mismatch)
    accuracy = tf.cast(matches, tf.float32)

    return accuracy  # Return accuracy values

# Compile model with optimizer, loss, and surrogate gradient training
model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=0.001),  # Adam optimizer
    loss=spike_categorical_crossentropy,  # Spike-based loss
    metrics=[spike_categorical_accuracy]  # Spike-based accuracy
)

print("\n✓ Model compiled with surrogate gradient training!")
print("Forward pass: Hard spike threshold")
print("Backward pass: Smooth fast sigmoid gradient (alpha=10)")


✓ Model compiled with surrogate gradient training!
Forward pass: Hard spike threshold
Backward pass: Smooth fast sigmoid gradient (alpha=10)


In [None]:
# Cell 8: Train the SNN
# Train model on spike-encoded MNIST data
history = model.fit(
    x_train_spikes,  # Input: spike trains
    y_train,  # Target: one-hot labels
    batch_size=32,  # Number of samples per gradient update
    epochs=10,  # Number of full passes through dataset
    validation_split=0.2,  # Use 20% of data for validation
    verbose=1  # Print progress bar
)

Epoch 1/10
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 85ms/step - loss: 2.5630 - spike_categorical_accuracy: 0.2765 - val_loss: 1.1714 - val_spike_categorical_accuracy: 0.6600
Epoch 2/10
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 38ms/step - loss: 0.9031 - spike_categorical_accuracy: 0.7070 - val_loss: 0.9908 - val_spike_categorical_accuracy: 0.7250
Epoch 3/10
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 69ms/step - loss: 0.5991 - spike_categorical_accuracy: 0.7900 - val_loss: 0.7237 - val_spike_categorical_accuracy: 0.7900
Epoch 4/10
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 58ms/step - loss: 0.3048 - spike_categorical_accuracy: 0.9124 - val_loss: 0.7231 - val_spike_categorical_accuracy: 0.8100
Epoch 5/10
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 45ms/step - loss: 0.1242 - spike_categorical_accuracy: 0.9673 - val_loss: 0.6850 - val_spike_categorical_accuracy: 0.8000
Epoch 6/10
[1m