In [1]:
!pip install spikingjelly
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torchvision
import torch
import numpy as np
from spikingjelly.clock_driven import neuron, encoding, functional
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import matplotlib.pyplot as plt 

### variables

In [2]:
device = 'cuda:0'
dataset_dir = './'
log_dir = './'
model_output_dir = './'
batch_size = 32
lr = 1e-3
T = 100
# T = timesteps
tau = 2.0
#tau = tau m 
train_epoch = 2
writer = SummaryWriter(log_dir)

 
\begin{align}\begin{aligned}\tau_{m} \frac{\mathrm{d}V(t)}{\mathrm{d}t} = -(V(t) - V_{reset}) + X(t)\\\tau_{m} (V(t) - V(t-1)) = -(V(t-1) - V_{reset}) + X(t)\end{aligned}\end{align}
##  
![lif](https://www.researchgate.net/publication/339574089/figure/fig1/AS:863726592872448@1582939875128/The-illustration-of-Leaky-Integrate-and-Fire-LIF-neuron-dynamics-The-pre-spikes-are.png)

### loading dataset

In [3]:
train_dataset = torchvision.datasets.MNIST(
        root=dataset_dir,
        train=True,
        transform=torchvision.transforms.ToTensor(),
        download=True
)
test_dataset = torchvision.datasets.MNIST(
        root=dataset_dir,
        train=False,
        transform=torchvision.transforms.ToTensor(),
        download=True
)

train_data_loader = data.DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        shuffle=True,
        drop_last=True
)
test_data_loader = data.DataLoader(
        dataset=test_dataset,
        batch_size=batch_size,
        shuffle=False,
        drop_last=False
)

### define model

In [4]:
net = nn.Sequential(
        nn.Flatten(),
        nn.Linear(28*28, 100, bias=False),
        neuron.LIFNode(tau=tau),
        nn.Linear(100, 10, bias=False),
        neuron.LIFNode(tau=tau)

)
net = net.to(device)
    # optimizer
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    #define encoder
encoder = encoding.PoissonEncoder()
train_times = 0
max_test_accuracy = 0

### training

In [5]:
test_accs = []
train_accs = []
losses= []
for epoch in range(train_epoch):
        print("Epoch {}:".format(epoch))
        train_correct_sum = 0
        train_sum = 0
        net.train()
        for img, label in tqdm(train_data_loader):
            img = img.to(device)
            label = label.to(device)
            label_one_hot = F.one_hot(label, 10).float()

            optimizer.zero_grad()

            for t in range(T):
                if t == 0:
                    # encode and train one step
                    out_spikes_counter = net(encoder(img).float())

                else:
                    out_spikes_counter += net(encoder(img).float())
            #check the spikes frequency 
            out_spikes_counter_frequency = out_spikes_counter / T
            # calculate the mse i think mse is better then cross entropy
            loss = F.mse_loss(out_spikes_counter_frequency, label_one_hot)
            # backward probagation 
            losses.append(loss)
            loss.backward()
            optimizer.step()
            functional.reset_net(net)
            #losses.append(loss)
            #metrics calulations
            train_correct_sum += (out_spikes_counter_frequency.max(1)[1] == label.to(device)).float().sum().item()
            train_sum += label.numel()

            train_batch_accuracy = (out_spikes_counter_frequency.max(1)[1] == label.to(device)).float().mean().item()
            writer.add_scalar('train_batch_accuracy', train_batch_accuracy, train_times)
            train_accs.append(train_batch_accuracy)

            train_times += 1
        # the mean of the accuracy 
        train_accuracy = train_correct_sum / train_sum
        net.eval()
        with torch.no_grad():
            test_correct_sum = 0
            test_sum = 0
            for img, label in tqdm(test_data_loader):
                img = img.to(device)
                for t in range(T):
                    if t == 0:
                        im = encoder(img)
                        out_spikes_counter = net(im.float())
                    else:
                        im = encoder(img)
                        out_spikes_counter += net(im.float())

                test_correct_sum += (out_spikes_counter.max(1)[1] == label.to(device)).float().sum().item()
                test_sum += label.numel()
                functional.reset_net(net)
            test_accuracy = test_correct_sum / test_sum
            writer.add_scalar('test_accuracy', test_accuracy, epoch)
            test_accs.append(test_accuracy)
            max_test_accuracy = max(max_test_accuracy, test_accuracy)
        print("train_acc = {}, test_acc={}, max_test_acc={}, train_times={}".format( train_accuracy, test_accuracy, max_test_accuracy, train_times))

### learning curves

In [6]:
plt.figure()
plt.plot(train_accs,"r")
plt.xlabel("timestamps")
plt.title('train accuracy')
plt.figure()
plt.plot(test_accs)
plt.title("test accuracy")
plt.figure()
plt.plot(list(map(lambda x: x.cpu().detach().numpy(),losses)))
plt.title("train losses")