In [1]:
import torch, torch.nn as nn
import snntorch as snn
from snntorch import functional as SF
import tonic


from tonic import datasets, transforms

dt = 30000

# encoding_dim = 100

transform = transforms.Compose(
            [
                # transforms.Downsample(spatial_factor=encoding_dim/700),
                transforms.CropTime(max=1e6), # all timestamp units in microseconds in Tonic
                transforms.ToFrame(
                    sensor_size=tonic.datasets.SHD.sensor_size,
                    time_window=dt,
                    include_incomplete=True,
                ),
            ]
        )

trainset=datasets.SHD('data', transform=transform)
testset=datasets.SHD('data', transform=transform, train=False)

In [2]:
# !rm -r cache
from tonic import DiskCachedDataset
from torch.utils.data import DataLoader

batch_size = 100
shd_trainset = DiskCachedDataset(trainset, cache_path='./cache/shd/train')
shd_testset = DiskCachedDataset(testset, cache_path='./cache/shd/test')
train_loader = DataLoader(shd_trainset, batch_size=batch_size, collate_fn=tonic.collation.PadTensors(batch_first=False), shuffle=False, drop_last=True)
test_loader = DataLoader(shd_testset, batch_size=batch_size, collate_fn=tonic.collation.PadTensors(batch_first=False), shuffle=False, drop_last=True)

In [3]:
for data, labels in iter(train_loader):
    print(data.size())
    print(labels)
    break

torch.Size([34, 100, 1, 700])
tensor([11, 13,  5, 10,  1, 13,  4, 14,  6,  0, 11, 17,  5,  9, 12, 12,  5,  0,
        14,  5,  9,  5, 15, 15, 19, 12,  0, 14, 19,  9,  1, 11, 12, 16,  0, 16,
        11, 15, 13, 12,  2, 14,  6, 17,  9,  3,  4,  4,  3, 11, 15,  1,  3, 12,
         1,  5,  9,  1,  6, 11, 19,  1,  7,  6, 14,  1,  0,  6, 10, 17, 11,  1,
        13,  8, 13,  2, 19,  6, 19,  9,  8,  2, 18,  7,  4,  7,  2,  3, 13,  9,
        14, 18,  2,  1, 14,  0, 19,  9, 15,  6])


In [25]:
# Training Parameters
num_classes = 20  # MNIST has 10 output classes
# device = "cuda"
device = "mps"

# Torch Variables
dtype = torch.float

# # Torch Variables
# dtype = torch.float

# from torchvision import datasets, transforms

# # Define a transform
# transform = transforms.Compose([
#             transforms.Resize((28,28)),
#             transforms.Grayscale(),
#             transforms.ToTensor(),
#             transforms.Normalize((0,), (1,))
# ])

# mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
# mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)

# from torch.utils.data import DataLoader

# train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=False, drop_last=True)
# test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=False, drop_last=True)

In [16]:
num_steps = 34
bits = num_steps
# Iterate through minibatches
data = iter(train_loader)
data_it, targets_it = next(data)

# Spiking Data
#spike_data = binary(data_it.byte(), num_steps)


In [20]:
# Network Architecture
num_inputs = 700
num_hidden = 300
num_outputs = 20

# Temporal Dynamics
beta = 1

# spike_grad = surrogate.fast_sigmoid()
spike_grad = None

# Define Network
class Net(nn.Module):
    def __init__(self, singleSpk=False):
        super().__init__()
        self.singleSpk = singleSpk
        if singleSpk:
            reset = "none"
        else:
            reset = "subtract"
        # Initialize layers
        self.fc1 = nn.Linear(num_inputs, num_hidden)
        self.lif1 = snn.Leaky(beta=beta, learn_beta=0, spike_grad=spike_grad, reset_mechanism="subtract")
        self.fc2 = nn.Linear(num_hidden, num_outputs)
        self.lif2 = snn.Leaky(beta=beta, learn_beta=0, spike_grad=spike_grad, reset_mechanism=reset)

            
        self.hasFired1 = torch.zeros(batch_size, num_hidden).to(device)
        self.hasFired2 = torch.zeros(batch_size, num_outputs).to(device)

    def forward(self, x):

        # Initialize hidden states at t=0
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        # Record the final layer
        spk2_rec = []
        mem2_rec = []
        
        self.hasFired1 = torch.zeros_like(self.hasFired1)
        self.hasFired2 = torch.zeros_like(self.hasFired2)

        for step in range(x.shape[0]):
            cur1 = self.fc1(x[step].flatten(1))
            spk1, mem1 = self.lif1(cur1, mem1)
            
            if self.singleSpk:
                # spk1 *= (1 - self.hasFired1)
                spk1 = spk1 * (1 - self.hasFired1)
                self.hasFired1 = torch.max(self.hasFired1, spk1)
            
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            
            if self.singleSpk:
                # spk2 *= (1 - self.hasFired2)
                spk2 = spk2 * (1 - self.hasFired2)
                self.hasFired2 = torch.max(self.hasFired2, spk2)

            spk2_rec.append(spk2)
            mem2_rec.append(mem2)
