# FedAVG

In this tutorial, you will learn how to simulate FedAVG, representative scheme of Federated Learning, with AIJack. You can choose the single process or MPI as the backend.

## Single Process

In [2]:
import random

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from mpi4py import MPI
from torchvision import datasets, transforms

from aijack.collaborative import FedAvgClient, FedAvgServer
from aijack.collaborative.fedavg import FedAVGAPI

In [3]:
def fix_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True


def prepare_dataloader(num_clients, myid, train=True, path=""):
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
    )
    if train:
        dataset = datasets.MNIST(path, train=True, download=True, transform=transform)
        idxs = list(range(len(dataset.data)))
        random.shuffle(idxs)
        idx = np.array_split(idxs, num_clients, 0)[myid - 1]
        dataset.data = dataset.data[idx]
        dataset.targets = dataset.targets[idx]
        train_loader = torch.utils.data.DataLoader(
            dataset, batch_size=training_batch_size
        )
        return train_loader
    else:
        dataset = datasets.MNIST(path, train=False, download=True, transform=transform)
        test_loader = torch.utils.data.DataLoader(dataset, batch_size=test_batch_size)
        return test_loader


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.ln = nn.Linear(28 * 28, 10)

    def forward(self, x):
        x = self.ln(x.reshape(-1, 28 * 28))
        output = F.log_softmax(x, dim=1)
        return output

In [4]:
training_batch_size = 64
test_batch_size = 64
num_rounds = 5
lr = 0.001
seed = 0
client_size = 2
criterion = F.nll_loss

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
fix_seed(seed)

In [5]:
clients = [FedAvgClient(Net().to(device), user_id=c) for c in range(client_size)]
local_optimizers = [optim.SGD(client.parameters(), lr=lr) for client in clients]
local_dataloaders = [prepare_dataloader(client_size, c) for c in range(client_size)]
test_dataloader = prepare_dataloader(client_size, -1, train=False)

server = FedAvgServer(clients, Net().to(device))

for round in range(1, num_rounds + 1):
  for client, local_trainloader, local_optimizer in zip(clients, local_dataloaders, local_optimizers):
      for data in local_trainloader:
          inputs, labels = data
          inputs = inputs.to(device)
          labels = labels.to(device)
          local_optimizer.zero_grad()
          outputs = client(inputs)
          loss = criterion(outputs, labels.to(torch.int64))
          client.backward(loss)
          local_optimizer.step()
  server.action()


  test_loss = 0
  correct = 0
  with torch.no_grad():
      for data, target in test_dataloader:
          data, target = data.to(device), target.to(device)
          output = server(data)
          test_loss += F.nll_loss(
              output, target, reduction="sum"
          ).item()  # sum up batch loss
          pred = output.argmax(
              dim=1, keepdim=True
          )  # get the index of the max log-probability
          correct += pred.eq(target.view_as(pred)).sum().item()

  test_loss /= len(test_dataloader.dataset)
  accuracy = 100.0 * correct / len(test_dataloader.dataset)
  print(
      f"Round: {round}, Test set: Average loss: {test_loss}, Accuracy: {accuracy}"
  )

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting MNIST/raw/train-images-idx3-ubyte.gz to MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting MNIST/raw/train-labels-idx1-ubyte.gz to MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting MNIST/raw/t10k-images-idx3-ubyte.gz to MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting MNIST/raw/t10k-labels-idx1-ubyte.gz to MNIST/raw

Round: 1, Test set: Average loss: 0.7824367607116699, Accuracy: 83.71
Round: 2, Test set: Average loss: 0.5854546638488769, Accuracy: 86.49
Round: 3, Test set: Average loss: 0.5077689335346222, Accuracy: 87.54
Round: 4, Test set: Average loss: 0.4647755696773529, Accuracy: 88.25
Round: 5, Test set: Average loss: 0.4369198709487915, Accuracy: 88.63


## MPI

In [6]:
%%writefile mpi_fedavg.py
import random
from logging import getLogger

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from mpi4py import MPI
from torchvision import datasets, transforms

from aijack.collaborative.fedavg import MPIFedAVGAPI, MPIFedAVGClient, MPIFedAVGServer

