In [1]:
import os
import copy
import time
import pickle
import numpy as np
import yaml
from tqdm import tqdm

import torch
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torch.nn.functional as F

from synthetic_dataloader import train_val_dataloader
from gd import LrdGD, GD
from Synthetic import SyntheticDataset

torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(0)

In [2]:
config = dict()
config['local_batch_size'] = 10
config['local_iters'] = 20           # E
config['global_iters'] = 200         # T

config['num_devices'] = 30           # p
config['num_active_devices'] = 10

config['lr'] = 0.01
config['num_classes'] = 10
config['device'] = 'cuda:0'

config['iid'] = 0

config['alpha'] = 0
config['beta'] = 0

config['dimension'] = 60

config_path = '../config/synthetic_iid.yaml' 
with open(config_path, 'w') as f:
    yaml.dump(config, f)

In [3]:
with open(config_path, 'r') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)
    print(config)

{'alpha': 0, 'beta': 0, 'device': 'cuda:0', 'dimension': 60, 'global_iters': 200, 'iid': 0, 'local_batch_size': 10, 'local_iters': 20, 'lr': 0.01, 'num_active_devices': 10, 'num_classes': 10, 'num_devices': 30}


In [4]:
# define training model
class MLP(nn.Module):
    def __init__(self, dim_in, dim_out):
        super(MLP, self).__init__()
        self.layer_hidden = nn.Linear(dim_in, dim_out)

    def forward(self, x):
        x = self.layer_hidden(x)
        return F.log_softmax(x, dim=1)
    
device = torch.device(config['device'])
    
global_model = MLP(dim_in=config['dimension'], dim_out=config['num_classes']).to(device)
# global_optim = torch.optim.SGD(global_model.parameters(), lr=config['lr'], weight_decay=1e-4)
# global_optim = LrdGD(global_model.parameters(), lr=config['lr'], weight_decay=1e-4)
# global_optim = GD(global_model.parameters(), lr=config['lr'], weight_decay=1e-4)

local_model_list = []
local_optim_list = []
for local_id in range(config['num_devices']):
    local_model = MLP(dim_in=config['dimension'], dim_out=config['num_classes']).to(device)
    local_optim = GD(local_model.parameters(), lr=config['lr'], weight_decay=1e-4)
    # local_optim = torch.optim.SGD(local_model.parameters(), lr=config['lr'], weight_decay=1e-4)
    # local_optim = LrdGD(local_model.parameters(), lr=config['lr'], weight_decay=1e-4)
    local_model_list.append(local_model)
    local_optim_list.append(local_optim)

global_init_weight = copy.deepcopy(global_model.state_dict())
local_model_init_weight_list = []
for local_id in range(config['num_devices']):
    local_weight = copy.deepcopy(local_model_list[local_id].state_dict())
    local_model_init_weight_list.append(local_weight)

criterion = nn.NLLLoss().to(device)

In [5]:
def inference(model, dataloader):
    """ Returns the inference accuracy and loss.
    """

    model.eval()
    total, correct = 0.0, 0.0
    loss = list()
    
    for batch_idx, (images, labels) in enumerate(dataloader):
        images, labels = images.to(device), labels.to(device)

        # Inference
        outputs = model(images)
        batch_loss = criterion(outputs, labels)
        loss += [batch_loss.item()]

        # Prediction
        _, pred_labels = torch.max(outputs, 1)
        pred_labels = pred_labels.view(-1)
        correct += torch.sum(torch.eq(pred_labels, labels)).item()
        total += len(labels)

    accuracy = correct/total
    loss = sum(loss)/len(loss)
    return accuracy, loss

In [6]:
def average_state_dicts(w, weight):
    """
    Returns the average of the weights or gradients.
    """
    weight = weight/sum(weight)
    
    w_avg = copy.deepcopy(w[0])
    for key in w_avg.keys():
        w_avg[key] = torch.zeros_like(w_avg[key])
        for i in range(len(w)):
            w_avg[key] = w_avg[key] + w[i][key]*weight[i]
    return w_avg

