# FedAVG

In this tutorial, you will learn how to simulate FedAVG, representative scheme of Federated Learning, with AIJack. You can choose the single process or MPI as the backend.

## Single Process

In [2]:
import random

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from mpi4py import MPI
from torchvision import datasets, transforms

from aijack.collaborative import FedAvgClient, FedAvgServer
from aijack.collaborative.fedavg import FedAVGAPI

In [3]:
def fix_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True


def prepare_dataloader(num_clients, myid, train=True, path=""):
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
    )
    if train:
        dataset = datasets.MNIST(path, train=True, download=True, transform=transform)
        idxs = list(range(len(dataset.data)))
        random.shuffle(idxs)
        idx = np.array_split(idxs, num_clients, 0)[myid - 1]
        dataset.data = dataset.data[idx]
        dataset.targets = dataset.targets[idx]
        train_loader = torch.utils.data.DataLoader(
            dataset, batch_size=training_batch_size
        )
        return train_loader
    else:
        dataset = datasets.MNIST(path, train=False, download=True, transform=transform)
        test_loader = torch.utils.data.DataLoader(dataset, batch_size=test_batch_size)
        return test_loader


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.ln = nn.Linear(28 * 28, 10)

    def forward(self, x):
        x = self.ln(x.reshape(-1, 28 * 28))
        output = F.log_softmax(x, dim=1)
        return output

In [4]:
training_batch_size = 64
test_batch_size = 64
num_rounds = 5
lr = 0.001
seed = 0
client_size = 2
criterion = F.nll_loss

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
fix_seed(seed)

In [None]:
local_dataloaders = [prepare_dataloader(client_size, c) for c in range(client_size)]
test_dataloader = prepare_dataloader(client_size, -1, train=False)

In [6]:
clients = [FedAvgClient(Net().to(device), user_id=c) for c in range(client_size)]
local_optimizers = [optim.SGD(client.parameters(), lr=lr) for client in clients]

server = FedAvgServer(clients, Net().to(device))

for round in range(1, num_rounds + 1):
    for client, local_trainloader, local_optimizer in zip(
        clients, local_dataloaders, local_optimizers
    ):
        for data in local_trainloader:
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)
            local_optimizer.zero_grad()
            outputs = client(inputs)
            loss = criterion(outputs, labels.to(torch.int64))
            client.backward(loss)
            local_optimizer.step()
    server.action()

    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_dataloader:
            data, target = data.to(device), target.to(device)
            output = server(data)
            test_loss += F.nll_loss(
                output, target, reduction="sum"
            ).item()  # sum up batch loss
            pred = output.argmax(
                dim=1, keepdim=True
            )  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_dataloader.dataset)
    accuracy = 100.0 * correct / len(test_dataloader.dataset)
    print(f"Round: {round}, Test set: Average loss: {test_loss}, Accuracy: {accuracy}")

Round: 1, Test set: Average loss: 0.7824367607116699, Accuracy: 83.71
Round: 2, Test set: Average loss: 0.5854546638488769, Accuracy: 86.49
Round: 3, Test set: Average loss: 0.5077689335346222, Accuracy: 87.54
Round: 4, Test set: Average loss: 0.4647755696773529, Accuracy: 88.25
Round: 5, Test set: Average loss: 0.4369198709487915, Accuracy: 88.63


### Federated Learning with Paillier Encryption

In [None]:
from aijack.defense import PaillierGradientClientManager, PaillierKeyGenerator

keygenerator = PaillierKeyGenerator(64)
pk, sk = keygenerator.generate_keypair()

manager = PaillierGradientClientManager(pk, sk)
PaillierGradFedAvgClient = manager.attach(FedAvgClient)

clients = [
    PaillierGradFedAvgClient(Net().to(device), user_id=c, server_side_update=False)
    for c in range(client_size)
]
local_optimizers = [optim.SGD(client.parameters(), lr=lr) for client in clients]

server = FedAvgServer(clients, Net().to(device), server_side_update=False)

for round in range(1, num_rounds + 1):
    for client, local_trainloader, local_optimizer in zip(
        clients, local_dataloaders, local_optimizers
    ):
        for data in local_trainloader:
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)
            local_optimizer.zero_grad()
            outputs = client(inputs)
            loss = criterion(outputs, labels.to(torch.int64))
            client.backward(loss)
            local_optimizer.step()
    server.action()

    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_dataloader:
            data, target = data.to(device), target.to(device)
            output = clients[0](data)
            test_loss += F.nll_loss(
                output, target, reduction="sum"
            ).item()  # sum up batch loss
            pred = output.argmax(
                dim=1, keepdim=True
            )  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_dataloader.dataset)
    accuracy = 100.0 * correct / len(test_dataloader.dataset)
    print(f"Round: {round}, Test set: Average loss: {test_loss}, Accuracy: {accuracy}")

