In [2]:
!pip install snntorch


Collecting snntorch
  Downloading snntorch-0.9.4-py2.py3-none-any.whl.metadata (15 kB)
Downloading snntorch-0.9.4-py2.py3-none-any.whl (125 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/125.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m125.6/125.6 kB[0m [31m7.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: snntorch
Successfully installed snntorch-0.9.4


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

import snntorch as snn
from snntorch import surrogate
from snntorch import functional as SF
from snntorch import spikeplot as splt
from snntorch import utils

import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import confusion_matrix
import seaborn as sns


In [4]:
def hard_reset(net):
    for m in net.modules():
        if hasattr(m, "mem"):
            m.mem = m.mem.detach() * 0
        if hasattr(m, "spk"):
            m.spk = m.spk.detach() * 0


In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Vom rula pe:", device)

batch_size   = 128
num_steps    = 100
num_epochs   = 5
learning_rate = 1e-3


Vom rula pe: cpu


In [6]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.,), (1.,))
])


train_dataset = torchvision.datasets.MNIST(
    root="./data", train=True, download=True, transform=transform)

test_dataset = torchvision.datasets.MNIST(
    root="./data", train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size,
                          shuffle=True,  drop_last=True)

test_loader  = DataLoader(test_dataset,  batch_size=batch_size,
                          shuffle=False, drop_last=True)

print(f"Dimensiune train: {len(train_dataset)} imagini")
print(f"Dimensiune test : {len(test_dataset)} imagini")


100%|██████████| 9.91M/9.91M [00:00<00:00, 12.9MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 359kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 3.21MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 8.62MB/s]


Dimensiune train: 60000 imagini
Dimensiune test : 10000 imagini


In [7]:
data, targets = next(iter(train_loader))

data    = data.to(device)
targets = targets.to(device)

spike_data = snn.spikegen.rate(data, num_steps=num_steps)

print("Forma trenului de spike‑uri:", spike_data.shape)


Forma trenului de spike‑uri: torch.Size([100, 128, 1, 28, 28])


In [10]:
class ConvSNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 12, kernel_size=5, padding=2)
        self.pool  = nn.MaxPool2d(2)


        self.lif1  = snn.Leaky(
            beta=0.95,
            threshold=0.5,
            spike_grad=surrogate.fast_sigmoid())

        self.fc    = nn.Linear(12*14*14, 10, bias=False)
        self.lif2  = snn.Leaky(
            beta=0.95,
            threshold=0.5,
            spike_grad=surrogate.fast_sigmoid(),
            output=True)

    def forward(self, x_seq):
        spk_rec = []
        for t in range(x_seq.size(0)):
            x   = F.relu(self.conv1(x_seq[t]))
            x   = self.pool(x)
            _, s1 = self.lif1(x)
            s1f   = s1.flatten(1)
            _, s2 = self.lif2(self.fc(s1f))
            spk_rec.append(s2)
        return torch.stack(spk_rec)

In [11]:
net = ConvSNN().to(device)
loss_fn  = nn.CrossEntropyLoss()
optim    = torch.optim.Adam(net.parameters(), lr=learning_rate)


In [12]:
def train_one_epoch(ep):
    net.train(); hard_reset(net); run_loss = 0
    for data, tgt in train_loader:
        data, tgt = data.to(device), tgt.to(device)
        spikes    = snn.spikegen.rate(data, num_steps=num_steps)

        optim.zero_grad()
        out  = net(spikes)
        loss = loss_fn(out.sum(0), tgt)
        loss.backward(); optim.step()

        run_loss += loss.item(); hard_reset(net)

    avg_loss = run_loss / len(train_loader)
    all_losses.append(avg_loss)
    print(f"Epoca {ep+1} | loss mediu: {run_loss/len(train_loader):.4f}")

