In [1]:
import torch, torch.nn as nn
import snntorch as snn
from snntorch import functional as SF
import tonic


from tonic import datasets, transforms

dt = 30000

# encoding_dim = 100

transform = transforms.Compose(
            [
                # transforms.Downsample(spatial_factor=encoding_dim/700),
                transforms.CropTime(max=1e6), # all timestamp units in microseconds in Tonic
                transforms.ToFrame(
                    sensor_size=tonic.datasets.SHD.sensor_size,
                    time_window=dt,
                    include_incomplete=True,
                ),
            ]
        )

trainset=datasets.SHD('data', transform=transform)
testset=datasets.SHD('data', transform=transform, train=False)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# !rm -r cache
from tonic import DiskCachedDataset
from torch.utils.data import DataLoader

batch_size = 100
shd_trainset = DiskCachedDataset(trainset, cache_path='./cache/shd/train')
shd_testset = DiskCachedDataset(testset, cache_path='./cache/shd/test')
train_loader = DataLoader(shd_trainset, batch_size=batch_size, collate_fn=tonic.collation.PadTensors(batch_first=False), shuffle=False, drop_last=True)
test_loader = DataLoader(shd_testset, batch_size=batch_size, collate_fn=tonic.collation.PadTensors(batch_first=False), shuffle=False, drop_last=True)

In [3]:
for data, labels in iter(train_loader):
    print(data.size())
    print(labels)
    break

torch.Size([28, 100, 1, 700])
tensor([11, 13,  5, 10,  1, 13,  4, 14,  6,  0, 11, 17,  5,  9, 12, 12,  5,  0,
        14,  5,  9,  5, 15, 15, 19, 12,  0, 14, 19,  9,  1, 11, 12, 16,  0, 16,
        11, 15, 13, 12,  2, 14,  6, 17,  9,  3,  4,  4,  3, 11, 15,  1,  3, 12,
         1,  5,  9,  1,  6, 11, 19,  1,  7,  6, 14,  1,  0,  6, 10, 17, 11,  1,
        13,  8, 13,  2, 19,  6, 19,  9,  8,  2, 18,  7,  4,  7,  2,  3, 13,  9,
        14, 18,  2,  1, 14,  0, 19,  9, 15,  6], dtype=torch.int32)


In [4]:
# Training Parameters
num_classes = 20  # MNIST has 10 output classes
device = "cuda"

# Torch Variables
dtype = torch.float

# # Torch Variables
# dtype = torch.float

# from torchvision import datasets, transforms

# # Define a transform
# transform = transforms.Compose([
#             transforms.Resize((28,28)),
#             transforms.Grayscale(),
#             transforms.ToTensor(),
#             transforms.Normalize((0,), (1,))
# ])

# mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
# mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)

# from torch.utils.data import DataLoader

# train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=False, drop_last=True)
# test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=False, drop_last=True)

In [5]:
num_steps = 28
bits = num_steps
# Iterate through minibatches
data = iter(train_loader)
data_it, targets_it = next(data)

class replaceGrad(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, x_r):
        return x_r

    @staticmethod
    def backward(ctx, grad):
        return (grad, grad)

In [287]:
# Network Architecture
num_inputs = 700
num_hidden = 300
num_outputs = 20

# Temporal Dynamics
beta = 1

# spike_grad = surrogate.fast_sigmoid()
spike_grad = None

# Define Network
class Net(nn.Module):
    def __init__(self, singleSpk=False):
        super().__init__()
        self.singleSpk = singleSpk
        if singleSpk:
            reset = "none"
        else:
            reset = "subtract"
        # Initialize layers
        self.fc1 = nn.Linear(num_inputs, num_hidden)
        self.lif1 = snn.Leaky(beta=beta, learn_beta=0, spike_grad=spike_grad, reset_mechanism=reset)
        self.fc2 = nn.Linear(num_hidden, num_outputs)
        self.lif2 = snn.Leaky(beta=beta, learn_beta=0, spike_grad=spike_grad, reset_mechanism=reset)
            
        self.hasFired1 = torch.zeros(batch_size, num_hidden).to(device)
        self.hasFired2 = torch.zeros(batch_size, num_outputs).to(device)
        
