In [None]:
from src.utils import get_mapping_dict
from src.dataset import ImageNetteDataset
from src.models import get_model
from torch.utils.data import DataLoader
import torch
import numpy as np

In [None]:
device = "cuda"
model_list = ["resnet18","resnet50","densenet121","wide_resnet50_2"]
attack_mode_list = ["fgsm_un", "pgd_un", "mifgsm_un", "deepfool"]
mapping_folder_to_name, mapping_folder_to_label, mapping_name_to_label, mapping_label_to_name = get_mapping_dict()

correctness_matrix = np.ones((len(attack_mode_list), len(model_list), len(model_list)))
for i, attack_mode in enumerate(attack_mode_list):
    for j, model_from in enumerate(model_list):
        for k, model_to in enumerate(model_list):
            #if model_from == model_to:
            #    continue
            print(f"From {model_from} attack {model_to} by {attack_mode}")
            adv_dataset = ImageNetteDataset(data_root=f"./adv_example/{attack_mode}/{model_from}/", mapping_folder_to_label=mapping_folder_to_label, train=True, simple_transform=True)
            adv_dataloader = DataLoader(adv_dataset, batch_size=50, shuffle=False)
            
            model = get_model(model_to)
            model.load_state_dict(torch.load(f"./models/{model_to}.pth"))

            model.to(device)
            model.eval()
            correct_num = 0
            data_num = 0
            with torch.no_grad():
                for idx, (data, labels) in enumerate(adv_dataloader):
                    print(f"val process {idx + 1} / {len(adv_dataloader)}             ", end = "\r")
                    data = data.to(device)
                    labels = labels.to(device)

                    output = model(data)
                    _, preds = torch.max(output, 1)
                    correct = (labels == preds).sum().cpu().numpy()
                    correct_num += correct
                    data_num += len(data)
            correctness = correct_num / data_num
            correctness_matrix[i, j, k] = correctness
            print(f"Correctness = {correctness:.4f}")
