# Standard Pytorch ANN training

## Mnist set up

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

In [2]:
torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
train_dataset = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor())

In [12]:
len(train_dataset), len(test_dataset)

(60000, 10000)

In [4]:
# Create data loaders
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=100, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=100, shuffle=False)

## FCN

In [5]:
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(784, 1200, bias=False)
        self.fc2 = nn.Linear(1200, 1200, bias=False)
        self.fc3 = nn.Linear(1200, 10, bias=False)
        self.dropout = nn.Dropout(0.5)
        
    def forward(self, x):
        x = x.view(-1, 784)
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = torch.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.fc3(x)
        return x

In [6]:
# Train multiple models and select the best one
num_trials = 5
best_accuracy = 0
best_mlp_model = None

for trial in range(num_trials):
    print(f"Trial {trial+1}")
    
    # Create a new model for each trial
    mlp_model = MLP().to(device)
    
    # Define the loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(mlp_model.parameters(), lr=0.001, weight_decay=1e-5)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)
    
    # Train the model
    num_epochs = 50
    for epoch in range(num_epochs):
        for images, labels in train_loader:
            images = images.to(device)
            labels = labels.to(device)
            
            # Forward pass
            outputs = mlp_model(images)
            loss = criterion(outputs, labels)
            
            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")
    
    # Evaluate the trained model on the test set
    mlp_model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = mlp_model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        
        accuracy = 100 * correct / total
        print(f"Test Accuracy: {accuracy:.2f}%")
        
        # Update the best model if the current model has higher accuracy
        if accuracy > best_accuracy:
            best_accuracy = accuracy
            best_mlp_model = model

print(f"Best Test Accuracy: {best_accuracy:.2f}%")

Trial 1
Epoch [1/50], Loss: 0.1215
Epoch [2/50], Loss: 0.0826
Epoch [3/50], Loss: 0.2695
Epoch [4/50], Loss: 0.0750
Epoch [5/50], Loss: 0.0142
Epoch [6/50], Loss: 0.0726
Epoch [7/50], Loss: 0.0458
Epoch [8/50], Loss: 0.0499
Epoch [9/50], Loss: 0.0669
Epoch [10/50], Loss: 0.0283
Epoch [11/50], Loss: 0.0742
Epoch [12/50], Loss: 0.0604
Epoch [13/50], Loss: 0.0426
Epoch [14/50], Loss: 0.0229
Epoch [15/50], Loss: 0.1541
Epoch [16/50], Loss: 0.0615
Epoch [17/50], Loss: 0.0132
Epoch [18/50], Loss: 0.0034
Epoch [19/50], Loss: 0.0598
Epoch [20/50], Loss: 0.0438
Epoch [21/50], Loss: 0.0381
Epoch [22/50], Loss: 0.0393
Epoch [23/50], Loss: 0.0023
Epoch [24/50], Loss: 0.1561
Epoch [25/50], Loss: 0.0583
Epoch [26/50], Loss: 0.0105
Epoch [27/50], Loss: 0.0493
Epoch [28/50], Loss: 0.0100
Epoch [29/50], Loss: 0.1189
Epoch [30/50], Loss: 0.1174
Epoch [31/50], Loss: 0.0256
Epoch [32/50], Loss: 0.0031
Epoch [33/50], Loss: 0.0347
Epoch [34/50], Loss: 0.0467
Epoch [35/50], Loss: 0.0374
Epoch [36/50], Loss: 

## ConvNet

In [15]:
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 12, kernel_size=5, bias=False)
        self.pool = nn.AvgPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(12, 64, kernel_size=5, bias=False)
        self.fc = nn.Linear(64 * 4 * 4, 10, bias=False)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 64 * 4 * 4)
        x = self.dropout(x)
        x = self.fc(x)
        return x

In [16]:
# Create the model and move it to the GPU
conv_model = ConvNet().to(device)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(conv_model.parameters())

# Train the model
num_epochs = 50
for epoch in range(num_epochs):
    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = conv_model(images)
        loss = criterion(outputs, labels)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

# Evaluate the trained model on the test set
conv_model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")

