In [1]:
import random
import copy

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from Data_Load_Assign import *
from sklearn.metrics import accuracy_score

In [2]:
# Set device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("| using device:", device)

# Set random seeds
np.random.seed(0)
torch.manual_seed(0)
random.seed(0)

# Set B: batch_size
bsz = 10

| using device: cuda


In [3]:
# Load Dataset
train_data, test_data = load_MNIST()

print("Type of trainDataset：", type(train_data))
print("Length of trainDataset：", len(train_data))

print("Type of trainDataset[0]：", type(train_data[0]))
print("Length of trainDataset[0]：", len(train_data[0]))

print("Type of trainDataset[0][0]：", type(train_data[0][0]))
print("Shape of trainDataset[0][0]：", train_data[0][0].shape)
print("Type of trainDataset[0][1]：", type(train_data[0][1]))
print("trainDataset[0][1]：", train_data[0][1])

# Get client dataloaders
iid_train_loader = iid_Assign(train_data, batch_size = bsz)
noniid_train_loader = non_iid_Assign(train_data, batch_size = bsz)
print("iid_train_loader length", len(iid_train_loader))
print("noniid_train_loader length", len(noniid_train_loader))

Type of trainDataset： <class 'torchvision.datasets.mnist.MNIST'>
Length of trainDataset： 60000
Type of trainDataset[0]： <class 'tuple'>
Length of trainDataset[0]： 2
Type of trainDataset[0][0]： <class 'torch.Tensor'>
Shape of trainDataset[0][0]： torch.Size([1, 28, 28])
Type of trainDataset[0][1]： <class 'int'>
trainDataset[0][1]： 5
iid_train_loader length 100
noniid_train_loader length 100


In [4]:
# Test iid and non-iid is correct.

# iid
sample_iid = random.sample(iid_train_loader, 5)
for i in sample_iid:
    sample_label = torch.zeros(10)
    for (x,y) in i:
        sample_label += torch.sum(F.one_hot(y, num_classes=10), dim=0)
    print("iid: ", sample_label)

# non_iid
sample_noniid = random.sample(noniid_train_loader, 5)
for i in sample_noniid:
    sample_label = torch.zeros(10)
    for (x,y) in i:
        sample_label += torch.sum(F.one_hot(y, num_classes=10), dim=0)
    print("non-iid: ", sample_label)

# Q: what if a client get two subdatasets of the same number? like: 
# non-iid:  tensor([600.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.])

iid:  tensor([59., 64., 71., 60., 70., 49., 64., 57., 61., 45.])
iid:  tensor([57., 66., 65., 63., 69., 47., 58., 53., 62., 60.])
iid:  tensor([56., 72., 72., 50., 60., 50., 56., 68., 52., 64.])
iid:  tensor([59., 60., 56., 57., 49., 40., 61., 65., 77., 76.])
iid:  tensor([64., 65., 51., 63., 71., 50., 67., 59., 67., 43.])
non-iid:  tensor([  0., 300.,   0.,   0.,   0.,   0.,   0.,   0.,   0., 300.])
non-iid:  tensor([600.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.])
non-iid:  tensor([300., 300.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.])
non-iid:  tensor([300.,   0.,   0.,   0.,   0.,   0., 300.,   0.,   0.,   0.])
non-iid:  tensor([  0., 300.,   0.,   0.,   0.,   0.,   0.,   0., 300.,   0.])


In [5]:
# Model Defination

# Model 1: A simple multilayer-perceptron with 2-hidden
# layers with 200 units each using ReLu activations (199,210
# total parameters), which we refer to as the MNIST 2NN.

# define fully connected NN
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(784, 200);
        self.fc2 = nn.Linear(200, 200);
        self.out = nn.Linear(200, 10);

    def forward(self, x):
        x = x.flatten(1) # torch.Size([B,784])
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.out(x)
        return x

print(MLP())
print("Total parameters of 2NN(MLP) is", num_params(MLP()))

# Model 2: A CNN with two 5x5 convolution layers (the first with
# 32 channels, the second with 64, each followed with 2x2
# max pooling), a fully connected layer with 512 units and
# ReLu activation, and a final softmax output layer (1,663,370
# total parameters).

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(64 * 4 * 4, 512)
        self.fc2 = nn.Linear(512, 10)
    
    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 4 * 4) # flatten
        x = F.relu(self.fc1(x))
        x = F.softmax(self.fc2(x), dim=1)
        return x

