# Certified Finetuning of a Classifier on the OCT-MNIST Dataset

In [1]:
%load_ext autoreload
%autoreload 2
import os
import torch
import tqdm
import abstract_gradient_training as agt
from abstract_gradient_training import AGTConfig
from abstract_gradient_training import certified_training_utils as ct_utils
from models.deepmind import DeepMindSmall 
from datasets import oct_mnist

## Pre-train the model

Exclude class 2 (Drusen) from the pretraining.

In [2]:
# set up pre-training
torch.manual_seed(1)
device = torch.device("cuda:1")
pretrain_batchsize = 100
pretrain_n_epochs = 20
pretrain_learning_rate = 0.001

In [3]:
# define model, dataset and optimizer
model = DeepMindSmall(1, 1)
dl_pretrain, _ = oct_mnist.get_dataloaders(pretrain_batchsize, exclude_classes=[2], balanced=True)
criterion = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=pretrain_learning_rate)
model = model.to(device)

In [4]:
if os.path.exists(".models/medmnist.ckpt"):
    model.load_state_dict(torch.load(".models/medmnist.ckpt"))
else:  # pre-train the model
    progress_bar = tqdm.trange(pretrain_n_epochs, desc="Epoch", )
    for epoch in progress_bar:
        for i, (x, u) in enumerate(dl_pretrain):
            # Forward pass
            u, x = u.to(device), x.to(device)
            output = model(x)
            loss = criterion(output.squeeze().float(), u.squeeze().float())
            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if i % 100 == 0:
                progress_bar.set_postfix(loss=loss.item())
    # save the model
    with open(".models/medmnist.ckpt", "wb") as file:
        torch.save(model.state_dict(), file)

## Fine-tune the model on the private Drusen data

In [10]:
# set up fine-tuning parameters
batchsize = 5000
config = AGTConfig(
    fragsize=100,
    learning_rate=0.1,
    n_epochs=2,
    k_private=10,
    forward_bound="interval",
    device="cuda:1",
    backward_bound="interval",
    loss="binary_cross_entropy",
    clip_gamma=2.0,
    # dp_sgd_sigma=0.1,
    lr_decay=1.0,
    lr_min=0.001,
    log_level="DEBUG",
)

In [11]:
# get dataloaders, train dataloader is a mix of drusen and the "healthy" class
dl_train, _ = oct_mnist.get_dataloaders(batchsize, 1000, exclude_classes=[0, 1], balanced=True)
_, dl_test_drusen = oct_mnist.get_dataloaders(batchsize, 1000, exclude_classes=[0, 1, 3])
_, dl_test_other = oct_mnist.get_dataloaders(batchsize, 1000, exclude_classes=[2])
_, dl_test_all = oct_mnist.get_dataloaders(batchsize, 1000)

In [12]:
# evaluate the pre-trained model
param_n, param_l, param_u = ct_utils.get_parameters(model)
drusen_acc = agt.test_metrics.test_accuracy(param_n, param_l, param_u, *next(iter(dl_test_drusen)), model, ct_utils.propagate_conv_layers)
other_acc = agt.test_metrics.test_accuracy(param_n, param_l, param_u, *next(iter(dl_test_other)), model, ct_utils.propagate_conv_layers)
all_acc = agt.test_metrics.test_accuracy(param_n, param_l, param_u, *next(iter(dl_test_all)), model, ct_utils.propagate_conv_layers)

print("=========== Pre-trained model accuracy ===========")
print(f"Class 2 (Drusen) : nominal = {drusen_acc[1]:.2g}")
print(f"Classes 0, 1, 3  : nominal = {other_acc[1]:.2g}")
print(f"All Classes      : nominal = {all_acc[1]:.2g}")

Class 2 (Drusen) : nominal = 0.46
Classes 0, 1, 3  : nominal = 0.96
All Classes      : nominal = 0.84


In [13]:
# fine-tune the model using abstract gradient training (keeping the convolutional layers fixed)
param_l, param_n, param_u = agt.privacy_certified_training(
    model, config, dl_train, dl_test_drusen, transform=ct_utils.propagate_conv_layers
)

[AGT] [DEBUG   ] [00:07:55] 	Optimizer params: n_epochs=2, learning_rate=0.1, l1_reg=0.0, l2_reg=0.0
[AGT] [DEBUG   ] [00:07:55] 	Learning rate schedule: lr_decay=1.0, lr_min=0.001, early_stopping=True
[AGT] [DEBUG   ] [00:07:55] 	Privacy parameter: k_private=10
[AGT] [DEBUG   ] [00:07:55] 	Clipping: gamma=2.0, method=clamp
[AGT] [DEBUG   ] [00:07:55] 	Noise: type=gaussian, sigma=0
[AGT] [DEBUG   ] [00:07:55] 	Bounding methods: forward=interval, loss=binary_cross_entropy, backward=interval
[AGT] [INFO    ] [00:07:55] Starting epoch 1