#             print(cur1.shape, mem1.shape, cur2.shape, mem2.shape)

        return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0)

In [26]:
# pass data into the network, sum the spikes over time
# and compare the neuron with the highest number of spikes
# with the target
import numpy as np

torch.manual_seed(7)
# Load the network onto CUDA if available
net = Net(singleSpk=True).to(device)

def print_batch_accuracy(data, targets, train=False):
    output, mem = net(data)
#     _, idx = output.sum(dim=0).max(1)
#     _, idx = bin2dec(output, bits).max(1)
    temp = bin2dec(output, bits, net.phase2)
    memF = mem[-1].clone()
    memF[temp!=temp.max(1)[0][:, None]] = float('-inf')
    idx = memF.max(-1)[1]
#     _, idx = mem[-1].max(1)
    acc = np.mean((targets == idx).detach().cpu().numpy())

    if train:
        print(f"Train set accuracy for a single minibatch: {acc*100:.2f}%")
    else:
        print(f"Test set accuracy for a single minibatch: {acc*100:.2f}%")

def train_printer():
    print(f"Epoch {epoch}, Iteration {iter_counter}")
    print(f"Train Set Loss: {loss_hist[counter]:.2f}")
    print(f"Test Set Loss: {test_loss_hist[counter]:.2f}")
    print_batch_accuracy(data, targets, train=True)
    print_batch_accuracy(test_data, test_targets, train=False)
    print("\n")
    
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=37e-5, betas=(0.9, 0.999))
#optimizer = torch.optim.SGD(net.parameters(), lr=.1, momentum=0.9)
# data, targets = next(iter(train_loader))
# data = data.to(device)
# targets = targets.to(device)
# spk_rec, mem_rec = net(data.view(batch_size, -1))

scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=15, gamma=.5)

In [27]:
torch.manual_seed(7)
num_epochs = 200
loss_hist = []
test_loss_hist = []
counter = 0

# Outer training loop
for epoch in range(num_epochs):
    iter_counter = 0
    train_batch = iter(train_loader)

    # Minibatch training loop
    for data, targets in train_batch:
        data = data.to(device)
        targets = targets.to(device).long()

        # forward pass
        net.train()
#         spk_rec, mem_rec = net(data.view(batch_size, -1))
        spk_rec, mem_rec = net(data)

        # initialize the loss & sum over time
        loss_val = torch.zeros((1), dtype=dtype, device=device)
#         for step in range(num_steps):
#             loss_val += loss(mem_rec[step], targets)
        loss_val += loss(mem_rec[-1], targets)
  #      loss_val += loss(bin2dec(spk_rec, bits, net.phase2), targets)
        # Gradient calculation + weight update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()
        

        # Store loss history for future plotting
        loss_hist.append(loss_val.item())

    scheduler.step()
    
    with torch.no_grad():
        net.eval()
        total = 0
        correct = 0
        for data, targets in test_loader:
            data = data.to(device)
            targets = targets.to(device)

            # forward pass
            test_spk, mem = net(data)

            # calculate total accuracy
#             _, predicted = test_spk.sum(dim=0).max(1)
            temp = test_spk.sum(0)
   #         temp = bin2dec(test_spk, bits, net.phase2)
            memF = mem[-1].clone()
            memF[temp!=temp.max(1)[0][:, None]] = float('-inf')
            predicted = memF.max(-1)[1]
#             _, predicted = bin2dec(test_spk, bits, net.phase2).max(1)
#             _, predicted = mem[-1].max(1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()
        print(f"Total correctly classified test set images: {correct}/{total}")
        print(f"Test Set Accuracy: {100 * correct / total:.2f}%\n")
        with open("shd1.txt", "a") as f:
            f.write(f"{epoch}: {100 * correct / total:.2f}%\n")

Total correctly classified test set images: 162/2200
Test Set Accuracy: 7.36%

Total correctly classified test set images: 292/2200
Test Set Accuracy: 13.27%

Total correctly classified test set images: 453/2200
Test Set Accuracy: 20.59%

Total correctly classified test set images: 587/2200
Test Set Accuracy: 26.68%

Total correctly classified test set images: 647/2200
Test Set Accuracy: 29.41%

Total correctly classified test set images: 666/2200
Test Set Accuracy: 30.27%

Total correctly classified test set images: 765/2200
Test Set Accuracy: 34.77%

Total correctly classified test set images: 854/2200
Test Set Accuracy: 38.82%

Total correctly classified test set images: 823/2200
Test Set Accuracy: 37.41%

Total correctly classified test set images: 811/2200
Test Set Accuracy: 36.86%

Total correctly classified test set images: 807/2200
Test Set Accuracy: 36.68%

Total correctly classified test set images: 877/2200
Test Set Accuracy: 39.86%

Total correctly classified test set image

In [23]:
spk1.shape

NameError: name 'spk1' is not defined

In [24]:
test_spk.shape

torch.Size([33, 100, 20])