In [3]:
import snntorch as snn
import torch
import torchvision

#Tarining Parameters 
batch_size = 128 #means in one irteration 128 image are processing
data_path = 'tmp/data/mnist' 
num_classes = 10 #Mnist has 10 output classes 

dtype = torch.float

from torchvision import datasets, transforms #torchvision is used to manipulate images or datasets

transform = transforms.Compose([
             transforms.Resize((28,28)),
             transforms.Grayscale(),
             transforms.ToTensor(), # it converts the image data (usually in the form of a PIL Image) into a PyTorch tensor.
             transforms.Normalize((0,),(1,))])

mnist_train = datasets.MNIST(data_path, train = True, download = True, transform = transform)

#for reduced dataset into subsets
from snntorch import utils

subset = 10
mnsit_train = utils.data_subset(mnist_train, subset)
print(f"Mnist_train is {len(mnist_train)}")

#DataLoaders
from torch.utils.data import DataLoader

train_loader = DataLoader(mnist_train, batch_size = batch_size, shuffle = True)

num_steps = 10
raw_vector = torch.ones(num_steps)*0
rate_coded_vector = torch.bernoulli(raw_vector)
print(f"Converted vector: {rate_coded_vector}")

rate_coded_vector = torch.randint(0, 2, size=(num_steps,))
print(f"Converted vector: {rate_coded_vector}")


Mnist_train is 6000
Converted vector: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
Converted vector: tensor([0, 0, 1, 0, 0, 0, 1, 0, 0, 1])


In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SpikingNeuronLayer(nn.Module):
    def __init__(self, num_inputs, num_neurons, threshold=1.0, decay=0.9):
        super(SpikingNeuronLayer, self).__init__()
        self.num_inputs = num_inputs
        self.num_neurons = num_neurons
        self.threshold = threshold
        self.decay = decay

        # Weight matrix and threshold for each neuron
        self.weights = nn.Parameter(torch.rand(num_neurons, num_inputs))
        self.membrane_potential = torch.zeros(num_neurons, requires_grad=False)

    def forward(self, spike_train):
        # Ensure the data type of spike_train matches the weights
        spike_train = spike_train.float()

        # Input spikes multiplied by weights
        weighted_input = torch.mm(spike_train, self.weights.t())

        # Update membrane potential with decay
        self.membrane_potential = self.membrane_potential * self.decay

        # Integrate input and update membrane potential
        self.membrane_potential += weighted_input

        # Check for spikes and reset membrane potential
        spikes = F.relu(torch.sign(self.membrane_potential - self.threshold))
        self.membrane_potential = torch.where(spikes > 0, torch.zeros_like(self.membrane_potential), self.membrane_potential)

        return spikes

class SNN5to1(nn.Module):
    def __init__(self):
        super(SNN5to1, self).__init__()
        self.layer1 = SpikingNeuronLayer(num_inputs=5, num_neurons=1)

    def forward(self, input_spikes):
        spikes = self.layer1(input_spikes)
        return spikes

# Example usage
batch_size = 10  # Define the batch size
model = SNN5to1()
input_spikes = torch.randint(2, size=(batch_size, 5)).float()  # Generate random spikes
output_spikes = model(input_spikes)

# Print output spikes
print("Input Spikes:\n", input_spikes)
print("Output Spikes:\n", output_spikes)


RuntimeError: output with shape [1] doesn't match the broadcast shape [10, 1]

In [2]:
import snntorch as snn
from snntorch import spikegen

import torch

def leaky_integrate_and_fire(mem, x, w, beta, threshold=1):
    spk = (mem > threshold).float()
    mem = beta * mem + w * x - spk * threshold
    return spk, mem

num_neurons = 5
num_output_neurons = 1
num_steps = 50

x = torch.cat((torch.zeros(num_neurons), torch.ones(num_steps - num_neurons) * 0.5), 0)
mem = torch.zeros(num_output_neurons)
spk_out = torch.zeros(num_output_neurons)
mem_rec = []
spk_rec = []

w = 0.4
beta = 0.819

num_output_spikes = 0  # Variable to count the number of spikes from the output neuron

for step in range(num_steps):
    input_spike, mem = leaky_integrate_and_fire(mem, x[step], w=w, beta=beta)
    mem_rec.append(mem.item())
    spk_rec.append(input_spike.item())
    
    num_output_spikes += input_spike.item()  # Counting the spikes from the output neuron

print(f"Number of spikes from the 5-to-1 neuron: {num_output_spikes}")

Number of spikes from the 5-to-1 neuron: 3.0


In [4]:
import torch

def leaky_integrate_and_fire(mem, x, w, beta, threshold=1):
    spk = (mem > threshold).float()
    mem = beta * mem + torch.matmul(w, x) - spk * threshold
    return spk, mem

num_neurons = 5
num_output_neurons = 1
num_steps = 10

# Each column of x corresponds to the input from one neuron
x = torch.zeros(num_steps, num_neurons)
x[num_neurons:, :] = 0.5  # Assuming after 'num_neurons' steps, all inputs are 0.5

mem = torch.zeros(num_output_neurons)
spk_out = torch.zeros(num_output_neurons)
mem_rec = []
spk_rec = []

# w should be a vector of size num_neurons
w = torch.full((num_neurons,), 0.4)
beta = 0.819

num_output_spikes = 0

for step in range(num_steps):
    input_spike, mem = leaky_integrate_and_fire(mem, x[step], w=w, beta=beta)
    mem_rec.append(mem.item())
    spk_rec.append(input_spike.item())
    
    num_output_spikes += input_spike.item()

print(f"Number of spikes from the neuron: {num_output_spikes}")


Number of spikes from the neuron: 3.0
