In [None]:
from typing import List

from src.attacks.shattered_class.surgery import SurgeryShatteredClass
from src.attacks.composite import CompositeSurgery
from src.attacks.surgery import Surgery
from src.attacks.verification_backdoor import PinterestBackdoorData
from src.core.performance_test import AttackExperiment

In [None]:
class ShatteredClassMultiExperiment(AttackExperiment):
    def __init__(self, backdoor_names:List[str], *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.backdoor_names = backdoor_names

    def _setup_attack(self) -> Surgery:
        attacks = [SurgeryShatteredClass(backdoor_data=PinterestBackdoorData(backdoor_name, self._pfr))
                   for backdoor_name in self.backdoor_names]
        return CompositeSurgery(attacks)

    def print_results(self, results):
        print(f'Results for {self}:')
        for stage, stats in results.items():
            print(f'\t{stage.capitalize()}:')
            ba = stats[0]["test/0/Accuracy"].mean()
            asrs = [s[f"test/{i+1}/Accuracy"].mean() for i, s in enumerate(stats[1:])]
            print(f'\t\tMean benign accuracy: {self._as_percentage(ba)}')
            for asr, backdoor_name in zip(asrs, self.backdoor_names):
                print(f'\t\tMean attack success rate for {backdoor_name}: {self._as_percentage(asr)}')

# Testing on Dev View

In [None]:
experiment = ShatteredClassMultiExperiment.sanity(PinterestBackdoorData.CANDIDATES[:2])
results = experiment.run()

In [None]:
experiment.print_results(results)

# Testing on Test View

In [None]:
experiment = ShatteredClassMultiExperiment(PinterestBackdoorData.CANDIDATES[:10])
results = experiment.run()

In [None]:
experiment.print_results(results)