#         self.decay = nn.parameter.Parameter(torch.ones(num_hidden))

    def clearState(self):
        self.mem1 = self.lif1.init_leaky()
        self.mem2 = self.lif2.init_leaky()
        self.hasFired1 = torch.zeros_like(self.hasFired1)
        self.hasFired2 = torch.zeros_like(self.hasFired2)
        self.traces0 = torch.zeros(batch_size, 700).cuda()
        self.traces1 = torch.zeros(batch_size, 300).cuda()
        
    def trackTrace(self, traces, spk):
        with torch.no_grad():
            traces = traces * beta + spk
        return traces


    def forward(self, x, targets, timestep=1):        
        # Record the final layer
        spk2_rec = []
        mem2_rec = []

        for step in range(timestep):
            spk = x[step].flatten(1)
            cur1 = self.fc1(spk)
            spk1, _ = self.lif1(cur1, self.mem1)
            if self.singleSpk:
                spk1 *= (1 - self.hasFired1)
                with torch.no_grad():
                    self.hasFired1 = torch.max(self.hasFired1, spk1)
            
#             if self.singleSpk:
#                 spk1 *= (1 - self.hasFired1)
#                 self.hasFired1 = torch.max(self.hasFired1, spk1)
            
            with torch.no_grad():
                self.traces1 = self.trackTrace(self.traces1, spk1)
                cur2 = self.fc2(spk1).detach()
            in_for_grad1 = replaceGrad.apply(spk1, self.traces1)
            out_for_grad1 = self.fc2(in_for_grad1)
            newCur2 = replaceGrad.apply(out_for_grad1, cur2)    
            spk2, self.mem2 = self.lif1(newCur2, self.mem2)
            
            
            
#             cur2 = self.fc2(spk1)
#             spk2, self.mem2 = self.lif2(cur2, self.mem2)
            
#             if self.singleSpk:
#                 spk2 *= (1 - self.hasFired2)
#                 self.hasFired2 = torch.max(self.hasFired2, spk2)

            spk2_rec.append(spk2)
            mem2_rec.append(newCur2)
#             print(cur1.shape, mem1.shape, cur2.shape, mem2.shape)

        return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0)

In [298]:
# Network Architecture
num_inputs = 700
num_hidden = 300
num_outputs = 20

# Temporal Dynamics
beta = 1

# spike_grad = surrogate.fast_sigmoid()
spike_grad = None

# Define Network
class Net(nn.Module):
    def __init__(self, singleSpk=False):
        super().__init__()
        self.singleSpk = singleSpk
        if singleSpk:
            reset = "none"
        else:
            reset = "subtract"
        # Initialize layers
        self.fc1 = nn.Linear(num_inputs, num_hidden)
        self.lif1 = snn.Leaky(beta=beta, learn_beta=0, spike_grad=spike_grad, reset_mechanism=reset)
        self.fc2 = nn.Linear(num_hidden, num_outputs)
        self.lif2 = snn.Leaky(beta=beta, learn_beta=0, spike_grad=spike_grad, reset_mechanism=reset)
            
        self.hasFired1 = torch.zeros(batch_size, num_hidden).to(device)
        self.hasFired2 = torch.zeros(batch_size, num_outputs).to(device)
        
#         self.decay = nn.parameter.Parameter(torch.ones(num_hidden))
    
    def clearState(self):
        self.mem1 = self.lif1.init_leaky()
        self.mem2 = self.lif2.init_leaky()
        self.hasFired1 = torch.zeros_like(self.hasFired1)
        self.hasFired2 = torch.zeros_like(self.hasFired2)

    def forward(self, x, targets, timestep=1):        
        # Record the final layer
        spk2_rec = []
        mem2_rec = []
        cur2_rec = []

        for step in range(timestep):
            spk = x[step].flatten(1)
            cur1 = self.fc1(spk)
            spk1, _ = self.lif1(cur1, self.mem1)
            if self.singleSpk:
                spk1 *= (1 - self.hasFired1)
                with torch.no_grad():
                    self.hasFired1 = torch.max(self.hasFired1, spk1)
            cur2 = self.fc2(spk1)            
            spk2, self.mem2 = self.lif1(cur2, self.mem2)
            if self.singleSpk:
                spk2 *= (1 - self.hasFired2)
                with torch.no_grad():
                    self.hasFired2 = torch.max(self.hasFired2, spk2)
            

