In [1]:
import os
import sys
import time
import numpy as np
import h5py
import matplotlib.pyplot as plt
import math
import random
import copy

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.backends.cudnn as cudnn

from torch.utils.data import TensorDataset, DataLoader
import torchvision
import torchvision.transforms as transforms

from cleverhans.torch.attacks.fast_gradient_method import fast_gradient_method
from cleverhans.torch.attacks.projected_gradient_descent import projected_gradient_descent


use_cuda = True
device = torch.device("cuda" if use_cuda and torch.cuda.is_available() else "cpu")
seed = 42

np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

In [20]:
class Network_ANN(nn.Module):
    def __init__(self):
        super(Network_ANN, self).__init__()
        self.linear_dim = 8192
        
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1, bias=False)
        self.HalfRect1 = nn.ReLU()
        self.dropout1 = nn.Dropout(0.5)
        self.subsample1 = nn.MaxPool2d(2, 2, 0)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1, bias=False)
        self.HalfRect2 = nn.ReLU()
        self.dropout2 = nn.Dropout(0.5)
        self.subsample2 = nn.MaxPool2d(2, 2, 0)
        self.fc1 = nn.Linear(self.linear_dim, 10, bias=False)
        self.HalfRect3 = nn.ReLU()

    def to(self, device):
        self.device = device
        super().to(device)
        return self

    def forward(self, input):
        x = self.conv1(input)
        x = self.HalfRect1(x)
        x = self.dropout1(x)
        x = self.subsample1(x)
        x = self.conv2(x)
        x = self.HalfRect2(x)
        x = self.dropout2(x)
        x = self.subsample2(x)
        x = x.view(-1, self.linear_dim)
        x = self.fc1(x)
        return x


