In [None]:
import os
if not os.path.exists("src"):
    !git clone https://github.com/Allliance/trodo
    !cp -r trodo/src ./src

from src import *

In [None]:
mapping = ["a2o", 'a2a'][0]
attack_in = False
adv = False
DEBUG = True
sample_num = 400
EPS = 1/255

source_dataset = ['cifar10', 'mnist', 'gtsrb', 'cifar100', 'pubfig'][2]
out_dataset = ['cutpaste', 'distort', 'elastic', 'rot'][0]

batch_size = 8 if source_dataset == 'pubfig' else 256

init_eps_lb = 0/255

if source_dataset == 'mnist':
    init_eps_ub = 32/255
else:
    init_eps_ub = 4/255
    
init_eps_step = 1/255 if source_dataset == 'mnist' else 0.5/255
sample_num = 500
sample_k = 3

arch = ['preact', 'resnet', 'vgg'][1]

attack_norm = ['linf', 'l2'][0]

if source_dataset in ['cifar10', 'cifar100'] and arch in ['vgg', 'preact', 'resnet']:
    discards = ['inputaware']
elif source_dataset == 'pubfig' and arch in ['preact']:
    discards = ['inputaware']
else:
    discards = []

    
min_sanity_acc = 0.7

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
%load_ext autoreload
%autoreload 2

# Fetching code from repo

In [2]:
# Loading constants
from BAD.constants import CLEAN_ROOT_DICT, BAD_ROOT_DICT, NORM_MEAN, NORM_STD
from BAD.constants import num_classes as num_classes_dict

# Loading Model architecture
if arch == 'preact':
    from BAD.models.loaders import load_preact as model_loader
elif arch == 'resnet':
    from BAD.models.loaders import load_resnet as model_loader
else:
    raise NotImplementedError("This architecture is not supported")

# Preparations

## Loading Model

In [None]:
num_classes = num_classes_dict[source_dataset]

final_model_loader = lambda x, meta_data: model_loader(x,
                                                       num_classes=num_classes,
                                                       mean=NORM_MEAN[source_dataset],
                                                       std=NORM_STD[source_dataset],
                                                       normalize=True,
                                                       meta_data=meta_data)


CLEAN_ROOT = CLEAN_ROOT_DICT[mapping][source_dataset][arch]
BAD_ROOT = BAD_ROOT_DICT[mapping][source_dataset][arch]
# CLEAN_ROOT = '/kaggle/input/cifar10-adv-resnet18-all-models/models/clean'
# BAD_ROOT = '/kaggle/input/cifar10-adv-resnet18-all-models/models'

def filter_dataset(source_dataset, to_remove_dataset):
    filter_paths = [data['path'] for data in to_remove_dataset.data]
    
    def filter_data_part(data_part):
        new_data = []
        for data in data_part:
            if data['path'] not in filter_paths:
                new_data.append(data)
        return new_data
    
    source_dataset.cleans_data = filter_data_part(source_dataset.cleans_data)
    source_dataset.bads_data = filter_data_part(source_dataset.bads_data)
    source_dataset.data = source_dataset.cleans_data + source_dataset.bads_data
    random.shuffle(source_dataset.data)

val_modelset = ModelDataset(CLEAN_ROOT, BAD_ROOT, final_model_loader, sample=True, sample_k=sample_k,  discards=discards, version='new')
test_modelset = ModelDataset(CLEAN_ROOT, BAD_ROOT, final_model_loader, sample=False,  discards=discards, version='new')

# filter_dataset(test_modelset, val_modelset)

print("No. clean models in validation set:", len(val_modelset.cleans_data))
print("No. bad models in validation set:", len(val_modelset.bads_data))

print("No. clean models in test set:", len(test_modelset.cleans_data))
print("No. bad models in test set:", len(test_modelset.bads_data))

## Sanity Checks

In [4]:
clear_memory()

In [None]:
sanity_testloader = get_cls_loader(source_dataset, train=True, sample_portion=0.05, batch_size=batch_size)

print(len(sanity_testloader.dataset))

sample_clean_model = test_modelset.get_random_clean_model()
acc = evaluate(sample_clean_model, sanity_testloader, device, metric='acc', attack=None, progress=True)
if acc < min_sanity_acc:
#     raise ValueError("The clean model is not working well. Accuracy:", acc)
    print("The clean model is not working well. Accuracy:", acc)
        
print("Some clean model acc on trainset:", acc)

