<a href="https://colab.research.google.com/github/HelenaBahrami/GoatVentures/blob/master/SNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Spiking Neural Network Components and Their Implementation Using PyTorch
Spiking Neural Networks (SNNs) are a biologically inspired model of neural networks where information is transmitted via discrete spikes rather than continuous values, mimicking how neurons in the brain communicate. This makes them suitable for neuromorphic computing and energy-efficient AI systems. The key components of SNNs include encoding methods, neuron models, and learning algorithms, which can all be implemented using PyTorch. Here's an overview of these components and their implementation:

**Encoding:** BSA (Bens Spiking Algorithm)
The encoding mechanism in SNNs converts continuous input data into spike trains. The BSA (Bens Spiking Algorithm) is an effective method to achieve this. It maps real-valued inputs into a series of spikes based on their intensity and timing, which are then used by the spiking neurons for further processing.

**Neuron Model:** Leaky Integrate-and-Fire (LIF)
The Leaky Integrate-and-Fire (LIF) neuron model is one of the most widely used in SNNs. It integrates incoming spikes over time, and once the membrane potential crosses a certain threshold, it generates a spike. The neuron "leaks" potential over time, which allows it to reset and respond to further inputs. Implementing an LIF model in PyTorch involves simulating this integration, leak, and spike generation behavior using PyTorch tensors and operations.

**Learning Algorithm:** Spike-Timing-Dependent Plasticity (STDP)
Spike-Timing-Dependent Plasticity (STDP) is a biologically inspired learning rule used in SNNs. It adjusts the synaptic weights based on the relative timing of spikes from the pre- and post-synaptic neurons. If a pre-synaptic neuron fires shortly before a post-synaptic neuron, the synapse is strengthened (long-term potentiation), whereas if the pre-synaptic neuron fires after the post-synaptic neuron, the synapse is weakened (long-term depression). This timing-based learning is key for pattern recognition tasks in SNNs and can be implemented in PyTorch by updating synaptic weights using differential timing between neurons.

## Overview of Implementation Steps:
**Encoding using BSA:**

Create a function to encode continuous inputs into spikes based on intensity and timing using PyTorch.

**Neuron Model (LIF):**

Define the LIF model in PyTorch, including the membrane potential, threshold, and leakage over time.
Simulate spike generation and potential reset upon reaching the threshold.

**Learning with STDP:**

Implement the STDP rule to adjust synaptic weights by calculating the time difference between pre- and post-synaptic spikes.
Update the weights using PyTorch’s autograd feature for efficient backpropagation during training.
This framework will serve as the foundation for designing SNNs in PyTorch, enabling you to simulate biologically inspired networks for energy-efficient AI tasks.


# Encoder Algorithm
The Ben's Spike Algorithm (BSA) is an encoding method for converting analog input signals into binary spike trains. This is particularly useful in Spiking Neural Networks (SNNs), which operate using discrete spikes instead of continuous signals. BSA aims to create a spike train representation of a signal that captures the signal's temporal and magnitude characteristics.

The core idea behind BSA is to match a predefined filter to segments of the input signal. When the error between the filter and the signal is below a threshold, a spike is generated. The algorithm iteratively subtracts the filter from the signal after a spike is produced.



In [None]:

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

# Define the BSA encoding function
def BSA_encoding(input_data, filter_type=1, threshold=0.1):
    spike_trains = []

    # Define filter based on the filter_type
    if filter_type == 1:
        Filter = torch.tensor([0.8, 0.2])
    elif filter_type == 2:
        Filter = torch.tensor([0.05] * 6)
    elif filter_type == 3:
        Filter = torch.tensor([0.8] * 6)

    filter_length = len(Filter)

    num_samples = input_data.shape[1]

    # Loop through all samples
    for s in range(num_samples):
        sample = input_data[:, s].unsqueeze(1)
        timelength = sample.shape[0]
        inputnum = sample.shape[1]
        time_len = sample.shape[0]
        total_error = torch.zeros(inputnum)
        SpikeTrain_temp = torch.zeros(timelength, inputnum)

        # Signal normalization [0, 1]
        min_data = sample.min(0, keepdim=True).values
        max_data = sample.max(0, keepdim=True).values
        EncodingSignal = (sample - min_data) / (max_data - min_data)

        # Set threshold
        BSA_Threshold_rowVector = torch.ones(inputnum) * threshold

        # Perform BSA encoding
        for f in range(inputnum):
            for i in range(timelength - filter_length + 1):
                error1 = 0
                error2 = 0
                for j in range(filter_length):
                    error1 += abs(EncodingSignal[i + j, f] - Filter[j])
                    error2 += abs(EncodingSignal[i + j, f])

                if error1 <= (error2 - BSA_Threshold_rowVector[f]):  # Spike criterion
                    SpikeTrain_temp[i, f] = 1

                    for j in range(filter_length):
                        EncodingSignal[i + j, f] -= Filter[j]

                    total_error[f] += error1
                else:
                    total_error[f] += error2

        spike_trains.append(SpikeTrain_temp)

    return spike_trains

