In [3]:
from datasets import prepare_poison_dataset
from util import *
import pickle

def save_predicted_indices(predicted_indices: np.array, save_name: str):
            with open(f"./cleansed_labels/{save_name}.pkl", "wb") as f:
                pickle.dump(predicted_indices, f)

for dataset_str in ["badnets1", "badnets10", "sig", "wanet"]:
    for train in [True, False]:

        poison_rates = []
        clean_kepts = []
        poison_kepts = []

        train_str = "train" if train else "test"
        print(f"{dataset_str}-{train_str}")

        for dataset_index in [0,1,2]:
        
            dataset_name = f"{dataset_str}-{dataset_index}"
            simclr_model_name = f"{dataset_name}-SimCLR.pt"

            dataset, true_poison_indices, _, _ = prepare_poison_dataset(dataset_name, train)
            simclr, _ = load_simclr(simclr_model_name)
            features, labels_poison, labels_true = extract_simclr_features(simclr, dataset, layer="repr")
            num_classes = int(max(labels_poison).item())

            n_neighbors = int(len(dataset) / 500)

            # Nondisruptive cleanse
            predicted_poison_indices_nondisruptive = knn_cleanse(features, labels_poison, n_neighbors=n_neighbors)

            # Disruptive cleanse
            features_2d = calculate_features_2d(features, n_neighbors=n_neighbors)
            #plot_features_2d(features_2d, labels_poison, true_poison_indices, legend=True)
            predicted_poison_indices_disruptive = kmeans_cleanse(features_2d, means=11, mode="distance")

            # Combine cleanses
            predicted_poison_indices_final = predicted_poison_indices_nondisruptive | predicted_poison_indices_disruptive

            # Evaluate
            poison_rate, poison_kept, clean_kept = evaluate_cleanse(predicted_poison_indices_final, true_poison_indices)
            poison_rates.append(poison_rate)
            clean_kepts.append(clean_kept)
            poison_kepts.append(poison_kept)

            # Save
            save_name = f"{dataset_str}-{dataset_index}-{train_str}"
            save_predicted_indices(predicted_poison_indices_final, save_name)

        poison_rate = sum(poison_rates)/len(poison_rates)
        clean_kept = sum(clean_kepts)/len(clean_kepts)
        poison_kept = sum(poison_kepts)/len(poison_kepts)
        
        # Print
        print(f"\tpoison rate: {100*poison_rate: .2f}%\t(", end="")
        for pr in poison_rates:
            print(f"{100*pr: .2f}, ", end="")
        print(")")
        print(f"\tclean kept:  {100*clean_kept: .2f}%\t(", end="")
        for ck in clean_kepts:
            print(f"{100*ck: .2f}, ", end="")
        print(")")
        print(f"\tclean kept:  {100*poison_kept: .2f}%\t(", end="")
        for pk in poison_kepts:
            print(f"{100*ck: .2f}, ", end="")
        print(")")

badnets1-train
	poison rate:  0.01%	( 0.01,  0.00,  0.03, )
	clean kept:   74.34%	( 73.77,  72.98,  76.27, )
badnets1-test
	poison rate:  0.03%	( 0.05,  0.00,  0.03, )
	clean kept:   73.46%	( 73.79,  73.34,  73.24, )
badnets10-train
	poison rate:  0.42%	( 0.44,  0.09,  0.73, )
	clean kept:   73.48%	( 73.36,  73.48,  73.61, )
badnets10-test
	poison rate:  0.91%	( 1.03,  0.44,  1.27, )
	clean kept:   74.25%	( 72.79,  73.33,  76.62, )
sig-train
	poison rate:  0.00%	( 0.00,  0.00,  0.00, )
	clean kept:   82.51%	( 81.70,  82.64,  83.19, )
sig-test
	poison rate:  0.00%	( 0.00,  0.00,  0.00, )
	clean kept:   83.02%	( 82.12,  83.17,  83.76, )
wanet-train
	poison rate:  0.62%	( 0.60,  0.25,  1.00, )
	clean kept:   73.84%	( 74.19,  73.99,  73.35, )
wanet-test
	poison rate:  1.29%	( 1.42,  0.82,  1.64, )
	clean kept:   74.80%	( 76.34,  74.17,  73.90, )
