In [1]:
# Application of FL task
from MLModel import *
from FLModel import *
from utils import *

from torchvision import datasets, transforms
import torch
import numpy as np
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#device = torch.device("cpu")

In [2]:
scattering, K, (h, w) = get_scatter_transform()
scattering.to(device)

def get_scattered_feature(dataset):
    scatters = []
    targets = []
    
    loader = torch.utils.data.DataLoader(
        dataset, batch_size=256, shuffle=True, num_workers=1, pin_memory=True)

    
    for (data, target) in loader:
        data, target = data.to(device), target.to(device)
        if scattering is not None:
            data = scattering(data)
        scatters.append(data)
        targets.append(target)

    scatters = torch.cat(scatters, axis=0)
    targets = torch.cat(targets, axis=0)

    data = torch.utils.data.TensorDataset(scatters, targets)
    return data

def load_mnist(num_users):
    train = datasets.MNIST(root="~/data/", train=True, download=True, transform=transforms.ToTensor())
    test = datasets.MNIST(root="~/data/", train=False, download=True, transform=transforms.ToTensor())
    
    # get scattered features
    train = get_scattered_feature(train)
    test = get_scattered_feature(test)
    
    train_data = train[:][0].squeeze().cpu().float()
    train_label = train[:][1].cpu()
    
    test_data = test[:][0].squeeze().cpu().float()
    test_label = test[:][1].cpu()

    # split MNIST (training set) into non-iid data sets
    non_iid = []
    user_dict = mnist_noniid(train_label, num_users)
    for i in range(num_users):
        idx = user_dict[i]
        d = train_data[idx]
        targets = train_label[idx].float()
        non_iid.append((d, targets))
    non_iid.append((test_data.float(), test_label.float()))
    return non_iid

In [3]:
"""
1. load_data
2. generate clients (step 3)
3. generate aggregator
4. training
"""
client_num = 4
d = load_mnist(client_num)

torch.cuda.empty_cache()

In [4]:
d[1][0][0].shape

torch.Size([81, 7, 7])

In [5]:
"""
FL model parameters.
"""
import warnings
warnings.filterwarnings("ignore")

lr = 0.075

fl_param = {
    'output_size': 10,
    'K': K,
    'h': h,
    'w': w,
    'client_num': client_num,
    'model': 'scatter',
    'data': d,
    'lr': lr,
    'E': 500,
    'C': 1,
    'eps': 4.0,
    'delta': 1e-5,
    'q': 0.01,
    'clip': 0.1,
    'tot_T': 10,
    'batch_size': 128,
    'device': device
}

fl_entity = FLServer(fl_param).to(device)

noise scale = 1.0771102905273438


In [6]:
import time

acc = []
start_time = time.time()
for t in range(fl_param['tot_T']):
    acc += [fl_entity.global_update()]
    print("global epochs = {:d}, acc = {:.4f}".format(t+1, acc[-1]), " Time taken: %.2fs" % (time.time() - start_time))

global epochs = 1, acc = 0.8842  Time taken: 161.98s
global epochs = 2, acc = 0.9348  Time taken: 322.99s
global epochs = 3, acc = 0.9546  Time taken: 486.85s
global epochs = 4, acc = 0.9600  Time taken: 648.92s
global epochs = 5, acc = 0.9657  Time taken: 807.26s
global epochs = 6, acc = 0.9666  Time taken: 959.25s
global epochs = 7, acc = 0.9704  Time taken: 1109.23s
global epochs = 8, acc = 0.9712  Time taken: 1257.56s
global epochs = 9, acc = 0.9739  Time taken: 1400.09s
global epochs = 10, acc = 0.9742  Time taken: 1538.17s


In [None]:
# SGD (mnt=0.9)