In [None]:
# local settings

import os, sys
from copy import deepcopy

import matplotlib.pyplot as plt

%load_ext autoreload
%autoreload 2

current = os.path.dirname(os.path.realpath("*.ipynb"))
parent = os.path.dirname(current)
sys.path.append(parent)
# %cd ..

In [None]:
# online settings (papermill execution)

plt.rcParams['figure.dpi'] = 300

In [None]:
# Parameters

num_epochs = 10
no_pca = False
noise_scale = 4

max_grad_norm = 4
lot_size = 600
hidden_size = 1000
q = None

In [None]:
from mnist import *

In [None]:
device = torch.device(
    'cuda' if torch.cuda.is_available() else
    'mps' if torch.backends.mps.is_available() else
    'cpu'
)

print(f'Using device: {device}')

In [None]:
%%capture

# data loaders
train_data = datasets.MNIST(
    root='data', download=True, train=True, transform=transforms.ToTensor())
test_data = datasets.MNIST(
    root='data', download=True, train=False, transform=transforms.ToTensor())

if not no_pca:
    # apply PCA to the dataset (as done in the paper)
    X_train = train_data.data.reshape(len(train_data), -1)
    X_test = test_data.data.reshape(len(test_data), -1)

    A = torch.cat([X_train, X_test]).float()
    pca_dim = 60
    _, _, V = torch.pca_lowrank(A, q=pca_dim)

    res = torch.matmul(A, V)

    X_train_pca_tensor = res[:60000]
    X_test_pca_tensor = res[60000:]
    y_train = train_data.targets
    y_test = test_data.targets

    # create torch datasets
    train_data = TensorDataset(X_train_pca_tensor, y_train)
    test_data = TensorDataset(X_test_pca_tensor, y_test)

# training settings
lot_size = lot_size if q is None else int(q * len(train_data))  # (L)

train_loader = torch.utils.data.DataLoader(
    train_data, batch_size=lot_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    test_data, batch_size=lot_size, shuffle=False)

In [None]:
# model = ConvNet().to(device)
accountant = accountants.ModifiedMomentsAccountant(noise_scale, q=lot_size / len(train_data))

# loss function
criterion = nn.CrossEntropyLoss()

In [None]:
from collections import defaultdict

all_logger = defaultdict(dict)

In [0]:
# differentially private optimizer
model = LinearNet(in_features=784 if no_pca else pca_dim, hidden=hidden_size).to(device)
optimizer = optim.PIAdam(model.named_parameters(), lot_size, lr=1e-4, noise_scale=noise_scale,
                         max_grad_norm=max_grad_norm)

all_logger['PIAdam'] = {'loss': [], 'total_loss': [], 'accuracy': [], 'total_accuracy': [], 'total_val_accuracy': [], 'epsilon': []}

train_dp_model(model, criterion, optimizer, num_epochs, train_loader, test_loader, device=device,
               logger=all_logger['PIAdam'], accountant=deepcopy(accountant))

In [0]:
logger = all_logger['PIAdam']

fig, ax = plt.subplots(1, 2, figsize=(15, 6), sharey=True)

ax[0].plot(logger['accuracy'])
ax[0].set_title('accuracy')

ax[1].plot(logger['total_accuracy'], label='train')
ax[1].set_title('per epoch accuracy')
ax[1].plot(logger['total_val_accuracy'], label='val')
ax[1].legend()
plt.show()

In [None]:
# differentially private optimizer
model = LinearNet(in_features=784 if no_pca else pca_dim, hidden=hidden_size).to(device)
optimizer = optim.DPSGD(model.named_parameters(), lot_size, lr=0.005, noise_scale=noise_scale,
                         max_grad_norm=max_grad_norm)

all_logger[optimizer.__class__.__name__] = {'loss': [], 'total_loss': [], 'accuracy': [], 'total_accuracy': [], 'total_val_accuracy': [], 'epsilon': []}

train_dp_model(model, criterion, optimizer, num_epochs, train_loader, test_loader, device=device,
               logger=all_logger[optimizer.__class__.__name__], accountant=deepcopy(accountant))

In [None]:
logger = all_logger[optimizer.__class__.__name__]

fig, ax = plt.subplots(1, 2, figsize=(15, 6), sharey=True)

ax[0].plot(logger['accuracy'])
ax[0].set_title('accuracy')

ax[1].plot(logger['total_accuracy'], label='train')
ax[1].set_title('per epoch accuracy')
ax[1].plot(logger['total_val_accuracy'], label='val')
ax[1].legend()
plt.show()

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10, 7))
for i, (name, logger) in enumerate(all_logger.items()):
    ax.plot(logger['total_val_accuracy'], label=name)
    
plt.legend()
plt.xlabel('Epoch')
plt.ylabel('Testing accuracy')
# plt.yticks([0.6, 0.7, 0.8, 0.9, 1.0])
plt.ylim(0.1, 1)
plt.grid(True)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10, 7))
for i, (name, logger) in enumerate(all_logger.items()):
    ax.plot(logger['epsilon'], logger['total_val_accuracy'], label=name)
    
plt.legend()
plt.xlabel('Epsilon')
plt.ylabel('Testing accuracy')
# plt.yticks([0.6, 0.7, 0.8, 0.9, 1.0])
plt.ylim(0.1, 1)
plt.grid(True)