# Create a Poisson 2D tensor as input data
def generate_poisson_tensor(shape, rate):
    return torch.poisson(torch.full(shape, rate, dtype=torch.float))

# Plot the Poisson signal and the encoded spike train
def plot_poisson_and_encoded(poisson_tensor, spike_trains):
    fig, axes = plt.subplots(2, 1, figsize=(10, 6))

    # Plot original Poisson signal
    axes[0].plot(poisson_tensor[:, 0].numpy(), color='blue')
    axes[0].set_title('Original Poisson Signal')

    # Plot spike trains
    axes[1].stem(spike_trains[0][:, 0].numpy(), 'r', markerfmt='ro', basefmt=" ")
    axes[1].set_title('Spike Train (BSA Encoding)')

    plt.tight_layout()
    plt.show()

# Generate Poisson input data
input_data = generate_poisson_tensor((100, 1), rate=10)

# Perform BSA encoding
spike_trains = BSA_encoding(input_data, filter_type=1, threshold=0.1)

# Plot results
plot_poisson_and_encoded(input_data, spike_trains)


# Neuron Model
The Leaky Integrate-and-Fire (LIF) neuron model is a simple yet widely used model in Spiking Neural Networks (SNNs). It mimics the behavior of biological neurons by integrating incoming input signals (synaptic currents) over time.
## LIF Characteristics:

**Integration:**
The neuron continuously integrates the incoming input (current) into its membrane potential. The potential increases as it receives inputs.

**Leak:**
Over time, the membrane potential gradually "leaks" back toward a resting potential if there are no inputs, simulating the natural decay of potential in biological neurons.

**Firing (Spike Generation):**
Once the membrane potential reaches a certain threshold, the neuron fires, generating a spike (action potential).
After firing, the membrane potential is reset to its resting state, and the process starts again.

In [None]:
import torch
import matplotlib.pyplot as plt

# Define the Leaky Integrate-and-Fire (LIF) neuron model
class LIFNeuron:
    def __init__(self, tau_m=20.0, v_reset=0.0, v_threshold=1.0, dt=1.0, r_m=1.0):
        self.tau_m = tau_m           # Membrane time constant (ms)
        self.v_reset = v_reset       # Reset membrane potential (after spike)
        self.v_threshold = v_threshold # Firing threshold
        self.dt = dt                 # Time step (ms)
        self.r_m = r_m               # Membrane resistance
        self.v_membrane = v_reset    # Initial membrane potential

    def reset(self):
        self.v_membrane = self.v_reset

    def simulate(self, I, time_steps):
        v_trace = []  # To store the membrane potential over time
        spikes = []   # To store the spike events

        for t in range(time_steps):
            dv = (-(self.v_membrane - self.v_reset) + self.r_m * I[t]) / self.tau_m
            self.v_membrane += dv * self.dt

            if self.v_membrane >= self.v_threshold:
                spikes.append(1)
                self.v_membrane = self.v_reset  # Reset the potential after spike
            else:
                spikes.append(0)

            v_trace.append(self.v_membrane)

        return torch.tensor(v_trace), torch.tensor(spikes)