print(CNN())
print("Total parameters of CNN is", num_params(CNN()))

# Q: The paper said there will be 1,663,370 total parameters. Something wrong?

MLP(
  (fc1): Linear(in_features=784, out_features=200, bias=True)
  (fc2): Linear(in_features=200, out_features=200, bias=True)
  (out): Linear(in_features=200, out_features=10, bias=True)
)
Total parameters of 2NN(MLP) is 199210
CNN(
  (conv1): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1))
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=1024, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=10, bias=True)
)
Total parameters of CNN is 582026


In [6]:
# Fed_Avg
def fedavg(global_model, C, K, E, c_loader, rounds, lr, acc_threshold):
    # C is the fraction of clients that perform computation on each round.
    assert C <= 1 and C >= 0
    # K is the number of clients.
    # E is the local epochs, the number of training passes each client makes
    # over its local dataset on each round.
    
    c_per_round = max( round(C * K), 1)
    accuracy = []
    
    for rd in range(rounds):
        
        # Choose c_per_round clients
        clients = random.sample(range(K), c_per_round)
        
        # Train on clients one by one
        client_model = []
        for i, c in enumerate(clients):
            local_model = client_training(global_model, id = c, local_epochs = E, dataloader = c_loader[c], lr = lr)
            local_model = local_model.to(device)
            client_model.append(local_model.state_dict())
            # print("round{}, client{} finished. {}/{}".format(rd, c, i+1, c_per_round))
            
    
        # Average and iterate global model parameters for next round  
        keys = client_model[0].keys()
        next_global_dict = {key: 0 for key in keys}
        for key in keys:
            values = [d[key] for d in client_model]
            averaged_value = sum(values) / len(values)
            next_global_dict[key] = averaged_value
            
        global_model.load_state_dict(next_global_dict)
        global_model = global_model.to(device)
        
        # Validate accuracy this round
        test_inputs, test_labels = zip(*test_data)
        # print("Type of test_inputs：", type(test_inputs))
        # print("Length of test_inputs：", len(test_inputs))
        # print("Type of test_inputs[0]：", type(test_inputs[0]))
        # print("Size of test_inputs[0]：", test_inputs[0].size())
        
        # print("Type of test_labels：", type(test_labels))
        # print("Length of test_labels：", len(test_labels))
        # print("Type of test_labels[0]：", type(test_labels[0]))
        test_inputs = torch.stack(test_inputs).to(device)
        test_labels = torch.Tensor(test_labels).to(device)

        global_model.eval()
        with torch.no_grad():
            predictions = global_model(test_inputs)

        predicted_labels = torch.argmax(predictions, dim=1).cpu().numpy()
        true_labels = test_labels.cpu().numpy()

        acc = accuracy_score(true_labels, predicted_labels)
        if rd%10 == 0 or acc >= acc_threshold:
            print("The clients in round {} are: {}".format(rd, clients))
            print("Accuracy of round {}:{}".format(rd, acc))
            print()
        accuracy.append(acc)
        if acc >= acc_threshold:
            break
        
    return accuracy

In [7]:
criterion = nn.CrossEntropyLoss()

def client_training(global_model, id, local_epochs, dataloader, lr):
    local_model = copy.deepcopy(global_model)
    local_model = local_model.to(device)
    optimizer = torch.optim.SGD(local_model.parameters(), lr=lr)
    
    for epoch in range(local_epochs):
        for inputs, labels in dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = local_model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
    return local_model

# MNIST 2NN 

In [8]:
mlp = MLP()
acc_threshold_2nn = 0.97

In [9]:
# iid C = 0.1
mlp_iid1 = copy.deepcopy(mlp)
acc_mlp_iid1 = fedavg(mlp_iid1, C = 0.1, K = 100, E = 1, 
                      c_loader = iid_train_loader, rounds = 100, 
                      lr = 0.05, acc_threshold = acc_threshold_2nn)
print("{} rounds needed to achieve the accuracy of {}".format(len(acc_mlp_iid1), acc_mlp_iid1[-1]))
print(acc_mlp_iid1)

