In [1]:
import torch
import torch.nn as nn
import numpy as np
import snntorch as snn
import matplotlib.pyplot as plt
from snntorch import spikegen
from snntorch import surrogate
from snntorch import utils
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

ModuleNotFoundError: No module named 'torch'

In [14]:
# Small-SNN(500 Neurons) - LIF Model ("takes the sum of weighted inputs, much like the artificial neuron. But rather than passing it directly to an activation function, it will integrate the input over time with a leakage, much like an RC circuit. If the integrated value exceeds a threshold, then the LIF neuron will emit a voltage spike."" - snnTorch Docs 2.1")
# Lapicque's Model

beta = 0.5  # leak factor
R = 1       # resistance
C = 1.44    # capacitance
batch_size = 500
tau = R * C
num_inputs = 3 # temp, audio, humidity
num_outputs = 2 # fire detected / not detected
num_hidden = 500


class SNN(nn.Module):
    def __init__(self):
        super().__init__()

        # start layers
        self.fc1 = nn.Linear(num_inputs, num_hidden)
        self.lif1 = snn.Leaky(beta=beta, spike_grad=surrogate.fast_sigmoid())

        self.fc2 = nn.Linear(num_hidden, num_outputs)
        self.lif2 = snn.Leaky(beta=beta, spike_grad=surrogate.fast_sigmoid())

    def forward(self, x, mem1=None, mem2=None):
        # initialize membrane potentials
        mem1, mem2 = self.lif1.init_leaky(), self.lif2.init_leaky()
            
        spk_rec = []
        mem_rec = []

        # Loop over time
        for step in range(x.size(0)):
            cur1 = self.fc1(x[step])
            spk1, mem1 = self.lif1(cur1, mem1)

            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)

            spk_rec.append(spk2)
            mem_rec.append(mem2)

        return torch.stack(spk_rec), torch.stack(mem_rec)

In [None]:
# snn test

network = SNN()

inputs = torch.rand((10,1,500)) # time, batch, features

spk_rec, mem_rec = network(inputs)

print(spk_rec.shape)

torch.Size([10, 1, 2])


In [None]:
import sys
print(sys.executable)