In [1]:
%load_ext autoreload
%autoreload 2

from datasets import prepare_poison_dataset
from util import *

# Validating k-NN

In [23]:
k_range = lambda d: [
    d/2500,
    d/1000,
    d/500,
    d/400,
    d/300,
    d/200,
    d/100,
    d/50,   
    d/20,   # d/2num_classes 
]

k_range_str = lambda d: [
    "N/2500",
    "N/1000",
    "N/500",
    "N/400",
    "N/300",
    "N/200",
    "N/100",
    "N/50",   
    "N/20"
]

results = {}

for dataset_name in ["badnets1", "wanet", "sig"]:

    results[dataset_name] = {}

    for train in [True, False]:
        
        train_str = "TRAIN" if train else "TEST"
        print(f"Testing {dataset_name} {train_str}")
        results[dataset_name][train_str] = {}

        dataset_size = 50000 if train else 10000
        for k, k_str in zip(k_range(dataset_size), k_range_str(dataset_size)):

            k = int(k)
            poison_rates = []
            clean_kepts = []
            for dataset_index in [0,1,2]:
                dataset_name_complete = dataset_name + "-" + str(dataset_index)

                simclr_model_name = f"{dataset_name_complete}-SimCLR.pt"
                dataset, true_poison_indices, _, _ = prepare_poison_dataset(dataset_name_complete, train)
                simclr, epochs = load_simclr(simclr_model_name)
                features, labels_poison, labels_true = extract_simclr_features(simclr, dataset)

                predicted_poison_indices_nondisruptive = knn_cleanse(features, labels_poison, n_neighbors=k)
                poison_rate, _, clean_kept = evaluate_cleanse(predicted_poison_indices_nondisruptive, true_poison_indices)

                poison_rates.append(poison_rate)
                clean_kepts.append(clean_kept)
        
            poison_rate = sum(poison_rates)/3
            clean_kept = sum(clean_kepts)/3
            print(f"\tk = {k_str} = {k}:")
            print(f"\t\tpoison rate: {100*poison_rate: .2f}\t(", end="")
            for pr in poison_rates:
                print(f"{100*pr: .2f}, ", end="")
            print(")")
            print(f"\t\tclean kept:  {100*clean_kept: .2f}\t(", end="")
            for ck in clean_kepts:
                print(f"{100*ck: .2f}, ", end="")
            print(")")

            results[dataset_name][train_str][k_str] = (poison_rate, clean_kept)

Testing badnets TRAIN
	k = d/5000 = 10:
		poison rate:  1.20	( 1.31,  0.91,  1.39, )
		clean kept:   84.83	( 84.37,  85.39,  84.74, )
	k = d/2500 = 20:
		poison rate:  0.60	( 0.63,  0.34,  0.82, )
		clean kept:   84.00	( 83.40,  84.78,  83.81, )
	k = d/1000 = 50:
		poison rate:  0.42	( 0.45,  0.14,  0.67, )
		clean kept:   82.89	( 82.27,  83.76,  82.63, )
	k = d/500 = 100:
		poison rate:  0.39	( 0.40,  0.12,  0.65, )
		clean kept:   82.24	( 81.69,  83.04,  81.99, )
	k = d/100 = 500:
		poison rate:  0.46	( 0.47,  0.11,  0.80, )
		clean kept:   80.58	( 79.94,  81.60,  80.21, )
	k = d/50 = 1000:
		poison rate:  0.53	( 0.57,  0.11,  0.92, )
		clean kept:   79.61	( 78.84,  80.77,  79.22, )
	k = d/20 = 2500:
		poison rate:  0.80	( 0.83,  0.14,  1.43, )
		clean kept:   77.47	( 76.47,  79.13,  76.82, )
	k = d/10 = 5000:
		poison rate:  1.25	( 1.41,  0.17,  2.18, )
		clean kept:   74.46	( 73.26,  76.67,  73.46, )
Testing badnets TEST
	k = d/5000 = 2:
		poison rate:  11.07	( 12.41,  10.97,  9.83

In [None]:
# TODO plot