In [4]:
from src.utils import get_mapping_dict
from src.dataset import ImageNetteDataset
from src.models import get_model
from torch.utils.data import DataLoader
import torch
import numpy as np

In [5]:
device = "cuda"
model_list = ["resnet18","resnet50","densenet121","wide_resnet50_2"]
attack_mode_list = ["fgsm_un", "pgd_un", "mifgsm_un", "deepfool"]
mapping_folder_to_name, mapping_folder_to_label, mapping_name_to_label, mapping_label_to_name = get_mapping_dict()

correctness_matrix = np.ones((len(attack_mode_list), len(model_list), len(model_list)))
for i, attack_mode in enumerate(attack_mode_list):
    for j, model_from in enumerate(model_list):
        for k, model_to in enumerate(model_list):
            #if model_from == model_to:
            #    continue
            print(f"From {model_from} attack {model_to} by {attack_mode}")
            adv_dataset = ImageNetteDataset(data_root=f"./adv_example/{attack_mode}/{model_from}/", mapping_folder_to_label=mapping_folder_to_label, train=True, simple_transform=True)
            adv_dataloader = DataLoader(adv_dataset, batch_size=50, shuffle=False)
            
            model = get_model(model_to)
            model.load_state_dict(torch.load(f"./models/{model_to}.pth"))

            model.to(device)
            model.eval()
            correct_num = 0
            data_num = 0
            with torch.no_grad():
                for idx, (data, labels) in enumerate(adv_dataloader):
                    print(f"val process {idx + 1} / {len(adv_dataloader)}             ", end = "\r")
                    data = data.to(device)
                    labels = labels.to(device)

                    output = model(data)
                    _, preds = torch.max(output, 1)
                    correct = (labels == preds).sum().cpu().numpy()
                    correct_num += correct
                    data_num += len(data)
            correctness = correct_num / data_num
            correctness_matrix[i, j, k] = correctness
            print(f"Correctness = {correctness:.4f}")


From resnet18 attack resnet18 by fgsm_un
Correctness = 0.1739            
From resnet18 attack resnet50 by fgsm_un
Correctness = 0.9095            
From resnet18 attack densenet121 by fgsm_un
Correctness = 0.9098            
From resnet18 attack wide_resnet50_2 by fgsm_un
Correctness = 0.9188            
From resnet50 attack resnet18 by fgsm_un
Correctness = 0.9036            
From resnet50 attack resnet50 by fgsm_un
Correctness = 0.2346            
From resnet50 attack densenet121 by fgsm_un
Correctness = 0.9063            
From resnet50 attack wide_resnet50_2 by fgsm_un
Correctness = 0.9075            
From densenet121 attack resnet18 by fgsm_un
Correctness = 0.8952            
From densenet121 attack resnet50 by fgsm_un
Correctness = 0.9042            
From densenet121 attack densenet121 by fgsm_un
Correctness = 0.2053            
From densenet121 attack wide_resnet50_2 by fgsm_un
Correctness = 0.9086            
From wide_resnet50_2 attack resnet18 by fgsm_un
Correctness = 0.9131  

In [4]:
mapping_label_to_name

{0: 'tench',
 1: 'English_springer',
 2: 'cassette_player',
 3: 'chain_saw',
 4: 'church',
 5: 'French_horn',
 6: 'garbage_truck',
 7: 'gas_pump',
 8: 'golf_ball',
 9: 'parachute'}