In [None]:
from src.core.feature_space_plotting import FeatureSpaceFigure
from src.data.mnist import MNISTBackdoorData, MNIST
from src.models.mlp import PretrainedMLPBackbone
from src.attacks.merged_classes.surgery import SurgeryMergedClasses
from src.core.utils import MyTrainer, ROOT_PATH
from src.models.similarity.threshold_siamese import ThresholdSiamese

# Setup

In [None]:
pretrained_classifier_path = ROOT_PATH / 'pretrained' / 'mnist_3_features' / 'classifier.ckpt'
backbone = PretrainedMLPBackbone(pretrained_classifier_path).eval()
model = ThresholdSiamese(backbone=backbone)
mnist = MNIST.load(None)
backdoor1_class, backdoor2_class = 1, 9
attack = SurgeryMergedClasses(backdoor_data=(MNISTBackdoorData(backdoor1_class, mnist),
                                             MNISTBackdoorData(backdoor2_class, mnist)))
trainer = MyTrainer()


In [None]:
def plot(counts={}, default_count=300, normalize=False, sphere=False, s=1):
    figure = FeatureSpaceFigure()
    if sphere:
        figure.plot_sphere()
    figure.plot_dataset_embedding(mnist.train_dataset, model.backbone, counts, default_count, normalize_features=normalize, s=s)

# Before the Attack

In [None]:
plot()

In [None]:
plot(normalize=True, sphere=True)

# After the Attack

In [None]:
attack.edit_model(model)

In [None]:
plot()

# Clean

In [None]:
plot({backdoor1_class: 0, backdoor2_class: 0}, default_count=100, normalize=True)

# Backdoors

In [None]:
plot({backdoor1_class:20, backdoor2_class:20}, 0, normalize=True, s=5)