[AGT] [DEBUG   ] [00:07:55] Initialising dataloader batchsize to 5000
[AGT] [INFO    ] [00:07:55] Training batch 1: Network eval bounds=(0.46, 0.46, 0.46), W0 Bound=0.0 
[AGT] [INFO    ] [00:07:57] Training batch 2: Network eval bounds=(0.66, 0.7 , 0.72), W0 Bound=0.461 
[AGT] [INFO    ] [00:07:58] Training batch 3: Network eval bounds=(0.72, 0.77, 0.81), W0 Bound=0.696 
[AGT] [DEBUG   ] [00:08:00] Skipping batch 4 in epoch 1 (expected batchsize 5000, got 508)
[AGT] [INFO    ] [00:08:00] Starting epoch 2
[AGT] [INFO    ] [00:08:00] Training batch 4: Network eval bounds=(0.72, 0.79, 0.84), W0 Bound=0.856 
[AGT] [INFO    ] [00:08:01] Training batch 5: Network eval bounds=(0.7 , 0.8 , 0.87), W0 Bound=0.978 
[AGT] [INFO    ] [00:08:03] Training batch 6: Network eval bounds=(0.69, 0.8 , 0.89), W0 Bound=1.08 
[AGT] [DEBUG   ] [00:08:04] Skipping batch 4 in epoch 2 (expected batchsize 5000, got 508)
[AGT] [INFO    ] [00:08:04] Final network eval: Network eval bounds=(0.67, 0.81, 0.91), W0 Bou

In [14]:
# evaluate the fine-tuned model
drusen_acc = agt.test_metrics.test_accuracy(param_n, param_l, param_u, *next(iter(dl_test_drusen)), model, ct_utils.propagate_conv_layers)
other_acc = agt.test_metrics.test_accuracy(param_n, param_l, param_u, *next(iter(dl_test_other)), model, ct_utils.propagate_conv_layers)
all_acc = agt.test_metrics.test_accuracy(param_n, param_l, param_u, *next(iter(dl_test_all)), model, ct_utils.propagate_conv_layers)

print("=========== Fine-tuned model accuracy + bounds ===========")
print(f"Class 2 (Drusen) : nominal = {drusen_acc[1]:.2g}, certified bound = {drusen_acc[0]:.2g}")
print(f"Classes 0, 1, 3  : nominal = {other_acc[1]:.2g}, certified bound = {other_acc[0]:.2g}")
print(f"All Classes      : nominal = {all_acc[1]:.2g}, certified bound = {all_acc[0]:.2g}")

# percentage of points with a certified unlearning-safe guarantee
percent_certified_all = agt.test_metrics.proportion_certified(
    param_n, param_l, param_u, *next(iter(dl_test_all)), model, ct_utils.propagate_conv_layers
)
percent_certified_drusen = agt.test_metrics.proportion_certified(
    param_n, param_l, param_u, *next(iter(dl_test_drusen)), model, ct_utils.propagate_conv_layers
)
percent_certified_other = agt.test_metrics.proportion_certified(
    param_n, param_l, param_u, *next(iter(dl_test_other)), model, ct_utils.propagate_conv_layers
)
print(f"======= Percentage of points with certified unlearning-safe guarantees =========")
print(f"Class 2 (Drusen) : {percent_certified_drusen:.2g}")
print(f"Classes 0, 1, 3  : {percent_certified_other:.2g}")
print(f"All Classes      : {percent_certified_all:.2g}")

Class 2 (Drusen) : nominal = 0.81, certified bound = 0.67
Classes 0, 1, 3  : nominal = 0.9, certified bound = 0.84
All Classes      : nominal = 0.88, certified bound = 0.8
Class 2 (Drusen) : 0.76
Classes 0, 1, 3  : 0.89
All Classes      : 0.85


In [20]:
# make private predictions
noise_level = 1.0 / 5.0
drusen_acc = agt.test_metrics.test_accuracy(
    param_n,
    param_l,
    param_u,
    *next(iter(dl_test_drusen)),
    model,
    ct_utils.propagate_conv_layers,
    noise_level=noise_level,
)
other_acc = agt.test_metrics.test_accuracy(
    param_n, param_l, param_u, *next(iter(dl_test_other)), model, ct_utils.propagate_conv_layers, noise_level=noise_level
)
all_acc = agt.test_metrics.test_accuracy(
    param_n, param_l, param_u, *next(iter(dl_test_all)), model, ct_utils.propagate_conv_layers, noise_level=noise_level
)

print("=========== Fine-tuned model accuracy + bounds ===========")
print(f"Class 2 (Drusen) : nominal = {drusen_acc[1]:.2g}, certified bound = {drusen_acc[0]:.2g}")
print(f"Classes 0, 1, 3  : nominal = {other_acc[1]:.2g}, certified bound = {other_acc[0]:.2g}")
print(f"All Classes      : nominal = {all_acc[1]:.2g}, certified bound = {all_acc[0]:.2g}")

Class 2 (Drusen) : nominal = 0.8, certified bound = 0.68
Classes 0, 1, 3  : nominal = 0.87, certified bound = 0.81
All Classes      : nominal = 0.85, certified bound = 0.77
