<a href="https://colab.research.google.com/github/MazedaZ/SNN-for-MNIST-digits-PyTorch/blob/main/SNN_for_MNIST_digit_classification_using_PyTorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

Preprocessing MNIST Data

In [None]:
# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

# Load the MNIST dataset
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)

Creating simple Spiking Neural Network architecture.

In [None]:
class SNN(nn.Module):
    def __init__(self):
        super(SNN, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)  # Input layer
        self.fc2 = nn.Linear(128, 10)       # Output layer

    def forward(self, x):
        x = x.view(-1, 28 * 28)  # Flatten the input image
        x = torch.relu(self.fc1(x))  # Apply ReLU activation to the hidden layer
        x = self.fc2(x)  # Output layer
        return x

Training Loop

In [None]:
snn = SNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(snn.parameters(), lr=0.01, momentum=0.9)

epochs = 10

for epoch in range(epochs):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        optimizer.zero_grad()

        # Forward pass
        outputs = snn(inputs)

        # Compute the loss
        loss = criterion(outputs, labels)
        loss.backward()

        # Optimize the network's weights
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch {epoch + 1}, Loss: {running_loss / len(trainloader)}")

print("Finished Training")

Epoch 1, Loss: 0.3727827922566168
Epoch 2, Loss: 0.18093010937489235
Epoch 3, Loss: 0.1340298765976387
Epoch 4, Loss: 0.1068474445075019
Epoch 5, Loss: 0.09108516346089351
Epoch 6, Loss: 0.07775542111511329
Epoch 7, Loss: 0.0694889890601926
Epoch 8, Loss: 0.06229219530999009
Epoch 9, Loss: 0.05525595104313894
Epoch 10, Loss: 0.04959232907053202
Finished Training


Testing

In [None]:
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        inputs, labels = data
        outputs = snn(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f"Accuracy on the test dataset: {accuracy:.2f}%")


Accuracy on the test dataset: 97.65%
