In [1]:
import torch.nn as nn
import torch
from torchvision import datasets, transforms
from snntorch import utils
from torch.utils.data import DataLoader

from snntorch import spikegen
import snntorch as snn

In [2]:
#parametry
batch_size= 30
data_path='/data/mnist'
num_classes = 10  # MNIST ma 10 wyjsc

dtype = torch.float

#ewentualne przekształcenia na początku
transform = transforms.Compose([
            transforms.Resize((28,28)),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,))])

mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)

train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True)

num_steps = 100 #liczba kroków czasowych

#iterujemy po minibatach
data = iter(train_loader)
data_it, targets_it = next(data)

#jedna z metod do przekształcenia danych na spikes
spike_data = spikegen.rate(data_it, num_steps=num_steps)

In [3]:
#architektura
num_inputs = 28*28
num_hidden = 50
num_outputs = 10

beta = 0.95

#Siec
class Net(nn.Module):

    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(num_inputs, num_hidden)
        self.lif1 = snn.Leaky(beta=beta)
        self.fc2 = nn.Linear(num_hidden, num_outputs)
        self.lif2 = snn.Leaky(beta=beta)

    def forward(self, x):

        #inicjalizacja stanow ukrytych w chwili 0
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()

        # Record the final layer
        spk2_rec = []
        mem2_rec = []

        for step in range(num_steps):
            cur1 = self.fc1(x)
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            spk2_rec.append(spk2)
            mem2_rec.append(mem2)

        return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0)

# Load the network onto CUDA if available
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

net = Net().to(device)

In [4]:
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=5e-4, betas=(0.9, 0.999))

In [5]:
num_epochs = 1
loss_hist = []
test_loss_hist = []
counter = 0

#pętla treningowa
for epoch in range(num_epochs):
    train_batch = iter(train_loader)

    #pętla dla mininbatchow
    for data, targets in train_batch:
        data = data.to(device)
        targets = targets.to(device)

        #przejscie forward
        net.train()
        spk_rec, mem_rec = net(data.view(batch_size, -1))

        #inicjalizacja funkcji kosztu i sume po czasie 
        loss_val = torch.zeros((1), dtype=dtype, device=device)
        for step in range(num_steps):
            loss_val += loss(mem_rec[step], targets)

        #wyznaczanie pochodnych + aktualizacja wag
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        #zapisujemy wartosc funkcji kosztu
        loss_hist.append(loss_val.item())

        #jak to wyglada na zbiorze testowym?
        with torch.no_grad():
            net.eval()  
            test_data, test_targets = next(iter(test_loader))
            test_data = test_data.to(device)
            test_targets = test_targets.to(device)

            #przejscie forward na testowym
            test_spk, test_mem = net(test_data.view(batch_size, -1))

            #funkcja kosztu na testowym
            test_loss = torch.zeros((1), dtype=dtype, device=device)

            for step in range(num_steps):
                test_loss += loss(test_mem[step], test_targets)

            test_loss_hist.append(test_loss.item())

            #wyswietl wartosc funkcji kosztu
            if counter % 50 == 0:
                print("Test loss: ", float(test_loss))

            counter += 1
            
            if counter == 200:
                break

Test loss:  263.38800048828125
Test loss:  118.98670959472656
Test loss:  75.61820220947266
Test loss:  64.1594467163086
