In [None]:
import sinabs
import sinabs.layers as sl
import torch
import torch.nn as nn
import numpy as np
import tonic
import matplotlib.pyplot as plt
import torchvision
import numpy.lib.recfunctions as rf

In [None]:
# load all the filters and stack them to a 3d array of (filter number, width, height)

filters = []
for i in range(0, 360, 45):
    filters.append(np.load(f"VMfilters/{i}_grad.npy"))
filters = torch.tensor(np.stack(filters).astype(np.float32))

In [None]:
fig, axes = plt.subplots(2, 4, figsize=(10, 5))

for i in range(8):
    if i < 4:
        axes[0, i].imshow(filters[i])
    else:
        axes[1, i-4].imshow(filters[i])

In [None]:
# define our single layer network and load the filters
net = nn.Sequential(
    nn.Conv2d(1, filters.shape[0], filters.shape[1], bias=False),
    sl.IAF()
)
net[0].weight.data = filters.unsqueeze(1)

In [None]:
# load the recording and convert it to a structured numpy array
recording = np.load("twoobjects.npy")
recording[:, 3] *= 1e6 # convert time from seconds to microseconds
rec = rf.unstructured_to_structured(recording, dtype=np.dtype([('x', np.int16), ('y', np.int16), ('p', bool), ('t', int)]))

In [None]:
# find out maximum x and y
max_x = rec['x'].max().astype(int)
max_y = rec['y'].max().astype(int)
# only use a single polarity
rec['p'] = 0
sensor_size = (max_x+1, max_y+1, 1)
print(f"sensor size is {sensor_size}")

In [None]:
# We have to convert the raw events into frames so that we can feed those to our network
# We use a library called tonic for that https://tonic.readthedocs.io/en/latest/ as well as torchvision
# We use a 20ms (20000us) time window to bin events into frames and crop the center of the frame
transforms = torchvision.transforms.Compose([
    tonic.transforms.ToFrame(sensor_size=sensor_size, time_window=20000),
    torch.tensor,
    torchvision.transforms.CenterCrop((300, 400)),
])

In [None]:
frames = transforms(rec)

In [None]:
# this leaves us with some 337 time steps. 
frames.shape

In [None]:
plt.imshow(frames[10, 0])

In [None]:
# now we feed the data to our network! Because my computer has little memory, I only feed 10 specific time steps
with torch.no_grad():
    output = net(frames[100:110].float())

In [None]:
output.shape

In [None]:
# in the end we can plot the 

fig, axes = plt.subplots(2, 4, figsize=(10, 5))

for i in range(8):
    if i < 4:
        axes[0, i].imshow(output[0, i])
    else:
        axes[1, i-4].imshow(output[0, i])