In [2]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
import csv
import torch
from model import CNNCrown, CNN, Encoder, LinearClassifier
from verifier import ABCrown
from torchvision.transforms import ToTensor
from torchvision import datasets, transforms

Adding complete_verifier to sys.path


In [7]:
encoder = Encoder(in_channels=3, proj_dim=128)
encoder.load_state_dict(torch.load("./model_weights/no_augmentation/encoder_weights.pt", map_location="cpu"))
classifier = LinearClassifier(in_dim=128, num_classes=10)
classifier.load_state_dict(torch.load("./model_weights/no_augmentation/classifier_weights.pt", map_location="cpu"))
cnn = CNN.import_from(encoder, classifier)
torch.save(cnn.state_dict(), "./model_weights/no_augmentation/contrastive_model.pt")

In [None]:
model_weigths_file = "./model_weights/no_augmentation/contrastive_model.pt"
log_file_name = "custom_model.csv"
n_images = 1   # number of images to test
epsilon_list = torch.tensor([1/255, 2/255, 4/255, 8/255, 16/255])
device = 'cpu'

In [None]:
# loading the model
n_classes = 10
image_dimension = (3, 32, 32)
net = CNNCrown(image_dimension[0], n_classes)
net.load_state_dict(torch.load(model_weigths_file, map_location=device))
net = net.eval()

In [None]:
# getting the images
torch.manual_seed(42)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

images_dataset = datasets.CIFAR10(root="data", train=False, download=True, transform=ToTensor())
selected_images = []

#Â select images that are corrctly classified
for i in range(len(images_dataset)):
    image, label = images_dataset[i]
    logits = net(image.unsqueeze(0))
    prediction = torch.argmax(logits, dim=1)[0].item()
    if prediction == label:
            selected_images.append((image, label))
            
    if len(selected_images) == n_images:
        break

In [None]:
# setting the verifier
verifier = ABCrown(device=device)

In [None]:
error_adversarial_example = -torch.ones_like(selected_images[0][0]).to(device)  # when an error occurs or no adversarial images are found, a Tensor with -1s is saved

# opening the csv
with open(log_file_name, mode="w", newline="") as file:
    # writing the header
    writer = csv.DictWriter(file, fieldnames=['image_id', 'eps', 'status', 'success'])
    writer.writeheader()
    adversarial_examples = []
    # iterating through the images
    for i, (image, label) in enumerate(selected_images):
        adversarial_example_per_image = []
        # iterating thtough the epsilons
        for eps in epsilon_list:
            # if there is an error, don't crash
            result = None
            try:
                result = verifier.verify(net, image, n_classes, label, eps.item()).as_dict()
            except Exception as error:
                # save a dummy tensor (all -1s)
                print(f"> Error at image {i} and eps={eps}: {error}")
                info_to_save = {
                    "image_id": i,
                    "eps": eps.item(),
                    "status": "CRASH",
                    "success": False,
                }
                adversarial_example_per_image.append(error_adversarial_example)
                writer.writerow(info_to_save)
                file.flush()
            
            if result is not None:
                # saving result
                info_to_save = {
                    "image_id": i,
                    "eps": eps.item(),
                    "status": result['status'],
                    "success": result['success'],
                }
                writer.writerow(info_to_save)
                file.flush()
                
                # saving adversarial example
                if result['stats']['attack_examples'].shape[0] != 0:
                    adversarial_example_per_image.append(result['stats']['attack_examples'][0])
                else:
                    adversarial_example_per_image.append(error_adversarial_example)
                
        adversarial_examples.append(torch.stack(adversarial_example_per_image))

# saving the obtained adversarial examples
adversarial_examples = torch.stack(adversarial_examples)
torch.save(adversarial_examples, "adversarial_examples.pt")

In [None]:
# Plotting the adversarial examples

import matplotlib.pyplot as plt

# Extract the selected image and its label
selected_image, label = selected_images[0]

# Extract the adversarial examples for the selected image
adversarial_examples_for_image = adversarial_examples[0]

# Plot the original image
plt.figure(figsize=(15, 5))
plt.subplot(1, len(adversarial_examples_for_image) + 1, 1)
plt.imshow(selected_image.permute(1, 2, 0).numpy())
plt.title(f"Original Image (Label: {label})")
plt.axis('off')

# Plot the adversarial examples
for idx, adv_example in enumerate(adversarial_examples_for_image):
    plt.subplot(1, len(adversarial_examples_for_image) + 1, idx + 2)
    plt.imshow(adv_example.permute(1, 2, 0).detach().cpu().numpy())
    plt.title(f"Adversarial Example {idx + 1}")
    plt.axis('off')

plt.tight_layout()
plt.show()