In [None]:
import argparse
import itertools
import numpy as np
import pandas as pd
import os
import pickle
import random
import time

import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms

import torch.nn as nn
import torch.nn.functional as func
import torch.optim as optim
import torch.backends.cudnn as cudnn

from torch.optim.optimizer import required
from torch.autograd import Variable
from torch.autograd import Function

from bayesian_privacy_accountant import BayesianPrivacyAccountant

In [None]:
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', default='cifar10', help='mnist | cifar10 | svhn')
parser.add_argument('--dataroot', default='data', help='path to dataset')
parser.add_argument('--batchSize', type=int, default=512, help='input batch size')
parser.add_argument('--imageSize', type=int, default=32, help='the height / width of the input image to network')
parser.add_argument('--nClasses', type=int, default=10, help='number of labels (classes)')
parser.add_argument('--nChannels', type=int, default=3, help='number of colour channels')
parser.add_argument('--ndf', type=int, default=64)
parser.add_argument('--filterSize', type=int, default=5)
parser.add_argument('--n_epochs', type=int, default=2, help='number of epochs to train for')
parser.add_argument('--lr', type=float, default=0.001, help='learning rate, default=0.0002')
parser.add_argument('--C', type=float, default=1.0, help='embedding L2-norm bound, default=1.0')
parser.add_argument('--sigma', type=float, default=0.8, help='noise variance, default=0.5')
parser.add_argument('--cuda', action='store_true', help='enables cuda')
parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use')
parser.add_argument('--outf', default='output', help='folder to output images and model checkpoints')
parser.add_argument('--manualSeed', type=int, default=2560, help='manual seed for reproducibility')

opt, unknown = parser.parse_known_args()

try:
    os.makedirs(opt.outf)
except OSError:
    pass

if torch.cuda.is_available():
    opt.cuda = True
    opt.ngpu = 1
    gpu_id = 1
    print("Using CUDA: gpu_id = %d" % gpu_id)
    
if opt.manualSeed is None:
    opt.manualSeed = random.randint(1, 10000)
print("Random Seed: ", opt.manualSeed)
random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)
if opt.cuda:
    torch.cuda.manual_seed_all(opt.manualSeed)

cudnn.benchmark = True

In [None]:
class View(nn.Module):
    """
        Implements a reshaping module.
        Allows to reshape a tensor between NN layers.
    """
    
    def __init__(self, *shape):
        super(View, self).__init__()
        self.shape = shape
    
    def forward(self, input):
        return input.view(self.shape)

In [None]:
import torch.nn.functional as F

filterSize = 5
w_out = 5
h_out = 5

def compute_dim_out(num_conv_layers, dim_in, kernel_size=opt.filterSize, stride=2, padding=0, dilation=1):
    dim_out = dim_in
    for i in range(num_conv_layers):
        dim_out = np.int((dim_out + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1)
    return dim_out

class SimpleConvNet(nn.Module):
    
    def __init__(self):
        super(SimpleConvNet, self).__init__()
        
        self.conv = nn.Sequential(
            nn.Conv2d(opt.nChannels, opt.ndf, filterSize),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            nn.BatchNorm2d(opt.ndf),
            nn.Conv2d(opt.ndf, opt.ndf, filterSize),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(opt.ndf),
            nn.MaxPool2d(2, 2),
        )
        self.classifier = nn.Sequential(
            View(-1, opt.ndf * w_out * h_out),
            #PrintLayer("View"),
            nn.Linear(opt.ndf * w_out * h_out, 384),
            nn.ReLU(inplace=True),
            nn.Linear(384, 384),
            nn.ReLU(inplace=True),
            nn.Linear(384, opt.nClasses),
            #nn.Sigmoid()
        )

    def forward(self, x):
        x = self.conv(x)
        x = self.classifier(x)
        return x

In [None]:
def validate(student, valset=None):
    """
        Compute test accuracy.
    """
    real_loss = 0
    real_val_batches = len(valset.dataset) / opt.batchSize + 1
    if valset is not None:
        real_loss = 0
        for images, labels in valset:
            if opt.cuda:
                images = images.cuda(gpu_id)
                labels = labels.cuda(gpu_id)
        
            images, labelv = Variable(images), Variable(labels)
            
            outputs = student(images)
            real_loss += torch.nn.functional.cross_entropy(outputs, labelv).data.cpu().numpy()
        real_loss = real_loss / real_val_batches
        
    return real_loss

In [None]:
def test(testloader, net):
    """
        Compute test accuracy.
    """
    correct = 0.0
    total = 0.0
    
    '''
    if opt.cuda:
        net = net.cuda()
    '''
    
    for data in testloader:
        images, labels = data
        
        if opt.cuda:
            images = images.cuda(gpu_id)
            labels = labels.cuda(gpu_id)
            
        outputs = net(Variable(images))
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == (labels.long().view(-1) % opt.nClasses)).sum()

    print('Accuracy of the network on test images: %f %%' % (100 * float(correct) / total))
    return 100 * float(correct) / total

