In [None]:
# required modules

!pip install --upgrade wandb==0.22.3   # needed for logging
!git clone https://github.com/Verified-Intelligence/auto_LiRPA
!pip install ./auto_LiRPA

In [None]:
# used for logging
import wandb

wandb_key = ... # YOUR KEY
wandb.login(key=wandb_key)

# Training

## Configuration

Let's import the needed modules and load the dataset.

In [None]:
import torch
import torch.nn as nn
from PIL import Image
import torch.optim as optim
from typing import Callable, Tuple
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
from auto_LiRPA import BoundedModule, BoundedTensor, PerturbationLpNorm

from model import CNNCrown, Encoder, LinearClassifier
from losses import SupConLoss
from verifier import PGDVerifier
from utils import train, test, get_device, get_embeddings_plot, RandomGaussianNoise

In [None]:
DEVICE = get_device()
BATCH_SIZE = 128
PROJ_DIM = 128

In [None]:
AUGMENTATION = False
AUGMENTATION_LABEL = "" if AUGMENTATION else "No"

if AUGMENTATION:
    transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(15),
        transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        RandomGaussianNoise(p=0.5),
    ])
else:
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

torch.manual_seed(42)
dataset = datasets.CIFAR10(root="data", train=True, download=True, transform=transform)

train_ratio, validation_ratio = 0.8, 0.2
dataset_size = len(dataset)
train_size = int(train_ratio * dataset_size)
validation_size = dataset_size - train_size

train_dataset, validation_dataset = random_split(dataset, [train_size, validation_size])
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
validation_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

test_dataset = datasets.CIFAR10(root="data", train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

## Contrastive Model Training

Steps needed:
1. train the encoder with supervised contrastive loss;
2. train the classifier with cross entropy loss.

In [None]:
# --- Encoder Training
EPOCHS = 20
learning_rate = 1e-3
sup_con_loss = SupConLoss()
encoder = Encoder(proj_dim=PROJ_DIM).to(DEVICE)
encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)

In [None]:
wandb.init(
    project="Cnn-Verification",
    name=f"Encoder - {AUGMENTATION_LABEL} Augmentation",
    config={
        "learning_rate": learning_rate,
        "epochs": EPOCHS,
        "batch_size": BATCH_SIZE,
        "projection_dimension": PROJ_DIM,
        "loss": "Supervised Contrastive Loss"
    }
)

In [None]:
encoder = train(
    encoder,
    train_loader,
    validation_loader,
    encoder_optimizer,
    sup_con_loss,
    EPOCHS,
    DEVICE,
    compute_accuracy=False,
    wandb_logging=True
)

In [None]:
# logging embeddings
buf = get_embeddings_plot(encoder, train_loader, validation_loader, DEVICE)
wandb.log({"embeddings_space": wandb.Image(Image.open(buf))})
# logging weights
model_filename = "encoder_weights.pt"
torch.save(encoder.state_dict(), model_filename)
artifact = wandb.Artifact("model", type="model")
artifact.add_file(f"/kaggle/working/{model_filename}")
wandb.log_artifact(artifact)

In [None]:
wandb.finish()

In [None]:
# --- Classifier Training
EPOCHS = 10
learning_rate = 0.001
cross_entropy_loss = nn.CrossEntropyLoss()
classifier = LinearClassifier(in_dim=PROJ_DIM).to(DEVICE)
classifier_optimizer = optim.Adam(classifier.parameters(), lr=learning_rate)

In [None]:
wandb.init(
    project="Cnn-Verification",
    name=f"Classifier - {AUGMENTATION_LABEL} Augmentation",
    config={
        "learning_rate": learning_rate,
        "epochs": EPOCHS,
        "batch_size": BATCH_SIZE,
        "projection_dimension": PROJ_DIM,
        "loss": "CrossEntropyLoss"
    }
)