In [21]:
class Surrogate_BP_Function(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        out = torch.zeros_like(input).cuda()
        out[input > 0] = 1.0
        return out

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad = grad_input * 0.3 * F.threshold(1.0 - torch.abs(input), 0, 0)
        return grad


def PoissonGen(inp, rescale_fac=2.0):
    rand_inp = torch.rand_like(inp).cuda()
    return torch.mul(torch.le(rand_inp * rescale_fac, torch.abs(inp)).float(), torch.sign(inp))


In [22]:
class Network_SNN(nn.Module):
    def __init__(self, time_window=30, max_rate=200, threshold=1.0, leak_factor=1.0) -> None:
        super(Network_SNN, self).__init__()
        
        self.image_size = 64
        self.linear_dim = 8192
        
        self.leak_factor = leak_factor
        self.threshold = threshold        
        self.spike_fn = Surrogate_BP_Function.apply
        
        self.time_window = time_window
        self.dt = 0.001
        self.max_rate = max_rate
        self.rescale_factor = 1.0/(self.dt*self.max_rate)
        
        self.conv1 = nn.Conv2d(in_channels=1,out_channels=16,kernel_size=3,stride=1,padding=1,bias=False)
        self.HalfRect1 = nn.ReLU()
        self.dropout1 = nn.Dropout(0.5)
        self.subsample1 = nn.MaxPool2d(2, 2, 0)
        self.conv2 = nn.Conv2d(in_channels=16,out_channels=32,kernel_size=3,stride=1,padding=1,bias=False)
        self.HalfRect2 = nn.ReLU()
        self.dropout2 = nn.Dropout(0.5)
        self.subsample2 = nn.MaxPool2d(2, 2, 0)
        self.fc1 = nn.Linear(self.linear_dim, 10, bias=False)
        self.HalfRect3 = nn.ReLU()
        
        self.conv_list = [self.conv1, self.conv2]
        self.pool_list = [self.subsample1, self.subsample2]
        
        for m in self.modules():
            if (isinstance(m, nn.Conv2d)):
                m.threshold = self.threshold
            elif (isinstance(m, nn.Linear)):
                m.threshold = self.threshold

    def to(self, device):
        self.device = device
        super().to(device)
        return self
    
    def forward(self, input):
        infer, _input = input
        batch_size = _input.size(0)
            
        input_spksum = torch.zeros(batch_size, self.conv1.in_channels, self.image_size, self.image_size).to(device)
        mem_conv1 = torch.zeros(batch_size, self.conv1.out_channels, 64, 64).to(device)
        mem_conv2 = torch.zeros(batch_size, self.conv2.out_channels, 32, 32).to(device)
        mem_fc1 = torch.zeros(batch_size, self.fc1.out_features).to(device)
        mem_conv_list = [mem_conv1, mem_conv2]
        
        for t in range(self.time_window):
            
            if infer == 'Synth':
                _t = _input[:,t].view(batch_size, 1, self.image_size, self.image_size)
            else:
                _t = _input.view(batch_size, 1, self.image_size, self.image_size)
                
            spike_input = PoissonGen(_t.to(device), self.rescale_factor)
            input_spksum += spike_input
            out_prev = spike_input
            
            for idx in range(len(self.conv_list)):                
                mem_conv_list[idx] = self.leak_factor * mem_conv_list[idx] + self.conv_list[idx](out_prev)
                mem_thr = (mem_conv_list[idx] / self.conv_list[idx].threshold) - 1.0
                out = self.spike_fn(mem_thr)
                rst = torch.zeros_like(mem_conv_list[idx]).to(device)
                rst[mem_thr > 0] = self.conv_list[idx].threshold
                mem_conv_list[idx] = mem_conv_list[idx] - rst
                out_prev = out.clone()
                
                out_prev = self.pool_list[idx](out_prev)
            
            out_prev = out_prev.reshape(batch_size, -1)
            
            mem_fc1 = mem_fc1 + self.fc1(out_prev)

        out_voltage = mem_fc1 / self.time_window

        return input_spksum, out_voltage


In [9]:
def OrdinEval(model, params, train_loader=None, test_loader=None, arc='ANN'):
    model.eval()
    acc_record = list([])
    correct = 0
    total = 0

    with torch.no_grad():
        for split, loader in [('Train', train_loader), ('Test', test_loader)]:
            if loader is not None:
                for batch_idx, (inputs, targets) in enumerate(loader):
                    batch_sz = inputs.size(0)            
                    inputs = inputs.float().to(device)
                    labels_ = torch.zeros(batch_sz, 10).scatter_(1, targets.view(-1, 1), 1).to(device)

                    if arc == 'SNN':
                        _, outputs = model((params['infer_type'], inputs))
                    else:
                        outputs = model(inputs)
                    targets = targets.to(device)

                    loss = params['criterion'](outputs, labels_)
                    _, predicted = outputs.max(1)
                    total += float(targets.size(0))
                    correct += float(predicted.eq(targets).sum().cpu().item())
                print(arc, split, "Acc: %.3f" % (100 * correct / total), end=' | ')
                acc = 100. * float(correct) / float(total)
                acc_record.append(acc)


def OrdinTrainNEval(model, params, train_loader, test_loader, arc='ANN', verbose_interval=4):
    print(model, '\nParameter Count:', sum(p.numel() for p in model.parameters() if p.requires_grad),
          '\n\n','**********',arc,'Training **********')
    verbose_at = [int(len(train_loader)*i/verbose_interval) for i in range(1, verbose_interval+1)]

    for epoch in range(params['num_epochs']):
        model.train()
        running_loss = 0
        start_time = time.time()

        for i, (inputs, targets) in enumerate(train_loader):
            batch_sz = inputs.size(0)
            inputs = inputs.float().to(device)
            labels_ = torch.zeros(batch_sz, 10).scatter_(1, targets.view(-1, 1), 1).to(device)

            params['optimizer'].zero_grad()
            if arc == 'SNN':
                _, outputs = model((params['infer_type'], inputs))
            else:
                outputs = model(inputs)
                
            loss = params['criterion'](outputs, labels_)
            running_loss += loss.cpu().item()
            loss.backward()
            params['optimizer'].step()
            if i+1 in verbose_at:
                print ('\nEpoch [%d/%d], Step [%d/%d], Training Loss: %.5f Time elasped:%.2f s'
                      %(epoch+1, params['num_epochs'], i+1, len(train_loader),running_loss,time.time()-start_time), end=' | ')
                running_loss = 0
        OrdinEval(model, params, train_loader, test_loader, arc)


In [7]:
def PerturbedOrdinEval(substitute, target, target_params, dataloader, arc='ANN',
    attack_info={'epsilon' : 0.3, 'alpha' : 0.01, 'iterations' : 40}) -> None:
    acc_hist = {}
    
    eps = attack_info['epsilon']
    alpha = attack_info['alpha']
    iters = attack_info['iterations']

    target_params['infer_type'] = 'Ordin'
    for adv_func in ['fast_gradient_method', 'projected_gradient_descent']:
        target.eval()
        correct = 0
        total = 0

        for batch_idx, (inputs, targets) in enumerate(dataloader):
            batch_sz = inputs.size(0)            
            inputs = inputs.float().to(device)
            if adv_func == 'fast_gradient_method':
                perturbed_inputs = fast_gradient_method(substitute, inputs, eps, np.inf)
            else:
                perturbed_inputs = projected_gradient_descent(substitute, inputs, eps, alpha, iters, np.inf)
            labels_ = torch.zeros(batch_sz, 10).scatter_(1, targets.view(-1, 1), 1).to(device)

            if arc == 'SNN':
                _, outputs = target((target_params['infer_type'], perturbed_inputs))
            else:
                outputs = target(perturbed_inputs)
            targets = targets.to(device)

            loss = target_params['criterion'](outputs, labels_)
            _, predicted = outputs.max(1)
            total += float(targets.size(0))
            correct += float(predicted.eq(targets).sum().cpu().item())

        print(arc, '('+adv_func+')','Perturbed Data', "Accuracy: %.3f" % (100 * correct / total))
        acc = 100. * float(correct) / float(total)
        acc_hist[adv_func] = acc
    
    return acc_hist


def AttackNetworks(substitute, target, target_params, epsilon_range=np.arange(0, 0.35, 0.05), step_size=0.01, iters=40, arc='ANN'):
    acc_hist = {'fast_gradient_method' : [], 'projected_gradient_descent' : []}

    for eps in epsilon_range:
        print('Perturbation Strength:', eps)
        inter_hist = PerturbedOrdinEval(substitute, target, target_params, mnist_test_loader,
                           attack_info={'epsilon' : eps, 'alpha' : step_size, 'iterations' : iters}, arc=arc); print()
        for key in inter_hist:
            acc_hist[key].append(inter_hist[key])    
    print(acc_hist)


In [5]:
def mnist_transform(img_size, default=28):
    pad_xy = (img_size - default) // 2
    return transforms.Compose([
        transforms.Pad((pad_xy, pad_xy)),
        transforms.ToTensor()
    ])


img_size = 64
train_batch_size = 100
test_batch_size = train_batch_size * 2

mnist_train_dataset = torchvision.datasets.MNIST(root="D:\Dataset\mnist", train=True, download=True, transform=mnist_transform(img_size))
mnist_train_loader = torch.utils.data.DataLoader(mnist_train_dataset, batch_size=train_batch_size, shuffle=True)

mnist_test_dataset = torchvision.datasets.MNIST(root="D:\Dataset\mnist", train=False, download=True, transform=mnist_transform(img_size))
mnist_test_loader = torch.utils.data.DataLoader(mnist_test_dataset, batch_size=test_batch_size, shuffle=False)

# import shutil
# shutil.rmtree('/kaggle/working/D:\Dataset\mnist')

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to D:\Dataset\mnist/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 376516556.88it/s]