In [8]:
test_loss = 0
correct = 0
with torch.no_grad():
    for data, target in test_dataloader:
        data, target = data.to(device), target.to(device)
        output = clients[0](data)
        test_loss += F.nll_loss(
            output, target, reduction="sum"
        ).item()  # sum up batch loss
        pred = output.argmax(
            dim=1, keepdim=True
        )  # get the index of the max log-probability
        correct += pred.eq(target.view_as(pred)).sum().item()

test_loss /= len(test_dataloader.dataset)
accuracy = 100.0 * correct / len(test_dataloader.dataset)
print(f"Round: {round}, Test set: Average loss: {test_loss}, Accuracy: {accuracy}")

Round: 5, Test set: Average loss: 0.3705228189945221, Accuracy: 89.22


## MPI

In [None]:
%%writefile mpi_fedavg.py
import random
from logging import getLogger

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from mpi4py import MPI
from torchvision import datasets, transforms

from aijack.collaborative import FedAvgClient, FedAvgServer
from aijack.collaborative.fedavg import MPIFedAVGAPI, MPIFedAvgClient, MPIFedAvgServer

logger = getLogger(__name__)

training_batch_size = 64
test_batch_size = 64
num_rounds = 5
lr = 0.001
seed = 0


def fix_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True


def prepare_dataloader(num_clients, myid, train=True, path=""):
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
    )
    if train:
        dataset = datasets.MNIST(path, train=True, download=False, transform=transform)
        idxs = list(range(len(dataset.data)))
        random.shuffle(idxs)
        idx = np.array_split(idxs, num_clients, 0)[myid - 1]
        dataset.data = dataset.data[idx]
        dataset.targets = dataset.targets[idx]
        train_loader = torch.utils.data.DataLoader(
            dataset, batch_size=training_batch_size
        )
        return train_loader
    else:
        dataset = datasets.MNIST(path, train=False, download=False, transform=transform)
        test_loader = torch.utils.data.DataLoader(dataset, batch_size=test_batch_size)
        return test_loader


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.ln = nn.Linear(28 * 28, 10)

    def forward(self, x):
        x = self.ln(x.reshape(-1, 28 * 28))
        output = F.log_softmax(x, dim=1)
        return output


def evaluate_gloal_model(dataloader):
    def _evaluate_global_model(api):
        test_loss = 0
        correct = 0
        with torch.no_grad():
            for data, target in dataloader:
                data, target = data.to(api.device), target.to(api.device)
                output = api.party.server(data)
                test_loss += F.nll_loss(
                    output, target, reduction="sum"
                ).item()  # sum up batch loss
                pred = output.argmax(
                    dim=1, keepdim=True
                )  # get the index of the max log-probability
                correct += pred.eq(target.view_as(pred)).sum().item()

        test_loss /= len(dataloader.dataset)
        accuracy = 100.0 * correct / len(dataloader.dataset)
        print(
            f"Round: {api.party.round}, Test set: Average loss: {test_loss}, Accuracy: {accuracy}"
        )

    return _evaluate_global_model


def main():
    fix_seed(seed)

    comm = MPI.COMM_WORLD
    myid = comm.Get_rank()
    size = comm.Get_size()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = Net()
    model = model.to(device)
    optimizer = optim.SGD(model.parameters(), lr=lr)

    if myid == 0:
        dataloader = prepare_dataloader(size - 1, myid, train=False)
        client_ids = list(range(1, size))
        server = MPIFedAvgServer(comm, FedAvgServer([1, 2], model))
        api = MPIFedAVGAPI(
            comm,
            server,
            True,
            F.nll_loss,
            None,
            None,
            num_rounds,
            1,
            custom_action=evaluate_gloal_model(dataloader),
            device=device
        )
    else:
        dataloader = prepare_dataloader(size - 1, myid, train=True)
        client = MPIFedAvgClient(comm, FedAvgClient(model, user_id=myid))
        api = MPIFedAVGAPI(
            comm,
            client,
            False,
            F.nll_loss,
            optimizer,
            dataloader,
            num_rounds,
            1,
            device=device
        )

    api.run()


if __name__ == "__main__":
    main()

In [10]:
!mpiexec -np 3 --allow-run-as-root python /content/mpi_fedavg.py