# FedAvg

In [7]:
"""
generate training data
"""
# X_split, y_split, weight_per_user = generate_synthetic(alpha=config['alpha'], beta=config['beta'], iid=config['iid'], 
#                                       num_user=config['num_devices'], dimension=config['dimension'],
#                                       num_class=config['num_classes'])

# trainloader_list, validloader_list = [], []
# trainloader_iterator_list, validloader_iterator_list = [], []
# for local_id in range(config['num_devices']):
#     trainloader, validloader = train_val_dataloader(X_split, y_split, local_id, batch_size=config['local_batch_size'])
#     trainloader_list.append(trainloader)
#     validloader_list.append(validloader)
    
#     trainloader_iterator_list.append(iter(trainloader_list[local_id]))
#     validloader_iterator_list.append(iter(validloader_list[local_id]))

synthetic_dataset = SyntheticDataset(num_classes=config['num_classes'], 
                                     num_tasks=config['num_devices'], 
                                     num_dim=config['dimension'],
                                     alpha=config['alpha'], beta=config['beta'])
data = synthetic_dataset.get_all_tasks()
num_samples = synthetic_dataset.get_num_samples()
weight_per_user = num_samples/sum(num_samples)

trainloader_list, validloader_list = [], []
trainloader_iterator_list, validloader_iterator_list = [], []
for local_id in range(config['num_devices']):
    trainloader, validloader = train_val_dataloader(data[local_id]['x'], data[local_id]['y'], batch_size=config['local_batch_size'])
    trainloader_list.append(trainloader)
    validloader_list.append(validloader)
    
    trainloader_iterator_list.append(iter(trainloader_list[local_id]))
    validloader_iterator_list.append(iter(validloader_list[local_id]))
# for local_id in range(config['num_devices']):
#     print(data[local_id]['x'].shape, data[local_id]['y'].shape, num_samples[local_id])
    
"""
load initial value
"""
global_model.load_state_dict(global_init_weight)
for local_id in range(config['num_devices']):
    local_model_list[local_id].load_state_dict(local_model_init_weight_list[local_id])

"""
start training
"""
global_acc = []
global_loss = []
    
# test global model
list_acc, list_loss = [], [] 
for local_id in range(config['num_devices']):
    acc, loss = inference(global_model, validloader_list[local_id])
    list_acc.append(acc)
    list_loss.append(loss)
global_acc +=  [sum(list_acc)/len(list_acc)]
global_loss += [sum(list_loss)/len(list_loss)]
print('global %d, acc %f, loss %f'%(len(global_acc), global_acc[-1], global_loss[-1]))
    
for global_iter in range(config['global_iters']): # T
    activate_devices = np.random.permutation(config['num_devices'])[:config['num_active_devices']] # np.arange(config['num_devices'])
    # get the local grad for each device
    for local_id in activate_devices: # K
        for local_iter in range(config['local_iters']): # E
            # load single mini-batch
            try:
                inputs, labels = next(trainloader_iterator_list[local_id])
            except StopIteration:
                trainloader_iterator_list[local_id] = iter(trainloader_list[local_id])
                inputs, labels = next(trainloader_iterator_list[local_id])

            # train local model
            inputs, labels = inputs.to(device), labels.to(device)
            local_model_list[local_id].train()
            local_model_list[local_id].zero_grad()
            log_probs = local_model_list[local_id](inputs)
            loss = criterion(log_probs, labels)
            loss.backward()
            local_optim_list[local_id].step() # lr=config['lr']/(1+global_iter)
            
        local_optim_list[local_id].inverse_prop_decay_learning_rate(global_iter)
        # acc, loss = inference(local_model_list[local_id], validloader_list[local_id])
        # print('local id %d, acc %f, loss %f'%(local_id, acc, loss))
    
    # average local models 
    local_weight_list = [local_model.state_dict() for local_model in local_model_list]
    avg_local_weight = average_state_dicts(local_weight_list, weight_per_user)
    global_model.load_state_dict(avg_local_weight)
    for local_id in range(config['num_devices']):
        local_model_list[local_id].load_state_dict(avg_local_weight)
        
    # test global model
    list_acc, list_loss = [], [] 
    for local_id in range(config['num_devices']):
        acc, loss = inference(global_model, validloader_list[local_id])
        list_acc.append(acc)
        list_loss.append(loss)
    global_acc +=  [sum(list_acc)/len(list_acc)]
    global_loss += [sum(list_loss)/len(list_loss)]
    print('global %d, acc %f, loss %f'%(len(global_acc), global_acc[-1], global_loss[-1]))
    
