$$ MI \ Attack \ Evaluation - MNIST - SNN - BrainLeaks $$

# Necessary Imports

In [1]:
!pip install snntorch

Collecting snntorch
  Downloading snntorch-0.7.0-py2.py3-none-any.whl (108 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m109.0/109.0 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
Collecting nir (from snntorch)
  Downloading nir-1.0.1-py3-none-any.whl (76 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m76.2/76.2 kB[0m [31m6.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nirtorch (from snntorch)
  Downloading nirtorch-1.0-py3-none-any.whl (13 kB)
Installing collected packages: nir, nirtorch, snntorch
Successfully installed nir-1.0.1 nirtorch-1.0 snntorch-0.7.0


In [9]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
import torch.optim as optim
import torch.nn.functional as F
import itertools
import snntorch as snn
from snntorch import surrogate
from snntorch import utils
from snntorch import spikegen
import snntorch.spikeplot as splt
from IPython.display import HTML

# Data Preparation

In [19]:
# Network Architecture

num_inputs= 28*28
num_outputs = 10
batch_size=100


# Temporal Dynamics
num_steps = 25      # Number of Time-Steps for Encoding the Static Input
beta = 0.7          # Leakage (Decay) Factor of LIF Neurons
spike_grad=surrogate.fast_sigmoid(slope=40) # surrogate function

# Other
data_path='/data/mnist'
dtype = torch.float
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")


In [20]:
# Define a transform
transform = transforms.Compose([
            transforms.Grayscale(),
            transforms.Resize((28, 28)),
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,))])

mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
mnist_test  = datasets.MNIST(data_path, train=False, download=True, transform=transform)

train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader  = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=True)

# Convolutional Spiking Neural Network Model (Evaluation Classifier)

In [11]:
#  Initialize Network
net = nn.Sequential(nn.Conv2d(1, 12, 5),
                    nn.MaxPool2d(2),
                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),
                    nn.Conv2d(12, 24, 5),
                    nn.MaxPool2d(2),
                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),
                    nn.Flatten(),
                    nn.Linear(24 * 4 * 4, 10),
                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True)
                    ).to(device)

In [12]:
def forward_pass(net, num_steps, data):
  mem_rec = []

  utils.reset(net)  # resets hidden states for all LIF neurons in net

  for step in range(num_steps):
      _, mem_out = net(data[step])

      mem_rec.append(mem_out)

  return  torch.stack(mem_rec)

In [16]:
# Define functions to print metrics during training loop

def print_batch_accuracy(data, targets, train=False):


    output = forward_pass(net, num_steps, data)
    idx = output.sum(dim=0).argmax(1)
    acc = np.mean((targets == idx).detach().cpu().numpy())


    if train:
        print(f"Train Set Accuracy: {acc}")
    else:
        print(f"Test Set Accuracy: {acc}")


def train_printer():
    print(f"Epoch {epoch}, Minibatch {minibatch_counter}")
    print(f"Train Set Loss: {loss_hist[counter]}")
    print(f"Test Set Loss: {test_loss_hist[counter]}")
    print_batch_accuracy(spike_data, targets_it, train=True)
    print_batch_accuracy(test_spike_data, testtargets_it, train=False)
    print("\n")

The state dictionary containing the pre-trained model's learned parameters is saved in **'MNIST_SNN_Weights_Eval'** and can be loaded by running the following cell.

In [21]:
net.load_state_dict(torch.load("/content/MNIST_SNN_Weights_Eval",map_location=device))

<All keys matched successfully>

Following cell contains the training loop. If you have already loaded learned parameters from "MNIST_SNN_Weights_Eval", you don't need to run this cell.

In [None]:
# Create an Adam optimizer for training the neural network with a specified learning rate and betas.
optimizer = torch.optim.Adam(net.parameters(), lr=2e-4, betas=(0.9, 0.999))

# Instantiate the log softmax function and the negative log-likelihood loss function.
log_softmax_fn = nn.LogSoftmax(dim=-1)
loss_fn = nn.NLLLoss()

# Initialize lists to store training and testing loss values.
loss_hist = []
test_loss_hist = []