In [13]:
def evaluate():
    net.eval()
    hard_reset(net)

    correct = total = 0
    with torch.no_grad():
        for data, targets in test_loader:
            data, targets = data.to(device), targets.to(device)
            spike_data    = snn.spikegen.rate(data, num_steps=num_steps)

            spk_out = net(spike_data)
            out_sum = spk_out.sum(dim=0)

            preds   = out_sum.argmax(dim=1)
            correct += (preds == targets).sum().item()
            total   += targets.size(0)
            hard_reset(net)
    acc = 100.0 * correct / total
    all_accuracies.append(acc)
    print(f"Acuratete test: {acc:.2f}%")
    return acc


In [14]:
all_losses = []
all_accuracies = []

In [15]:
for epoch in range(num_epochs):
    train_one_epoch(epoch)
    evaluate()


Epoca 1 | loss mediu: 41.1751
Acurateţe test: 91.03%


KeyboardInterrupt: 

In [None]:
torch.save(net.state_dict(), "snn_mnist_model.pt")
print("Model salvat ca: snn_mnist_model.pt")

In [None]:
net.load_state_dict(torch.load("snn_mnist_model.pt"))
net.to(device)
net.eval()


In [None]:
# Evolutia loss si acuratetei in timp

plt.figure(figsize=(6, 3))
plt.plot(all_losses, label="Loss", marker='o')
plt.plot(all_accuracies, label="Accuracy (%)", marker='x')
plt.title("Evolutia loss si acuratetei in timp")
plt.xlabel("Epoca")
plt.ylabel("Valoare")
plt.legend()
plt.grid(True)
plt.savefig("evolutie_loss_accuracy.png", dpi=300)
plt.show()



In [None]:
# Predictii corecte

images, labels = next(iter(test_loader))

for idx in range(10):
    img = images[idx].unsqueeze(0).to(device)
    label = labels[idx].item()

    with torch.no_grad():
        spikes = snn.spikegen.rate(img, num_steps=num_steps)
        out    = net(spikes)
        summed = out.sum(dim=0)
        predicted = summed.argmax(dim=1).item()

    plt.imshow(img.cpu().squeeze(), cmap='gray')
    plt.title(f"Real: {label} | Prezis: {predicted}", fontsize=14)
    plt.axis(False)
    plt.show()


In [None]:
# Matricea de confuzie

all_preds = []
all_labels = []

net.eval()
hard_reset(net)

with torch.no_grad():
    for data, targets in test_loader:
        data, targets = data.to(device), targets.to(device)
        spike_data = snn.spikegen.rate(data, num_steps=num_steps)

        spk_out = net(spike_data)
        out_sum = spk_out.sum(dim=0)
        preds = out_sum.argmax(dim=1)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(targets.cpu().numpy())

        hard_reset(net)


cm = confusion_matrix(all_labels, all_preds)


plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", cbar=False,
            xticklabels=range(10), yticklabels=range(10))
plt.xlabel("Etichete prezise")
plt.ylabel("Etichete reale")
plt.title("Matricea de confuzie - clasificarea cifrelor MNIST")
plt.tight_layout()
plt.savefig("matrice_confuzie.png", dpi=300)
plt.show()


In [None]:
# Predictii gresite

wrong_preds = []

net.eval()
with torch.no_grad():
    for data, targets in test_loader:
        data, targets = data.to(device), targets.to(device)
        spikes = snn.spikegen.rate(data, num_steps=num_steps)
        outputs = net(spikes)
        summed = outputs.sum(dim=0)
        preds = summed.argmax(dim=1)

        for i in range(len(preds)):
            if preds[i] != targets[i]:
                wrong_preds.append((data[i].cpu(), targets[i].item(), preds[i].item()))

for i in range(min(5, len(wrong_preds))):
    img, label, predicted = wrong_preds[i]
    plt.imshow(img.squeeze(), cmap='gray')
    plt.title(f"Eticheta reala: {label} | Prezis: {predicted}")
    plt.axis('off')
    plt.show()
