In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np
from tqdm import tqdm
import matplotlib 
import matplotlib.pyplot as plt
matplotlib.use('Agg')
%matplotlib inline

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
        self.fc1 = nn.Linear(2048, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.conv3(x)
        x = F.relu(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output


def client_update(client_model, optimizer, train_loader, epoch, num_clients, pk, batch_size):
    client_model.train()
    Grad_accumulator = []
    for e in range(epoch):
        grad_batch_idx = []
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.cuda(), target.cuda()
            optimizer.zero_grad()
            output = client_model(data)
            loss_fed = F.nll_loss(output, target)
            loss = loss_fed/(epoch *num_clients*pk)
            loss.backward()
            #print(type(client_model))
            #print(dir(client_model))
#             print(client_model.parameters())
#             grad_batch_idx.append(client_model.parameters().grad.numpy()) 
#             grad_batch_idx.append(list(i.grad for i in list(client_model.parameters())))
            if Grad_accumulator == []:
                Grad_accumulator = list(i.grad for i in list(client_model.parameters()))
                    
            else:
                h = list(i.grad for i in list(client_model.parameters()))
                Grad_accumulator = [Grad_accumulator[i]+h[i] for i in range(len(Grad_accumulator))]
            optimizer.step()
#         grad_batch_idx_np = np.array(grad_batch_idx)
#         grad_batch_e = sum(grad_batch_idx_np)/len(grad_batch_idx_np)
#         grad_epoch.append(grad_batch_e)
#     grad_client = sum(grad_epoch)/epoch 
#     grad_client_tensor = torch.from_numpy(grad_client)
    
    nabla_P_norm2 = sum([(torch.norm(a))**2 for a in Grad_accumulator]).item()
    grad_client = (1/(epoch*batch_size))*np.sqrt(nabla_P_norm2)
    return loss_fed.item(),grad_client

def server_aggregate(global_model, client_models):
    global_dict = global_model.state_dict()
    for k in global_dict.keys():
        global_dict[k] = torch.stack([client_models[i].state_dict()[k] for i in range(len(client_models))], 0).mean(0)
    global_model.load_state_dict(global_dict)
    for model in client_models:
        model.load_state_dict(global_model.state_dict())

def test(global_model, test_loader):
    global_model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.cuda(), target.cuda()
            output = global_model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    acc = correct / len(test_loader.dataset)

    return test_loss, acc

In [8]:
# NON-IID case: every client has images of two categories chosen from [0, 1], [2, 3], [4, 5], [6, 7], or [8, 9].

# Hyperparameters

num_clients = 100
num_selected = 10
num_rounds = 100
epochs = 10
batch_size = 32
local_ep_list = np.random.choice(range(1,epochs+1),size=num_clients)
lr = 0.01
# Creating decentralized datasets

traindata = datasets.MNIST('./data', train=True, download=True,
                       transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])
                       )
target_labels = torch.stack([traindata.targets == i for i in range(10)])
target_labels_split = []
for i in range(5):
    target_labels_split += torch.split(torch.where(target_labels[(2 * i):(2 * (i + 1))].sum(0))[0], int(60000 / num_clients))
traindata_split = [torch.utils.data.Subset(traindata, tl) for tl in target_labels_split]
train_loader = [torch.utils.data.DataLoader(x, batch_size=batch_size, shuffle=True) for x in traindata_split]

test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('./data', train=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])
        ), batch_size=batch_size, shuffle=True)

# Instantiate models and optimizers

global_model = Net().cuda()
client_models = [Net().cuda() for _ in range(num_selected)]
for model in client_models:
    model.load_state_dict(global_model.state_dict())
# print(global_model.state_dict())