# Function to simulate LIF neuron for both constant and varying input currents
def simulate_lif():
    time_steps = 100
    dt = 1.0

    # Create LIF neuron
    neuron = LIFNeuron(dt=dt)

    # Constant current input
    constant_current = torch.full((time_steps,), 1.2)  # Constant current
    v_const, spikes_const = neuron.simulate(constant_current, time_steps)

    # Reset neuron
    neuron.reset()

    # Varying current input (sinusoidal pattern)
    t = torch.arange(0, time_steps, dt)
    varying_current = 1.5 * torch.sin(0.1 * t) + 1.0  # Varying current
    v_var, spikes_var = neuron.simulate(varying_current, time_steps)

    # Plot membrane potentials and spike trains
    fig, axes = plt.subplots(2, 2, figsize=(12, 8))

    # Plot constant current
    axes[0, 0].plot(t.numpy(), constant_current.numpy(), color='blue')
    axes[0, 0].set_title('Constant Current Input')

    axes[0, 1].plot(t.numpy(), v_const.numpy(), color='red')
    axes[0, 1].set_title('Membrane Potential (Constant Current)')

    # Plot varying current
    axes[1, 0].plot(t.numpy(), varying_current.numpy(), color='green')
    axes[1, 0].set_title('Varying Current Input')

    axes[1, 1].plot(t.numpy(), v_var.numpy(), color='orange')
    axes[1, 1].set_title('Membrane Potential (Varying Current)')

    plt.tight_layout()
    plt.show()

# Run the simulation and plot results
simulate_lif()


# Learning Mechanism: Spike-Timing-Dependent Plasticity (STDP)

STDP is a biological learning rule used in spiking neural networks (SNNs) that adjusts the strength of synapses (connections between neurons) based on the precise timing of spikes.

## STDP Principles:
**Causal Relationship (LTP - Long-Term Potentiation):**
If a presynaptic neuron (sending neuron) fires shortly before a postsynaptic neuron (receiving neuron), the synapse is strengthened. This encourages the connection between neurons, increasing synaptic weight.

**Anti-Causal Relationship (LTD - Long-Term Depression):**
If the presynaptic neuron fires after the postsynaptic neuron, the synapse is weakened, decreasing synaptic weight.

**Time-Dependent Rule:**
The amount of change in synaptic strength depends on the time difference between the pre- and postsynaptic spikes. The closer in time the spikes are, the greater the change.
STDP allows SNNs to learn temporal patterns and correlations in spiking activity, making it an important mechanism for synaptic plasticity and learning in neuromorphic computing and brain-inspired systems.

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

# Parameters for STDP
tau_plus = 20.0  # Time constant for LTP (ms)
tau_minus = 20.0  # Time constant for LTD (ms)
A_plus = 0.005   # Maximum weight change for LTP
A_minus = 0.005  # Maximum weight change for LTD
w_max = 1.0  # Maximum synaptic weight
w_min = 0.0  # Minimum synaptic weight

# Initial synaptic weight
w = 0.5  # Initial synaptic weight (start at middle value)

# STDP Function
def stdp(delta_t):
    """STDP update rule based on timing difference delta_t."""
    if delta_t > 0:
        # Pre-synaptic neuron fires before post-synaptic (LTP)
        return A_plus * np.exp(-delta_t / tau_plus)
    else:
        # Post-synaptic neuron fires before pre-synaptic (LTD)
        return -A_minus * np.exp(delta_t / tau_minus)

# Time differences to evaluate (-50 ms to 50 ms)
delta_t_values = np.arange(-50, 51, 1)

# Track synaptic weight changes
weight_changes = []

# Apply STDP for each delta_t
for delta_t in delta_t_values:
    delta_w = stdp(delta_t)
    w_new = w + delta_w

    # Clip the weight within [w_min, w_max]
    w_new = np.clip(w_new, w_min, w_max)

    weight_changes.append(w_new)

# Plot the STDP behavior
plt.figure(figsize=(8, 6))
plt.plot(delta_t_values, weight_changes, color='b', marker='o')
plt.axhline(0.5, color='k', linestyle='--', label='Initial Weight')
plt.title("STDP Weight Change Based on Timing Difference")
plt.xlabel("Time Difference (Pre - Post) [ms]")
plt.ylabel("Synaptic Weight")
plt.grid(True)
plt.legend()
plt.show()


# Spiking Neural Network with BSA Encoding and STDP Learning: A Comparative Analysis of Poisson Signal Responses

