In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import csv
import torch
from model import CNNCrown
from verifier import ABCrown
from torchvision.transforms import ToTensor
from torchvision import datasets, transforms

Adding complete_verifier to sys.path


In [3]:
model_weigths_file = "./custom_model.pt"
log_file_name = "custom_model.csv"
n_images = 2   # number of images to test
epsilon_list = torch.linspace(0.0001, 0.0005, 5)
device = 'cpu'
torch.manual_seed(0)

<torch._C.Generator at 0x11020b350>

In [None]:
# loading the model
n_classes = 10
image_dimension = (3, 32, 32)
net = CNNCrown(image_dimension[0], n_classes)
net.load_state_dict(torch.load(model_weigths_file, map_location=device))
net = net.eval()

In [5]:
# getting the images
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

images_dataset = datasets.CIFAR10(root="data", train=False, download=True, transform=ToTensor())
selected_images = []

# select images that are corrctly classified
for i in range(len(images_dataset)):
    image, label = images_dataset[i]
    logits = net(image.unsqueeze(0))
    prediction = torch.argmax(logits, dim=1)[0].item()
    if prediction == label:
            selected_images.append((image, label))
            
    if len(selected_images) == n_images:
        break

In [6]:
# setting the verifier
verifier = ABCrown(device=device)

In [None]:
error_adversarial_example = -torch.ones_like(selected_images[0][0]).to(device)  # when an error occurs or no adversarial images are found, a Tensor with -1s is saved

# opening the csv
with open(log_file_name, mode="w", newline="") as file:
    # writing the header
    writer = csv.DictWriter(file, fieldnames=['image_id', 'eps', 'status', 'success'])
    writer.writeheader()
    adversarial_examples = []
    # iterating through the images
    for i, (image, label) in enumerate(selected_images):
        adversarial_example_per_image = []
        # iterating thtough the epsilons
        for eps in epsilon_list:
            # if there is an error, don't crash
            result = None
            try:
                result = verifier.verify(net, image, n_classes, label, eps.item()).as_dict()
            except Exception as error:
                # save a dummy tensor (all -1s)
                print(f"> Error at image {i} and eps={eps}: {error}")
                info_to_save = {
                    "image_id": i,
                    "eps": eps.item(),
                    "status": "CRASH",
                    "success": False,
                }
                adversarial_example_per_image.append(error_adversarial_example)
                writer.writerow(info_to_save)
                file.flush()
            
            if result is not None:
                # saving result
                info_to_save = {
                    "image_id": i,
                    "eps": eps.item(),
                    "status": result['status'],
                    "success": result['success'],
                }
                writer.writerow(info_to_save)
                file.flush()
                
                # saving adversarial example
                if result['stats']['attack_examples'].shape[0] != 0:
                    adversarial_example_per_image.append(result['stats']['attack_examples'][0])
                else:
                    adversarial_example_per_image.append(error_adversarial_example)
                
        adversarial_examples.append(torch.stack(adversarial_example_per_image))

# saving the obtained adversarial examples
torch.save(torch.stack(adversarial_examples), "adversarial_examples.pt")

Specification DNF: (y[3] > y[0]) & (y[3] > y[1]) & (y[3] > y[2]) & (y[3] > y[4]) & (y[3] > y[5]) & (y[3] > y[6]) & (y[3] > y[7]) & (y[3] > y[8]) & (y[3] > y[9])
Attack parameters: initialization=uniform, steps=100, restarts=30, alpha=2.5004148483276367e-05, GAMA=False
PGD attack margin (first 2 examples and 10 specs):
 tensor([[115.87402344, 140.26126099, 137.95802307,  90.91664124,  70.87952423,
          79.45481873, 109.64601898, 112.30955505, 121.10418701]])
Total number of violation:  0
Processing batch 1/1...


100%|██████████| 100/100 [01:07<00:00,  1.49it/s]


PGD attack margin (first 2 examples and 10 specs):
 tensor([[114.26208496, 138.24836731, 136.04217529,  89.64808655,  69.82521820,
          78.13153076, 108.04447937, 110.64836884, 119.38836670]])
Total number of violation:  0
Attack finished in 67.3229 seconds.
PGD attack failed
verified_status unknown
verified_success False
Model: BoundedModule(
  (/input-1): BoundInput(name=/input-1, inputs=[], perturbed=True)
  (/1): BoundParams(name=/1, inputs=[], perturbed=False)
  (/2): BoundParams(name=/2, inputs=[], perturbed=False)
  (/3): BoundParams(name=/3, inputs=[], perturbed=False)
  (/4): BoundParams(name=/4, inputs=[], perturbed=False)
  (/5): BoundParams(name=/5, inputs=[], perturbed=False)
  (/6): BoundParams(name=/6, inputs=[], perturbed=False)
  (/7): BoundParams(name=/7, inputs=[], perturbed=False)
  (/8): BoundParams(name=/8, inputs=[], perturbed=False)
  (/9): BoundParams(name=/9, inputs=[], perturbed=False)
  (/10): BoundParams(name=/10, inputs=[], perturbed=False)
  (/11): B

100%|██████████| 100/100 [01:13<00:00,  1.36it/s]


PGD attack margin (first 2 examples and 10 specs):
 tensor([[112.62605286, 136.20693970, 134.09628296,  88.36073303,  68.75930023,
          76.79138184, 106.42063904, 108.96321106, 117.64836121]])
Total number of violation:  0
Attack finished in 73.3380 seconds.
PGD attack failed
verified_status unknown
verified_success False
Model: BoundedModule(
  (/input-1): BoundInput(name=/input-1, inputs=[], perturbed=True)
  (/1): BoundParams(name=/1, inputs=[], perturbed=False)
  (/2): BoundParams(name=/2, inputs=[], perturbed=False)
  (/3): BoundParams(name=/3, inputs=[], perturbed=False)
  (/4): BoundParams(name=/4, inputs=[], perturbed=False)
  (/5): BoundParams(name=/5, inputs=[], perturbed=False)
  (/6): BoundParams(name=/6, inputs=[], perturbed=False)
  (/7): BoundParams(name=/7, inputs=[], perturbed=False)
  (/8): BoundParams(name=/8, inputs=[], perturbed=False)
  (/9): BoundParams(name=/9, inputs=[], perturbed=False)
  (/10): BoundParams(name=/10, inputs=[], perturbed=False)
  (/11): B