Let's compare the efficiency of **Leaky** and **LeakyParallel**.

In [None]:
!pip install snntorch



In [None]:
import time


We'll simulate time through 1000 iterations of a forward pass and compare the respective speeds.

First is Leaky.


In [None]:
import torch
import torch.nn as nn
import snntorch as snn

beta = 0.5
num_inputs = 784
num_hidden = 128
num_outputs = 10
batch_size = 128
x = torch.rand((batch_size, num_inputs))

# Define Network
class Net(nn.Module):
    def __init__(self):
        super().__init__()

        # initialize layers
        self.lif1 = snn.LeakyParallel(input_size=num_inputs, hidden_size=num_hidden) # randomly initialize recurrent weights
        self.lif2 = snn.LeakyParallel(input_size=num_hidden, hidden_size=num_outputs, beta=beta, learn_beta=True) # learnable recurrent weights initialized at beta

    def forward(self, x):
        spk1 = self.lif1(x)
        spk2 = self.lif2(spk1)
        return spk2
firTime = time.time()
fun = Net()
for i in range(1000):
  fun.forward(x)
secTime = time.time()
print(f"Time:{secTime - firTime}")


Time:15.50355315208435


15 seconds this time, not the best.

Let's see LeakyParallel now.

In [None]:
import torch
import torch.nn as nn
import snntorch as snn

beta = 0.5
num_inputs = 784
num_hidden = 128
num_outputs = 10
batch_size = 128
x = torch.rand((batch_size, num_inputs))
mem0 = torch.zeros(1)
mem1 = torch.zeros(1)
spk = torch.zeros(1)

# Define Network
class Net(nn.Module):
    def __init__(self):
        super().__init__()

        # initialize layers
        self.fc1 = nn.Linear(num_inputs, num_hidden)
        self.lif1 = snn.Leaky(beta=beta)
        self.fc2 = nn.Linear(num_hidden, num_outputs)
        self.lif2 = snn.Leaky(beta=beta)

    def forward(self, x, mem1, spk1, mem2):
        cur1 = self.fc1(x)
        spk1, mem1 = self.lif1(cur1, mem1)
        cur2 = self.fc2(spk1)
        spk2, mem2 = self.lif2(cur2, mem2)
        return mem1, spk1, mem2, spk2
thirTime = time.time()
fun = Net()
for i in range(1000):
  fun.forward(x, mem0, spk, mem1)
fourTime = time.time()
print(f"Time:{fourTime - thirTime}")

Time:2.626756191253662


More than five times faster!