The clients in round 0 are: [85, 22, 1, 60, 87, 52, 72, 65, 39, 83]
Accuracy of round 0:0.5578

The clients in round 10 are: [27, 1, 91, 96, 0, 86, 67, 78, 12, 24]
Accuracy of round 10:0.9047

The clients in round 20 are: [38, 53, 13, 12, 71, 61, 60, 43, 15, 14]
Accuracy of round 20:0.9303

The clients in round 30 are: [43, 90, 95, 40, 41, 4, 67, 18, 32, 77]
Accuracy of round 30:0.9422

The clients in round 40 are: [49, 58, 6, 12, 60, 99, 19, 2, 4, 76]
Accuracy of round 40:0.9517

The clients in round 50 are: [9, 76, 18, 26, 0, 84, 86, 93, 15, 95]
Accuracy of round 50:0.9569

The clients in round 60 are: [77, 88, 19, 89, 57, 51, 23, 98, 53, 55]
Accuracy of round 60:0.9606

The clients in round 70 are: [51, 94, 99, 35, 37, 56, 47, 72, 80, 17]
Accuracy of round 70:0.9652

The clients in round 80 are: [75, 90, 6, 49, 4, 29, 81, 10, 23, 46]
Accuracy of round 80:0.9673

The clients in round 90 are: [9, 1, 58, 63, 92, 56, 6, 52, 15, 10]
Accuracy of round 90:0.9686

The clients in round 93 ar

In [11]:
# iid C = 0.2
mlp_iid2 = copy.deepcopy(mlp)
acc_mlp_iid2 = fedavg(mlp_iid2, C = 0.2, K = 100, E = 1, 
                      c_loader = iid_train_loader, rounds = 100, 
                      lr = 0.05, acc_threshold = acc_threshold_2nn)
print("{} rounds needed to achieve the accuracy of {}".format(len(acc_mlp_iid2), acc_mlp_iid2[-1]))
print(acc_mlp_iid2)

The clients in round 0 are: [91, 37, 99, 56, 65, 77, 59, 68, 81, 33, 34, 29, 2, 15, 78, 12, 22, 93, 53, 31]
Accuracy of round 0:0.6377

The clients in round 10 are: [2, 21, 95, 87, 71, 91, 10, 54, 76, 13, 79, 80, 58, 90, 19, 78, 77, 5, 32, 43]
Accuracy of round 10:0.91

The clients in round 20 are: [12, 24, 21, 56, 8, 54, 81, 50, 34, 32, 55, 96, 99, 45, 78, 41, 11, 39, 3, 63]
Accuracy of round 20:0.9305

The clients in round 30 are: [60, 10, 6, 68, 51, 33, 3, 82, 66, 12, 42, 45, 4, 19, 80, 36, 0, 48, 43, 20]
Accuracy of round 30:0.9411

The clients in round 40 are: [17, 63, 45, 79, 94, 36, 91, 42, 87, 15, 52, 98, 33, 20, 43, 80, 64, 68, 88, 18]
Accuracy of round 40:0.9523

The clients in round 50 are: [80, 45, 55, 35, 51, 98, 91, 36, 13, 61, 37, 15, 57, 19, 44, 31, 95, 23, 43, 63]
Accuracy of round 50:0.9581

The clients in round 60 are: [79, 63, 84, 73, 27, 76, 56, 4, 31, 62, 77, 40, 71, 93, 25, 0, 5, 7, 17, 86]
Accuracy of round 60:0.961

