In [None]:
import torch
from model import CNNCrown
from verifier import ABCrown
from verifier import PGDVerifier
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split


In [None]:
# loading all the models

DEVICE = "cuda"

base_path = "models_info/model_weights/no_pooling_model"
augmentation_path = f"{base_path}/augmentation"
no_augmentation_path = f"{base_path}/no_augmentation"

models_weights = [
    torch.load(f"{augmentation_path}/normal_model.pt", map_location=DEVICE),
    torch.load(f"{augmentation_path}/contrastive_model.pt", map_location=DEVICE),
    torch.load(f"{augmentation_path}/adversarial_model.pt", map_location=DEVICE),
    torch.load(f"{augmentation_path}/adversarial_contrastive_model.pt", map_location=DEVICE),
    torch.load(f"{no_augmentation_path}/normal_model.pt", map_location=DEVICE),
    torch.load(f"{no_augmentation_path}/contrastive_model.pt", map_location=DEVICE),
    torch.load(f"{no_augmentation_path}/adversarial_model.pt", map_location=DEVICE),
    torch.load(f"{no_augmentation_path}/adversarial_contrastive_model.pt", map_location=DEVICE),
]

models = []

for weights in models_weights:
    model = CNNCrown(pooling=False)
    model.load_state_dict(weights)
    models.append(model)
    
models_name = ["Normal Model", "Contrastive Model", "Adversarial Model", "Adversarial Contrastive"] * 2

In [None]:
# loading dataset

torch.manual_seed(42)
BATCH_SIZE = 2048
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

dataset = datasets.CIFAR10(root="data", train=True, download=True, transform=transform)

train_ratio, validation_ratio = 0.8, 0.2
dataset_size = len(dataset)
train_size = int(train_ratio * dataset_size)
validation_size = dataset_size - train_size

train_dataset, validation_dataset = random_split(dataset, [train_size, validation_size])
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)
validation_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)

test_dataset = datasets.CIFAR10(root="data", train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)

In [None]:
epsilon_list = [1/255, 2/255, 4/255, 8/255, 16/255]
epsilon_labels = ["1/255", "2/255", "4/255", "8/255", "16/255"]

In [None]:
# PGD attack only
pgd = PGDVerifier(DEVICE)

print("> ASR = Attack Success Rate\n")

with torch.no_grad():
    for epsilon, eps_label  in zip(epsilon_list, epsilon_labels):
        print(f"> Eps: {eps_label}")
        for model, model_name in zip(models, models_name): 
            model.eval()
            print(f"\t> {model_name}")
            
            test_adv_attack_success = 0
            total_images = 0
            for i, (images, labels) in enumerate(train_loader):
                images = images.to(DEVICE)
                labels = labels.to(DEVICE)
                adversarial_examples, successes, initial_wrong_predictions = pgd.verify(
                    model, 
                    images, 
                    labels, 
                    epsilon=epsilon, 
                    clamp_min=-1, # clamp_* -> ensure valid input within [-1,1]
                    clamp_max=1
                )      
                test_adv_attack_success += successes.sum().item()
                total_images += successes.shape[0] - initial_wrong_predictions.sum().item()
                    
            test_adv_attack_success /= total_images
            test_adv_attack_success *= 100
            print(f"\t\t- Test ASR: {test_adv_attack_success}%;")

In [None]:
# alpha-beta crown analysis

verifier = ABCrown(DEVICE)
MAX_POINTS = 10

outputs = [[] for _ in models[:4]]

for model_id, model in enumerate(models[:4]):
    model.eval()
    model = model.to(DEVICE)
    for i in range(len(test_dataset)):
        image, label = test_dataset[i]
        image = image.to(DEVICE)
        with torch.no_grad():
            logits = model(image.unsqueeze(0))[0]
        prediction = torch.argmax(logits)
        
        if prediction == label:
            # for epsilon, epsilon_label in zip(epsilon_list, epsilon_labels):
            result = verifier.verify(model, image, 10, label, epsilon_list[0])
            outputs[model_id].append(result.status)
            print("QUA", models_name[model_id], len(outputs[model_id]), result.status)
                # if result.status not in ["verified", "", "unknown"]:
                    # if this epsilon is not verified, then also the following one is not
                    #break

        if len(outputs[model_id]) >= MAX_POINTS:
            break
            

In [None]:
import matplotlib.pyplot as plt

# Define all possible labels
all_labels = ['safe', 'verified', 'unsafe-pgd', 'unsafe-bab', 'safe-incomplete', 'unknown']
label_indices = range(len(all_labels))

for model_id, output in enumerate(outputs[:4]):
    # Map outputs to indices based on all_labels
    output_indices = [all_labels.index(label) for label in output]
    
    plt.figure(figsize=(8, 6))
    plt.hist(output_indices, bins=range(len(all_labels) + 1), align='left', rwidth=0.8)
    plt.title(f"Histogram of Labels for Model {models_name[model_id]}")
    plt.xlabel("Labels")
    plt.ylabel("Frequency")
    plt.xticks(label_indices, all_labels, rotation=45)
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.show()

In [None]:
from collections import Counter

# Create a dictionary to store the frequency of statuses for each model
status_frequency = {}

# Iterate over the outputs and models_name
for model_name, output in zip(models_name[:4], outputs[:4]):
    # Count the frequency of each status in the output
    status_frequency[model_name] = dict(Counter(output))

print(status_frequency)