# Certified Finetuning of a Classifier on the OCT-MNIST Dataset

In [2]:
%load_ext autoreload
%autoreload 2
import os
import torch
import tqdm
import abstract_gradient_training as agt
from abstract_gradient_training import AGTConfig
from abstract_gradient_training import certified_training_utils as ct_utils
from models.deepmind import DeepMindSmall 
from datasets import oct_mnist
from models.robust_regularizer import parameter_gradient_interval_regularizer

## Test the robustness of a non-robustly pre-trained classifier

In [3]:
device = torch.device("cuda:0")
_, dl_test = oct_mnist.get_dataloaders(1000, exclude_classes=[2], balanced=True)
standard_model = DeepMindSmall(1, 1).to(device)
standard_model.load_state_dict(torch.load(".models/medmnist.ckpt"))
params_l, params_n, params_u = ct_utils.get_parameters(standard_model)
epsilon = 0.0005
test_batch, test_labels = next(iter(dl_test))
accs = agt.test_metrics.test_accuracy(
    params_l, params_n, params_u, test_batch, test_labels, standard_model, ct_utils.propagate_conv_layers, epsilon
)
accs = ", ".join([f"{a:.2f}" for a in accs])

print(f"Accuracy of non-robustly trained classifier on test set with epsilon={epsilon}: [{accs}]")

Accuracy of non-robustly trained classifier on test set with epsilon=0.0005: [0.04, 0.96, 1.00]


## Pre-train the model

Exclude class 2 (Drusen) from the pretraining.

In [23]:
# set up pre-training
torch.manual_seed(1)
pretrain_batchsize = 100
pretrain_n_epochs = 20
pretrain_learning_rate = 0.001
pretrain_epsilon = 0.0
pretrain_model_epsilon = 1e-2
pretrain_reg_strength = 0.05
model_path = f".models/medmnist_param_robust_eps{pretrain_epsilon}_alpha{pretrain_reg_strength}_meps{pretrain_model_epsilon}.ckpt"

In [24]:
# define model, dataset and optimizer
model = DeepMindSmall(1, 1)
dl_pretrain, _ = oct_mnist.get_dataloaders(pretrain_batchsize, exclude_classes=[2], balanced=True)
criterion = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=pretrain_learning_rate)
model = model.to(device)

In [25]:
if os.path.exists(model_path):
    model.load_state_dict(torch.load(model_path))
else:  # pre-train the model
    progress_bar = tqdm.trange(pretrain_n_epochs, desc="Epoch")
    for epoch in progress_bar:
        kappa = min(epoch * pretrain_reg_strength, 0.5)
        for i, (x, u) in enumerate(dl_pretrain):
            # Forward pass
            u, x = u.to(device), x.to(device)
            output = model(x)
            bce_loss = criterion(output.squeeze().float(), u.squeeze().float())
            if bce_loss.item() < 0.05:
                break
            if kappa > 0:
                regularization = parameter_gradient_interval_regularizer(
                    model, x, u, "binary_cross_entropy", pretrain_epsilon, pretrain_model_epsilon
                )
            else:
                regularization = torch.tensor(0.0)
            loss = bce_loss + kappa * regularization
            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if i % 100 == 0:
                progress_bar.set_postfix(loss=loss.item(), bce_loss=bce_loss.item(), reg=regularization.item(), reg_strength=kappa)
    # save the model
    # with open(model_path, "wb") as file:
    #     torch.save(model.state_dict(), file)

Epoch:  50%|█████     | 10/20 [03:16<03:16, 19.61s/it, bce_loss=0.248, loss=0.34, reg=0.184, reg_strength=0.5]  


KeyboardInterrupt: 

### Test the robustness of the model pre-trained with the gradient interval regularization term

In [26]:
_, dl_test = oct_mnist.get_dataloaders(1000, exclude_classes=[2], balanced=True)
params_l, params_n, params_u = ct_utils.get_parameters(model)
epsilon = 0.0005
test_batch, test_labels = next(iter(dl_test))
accs = agt.test_metrics.test_accuracy(
    params_l, params_n, params_u, test_batch, test_labels, model, ct_utils.propagate_conv_layers, epsilon
)
accs = ", ".join([f"{a:.2f}" for a in accs])

print(f"Accuracy of robustly trained classifier on test set with epsilon={epsilon}: [{accs}]")

Accuracy of robustly trained classifier on test set with epsilon=0.0005: [0.91, 0.93, 0.95]


### Fine tune the model with AGT

