In [2]:
# clear memory
from IPython import get_ipython
get_ipython().magic('reset -sf') 

import numpy as np
import torch
import time
timer = 0

from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data.sampler import SubsetRandomSampler

import torch.nn as nn
import torch.nn.functional as F

import advertorch.attacks as attacks
from attacks.deepfool import DeepfoolLinfAttack
import torch.nn as nn
from autoattack import AutoAttack

from advertorch.context import ctx_noparamgrad_and_eval
from torch.utils.tensorboard import SummaryWriter

import foolbox as fb

import os, random


# import argparse

# argument_parser = argparse.ArgumentParser()

# argument_parser.add_argument("--lr_init", type=float, help="Initial learning rate value, default=0.01. CAREFUL: this will be divided by beta, since the ERM term is multiplied by beta in the objective.")

# parsed_args = argument_parser.parse_args()


# Make sure validation splits are the same at all time (e.g. even after loading)
seed = 0

def seed_init_fn(seed=seed):
   np.random.seed(seed)
   random.seed(seed)
   torch.manual_seed(seed)
   return

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

num_workers = 0
# Make sure test_data is a multiple of batch_size_test
batch_size_train_and_valid = 128
batch_size_test = 200

# proportion of full training set used for validation
valid_size = 0.2




transform = transforms.ToTensor()
train_and_valid_data = datasets.MNIST(root = 'data', train = True, download = True, transform = transform)
test_data = datasets.MNIST(root = 'data', train = False, download = True, transform = transform)

num_valid_samples = int(np.floor(valid_size * len(train_and_valid_data)))
num_train_samples = len(train_and_valid_data) - num_valid_samples
train_data, valid_data = torch.utils.data.random_split(train_and_valid_data, [num_train_samples, num_valid_samples], generator=torch.Generator().manual_seed(seed))

train_loader = torch.utils.data.DataLoader(train_data, batch_size = batch_size_train_and_valid)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size = batch_size_train_and_valid)
test_loader = torch.utils.data.DataLoader(test_data, batch_size = batch_size_test, worker_init_fn=seed_init_fn)