In [None]:
def sparsify_update(params, p, use_grad_field=True):
    init = True
    for param in params:
        if param is not None:
            if init:
                idx = torch.zeros_like(param, dtype=torch.bool)
                idx.bernoulli_(1 - p)
            if use_grad_field:
                if param.grad is not None:
                    idx = torch.zeros_like(param, dtype=torch.bool)
                    idx.bernoulli_(1 - p)
                    param.grad.data[idx] = 0
            else:
                init = False
                param.data[idx] = 0
    return idx

In [None]:
def train(trainloader, testloader, student, n_epochs=25, lr=0.0001, accountant=None):
    criterion = nn.CrossEntropyLoss(reduction='none')
    optimizer = optim.Adam(student.parameters(), lr=lr)
    #optimizer = optim.SGD(student.parameters(), lr=lr)
    
    if opt.cuda:
        student = student.cuda(gpu_id)
        criterion = criterion.cuda(gpu_id)
    
    accuracies = []
    
    num_batches = len(trainloader.dataset) / opt.batchSize + 1
    sampling_prob = 1.0
    
    max_grad_norm = opt.C
    sigma = opt.sigma * max_grad_norm
    
    for epoch in range(n_epochs):  # loop over the dataset multiple times
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            # get the inputs
            inputs, labels = data
            
            if opt.cuda:
                inputs = inputs.cuda(gpu_id)
                labels = labels.cuda(gpu_id)
                
            batch_size = float(len(inputs))
            
            inputv = Variable(inputs)
            labelv = Variable(labels.long().view(-1) % opt.nClasses)
            
            # zero the parameter gradients
            optimizer.zero_grad()
            
            # forward + backward + optimize
            outputs = student(inputv)
            loss = criterion(outputs, labelv)
            
            if accountant:
                grads_est = []
                num_subbatch = 8
                for j in range(num_subbatch):
                    grad_sample = torch.autograd.grad(loss[np.delete(range(int(batch_size)), j)].mean(), [p for p in student.parameters() if p.requires_grad], retain_graph=True)
                    with torch.no_grad():
                        grad_sample = torch.cat([g.view(-1) for g in grad_sample])
                        #grad_sample /= max(1.0, grad_sample.norm().item() / max_grad_norm)
                        grads_est += [grad_sample]
                with torch.no_grad():
                    sparsify_update(grads_est, p=sampling_prob, use_grad_field=False)
                    
            (loss.mean()).backward()
            running_loss += loss.mean().item()
            
            if accountant:
                #torch.nn.utils.clip_grad_norm_(student.parameters(), max_grad_norm)
                for group in optimizer.param_groups:
                    for p in group['params']:
                        if p.grad is not None:
                            p.grad.data += torch.randn_like(p.grad) * sigma #* max_grad_norm
                sparsify_update(student.parameters(), p=sampling_prob)
            
            
            optimizer.step()
            
            if accountant:
                with torch.no_grad():
                    q = batch_size / len(trainloader.dataset)
                    # NOTE: 
                    # Using combinations within a set of gradients (like below)
                    # does not actually produce samples from the correct distribution
                    # (for that, we need to sample pairs of gradients independently).
                    # However, the difference is not significant, and it speeds up computations.
                    grad_pairs = list(zip(*itertools.combinations(grads_est, 2)))
                    accountant.accumulate(
                        ldistr=(torch.stack(grad_pairs[0]), opt.sigma*max_grad_norm),
                        rdistr=(torch.stack(grad_pairs[1]), opt.sigma*max_grad_norm),
                        q=q, 
                        steps=1,
                    )
                    running_eps = accountant.get_privacy(target_delta=1e-5)
            
        # print training stats every epoch
        running_eps = accountant.get_privacy(target_delta=1e-5) if accountant else None
        print("Epoch: %d/%d. Loss: %.3f. Privacy (𝜀,𝛿): %s" %
              (epoch + 1, n_epochs, running_loss / len(trainloader), running_eps))

        student.eval()
        acc = test(testloader, student)
        accuracies += [acc]
        student.train()
        print("Test accuracy is %d %%" % acc)
        save_step = 100
        if (epoch + 1) % save_step == 0:
            torch.save(student.state_dict(), '%s/private_net_epoch_%d.pth' % (opt.outf, epoch + 1))
            pickle.dump(accountant, open('%s/bayes_accountant_epoch_%d' % (opt.outf, epoch + 1), 'wb'))

    print('Finished Training')
    return student.cpu(), accuracies