This code defines a Spiking Neural Network (SNN) using Leaky Integrate-and-Fire (LIF) neurons, which process inputs encoded using Ben's Spike Algorithm (BSA). The SNN has an architecture with 2 input neurons, 10 hidden neurons, and 1 output neuron. Input-to-hidden synapses (synaptic_weights_ih) and hidden-to-output synapses (synaptic_weights_ho) are initialized with random weights. The input signals are Poisson-distributed, and two signals (A and B) are encoded with BSA. The SNN is trained using Spike-Timing-Dependent Plasticity (STDP), which updates the synaptic weights based on the timing difference between pre- and post-synaptic spikes. The network propagates input spikes through the layers and updates the weights using the STDP learning rule. After training on both signals, the output spike trains for Signal A and Signal B are visualized, and their spike rates are compared to determine how the network responds differently to each input. The spike rate calculation gives a quantitative measure of the output neuron's response.


In [None]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

# BSA Encoding Function
def BSA_encoding(input_data, filter_type=1, threshold=0.1):
    spike_trains = []
    if filter_type == 1:
        Filter = torch.tensor([0.8, 0.2])
    elif filter_type == 2:
        Filter = torch.tensor([0.05] * 6)
    elif filter_type == 3:
        Filter = torch.tensor([0.8] * 6)

    filter_length = len(Filter)
    num_samples = input_data.shape[1]

    for s in range(num_samples):
        sample = input_data[:, s].unsqueeze(1)
        timelength = sample.shape[0]
        inputnum = sample.shape[1]
        total_error = torch.zeros(inputnum)
        SpikeTrain_temp = torch.zeros(timelength, inputnum)
        min_data = sample.min(0, keepdim=True).values
        max_data = sample.max(0, keepdim=True).values
        EncodingSignal = (sample - min_data) / (max_data - min_data)
        BSA_Threshold_rowVector = torch.ones(inputnum) * threshold

        for f in range(inputnum):
            for i in range(timelength - filter_length + 1):
                error1 = 0
                error2 = 0
                for j in range(filter_length):
                    error1 += abs(EncodingSignal[i + j, f] - Filter[j])
                    error2 += abs(EncodingSignal[i + j, f])
                if error1 <= (error2 - BSA_Threshold_rowVector[f]):
                    SpikeTrain_temp[i, f] = 1
                    for j in range(filter_length):
                        EncodingSignal[i + j, f] -= Filter[j]
        spike_trains.append(SpikeTrain_temp)
    return torch.cat(spike_trains, dim=1)

# Generate Poisson 2D Tensor
def generate_poisson_tensor(shape, rate):
    return torch.poisson(torch.full(shape, rate, dtype=torch.float))

# Leaky Integrate-and-Fire Neuron
class LIFNeuron(nn.Module):
    def __init__(self, n_in, tau_m=10.0, v_reset=0.0, v_threshold=0.5, dt=1.0, r_m=2.0):
        super(LIFNeuron, self).__init__()
        self.tau_m = tau_m
        self.v_reset = v_reset
        self.v_threshold = v_threshold
        self.dt = dt
        self.r_m = r_m
        self.n_in = n_in

    def forward(self, I):
        batch_size, n_neurons = I.shape
        v_membrane = torch.zeros(batch_size, n_neurons)  # Initialize membrane potential with correct shape

        spikes = torch.zeros(batch_size, n_neurons)  # To record spike events
        for t in range(batch_size):
            dv = (-(v_membrane[t] - self.v_reset) + self.r_m * I[t]) / self.tau_m
            v_membrane[t] += dv * self.dt

            spikes[t] = (v_membrane[t] >= self.v_threshold).float()
            v_membrane[t][spikes[t] > 0] = self.v_reset  # Reset membrane potential after firing

        return spikes

# SNN Network Definition
class SNN(nn.Module):
    def __init__(self):
        super(SNN, self).__init__()
        self.input_neurons = LIFNeuron(10)
        self.hidden_neurons = LIFNeuron(10)
        self.output_neuron = LIFNeuron(1)

        self.synaptic_weights_ih = torch.rand(2, 10) * 2  # Increased weights for more input current
        self.synaptic_weights_ho = torch.rand(10, 1) * 2  # Increased weights for more input current

        # STDP Learning parameters
        self.stdp_eta = 0.005  # Learning rate for STDP

    def forward(self, input_spikes):
        # Input -> Hidden Layer
        hidden_inputs = torch.matmul(input_spikes, self.synaptic_weights_ih)
        hidden_spikes = self.hidden_neurons(hidden_inputs)

        # Hidden -> Output Layer
        output_inputs = torch.matmul(hidden_spikes, self.synaptic_weights_ho)
        output_spikes = self.output_neuron(output_inputs)

        return hidden_spikes, output_spikes

    def apply_stdp(self, pre_spikes, post_spikes):
        delta_w = self.stdp_eta * torch.matmul(pre_spikes.T, post_spikes)
        return delta_w

