In [None]:
from __future__ import print_function
import math
import torchvision.transforms as transforms
import torch
import torch.utils.data as data
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets
from tqdm import tqdm
from utils_BNN_resnet import neg_ELBO, Logger
from BayesianResnet import resnet18
import torchvision.models as models
from model_BNN_test import CNN_lrt
import torch.nn.functional as F
from model_SNN import Net
import matplotlib.pyplot as plt

# Helper Functions

In [None]:
def get_entropy(probabilities):
    try:
        ent = -torch.sum(probabilities.cpu() * np.log(1e-16 + probabilities.cpu()), 1)
    except:
        ent = -torch.sum(probabilities * np.log(1e-16 + probabilities))

    return ent

def get_max_entropy(probabilities):
    p_uniform = 1.0/probabilities.size(1)
    p_uniform_dist = torch.ones(probabilities.size(1))*p_uniform
    max_ent = -torch.sum(p_uniform_dist * np.log(1e-16 + p_uniform_dist))
    
    return max_ent

In [None]:
def entropy_vs_eps( model, device, test_loader, epsilon, num_samples ):
    softmax = nn.Softmax(dim=1)
    model.eval()
    # Accuracy counter
    correct = 0
    entropy = 0.0

    # Loop over all examples in test set
    #for data, target in test_loader:
    for i, (data, target) in zip(tqdm(range(len(test_loader))),(test_loader)):

        # Send the data and label to the device
        data, target = data.to(device), target.to(device)

        # Set requires_grad attribute of tensor. Important for Attack
        data.requires_grad = True

        # Forward pass the data through the model
        output = model(data)

        # Calculate the loss
        loss = F.nll_loss(output, target)
        

        # Zero all existing gradients
        model.zero_grad()

        # Calculate gradients of model in backward pass
        loss.backward()

        # Collect datagrad
        data_grad = data.grad.data

        # Call FGSM Attack
        perturbed_data = fgsm_attack(data, epsilon, data_grad)

        # Re-classify the perturbed image
        entropy_tmp = 0.0
        for sample in range(num_samples):
            # Forward pass the data through the model
            output = model(perturbed_data)
            probs = softmax(output.data)
            entropy_tmp += get_entropy(probs)
        
        entropy += entropy_tmp/num_samples


    entropy_avg = (entropy/len(test_loader)).item()
    max_entropy = get_max_entropy(probs)
    print("Epsilon: {}\tEntropy = {}".format(epsilon, entropy_avg))
    
    
    # Return the accuracy and an adversarial example
    return entropy_avg, max_entropy

# Setup

In [None]:
use_cuda = torch.cuda.is_available()
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

learning_rate = 0.001
batch_size = 16
num_epochs = 80

In [None]:
# convert data to a normalized torch.FloatTensor
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
    ])

In [None]:
# choose the training and test datasets
train_data = datasets.CIFAR10('data', train=True,
                              download=True, transform=transform)
test_data = datasets.CIFAR10('data', train=False,
                             download=True, transform=transform)

# prepare data loaders (combine dataset and sampler)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=16)

test_loader = torch.utils.data.DataLoader(test_data, batch_size=1)

# Frequentist NN

In [None]:
reg = [1e-5, 1e-4, 1e-3, 1e-2, 1e-1]
for weight_decay in reg:
    # Initialize the network
    model = Net().to(device)

    criterion=nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)