logger = getLogger(__name__)

training_batch_size = 64
test_batch_size = 64
num_rounds = 5
lr = 0.001
seed = 0


def fix_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True


def prepare_dataloader(num_clients, myid, train=True, path=""):
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
    )
    if train:
        dataset = datasets.MNIST(path, train=True, download=True, transform=transform)
        idxs = list(range(len(dataset.data)))
        random.shuffle(idxs)
        idx = np.array_split(idxs, num_clients, 0)[myid - 1]
        dataset.data = dataset.data[idx]
        dataset.targets = dataset.targets[idx]
        train_loader = torch.utils.data.DataLoader(
            dataset, batch_size=training_batch_size
        )
        return train_loader
    else:
        dataset = datasets.MNIST(path, train=False, download=True, transform=transform)
        test_loader = torch.utils.data.DataLoader(dataset, batch_size=test_batch_size)
        return test_loader


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.ln = nn.Linear(28 * 28, 10)

    def forward(self, x):
        x = self.ln(x.reshape(-1, 28 * 28))
        output = F.log_softmax(x, dim=1)
        return output


def evaluate_gloal_model(dataloader):
    def _evaluate_global_model(api):
        test_loss = 0
        correct = 0
        with torch.no_grad():
            for data, target in dataloader:
                data, target = data.to(api.party.device), target.to(api.party.device)
                output = api.party.server_model(data)
                test_loss += F.nll_loss(
                    output, target, reduction="sum"
                ).item()  # sum up batch loss
                pred = output.argmax(
                    dim=1, keepdim=True
                )  # get the index of the max log-probability
                correct += pred.eq(target.view_as(pred)).sum().item()

        test_loss /= len(dataloader.dataset)
        accuracy = 100.0 * correct / len(dataloader.dataset)
        print(
            f"Round: {api.party.round}, Test set: Average loss: {test_loss}, Accuracy: {accuracy}"
        )

    return _evaluate_global_model


def main():
    fix_seed(seed)

    comm = MPI.COMM_WORLD
    myid = comm.Get_rank()
    size = comm.Get_size()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = Net()
    model = model.to(device)
    optimizer = optim.SGD(model.parameters(), lr=lr)

    if myid == 0:
        dataloader = prepare_dataloader(size - 1, myid, train=False)
        client_ids = list(range(1, size))
        server = MPIFedAVGServer(comm, model, myid, client_ids, myid, lr, "sgd", device=device)
        api = MPIFedAVGAPI(
            comm,
            server,
            True,
            F.nll_loss,
            None,
            None,
            num_rounds,
            1,
            custom_action=evaluate_gloal_model(dataloader),
            device=device
        )
    else:
        dataloader = prepare_dataloader(size - 1, myid, train=True)
        client = MPIFedAVGClient(comm, model, myid, lr, device=device)
        api = MPIFedAVGAPI(
            comm,
            client,
            False,
            F.nll_loss,
            optimizer,
            dataloader,
            num_rounds,
            1,
            device=device
        )

    t1 = MPI.Wtime()
    api.run()
    t2 = MPI.Wtime()

    t0 = np.ndarray(1, dtype="float64")
    t_w = np.ndarray(1, dtype="float64")
    t0[0] = t2 - t1
    comm.Reduce(t0, t_w, op=MPI.MAX, root=0)
    if myid == 0:
        print("Execution time = : ", t_w[0], "  [sec.] \n")


if __name__ == "__main__":
    main()

Writing mpi_fedavg.py


In [7]:
!mpiexec -np 3 --allow-run-as-root python /content/mpi_fedavg.py

Round: 1, Test set: Average loss: 0.7860309120178223, Accuracy: 82.72
Round: 2, Test set: Average loss: 0.5885528892040253, Accuracy: 86.04
Round: 3, Test set: Average loss: 0.5102099300861359, Accuracy: 87.33
Round: 4, Test set: Average loss: 0.46664143724441526, Accuracy: 88.01
Round: 5, Test set: Average loss: 0.43830649552345274, Accuracy: 88.65
Execution time = :  87.10973795500001   [sec.] 