# Training the network with STDP
def train_snn(snn, input_data, n_steps=100):
    for t in range(n_steps):
        # Forward pass
        pre_spikes, output_spikes = snn(input_data)

        # STDP rule applied between hidden and output neurons
        dw_ho = snn.apply_stdp(pre_spikes, output_spikes)
        snn.synaptic_weights_ho += dw_ho

    return output_spikes

# Spike Rate Calculation
def calculate_spike_rate(spike_train):
    return torch.sum(spike_train, dim=0) / spike_train.shape[0]

# Generate two different Poisson signals
input_data_A = generate_poisson_tensor((100, 2), rate=10)  # Signal A
input_data_B = generate_poisson_tensor((100, 2), rate=20)  # Signal B

# Encode both signals using BSA encoding
encoded_input_A = BSA_encoding(input_data_A)
encoded_input_B = BSA_encoding(input_data_B)

# Initialize SNN and train separately for each signal
snn_A = SNN()  # Separate SNN for Signal A
output_spikes_A = train_snn(snn_A, encoded_input_A)

snn_B = SNN()  # Separate SNN for Signal B
output_spikes_B = train_snn(snn_B, encoded_input_B)

# Plot output neuron spike trains for Signal A and Signal B
plt.figure(figsize=(12, 6))

# Plot for Signal A
plt.subplot(1, 2, 1)
plt.plot(output_spikes_A.detach().numpy(), color='blue', marker='o')
plt.title('Output Neuron Spike Train for Signal A')
plt.xlabel('Time Steps')
plt.ylabel('Spikes')

# Plot for Signal B
plt.subplot(1, 2, 2)
plt.plot(output_spikes_B.detach().numpy(), color='red', marker='o')
plt.title('Output Neuron Spike Train for Signal B')
plt.xlabel('Time Steps')
plt.ylabel('Spikes')

plt.tight_layout()
plt.show()

# Calculate spike rates for both signals
spike_rate_A = calculate_spike_rate(output_spikes_A)
spike_rate_B = calculate_spike_rate(output_spikes_B)

print(f"Spike Rate for Signal A: {spike_rate_A.item()}")
print(f"Spike Rate for Signal B: {spike_rate_B.item()}")


# Guide to Simulating a Simple Spiking Neural Network using Brian2

In [1]:
pip install brian2

Collecting brian2
  Downloading Brian2-2.7.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.2 kB)
Downloading Brian2-2.7.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m6.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: brian2
Successfully installed brian2-2.7.1


In [None]:
from brian2 import *

# Define model parameters
tau = 20*ms  # Faster neuron dynamics
synaptic_strength = 0.5  # Increased synaptic strength
threshold_value = 0.8  # Threshold for spiking

# Define model equations
eqs = '''
dv/dt = (-v) / tau : 1  # LIF model equation with decay towards zero
'''

# Create a neuron group
N = 8  # number of neurons
G = NeuronGroup(N, eqs, threshold=f'v>{threshold_value}', reset='v=0', method='exact')
G.v = 'rand()'  # initialize membrane potential with random values

# Connect neurons
S = Synapses(G, G, on_pre=f'v_post += {synaptic_strength}')  # set synaptic strength
S.connect(p=0.2)  # randomly connect neurons with higher connection probability

# Monitor spikes
spike_monitor = SpikeMonitor(G)

# Monitor neuron membrane potentials for raster plot
state_monitor = StateMonitor(G, 'v', record=True)

# Create a network and add all components
net = Network(G, S, spike_monitor, state_monitor)

# Run the simulation
net.run(2*second)

# Visualize the spike raster plot
figure(figsize=(12, 4))

subplot(121)
plot(spike_monitor.t/ms, spike_monitor.i, '.k')
xlabel('Time (ms)')
ylabel('Neuron index')
title('Spike Raster Plot')

# Plot membrane potentials for all neurons
subplot(122)
for i in range(N):
    plot(state_monitor.t/ms, state_monitor.v[i])
xlabel('Time (ms)')
ylabel('Membrane potential (v)')
title('Neuron Membrane Potentials')

show()
