In [1]:
%load_ext autoreload
%autoreload 2

from datasets import prepare_poison_dataset
from util import *

In [2]:
results = {}

for dataset_name in ["sig", "badnets1", "badnets10", "wanet"]:

    results[dataset_name] = {}

    for train in [True, False]:

        train_str = "TRAIN" if train else "TEST"
        print(f"Testing {dataset_name} {train_str}")
        results[dataset_name][train_str] = {}

        for mode in ["distance", "size", "both"]:

            poison_rates = []
            clean_kepts = []

            for dataset_index in [0,1,2]:
                dataset_name_complete = dataset_name + "-" + str(dataset_index)

                simclr_model_name = f"{dataset_name_complete}-SimCLR.pt"
                dataset, true_poison_indices, _, _ = prepare_poison_dataset(dataset_name_complete, train)
                simclr, epochs = load_simclr(simclr_model_name)
                features, labels_poison, labels_true = extract_simclr_features(simclr, dataset)
                features_2d = calculate_features_2d(features, n_neighbors=len(dataset)/500)

                for _ in range(10):
                    predicted_poison_indices_nondisruptive = kmeans_cleanse(features_2d, means=11, mode=mode)
                    poison_rate, _, clean_kept = evaluate_cleanse(predicted_poison_indices_nondisruptive, true_poison_indices)

                    poison_rates.append(poison_rate)
                    clean_kepts.append(clean_kept)

            poison_rate = sum(poison_rates)/len(poison_rates)
            clean_kept = sum(clean_kepts)/len(clean_kepts)
            results[dataset_name][train_str][mode] = (poison_rate, clean_kept)

            print(f"\tMode = {mode}:")
            print(f"\t\tpoison rate: {100*poison_rate: .2f}\t(", end="")
            # for pr in poison_rates:
            #     print(f"{100*pr: .2f}, ", end="")
            print(")")
            print(f"\t\tclean kept:  {100*clean_kept: .2f}\t(", end="")
            # for ck in clean_kepts:
            #     print(f"{100*ck: .2f}, ", end="")
            print(")")

Testing sig TRAIN
	Mode = distance:
		poison rate:  0.00	()
		clean kept:   100.00	()
	Mode = size:
		poison rate:  0.00	()
		clean kept:   100.00	()
	Mode = both:
		poison rate:  0.00	()
		clean kept:   100.00	()
Testing sig TEST
	Mode = distance:
		poison rate:  0.01	()
		clean kept:   100.00	()
	Mode = size:
		poison rate:  0.01	()
		clean kept:   100.00	()
	Mode = both:
		poison rate:  0.01	()
		clean kept:   100.00	()
Testing badnets1 TRAIN
	Mode = distance:
		poison rate:  1.02	()
		clean kept:   92.20	()
	Mode = size:
		poison rate:  1.01	()
		clean kept:   94.50	()
	Mode = both:
		poison rate:  1.00	()
		clean kept:   99.21	()
Testing badnets1 TEST
	Mode = distance:
		poison rate:  1.01	()
		clean kept:   91.43	()
	Mode = size:
		poison rate:  1.01	()
		clean kept:   93.84	()
	Mode = both:
		poison rate:  1.01	()
		clean kept:   98.87	()
Testing badnets10 TRAIN
	Mode = distance:
		poison rate:  10.17	()
		clean kept:   91.97	()
	Mode = size:
		poison rate:  10.09	()
		clean kep