In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch
import torchvision.transforms as transforms
from torchvision import datasets

In [2]:
#Classifier for MNIST 
class model(nn.Module):
    def __init__(self, args):
        super(model, self).__init__()
        self.conv1 = nn.Conv2d(args.num_channels, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, args.num_classes)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

In [3]:
# Virtual workers
class FedClient:
    def __init__(self, id, train_data, test_data, args, global_model):
        self.client_id = id
        self.train_data = train_data
        self.test_data = test_data
        self.model = model(args)
        self.optimizer = optim.SGD(self.model.parameters(), lr=0.01, momentum=0.5)
        self.global_model = global_model

    #train local model
    def train(self, epochs):
        self.model.train()
        for epoch in range(epochs):
            for batch_idx, (data, target) in enumerate(self.train_data):
                self.optimizer.zero_grad()
                output = self.model(data)
                loss = F.nll_loss(output, target)
                loss.backward()
                self.optimizer.step()
                if batch_idx % 50 == 0: #only for print
                    print('Client: {} Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                        self.client_id, epoch, batch_idx * len(data), len(self.train_data.dataset),
                        100. * batch_idx / len(self.train_data), loss.item()))
                    
    #evaluate local model
    def evaluate(self):
        self.model.eval()
        test_loss = 0
        correct = 0
        with torch.no_grad():
            for data, target in self.test_data:
                output = self.model(data)
                test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
                pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
                correct += pred.eq(target.view_as(pred)).sum().item()

        test_loss /= len(self.test_data.dataset)
        accuracy = correct / len(self.test_data.dataset)
        print('Client: {} Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
            self.client_id, test_loss, correct, len(self.test_data.dataset),
            100. * correct / len(self.test_data.dataset)))
        return test_loss, accuracy
    
    #update the local model from the aggregator's feedback
    def update_to_local(self):
        self.model.load_state_dict(self.global_model.state_dict())

In [4]:
#central server
class FedAggregator:
    def __init__(self, num_clients,args):
        self.num_clients = num_clients
        self.clients = []
        self.global_model = model(args)
        self.transform = transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])
        self.global_loss = float('inf')
        
    def load_data(self):
        train_data = datasets.MNIST('../data', train=True, download=True,
                                    transform=self.transform)
        test_data = datasets.MNIST('../data', train=False, download=True,
                                   transform=self.transform)

        # split train_data across clients
        num_samples = len(train_data) // self.num_clients
        split_data = [num_samples] * self.num_clients
        split_data[-1] += len(train_data) - sum(split_data)

        train_data_splits = torch.utils.data.random_split(train_data, split_data)

        # create clients
        for i in range(self.num_clients):
            client_train_data = train_data_splits[i]
            client_train_loader = torch.utils.data.DataLoader(client_train_data, batch_size=64, shuffle=True)
            client_test_loader = torch.utils.data.DataLoader(test_data, batch_size=64, shuffle=False)

            client = FedClient(i, client_train_loader, client_test_loader, args, self.global_model)
            self.clients.append(client)
            
    def run_rounds(self, num_rounds, epochs):
        for round in range(num_rounds):
            print("***********************************Round: ", round)
            client_models = []
            client_losses = []
            client_accs = []
            for client in self.clients:
                client.train(epochs)
                client_loss, client_acc = client.evaluate()
                client_losses.append(client_loss)
                client_accs.append(client_acc)
                client_model = client.model.state_dict()
                client_models.append(client_model)


            acc_loss_ratios = []
            alpha = 0.3 # top alpha% of the models wrt loss is chosen
            for i, client_model in enumerate(client_models):
                acc = client_accs[i]
                loss = client_losses[i]
                ratio = acc / loss
                acc_loss_ratios.append((ratio, i))
            acc_loss_ratios = sorted(acc_loss_ratios, reverse=True)  # sort in descending order
            ft_size = max(int(np.ceil(alpha * self.num_clients)), 1)  # ensure ft_size is at least 1

            # update global model using weighted average of models from above
            global_model_dict = {}
            total_weight = 0
            for _, i in acc_loss_ratios[:ft_size]:
                weight = 1 / client_losses[i]  # to use inverse loss as weight
                total_weight += weight
                client_model = client_models[i]
                for k in client_model.keys():
                    if k in global_model_dict:
                        global_model_dict[k] += weight * client_model[k]
                    else:
                        global_model_dict[k] = weight * client_model[k]
            for k in global_model_dict.keys():
                global_model_dict[k] /= total_weight
                self.global_model.state_dict()[k].copy_(global_model_dict[k])

            # update clients' local models with the global model
            for client in self.clients:
                client.update_to_local()

            # evaluate global model
            test_loss, test_acc = self.evaluate_global()
            print(f"Test loss: {test_loss}, Test accuracy: {test_acc}")
            

    def evaluate_global(self):
        self.global_model.eval()
        test_loss = 0
        correct = 0
        with torch.no_grad():
            for data, target in self.clients[0].test_data:
                output = self.global_model(data)
                test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
                pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
                correct += pred.eq(target.view_as(pred)).sum().item()

        test_loss /= len(self.clients[0].test_data.dataset)
        accuracy = correct / len(self.clients[0].test_data.dataset)
        print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
            test_loss, correct, len(self.clients[0].test_data.dataset),
            100. * correct / len(self.clients[0].test_data.dataset)))
        return test_loss, accuracy



In [5]:
class Args:
    def __init__(self):
        self.num_channels = 1
        self.num_classes = 10

args = Args()


In [6]:
num_clients = 5
num_rounds = 10
epochs = 10

aggregator = FedAggregator(num_clients,args)
aggregator.load_data()

aggregator.run_rounds(num_rounds, epochs)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw

***********************************Round:  0
Client: 0 Test set: Average loss: 0.1160, Accuracy: 9654/10000 (97%)

Client: 1 Test set: Average loss: 0.1268, Accuracy: 9618/10000 (96%)

Client: 2 Test set: Average loss: 0.1259, Accuracy: 9615/10000 (96%)

Client: 3 Test set: Average loss: 0.1317, Accuracy: 9603/10000 (96%)

Client: 4 Test set: Average loss: 0.1303, Accuracy: 9588/10000 (96%)

Test set: Average loss: 1.7370, Accuracy: 8064/10000 (81%)

Test loss: 1.7369786029815675, Test accuracy: 0.8064
***********************************Round:  1
Client: 0 Test set: Average loss: 0.0885, Accuracy: 9719/10000 (97%)

Client: 1 Test set: Average loss: 0.0896, Accuracy: 9710/10000 (97%)

Client: 2 Test set: Average loss: 0.0921, Accuracy: 9713/10000 (97%)

Client: 3 Test set: Average loss: 0.0902, Accuracy: 9707/10000 (97%)

Client: 4 Test set: Average loss: 0.0942, Accuracy: 9699/10000 (97%)

Test set: Average lo