In [18]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np
from tqdm import tqdm
import copy
import matplotlib 
import matplotlib.pyplot as plt
matplotlib.use('Agg')
%matplotlib inline

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
        self.fc1 = nn.Linear(2048, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.conv3(x)
        x = F.relu(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output


def client_update(client_model, optimizer, train_loader, epoch):
    client_model.train()
    for e in range(epoch):
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.cuda(), target.cuda()
            optimizer.zero_grad()
            output = client_model(data)
            loss = F.nll_loss(output, target)/epoch
            loss.backward()
            optimizer.step()
    return loss.item()

def server_aggregate(global_model, client_models):
    global_dict = global_model.state_dict()
    for k in global_dict.keys():
        global_dict[k] = torch.stack([client_models[i].state_dict()[k] for i in range(len(client_models))], 0).mean(0)
    global_model.load_state_dict(global_dict)
    for model in client_models:
        model.load_state_dict(global_model.state_dict())

def test(global_model, test_loader):
    global_model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.cuda(), target.cuda()
            output = global_model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    acc = correct / len(test_loader.dataset)

    return test_loss, acc

In [19]:
# NON-IID case: every client has images of two categories chosen from [0, 1], [2, 3], [4, 5], [6, 7], or [8, 9].

# Hyperparameters

num_clients = 100
num_selected = 10
num_rounds = 100
epochs = 10
batch_size = 32
local_ep_list = np.random.choice(range(1,epochs+1),size=num_clients)
lr = 0.01
# Creating decentralized datasets

traindata = datasets.MNIST('./data', train=True, download=True,
                       transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])
                       )
target_labels = torch.stack([traindata.targets == i for i in range(10)])
target_labels_split = []
for i in range(5):
    target_labels_split += torch.split(torch.where(target_labels[(2 * i):(2 * (i + 1))].sum(0))[0], int(60000 / num_clients))
traindata_split = [torch.utils.data.Subset(traindata, tl) for tl in target_labels_split]
train_loader = [torch.utils.data.DataLoader(x, batch_size=batch_size, shuffle=True) for x in traindata_split]

test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('./data', train=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])
        ), batch_size=batch_size, shuffle=True)

# Instantiate models and optimizers

global_model = Net().cuda()
client_models = [Net().cuda() for _ in range(num_selected)]
for model in client_models:
    model.load_state_dict(global_model.state_dict())
# print(global_model.state_dict())


opt = [optim.SGD(model.parameters(), lr=lr) for model in client_models]

# Runnining FL
test_loss_accu = []
acc_accu = []
for r in range(num_rounds):
    # select random clients
    client_idx = np.random.permutation(num_clients)[:num_selected]

    # client update
    loss = 0
    for i in range(num_selected):
        loss += client_update(client_models[i], opt[i], train_loader[client_idx[i]], epoch=int(local_ep_list[client_idx[i]]))
    
    # serer aggregate
    server_aggregate(global_model, client_models)
    test_loss, acc = test(global_model, test_loader)
    test_loss_accu.append(test_loss)
    acc_accu.append(acc)
    
    print('%d-th round' % r)
    print('average train loss %0.3g | test loss %0.3g | test acc: %0.3f' % (loss / num_selected, test_loss, acc))
    
file_name = './save/objects/fed_diffEpoch{}_lr{}_loss_and_acc.pkl'. \
                format(epochs, lr)
with open(file_name, 'wb') as f:
            pickle.dump([test_loss_accu, acc_accu.append], f)
    
# Plot Average Accuracy vs Communication rounds
plt.figure()
plt.title('Accuracy vs Communication rounds')
plt.plot(range(len(acc_accu)), acc_accu, color='k')
plt.ylabel('Accuracy')
plt.xlabel('Communication Rounds')
plt.savefig('./save/fed_diffEpoch{}_lr{}_acc.png'.
                format(epochs, lr))

0-th round
average train loss 0.818 | test loss 2.3 | test acc: 0.118
1-th round
average train loss 0.851 | test loss 2.29 | test acc: 0.146
2-th round
average train loss 0.316 | test loss 2.28 | test acc: 0.098
3-th round
average train loss 0.435 | test loss 2.28 | test acc: 0.200
4-th round
average train loss 0.167 | test loss 2.35 | test acc: 0.182
5-th round
average train loss 0.208 | test loss 2.38 | test acc: 0.365
6-th round
average train loss 0.0822 | test loss 2.3 | test acc: 0.378
7-th round
average train loss 0.176 | test loss 2.01 | test acc: 0.493
8-th round
average train loss 0.0312 | test loss 2.79 | test acc: 0.179
9-th round
average train loss 0.109 | test loss 1.93 | test acc: 0.405
10-th round
average train loss 0.0361 | test loss 1.76 | test acc: 0.454
11-th round
average train loss 0.031 | test loss 1.7 | test acc: 0.417
12-th round
average train loss 0.0814 | test loss 1.62 | test acc: 0.444
13-th round
average train loss 0.112 | test loss 1.66 | test acc: 0.340
1

112-th round
average train loss 0.00795 | test loss 0.698 | test acc: 0.755
113-th round
average train loss 0.0107 | test loss 0.582 | test acc: 0.806
114-th round
average train loss 0.00959 | test loss 0.595 | test acc: 0.806
115-th round
average train loss 0.0026 | test loss 0.705 | test acc: 0.738
116-th round
average train loss 0.0429 | test loss 1.31 | test acc: 0.618
117-th round
average train loss 0.0109 | test loss 0.784 | test acc: 0.730
118-th round
average train loss 0.0098 | test loss 0.871 | test acc: 0.717
119-th round
average train loss 0.00376 | test loss 1.1 | test acc: 0.633
120-th round
average train loss 0.00351 | test loss 0.655 | test acc: 0.777
121-th round
average train loss 0.00583 | test loss 0.511 | test acc: 0.838
122-th round
average train loss 0.216 | test loss 0.551 | test acc: 0.807
123-th round
average train loss 0.00682 | test loss 0.575 | test acc: 0.800
124-th round
average train loss 0.0119 | test loss 0.78 | test acc: 0.736


KeyboardInterrupt: 