with open('fedavg_%.2f.pkl'%config['beta'], 'wb') as f:
    pickle.dump([global_acc, global_loss], f)

global 1, acc 0.104004, loss 2.348573
global 2, acc 0.196250, loss 2.272546
global 3, acc 0.284714, loss 2.235676
global 4, acc 0.346671, loss 2.204842
global 5, acc 0.400072, loss 2.178353
global 6, acc 0.446480, loss 2.157291
global 7, acc 0.455227, loss 2.140608
global 8, acc 0.461343, loss 2.122640
global 9, acc 0.481048, loss 2.110097
global 10, acc 0.503477, loss 2.098779
global 11, acc 0.519802, loss 2.087358
global 12, acc 0.537644, loss 2.077344
global 13, acc 0.543943, loss 2.070010
global 14, acc 0.552603, loss 2.060399
global 15, acc 0.557559, loss 2.052804
global 16, acc 0.568651, loss 2.047094
global 17, acc 0.583714, loss 2.042224
global 18, acc 0.586457, loss 2.036509
global 19, acc 0.597384, loss 2.032245
global 20, acc 0.601012, loss 2.027309
global 21, acc 0.603734, loss 2.022967
global 22, acc 0.607307, loss 2.016768
global 23, acc 0.618585, loss 2.012919
global 24, acc 0.628164, loss 2.010095
global 25, acc 0.633407, loss 2.007187
global 26, acc 0.641044, loss 2.00

In [8]:
import matplotlib.pyplot as plt

fig, axs = plt.subplots()

# diversity 0
with open('fedavg_0.00.pkl', 'rb') as f:
    global_acc, global_loss = pickle.load(f)
    
y = global_acc
x = np.arange(len(y))
axs.plot(x,y,label='Diversity 0')

# # diversity 0.25
# with open('fedavg_0.25.pkl', 'rb') as f:
#     global_acc, global_loss = pickle.load(f)
    
# y = global_acc
# x = np.arange(len(y))
# axs.plot(x,y,label='Diversity 0.25')

# diversity 0.5
with open('fedavg_0.50.pkl', 'rb') as f:
    global_acc, global_loss = pickle.load(f)
    
y = global_acc
x = np.arange(len(y))
axs.plot(x,y,label='Diversity 0.5')

# # diversity 0.75
# with open('fedavg_0.75.pkl', 'rb') as f:
#     global_acc, global_loss = pickle.load(f)
    
# y = global_acc
# x = np.arange(len(y))
# axs.plot(x,y,label='Diversity 0.75')

# diversity 1
with open('fedavg_1.00.pkl', 'rb') as f:
    global_acc, global_loss = pickle.load(f)
    
y = global_acc
x = np.arange(len(y))
axs.plot(x,y,label='Diversity 1')
    
axs.set_xlabel('Communication round')
axs.set_ylabel('Global model accuracy')
axs.grid(True)

plt.title('Synthetic dataset - Accuracy / Communication round')
fig.tight_layout()
plt.legend()
plt.savefig('fedavg_diff_diversity.pdf')
plt.close()

In [9]:
config

{'alpha': 0,
 'beta': 0,
 'device': 'cuda:0',
 'dimension': 60,
 'global_iters': 200,
 'iid': 0,
 'local_batch_size': 10,
 'local_iters': 20,
 'lr': 0.01,
 'num_active_devices': 10,
 'num_classes': 10,
 'num_devices': 30}