# Certified Finetuning of a Classifier on the OCT-MNIST Dataset

In [1]:
%load_ext autoreload
%autoreload 2
import logging
import torch
import abstract_gradient_training as agt
from abstract_gradient_training import AGTConfig
from abstract_gradient_training import certified_training_utils as ct_utils
from models.deepmind import DeepMindSmall 
from datasets import oct_mnist
import opacus

In [2]:
# opacus doesn't respect my logging handler :(
logger = logging.getLogger("abstract_gradient_training")
logger.handlers.clear()

## Fine-tune the model on the private Drusen data

In [64]:
# configure a training set-up using Opacus to mimic the AGT training loop
import opacus.accountants


torch.manual_seed(1)
device = "cuda:0"
batchsize = 5000
lr_decay = 0.3
lr_min = 0.001
dp_sgd_sigma = 0.962
clipping = 1.0
learning_rate = 0.3
n_epochs = 1

model = DeepMindSmall(1, 1)
criterion = torch.nn.BCELoss()
model = model.to(device)
model.load_state_dict(torch.load(".models/medmnist.ckpt"))

dl_train, _ = oct_mnist.get_dataloaders(batchsize, 1000, exclude_classes=[0, 1], balanced=True)
_, dl_test_drusen = oct_mnist.get_dataloaders(batchsize, 1000, exclude_classes=[0, 1, 3])
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

privacy_engine = opacus.PrivacyEngine(accountant="rdp")
model_private, optimizer_private, data_loader_private = privacy_engine.make_private(
    module=model,
    optimizer=optimizer,
    data_loader=dl_train,
    noise_multiplier=dp_sgd_sigma,
    max_grad_norm=clipping,
    poisson_sampling=False,
)

def get_lr(epoch):
    lr = max(1 / (1 + lr_decay * epoch), lr_min / learning_rate)
    return lr

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer_private, get_lr)

for _ in range(n_epochs):
    for i, (x, u) in enumerate(dl_train):
        # AGT only takes full batches
        if u.size(0) < batchsize:
            break
        u, x = u.to(device), x.to(device)
        output = model_private(x)
        loss = criterion(output.squeeze().float(), u.squeeze().float())
        # Backward and optimize
        optimizer_private.zero_grad()
        loss.backward()
        optimizer_private.step()
        scheduler.step()

# compute privacy guarantees
delta = 0.0001
epsilon = privacy_engine.accountant.get_epsilon(delta=delta)
print(f"(eps, delta)=({epsilon:.4g}, {delta:.1e})")

(eps, delta)=(3.914, 1.0e-04)


In [40]:
# set up fine-tuning parameters
config = AGTConfig(
    fragsize=1000,
    learning_rate=learning_rate,
    n_epochs=n_epochs,
    k_private=10,
    clip_gamma=clipping,
    clip_method="clamp",
    dp_sgd_sigma=dp_sgd_sigma,
    forward_bound="interval",
    device=device,
    backward_bound="interval",
    loss="binary_cross_entropy",
    log_level="DEBUG",
    lr_decay=lr_decay,
    lr_min=lr_min,
    early_stopping=False,
    noise_type="gaussian",
)

# fine-tune the model using abstract gradient training (keeping the convolutional layers fixed)
param_l, param_n, param_u = agt.privacy_certified_training(
    model, config, dl_train, dl_test_drusen, transform=ct_utils.propagate_conv_layers
)

09/20/2024 18:02:04:DEBUG:	Privacy parameters: k_private=10, clip_gamma=1.0, dp_sgd_sigma=1.0
09/20/2024 18:02:04:DEBUG:	Bounding methods: forward=interval, backward=interval
09/20/2024 18:02:04:DEBUG:	Using Gaussian privacy-preserving noise (std 1)
09/20/2024 18:02:04:INFO:Starting epoch 1
09/20/2024 18:02:04:DEBUG:Initialising dataloader batchsize to 5000
09/20/2024 18:02:04:INFO:Training batch 1: Network eval bounds=(0.82, 0.82, 0.82), W0 Bound=0.0 
09/20/2024 18:02:05:INFO:Training batch 2: Network eval bounds=(0.96, 0.97, 0.98), W0 Bound=0.71 
09/20/2024 18:02:06:INFO:Training batch 3: Network eval bounds=(0.57, 0.68, 0.8 ), W0 Bound=1.31 
09/20/2024 18:02:07:DEBUG:Skipping batch 4 in epoch 1 (expected batchsize 5000, got 508)
09/20/2024 18:02:07:INFO:Final network eval: Network eval bounds=(0.62, 0.87, 0.99), W0 Bound=1.89 


In [34]:
# get dataloaders, train dataloader is a mix of drusen and the "healthy" class
dl_train, _ = oct_mnist.get_dataloaders(batchsize, 1000, exclude_classes=[0, 1], balanced=True)
_, dl_test_drusen = oct_mnist.get_dataloaders(batchsize, 1000, exclude_classes=[0, 1, 3])
_, dl_test_other = oct_mnist.get_dataloaders(batchsize, 1000, exclude_classes=[2])
_, dl_test_all = oct_mnist.get_dataloaders(batchsize, 1000)

# evaluate the fine-tuned model
drusen_acc = agt.test_metrics.test_accuracy(param_n, param_l, param_u, *next(iter(dl_test_drusen)), model, ct_utils.propagate_conv_layers)
other_acc = agt.test_metrics.test_accuracy(param_n, param_l, param_u, *next(iter(dl_test_other)), model, ct_utils.propagate_conv_layers)
all_acc = agt.test_metrics.test_accuracy(param_n, param_l, param_u, *next(iter(dl_test_all)), model, ct_utils.propagate_conv_layers)

print("=========== Fine-tuned model accuracy + bounds ===========")
print(f"Class 2 (Drusen) : nominal = {drusen_acc[1]:.2g}, certified bound = {drusen_acc[0]:.2g}")
print(f"Classes 0, 1, 3  : nominal = {other_acc[1]:.2g}, certified bound = {other_acc[0]:.2g}")
print(f"All Classes      : nominal = {all_acc[1]:.2g}, certified bound = {all_acc[0]:.2g}")

Class 2 (Drusen) : nominal = 0.87, certified bound = 0.62
Classes 0, 1, 3  : nominal = 0.79, certified bound = 0.65
All Classes      : nominal = 0.81, certified bound = 0.64