#             cur1 = self.fc1(x[step].flatten(1))
#             spk1, self.mem1 = self.lif1(cur1, self.mem1)
            
# #             if self.singleSpk:
# #                 spk1 *= (1 - self.hasFired1)
# #                 self.hasFired1 = torch.max(self.hasFired1, spk1)
            
#             cur2 = self.fc2(spk1)
#             spk2, self.mem2 = self.lif2(cur2, self.mem2)
            
# #             if self.singleSpk:
# #                 spk2 *= (1 - self.hasFired2)
# #                 self.hasFired2 = torch.max(self.hasFired2, spk2)

            spk2_rec.append(spk2)
            mem2_rec.append(self.mem2)
            cur2_rec.append(cur2)

        return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0), torch.stack(cur2_rec, dim=0)

In [299]:
# pass data into the network, sum the spikes over time
# and compare the neuron with the highest number of spikes
# with the target
import numpy as np

torch.manual_seed(7)
# Load the network onto CUDA if available
net = Net(singleSpk=1).to(device)
   
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=17e-5, betas=(0.9, 0.999))

scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=.5)

In [300]:
torch.manual_seed(7)
num_epochs = 200
loss_hist = []
test_loss_hist = []
counter = 0

for epoch in range(num_epochs):
    iter_counter = 0
    train_batch = iter(train_loader)

    # Minibatch training loop
    for data, targets in train_batch:
        data = data.to(device)
        targets = targets.long()
        targets = targets.to(device)

        # forward pass
        net.train()
        net.clearState()
        for t in range(data.shape[0]):
            spk_rec, mem_rec, cur_rec = net(data[t:t+1], targets)
            loss_val = loss(cur_rec[-1], targets)
            optimizer.zero_grad()
            loss_val.backward()
            optimizer.step()
       

        # Store loss history for future plotting
        loss_hist.append(loss_val.item())

    scheduler.step()
    
    with torch.no_grad():
        net.eval()
        total = 0
        correct = 0
        for data, targets in test_loader:
            data = data.to(device)
            targets = targets.long()
            targets = targets.to(device)

            # forward pass
            net.clearState()
            test_spk, mem, _ = net(data, targets, data.shape[0])

            # calculate total accuracy
#             _, predicted = test_spk.sum(dim=0).max(1)
            temp = test_spk.sum(0)
   #         temp = bin2dec(test_spk, bits, net.phase2)
            memF = mem[-1].clone()
            memF[temp!=temp.max(1)[0][:, None]] = float('-inf')
            predicted = memF.max(-1)[1]
#             if not net.singleSpk:
#                 predicted = test_spk.sum(0).max(-1)[1]
#             _, predicted = bin2dec(test_spk, bits, net.phase2).max(1)
#             _, predicted = mem[-1].max(1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()
        print(f"Total correctly classified test set images: {correct}/{total}")
        print(f"Test Set Accuracy: {100 * correct / total:.2f}%\n")
        with open("shd2.txt", "a") as f:
            f.write(f"{epoch}: {100 * correct / total:.2f}%\n")

Total correctly classified test set images: 636/2200
Test Set Accuracy: 28.91%

Total correctly classified test set images: 742/2200
Test Set Accuracy: 33.73%

Total correctly classified test set images: 832/2200
Test Set Accuracy: 37.82%

Total correctly classified test set images: 910/2200
Test Set Accuracy: 41.36%

Total correctly classified test set images: 1004/2200
Test Set Accuracy: 45.64%

Total correctly classified test set images: 1056/2200
Test Set Accuracy: 48.00%

Total correctly classified test set images: 1054/2200
Test Set Accuracy: 47.91%

Total correctly classified test set images: 1118/2200
Test Set Accuracy: 50.82%

Total correctly classified test set images: 1133/2200
Test Set Accuracy: 51.50%

Total correctly classified test set images: 1166/2200
Test Set Accuracy: 53.00%

Total correctly classified test set images: 1243/2200
Test Set Accuracy: 56.50%

Total correctly classified test set images: 1244/2200
Test Set Accuracy: 56.55%

Total correctly classified test 

KeyboardInterrupt: 

In [None]:
spk_rec[-1].requires_grad

In [92]:
mem_rec[-1]

_SpikeTensor([])