In [None]:
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License

import os
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch
from torchvision import datasets, transforms
import warnings
from privacygan import privacy_gan as pg
from privacygan.cifar import cifar_gan
from datetime import datetime

warnings.filterwarnings("ignore")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(torch.__version__, "device:", device)


In [None]:
# Load CIFAR10 data and concatenate the train+test set using torchvision
transform = transforms.Compose([transforms.ToTensor()])
train_ds = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
test_ds = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)

X_train = (torch.tensor(train_ds.data).float() - 127.5) / 127.5
X_test = (torch.tensor(test_ds.data).float() - 127.5) / 127.5
X_train = X_train.permute(0, 3, 1, 2).numpy()
X_test = X_test.permute(0, 3, 1, 2).numpy()
X_all = np.concatenate((X_train, X_test))

# Generate training/test split
frac = 0.1
n = int(frac * len(X_all))
idx = np.random.permutation(len(X_all))
X = X_all[idx[:n]]
X_comp = X_all[idx[n:]]

print('training set size:', X.shape)
print('test set size:', X_comp.shape)


In [None]:
(generator, discriminator, dLosses, gLosses) = pg.SimpGAN(
    X,
    generator=cifar_gan.CIFAR_Generator(),
    discriminator=cifar_gan.CIFAR_Discriminator(),
    epochs=1,
    batchSize=128,
)


In [None]:
# Perform white box attack
Acc = pg.WBattack(X, X_comp, discriminator)

In [None]:
with torch.no_grad():
    scores_train = discriminator(torch.tensor(X, dtype=torch.float32, device=device)).cpu().numpy().squeeze()
    scores_test = discriminator(torch.tensor(X_comp, dtype=torch.float32, device=device)).cpu().numpy().squeeze()

plt.hist(scores_train, color='r', alpha=0.5, label='train', density=True, bins=50)
plt.hist(scores_test, color='b', alpha=0.5, label='test', density=True, bins=50)
plt.xlabel('Discriminator probability')
plt.ylabel('Normed frequency')
plt.title('GAN')
plt.legend()


In [None]:
# Generate fake images
pg.DisplayImages(generator, figSize=(5,5), NoImages=25, TargetShape=(32,32,3))

In [None]:
### Test privGAN (with n_reps = 2)
generators = [cifar_gan.CIFAR_Generator(), cifar_gan.CIFAR_Generator()]
discriminators = [cifar_gan.CIFAR_Discriminator(), cifar_gan.CIFAR_Discriminator()]
pDisc = cifar_gan.CIFAR_DiscriminatorPrivate(OutSize=2)

(generators, discriminators, _, dLosses, dpLosses, gLosses) = pg.privGAN(
    X,
    epochs=1,
    disc_epochs=1,
    batchSize=128,
    generators=generators,
    discriminators=discriminators,
    pDisc=pDisc,
    privacy_ratio=1.0,
)


In [None]:
# Perform white box attack for privGAN
pg.WBattack_priv(X, X_comp, discriminators)

In [None]:
#perform white box attack
pg.WBattack_priv(X,X_comp, discriminators)

In [None]:
with torch.no_grad():
    scores_train = discriminators[0](torch.tensor(X, dtype=torch.float32, device=device)).cpu().numpy().squeeze()
    scores_test = discriminators[0](torch.tensor(X_comp, dtype=torch.float32, device=device)).cpu().numpy().squeeze()

plt.hist(scores_train, color='r', alpha=0.5, label='train', density=True, bins=50)
plt.hist(scores_test, color='b', alpha=0.5, label='test', density=True, bins=50)
plt.xlabel('Discriminator probability')
plt.ylabel('Normed frequency')
plt.title('privGAN')
plt.legend()


In [None]:
plt.hist(discriminators[0].predict(X)[:,0],color = 'r', alpha = 0.5, label = 'train', density = True, bins = 50)
plt.hist(discriminators[0].predict(X_comp)[:,0],color = 'b', alpha = 0.5, label = 'test', density = True, bins = 50)
plt.xlabel('Discriminator probability')
plt.ylabel('Normed frequency')
plt.title('privGAN (1.0)')
plt.legend()

In [None]:
pg.WBattack_priv(X,X_comp, discriminators)