# Objective 1: Basic SNN for Binary Classification

Using SNNTorch to distinguish two MNIST digits


In [13]:
# 1. Imports and setup
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
import snntorch as snn
from snntorch import surrogate


In [14]:
# 2. Hyperparameters
batch_size = 128
num_steps   = 25
beta        = 0.95
lr          = 5e-4
epochs      = 10

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [15]:
# 3. Prepare binary MNIST (digits 0 and 1)
transform = transforms.Compose([
    transforms.Resize((28,28)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
full_train = datasets.MNIST('data', train=True, download=True, transform=transform)
full_test  = datasets.MNIST('data', train=False, download=True, transform=transform)

def filter_digits(dataset, digits=[0,1]):
    idx = [i for i,(img,t) in enumerate(dataset) if t in digits]
    return Subset(dataset, idx)

train_data = filter_digits(full_train)
test_data  = filter_digits(full_test)

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader  = DataLoader(test_data,  batch_size=batch_size, shuffle=False, drop_last=False)


In [16]:
# 4. Define Network (architecture unchanged)
num_inputs  = 28*28
num_hidden  = 1000
num_outputs = 2  # binary

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1  = nn.Linear(num_inputs, num_hidden)
        self.lif1 = snn.Leaky(beta=beta)
        self.fc2  = nn.Linear(num_hidden, num_outputs)
        self.lif2 = snn.Leaky(beta=beta)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        self.lif1.reset_mem()
        self.lif2.reset_mem()

        spk2_rec = []
        mem2_rec = []
        for step in range(num_steps):
            cur1 = self.fc1(x)
            spk1, mem1 = self.lif1(cur1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2)
            spk2_rec.append(spk2)
            mem2_rec.append(mem2)

        return torch.stack(spk2_rec), torch.stack(mem2_rec)

net = Net().to(device)


In [17]:
# 5. Loss, optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=lr)


In [20]:
# 6. Training and evaluation functions
def train_epoch():
    net.train()
    total_loss = total_correct = total_samples = 0
    for data, targets in train_loader:
        data, targets = data.to(device), targets.to(device)
        optimizer.zero_grad()
        spk_rec, mem_rec = net(data)

        # compute loss summed over time
        loss = sum(criterion(mem_rec[t], targets) for t in range(num_steps))
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * targets.size(0)
        # rate decoding via spike counts: sum over time
        spikes = spk_rec.sum(dim=0)
        preds = spikes.argmax(dim=1)
        total_correct += (preds == targets).sum().item()
        total_samples += targets.size(0)
    return total_loss/total_samples, total_correct/total_samples


def test_epoch():
    net.eval()
    total_loss = total_correct = total_samples = 0
    with torch.no_grad():
        for data, targets in test_loader:
            data, targets = data.to(device), targets.to(device)
            spk_rec, mem_rec = net(data)

            loss = sum(criterion(mem_rec[t], targets) for t in range(num_steps))
            total_loss += loss.item() * targets.size(0)

            spikes = spk_rec.sum(dim=0)
            preds = spikes.argmax(dim=1)
            total_correct += (preds == targets).sum().item()
            total_samples += targets.size(0)
    return total_loss/total_samples, total_correct/total_samples


In [21]:
# 7. Run training
for epoch in range(1, epochs+1):
    tr_loss, tr_acc = train_epoch()
    te_loss, te_acc = test_epoch()
    print(f"Epoch {epoch:2d} | Train loss {tr_loss:.3f}, acc {tr_acc:.3f} | Test loss {te_loss:.3f}, acc {te_acc:.3f}")


Epoch  1 | Train loss 2.437, acc 0.978 | Test loss 0.317, acc 1.000
Epoch  2 | Train loss 0.356, acc 0.998 | Test loss 0.246, acc 0.997
Epoch  3 | Train loss 0.273, acc 0.998 | Test loss 0.230, acc 0.998
Epoch  4 | Train loss 0.261, acc 0.999 | Test loss 0.223, acc 1.000
Epoch  5 | Train loss 0.183, acc 0.999 | Test loss 0.201, acc 0.999
Epoch  6 | Train loss 0.132, acc 0.999 | Test loss 0.231, acc 0.999
Epoch  7 | Train loss 0.171, acc 0.999 | Test loss 0.238, acc 0.999
Epoch  8 | Train loss 0.091, acc 1.000 | Test loss 0.182, acc 0.999
Epoch  9 | Train loss 0.121, acc 0.999 | Test loss 0.199, acc 0.997
Epoch 10 | Train loss 0.112, acc 0.999 | Test loss 0.155, acc 0.998


In [24]:
# 8. Final test accuracy
te_loss, te_acc = test_epoch()
print(f"\nFinal Test Set Accuracy: {100*te_acc:.2f}%")



Final Test Set Accuracy: 99.81%
