In [None]:
from main import *

In [None]:
# Parameters

num_epochs = 100
noise_scale = 4

max_grad_norm = 4
lot_size = 600
hidden_size = 1000
q = None

In [None]:
device = torch.device(
    'cuda' if torch.cuda.is_available() else
    'mps' if torch.backends.mps.is_available() else
    'cpu'
)

print(f'Using device: {device}')

### Loading data

In [None]:
# data loaders
train_data = datasets.MNIST(
    root='data', download=True, train=True, transform=transforms.ToTensor())
test_data = datasets.MNIST(
    root='data', download=True, train=False, transform=transforms.ToTensor())

# training settings
lot_size = lot_size if q is None else int(q * len(train_data))  # (L)

train_loader = torch.utils.data.DataLoader(
    train_data, batch_size=lot_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    test_data, batch_size=lot_size, shuffle=False)

In [None]:
accountant = accountants.MomentsAccountant(noise_scale, q=lot_size / len(train_data))

# loss function
criterion = nn.CrossEntropyLoss()

### Training

In [None]:
model = LinearNet(in_features=784, hidden=hidden_size).to(device)

# differentially private optimizer
optimizer = optim.PIAdam(model.named_parameters(), lot_size, lr=1e-4, noise_scale=noise_scale,
                         max_grad_norm=max_grad_norm)

logger = {'loss': [], 'total_loss': [], 'accuracy': [], 'total_accuracy': [], 'total_val_accuracy': [], 'epsilon': []}

train_dp_model(model, criterion, optimizer, num_epochs, train_loader, test_loader, device=device,
               logger=logger, accountant=accountant)

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(15, 6), sharey=True)

ax[0].plot(logger['accuracy'])
ax[0].set_title('accuracy')

ax[1].plot(logger['total_accuracy'], label='train')
ax[1].set_title('per epoch accuracy')
ax[1].plot(logger['total_val_accuracy'], label='val')
ax[1].legend()
plt.show()