In [None]:
def execute_classifier(encoder:nn.Module) -> Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
    # used to compute the embeddings given the encoder
    
    def main(images:torch.Tensor, labels:torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        encoder.eval()
        with torch.no_grad():
            embeddings = encoder(images)
        return embeddings, labels
    return main

classifier = train(
    classifier,
    train_loader,
    validation_loader,
    classifier_optimizer,
    cross_entropy_loss,
    EPOCHS,
    DEVICE,
    middleware=execute_classifier(encoder),
    wandb_logging=True
)

In [None]:
# logging weights
model_filename = "classifier_weights.pt"
torch.save(classifier.state_dict(), model_filename)
artifact = wandb.Artifact("model", type="model")
artifact.add_file(f"/kaggle/working/{model_filename}")
wandb.log_artifact(artifact)

In [None]:
wandb.finish()

## Normal Model Training

Training together the encoder and classifier using cross entropy loss.

In [None]:
# --- Normal Model Training
EPOCHS = 30
BATCH_SIZE = 128
learning_rate = 0.001
cross_entropy_loss = nn.CrossEntropyLoss()
full_model = CNNCrown(proj_dim=PROJ_DIM).to(DEVICE)
full_model_optimizer = optim.Adam(full_model.parameters(), lr=learning_rate)

In [None]:
wandb.init(
    project="Cnn-Verification",
    name=f"Normal Model - {AUGMENTATION_LABEL} Augmentation",
    config={
        "learning_rate": learning_rate,
        "epochs": EPOCHS,
        "batch_size": BATCH_SIZE,
        "projection_dimension": PROJ_DIM,
        "loss": "CrossEntropy"
    }
)

In [None]:
full_model = train(
    full_model,
    train_loader,
    validation_loader,
    full_model_optimizer,
    cross_entropy_loss,
    EPOCHS,
    DEVICE,
    wandb_logging=True
)

In [None]:
# logging embeddings
buf = get_embeddings_plot(full_model.encoder, train_loader, validation_loader, DEVICE)
wandb.log({"embeddings_space": wandb.Image(Image.open(buf))})
# logging weights
model_filename = "full_model_weights.pt"
torch.save(full_model.state_dict(), model_filename)
artifact = wandb.Artifact("model", type="model")
artifact.add_file(f"/kaggle/working/{model_filename}")
wandb.log_artifact(artifact)

In [None]:
wandb.finish()

## Normal Model with Adversarial Training

Training together the encoder and the classifier using cross entropy loss. During the training the batch is enlarged with adversarial examples found using PGD.

**Important**: the batch size you set at the beginning is going to be doubled because of the adversarial examples. Halve it if needed!

In [None]:
# --- Adversarial Training

EPOCHS = 10
learning_rate = 0.001
cross_entropy_loss = nn.CrossEntropyLoss()
adversarial_model = CNNCrown(proj_dim=PROJ_DIM).to(DEVICE)
adversarial_model_optimizer = optim.Adam(adversarial_model.parameters(), lr=learning_rate)
pgd = PGDVerifier(device=DEVICE)

In [None]:
wandb.init(
    project="Cnn-Verification",
    name=f"Adversarial Model - {AUGMENTATION_LABEL} Augmentation",
    config={
        "learning_rate": learning_rate,
        "epochs": EPOCHS,
        "batch_size": BATCH_SIZE * 2,
        "projection_dimension": PROJ_DIM,
        "loss": "CrossEntropy"
    }
)

In [None]:
def compute_adversarial_examples(adversarial_model:nn.Module) -> Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
    # used to enlarge the bacth size with the adversarial examples

    def main(images:torch.Tensor, labels:torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        adversarial_examples, _, _ = pgd.verify(adversarial_model, images, labels, clamp_min=-1, clamp_max=1)
        adversarial_examples.requires_grad = False
        images = torch.cat([images, adversarial_examples])
        labels = torch.cat([labels, labels])
        return images, labels
    return main

adversarial_model = train(
    adversarial_model,
    train_loader,
    validation_loader,
    adversarial_model_optimizer,
    cross_entropy_loss,
    EPOCHS,
    DEVICE,
    middleware=compute_adversarial_examples(adversarial_model),
    wandb_logging=True
)

In [None]:
# logging embeddings
buf = get_embeddings_plot(adversarial_model.encoder, train_loader, validation_loader, DEVICE)
wandb.log({"embeddings_space": wandb.Image(Image.open(buf))})
# logging weights
model_filename = "adversarial_model_weights.pt"
torch.save(adversarial_model.state_dict(), model_filename)
artifact = wandb.Artifact("model", type="model")
artifact.add_file(f"/kaggle/working/{model_filename}")
wandb.log_artifact(artifact)

In [None]:
wandb.finish()

## Contrastive Model with Adversarial Triaining
Steps:
1. enlarge the batch using adversarial examples;
2. train the encoder with supervised contrastive loss;
3. train the classifier with cross entropy loss.

**Important**: the batch size defined above is going to be doubled because of adversarial examples. Halve it if needed!

In [None]:
# --- Adversarial Training With Supervised Contrastive Loss

# --- Encoder Training
EPOCHS = 40
learning_rate = 1e-3
sup_con_loss = SupConLoss(temperature=0.1)
adversarial_encoder = Encoder(proj_dim=PROJ_DIM).to(DEVICE)
adversarial_encoder_optimizer = optim.Adam(adversarial_encoder.parameters(), lr=learning_rate)
pgd = PGDVerifier(device=DEVICE)

In [None]:
wandb.init(
    project="Cnn-Verification",
    name=f"Adversarial Contrastive Encoder - {AUGMENTATION_LABEL} Augmentation",
    config={
        "learning_rate": learning_rate,
        "epochs": EPOCHS,
        "batch_size": BATCH_SIZE * 2,
        "projection_dimension": PROJ_DIM,
        "loss": "Supervised Contrastive Loss"
    }
)

In [None]:
def compute_adversarial_examples(adversarial_encoder:nn.Module) -> Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
    # used to add the adversarial examples in the batch
    sup_con_loss = SupConLoss()
    
    def main(images:torch.Tensor, labels:torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        adversarial_examples, _, _ = pgd.verify(adversarial_encoder, images, labels, clamp_min=-1, clamp_max=1, criterion=sup_con_loss)
        adversarial_examples.requires_grad = False
        images = torch.cat([images, adversarial_examples])
        labels = torch.cat([labels, labels])
        return images, labels

    return main

adversarial_encoder = train(
    adversarial_encoder,
    train_loader,
    validation_loader,
    adversarial_encoder_optimizer,
    sup_con_loss,
    EPOCHS,
    DEVICE,
    middleware=compute_adversarial_examples(adversarial_encoder),
    compute_accuracy=False,
    wandb_logging=True,
)

In [None]:
# logging embeddings
buf = get_embeddings_plot(adversarial_encoder, train_loader, validation_loader, DEVICE)
wandb.log({"embeddings_space": wandb.Image(Image.open(buf))})
# logging weights
model_filename = "adversarial_encoder_weights.pt"
torch.save(adversarial_encoder.state_dict(), model_filename)
artifact = wandb.Artifact("model", type="model")
artifact.add_file(f"/kaggle/working/{model_filename}")
wandb.log_artifact(artifact)

In [None]:
wandb.finish()

In [None]:
# Training the Linear Classifier

EPOCHS = 10
learning_rate = 0.001
adversarial_encoder.eval()
cross_entropy_loss = nn.CrossEntropyLoss()
adversarial_classifier = LinearClassifier(in_dim=PROJ_DIM).to(DEVICE)
adversarial_classifier_optimizer = optim.Adam(adversarial_classifier.parameters(), lr=learning_rate)
pgd = PGDVerifier(device=DEVICE)

In [None]:
wandb.init(
    project="Cnn-Verification",
    name=f"Adversarial Contrastive Classifier - {AUGMENTATION_LABEL} Augmentation",
    config={
        "learning_rate": learning_rate,
        "epochs": EPOCHS,
        "batch_size": BATCH_SIZE * 2,
        "projection_dimension": PROJ_DIM,
        "loss": "CrossEntropy Loss"
    }
)

In [None]:
def compute_adversarial_examples(adversarial_encoder:nn.Module) -> Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
    # used to compute the embeddigns with the encoder and to add the adversarial examples in the batch
    
    def main(images:torch.Tensor, labels:torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        adversarial_examples, _, _ = pgd.verify(adversarial_encoder, images, labels, clamp_min=-1, clamp_max=1)
        adversarial_examples.requires_grad = False
        images = torch.cat([images, adversarial_examples])
        labels = torch.cat([labels, labels])
        with torch.no_grad():
            embeddings = adversarial_encoder(images)
        return embeddings, labels

    return main

adversarial_classifier = train(
    adversarial_classifier,
    train_loader,
    validation_loader,
    adversarial_classifier_optimizer,
    cross_entropy_loss,
    EPOCHS,
    DEVICE,
    middleware=compute_adversarial_examples(adversarial_encoder),
    wandb_logging=True,
)

In [None]:
# logging weights
model_filename = "adversarial_classifier_weights.pt"
torch.save(adversarial_classifier.state_dict(), model_filename)
artifact = wandb.Artifact("model", type="model")
artifact.add_file(f"/kaggle/working/{model_filename}")
wandb.log_artifact(artifact)

In [None]:
wandb.finish()

## Certified Model

Training together the encoder and classifier with cross entropy loss. The input is perturbed by an $\epsilon$ value, relative bounds are propagated using CROWN-IBP and the lower bound is passed to the loss. 

In [None]:
epsilon = 2/255  # image perturbation

In [None]:
# Certified Model Train

EPOCHS = 30
BATCH_SIZE = 128
learning_rate = 0.001
cross_entropy_loss = nn.CrossEntropyLoss()
certified_model = CNNCrown(proj_dim=PROJ_DIM).to(DEVICE)
certified_model = BoundedModule(certified_model, torch.empty(2, 3, 32, 32))
certified_model_optimizer = optim.Adam(certified_model.parameters(), lr=learning_rate)

In [None]:
wandb.init(
    project="Cnn-Verification",
    name=f"Certified Model - {AUGMENTATION_LABEL} Augmentation",
    config={
        "learning_rate": learning_rate,
        "epochs": EPOCHS,
        "batch_size": BATCH_SIZE * 2,
        "projection_dimension": PROJ_DIM,
        "loss": "CrossEntropy Loss"
    }
)

In [None]:
# train loop
import warnings
warnings.filterwarnings("ignore", category=ResourceWarning)

ptb = PerturbationLpNorm(norm=float('inf'), eps=epsilon)

for epoch in range(EPOCHS):
    certified_model.train()
    train_loss = 0
    train_accuracy = 0
    for _, (images, labels) in enumerate(train_loader):
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)
        images_bounded = BoundedTensor(images, ptb)

        certified_model_optimizer.zero_grad()
        lb, ub = certified_model.compute_bounds(x=(images_bounded,), method="CROWN-IBP")
        loss = cross_entropy_loss(lb, labels)
        loss.backward()
        certified_model_optimizer.step()

        train_loss += loss.item()
        train_accuracy += (torch.argmax(lb, dim=1) == labels).sum().item() / len(labels)

    train_loss /= len(train_loader)
    train_accuracy = train_accuracy / len(train_loader) * 100

    certified_model.eval()
    with torch.no_grad():
        validation_loss = 0
        validation_accuracy = 0
        for images, labels in validation_loader:
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)
            images_bounded = BoundedTensor(images, ptb)

            lb, ub = certified_model.compute_bounds(x=(images_bounded,), method="CROWN-IBP")
            loss = cross_entropy_loss(lb, labels)

            validation_loss += loss.item()
            validation_accuracy += (torch.argmax(lb, dim=1) == labels).sum().item() / len(labels)

    validation_loss /= len(validation_loader)
    validation_accuracy = validation_accuracy / len(validation_loader) * 100

    print(f"> Epoch {epoch+1}/{EPOCHS}")
    print(f"  Training loss      : {train_loss:.4f}, Training accuracy  : {train_accuracy:.2f}%")
    print(f"  Validation loss    : {validation_loss:.4f}, Validation accuracy: {validation_accuracy:.2f}%")

    log = {
        "train_loss": train_loss,
        "validation_loss": validation_loss,
        "train_accuracy": train_accuracy,
        "validation_accuracy": validation_accuracy,
    }

    wandb.log(log)


In [None]:
# logging weights
model_filename = "certified_model.pt"
torch.save(certified_model.state_dict(), model_filename)
artifact = wandb.Artifact("model", type="model")
artifact.add_file(f"/kaggle/working/{model_filename}")
wandb.log_artifact(artifact)

In [None]:
wandb.finish()

## Certified Contrastive Model

The steps are:
1. train the encoder using bounds propagation: the lower bound is passed to the supervised contrastive loss;
2. train the classifier using the cross entropy loss with the lower bounds obtained by the previous step.

In [None]:
# Certified Contrastive Model - Encoder

epsilon = 2/255  # image perturbation

EPOCHS = 20
BATCH_SIZE = 128
learning_rate = 0.001
sup_con_loss = SupConLoss()
certified_contrastive_encoder = Encoder(proj_dim=PROJ_DIM,).to(DEVICE)
certified_contrastive_encoder = BoundedModule(certified_contrastive_encoder, torch.empty(2, 3, 32, 32))
certified_contrastive_encoder_optimizer = optim.Adam(certified_contrastive_encoder.parameters(), lr=learning_rate)

In [None]:
wandb.init(
    project="Cnn-Verification",
    name=f"Certified Contrastive Encoder - {AUGMENTATION_LABEL} Augmentation",
    config={
        "learning_rate": learning_rate,
        "epochs": EPOCHS,
        "batch_size": BATCH_SIZE,
        "projection_dimension": PROJ_DIM,
        "loss": "Supervised Contrastive Loss"
    }
)

In [None]:
# train loop
import warnings
warnings.filterwarnings("ignore", category=ResourceWarning)

ptb = PerturbationLpNorm(norm=float('inf'), eps=epsilon)

for epoch in range(EPOCHS):
    certified_contrastive_encoder.train()
    train_loss = 0
    train_accuracy = 0
    for _, (images, labels) in enumerate(train_loader):
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)
        images_bounded = BoundedTensor(images, ptb)

        certified_contrastive_encoder_optimizer.zero_grad()
        lb, ub = certified_contrastive_encoder.compute_bounds(x=(images_bounded,), method="CROWN-IBP")
        loss = sup_con_loss(lb, labels)
        loss.backward()
        certified_contrastive_encoder_optimizer.step()

        train_loss += loss.item()
        train_accuracy += (torch.argmax(lb, dim=1) == labels).sum().item() / len(labels)

    train_loss /= len(train_loader)

    certified_contrastive_encoder.eval()
    with torch.no_grad():
        validation_loss = 0
        validation_accuracy = 0
        for images, labels in validation_loader:
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)
            images_bounded = BoundedTensor(images, ptb)

            lb, ub = certified_contrastive_encoder.compute_bounds(x=(images_bounded,), method="CROWN-IBP")
            loss = sup_con_loss(lb, labels)

            validation_loss += loss.item()
            validation_accuracy += (torch.argmax(lb, dim=1) == labels).sum().item() / len(labels)

    validation_loss /= len(validation_loader)

    print(f"> Epoch {epoch+1}/{EPOCHS}")
    print(f"  Training loss      : {train_loss:.4f}")
    print(f"  Validation loss    : {validation_loss:.4f}")

    log = {
        "train_loss": train_loss,
        "validation_loss": validation_loss,
    }

    wandb.log(log)


