In [None]:
!git clone https://github.com/Verified-Intelligence/auto_LiRPA
!pip install ./auto_LiRPA

## Configuration

Let's import the needed modules and load the dataset.

In [None]:
import sys
import os
import subprocess
import tempfile
from collections import Counter
import matplotlib.pyplot as plt
import multiprocessing
import torch
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms

from auto_LiRPA import BoundedModule
from model import CNNCrown, Encoder, LinearClassifier
from verifier import PGDVerifier

In [None]:
multiprocessing.set_start_method('spawn')

In [None]:
DEVICE = "cuda"
BATCH_SIZE = 2048
torch.manual_seed(42)

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

dataset = datasets.CIFAR10(root="data", train=True, download=True, transform=transform)

train_ratio, validation_ratio = 0.8, 0.2
dataset_size = len(dataset)
train_size = int(train_ratio * dataset_size)
validation_size = dataset_size - train_size

train_dataset, validation_dataset = random_split(dataset, [train_size, validation_size])
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
validation_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)

test_dataset = datasets.CIFAR10(root="data", train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)

After that, load all the trained models on the `DEVICE`:

In [None]:
base_path = "models_info/model_weights"
augmentation_path = f"{base_path}/augmentation"
no_augmentation_path = f"{base_path}/no_augmentation"

# loading no certified models
models_weights = [
    torch.load(f"{augmentation_path}/normal_model.pt", map_location=DEVICE),
    torch.load(f"{augmentation_path}/contrastive_model.pt", map_location=DEVICE),
    torch.load(f"{augmentation_path}/adversarial_model.pt", map_location=DEVICE),
    torch.load(f"{augmentation_path}/adversarial_contrastive_model.pt", map_location=DEVICE),
    torch.load(f"{no_augmentation_path}/normal_model.pt", map_location=DEVICE),
    torch.load(f"{no_augmentation_path}/contrastive_model.pt", map_location=DEVICE),
    torch.load(f"{no_augmentation_path}/adversarial_model.pt", map_location=DEVICE),
    torch.load(f"{no_augmentation_path}/adversarial_contrastive_model.pt", map_location=DEVICE),
]

models = []

for weights in models_weights:
    model = CNNCrown(pooling=False)
    model.load_state_dict(weights)
    model.to(DEVICE)
    models.append(model)
    
# loading certified models by hand

# with augmentation
certified_encoder = BoundedModule(Encoder(), torch.empty(2, 3, 32, 32))
certified_encoder.load_state_dict(torch.load(f"{augmentation_path}/certified_contrastive_encoder.pt"))
certified_classifier = LinearClassifier()
certified_classifier.load_state_dict(torch.load(f"{augmentation_path}/certified_contrastive_classifier.pt"))
certified_contrastive_model = CNNCrown()
certified_contrastive_model.encoder = certified_encoder
certified_contrastive_model.classifier = certified_classifier
models[4:4] = [certified_contrastive_model]    # the models trained using augmentation are in the first part of the list

certified_model = BoundedModule(CNNCrown(), torch.empty(2, 3, 32, 32))
certified_model.load_state_dict(torch.load(f"{augmentation_path}/certified_model.pt"))
models[4:4] = [certified_model]    # the models trained using augmentation are in the first part of the list

# no augmentation
certified_model = BoundedModule(CNNCrown(), torch.empty(2, 3, 32, 32))
certified_model.load_state_dict(torch.load(f"{no_augmentation_path}/certified_model.pt"))
certified_model.to(DEVICE)
models.append(certified_model)    # the models trained using no augmentation are in the last part of the list

certified_encoder = BoundedModule(Encoder(), torch.empty(2, 3, 32, 32))
certified_encoder.load_state_dict(torch.load(f"{no_augmentation_path}/certified_contrastive_encoder.pt"))
certified_classifier = LinearClassifier()
certified_classifier.load_state_dict(torch.load(f"{no_augmentation_path}/certified_contrastive_classifier.pt"))
certified_contrastive_model = CNNCrown()
certified_contrastive_model.encoder = certified_encoder
certified_contrastive_model.classifier = certified_classifier
certified_contrastive_model.to(DEVICE)
models.append(certified_contrastive_model)    # the models trained using no augmentation are in the last part of the list

models_name = ["Normal Model", "Contrastive Model", "Adversarial Model", "Adversarial Contrastive", "Certified", "Certified Contrastive"] * 2

In [None]:
epsilon_list = [1/255, 2/255, 4/255, 8/255, 16/255]
epsilon_labels = ["1/255", "2/255", "4/255", "8/255", "16/255"]
MAX_POINTS = 20

# PGD Verifier

This cell evaluates multiple models under a PGD attack across different ε values, computing the **Attack Success Rate (ASR)** on the test set.  
For ε = 4/255, it also collects a fixed number of adversarial (or remaining robust) examples per model for later analysis or visualization.