Round: 1, Test set: Average loss: 0.7860309104919434, Accuracy: 82.72
Round: 2, Test set: Average loss: 0.5885528886795044, Accuracy: 86.04
Round: 3, Test set: Average loss: 0.5102099328994751, Accuracy: 87.33
Round: 4, Test set: Average loss: 0.4666414333820343, Accuracy: 88.01
Round: 5, Test set: Average loss: 0.4383064950466156, Accuracy: 88.65


### MPI + Sparse Gradient

In [None]:
%%writefile mpi_fedavg_sparse.py
import random
from logging import getLogger

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from mpi4py import MPI
from torchvision import datasets, transforms

from aijack.collaborative import FedAvgClient, FedAvgServer
from aijack.collaborative.fedavg import MPIFedAVGAPI, MPIFedAvgClient, MPIFedAvgServer
from aijack.defense.sparse import (
    SparseGradientClientManager,
    SparseGradientServerManager,
)

logger = getLogger(__name__)

training_batch_size = 64
test_batch_size = 64
num_rounds = 5
lr = 0.001
seed = 0


def fix_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True


def prepare_dataloader(num_clients, myid, train=True, path=""):
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
    )
    if train:
        dataset = datasets.MNIST(path, train=True, download=False, transform=transform)
        idxs = list(range(len(dataset.data)))
        random.shuffle(idxs)
        idx = np.array_split(idxs, num_clients, 0)[myid - 1]
        dataset.data = dataset.data[idx]
        dataset.targets = dataset.targets[idx]
        train_loader = torch.utils.data.DataLoader(
            dataset, batch_size=training_batch_size
        )
        return train_loader
    else:
        dataset = datasets.MNIST(path, train=False, download=False, transform=transform)
        test_loader = torch.utils.data.DataLoader(dataset, batch_size=test_batch_size)
        return test_loader


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.ln = nn.Linear(28 * 28, 10)

    def forward(self, x):
        x = self.ln(x.reshape(-1, 28 * 28))
        output = F.log_softmax(x, dim=1)
        return output


def evaluate_gloal_model(dataloader):
    def _evaluate_global_model(api):
        test_loss = 0
        correct = 0
        with torch.no_grad():
            for data, target in dataloader:
                data, target = data.to(api.device), target.to(api.device)
                output = api.party.server(data)
                test_loss += F.nll_loss(
                    output, target, reduction="sum"
                ).item()  # sum up batch loss
                pred = output.argmax(
                    dim=1, keepdim=True
                )  # get the index of the max log-probability
                correct += pred.eq(target.view_as(pred)).sum().item()

        test_loss /= len(dataloader.dataset)
        accuracy = 100.0 * correct / len(dataloader.dataset)
        print(
            f"Round: {api.party.round}, Test set: Average loss: {test_loss}, Accuracy: {accuracy}"
        )

    return _evaluate_global_model


def main():
    fix_seed(seed)

    comm = MPI.COMM_WORLD
    myid = comm.Get_rank()
    size = comm.Get_size()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = Net()
    model = model.to(device)
    optimizer = optim.SGD(model.parameters(), lr=lr)

    client_manager = SparseGradientClientManager(k=0.03)
    SparseGradientFedAvgClient = client_manager.attach(FedAvgClient)

    server_manager = SparseGradientServerManager()
    SparseGradientFedAvgServer = server_manager.attach(FedAvgServer)

    if myid == 0:
        dataloader = prepare_dataloader(size - 1, myid, train=False)
        client_ids = list(range(1, size))
        server = MPIFedAvgServer(comm, SparseGradientFedAvgServer([1, 2], model))
        api = MPIFedAVGAPI(
            comm,
            server,
            True,
            F.nll_loss,
            None,
            None,
            num_rounds,
            1,
            custom_action=evaluate_gloal_model(dataloader),
            device=device,
        )
    else:
        dataloader = prepare_dataloader(size - 1, myid, train=True)
        client = MPIFedAvgClient(comm, SparseGradientFedAvgClient(model, user_id=myid))
        api = MPIFedAVGAPI(
            comm,
            client,
            False,
            F.nll_loss,
            optimizer,
            dataloader,
            num_rounds,
            1,
            device=device,
        )

    api.run()


if __name__ == "__main__":
    main()

In [12]:
!mpiexec -np 3 --allow-run-as-root python /content/mpi_fedavg_sparse.py

Round: 1, Test set: Average loss: 1.7728474597930908, Accuracy: 38.47
Round: 2, Test set: Average loss: 1.4043720769882202, Accuracy: 60.5
Round: 3, Test set: Average loss: 1.1684634439468384, Accuracy: 70.27
Round: 4, Test set: Average loss: 1.0258800836563111, Accuracy: 75.0
Round: 5, Test set: Average loss: 0.9197616576194764, Accuracy: 77.6