counter = 0
num_steps=25
# Outer training loop
for epoch in range(2): # These are the last two epochs to show the loss and metrics through the training
    minibatch_counter = 0
    data = iter(train_loader)

    # Minibatch training loop
    for data_it, targets_it in data:
        data_it = data_it.to(device)
        targets_it = targets_it.to(device)

        # Spike generator
        spike_data = spikegen.rate(data_it, num_steps)


        # Forward pass
        mem_rec = forward_pass(net, num_steps, spike_data)
        log_p_y = log_softmax_fn(mem_rec)
        loss_val = torch.zeros((1), dtype=dtype, device=device)

        # Sum loss over time steps to perform BPTT
        for step in range(num_steps):
          loss_val += loss_fn(log_p_y[step], targets_it)
        # print(loss_val)
        # break
        # Gradient Calculation
        optimizer.zero_grad()
        loss_val.backward()
        nn.utils.clip_grad_norm_(net.parameters(), 1)

        # Weight Update
        optimizer.step()

        # Store Loss history
        loss_hist.append(loss_val.item())

        # Test set
        test_data = itertools.cycle(test_loader)
        testdata_it, testtargets_it = next(test_data)
        testdata_it = testdata_it.to(device)
        testtargets_it = testtargets_it.to(device)

        # Test set spike conversion
        test_spike_data = spikegen.rate(testdata_it,num_steps)

        # Test set forward pass
        test_mem_rec = forward_pass(net, num_steps, test_spike_data)

        # Test set loss
        log_p_ytest = log_softmax_fn(test_mem_rec)
        log_p_ytest = log_p_ytest.sum(dim=0)
        loss_val_test = loss_fn(log_p_ytest, testtargets_it)
        test_loss_hist.append(loss_val_test.item())

        # Print test/train loss/accuracy
        if counter % 100 == 0:
            train_printer()
        minibatch_counter += 1
        counter += 1


Epoch 0, Minibatch 0
Train Set Loss: 2.0587158203125
Test Set Loss: 2.363649845123291
Train Set Accuracy: 0.98
Test Set Accuracy: 0.99


Epoch 0, Minibatch 100
Train Set Loss: 1.8622440099716187
Test Set Loss: 2.044363260269165
Train Set Accuracy: 0.97
Test Set Accuracy: 0.99


Epoch 0, Minibatch 200
Train Set Loss: 1.8321194648742676
Test Set Loss: 0.732135534286499
Train Set Accuracy: 0.99
Test Set Accuracy: 1.0


Epoch 0, Minibatch 300
Train Set Loss: 2.3062918186187744
Test Set Loss: 1.1230648756027222
Train Set Accuracy: 0.98
Test Set Accuracy: 1.0


Epoch 0, Minibatch 400
Train Set Loss: 1.0212091207504272
Test Set Loss: 2.224919080734253
Train Set Accuracy: 1.0
Test Set Accuracy: 0.96


Epoch 0, Minibatch 500
Train Set Loss: 1.2539761066436768
Test Set Loss: 1.781040072441101
Train Set Accuracy: 0.99
Test Set Accuracy: 0.98


Epoch 1, Minibatch 0
Train Set Loss: 2.6294121742248535
Test Set Loss: 1.7578938007354736
Train Set Accuracy: 0.98
Test Set Accuracy: 0.99


Epoch 1, Minib

In [None]:
# torch.save(net.state_dict(),'MNIST_SNN_Weights_Eval')

## Model Evaluation

In [22]:
# Evaluation
net.eval()
correct = 0
total = 0
with torch.no_grad():
    for data, targets in test_loader:
        data = data.to(device)
        targets = targets.to(device)
        spike_data = spikegen.rate(data,num_steps=num_steps)

        output = forward_pass(net, num_steps, spike_data)
        predicted = output.sum(dim=0).argmax(1)

        total += targets.size(0)
        correct += (predicted == targets).sum().item()

accuracy = 100 * correct / total
print(f"Test Accuracy: {accuracy}%")

Test Accuracy: 98.43%


# Attack Evaluation Phase

In [24]:
# Function for calculating the accuracy and top-3 accuracy
def count_correct(lst):
    first = 0
    top_3 = 0

    for num in lst:
        # Count the number of occurrences of digit 1
        if num == 1:
          first += 1

        # Count the number of values between 2 and 3 (inclusive)
        if 2 <= num <= 3:
            top_3 += 1

    return first, top_3+first


## BrainLeak V1

In [25]:
# Import Inverted Samples
Inv_BL1 = torch.load("MNIST_SNN_Inverted_BL1",map_location=device)

In [28]:
# Attack Evaluation

conf_mat = F.softmax(forward_pass(net,num_steps,Inv_BL1.permute(1, 0, 2,3).view([num_steps,10,1,28,28])),dim=2).mean(0) # Confidence Matrix
rank_mat = torch.argsort(conf_mat,descending=True) # Ranking Matrix
conf_ranks=[]

for i in range(10):
   conf_ranks.append(torch.where(rank_mat[i]==i)[0].item()+1)

DAA_ranks = conf_mat.argmax(0) # Out of all inverted samples, which one is the most similar to digit X