In [None]:
all_images = []

pgd = PGDVerifier(DEVICE)
print("> ASR = Attack Success Rate\n")

with torch.no_grad():
    for epsilon, eps_label in zip(epsilon_list, epsilon_labels):
        
        print(f"> Eps: {eps_label}")
        eps_list = []
        
        for model, model_name in zip(models, models_name):
            print(f"\t> {model_name}")
            model.eval()

            model_images_list = []
            model_labels_list = []
            
            images_counter = torch.tensor(0, device=DEVICE)
            test_adv_attack_success = torch.tensor(0, device=DEVICE)
            total_images = torch.tensor(0, device=DEVICE)
            
            for images, labels in test_loader:
                images = images.to(DEVICE)
                labels = labels.to(DEVICE)
                
                adversarial_examples, successes, initial_wrong_predictions = pgd.verify(
                    model, 
                    images, 
                    labels, 
                    epsilon=epsilon,
                    alpha=epsilon/2, # more fair for comparison
                    clamp_min=-1, # clamp_* -> ensure valid input within [-1,1]
                    clamp_max=1
                )

                if epsilon == 4/255:
                    initial_right_predictions = ~initial_wrong_predictions
                    adv_indices = torch.logical_and(~successes, initial_right_predictions)
                    max_left = MAX_POINTS - images_counter
                        
                    adv_images = images[adv_indices][:max_left]
                    adv_labels = labels[adv_indices][:max_left]
    
                    if adv_images.numel() > 0:
                        model_images_list.append(adv_images)
                        model_labels_list.append(adv_labels)
                        images_counter += adv_images.shape[0]
                    
                test_adv_attack_success += successes.sum().item()
                total_images += successes.shape[0] - initial_wrong_predictions.sum()

            # stack per model
            if epsilon == 4/255:
                model_images = torch.cat(model_images_list, dim=0)
                model_labels = torch.cat(model_labels_list, dim=0)

                eps_list.append((model_images, model_labels))
                    
            test_adv_attack_success = 100 * test_adv_attack_success / total_images
            print(f"\t\t- Test ASR: {test_adv_attack_success}%;")

        if epsilon == 4/255:
            all_images.append(eps_list)

# $\alpha\beta$-CROWN Analysis

This cell runs **formal robustness verification ($\alpha\beta$-Crown)** on a fixed set of selected images for each of the first four models.  
For every image, verification is executed in an isolated subprocess and the resulting robustness status is collected per model.

In [None]:
def verify_one_point_subprocess(model, image, label, epsilon, device="cuda"):
    with tempfile.TemporaryDirectory() as tmpdir:
        model_path = os.path.join(tmpdir, "model.pt")
        image_path = os.path.join(tmpdir, "image.pt")

        torch.save(model, model_path)
        torch.save(image.cpu(), image_path)

        cmd = [
            sys.executable,
            "verify_point.py",
            "--model_path", model_path,
            "--image_path", image_path,
            "--label", str(int(label)),
            "--epsilon", str(float(epsilon)),
            "--device", device,
        ]

        result = subprocess.run(
            cmd,
            capture_output=True,
            text=True,
            check=False
        )

        return result.stdout.strip()

In [None]:
outputs = [[] for _ in models[:4]]

for model_id, model in enumerate(models[:4]):
    model.eval()

    for i in range(MAX_POINTS):
        image = all_images[0][model_id][0][i]
        label = all_images[0][model_id][1][i]

        status = verify_one_point_subprocess(
            model=model,
            image=image,
            label=label,
            epsilon=epsilon_list[1],
            device=DEVICE,
        )
        status = status.split("\n")[-1]

        outputs[model_id].append(status)
        print("OK", models_name[model_id], i, status)

In [None]:
# Create a dictionary to store the frequency of statuses for each model
status_frequency = {}

# Iterate over the outputs and models_name
for model_name, output in zip(models_name[:4], outputs[:4]):
    # Count the frequency of each status in the output
    status_frequency[model_name] = dict(Counter(output))

print(status_frequency)

In [None]:
# Define all possible labels
all_labels = ['safe', 'verified', 'unsafe-pgd', 'unsafe-bab', 'safe-incomplete', 'unknown']
label_indices = range(len(all_labels))

for model_id, output in enumerate(outputs[:4]):
    # Map outputs to indices based on all_labels
    output_indices = [all_labels.index(label) for label in output]
    
    plt.figure(figsize=(8, 6))
    plt.hist(output_indices, bins=range(len(all_labels) + 1), align='left', rwidth=0.8)
    plt.title(f"Histogram of Labels for Model {models_name[model_id]}")
    plt.xlabel("Labels")
    plt.ylabel("Frequency")
    plt.xticks(label_indices, all_labels, rotation=45)
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.show()

For both `certified_model` and `certified_contrastive_model`, we used the command-line to perform robust verification with $\alpha\beta$-CROWN, as it did not work with our script.