In [1]:
import torch
import snntorch
device = torch.device("mps" if torch.cuda.is_available() else "cpu")
#torch.cuda.set_per_process_memory_fraction(0.85, device=0)

import numpy as np

import tonic
from tonic import datasets, transforms
from torch.utils.data import DataLoader
from collections import namedtuple

State = namedtuple("State", "obs labels")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class _SHD2Raster():
    """ 
    Tool for rastering SHD samples into frames. Packs bits along the temporal axis for memory efficiency. This means
        that the used will have to apply jnp.unpackbits(events, axis=<time axis>) prior to feeding the data to the network.
    """

    def __init__(self, encoding_dim, sample_T = 100):
        self.encoding_dim = encoding_dim
        self.sample_T = sample_T
        
    def __call__(self, events):
        # tensor has dimensions (time_steps, encoding_dim)
        tensor = np.zeros((events["t"].max()+1, self.encoding_dim), dtype=int)
        np.add.at(tensor, (events["t"], events["x"]), 1)
        #return tensor[:self.sample_T,:]
        tensor = tensor[:self.sample_T,:]
        tensor = np.minimum(tensor, 1)
        #tensor = np.packbits(tensor, axis=0) pytorch does not have an unpack feature.
        return tensor

In [3]:
sample_T = 128
shd_timestep = 1e-6
shd_channels = 700
net_channels = 128
net_dt = 1/sample_T
batch_size = 256

obs_shape = tuple([net_channels,])
act_shape = tuple([20,])

transform = transforms.Compose([
    transforms.Downsample(
        time_factor=shd_timestep / net_dt,
        spatial_factor=net_channels / shd_channels
    ),
    _SHD2Raster(net_channels, sample_T=sample_T)
])

train_dataset = datasets.SHD("../data", train=True, transform=transform)
test_dataset = datasets.SHD("../data", train=False, transform=transform)



In [4]:
train_dl = iter(DataLoader(train_dataset, batch_size=len(train_dataset),
                          collate_fn=tonic.collation.PadTensors(batch_first=True), drop_last=True, shuffle=False))
        
x_train, y_train = next(train_dl)
x_train, y_train = x_train.to(torch.uint8), y_train.to(torch.uint8)
x_train, y_train = x_train.to(device), y_train.to(device)

In [5]:
def shuffle(dataset):
    x, y = dataset

    cutoff = y.shape[0] % batch_size

    indices = torch.randperm(y.shape[0])[:-cutoff]
    obs, labels = x[indices], y[indices]


    obs = torch.reshape(obs, (-1, batch_size) + obs.shape[1:])
    labels = torch.reshape(labels, (-1, batch_size)) # should make batch size a global

    return State(obs=obs, labels=labels)

In [6]:
test_dl = iter(DataLoader(test_dataset, batch_size=len(test_dataset),
                          collate_fn=tonic.collation.PadTensors(batch_first=True), drop_last=True, shuffle=False))
        
x_test, y_test = next(test_dl)
x_test, y_test = x_test.to(torch.uint8), y_test.to(torch.uint8)
x_test, y_test = x_test.to(device), y_test.to(device)
x_test, y_test = shuffle((x_test, y_test))

In [7]:
num_hidden = 64
# Define Network
class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()

        # Initialize layers
        self.fc1 = torch.nn.Linear(128, num_hidden)
        self.lif1 = snntorch.Leaky(beta=torch.ones(num_hidden)*0.5, learn_beta=True)
        self.fc2 = torch.nn.Linear(num_hidden, num_hidden)
        self.lif2 = snntorch.Leaky(beta=torch.ones(num_hidden)*0.5, learn_beta=True)
        self.fc3 = torch.nn.Linear(num_hidden, 20)
        self.lif3 = snntorch.Leaky(beta=torch.ones(20)*0.5, learn_beta=True, reset_mechanism="none")

    def forward(self, x):

        x = x.float() # [batch, time, channel]
        
        x = x.permute(1,0,2) # [time, batch, channel]
        # Initialize hidden states at t=0
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()
        
        V = []

        # need to fix since data is not time leading axis...
        for i, step in enumerate(x):
            cur1 = self.fc1(step)
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            cur3 = self.fc3(spk2)
            spk3, mem3 = self.lif3(cur3, mem3)
            
            V.append(mem3)

        
        return torch.stack(V, axis=0).permute(1,0,2)
        
# Load the network onto CUDA if available
net = Net().to(device)
#precompiled_net = Net().to(device)
#net = torch.compile(precompiled_net, fullgraph=True)

In [8]:
loss = torch.nn.CrossEntropyLoss(label_smoothing=0.3)
optimizer = torch.optim.Adam(net.parameters(), lr=5e-4)
acc = lambda predictions, targets : (torch.argmax(predictions, axis=-1) == targets).sum().item() / len(targets)

In [9]:
num_epochs = 50
loss_hist = []
counter = 0


# Outer training loop
for epoch in range(num_epochs):    
    print(epoch)
    train_batch = shuffle((x_train, y_train))
    train_data, targets = train_batch
    
    
    # Minibatch training loop
    for data, targets in zip(train_data, targets):

        # forward pass
        net.train()
        out_V = net(data)
        # initialize the loss & sum over time
        loss_val = loss(torch.sum(out_V, axis=-2), targets)

        # Gradient calculation + weight update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

    # Store loss history for future plotting
    loss_hist.append(loss_val.item())


# Test set
with torch.no_grad():
    denominator = y_test[0]
    test_acc = 0
    batch_acc = []
    for test_data, test_targets in zip(x_test, y_test):
        net.eval()
        # Test set forward pass
        out_V = net(test_data)
        # Test set loss
        batch_acc.append( acc(torch.sum(out_V, axis=-2), test_targets) )
    
    test_acc = np.mean(batch_acc)


0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49


In [10]:
test_acc

0.59716796875

Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50
Epoch 29/50
Epoch 30/50
Epoch 31/50
Epoch 32/50
Epoch 33/50
Epoch 34/50
Epoch 35/50
Epoch 36/50
Epoch 37/50
Epoch 38/50
Epoch 39/50
Epoch 40/50
Epoch 41/50
Epoch 42/50
Epoch 43/50
Epoch 44/50
Epoch 45/50
Epoch 46/50
Epoch 47/50
Epoch 48/50
Epoch 49/50
Epoch 50/50
Test Accuracy: 56.98%