opt = [optim.SGD(model.parameters(), lr=lr for model in client_models]
# opt = optim.SGD(model.parameters(), lr=0.01) 
# Runnining FL
p_initial = np.ones(num_clients)/num_clients # initialize the probability vector
p_usersampling = p_initial
        
test_loss_accu[]
acc_accu = []
for r in range(num_rounds):
    # select random clients
#     client_idx = np.random.permutation(num_clients)[:num_selected]
    client_idx = np.random.choice(range(num_clients), num_selected, replace=False,p = p_usersampling)
 
    # client update
#     loss = 0
#     grad_client = 0
    grad_list=[]
    loss_list=[]
    for i in range(num_selected):
        loss,grad_client = client_update(client_models[i], opt[i], train_loader[client_idx[i]], 
                              epoch=int(local_ep_list[client_idx[i]]), num_clients=num_clients, pk=p_usersampling[client_idx[i]], batch_size=batch_size )
        grad_list.append(grad_client)
        loss_list.append(loss)
    loss = sum(loss_list)
    
    grad_list = [a/sum(grad_list) for a in grad_list]
    normalizing_factor = sum([p_usersampling[i] for i in client_idx])
    
    for i in range(num_selected):
        p_usersampling[client_idx[i]]=(grad_list[i]/sum(grad_list)) * normalizing_factor
    # serer aggregate
    server_aggregate(global_model, client_models)
    test_loss, acc = test(global_model, test_loader)
    test_loss_accu.append(test_loss)
    acc_accu.append(acc)
                 
    print('%d-th round' % r)
    print('average train loss %0.3g | test loss %0.3g | test acc: %0.3f' % (loss / num_selected, test_loss, acc))
#     print('sampling probability:' )
#     print(p_usersampling)
                 
file_name = './save/objects/fedsample_Epoch{}_lr{}_loss_and_acc.pkl'.
                format(epochs, lr)
with open(file_name, 'wb') as f:
            pickle.dump([test_loss_accu, acc_accu.append], f)
                 
# Plot Average Accuracy vs Communication rounds
plt.figure()
plt.title('Accuracy vs Communication rounds')
plt.plot(range(len(acc_accu)), acc_accu, color='k')
plt.ylabel('Accuracy')
plt.xlabel('Communication Rounds')
plt.savefig('./save/fedsample_Epoch{}_lr{}_acc.png'.
                format(epochs, lr))             
    

0-th round
average train loss 2.14 | test loss 2.3 | test acc: 0.089
1-th round
average train loss 2.03 | test loss 2.3 | test acc: 0.089
2-th round
average train loss 1.64 | test loss 2.31 | test acc: 0.133
3-th round
average train loss 0.961 | test loss 2.43 | test acc: 0.210
4-th round
average train loss 0.772 | test loss 2.48 | test acc: 0.154
5-th round
average train loss 0.661 | test loss 2.53 | test acc: 0.198
6-th round
average train loss 0.609 | test loss 2.67 | test acc: 0.192
7-th round
average train loss 0.556 | test loss 2.36 | test acc: 0.103
8-th round
average train loss 0.507 | test loss 2.27 | test acc: 0.175
9-th round
average train loss 0.541 | test loss 3.02 | test acc: 0.180
10-th round
average train loss 0.367 | test loss 2.06 | test acc: 0.436
11-th round
average train loss 0.271 | test loss 2.15 | test acc: 0.369
12-th round
average train loss 0.324 | test loss 2.04 | test acc: 0.416
13-th round
average train loss 0.276 | test loss 1.88 | test acc: 0.248
14-th r

113-th round
average train loss 0.0583 | test loss 0.559 | test acc: 0.828
114-th round
average train loss 0.0408 | test loss 0.609 | test acc: 0.804
115-th round
average train loss 0.0206 | test loss 0.625 | test acc: 0.796
116-th round
average train loss 0.021 | test loss 0.574 | test acc: 0.817
117-th round
average train loss 0.0425 | test loss 0.696 | test acc: 0.764
118-th round
average train loss 0.0494 | test loss 0.578 | test acc: 0.814
119-th round
average train loss 0.0199 | test loss 0.767 | test acc: 0.731
120-th round
average train loss 0.0956 | test loss 0.862 | test acc: 0.690
121-th round
average train loss 0.0203 | test loss 0.732 | test acc: 0.737
122-th round
average train loss 0.0717 | test loss 0.734 | test acc: 0.735
123-th round
average train loss 0.0444 | test loss 1.04 | test acc: 0.645
124-th round
average train loss 0.03 | test loss 0.684 | test acc: 0.761
125-th round
average train loss 0.0312 | test loss 0.582 | test acc: 0.798
126-th round
average train lo