Extracting D:\Dataset\mnist/MNIST/raw/train-images-idx3-ubyte.gz to D:\Dataset\mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to D:\Dataset\mnist/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 38455775.82it/s]


Extracting D:\Dataset\mnist/MNIST/raw/train-labels-idx1-ubyte.gz to D:\Dataset\mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to D:\Dataset\mnist/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 226564828.72it/s]


Extracting D:\Dataset\mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to D:\Dataset\mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to D:\Dataset\mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 19479068.27it/s]

Extracting D:\Dataset\mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to D:\Dataset\mnist/MNIST/raw






In [23]:
Substitute_ANN = Network_ANN()
Substitute_ANN.to(device)

Substitute_ANN_Params = {'num_epochs' : 8, 'optimizer' : optim.Adam(Substitute_ANN.parameters(), lr=1e-3),
    'criterion' : nn.CrossEntropyLoss(reduction='mean').to(device), 'best_acc' : 0}


OrdinTrainNEval(Substitute_ANN, Substitute_ANN_Params, mnist_train_loader, mnist_test_loader, verbose_interval=1)

Network_ANN(
  (conv1): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (HalfRect1): ReLU()
  (dropout1): Dropout(p=0.5, inplace=False)
  (subsample1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (HalfRect2): ReLU()
  (dropout2): Dropout(p=0.5, inplace=False)
  (subsample2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=8192, out_features=10, bias=False)
  (HalfRect3): ReLU()
) 
Parameter Count: 86672 

 ********** ANN Training **********

Epoch [1/8], Step [600/600], Training Loss: 161.12518 Time elasped:11.44 s | ANN Train Acc: 97.003 | ANN Test Acc: 97.029 | 
Epoch [2/8], Step [600/600], Training Loss: 61.36725 Time elasped:11.95 s | ANN Train Acc: 98.102 | ANN Test Acc: 98.104 | 
Epoch [3/8], Step [600/600], Training Loss: 48.44498 Time elasped:10.95 s | ANN Train Acc: 98.348

In [24]:
time_window = 30
max_rate = 800
num_epochs = 2


SNN = Network_SNN(time_window=time_window, max_rate=max_rate)
SNN.to(device)

SNN_Params = {
    'num_epochs' : num_epochs, 'optimizer' : optim.Adam(SNN.parameters()), 'criterion' : nn.MSELoss().to(device),
    'infer_type' : 'Ordin', 'best_acc' : 0
}


OrdinTrainNEval(SNN, SNN_Params, mnist_train_loader, mnist_test_loader, arc='SNN', verbose_interval=1)

Network_SNN(
  (conv1): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (HalfRect1): ReLU()
  (dropout1): Dropout(p=0.5, inplace=False)
  (subsample1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (HalfRect2): ReLU()
  (dropout2): Dropout(p=0.5, inplace=False)
  (subsample2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=8192, out_features=10, bias=False)
  (HalfRect3): ReLU()
) 
Parameter Count: 86672 

 ********** SNN Training **********

Epoch [1/2], Step [600/600], Training Loss: 12.35645 Time elasped:131.16 s | SNN Train Acc: 97.548 | SNN Test Acc: 97.549 | 
Epoch [2/2], Step [600/600], Training Loss: 7.83382 Time elasped:131.17 s | SNN Train Acc: 98.048 | SNN Test Acc: 98.056 | 

In [25]:
epsilon_range = np.arange(0.15, 0.2, 0.1)

AttackNetworks(Substitute_ANN, SNN, SNN_Params, epsilon_range, arc='SNN')

Perturbation Strength: 0.15
SNN (fast_gradient_method) Perturbed Data Accuracy: 86.740
SNN (projected_gradient_descent) Perturbed Data Accuracy: 45.600

{'fast_gradient_method': [86.74], 'projected_gradient_descent': [45.6]}