The clients in round 70 are: [71, 85, 99, 98,

In [12]:
# iid C = 0.5
mlp_iid3 = copy.deepcopy(mlp)
acc_mlp_iid3 = fedavg(mlp_iid3, C = 0.5, K = 100, E = 1, 
                      c_loader = iid_train_loader, rounds = 100, 
                      lr = 0.05, acc_threshold = acc_threshold_2nn)
print("{} rounds needed to achieve the accuracy of {}".format(len(acc_mlp_iid3), acc_mlp_iid3[-1]))
print(acc_mlp_iid3)

The clients in round 0 are: [49, 51, 69, 61, 70, 54, 66, 9, 28, 18, 82, 98, 65, 40, 90, 92, 85, 19, 4, 99, 72, 14, 8, 67, 34, 79, 38, 45, 96, 88, 86, 59, 87, 44, 29, 64, 81, 46, 24, 50, 23, 21, 60, 32, 0, 15, 36, 63, 30, 7]
Accuracy of round 0:0.6285

The clients in round 10 are: [79, 22, 8, 13, 35, 12, 34, 99, 59, 73, 33, 74, 53, 5, 9, 18, 57, 7, 44, 42, 32, 52, 49, 2, 37, 30, 27, 11, 84, 31, 66, 90, 83, 40, 1, 92, 87, 75, 89, 95, 97, 45, 21, 58, 36, 41, 93, 24, 94, 19]
Accuracy of round 10:0.9083

The clients in round 20 are: [44, 95, 41, 25, 80, 12, 64, 63, 18, 98, 32, 97, 24, 81, 14, 79, 31, 45, 51, 13, 35, 49, 71, 82, 4, 42, 29, 50, 75, 53, 68, 34, 20, 33, 5, 69, 10, 21, 88, 26, 0, 78, 43, 17, 85, 61, 52, 58, 76, 11]
Accuracy of round 20:0.9306

The clients in round 30 are: [38, 56, 68, 42, 51, 85, 96, 95, 35, 89, 1, 27, 59, 99, 92, 54, 16, 22, 61, 24, 21, 80, 9, 34, 64, 75, 20, 23, 62, 15, 81, 31, 55, 72, 14, 6, 4, 44, 47, 5, 50, 69, 46, 74, 37, 8, 77, 36, 13, 41]
Accuracy of rou

In [13]:
# non-iid C = 0.1
mlp_noniid1 = copy.deepcopy(mlp)
acc_mlp_noniid1 = fedavg(mlp_noniid1, C = 0.1, K = 100, E = 1, 
                      c_loader = noniid_train_loader, rounds = 500, 
                      lr = 0.05, acc_threshold = acc_threshold_2nn)
print("{} rounds needed to achieve the accuracy of {}".format(len(acc_mlp_noniid1), acc_mlp_noniid1[-1]))
print(acc_mlp_noniid1)

The clients in round 0 are: [39, 42, 67, 37, 35, 48, 62, 14, 89, 9]
Accuracy of round 0:0.2465

The clients in round 10 are: [16, 93, 90, 38, 4, 25, 67, 37, 33, 30]
Accuracy of round 10:0.642

The clients in round 20 are: [59, 86, 16, 89, 55, 48, 57, 52, 93, 27]
Accuracy of round 20:0.7354

The clients in round 30 are: [22, 50, 60, 63, 26, 57, 69, 49, 89, 38]
Accuracy of round 30:0.8228

The clients in round 40 are: [33, 76, 93, 28, 83, 66, 41, 70, 27, 26]
Accuracy of round 40:0.8635

The clients in round 50 are: [47, 75, 11, 40, 82, 49, 15, 65, 41, 6]
Accuracy of round 50:0.8722

The clients in round 60 are: [53, 27, 26, 86, 0, 90, 46, 94, 21, 39]
Accuracy of round 60:0.8773

The clients in round 70 are: [8, 2, 91, 52, 45, 25, 26, 42, 55, 13]
Accuracy of round 70:0.9056

The clients in round 80 are: [79, 64, 89, 81, 83, 53, 14, 43, 80, 28]
Accuracy of round 80:0.9088

The clients in round 90 are: [12, 85, 8, 71, 31, 21, 58, 24, 74, 46]
Accuracy of round 90:0.8813

The clients in round

In [None]:
# non-iid C = 0.2
mlp_noniid2 = copy.deepcopy(mlp)
acc_mlp_noniid2 = fedavg(mlp_noniid2, C = 0.1, K = 100, E = 1, 
                      c_loader = noniid_train_loader, rounds = 500, 
                      lr = 0.05, acc_threshold = acc_threshold_2nn)
print("{} rounds needed to achieve the accuracy of {}".format(len(acc_mlp_noniid2), acc_mlp_noniid2[-1]))
print(acc_mlp_noniid2)

In [None]:
# non-iid C = 0.5
mlp_noniid3 = copy.deepcopy(mlp)
acc_mlp_noniid3 = fedavg(mlp_noniid3, C = 0.1, K = 100, E = 1, 
                      c_loader = noniid_train_loader, rounds = 500, 
                      lr = 0.05, acc_threshold = acc_threshold_2nn)
print("{} rounds needed to achieve the accuracy of {}".format(len(acc_mlp_noniid3), acc_mlp_noniid3[-1]))
print(acc_mlp_noniid3)

# MNIST CNN

In [18]:
cnn = CNN()
acc_threshold_cnn = 0.99

In [19]:
# iid C = 0.1
cnn_iid1 = copy.deepcopy(cnn)
acc_cnn_iid1 = fedavg(cnn_iid1, C = 0.1, K = 100, E = 5, 
                      c_loader = iid_train_loader, rounds = 100, 
                      lr = 0.05, acc_threshold = acc_threshold_cnn)
print("{} rounds needed to achieve the accuracy of {}".format(len(acc_cnn_iid1), acc_cnn_iid1[-1]))
print(acc_cnn_iid1)

The clients in round 0 are: [17, 27, 9, 44, 54, 21, 93, 24, 10, 49]
Accuracy of round 0:0.4846

The clients in round 10 are: [52, 4, 36, 75, 58, 88, 53, 91, 30, 16]
Accuracy of round 10:0.9647

The clients in round 20 are: [63, 97, 17, 9, 81, 66, 88, 74, 32, 25]
Accuracy of round 20:0.9745

The clients in round 30 are: [94, 5, 96, 68, 78, 42, 11, 20, 56, 40]
Accuracy of round 30:0.9789

The clients in round 40 are: [52, 7, 5, 24, 64, 15, 68, 19, 46, 43]
Accuracy of round 40:0.9811

The clients in round 50 are: [45, 44, 69, 52, 60, 26, 58, 83, 47, 76]
Accuracy of round 50:0.9832

The clients in round 60 are: [7, 86, 47, 60, 37, 28, 13, 57, 69, 21]
Accuracy of round 60:0.9841

The clients in round 70 are: [16, 82, 61, 70, 93, 57, 38, 98, 71, 45]
Accuracy of round 70:0.9862

The clients in round 80 are: [11, 76, 31, 44, 29, 23, 33, 43, 21, 74]
Accuracy of round 80:0.9866

The clients in round 90 are: [18, 23, 2, 7, 57, 21, 63, 84, 83, 82]
Accuracy of round 90:0.9869

100 rounds needed to 

In [21]:
# iid C = 0.2
cnn_iid2 = copy.deepcopy(cnn)
acc_cnn_iid2 = fedavg(cnn_iid2, C = 0.2, K = 100, E = 5, 
                      c_loader = iid_train_loader, rounds = 100, 
                      lr = 0.05, acc_threshold = acc_threshold_cnn)
print("{} rounds needed to achieve the accuracy of {}".format(len(acc_cnn_iid2), acc_cnn_iid2[-1]))
print(acc_cnn_iid2)

The clients in round 0 are: [28, 52, 91, 78, 47, 18, 53, 97, 38, 46, 9, 20, 93, 82, 77, 66, 29, 87, 48, 5]
Accuracy of round 0:0.3455

The clients in round 10 are: [47, 86, 12, 34, 92, 81, 0, 96, 78, 90, 1, 37, 57, 58, 26, 7, 50, 70, 48, 14]
Accuracy of round 10:0.9657

The clients in round 20 are: [49, 64, 51, 25, 12, 24, 40, 0, 62, 90, 48, 53, 86, 50, 65, 55, 16, 57, 41, 34]
Accuracy of round 20:0.9765

The clients in round 30 are: [16, 51, 53, 18, 7, 99, 52, 40, 93, 91, 84, 75, 64, 82, 87, 61, 81, 70, 57, 13]
Accuracy of round 30:0.9797

The clients in round 40 are: [87, 7, 27, 16, 28, 51, 10, 34, 79, 71, 18, 52, 98, 4, 64, 21, 61, 0, 43, 42]
Accuracy of round 40:0.9819

The clients in round 50 are: [19, 9, 49, 97, 1, 67, 62, 29, 98, 7, 5, 84, 50, 11, 73, 79, 63, 64, 43, 86]
Accuracy of round 50:0.9836

The clients in round 60 are: [89, 72, 78, 38, 46, 2, 20, 15, 99, 33, 70, 41, 97, 68, 30, 12, 90, 59, 81, 60]
Accuracy of round 60:0.9855

The clients in round 70 are: [79, 97, 93, 72

In [22]:
# iid C = 0.5
cnn_iid3 = copy.deepcopy(cnn)
acc_cnn_iid3 = fedavg(cnn_iid3, C = 0.5, K = 100, E = 5, 
                      c_loader = iid_train_loader, rounds = 200, 
                      lr = 0.05, acc_threshold = acc_threshold_cnn)
print("{} rounds needed to achieve the accuracy of {}".format(len(acc_cnn_iid3), acc_cnn_iid3[-1]))
print(acc_cnn_iid3)

The clients in round 0 are: [71, 34, 46, 65, 41, 72, 56, 96, 23, 6, 58, 38, 28, 95, 26, 61, 86, 94, 84, 5, 43, 74, 62, 45, 32, 88, 93, 83, 68, 67, 13, 11, 17, 75, 60, 98, 99, 4, 63, 12, 53, 73, 54, 82, 21, 81, 69, 52, 35, 16]
Accuracy of round 0:0.2744

The clients in round 10 are: [59, 31, 21, 45, 19, 18, 42, 97, 49, 70, 44, 89, 40, 50, 20, 56, 91, 16, 58, 43, 23, 76, 6, 4, 57, 33, 30, 15, 1, 60, 63, 29, 73, 82, 24, 36, 77, 65, 99, 3, 46, 84, 35, 41, 74, 81, 52, 68, 83, 88]
Accuracy of round 10:0.9648

The clients in round 20 are: [7, 52, 33, 76, 96, 17, 94, 16, 95, 29, 45, 60, 91, 70, 57, 89, 85, 47, 93, 36, 30, 82, 64, 1, 88, 27, 22, 51, 12, 69, 4, 9, 13, 80, 43, 44, 10, 26, 6, 39, 63, 56, 35, 55, 31, 81, 67, 41, 23, 75]
Accuracy of round 20:0.9757

The clients in round 30 are: [21, 85, 83, 16, 79, 89, 92, 0, 20, 13, 77, 11, 82, 87, 73, 71, 37, 66, 59, 51, 1, 17, 29, 72, 26, 61, 54, 76, 98, 32, 30, 67, 3, 80, 48, 18, 43, 7, 15, 4, 90, 23, 65, 27, 8, 41, 88, 62, 5, 69]
Accuracy of ro

In [None]:
# non-iid C = 0.1
cnn_noniid1 = copy.deepcopy(cnn)
acc_cnn_noniid1 = fedavg(cnn_noniid1, C = 0.1, K = 100, E = 5, 
                      c_loader = noniid_train_loader, rounds = 500, 
                      lr = 0.05, acc_threshold = acc_threshold_cnn)
print("{} rounds needed to achieve the accuracy of {}".format(len(acc_cnn_noniid1), acc_cnn_noniid1[-1]))
print(acc_cnn_noniid1)

In [None]:
# non-iid C = 0.2
cnn_noniid2 = copy.deepcopy(cnn)
acc_cnn_noniid2 = fedavg(cnn_noniid2, C = 0.1, K = 100, E = 5, 
                      c_loader = noniid_train_loader, rounds = 500, 
                      lr = 0.05, acc_threshold = acc_threshold_cnn)
print("{} rounds needed to achieve the accuracy of {}".format(len(acc_cnn_noniid2), acc_cnn_noniid2[-1]))
print(acc_cnn_noniid2)

In [None]:
# non-iid C = 0.5
cnn_noniid3 = copy.deepcopy(cnn)
acc_cnn_noniid3 = fedavg(cnn_noniid3, C = 0.1, K = 100, E = 5, 
                      c_loader = noniid_train_loader, rounds = 500, 
                      lr = 0.05, acc_threshold = acc_threshold_cnn)
print("{} rounds needed to achieve the accuracy of {}".format(len(acc_cnn_noniid3), acc_cnn_noniid3[-1]))
print(acc_cnn_noniid3)