class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        
        self.fc1 = nn.Linear(28*28, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc3 = nn.Linear(512, 10)
        
    def forward(self,x):
        # vectorise input
        x = x.view(-1,28*28)
        # Hidden layer 1 + relu
        x = F.relu(self.fc1(x))
        # Hidden layer 2 + relu
        x = F.relu(self.fc2(x))
        # Output layer
        x = self.fc3(x)
        return x


model = Net()
# model.to(device)


model.load_state_dict(torch.load('model_no_dropout.pt'))
model.to(device)


# if str(device) == "cuda" and torch.cuda.device_count() > 1:
#     print("Using DataParallel")
#     model = torch.nn.DataParallel(model)
# model.to(device)








# divided by 10 eps, eps_iter and CW's lr, added as input binary_search_steps to CW attacks


adversary_PGD_Linf_std = attacks.LinfPGDAttack(
    model, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=0.3,
    nb_iter=40, eps_iter=0.01, rand_init=True, clip_min=0.0,
    clip_max=1.0, targeted=False)

adversary_CW = attacks.CarliniWagnerL2Attack(
    model, num_classes=10, max_iterations=20, learning_rate=0.1,
    binary_search_steps=5, clip_min=0.0, clip_max=1.0)

adversary_deepfool = DeepfoolLinfAttack(
        model, num_classes=10, nb_iter=30, eps=0.11, clip_min=0.0, clip_max=1.0)

# Unseen attacks used for validation, has bigger learning rate and number of iterations
adversary_CW_unseen = attacks.CarliniWagnerL2Attack(
    model, num_classes=10, max_iterations=30, learning_rate=0.12,
    binary_search_steps=7, clip_min=0.0, clip_max=1.0)

adversary_PGD_Linf_unseen = attacks.LinfPGDAttack(
    model, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=0.4,
    nb_iter=40, eps_iter=0.033, rand_init=True, clip_min=0.0,
    clip_max=1.0, targeted=False)

adversary_deepfool_unseen = DeepfoolLinfAttack(
        model, num_classes=10, nb_iter=50, eps=0.4, clip_min=0.0, clip_max=1.0)

adversary_autoattack_unseen = AutoAttack(model, norm='Linf', eps=.3, 
        version='standard', seed=None, verbose=False)

adversary_PGD_L2_std = attacks.L2PGDAttack(
    model, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=2.,
    nb_iter=40, eps_iter=0.1, rand_init=True, clip_min=0.0,
    clip_max=1.0, targeted=False)

adversary_PGD_L1_std = attacks.L1PGDAttack(
    model, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=10.,
    nb_iter=40, eps_iter=0.5, rand_init=True, clip_min=0.0,
    clip_max=1.0, targeted=False)

def get_fb_attack(attack_name):
    if attack_name == 'PA_L1':
        fb_attack = fb.attacks.PointwiseAttack()
        fb_attack._distance = fb.distances.l1
        metric = 'L1'
    elif attack_name == 'PA_L2':
        fb_attack = fb.attacks.PointwiseAttack()
        fb_attack._distance = fb.distances.l2
        metric = 'L2'
    elif attack_name == 'BA_L2':
        fb_attack = fb.attacks.BoundaryAttack(steps=5000)
        metric = 'L2'
    elif attack_name == "VAT":
        fb_attack = fb.attacks.VirtualAdversarialAttack(steps=1000)
        metric = 'L2'
    elif attack_name == 'InvL2':
        fb_attack = fb.attacks.InversionAttack(distance=fb.distances.l2)
        metric = 'L2'
    elif attack_name == 'LinContL2':
        fb_attack = fb.attacks.LinearSearchContrastReductionAttack(distance=fb.distances.l2)
        metric = 'L2'
    else:
        raise ValueError("Invalid fb attack:", attack_name)
    return fb_attack,  metric

def generate_domains(domain_name, data, label, batch_size=batch_size_test, bool_correct_preds_per_domain={}):
    if len(bool_correct_preds_per_domain) == 0:
        mask = torch.ones_like(label)
    else:
        mask = bool_correct_preds_per_domain[domain_name]
    masked_data = data[mask, :, :, :]
    masked_label = label[mask]

    # All the data might have been masked. In that case return None.
    if len(masked_data) == 0:
        return None

    if domain_name == 'clean':
        return masked_data
    if domain_name == 'PGD_L1_std':
        return adversary_PGD_L1_std.perturb(masked_data, masked_label)
    if domain_name == 'PGD_L2_std':
        return adversary_PGD_L2_std.perturb(masked_data, masked_label)
    if domain_name == 'PGD_Linf_std':
        return adversary_PGD_Linf_std.perturb(masked_data, masked_label)
    if domain_name == 'Deepfool_base':
        return adversary_deepfool.perturb(masked_data, masked_label)
    if domain_name == "CW_base":
        return adversary_CW.perturb(masked_data, masked_label)
    if domain_name == 'PGD_Linf_mod':
        return adversary_PGD_Linf_unseen.perturb(masked_data, masked_label)
    if domain_name == 'Deepfool_mod':
        return adversary_deepfool_unseen.perturb(masked_data, masked_label)
    if domain_name == 'CW_mod':
        return adversary_CW_unseen.perturb(masked_data, masked_label)
    if domain_name == "Autoattack":
        return adversary_autoattack_unseen.run_standard_evaluation(masked_data, masked_label, bs=len(masked_label))











def loss_helper(model, data_all_domains, label_all_domains, num_domains, num_correct_per_domain, tensor_list_losses_epoch):
    list_losses = []
    
    for domain in range(0, num_domains):
        preds = model(data_all_domains[domain])
        list_losses.append(F.cross_entropy(preds, label_all_domains[domain]))
        num_correct_per_domain[domain] += ((torch.argmax(preds, dim=1) == label_all_domains[domain]).sum().item())
    
    # Some spaghetti going on here between torch and lists types, as evidenced by how the loss_helper() is called in compute_loss()
    tensor_list_losses = torch.stack(list_losses)
    
    ERM_term = torch.sum(tensor_list_losses) / num_domains
    REx_variance_term = torch.var(tensor_list_losses)
    
    tensor_list_losses_epoch += tensor_list_losses
    
    return ERM_term, REx_variance_term

def REx_loss(ERM_term, REx_variance_term, beta):
    return beta * REx_variance_term + ERM_term

 
def compute_loss(is_REx, beta, loss_terms, model, list_data_all_domains, list_label_all_domains, num_domains, 
                 num_train_correct_preds_per_domain, tensor_list_losses_epoch_train):
    if is_REx:
        ERM_term, REx_variance_term = loss_helper(model, list_data_all_domains, list_label_all_domains, num_domains, num_train_correct_preds_per_domain, tensor_list_losses_epoch_train)
        loss_terms_temp = [ERM_term.item(), REx_variance_term.item()]
        loss_terms += np.array(loss_terms_temp)
        loss = REx_loss(ERM_term, REx_variance_term, beta)
    else:
        ERM_term, _ = loss_helper(model, list_data_all_domains, list_label_all_domains, num_domains, num_train_correct_preds_per_domain, tensor_list_losses_epoch_train)
        loss_terms += np.array([ERM_term.item()])
        loss = ERM_term
    return loss


# Keep track across restarts of which samples were still correctly predicted, for each attack
def track_correct_pred_per_domain(model, data_all_domains, labels, domains, bool_correct_per_domain):
    for domain in domains:
        # Case when the mask filtered all data
        if data_all_domains[domain] == None:
            continue

        preds = model(data_all_domains[domain])
        # bool_correct_per_domain[domain] = torch.logical_and(bool_correct_per_domain[domain], (torch.argmax(preds, dim=1) == label_all_domains[domain]))

        # Array sizes of preds and bool_correct are different because of the mask when generating the domains, so handling it manually. Maybe
        # there is/will be a native method to handle this but gotta go fast.
        mask = bool_correct_per_domain[domain]
        are_preds_right = (torch.argmax(preds, dim=1) == labels[mask])
        i = 0
        for k in range(len(bool_correct_per_domain[domain])):
            if bool_correct_per_domain[domain][k]:
                bool_correct_per_domain[domain][k] = are_preds_right[i]
                i += 1
    return

# Compute the number of correct predictions against each attack after all the restarts
def update_num_correct_pred_per_domain(num_correct_per_domain, bool_correct_per_domain, domains):
    for domain in domains:
        num_correct_per_domain[domain] += bool_correct_per_domain[domain].sum().item()
    return

# Compute the number of correct predictions if the attacker was using an ensemble of all attacks. Skip the attacks in skipped_domains_worst_case from calculation.
def get_num_correct_worst_case(bool_correct_per_domain, domains, skipped_domains_worst_case=[]):
    # TODO WARNING
    # TODO WARNING
    if len(domains) == 0:
        raise ValueError("No domain has been defined !")
    
    bool_correct_worst_case = torch.ones_like(bool_correct_per_domain[domains[0]], dtype=torch.bool)
    for domain in domains:
        if domain in skipped_domains_worst_case:
            continue
        bool_correct_worst_case = torch.logical_and(bool_correct_worst_case, bool_correct_per_domain[domain])

    return bool_correct_worst_case.sum().item()

# Get which attacks were seen based on model filename
def get_seen_attacks(model_name):
    split_model_name = model_name.split('_')
    seen_attacks = []
    if "MSD" in split_model_name:
        if "ERM" in split_model_name:
            seen_attacks = ['PGD_L1_std', 'PGD_L2_std', 'PGD_Linf_std']
        else:
            seen_attacks = ['clean', 'PGD_L1_std', 'PGD_L2_std', 'PGD_Linf_std']
    if "PGDs" in split_model_name:
        seen_attacks = ['clean', 'PGD_L1_std', 'PGD_L2_std', 'PGD_Linf_std']
    if "std" in split_model_name:
        seen_attacks = ['clean', 'PGD_Linf_std', 'Deepfool_base', 'CW_base']
    if "clean" in split_model_name:
        seen_attacks = ['clean']
    if "L1" in split_model_name:
        seen_attacks = ['clean', 'PGD_L1_std']
    if "L2" in split_model_name:
        seen_attacks = ['clean', 'PGD_L2_std']
    if "Linf" in split_model_name:
        seen_attacks = ['clean', 'PGD_Linf_std']
    return seen_attacks















    






resume = True
# If you do not want restarts, set to 1 and not 0 as it's the number of times an adv is computed per sample
num_attack_restarts = 10

WORKING_DIR = "results/MNIST/"
TRAINED_MODEL_PATH = WORKING_DIR + "models/"
for root, dirs, files in os.walk(TRAINED_MODEL_PATH):
    model_filenames = files
    model_paths = [TRAINED_MODEL_PATH + file for file in files]



# if resume:
#     # checkpoint = torch.load("experiments/MNIST/MLP/pretrained_hard_PGD/REx_waterfall_lr_init_0.01/model_AIT_REx_3040.pt")
#     checkpoint = torch.load("model_MNIST_std_REx_840.pt")
#     # checkpoint = torch.load("model_MNIST_MSD_250.pt")
#     starting_epoch = checkpoint['epoch']
#     model.load_state_dict(checkpoint['current_model'])
#     model.to(device)



        

# TRAINED_MODEL_PATH = "experiments/MNIST/MLP/test/"
# writer = SummaryWriter(TRAINED_MODEL_PATH)

fb_attacks = []#['PA_L1'] #['PA_L1', 'PA_L2', 'BA_L2']
domains = ['clean', 'PGD_L1_std', 'PGD_L2_std', 'PGD_Linf_std', 'Deepfool_base', 'CW_base',
                'PGD_Linf_mod', 'Deepfool_mod', 'CW_mod', 'Autoattack']
skipped_domains_worst_case = ['PGD_Linf_mod']
# includes foolbox attacks
all_domains = domains + fb_attacks

num_test_batches = len(test_loader)
# Number of non foolbox domains
num_domains = 0
# Number of foolbox domains
num_fb_domains = len(fb_attacks)
    
    
######################    
# test the model #
######################
model.eval()
fmodel = fb.PyTorchModel(model, bounds=(0, 1), device=device)

for model_num, model_path in enumerate(model_paths):
    # checkpoint = torch.load("experiments/MNIST/MLP/pretrained_hard_PGD/REx_waterfall_lr_init_0.01/model_AIT_REx_3040.pt")
    checkpoint = torch.load(model_path)
    # checkpoint = torch.load("model_MNIST_MSD_250.pt")
    starting_epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['current_model'])
    model.to(device)

    seen_attacks = get_seen_attacks(model_filenames[model_num])
    unseen_attacks = [attack for attack in all_domains if attack not in seen_attacks]
    always_unseen_attacks = ['PGD_Linf_mod', 'Deepfool_mod', 'CW_mod', 'Autoattack']

    # number of correct predictions on each domain
    num_test_correct_preds_per_domain = {}
    results = {}
    for domain in all_domains:
        results[domain] = 0
        results[domain + "_bool_track_correct_preds"] = []
        num_test_correct_preds_per_domain[domain] = 0
        
    # number of correct predictions against ensemble of all attacks, first excludes the skipped domains in worst case calculation, second doesn't
    num_test_correct_preds_per_domain['worst_no_skipped'] = 0
    num_test_correct_preds_per_domain['worst_with_skipped'] = 0
    # number of correct preds against worst ensemble of seen or unseen
    num_test_correct_preds_per_domain['worst_seen'] = 0
    num_test_correct_preds_per_domain['worst_unseen'] = 0
    num_test_correct_preds_per_domain['worst_unseen_no_skipped'] = 0
    num_test_correct_preds_per_domain['worst_always_unseen'] = 0
    num_test_correct_preds_per_domain['worst_always_unseen_no_skipped'] = 0

    which_batch_test = 1

    for _, (data, label) in enumerate(test_loader):
        data, label = data.to(device), label.to(device)

        # Keeps track for each sample and each domain of whether one restart succeeded in fooling the network by using logical and
        # on (label == prediction) and bool_track_correct_pred each iteration. fb trackers are appended later in the code
        bool_track_correct_pred_per_domain = {}
        for domain in all_domains:
            bool_track_correct_pred_per_domain[domain] = torch.ones_like(label, dtype=torch.bool)


        for i_restarts in range(0, num_attack_restarts):
            with ctx_noparamgrad_and_eval(model):
                # Clean data is a domain.
                data_all_domains = {}
                for domain in domains:
                    data_all_domains[domain] = generate_domains(domain, data, label, batch_size=batch_size_test, bool_correct_preds_per_domain=bool_track_correct_pred_per_domain)


                # num_domains = len(data_all_domains)
                # # Initialise count of correct predictions per domain. This array tracks both non fb AND fb domains
                # if len(num_test_correct_preds_per_domain) == 0:
                #     num_test_correct_preds_per_domain = np.zeros(num_domains + num_fb_domains)



                # if len(bool_track_correct_pred_per_domain) == 0:
                #     bool_track_correct_pred_per_domain = [torch.ones_like(label)] * num_domains




            with torch.no_grad():
                track_correct_pred_per_domain(model, data_all_domains, label, domains, bool_track_correct_pred_per_domain)

        # Out of the block that is restarted due to historically testing Boundary attack here, which doesn't require restarts
        for fb_attack_name in fb_attacks:
            # Only notified on the first minibatch to avoid spamming
            if which_batch_test == 1:
                print("Using Foolbox attack ", fb_attack_name)
            fb_attack, metric = get_fb_attack(fb_attack_name)
            if metric == 'L0' or metric == 'L1':
                epsilon = 10.
            elif metric == 'L2':
                epsilon = 2
            _, temp_adv, bool_track_preds_temp = fb_attack(fmodel, data, label, epsilons=epsilon)
            # invert the bool because foolbox reports the attack's successes as True and we track the model's successes against adv
            bool_track_correct_pred_per_domain[fb_attack_name] = (~bool_track_preds_temp)


            # # Measure distance between adv example and clean sample with the same norm as the attack, compute the median distance over minibatch
            # temp = temp_adv-data
            # temp = torch.reshape(temp, (100, -1))
            # print(torch.linalg.norm(temp, dim=1, ord=int(metric[-1])).median())


        with torch.no_grad():
            update_num_correct_pred_per_domain(num_test_correct_preds_per_domain, bool_track_correct_pred_per_domain, all_domains)
            num_test_correct_preds_per_domain['worst_no_skipped'] += get_num_correct_worst_case(bool_track_correct_pred_per_domain, all_domains, skipped_domains_worst_case)
            num_test_correct_preds_per_domain['worst_with_skipped'] += get_num_correct_worst_case(bool_track_correct_pred_per_domain, all_domains)
            num_test_correct_preds_per_domain['worst_seen'] += get_num_correct_worst_case(bool_track_correct_pred_per_domain, seen_attacks)
            num_test_correct_preds_per_domain['worst_unseen'] += get_num_correct_worst_case(bool_track_correct_pred_per_domain, unseen_attacks)
            num_test_correct_preds_per_domain['worst_unseen_no_skipped'] += get_num_correct_worst_case(bool_track_correct_pred_per_domain, unseen_attacks, skipped_domains_worst_case)
            num_test_correct_preds_per_domain['worst_always_unseen'] += get_num_correct_worst_case(bool_track_correct_pred_per_domain, always_unseen_attacks)
            num_test_correct_preds_per_domain['worst_always_unseen_no_skipped'] += get_num_correct_worst_case(bool_track_correct_pred_per_domain, always_unseen_attacks, skipped_domains_worst_case)


        # Keep track of bool array to avoid having to redo the very costly perturbation with all attacks in case further metrics are needed
        for domain in all_domains:
            results[domain + "_bool_track_correct_preds"].append(bool_track_correct_pred_per_domain[domain].to('cpu'))

        # Debugging
        print("Testing, epoch ", starting_epoch, ": done with batch ", which_batch_test, " out of ", num_test_batches)
        if which_batch_test % 5 == 0:
            # print("Testing, epoch ", starting_epoch, ": done with batch ", which_batch_test, " out of ", num_test_batches)
            print("GPU memory allocated in GB:", torch.cuda.memory_allocated()/10**9)
            # Only compute on the first 10 minibatches = 1000 test samples with the default test minibatches of 100
            break
        which_batch_test += 1



    # calculate accuracies
    for keys, _ in num_test_correct_preds_per_domain.items():
        results[keys] = num_test_correct_preds_per_domain[keys] / (which_batch_test * batch_size_test) #len(test_loader.sampler)
    results['num_test_samples'] = (which_batch_test * batch_size_test)
    results['num_attack_restarts'] = num_attack_restarts
    results['model_name'] = model_filenames[model_num]
    results['seen_attacks'] = seen_attacks
    results['unseen_attacks'] = unseen_attacks
    results['always_unseen_attacks'] = always_unseen_attacks
    results['skipped_domains_worst_case'] = skipped_domains_worst_case
    print(results)

    working_dir_of_save = WORKING_DIR + "test_accs/"
    if not os.path.exists(working_dir_of_save):
        os.mkdir(WORKING_DIR + "test_accs/")
    np.save(working_dir_of_save + model_filenames[model_num], results)