print("\n ** Predicted Labels  :\n\n",conf_mat.max(1)[1])
print("\n\n ** Predicted Labels Confidence  :\n\n",conf_mat.max(1)[0])

print("\n\n\n ** Ground Truth Ranks in Prediction  :\n\n",conf_ranks)
print("\n\n\n ** Number of corrects (First,Top-3)  :\n\n",count_correct(conf_ranks))

print("\n\n ** Confidences of the correct outputs  :\n\n",torch.diag(conf_mat))

print("\n\n ** DAA Ranks  :\n\n",DAA_ranks) # Out of all inverted samples, which one is the most similar to digit X


 ** Predicted Labels  :

 tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')


 ** Predicted Labels Confidence  :

 tensor([0.9835, 0.7286, 0.9973, 0.9925, 0.9799, 0.9915, 0.9702, 0.9913, 0.9135,
        0.7624], device='cuda:0', grad_fn=<MaxBackward0>)



 ** Ground Truth Ranks in Prediction  :

 [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]



 ** Number of corrects (First,Top-3)  :

 (10, 10)


 ** Confidences of the correct outputs  :

 tensor([0.9835, 0.7286, 0.9973, 0.9925, 0.9799, 0.9915, 0.9702, 0.9913, 0.9135,
        0.7624], device='cuda:0', grad_fn=<DiagonalBackward0_copy>)


 ** DAA Ranks  :

 tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')


In [30]:
print("Attack Accracy:",count_correct(conf_ranks)[0]/10)
print("\n\nTop-3 Attack Accuracy:", count_correct(conf_ranks)[1]/10)
print("\n\n ** Average Confidence of correct outputs",torch.diag(conf_mat).mean().item())
print("\n\nDAA:", (sum(DAA_ranks==torch.tensor(range(10)).to(device)).item())/10)

Attack Accracy: 1.0


Top-3 Attack Accuracy: 1.0


 ** Average Confidence of correct outputs 0.9310759902000427


DAA: 1.0


## Brain Leak V2

In [31]:
# Import Inverted Samples
Inv_BL2 = torch.load("MNIST_SNN_Inverted_BL2",map_location=device)

In [38]:
# Attack Evaluation


conf_mat = F.softmax(forward_pass(net,num_steps,torch.bernoulli(Inv_BL2).permute(1, 0, 2,3).view([num_steps,10,1,28,28])),dim=2).mean(0) # Confidence Matrix
rank_mat = torch.argsort(conf_mat,descending=True) # Ranking Matrix
conf_ranks=[]

for i in range(10):
   conf_ranks.append(torch.where(rank_mat[i]==i)[0].item()+1)

DAA_ranks = conf_mat.argmax(0) # Out of all inverted samples, which one is the most similar to digit X

print("\n ** Predicted Labels  :\n\n",conf_mat.max(1)[1])
print("\n\n ** Predicted Labels Confidence  :\n\n",conf_mat.max(1)[0])

print("\n\n\n ** Ground Truth Ranks in Prediction  :\n\n",conf_ranks)
print("\n\n\n ** Number of corrects (First,Top-3)  :\n\n",count_correct(conf_ranks))

print("\n\n ** Confidences of the correct outputs  :\n\n",torch.diag(conf_mat))

print("\n\n ** DAA Ranks  :\n\n",DAA_ranks) # Out of all inverted samples, which one is the most similar to digit X


 ** Predicted Labels  :

 tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')


 ** Predicted Labels Confidence  :

 tensor([0.9247, 0.9034, 0.9920, 0.9656, 0.9417, 0.9969, 0.9898, 0.9951, 0.9545,
        0.7955], device='cuda:0', grad_fn=<MaxBackward0>)



 ** Ground Truth Ranks in Prediction  :

 [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]



 ** Number of corrects (First,Top-3)  :

 (10, 10)


 ** Confidences of the correct outputs  :

 tensor([0.9247, 0.9034, 0.9920, 0.9656, 0.9417, 0.9969, 0.9898, 0.9951, 0.9545,
        0.7955], device='cuda:0', grad_fn=<DiagonalBackward0_copy>)


 ** DAA Ranks  :

 tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')


In [39]:
print("Attack Accracy:",count_correct(conf_ranks)[0]/10)
print("\n\nTop-3 Attack Accuracy:", count_correct(conf_ranks)[1]/10)
print("\n\n ** Average Confidence of correct outputs",torch.diag(conf_mat).mean().item())
print("\n\nDAA:", (sum(DAA_ranks==torch.tensor(range(10)).to(device)).item())/10)

Attack Accracy: 1.0


Top-3 Attack Accuracy: 1.0


 ** Average Confidence of correct outputs 0.945917546749115


DAA: 1.0
