In [None]:
import torch
from model import CNNCrown
from verifier import PGDVerifier
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

In [None]:
# loading models
models_path = "/kaggle/input/cnnrobust/pytorch/default/1"
DEVICE = 'cuda'
no_aug_normal_model = CNNCrown()
no_aug_normal_model.load_state_dict(torch.load(f"{models_path}/no_normal_model.pt", map_location=DEVICE))
no_aug_contrastive_model = CNNCrown()
no_aug_contrastive_model.load_state_dict(torch.load(f"{models_path}/no_contrastive_model.pt", map_location=DEVICE))
aug_normal_model = CNNCrown()
aug_normal_model.load_state_dict(torch.load(f"{models_path}/normal_model.pt", map_location=DEVICE))
aug_contrastive_model = CNNCrown()
aug_contrastive_model.load_state_dict(torch.load(f"{models_path}/contrastive_model.pt", map_location=DEVICE))

models = [
    no_aug_normal_model,
    no_aug_contrastive_model,
    aug_normal_model,
    aug_contrastive_model
]

models_name = [
    "Normal Model - No Augmentation",
    "Contrastive Model - No Augmentation",
    "Normal Model - Augmentation",
    "Contrastive Model - Augmentation"
]

In [None]:
# loading dataset

torch.manual_seed(42)
BATCH_SIZE = 64
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

dataset = datasets.CIFAR10(root="data", train=True, download=True, transform=transform)

train_ratio, validation_ratio = 0.8, 0.2
dataset_size = len(dataset)
train_size = int(train_ratio * dataset_size)
validation_size = dataset_size - train_size

train_dataset, validation_dataset = random_split(dataset, [train_size, validation_size])
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)
validation_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)

test_data = datasets.CIFAR10(root="data", train=False, download=True, transform=transform)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [None]:
pgd = PGDVerifier(DEVICE)
epsilon_list = [1/255, 2/255, 4/255, 8/255, 16/255]
epsilon_labels = ["1/255", "2/255", "4/255", "8/255", "16/255"]

MAX_ITERATIONS = 7  #Â 1024 samples

print("> ASR = Attack Success Rate\n")

with torch.no_grad():
    for epsilon, eps_label  in zip(epsilon_list, epsilon_labels):
        print(f"> Eps: {eps_label}")
        for model, model_name in zip(models, models_name): 
            model.eval()
            print(f"\t> {model_name}")
            train_adv_attack_success = 0
            total_images = 0
            for i, (images, labels) in enumerate(train_loader):
                adversarial_examples, successes = pgd.verify(model, images, labels, epsilon=epsilon, clamp_min=-1, clamp_max=1)      # clamp_* -> ensure valid input within [-1,1]
                train_adv_attack_success += successes.sum().item()
                total_images += successes.shape[0]
                if i > MAX_ITERATIONS:
                    break
                
            train_adv_attack_success /= total_images
            train_adv_attack_success *= 100
            print(f"\t\t- Train ASR: {train_adv_attack_success:.2f}%;")        
            
            validation_adv_attack_success = 0
            total_images = 0
            for i, (images, labels) in enumerate(validation_loader):
                adversarial_examples, successes = pgd.verify(model, images, labels, epsilon=epsilon, clamp_min=-1, clamp_max=1)      # clamp_* -> ensure valid input within [-1,1]
                validation_adv_attack_success += successes.sum().item()
                total_images += successes.shape[0]
                if i > MAX_ITERATIONS:
                    break
                    
            validation_adv_attack_success /= total_images
            validation_adv_attack_success *= 100
            print(f"\t\t- Validation ASR: {validation_adv_attack_success:.2f}%;")
            
            # basically the same as the validation set
            """test_adv_attack_success = 0
            total_images = 0
            for i, (images, labels) in enumerate(validation_loader):
                adversarial_examples, successes = pgd.verify(model, images, labels, epsilon=epsilon, clamp_min=-1, clamp_max=1)      # clamp_* -> ensure valid input within [-1,1]
                test_adv_attack_success += successes.sum().item()
                total_images += successes.shape[0]
                if i > MAX_ITERATIONS:
                    break
                    
            test_adv_attack_success /= total_images
            test_adv_attack_success *= 100
            print(f"\t\t- Test ASR: {test_adv_attack_success}%;")"""