In [None]:
# logging weights
model_filename = "certified_contrastive_encoder.pt"
torch.save(certified_contrastive_encoder.state_dict(), model_filename)
artifact = wandb.Artifact("model", type="model")
artifact.add_file(f"/kaggle/working/{model_filename}")
wandb.log_artifact(artifact)

In [None]:
wandb.finish()

In [None]:
# Certified Contrastive Model - Classifier

EPOCHS = 10
BATCH_SIZE = 128
learning_rate = 0.001
cross_entropy_loss = nn.CrossEntropyLoss()
certified_contrastive_classifier = LinearClassifier(in_dim=PROJ_DIM).to(DEVICE)
certified_contrastive_classifier_optimizer = optim.Adam(certified_contrastive_classifier.parameters(), lr=learning_rate)

In [None]:
wandb.init(
    project="Cnn-Verification",
    name=f"Certified Contrastive Classifier - {AUGMENTATION_LABEL} Augmentation",
    id="cy5rejql",
    resume="allow",
    config={
        "learning_rate": learning_rate,
        "epochs": EPOCHS,
        "batch_size": BATCH_SIZE,
        "projection_dimension": PROJ_DIM,
        "loss": "CrossEntropy Loss"
    }
)

In [None]:
def get_certfied_embeddings(certified_contrastive_encoder:nn.Module, epsilon:float) -> Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
    # used to get the perturbed embeddings from the encoder
    
    def main(images:torch.Tensor, labels:torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        certified_contrastive_encoder.eval()
        ptb = PerturbationLpNorm(norm=float('inf'), eps=epsilon)
        images_bounded = BoundedTensor(images, ptb)
        with torch.no_grad():
            certified_contrastive_encoder_optimizer.zero_grad()
            embeddings, _ = certified_contrastive_encoder.compute_bounds(x=(images_bounded,), method="CROWN-IBP")
        return embeddings, labels
        
    return main

certified_contrastive_classifier = train(
    certified_contrastive_classifier,
    train_loader,
    validation_loader,
    certified_contrastive_classifier_optimizer,
    cross_entropy_loss,
    EPOCHS,
    DEVICE,
    middleware=get_certfied_embeddings(certified_contrastive_encoder, epsilon),
    wandb_logging=True
)

In [None]:
# logging weights
model_filename = "certified_contrastive_classifier.pt"
torch.save(certified_contrastive_classifier.state_dict(), model_filename)
artifact = wandb.Artifact("model", type="model")
artifact.add_file(f"/kaggle/working/{model_filename}")
wandb.log_artifact(artifact)

In [None]:
wandb.finish()

---

# Testing

Since some models are trained with augmented data, adversarial examples and intervals propagation, let's compute the accuracy on the original CIFAR10 images.

In [None]:
import os
import torch
import multiprocessing
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
from auto_LiRPA import BoundedModule, BoundedTensor, PerturbationLpNorm

In [None]:
multiprocessing.set_start_method('spawn')

In [None]:
# loading all the models

DEVICE = "cuda"

augmentation_path = "/kaggle/input/cnnrobust/pytorch/nopooling_models/5/augmentation"
no_augmentation_path = "/kaggle/input/cnnrobust/pytorch/nopooling_models/5/no_augmentation"

# loading no certified models
models_weights = [
    torch.load(f"{augmentation_path}/normal_model.pt"),
    torch.load(f"{augmentation_path}/contrastive_model.pt"),
    torch.load(f"{augmentation_path}/adversarial_model.pt"),
    torch.load(f"{augmentation_path}/adversarial_contrastive_model.pt"),
    torch.load(f"{no_augmentation_path}/normal_model.pt"),
    torch.load(f"{no_augmentation_path}/contrastive_model.pt"),
    torch.load(f"{no_augmentation_path}/adversarial_model.pt"),
    torch.load(f"{no_augmentation_path}/adversarial_contrastive_model.pt"),
]

models = []

for weights in models_weights:
    model = CNNCrown()
    model.load_state_dict(weights)
    models.append(model)

# -- loading certified models by hand
# with augmentation

certified_encoder = BoundedModule(Encoder(), torch.empty(2, 3, 32, 32))
certified_encoder.load_state_dict(torch.load(f"{augmentation_path}/certified_contrastive_encoder.pt"))
certified_classifier = LinearClassifier()
certified_classifier.load_state_dict(torch.load(f"{augmentation_path}/certified_contrastive_classifier.pt"))
certified_contrastive_model = CNNCrown()
certified_contrastive_model.encoder = certified_encoder
certified_contrastive_model.classifier = certified_classifier
models[4:4] = [certified_contrastive_model]    # the models trained using augmentation are in the first part of the list

certified_model = BoundedModule(CNNCrown(), torch.empty(2, 3, 32, 32))
certified_model.load_state_dict(torch.load(f"{augmentation_path}/certified_model.pt"))
models[4:4] = [certified_model]    # the models trained using augmentation are in the first part of the list

# no augmentation
certified_model = BoundedModule(CNNCrown(), torch.empty(2, 3, 32, 32))
certified_model.load_state_dict(torch.load(f"{no_augmentation_path}/certified_model.pt"))
models.append(certified_model)    # the models trained using no augmentation are in the last part of the list

certified_encoder = BoundedModule(Encoder(), torch.empty(2, 3, 32, 32))
certified_encoder.load_state_dict(torch.load(f"{no_augmentation_path}/certified_contrastive_encoder.pt"))
certified_classifier = LinearClassifier()
certified_classifier.load_state_dict(torch.load(f"{no_augmentation_path}/certified_contrastive_classifier.pt"))
certified_contrastive_model = CNNCrown()
certified_contrastive_model.encoder = certified_encoder
certified_contrastive_model.classifier = certified_classifier
models.append(certified_contrastive_model)    # the models trained using no augmentation are in the last part of the list


models_name = ["Normal Model", "Contrastive Model", "Adversarial Model", "Adversarial Contrastive", "Certified", "Certified Contrastive"] * 2

In [None]:
# loading dataset - no augmentation used
torch.manual_seed(42)
BATCH_SIZE = 2048
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

dataset = datasets.CIFAR10(root="data", train=True, download=True, transform=transform)

train_ratio, validation_ratio = 0.8, 0.2
dataset_size = len(dataset)
train_size = int(train_ratio * dataset_size)
validation_size = dataset_size - train_size

train_dataset, validation_dataset = random_split(dataset, [train_size, validation_size])
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)
validation_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)

test_data = datasets.CIFAR10(root="data", train=False, download=True, transform=transform)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)

In [None]:
print("> Accuracy")
for model_id, (model, model_name) in enumerate(zip(models, models_name)):
    train_accuracy, test_accuracy = test(model, train_loader, test_loader, DEVICE)
    
    if model_id == 0:
        print("\t- Augmentation")
    if model_id == 6:
        print("\t- No Augmentation")

    print(f"\t\t- {model_name}: {train_accuracy:.2f}% -> {test_accuracy:.2f}%")