In [27]:
# set up fine-tuning parameters
batchsize = 5000
config = AGTConfig(
    fragsize=1000,
    learning_rate=0.3,
    n_epochs=1,
    k_private=10,
    clip_gamma=2.0,
    clip_method="clamp",
    dp_sgd_sigma=0.0,
    forward_bound="interval",
    device="cuda:0",
    backward_bound="interval",
    loss="binary_cross_entropy",
    log_level="DEBUG",
    lr_decay=0.3,
    lr_min=0.001,
    early_stopping=False,
)

In [28]:
# get dataloaders, train dataloader is a mix of drusen and the "healthy" class
dl_train, _ = oct_mnist.get_dataloaders(batchsize, 1000, exclude_classes=[0, 1], balanced=True)
_, dl_test_drusen = oct_mnist.get_dataloaders(batchsize, 1000, exclude_classes=[0, 1, 3])
_, dl_test_other = oct_mnist.get_dataloaders(batchsize, 1000, exclude_classes=[2])
_, dl_test_all = oct_mnist.get_dataloaders(batchsize, 1000)

In [29]:
# evaluate the pre-trained model
param_n, param_l, param_u = ct_utils.get_parameters(model)
drusen_acc = agt.test_metrics.test_accuracy(param_n, param_l, param_u, *next(iter(dl_test_drusen)), model, ct_utils.propagate_conv_layers)
other_acc = agt.test_metrics.test_accuracy(param_n, param_l, param_u, *next(iter(dl_test_other)), model, ct_utils.propagate_conv_layers)
all_acc = agt.test_metrics.test_accuracy(param_n, param_l, param_u, *next(iter(dl_test_all)), model, ct_utils.propagate_conv_layers)

print("=========== Pre-trained model accuracy ===========")
print(f"Class 2 (Drusen) : nominal = {drusen_acc[1]:.2g}")
print(f"Classes 0, 1, 3  : nominal = {other_acc[1]:.2g}")
print(f"All Classes      : nominal = {all_acc[1]:.2g}")

Class 2 (Drusen) : nominal = 0.38
Classes 0, 1, 3  : nominal = 0.94
All Classes      : nominal = 0.8


In [30]:
# fine-tune the model using abstract gradient training (keeping the convolutional layers fixed)
param_l, param_n, param_u = agt.privacy_certified_training(
    model, config, dl_train, dl_test_drusen, transform=ct_utils.propagate_conv_layers
)

[AGT] [DEBUG   ] [19:29:42] 	Privacy parameters: k_private=10, clip_gamma=2.0, dp_sgd_sigma=0.0
[AGT] [DEBUG   ] [19:29:42] 	Bounding methods: forward=interval, backward=interval
[AGT] [INFO    ] [19:29:42] Starting epoch 1


[AGT] [DEBUG   ] [19:29:42] Initialising dataloader batchsize to 5000
[AGT] [INFO    ] [19:29:42] Training batch 1: Network eval bounds=(0.38, 0.38, 0.38), W0 Bound=0.0 
[AGT] [INFO    ] [19:29:43] Training batch 2: Network eval bounds=(1   , 1   , 1   ), W0 Bound=1.36 
[AGT] [INFO    ] [19:29:44] Training batch 3: Network eval bounds=(0.23, 0.31, 0.5 ), W0 Bound=2.41 
[AGT] [DEBUG   ] [19:29:45] Skipping batch 4 in epoch 1 (expected batchsize 5000, got 508)
[AGT] [INFO    ] [19:29:45] Final network eval: Network eval bounds=(0.46, 0.9 , 1   ), W0 Bound=3.29 


In [32]:
# evaluate the fine-tuned model
drusen_acc = agt.test_metrics.test_accuracy(param_n, param_l, param_u, *next(iter(dl_test_drusen)), model, ct_utils.propagate_conv_layers)
other_acc = agt.test_metrics.test_accuracy(param_n, param_l, param_u, *next(iter(dl_test_other)), model, ct_utils.propagate_conv_layers)
all_acc = agt.test_metrics.test_accuracy(param_n, param_l, param_u, *next(iter(dl_test_all)), model, ct_utils.propagate_conv_layers)

print("=========== Fine-tuned model accuracy + bounds ===========")
print(f"Class 2 (Drusen) : nominal = {drusen_acc[1]:.2g}, certified bound = {drusen_acc[0]:.2g}")
print(f"Classes 0, 1, 3  : nominal = {other_acc[1]:.2g}, certified bound = {other_acc[0]:.2g}")
print(f"All Classes      : nominal = {all_acc[1]:.2g}, certified bound = {all_acc[0]:.2g}")

Class 2 (Drusen) : nominal = 0.9, certified bound = 0.46
Classes 0, 1, 3  : nominal = 0.81, certified bound = 0.63
All Classes      : nominal = 0.83, certified bound = 0.59