for attack in ['badnet', 'sig', 'bpp', 'blended', 'inputaware', 'wanet']:
    if attack in discards:
        continue
    print('attack:', attack)
    try:
        sample_bad_model = test_modelset.get_random_bad_model(attack)
    except Exception as e:
        print(attack, "skipped")
    
    acc = evaluate(sample_bad_model, sanity_testloader, device, metric='acc', attack=None, progress=True)
    print("Some bad model acc on trainset:", acc)
    if acc < min_sanity_acc:
        continue
#         raise ValueError(f"The {attack} model is not working well. auc:", acc)

# Experiments

## Validation

In [None]:
from BAD.data.loaders import get_ood_loader
from BAD.visualization import visualize_samples

def get_dataloader():
    dataloader = get_ood_loader(in_dataset=source_dataset,
                                out_dataset=out_dataset,
                                sample_num=sample_num,
                                sample=True,
                                only_ood=True,
                                batch_size=batch_size)
    # print("Size of dataset:", len(dataloader.dataset))
    return dataloader

dataloader = get_dataloader()
print(len(dataloader.dataset))
# visualize_samples(dataloader, 10)

In [7]:
from BAD.validate import get_models_scores
from BAD.scores.msp import get_msp

def mean_id_score_diff(model, dataloader, attack, progress=False):
    before_attack_scores = []
    after_attack_scores = []
    
    for data, targets in dataloader:
        data = data.to(device)
        
        before_attack = get_msp(model, data)
        
        data = attack(data, targets)
        
        after_attack = get_msp(model, data)
        
        before_attack_scores += before_attack.detach().cpu().numpy().tolist()
        after_attack_scores += after_attack.detach().cpu().numpy().tolist()
        
        torch.cuda.empty_cache()
        gc.collect()
        
    before_attack_scores = np.asarray(before_attack_scores)
    after_attack_scores = np.asarray(after_attack_scores)

    # print("Mean ID Score before attack:", np.mean(before_attack_scores))
    # print("Mean ID Score After attack:", np.mean(after_attack_scores))
    
    
    return 1 - (np.mean(after_attack_scores) - np.mean(before_attack_scores))

def get_scores(model_dataset, eps, progress=False):
    attack_eps = eps
    attack_steps = 10
    attack_alpha = 2.5 * attack_eps / attack_steps
    
    def score_function(model, progress=progress):
        dataloader = get_dataloader()
        attack = Attack(model, eps=attack_eps, steps=attack_steps, alpha=attack_alpha, attack_in=attack_in)
        
        if eps == 0:
            attack = None
        
        return mean_id_score_diff(model, dataloader, attack, progress=progress)

    return get_models_scores(model_dataset, score_function, progress=progress, live=True, strict=True)

In [None]:
from BAD.validate import find_best_eps, get_auc_on_models_scores
from BAD.score_functions import get_auc, get_l2

def validation_function(eps, progress=False):
    scores, labels = get_aucs(val_modelset, eps)
    return roc_auc_score(labels, scores)

if EPS == None:
    best_eps = find_best_eps(init_eps_lb, init_eps_ub, init_eps_step, validation_function, max_error=1e-3, partition=10, progress=True, verbose=True)
else:
    best_eps = EPS
print("Best epsilon is:", best_eps * 255)

In [9]:
from tqdm import tqdm
from BAD.utils import get_best_acc_and_thresh
from BAD.score_functions import get_auc

def get_auc_on_auc_valset(eps, progress=False):
    aucs, labels = get_aucs(val_modelset, eps, progress=progress)
    
    auc = roc_auc_score(labels, aucs)
    acc, thresh = get_best_acc_and_thresh(labels, aucs)
    
    return auc, acc, thresh

# val_auc, val_acc, val_thresh = get_auc_on_auc_valset(best_eps, progress=True)
# val_auc, val_acc, val_thresh

# Testing

In [None]:
from BAD.validate import find_best_eps, get_auc_on_models_scores
from BAD.score_functions import get_auc, get_l2

total_runs = 5
results = []
for _ in range(total_runs):
    scores, labels = get_scores(test_modelset, best_eps, progress=False)
    # preds = [0 if score < val_thresh else 1 for score in scores]
    # print("Final Accuracy on test set:", accuracy_score(labels, preds))
    results.append(roc_auc_score(labels, scores))
    print("A run finished:", results[-1])

print("Final AUROC on test set:", sum(results) / len(results))