In [1]:
#!pip install -r requirements.txt
import numpy as np
import torch
from torchvision import transforms, datasets
from torchvision.utils import save_image

import unsplit.attacks as unsplit
from unsplit.models import *
from unsplit.util import *

import optimizers as opt



In [2]:
print(torch.cuda.is_available())
print(torch.cuda.current_device())
print(torch.cuda.device_count())

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

True
0
1


device(type='cuda')

In [3]:
split_layer = 1

In [18]:
dataset = 'mnist'

if dataset == 'mnist':
    trainset = datasets.MNIST('data/mnist', download=True, train=True, transform=transforms.ToTensor())
    testset = datasets.MNIST('data/mnist', download=True, train=False, transform=transforms.ToTensor())
    client, server, clone = MnistNet(), MnistNet(), MnistNet()
elif dataset == 'f_mnist':
    trainset = datasets.FashionMNIST('data/f_mnist', download=True, train=True, transform=transforms.ToTensor())
    testset = datasets.FashionMNIST('data/f_mnist', download=True, train=False, transform=transforms.ToTensor())
    client, server, clone = MnistNet(), MnistNet(), MnistNet()
elif dataset == 'cifar':
    trainset = datasets.CIFAR10('data/cifar', download=True, train=True, transform=transforms.ToTensor())
    testset = datasets.CIFAR10('data/cifar', download=True, train=False, transform=transforms.ToTensor())
    client, server, clone = CifarNet(), CifarNet(), CifarNet()

trainloader = torch.utils.data.DataLoader(trainset, shuffle=True, batch_size=64)
testloader = torch.utils.data.DataLoader(testset, shuffle=True)

In [5]:
for images, labels in trainloader:
    for image in images:
        image = image.to(device)
    for label in labels:
        label = label.to(device)

## Learning without noise