# writer.add_scalar('Test_accuracy_clean', test_acc_per_domain[0], starting_epoch)




# writer.close()


Testing, epoch  1126 : done with batch  1  out of  50
Testing, epoch  1126 : done with batch  2  out of  50
Testing, epoch  1126 : done with batch  3  out of  50
Testing, epoch  1126 : done with batch  4  out of  50
Testing, epoch  1126 : done with batch  5  out of  50
GPU memory allocated in GB: 0.011449344
{'clean': 0.844, 'clean_bool_track_correct_preds': [tensor([ True, False,  True,  True,  True,  True,  True,  True, False,  True,
         True,  True,  True,  True,  True,  True,  True,  True, False,  True,
         True,  True, False,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True, False,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True, False,  True,  True,  True,  True,  True, False,
         True, False,  True, False, False,  True, False,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True, False,  True,  True,
         True,

In [None]:
# clear memory
from IPython import get_ipython
get_ipython().magic('reset -sf') 

import numpy as np
import torch
import time
timer = 0

from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data.sampler import SubsetRandomSampler

import torch.nn as nn
import torch.nn.functional as F

import advertorch.attacks as attacks
from attacks.deepfool import DeepfoolLinfAttack
import torch.nn as nn
from autoattack import AutoAttack

from advertorch.context import ctx_noparamgrad_and_eval
from torch.utils.tensorboard import SummaryWriter

import foolbox as fb

import os, random


# import argparse

# argument_parser = argparse.ArgumentParser()

# argument_parser.add_argument("--lr_init", type=float, help="Initial learning rate value, default=0.01. CAREFUL: this will be divided by beta, since the ERM term is multiplied by beta in the objective.")

# parsed_args = argument_parser.parse_args()


# Make sure validation splits are the same at all time (e.g. even after loading)
seed = 0

def seed_init_fn(seed=seed):
   np.random.seed(seed)
   random.seed(seed)
   torch.manual_seed(seed)
   return

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

num_workers = 0
# Make sure test_data is a multiple of batch_size_test
batch_size_train_and_valid = 128
batch_size_test = 200

# proportion of full training set used for validation
valid_size = 0.2




transform = transforms.ToTensor()
train_and_valid_data = datasets.MNIST(root = 'data', train = True, download = True, transform = transform)
test_data = datasets.MNIST(root = 'data', train = False, download = True, transform = transform)

num_valid_samples = int(np.floor(valid_size * len(train_and_valid_data)))
num_train_samples = len(train_and_valid_data) - num_valid_samples
train_data, valid_data = torch.utils.data.random_split(train_and_valid_data, [num_train_samples, num_valid_samples], generator=torch.Generator().manual_seed(seed))

train_loader = torch.utils.data.DataLoader(train_data, batch_size = batch_size_train_and_valid)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size = batch_size_train_and_valid)
test_loader = torch.utils.data.DataLoader(test_data, batch_size = batch_size_test, worker_init_fn=seed_init_fn)


class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        
        self.fc1 = nn.Linear(28*28, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc3 = nn.Linear(512, 10)
        
    def forward(self,x):
        # vectorise input
        x = x.view(-1,28*28)
        # Hidden layer 1 + relu
        x = F.relu(self.fc1(x))
        # Hidden layer 2 + relu
        x = F.relu(self.fc2(x))
        # Output layer
        x = self.fc3(x)
        return x


model = Net()
# model.to(device)


model.load_state_dict(torch.load('model_no_dropout.pt'))
model.to(device)


# if str(device) == "cuda" and torch.cuda.device_count() > 1:
#     print("Using DataParallel")
#     model = torch.nn.DataParallel(model)
# model.to(device)








def eval_PGD_Linf_increasing_eps(list_data_all_domains, epsilons):
    for epsilon in epsilons:
        adversary_PGD_attempt = attacks.LinfPGDAttack(
            model, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=epsilon,
            nb_iter=200, eps_iter=0.05, rand_init=True, clip_min=0.0,
            clip_max=1.0, targeted=False)
        list_data_all_domains.append(adversary_PGD_attempt.perturb(data, label))
    return

def eval_PGD_L2_increasing_eps(list_data_all_domains, epsilons):
    for epsilon in epsilons:
        adversary_PGD_attempt = attacks.L2PGDAttack(
            model, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=epsilon,
            nb_iter=200, eps_iter=0.1, rand_init=True, clip_min=0.0,
            clip_max=1.0, targeted=False)
        list_data_all_domains.append(adversary_PGD_attempt.perturb(data, label))
    return

def eval_PGD_L1_increasing_eps(list_data_all_domains, epsilons):
    for epsilon in epsilons:
        adversary_PGD_attempt = attacks.L1PGDAttack(
            model, loss_fn=nn.CrossEntropyLoss(reduction="sum"), eps=epsilon,
            nb_iter=200, eps_iter=0.5, rand_init=True, clip_min=0.0,
            clip_max=1.0, targeted=False)
        list_data_all_domains.append(adversary_PGD_attempt.perturb(data, label))
    return


# Keep track across restarts of which samples were still correctly predicted, for each attack
def track_correct_pred_per_domain(model, data_all_domains, label_all_domains, num_domains, bool_correct_per_domain):
    for domain in range(0, num_domains):
        preds = model(data_all_domains[domain])
        # print((torch.argmax(preds, dim=1) == label_all_domains[domain]))
        bool_correct_per_domain[domain] = torch.logical_and(bool_correct_per_domain[domain], (torch.argmax(preds, dim=1) == label_all_domains[domain]))
    return

# Compute the number of correct predictions against each attack after all the restarts
def update_num_correct_pred_per_domain(num_correct_per_domain, bool_correct_per_domain, num_domains, only_update_fb=False):
    start_for_loop = 0
    # Avoids reaccumulating the first few entries
    if only_update_fb:
        start_for_loop = num_domains
    for domain in range(start_for_loop, len(bool_correct_per_domain)):
        num_correct_per_domain[domain] += bool_correct_per_domain[domain].sum().item()
    return






WORKING_DIR = "results/MNIST/"
TRAINED_MODEL_PATH = WORKING_DIR + "models/"
for root, dirs, files in os.walk(TRAINED_MODEL_PATH):
    model_filenames = files
    model_paths = [TRAINED_MODEL_PATH + file for file in files]



num_test_batches = len(test_loader)
# Number of non foolbox domains
num_domains = 0





num_attack_restarts = 10
epsilons_Linf = [0.1, 0.2, 0.3, 0.4, 0.5]
epsilons_L2 = [1.0, 2.0, 3.0, 4.0, 5.0]
epsilons_L1 = [20.0, 40.0, 60.0, 80.0, 100.0]





base_domains = ['clean', 'PGD_L1_std', 'PGD_L2_std', 'PGD_Linf_std']
domains = ['clean']
for i, epsilons in enumerate([epsilons_L1, epsilons_L2, epsilons_Linf]):
    for epsilon in epsilons:
        domains.append(base_domains[i+1] + "_eps_" + str(epsilon))



######################    
# test the model #
######################
model.eval()
for model_num, model_path in enumerate(model_paths):
    checkpoint = torch.load(model_path)
    starting_epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['current_model'])
    model.to(device)

    results = {}
    for domain in domains:
        results[domain] = 0
    # number of correct predictions on each domain
    num_test_correct_preds_per_domain = []
    # number of correct predictions against ensemble of all attacks



    which_batch_test = 1

    for _, (data, label) in enumerate(test_loader):
        data, label = data.to(device), label.to(device)

        # Keeps track for each sample and each domain of whether one restart succeeded in fooling the network by using logical and
        # on (label == prediction) and bool_track_correct_pred each iteration. fb trackers are appended later in the code
        bool_track_correct_pred_per_domain = [torch.ones_like(label)] * num_domains

        for i_restarts in range(0, num_attack_restarts):
            with ctx_noparamgrad_and_eval(model):
                # Clean data is a domain
                list_data_all_domains = [data]
                # # Eval at multiple eps for each norm to test for gradient masking
                eval_PGD_L1_increasing_eps(list_data_all_domains, epsilons=epsilons_L1)
                eval_PGD_L2_increasing_eps(list_data_all_domains, epsilons=epsilons_L2)
                eval_PGD_Linf_increasing_eps(list_data_all_domains, epsilons=epsilons_Linf)

                num_domains = len(list_data_all_domains)
                # Initialise count of correct predictions per domain. This array tracks both non fb AND fb domains
                if len(num_test_correct_preds_per_domain) == 0:
                    num_test_correct_preds_per_domain = np.zeros(num_domains)

                list_label_all_domains = [label] * num_domains

                if len(bool_track_correct_pred_per_domain) == 0:
                    bool_track_correct_pred_per_domain = [torch.ones_like(label)] * num_domains




            with torch.no_grad():
                track_correct_pred_per_domain(model, list_data_all_domains, list_label_all_domains, num_domains, bool_track_correct_pred_per_domain)

        with torch.no_grad():
            update_num_correct_pred_per_domain(num_test_correct_preds_per_domain, bool_track_correct_pred_per_domain, num_domains)


        # Debugging
        print("Testing, epoch ", starting_epoch, ": done with batch ", which_batch_test, " out of ", num_test_batches)
        if which_batch_test % 5 == 0:
            # print("Testing, epoch ", starting_epoch, ": done with batch ", which_batch_test, " out of ", num_test_batches)
            print("GPU memory allocated in GB:", torch.cuda.memory_allocated()/10**9)
            break
        which_batch_test += 1



    # calculate average loss over an epoch
    test_acc_per_domain = num_test_correct_preds_per_domain / (which_batch_test * batch_size_test) #len(test_loader.sampler)
    results['num_test_samples'] = (which_batch_test * batch_size_test)
    results['num_attack_restarts'] = num_attack_restarts
    results['model_name'] = model_filenames[model_num]
    for i, domain in enumerate(domains):
        results[domain] = test_acc_per_domain[i]
    print(results)

    working_dir_of_save = WORKING_DIR + "increasing_eps/"
    if not os.path.exists(working_dir_of_save):
        os.mkdir(WORKING_DIR + "increasing_eps/")
    np.save(working_dir_of_save + model_filenames[model_num], results)


Testing, epoch  656 : done with batch  1  out of  50
Testing, epoch  656 : done with batch  2  out of  50
Testing, epoch  656 : done with batch  3  out of  50
Testing, epoch  656 : done with batch  4  out of  50
Testing, epoch  656 : done with batch  5  out of  50
GPU memory allocated in GB: 0.026421248
{'clean': 0.884, 'PGD_L1_std_eps_20.0': 0.812, 'PGD_L1_std_eps_40.0': 0.835, 'PGD_L1_std_eps_60.0': 0.809, 'PGD_L1_std_eps_80.0': 0.818, 'PGD_L1_std_eps_100.0': 0.818, 'PGD_L2_std_eps_1.0': 0.632, 'PGD_L2_std_eps_2.0': 0.497, 'PGD_L2_std_eps_3.0': 0.305, 'PGD_L2_std_eps_4.0': 0.07, 'PGD_L2_std_eps_5.0': 0.008, 'PGD_Linf_std_eps_0.1': 0.68, 'PGD_Linf_std_eps_0.2': 0.449, 'PGD_Linf_std_eps_0.3': 0.082, 'PGD_Linf_std_eps_0.4': 0.002, 'PGD_Linf_std_eps_0.5': 0.0, 'num_test_samples': 1000, 'num_attack_restarts': 10, 'model_name': 'model_MSD_ERM_655.pt'}
Testing, epoch  656 : done with batch  1  out of  50
Testing, epoch  656 : done with batch  2  out of  50
Testing, epoch  656 : done with ba

: 

In [3]:
# mask = torch.BoolTensor([False, False, False])
# a = torch.rand(3,6,2)
# print(a)
# b = len(a[mask, :, :])
# if b == 0:
#     c = None
# print(c == None)

# for domain in domains:
    # print(domain, data_all_domains[domain].shape)
import os
import numpy as np
import torch
WORKING_DIR = "results/MNIST/"
RESULTS_PATH = WORKING_DIR + "test_accs/"
for root, dirs, files in os.walk(RESULTS_PATH):
    model_filenames = files
    model_paths = [RESULTS_PATH + file for file in files]
for path in model_paths:
    temp = np.load(path, allow_pickle = True).item()
    results = {}
    for k, v in temp.items():
        if "bool" in k.split('_'):
            continue
        results[k] = v
    print(results['model_name'])
    print(results, '\n\n')
# PATH = "results/CIFAR10/test_accs/model_PGDs_REx_370.pt.npy"
# for domain in all_domains:
#     for k in range(len(results[domain + '_bool_track_correct_preds'])):
#         results[domain + '_bool_track_correct_preds'][k] = results[domain + '_bool_track_correct_preds'][k].to('cpu')
# results = np.save(PATH, results)

model_MSD_ERM_655.pt
{'clean': 0.884, 'PGD_L1_std': 0.822, 'PGD_L2_std': 0.611, 'PGD_Linf_std': 0.193, 'Deepfool_base': 0.567, 'CW_base': 0.771, 'PGD_Linf_mod': 0.002, 'Deepfool_mod': 0.158, 'CW_mod': 0.402, 'Autoattack': 0.015, 'worst_no_skipped': 0.006, 'worst_with_skipped': 0.002, 'worst_seen': 0.193, 'worst_unseen': 0.002, 'worst_unseen_no_skipped': 0.006, 'worst_always_unseen': 0.002, 'worst_always_unseen_no_skipped': 0.006, 'num_test_samples': 1000, 'num_attack_restarts': 10, 'model_name': 'model_MSD_ERM_655.pt', 'seen_attacks': ['PGD_L1_std', 'PGD_L2_std', 'PGD_Linf_std'], 'unseen_attacks': ['clean', 'Deepfool_base', 'CW_base', 'PGD_Linf_mod', 'Deepfool_mod', 'CW_mod', 'Autoattack'], 'always_unseen_attacks': ['PGD_Linf_mod', 'Deepfool_mod', 'CW_mod', 'Autoattack'], 'skipped_domains_worst_case': ['PGD_Linf_mod']} 


model_MSD_REx_655.pt
{'clean': 0.902, 'PGD_L1_std': 0.868, 'PGD_L2_std': 0.718, 'PGD_Linf_std': 0.674, 'Deepfool_base': 0.824, 'CW_base': 0.473, 'PGD_Linf_mod': 0.01,

CIFAR:
clean
0.8648 clean; 0.0946 BA 10it; 0.0946 worst

MNIST
Test with clean + PGD Linf, eps = [0.1, 0.2, 0.3, 0.4, 0.5, 1.0], 100 iter, eps_iter = 0.05
MSD @ 50: [0.8947 0.734  0.4845 0.1766 0.0137 0.     0.    ] 0.0
MSD @ 250: [0.9072 0.7661 0.5444 0.1954 0.0104 0.     0.    ]
REx std @ 840: [0.9829 0.9473 0.8952 0.7839 0.0394 0.     0.    ] 0.0

Test with clean + PGD L2, eps = [0.5, 1.0, 1.5, 2.0, 3.0, 5.0], 100 iter, eps_iter = 0.1
MSD @ 50: [0.8947 0.8192 0.7418 0.6535 0.5663 0.3917 0.0309] 0.0309
MSD @ 250: [0.9072 0.7977 0.7261 0.6653 0.5997 0.4184 0.0391]
REx std @ 840: [0.9829 0.9542 0.8877 0.739  0.5186 0.1728 0.0081] 0.0065

Test with clean + PGD L1, eps = [4.0, 6.0, 8.0, 10.0, 17.0, 25.0], 100 iter, eps_iter = 0.5
MSD @ 50: [0.8947 0.8401 0.8335 0.8347 0.8358 0.8398 0.8438] 0.8092
MSD @ 250: [0.9072 0.8663 0.8653 0.868  0.8706 0.8802 0.8854]
REx std @ 840: [0.9829 0.9248 0.9153 0.9099 0.9073 0.8988 0.8998] 0.8574

PA L1, PA L2, BA L2 50 iter (same eps as corresponding PGD std):
MSD @ 250: CRASH
REx std @ 840: 

BA L2 5000 iter (same eps as corresponding PGD std):
MSD @ 250: [0.9072 0.7686] 0.7686
REx std @ 840: [0.9829 0.5563] 0.5563

PA L1, 10 first minibatchs (=1000 test points):
MSD @ 250: 0.704
REx std @ 840: 0.275