In [None]:
from typing import List, Tuple

from src.attacks.composite import CompositeSurgery
from src.attacks.merged_classes.surgery import SurgeryMergedClasses
from src.attacks.surgery import Surgery
from src.attacks.verification_backdoor import PinterestBackdoorData
from src.core.performance_test import AttackExperiment

In [None]:
class MergedClassesMultiSurgery(AttackExperiment):
    def __init__(self, backdoor_name_pairs=List[Tuple[str]], *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.backdoor_name_pairs = backdoor_name_pairs

    def _setup_attack(self) -> Surgery:
        def iter_attacks():
            for backdoor1_name, backdoor2_name in self.backdoor_name_pairs:
                yield SurgeryMergedClasses(backdoor_data=(PinterestBackdoorData(backdoor1_name, dataset=self._pfr),
                                                          PinterestBackdoorData(backdoor2_name, dataset=self._pfr)))
        return CompositeSurgery(list(iter_attacks()))

    def print_results(self, results):
        print(f'Results for {self}:')
        for stage, stats in results.items():
            print(f'\t{stage.capitalize()}:')
            ba = stats[0]["test/0/Accuracy"].mean()
            asrs = [s[f"test/{i+1}/Accuracy"].mean() for i, s in enumerate(stats[1:])]
            print(f'\t\tMean benign accuracy: {self._as_percentage(ba)}')
            for asr, (name1, name2) in zip(asrs, self.backdoor_name_pairs):
                print(f'\t\tMean attack success rate for [{name1} - {name2}]: {self._as_percentage(asr)}')

In [None]:
backdoor_pairs = [
    ('Anthony Mackie', 'Margot Robbie'),
    ('Rihanna', 'Jeff Bezos'),
    ('Morgan Freeman', 'Scarlett Johansson'),
    ('Barack Obama', 'Elon Musk')
]

# Testing on Dev View

In [None]:
experiment = MergedClassesMultiSurgery.sanity(backdoor_pairs[:2])
results = experiment.run()

In [None]:
experiment.print_results(results)

# Testing on Test View

In [None]:
experiment = MergedClassesMultiSurgery(backdoor_pairs)
results = experiment.run()

In [None]:
experiment.print_results(results)