In [None]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append('../')
from unsplit.unsplit.models import MLP, CNN

In [3]:
import os
import torch
os.environ["CUDA_DEVICE_ORDER"]='PCI_BUS_ID'

def set_device(device_no: int):
    if torch.cuda.is_available():
        device = torch.device(f"cuda:{device_no}")
        print("There are %d GPU(s) available." % torch.cuda.device_count())
        print("We will use the GPU:", torch.cuda.get_device_name(device_no))
    else:
        print("No GPU available, using the CPU instead.")
        device = torch.device("cpu")
        
    return device

In [None]:
device = set_device(0)

In [7]:
import unsplit.unsplit.attacks as attacks
from unsplit.unsplit.util import *
from models import *
from tqdm.auto import tqdm
import numpy as np

In [8]:
dataset_name = "cifar10"
architecture = "mlp-mixer"
batch_size = 64

n_epochs = 10
split_layer = 1

In [9]:
def create_models_and_data(dataset_name="mnist", architecture="mlp", batch_size=64, device="cuda:0", seed=0):
    assert dataset_name in ["mnist", "f_mnist", "cifar10"], "Wrong dataset name. Valid options are 'mnist', 'f_mnist' and 'cifar10'."
    assert architecture in ["mlp", "cnn", "mlp-mixer"], "Wrong architecture name. Valid options are 'mlp', 'cnn' and 'mlp-mixer'."
    dataset_creator = datasets.MNIST if dataset_name == "mnist" else datasets.FashionMNIST \
        if dataset_name == "f_mnist" else datasets.CIFAR10
    model_creator = MLP if architecture == "mlp" else CNN if architecture == "cnn" else MLPMixer

    trainset = dataset_creator(f'data/{dataset_name}', download=True, train=True, transform=transforms.ToTensor())
    testset = dataset_creator(f'data/{dataset_name}', download=True, train=False, transform=transforms.ToTensor())

    torch.manual_seed(seed)
    client, server = model_creator().to(device), model_creator().to(device)
    return client, server, trainset, testset

In [None]:
client, server, trainset, testset = create_models_and_data(dataset_name=dataset_name,
                                                           architecture=architecture,
                                                           device=device)

assert next(client.parameters()).is_cuda == True
assert next(server.parameters()).is_cuda == True

In [11]:
def train_client_server(client, server, trainset, testset, split_layer, n_epochs=10, device="cuda:0", batch_size=64):
    trainloader = torch.utils.data.DataLoader(trainset, shuffle=True, batch_size=64)
    testloader = torch.utils.data.DataLoader(testset, shuffle=True, batch_size=64)

    client_opt = torch.optim.Adam(client.parameters(), lr=0.001, amsgrad=True)
    server_opt = torch.optim.Adam(server.parameters(), lr=0.001, amsgrad=True)
    criterion = torch.nn.CrossEntropyLoss()

    for epoch in range(n_epochs):
        running_loss = 0
        for images, labels in tqdm(trainloader):
            images, labels = images.to(device), labels.to(device)
            client_opt.zero_grad()
            server_opt.zero_grad()

            client_pred = client(images, end=split_layer)
            pred = server(client_pred, start=split_layer+1)

            loss = criterion(pred, labels)
            loss.backward()
            running_loss += loss

            server_opt.step()
            client_opt.step()
        print(f'Epoch: {epoch} Loss: {running_loss / len(trainloader)} Acc: {get_test_acc(client, server, testloader, split=split_layer)}')
    return client, server

In [None]:
client, server = train_client_server(client=client, server=server, trainset=trainset, testset=testset, 
                                     split_layer=split_layer, n_epochs=10, device=device, batch_size=64)

In [13]:
COUNT = 1
inversion_targets = [get_examples_by_class(testset, i, count=COUNT).to(device) for i in range(10)]

In [None]:
display_cifar(inversion_targets)

In [19]:
main_iters, input_iters, model_iters = 200, 20, 20
lambda_l2, lambda_tv = 0.1, 1.0

In [25]:
def launch_attack(inversion_targets, client, split_layer, clone_architecture="mlp",
                  main_iters=1000, input_iters=100, model_iters=100,
                  lambda_tv=0.1, lambda_l2=1, device="cuda:0"):
    assert clone_architecture in ["mlp", "cnn", "mlp-mixer"], "Wrong architecture name. Valid options are 'mlp' and 'cnn'."
    clone = MLP() if architecture == "mlp" else CNN() if architecture == "cnn" else MLPMixer()
    mse = torch.nn.MSELoss()

    reconstructed_images, reconstruction_losses = [], []
    cut_layer_training_losses = []
    for idx, target in enumerate(inversion_targets):
        # obtain client output
        with torch.no_grad():
            client_out = client(target, end=split_layer)

        # perform the attack
        target_size = target.size()
        reconstructed, cur_loss_arr = attacks.model_inversion_stealing(
            clone, split_layer, client_out, target_size,
            main_iters=main_iters, input_iters=input_iters, model_iters=model_iters,
            lambda_tv=lambda_tv, lambda_l2=lambda_l2, device=device
        )
        cut_layer_training_losses.append(cur_loss_arr)

        # save result
        reconstructed = normalize(reconstructed)
        reconstructed_images.append(reconstructed)
        reconstruction_loss = mse(reconstructed, target)
        reconstruction_losses.append(reconstruction_loss.item())
    return clone, reconstructed_images, reconstruction_losses, cut_layer_training_losses

In [None]:
clone, reconstructed_images, reconstruction_losses, cut_layer_training_losses = launch_attack(
    inversion_targets, client, split_layer, clone_architecture=architecture,
    main_iters=main_iters, input_iters=input_iters, model_iters=model_iters,
    lambda_tv=lambda_tv, lambda_l2=lambda_l2, device=device
)

In [None]:
display_cifar(inversion_targets)
display_cifar(reconstructed_images)

In [None]:
def compute_metrics(inversion_targets, reconstructed_images,
                    reconstruction_losses, cut_layer_training_losses,
                    use_fid=False):
    reconstruction_mse = np.mean(reconstruction_losses)
    cut_layer_mse = np.mean([loss_arr[-1] for loss_arr in cut_layer_training_losses])
    print(f"Recounstruction MSE: {reconstruction_mse:.3f}")
    print(f"Log10 of cut layer MSE: {np.log10(cut_layer_mse):.3f}")
    if use_fid:
        print(f"FID: {compute_fid(inversion_targets, reconstructed_images):.1f}")

In [None]:
compute_metrics(inversion_targets, reconstructed_images, reconstruction_losses, cut_layer_training_losses)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_theme()
fig, ax = plt.subplots(1, 2, figsize=(15, 6))
cut_layer_mse = np.mean(cut_layer_training_losses, axis=1)
markers = ['o','s']
titles = [f'Reconstruction MSE',
          f'Cut Layer MSE']
data = [reconstruction_losses, cut_layer_mse]

for i, axis in enumerate(ax):
    axis.plot(data[i], linewidth=2, marker=markers[i], markersize=6, markevery=2)
    axis.set_xlabel('class', fontsize=30)
    axis.set_ylabel('MSE', fontsize=30)
    axis.set_title(titles[i], fontsize=30)
    axis.grid(True)
    axis.tick_params(labelsize=15)

plt.subplots_adjust(wspace=0.3)
plt.suptitle('MLP-Mixer on CIFAR10, Cut Layer = {}'.format(split_layer), fontsize=35, y=1.05)
plt.show()