In [None]:
# transformations applied to data
transform = transforms.Compose([transforms.Resize(size=(128, 128)),
                                transforms.RandomHorizontalFlip(), 
                                transforms.ColorJitter(),
                                transforms.ToTensor(),
                                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
transform_test = transforms.Compose([transforms.Resize(size=(128, 128)),
                                     transforms.ToTensor(),
                                     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

# switch datasets
if opt.dataset == 'mnist':
    trainset = torchvision.datasets.MNIST(root=opt.dataroot + os.sep + opt.dataset, train=True, download=True, transform=transform)
    testset = torchvision.datasets.MNIST(root=opt.dataroot + os.sep + opt.dataset, train=False, download=True, transform=transform)
elif opt.dataset == 'cifar10':
    trainset = torchvision.datasets.CIFAR10(root=opt.dataroot + os.sep + opt.dataset, train=True, download=True, transform=transform)
    testset = torchvision.datasets.CIFAR10(root=opt.dataroot + os.sep + opt.dataset, train=False, download=True, transform=transform_test)
    trainset100 = torchvision.datasets.CIFAR100(root=opt.dataroot + os.sep + opt.dataset, train=True, download=True, transform=transform)
    testset100 = torchvision.datasets.CIFAR100(root=opt.dataroot + os.sep + opt.dataset, train=False, download=True, transform=transform_test)
elif opt.dataset == 'svhn':
    trainset = torchvision.datasets.SVHN(root=opt.dataroot + os.sep + opt.dataset, split='train', download=True, transform=transform)
    testset = torchvision.datasets.SVHN(root=opt.dataroot + os.sep + opt.dataset, split='test', download=True, transform=transform)
    
# initialise data loaders
trainloader = torch.utils.data.DataLoader(trainset, batch_size=opt.batchSize, shuffle=True, num_workers=2, drop_last=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=opt.batchSize, shuffle=False, num_workers=2)
if opt.dataset == 'cifar10':
    trainloader100 = torch.utils.data.DataLoader(trainset100, batch_size=opt.batchSize, shuffle=True, num_workers=2, drop_last=True)
    testloader100 = torch.utils.data.DataLoader(testset100, batch_size=opt.batchSize, shuffle=False, num_workers=2)

In [None]:
# train GAN and measure time
start = time.time()

netS = torchvision.models.vgg16(pretrained=True)
netS.train()
# Disable updates for feature extraction layers
for p in netS.named_parameters():
    p[1].requires_grad = False
last_layers = [p for p in netS.parameters()][-2:]
netS.classifier[-1] = nn.Linear(4096, 10)
last_layers[-1] = nn.Parameter(last_layers[-1][:10])
last_layers[-2] = nn.Parameter(last_layers[-2][:10])
for ll in last_layers:
    ll.requires_grad = True
netS.aux_logits = False

total_steps = opt.n_epochs * len(trainloader)
bayes_accountant = BayesianPrivacyAccountant(powers=32, total_steps=total_steps)
netS, accs = train(trainloader, testloader, netS, lr=0.001, n_epochs=opt.n_epochs, accountant=bayes_accountant)

stop = time.time()
print("Time elapsed: %f" % (stop - start))

In [None]:
print("Epsilon = ", bayes_accountant.get_privacy(target_delta=1e-10))

In [None]:
from scipy.spatial.distance import pdist

loss_fn = nn.CrossEntropyLoss(reduction='none')

netS.cuda(gpu_id)
loss_fn.cuda(gpu_id)

dists_train = []
dists_test = []

i = 0
for inputs, labels in trainloader:
    i += 1
    if i > 10:
        break
    inputs = inputs.cuda(gpu_id)
    labels = labels.cuda(gpu_id)
    netS.zero_grad()
    outputs = netS(inputs)
    loss = loss_fn(outputs, labels)
    
    grads_est = []
    num_subbatch = 100
    for j in range(num_subbatch):
        grad_sample = torch.autograd.grad(loss[np.delete(range(int(opt.batchSize)), j)].mean(), [p for p in netS.parameters() if p.requires_grad], retain_graph=True)
        with torch.no_grad():
            grad_sample = torch.cat([g.view(-1) for g in grad_sample])
            grad_sample /= max(1.0, grad_sample.norm().item() / opt.C)
            grads_est += [grad_sample]
    with torch.no_grad():
        grads_est = torch.stack(grads_est)
        #sparsify_update(grads_est, p=sampling_prob, use_grad_field=False)
    q = opt.batchSize / len(trainloader.dataset)
    dists_train += [pdist(grads_est.cpu())]
    
i = 0
for inputs, labels in testloader:
    i += 1
    if i > 10:
        break
    inputs = inputs.cuda(gpu_id)
    labels = labels.cuda(gpu_id)
    netS.zero_grad()
    outputs = netS(inputs)
    loss = loss_fn(outputs, labels)
    
    grads_est = []
    num_subbatch = 100
    for j in range(num_subbatch):
        grad_sample = torch.autograd.grad(loss[np.delete(range(int(opt.batchSize)), j)].mean(), [p for p in netS.parameters() if p.requires_grad], retain_graph=True)
        with torch.no_grad():
            grad_sample = torch.cat([g.view(-1) for g in grad_sample])
            grad_sample /= max(1.0, grad_sample.norm().item() / opt.C)
            grads_est += [grad_sample]
    with torch.no_grad():
        grads_est = torch.stack(grads_est)
        #sparsify_update(grads_est, p=sampling_prob, use_grad_field=False)
    q = opt.batchSize / len(trainloader.dataset)
    dists_test += [pdist(grads_est.cpu())]

In [None]:
dists_train = np.stack(dists_train).squeeze()
dists_test = np.stack(dists_test).squeeze()

plt.hist(dists_train.flatten(), bins=np.arange(0, 0.2, 0.005), label='Train', alpha=0.5)
plt.hist(dists_test.flatten(), bins=np.arange(0, 0.2, 0.005), label='Test', alpha=0.5)
plt.legend()
plt.xlabel(r'Distance')
plt.ylabel(r'Number of samples')
plt.title(r'Pairwise gradient distances distribution, CIFAR10')
plt.savefig('grad_dist_histogram_cifar10.pdf', format='pdf', bbox_inches='tight')

In [None]:
from scipy.stats import ttest_rel, ttest_ind, levene

print(ttest_rel(dists_train.flatten(), dists_test.flatten()))
print(ttest_ind(dists_train.flatten(), dists_test.flatten()))
print(levene(dists_train.flatten(), dists_test.flatten()))