In [1]:
# Application of FL task
from MLModel import *
from FLModel import *
from utils import *

from torchvision import datasets, transforms
import torch
import numpy as np
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "1"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#device = torch.device("cpu")

In [2]:
def load_cnn_mnist(num_users):
    train = datasets.MNIST(root="~/data/", train=True, download=True, transform=transforms.ToTensor())
    train_data = train.data.float().unsqueeze(1)
    train_label = train.targets

    mean = train_data.mean()
    std = train_data.std()
    train_data = (train_data - mean) / std

    test = datasets.MNIST(root="~/data/", train=False, download=True, transform=transforms.ToTensor())
    test_data = test.data.float().unsqueeze(1)
    test_label = test.targets
    test_data = (test_data - mean) / std

    # split MNIST (training set) into non-iid data sets
    non_iid = []
    user_dict = mnist_noniid(train_label, num_users)
    for i in range(num_users):
        idx = user_dict[i]
        d = train_data[idx]
        targets = train_label[idx].float()
        non_iid.append((d, targets))
    non_iid.append((test_data.float(), test_label.float()))
    return non_iid

In [3]:
"""
1. load_data
2. generate clients (step 3)
3. generate aggregator
4. training
"""
client_num = 4
d = load_cnn_mnist(client_num)

In [4]:
"""
FL model parameters.
"""
import warnings
warnings.filterwarnings("ignore")

lr = 0.1

fl_param = {
    'output_size': 10,
    'client_num': client_num,
    'model': MnistCNN,
    'data': d,
    'lr': lr,
    'E': 40,
    'C': 1,
    'eps': 6.0,
    'delta': 1e-5,
    'q': 0.05,
    'clip': 15,
    'tot_T': 5,
    'batch_size': 128,
    'device': device
}

fl_entity = FLServer(fl_param).to(device)

DP-SGD with sampling rate = 5% and noise_multiplier = 2.691766953138113 iterated over 4000 steps satisfies differential privacy with eps = 6 and delta = 1e-05.


In [5]:
import time

acc = []
start_time = time.time()
for t in range(fl_param['tot_T']):
    acc += [fl_entity.global_update()]
    print("global epochs = {:d}, acc = {:.4f}".format(t+1, acc[-1]), " Time taken: %.2fs" % (time.time() - start_time))

global epochs = 1, acc = 0.6035  Time taken: 186.98s
global epochs = 2, acc = 0.7875  Time taken: 353.05s
global epochs = 3, acc = 0.8891  Time taken: 536.27s
global epochs = 4, acc = 0.9005  Time taken: 735.96s
global epochs = 5, acc = 0.9111  Time taken: 909.97s
