In [1]:
import norse
import torch

# Creating a model

In [2]:
model = norse.torch.SequentialState(
    norse.torch.LIFBoxCell(),
    torch.nn.Conv2d(2, 8, 3),
    torch.nn.MaxPool2d(2),
    norse.torch.LIFBoxCell(),
    torch.nn.Conv2d(8, 16, 3),
    torch.nn.MaxPool2d(2),
    torch.nn.Flatten(1),
    torch.nn.Linear(784, 10)
)

In [3]:
model(torch.empty(1, 2, 34, 34))[0].shape

torch.Size([1, 10])

In [4]:
nir_graph = norse.torch.to_nir(model, torch.empty(1, 2, 34, 34))

In [5]:
nir_graph.nodes.keys()

dict_keys(['input', '0', '1', '2', '3', '4', '5', '6', '7', 'output'])

In [6]:
nir_graph.edges

[('7', 'output'),
 ('0', '1'),
 ('2', '3'),
 ('1', '2'),
 ('4', '5'),
 ('6', '7'),
 ('3', '4'),
 ('input', '0'),
 ('5', '6')]

# Training

In [7]:
import tqdm
import tonic

In [8]:
to_frame = tonic.transforms.ToFrame(sensor_size=tonic.datasets.NMNIST.sensor_size, time_window = 1e3)
dataset = tonic.datasets.NMNIST(".", transform=to_frame, train=False)
loader = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=10, collate_fn=tonic.collation.PadTensors(batch_first=False))

In [9]:
dataset[0][0].shape

(308, 2, 34, 34)

In [10]:
optimizer = torch.optim.Adam(model.parameters())
model = model.cuda()
correct = 0
for (x, label) in tqdm.tqdm(loader):
    optimizer.zero_grad()
    state = None
    out = []
    for timestep in x:
        y, state = model(timestep.cuda(), state)
        out.append(y)
    out = torch.stack(out).mean(0)
    pred = out.argmax(0)
    correct += out.argmax(0, keepdim=True).eq(label.cuda().view_as(pred)).sum()
    loss = torch.nn.functional.cross_entropy(out, label.cuda())
    loss.backward()
    optimizer.step()
    
print(f"Correct: {100 * correct / len(dataset):.1f}%")

100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [05:42<00:00,  2.92it/s]

Correct: 9.8%





# Export to NIR

In [11]:
norse.torch.to_nir(model.cpu(), torch.empty(1, 2, 34, 34))

NIRGraph(nodes={'input': Input(input_type={'input': array([ 1,  2, 34, 34])}), '0': LIF(tau=tensor(1.0000e-05), r=tensor(1.), v_leak=tensor(0.), v_threshold=tensor(1.)), '1': Conv2d(input_shape=None, weight=tensor([[[[ 0.2302, -0.1923,  0.1116],
          [ 0.0519,  0.1096, -0.1946],
          [-0.1547, -0.2132,  0.0488]],

         [[ 0.1668,  0.2272, -0.1020],
          [ 0.0744,  0.1548, -0.1317],
          [ 0.1251, -0.1976, -0.0381]]],


        [[[-0.1518,  0.2230, -0.0752],
          [-0.1672, -0.1807,  0.0759],
          [-0.2206,  0.1103,  0.0126]],

         [[ 0.1893,  0.0253,  0.1672],
          [ 0.1789,  0.1015,  0.2072],
          [ 0.1583, -0.0384, -0.0347]]],


        [[[ 0.1673, -0.0356, -0.0323],
          [ 0.2265,  0.2239, -0.2084],
          [-0.1985,  0.0692,  0.0279]],

         [[ 0.0436, -0.1635, -0.0721],
          [-0.0762, -0.2009, -0.0613],
          [-0.0784, -0.0248,  0.1876]]],


        [[[-0.1281, -0.2353, -0.2279],
          [ 0.0162,  0.0070, -0.01