In [1]:
# Main Setup

# imports
import snntorch as snn
from snntorch import spikeplot as splt
from snntorch import spikegen

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import matplotlib.pyplot as plt
import numpy as np
import itertools

# dataloader arguments
batch_size = 128
data_path='/tmp/data/mnist'

dtype = torch.float
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")

# Define a transform
transform = transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,))])

mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)

# Create DataLoaders
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=True)

# Network Architecture
num_inputs = 28*28
num_hidden = 1000
num_outputs = 10

# Temporal Dynamics
num_steps = 25
beta = 0.70

R = 1
C = 1.44

# Define Network
class Net(nn.Module):
    def __init__(self):
        super().__init__()

        # Initialize layers
        self.fc1 = nn.Linear(num_inputs, num_hidden)
        self.lif1 = snn.Lapicque(beta=beta)
        self.fc2 = nn.Linear(num_hidden, num_outputs)
        self.lif2 = snn.Lapicque(R=R, C=C)

    def forward(self, x):

        # Initialize hidden states at t=0
        mem1 = self.lif1.init_lapicque()
        mem2 = self.lif2.init_lapicque()

        # Record the final layer
        spk2_rec = []
        mem2_rec = []

        for step in range(num_steps):
            cur1 = self.fc1(x)
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            spk2_rec.append(spk2)
            mem2_rec.append(mem2)

        return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0)

In [3]:
# Testing Setup

test_net = Net()

test_net.load_state_dict(torch.load('net_generators/ref_snn_lapicque_0.pth'))

test_net.to(device)


print("Testing Loaded Network")


def acc_test(net, mnist_test, batch_size):
    # Test Accuracy

    total = 0
    correct = 0

    # drop_last switched to False to keep all samples
    test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=False)

    with torch.no_grad():
        net.eval()
        for data, targets in test_loader:
            data = data.to(device)
            targets = targets.to(device)
            
            # forward pass
            test_spk, _ = net(data.view(data.size(0), -1))

            # calculate total accuracy
            _, predicted = test_spk.sum(dim=0).max(1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()

    # print(f"Total correctly classified test set images: {correct}/{total}")
    # print(f"Test Set Accuracy: {100 * correct / total:.2f}%")
    return 100 * correct / total

def mod_fc1(net, neuron_index): # neuron_index 0 to 999
    # Print the initial weights
    # print("Initial weights:", net.fc1.weight)

    # Access the parameters
    parameters = net.parameters()

    # Assuming you want to modify the weights of the first neuron in the linear layer
    # neuron_index = 1    # 0 to 999
    for param in parameters:
        if param is net.fc1.weight:
            # Modify the weights of the first neuron
            new_weights = torch.randn(param.size(1))  # Example: Initialize new weights randomly
            param.data[neuron_index] = new_weights

    # Print the modified weights
    # print("Modified weights:", net.fc1.weight)
    # print("Size of Weights: ", net.fc1.weight.size())
    return net

def mod_fc2(net, neuron_index): # neuron_index 0 to 9
    # Print the initial weights
    # print("Initial weights:", net.fc2.weight)

    # Access the parameters
    parameters = net.parameters()

    # Assuming you want to modify the weights of the first neuron in the linear layer
    # neuron_index = 4     # 0 to 9
    for param in parameters:
        if param is net.fc2.weight:
            # Modify the weights of the first neuron
            new_weights = torch.randn(param.size(1))  # Example: Initialize new weights randomly
            param.data[neuron_index] = new_weights

    # Print the modified weights
    # print("Modified weights:", net.fc2.weight)
    # print("Size of Weights: ", net.fc2.weight.size())
    return net

Testing Loaded Network


  test_net.load_state_dict(torch.load('net_generators/ref_snn_lapicque_0.pth'))


In [4]:
# Hidden Layer Corruption

fc1_acc_loss_rec = []


test_net = Net()

test_net.load_state_dict(torch.load('net_generators/ref_snn_lapicque_0.pth'))

test_net.to(device)



acc_ref = acc_test(test_net, mnist_test, batch_size)
print(f"Base Accuracy: {(acc_ref):.2f}%")
for x in range(0,1000):
    acc = acc_test(mod_fc1(test_net, x), mnist_test, batch_size)
    print(f"Accuracy loss of Hidden Layer Parameter Corruption({x}): {(acc_ref-acc):.2f}%")
    fc1_acc_loss_rec.append(acc_ref-acc)


  test_net.load_state_dict(torch.load('net_generators/ref_snn_lapicque_0.pth'))


Base Accuracy: 91.25%
Accuracy loss of Hidden Layer Parameter Corruption(0): 0.00%
Accuracy loss of Hidden Layer Parameter Corruption(1): 0.10%
Accuracy loss of Hidden Layer Parameter Corruption(2): 0.12%
Accuracy loss of Hidden Layer Parameter Corruption(3): 0.07%
Accuracy loss of Hidden Layer Parameter Corruption(4): 0.08%
Accuracy loss of Hidden Layer Parameter Corruption(5): 0.08%
Accuracy loss of Hidden Layer Parameter Corruption(6): 0.04%
Accuracy loss of Hidden Layer Parameter Corruption(7): 0.08%
Accuracy loss of Hidden Layer Parameter Corruption(8): 0.08%
Accuracy loss of Hidden Layer Parameter Corruption(9): 0.05%
Accuracy loss of Hidden Layer Parameter Corruption(10): 0.07%
Accuracy loss of Hidden Layer Parameter Corruption(11): 0.14%
Accuracy loss of Hidden Layer Parameter Corruption(12): 0.13%
Accuracy loss of Hidden Layer Parameter Corruption(13): 0.12%
Accuracy loss of Hidden Layer Parameter Corruption(14): 0.10%
Accuracy loss of Hidden Layer Parameter Corruption(15): 0.