Epoch [1/50], Loss: 0.1256
Epoch [2/50], Loss: 0.0690
Epoch [3/50], Loss: 0.0619
Epoch [4/50], Loss: 0.1361
Epoch [5/50], Loss: 0.0885
Epoch [6/50], Loss: 0.0968
Epoch [7/50], Loss: 0.0608
Epoch [8/50], Loss: 0.0809
Epoch [9/50], Loss: 0.0319
Epoch [10/50], Loss: 0.0230
Epoch [11/50], Loss: 0.0769
Epoch [12/50], Loss: 0.0083
Epoch [13/50], Loss: 0.1382
Epoch [14/50], Loss: 0.0138
Epoch [15/50], Loss: 0.0357
Epoch [16/50], Loss: 0.0499
Epoch [17/50], Loss: 0.0068
Epoch [18/50], Loss: 0.0764
Epoch [19/50], Loss: 0.0113
Epoch [20/50], Loss: 0.0073
Epoch [21/50], Loss: 0.0108
Epoch [22/50], Loss: 0.0132
Epoch [23/50], Loss: 0.0152
Epoch [24/50], Loss: 0.0410
Epoch [25/50], Loss: 0.0641
Epoch [26/50], Loss: 0.0208
Epoch [27/50], Loss: 0.0031
Epoch [28/50], Loss: 0.0027
Epoch [29/50], Loss: 0.0096
Epoch [30/50], Loss: 0.0340
Epoch [31/50], Loss: 0.0026
Epoch [32/50], Loss: 0.0303
Epoch [33/50], Loss: 0.0109
Epoch [34/50], Loss: 0.0546
Epoch [35/50], Loss: 0.0706
Epoch [36/50], Loss: 0.0226
E

# ANN to SNN conversion

In [17]:
import snntorch as snn
from snntorch import spikeplot as splt
from snntorch import spikegen

In [18]:
lif1 = snn.Leaky(beta=0.5).to(device)
lif2 = snn.Leaky(beta=0.5).to(device)
lif3 = snn.Leaky(beta=0.5).to(device)

In [20]:
fc1_weights = best_model.fc1.weight.data
fc2_weights = best_model.fc2.weight.data
fc3_weights = best_model.fc3.weight.data

snn_fc1 = nn.Conv2d(1, 1200, kernel_size=(784, 1), bias=False).to(device)
snn_fc1.weight.data = fc1_weights.view(1200, 784, 1, 1).to(device)

snn_fc2 = nn.Conv2d(1200, 1200, kernel_size=(1, 1), bias=False).to(device)
snn_fc2.weight.data = fc2_weights.view(1200, 1200, 1, 1).to(device)

snn_fc3 = nn.Conv2d(1200, 10, kernel_size=(1, 1), bias=False).to(device)
snn_fc3.weight.data = fc3_weights.view(10, 1200, 1, 1).to(device)

In [21]:
def snn_forward(x, num_steps):
    # Reshape the input tensor to have 784 channels
    x = x.view(x.size(0), -1, 1, 1)
    
    mem1 = lif1.init_leaky().to(device)
    mem2 = lif2.init_leaky().to(device)
    mem3 = lif3.init_leaky().to(device)
    
    spk1_rec = []
    spk2_rec = []
    spk3_rec = []

    for step in range(num_steps):
        cur1 = F.relu(snn_fc1(x))
        spk1, mem1 = lif1(cur1, mem1)
        spk1_rec.append(spk1)
        
        cur2 = F.relu(snn_fc2(spk1))
        spk2, mem2 = lif2(cur2, mem2)
        spk2_rec.append(spk2)
        
        cur3 = snn_fc3(spk2)
        spk3, mem3 = lif3(cur3, mem3)
        spk3_rec.append(spk3)
    
    return torch.stack(spk3_rec, dim=0).mean(dim=0)

In [48]:
num_steps = 100  # Number of time steps for spiking dynamics

# Set the model to evaluation mode
best_model.eval()

# Disable gradient computation
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        labels = labels.to(device)
        images = images.to(device)
        # Reshape input images
        images = images.view(images.size(0), -1, 1, 1)
        
        # Forward pass through the spiking network
        outputs = snn_forward(images, num_steps)
        
        # Get the predicted class
        _, predicted = torch.max(outputs.data, 1)
        
        # Update total and correct predictions
        total += labels.size(0)
        correct += (predicted == labels.reshape(100,1,1)).sum().item()
    
    # Print the accuracy
    accuracy = 100 * correct / total
    print(f"Accuracy on the test set: {accuracy:.2f}%")

Accuracy on the test set: 97.79%