In [31]:
client_opt = torch.optim.SGD(client.parameters(), lr=0.001)
server_opt = torch.optim.SGD(server.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()

epochs = 50
for epoch in range(epochs):
    running_loss = 0
    for images, labels in trainloader:
        client_opt.zero_grad()
        server_opt.zero_grad()

        pred = server(client(images, end=split_layer), start=split_layer+1)

        loss = criterion(pred, labels)
        loss.backward()
        running_loss += loss

        server_opt.step()
        client_opt.step()
    else:
        print(f'Epoch: {epoch} Loss: {running_loss / len(trainloader)} Acc: {get_test_score(client, server, testset, split=split_layer)}')

Epoch: 0 Loss: 2.302513599395752 Acc: 9.23076923076923
Epoch: 1 Loss: 2.2995636463165283 Acc: 9.994538503549972
Epoch: 2 Loss: 2.2964916229248047 Acc: 10.399562123700054
Epoch: 3 Loss: 2.292752981185913 Acc: 16.5848336061102
Epoch: 4 Loss: 2.2878165245056152 Acc: 19.102353585112205
Epoch: 5 Loss: 2.2807109355926514 Acc: 30.45379989065063
Epoch: 6 Loss: 2.2687489986419678 Acc: 36.862314865606145
Epoch: 7 Loss: 2.243864059448242 Acc: 53.87149917627677
Epoch: 8 Loss: 2.1818437576293945 Acc: 57.158590308370044
Epoch: 9 Loss: 1.9545646905899048 Acc: 68.17685589519651
Epoch: 10 Loss: 1.2179653644561768 Acc: 76.1878453038674
Epoch: 11 Loss: 0.7220318913459778 Acc: 81.5401419989077
Epoch: 12 Loss: 0.5681466460227966 Acc: 85.42576419213974
Epoch: 13 Loss: 0.4850541353225708 Acc: 86.66666666666667
Epoch: 14 Loss: 0.4297293722629547 Acc: 88.14773980154355
Epoch: 15 Loss: 0.38810059428215027 Acc: 89.62162162162163
Epoch: 16 Loss: 0.3556075692176819 Acc: 90.46575342465754
Epoch: 17 Loss: 0.32955154

## Learning with noise

In [16]:
import torch.distributions as dist

class CustomNormalDistribution(nn.Module):
    def __init__(self, shape, mean, std):
        super(CustomNormalDistribution, self).__init__()
        self.shape = shape
        self.mean = nn.Parameter(torch.tensor(mean))
        self.std = nn.Parameter(torch.tensor(std))

    def forward(self):
        return dist.Normal(self.mean, self.std)


def noisybatch(images, mean=0.0, std=1.0):
    shape = images.shape
    custom_dist = CustomNormalDistribution(shape, mean, std)
    sample = custom_dist().rsample()
    return sample

Find global $l2$-sensitivity 

Now let us adjust distribution parameter $\sigma$ for $(\varepsilon, \delta)$-DP.

Code is taken from https://github.com/BorjaBalle/analytic-gaussian-mechanism/blob/master/agm-example.py

In [20]:
from math import exp, sqrt
from scipy.special import erf

def calibrateAnalyticGaussianMechanism(epsilon, delta, GS, tol = 1.e-12):
    """ Calibrate a Gaussian perturbation for differential privacy using the analytic Gaussian mechanism of [Balle and Wang, ICML'18]

    Arguments:
    epsilon : target epsilon (epsilon > 0)
    delta : target delta (0 < delta < 1)
    GS : upper bound on L2 global sensitivity (GS >= 0)
    tol : error tolerance for binary search (tol > 0)

    Output:
    sigma : standard deviation of Gaussian noise needed to achieve (epsilon,delta)-DP under global sensitivity GS
    """

    def Phi(t):
        return 0.5*(1.0 + erf(float(t)/sqrt(2.0)))

    def caseA(epsilon,s):
        return Phi(sqrt(epsilon*s)) - exp(epsilon)*Phi(-sqrt(epsilon*(s+2.0)))

    def caseB(epsilon,s):
        return Phi(-sqrt(epsilon*s)) - exp(epsilon)*Phi(-sqrt(epsilon*(s+2.0)))

    def doubling_trick(predicate_stop, s_inf, s_sup):
        while(not predicate_stop(s_sup)):
            s_inf = s_sup
            s_sup = 2.0*s_inf
        return s_inf, s_sup

    def binary_search(predicate_stop, predicate_left, s_inf, s_sup):
        s_mid = s_inf + (s_sup-s_inf)/2.0
        while(not predicate_stop(s_mid)):
            if (predicate_left(s_mid)):
                s_sup = s_mid
            else:
                s_inf = s_mid
            s_mid = s_inf + (s_sup-s_inf)/2.0
        return s_mid

    delta_thr = caseA(epsilon, 0.0)

    if (delta == delta_thr):
        alpha = 1.0

    else:
        if (delta > delta_thr):
            predicate_stop_DT = lambda s : caseA(epsilon, s) >= delta
            function_s_to_delta = lambda s : caseA(epsilon, s)
            predicate_left_BS = lambda s : function_s_to_delta(s) > delta
            function_s_to_alpha = lambda s : sqrt(1.0 + s/2.0) - sqrt(s/2.0)

        else:
            predicate_stop_DT = lambda s : caseB(epsilon, s) <= delta
            function_s_to_delta = lambda s : caseB(epsilon, s)
            predicate_left_BS = lambda s : function_s_to_delta(s) < delta
            function_s_to_alpha = lambda s : sqrt(1.0 + s/2.0) + sqrt(s/2.0)

        predicate_stop_BS = lambda s : abs(function_s_to_delta(s) - delta) <= tol

        s_inf, s_sup = doubling_trick(predicate_stop_DT, 0.0, 1.0)
        s_final = binary_search(predicate_stop_BS, predicate_left_BS, s_inf, s_sup)
        alpha = function_s_to_alpha(s_final)
        
    sigma = alpha*GS/sqrt(2.0*epsilon)

    return sigma

In [31]:
sigma = calibrateAnalyticGaussianMechanism(0.5, 0.5, 1.0)
sigma

0.5909175992591167

In [19]:
client_opt = torch.optim.SGD(client.parameters(), lr=0.001)
server_opt = torch.optim.SGD(server.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()

epochs = 50
for epoch in range(epochs):
    running_loss = 0
    for images, labels in trainloader:
        client_opt.zero_grad()
        server_opt.zero_grad()
        noise = noisybatch(images, mean=0.0, std=sigma)
        pred = server(client(torch.add(images, noise), end=split_layer), start=split_layer+1)

        loss = criterion(pred, labels)
        loss.backward()
        running_loss += loss

        server_opt.step()
        client_opt.step()
    else:
        print(f'Epoch: {epoch} Loss: {running_loss / len(trainloader)} Acc: {get_test_score(client, server, testset, split=split_layer)}')

Epoch: 0 Loss: 2.3023247718811035 Acc: 10.757409440175632
Epoch: 1 Loss: 2.3000195026397705 Acc: 9.764125068568294
Epoch: 2 Loss: 2.298835515975952 Acc: 11.95414847161572
Epoch: 3 Loss: 2.2971088886260986 Acc: 14.144736842105264
Epoch: 4 Loss: 2.295240879058838 Acc: 18.61353711790393
Epoch: 5 Loss: 2.2938835620880127 Acc: 18.489727928928374
Epoch: 6 Loss: 2.2917065620422363 Acc: 16.602528862012093
Epoch: 7 Loss: 2.2879931926727295 Acc: 19.493392070484582
Epoch: 8 Loss: 2.28412127494812 Acc: 26.464088397790054
Epoch: 9 Loss: 2.2770228385925293 Acc: 33.91589295466958
Epoch: 10 Loss: 2.2696893215179443 Acc: 38.58397365532382
Epoch: 11 Loss: 2.254446268081665 Acc: 39.33701657458563
Epoch: 12 Loss: 2.2243704795837402 Acc: 45.7347275729224
Epoch: 13 Loss: 2.1799306869506836 Acc: 52.05704882062534
Epoch: 14 Loss: 2.0922393798828125 Acc: 55.14425694066413
Epoch: 15 Loss: 1.9192321300506592 Acc: 61.06243154435926
Epoch: 16 Loss: 1.6574550867080688 Acc: 65.00274273176083
Epoch: 17 Loss: 1.388247