In [2]:
import torch
import torchvision
import torchvision.transforms as transforms

import numpy as np
import matplotlib.pyplot as plt

from spikenet.network import Network
from spikenet.image_to_spike_convertor import ImageToSpikeConvertor, SpikePlotter
from IPython.display import HTML

splt = SpikePlotter()

In [None]:
# from matplotlib.animation import FuncAnimation


# class SpikePlotter2(SpikePlotter):
#     def animate(self, data: torch.Tensor, vmin: float = 0, vmax: float = 1, cmap: str = "gray", title: str | None = None):
#         fig, ax = plt.subplots()
#         if (isinstance(data, torch.Tensor)):
#             data = data.detach().numpy()
#         t, w, h = data.shape
#         img = ax.imshow(np.ones((w, h)), cmap=cmap, vmin=vmin, vmax=vmax)
#         if title:
#             ax.set_title(title)

#         def init():
#             return img,

#         def update(frame):
#             img.set_data(data[frame])
#             return img,

#         ani = FuncAnimation(
#             fig, update, frames=np.arange(0, t), init_func=init, blit=False
#         )
#         res = ani.to_jshtml()
#         plt.close()
#         return res

#     def plot_spikes_history_1d(self, spikes, title: str | None = None):
#         plt.figure()
#         plt.plot(spikes)
#         plt.title(title or "1D Spikes")
#         plt.show()


# splt = SpikePlotter2()

In [3]:
class MyData(ImageToSpikeConvertor):
    def __init__(self):
        super().__init__(
            train_data=torchvision.datasets.MNIST(
                root="./data",
                train=True,
                transform=transforms.ToTensor(),
                download=True,
            ),
            test_data=torchvision.datasets.MNIST(
                root="./data", train=False, transform=transforms.ToTensor()
            ),
        )

    def x_transform(self, x):
        x = super().x_transform(x)
        res = x.reshape(-1, self.time_scale, 28 * 28)
        return res


data = MyData()
data.describe()

shape:
  - x: torch.Size([16, 784])
  - y: torch.Size([])
  - batch_size: 128


In [None]:
def create_net(input_shape, output_shape):
    from spikenet.layers.spiking_dense import SpikingDenseLayer

    net = (
        Network()
        .add_layer(
            SpikingDenseLayer, name="l1", input_dim=input_shape[1], output_dim=100
        )
        .add_layer(
            SpikingDenseLayer,
            name="l2",
            input_dim=100,
            output_dim=10,
            time_reduction="SpikeRate",
        )
    )
    return net

In [None]:
input_shape, output_shape = data.shape

net = create_net(input_shape, output_shape)

net = net.fit(data, epochs=50)

In [None]:
test_data = data.sample()
data_x, data_y = test_data

res = splt.animate(data_x.reshape(16, 28, 28), title=f"y= {data_y}")
HTML(res)

In [None]:
net.forward(data_x.reshape([1, *data_x.shape]))
# mem = net.layers[0].get_history("mem")
# mem.shape
# res = splt.animate(mem)
# HTML(res)

In [None]:
mem = net.layers[0].get_history("mem")
plt.imshow(mem.reshape(-1, 100), cmap="gray")
plt.xlabel("Neurons")
plt.ylabel("Time")
plt.show()

for i in range(10):
    plt.plot(mem[0, :, i], label=f"Neuron {i}")
plt.xlabel("Time")
plt.ylabel("Membrane Potential")
plt.show()

In [None]:
mem = net.layers[1].get_history("mem")
for i in range(10):
    plt.plot(mem[0, :, i], label=f"Neuron {i}")
# plt.plot(mem.reshape(-1, 10), cmap="gray")
plt.xlabel("Neurons")
plt.ylabel("Time")
plt.show()

In [None]:
plt.imshow(net.layers[0].w.detach().numpy())
plt.xlabel("Neurons")
plt.ylabel("Input")
plt.show()

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms

import numpy as np
import matplotlib.pyplot as plt

from spikenet.network import Network
from spikenet.image_to_spike_convertor import ImageToSpikeConvertor, SpikePlotter

splt = SpikePlotter()

In [None]:
# import numpy as np
# import torch
# from spikenet.layers.spiking_base import SpikingNeuron
# from spikenet.tools.configs import EPSILON

# from spikenet.layers.spiking_dense import SpikingDenseLayer


# class SpikingReadoutLayer(SpikingNeuron):
#     def __init__(self, **kwargs) -> None:
#         super().__init__(**kwargs)

#         self.__time_reduction = kwargs.get("time_reduction", "max")
#         assert self.__time_reduction in ["max", "time"]

#     def _max_reduction(self, x: torch.Tensor) -> torch.Tensor:
#         res = torch.max(x, 1)[0]
#         return res

#     def _time_reduction(self, x: torch.Tensor) -> torch.Tensor:
#         max_data = torch.max(x, 1)[1]
#         return max_data.min(1)[1]

#     def forward(self, x: torch.Tensor) -> torch.Tensor:
#         match self.__time_reduction:
#             case "max":
#                 return self._max_reduction(x)
#             case "time":
#                 return self._time_reduction(x)
#             case _:
#                 raise ValueError("Invalid time reduction method")

#     def initialize_parameters(self):
#         pass

#     def clamp(self):
#         pass

In [None]:
# create a random input tensor t*batch_size*input_dim
t = 100
batch_size = 24
input_dim = 10
spike_threshold = 0.9

input_tensor = torch.rand((batch_size, t, input_dim))
input_tensor = (input_tensor > spike_threshold).float()

print("shape:", input_tensor.shape)

plt.imshow(input_tensor[0, :, :].reshape(-1, input_dim).T, cmap="gray")
plt.title("Input Spikes")
plt.show()

In [None]:
max_data = torch.max(input_tensor, 1)[1]
max_min_data = max_data.min(1)[1]
# max_min_data.shape
# return max_data.min(1)[1]
torch.nn.functional.one_hot(max_min_data, num_classes=10).to(torch.float32)

In [None]:
from spikenet.layers.spiking_dense import SpikingDenseLayer


layer_1 = SpikingDenseLayer(name="dense_layer 1", input_dim=10, output_dim=2)
readout = SpikingDenseLayer(name="readout_layer", input_dim=2, output_dim=2, time_reduction="time")

layer_1.initialize_parameters()
readout.initialize_parameters()


layer_1_output = layer_1.forward(input_tensor)
pred = readout.forward(layer_1_output)

In [None]:
for ly, output in zip([layer_1, readout], [layer_1_output, pred]):
    mem = ly.get_history("mem")

    for i in range(mem.shape[2]):
        plt.plot(mem[0, :, i], f'C{i}', label=f"Neuron {i}")
        p = output.detach().numpy()[0, :, i]
        p = p * mem.max()
        p = np.array([[x, y] for x, y in enumerate(p) if y > 0])
        if p.shape[0] > 0:
            plt.plot(p[:, 0], p[:, 1], f"C{i}x", label=f"Neuron {i} Output")

    plt.title(f"Layer {ly.name} - winner: {pred.max()}")
    plt.legend()
    plt.show()