In [1]:
%%shell

apt install -y libboost-all-dev
pip install -U pip
pip install "pybind11[global]"

pip install git+https://github.com/Koukyosyumei/AIJack@dba

Reading package lists... Done
Building dependency tree       
Reading state information... Done
libboost-all-dev is already the newest version (1.65.1.0ubuntu1).
The following package was automatically installed and is no longer required:
  libnvidia-common-460
Use 'apt autoremove' to remove it.
0 upgraded, 0 newly installed, 0 to remove and 20 not upgraded.
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pip
  Downloading pip-22.3.1-py3-none-any.whl (2.1 MB)
[K     |████████████████████████████████| 2.1 MB 6.5 MB/s 
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 21.1.3
    Uninstalling pip-21.1.3:
      Successfully uninstalled pip-21.1.3
Successfully installed pip-22.3.1
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pybind11[global]
  Downloading pybind11-2.10.1-py3-none-any.whl (216 kB)
[2K     [90



# FedAVG

In this tutorial, you will learn how to simulate FedAVG, representative scheme of Federated Learning, with AIJack. You can choose the single process or MPI as the backend.

## Single Process

In [2]:
import random

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from mpi4py import MPI
from torchvision import datasets, transforms

from aijack.collaborative.fedavg import FedAVGClient, FedAVGServer, FedAVGAPI


def evaluate_gloal_model(dataloader, client_id=-1):
    def _evaluate_global_model(api):
        test_loss = 0
        correct = 0
        with torch.no_grad():
            for data, target in dataloader:
                data, target = data.to(api.device), target.to(api.device)
                if client_id == -1:
                  output = api.server(data)
                else:
                  output = api.clients[client_id](data)
                test_loss += F.nll_loss(
                    output, target, reduction="sum"
                ).item()  # sum up batch loss
                pred = output.argmax(
                    dim=1, keepdim=True
                )  # get the index of the max log-probability
                correct += pred.eq(target.view_as(pred)).sum().item()

        test_loss /= len(dataloader.dataset)
        accuracy = 100.0 * correct / len(dataloader.dataset)
        print(
            f"Test set: Average loss: {test_loss}, Accuracy: {accuracy}"
        )

    return _evaluate_global_model

In [3]:
training_batch_size = 64
test_batch_size = 64
num_rounds = 5
lr = 0.001
seed = 0
client_size = 2
criterion = F.nll_loss

In [4]:
def fix_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True


def prepare_dataloader(num_clients, myid, train=True, path=""):
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
    )
    if train:
        dataset = datasets.MNIST(path, train=True, download=True, transform=transform)
        idxs = list(range(len(dataset.data)))
        random.shuffle(idxs)
        idx = np.array_split(idxs, num_clients, 0)[myid - 1]
        dataset.data = dataset.data[idx]
        dataset.targets = dataset.targets[idx]
        train_loader = torch.utils.data.DataLoader(
            dataset, batch_size=training_batch_size
        )
        return train_loader
    else:
        dataset = datasets.MNIST(path, train=False, download=True, transform=transform)
        test_loader = torch.utils.data.DataLoader(dataset, batch_size=test_batch_size)
        return test_loader


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.ln = nn.Linear(28 * 28, 10)

    def forward(self, x):
        x = self.ln(x.reshape(-1, 28 * 28))
        output = F.log_softmax(x, dim=1)
        return output

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
fix_seed(seed)

In [6]:
local_dataloaders = [prepare_dataloader(client_size, c) for c in range(client_size)]
test_dataloader = prepare_dataloader(client_size, -1, train=False)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting MNIST/raw/train-images-idx3-ubyte.gz to MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting MNIST/raw/train-labels-idx1-ubyte.gz to MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting MNIST/raw/t10k-images-idx3-ubyte.gz to MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting MNIST/raw/t10k-labels-idx1-ubyte.gz to MNIST/raw



In [7]:
clients = [FedAVGClient(Net().to(device), user_id=c) for c in range(client_size)]
local_optimizers = [optim.SGD(client.parameters(), lr=lr) for client in clients]

server = FedAVGServer(clients, Net().to(device))

api = FedAVGAPI(
    server,
    clients,
    criterion,
    local_optimizers,
    local_dataloaders,
    num_communication=num_rounds,
    custom_action=evaluate_gloal_model(test_dataloader)
)
api.run()

communication 0, epoch 0: client-1 0.019623182545105616
communication 0, epoch 0: client-2 0.019723439192771912
Test set: Average loss: 0.7824367607116699, Accuracy: 83.71
communication 1, epoch 0: client-1 0.010717547312378884
communication 1, epoch 0: client-2 0.01085114210943381
Test set: Average loss: 0.5854546638488769, Accuracy: 86.49
communication 2, epoch 0: client-1 0.008766427417596182
communication 2, epoch 0: client-2 0.008916550938288371
Test set: Average loss: 0.5077689335346222, Accuracy: 87.54
communication 3, epoch 0: client-1 0.007839484986662865
communication 3, epoch 0: client-2 0.007999675015608469
Test set: Average loss: 0.4647755696773529, Accuracy: 88.25
communication 4, epoch 0: client-1 0.00727825770676136
communication 4, epoch 0: client-2 0.007445397703349591
Test set: Average loss: 0.4369198709487915, Accuracy: 88.63


### Federated Learning with Paillier Encryption

In [8]:
from aijack.defense import PaillierGradientClientManager, PaillierKeyGenerator

keygenerator = PaillierKeyGenerator(64)
pk, sk = keygenerator.generate_keypair()

manager = PaillierGradientClientManager(pk, sk)
PaillierGradFedAVGClient = manager.attach(FedAVGClient)

clients = [
    PaillierGradFedAVGClient(Net().to(device), user_id=c, server_side_update=False)
    for c in range(client_size)
]
local_optimizers = [optim.SGD(client.parameters(), lr=lr) for client in clients]

server = FedAVGServer(clients, Net().to(device), server_side_update=False)

api = FedAVGAPI(
    server,
    clients,
    criterion,
    local_optimizers,
    local_dataloaders,
    num_communication=num_rounds,
    custom_action=evaluate_gloal_model(test_dataloader, 0)
)
api.run()

communication 0, epoch 0: client-1 0.01997546571890513
communication 0, epoch 0: client-2 0.020125101908047994


  input._paillier_np_array + other.detach().cpu().numpy()


Test set: Average loss: 0.5059196502208709, Accuracy: 84.52
communication 1, epoch 0: client-1 0.007643952090044816
communication 1, epoch 0: client-2 0.007840833148360253
Test set: Average loss: 0.44262871532440184, Accuracy: 87.33
communication 2, epoch 0: client-1 0.006744246105353038
communication 2, epoch 0: client-2 0.006942570747931798
Test set: Average loss: 0.40395034172534944, Accuracy: 88.34
communication 3, epoch 0: client-1 0.006300356099506219
communication 3, epoch 0: client-2 0.006500222749014696
Test set: Average loss: 0.3897844295024872, Accuracy: 89.0
communication 4, epoch 0: client-1 0.0060082643752296765
communication 4, epoch 0: client-2 0.006209123346706232
Test set: Average loss: 0.3705228189945221, Accuracy: 89.22


## MPI

In [9]:
%%writefile mpi_FedAVG.py
import random
from logging import getLogger

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from mpi4py import MPI
from torchvision import datasets, transforms

from aijack.collaborative import FedAVGClient, FedAVGServer, MPIFedAVGAPI, MPIFedAVGClientManager, MPIFedAVGServerManager

logger = getLogger(__name__)

training_batch_size = 64
test_batch_size = 64
num_rounds = 5
lr = 0.001
seed = 0


def fix_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True


def prepare_dataloader(num_clients, myid, train=True, path=""):
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
    )
    if train:
        dataset = datasets.MNIST(path, train=True, download=False, transform=transform)
        idxs = list(range(len(dataset.data)))
        random.shuffle(idxs)
        idx = np.array_split(idxs, num_clients, 0)[myid - 1]
        dataset.data = dataset.data[idx]
        dataset.targets = dataset.targets[idx]
        train_loader = torch.utils.data.DataLoader(
            dataset, batch_size=training_batch_size
        )
        return train_loader
    else:
        dataset = datasets.MNIST(path, train=False, download=False, transform=transform)
        test_loader = torch.utils.data.DataLoader(dataset, batch_size=test_batch_size)
        return test_loader


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.ln = nn.Linear(28 * 28, 10)

    def forward(self, x):
        x = self.ln(x.reshape(-1, 28 * 28))
        output = F.log_softmax(x, dim=1)
        return output


def evaluate_gloal_model(dataloader):
    def _evaluate_global_model(api):
        test_loss = 0
        correct = 0
        with torch.no_grad():
            for data, target in dataloader:
                data, target = data.to(api.device), target.to(api.device)
                output = api.party(data)
                test_loss += F.nll_loss(
                    output, target, reduction="sum"
                ).item()  # sum up batch loss
                pred = output.argmax(
                    dim=1, keepdim=True
                )  # get the index of the max log-probability
                correct += pred.eq(target.view_as(pred)).sum().item()

        test_loss /= len(dataloader.dataset)
        accuracy = 100.0 * correct / len(dataloader.dataset)
        print(
            f"Round: {api.party.round}, Test set: Average loss: {test_loss}, Accuracy: {accuracy}"
        )

    return _evaluate_global_model


def main():
    fix_seed(seed)

    comm = MPI.COMM_WORLD
    myid = comm.Get_rank()
    size = comm.Get_size()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = Net()
    model = model.to(device)
    optimizer = optim.SGD(model.parameters(), lr=lr)

    mpi_client_manager = MPIFedAVGClientManager()
    mpi_server_manager = MPIFedAVGServerManager()
    MPIFedAVGClient = mpi_client_manager.attach(FedAVGClient)
    MPIFedAVGServer = mpi_server_manager.attach(FedAVGServer)

    if myid == 0:
        dataloader = prepare_dataloader(size - 1, myid, train=False)
        client_ids = list(range(1, size))
        server = MPIFedAVGServer(comm, [1, 2], model)
        api = MPIFedAVGAPI(
            comm,
            server,
            True,
            F.nll_loss,
            None,
            None,
            num_rounds,
            1,
            custom_action=evaluate_gloal_model(dataloader),
            device=device
        )
    else:
        dataloader = prepare_dataloader(size - 1, myid, train=True)
        client = MPIFedAVGClient(comm, model, user_id=myid)
        api = MPIFedAVGAPI(
            comm,
            client,
            False,
            F.nll_loss,
            optimizer,
            dataloader,
            num_rounds,
            1,
            device=device
        )

    api.run()


if __name__ == "__main__":
    main()

Writing mpi_FedAVG.py


In [10]:
!mpiexec -np 3 --allow-run-as-root python /content/mpi_FedAVG.py

communication 0, epoch 0: client-2 0.02008056694070498
communication 0, epoch 0: client-3 0.019996537216504413
Round: 1, Test set: Average loss: 0.7860309104919434, Accuracy: 82.72
communication 1, epoch 0: client-3 0.010822976715366046
communication 1, epoch 0: client-2 0.010937693453828494
Round: 2, Test set: Average loss: 0.5885528886795044, Accuracy: 86.04
communication 2, epoch 0: client-2 0.008990796900788942
communication 2, epoch 0: client-3 0.008850129560629527
Round: 3, Test set: Average loss: 0.5102099328994751, Accuracy: 87.33
communication 3, epoch 0: client-2 0.008069112183650334
communication 3, epoch 0: client-3 0.00791173183619976
Round: 4, Test set: Average loss: 0.4666414333820343, Accuracy: 88.01
communication 4, epoch 0: client-3 0.007343090359369914
communication 4, epoch 0: client-2 0.007512268128991127
Round: 5, Test set: Average loss: 0.4383064950466156, Accuracy: 88.65


### MPI + Sparse Gradient

In [11]:
%%writefile mpi_FedAVG_sparse.py
import random
from logging import getLogger

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from mpi4py import MPI
from torchvision import datasets, transforms

from aijack.collaborative import FedAVGClient, FedAVGServer, MPIFedAVGAPI, MPIFedAVGClientManager, MPIFedAVGServerManager
from aijack.defense.sparse import (
    SparseGradientClientManager,
    SparseGradientServerManager,
)

logger = getLogger(__name__)

training_batch_size = 64
test_batch_size = 64
num_rounds = 5
lr = 0.001
seed = 0


def fix_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True


def prepare_dataloader(num_clients, myid, train=True, path=""):
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
    )
    if train:
        dataset = datasets.MNIST(path, train=True, download=False, transform=transform)
        idxs = list(range(len(dataset.data)))
        random.shuffle(idxs)
        idx = np.array_split(idxs, num_clients, 0)[myid - 1]
        dataset.data = dataset.data[idx]
        dataset.targets = dataset.targets[idx]
        train_loader = torch.utils.data.DataLoader(
            dataset, batch_size=training_batch_size
        )
        return train_loader
    else:
        dataset = datasets.MNIST(path, train=False, download=False, transform=transform)
        test_loader = torch.utils.data.DataLoader(dataset, batch_size=test_batch_size)
        return test_loader


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.ln = nn.Linear(28 * 28, 10)

    def forward(self, x):
        x = self.ln(x.reshape(-1, 28 * 28))
        output = F.log_softmax(x, dim=1)
        return output


def evaluate_gloal_model(dataloader):
    def _evaluate_global_model(api):
        test_loss = 0
        correct = 0
        with torch.no_grad():
            for data, target in dataloader:
                data, target = data.to(api.device), target.to(api.device)
                output = api.party(data)
                test_loss += F.nll_loss(
                    output, target, reduction="sum"
                ).item()  # sum up batch loss
                pred = output.argmax(
                    dim=1, keepdim=True
                )  # get the index of the max log-probability
                correct += pred.eq(target.view_as(pred)).sum().item()

        test_loss /= len(dataloader.dataset)
        accuracy = 100.0 * correct / len(dataloader.dataset)
        print(
            f"Round: {api.party.round}, Test set: Average loss: {test_loss}, Accuracy: {accuracy}"
        )

    return _evaluate_global_model


def main():
    fix_seed(seed)

    comm = MPI.COMM_WORLD
    myid = comm.Get_rank()
    size = comm.Get_size()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = Net()
    model = model.to(device)
    optimizer = optim.SGD(model.parameters(), lr=lr)

    sg_client_manager = SparseGradientClientManager(k=0.03)
    mpi_client_manager = MPIFedAVGClientManager()
    SparseGradientFedAVGClient = sg_client_manager.attach(FedAVGClient)
    MPISparseGradientFedAVGClient = mpi_client_manager.attach(SparseGradientFedAVGClient)

    sg_server_manager = SparseGradientServerManager()
    mpi_server_manager = MPIFedAVGServerManager()
    SparseGradientFedAVGServer = sg_server_manager.attach(FedAVGServer)
    MPISparseGradientFedAVGServer = mpi_server_manager.attach(SparseGradientFedAVGServer)

    if myid == 0:
        dataloader = prepare_dataloader(size - 1, myid, train=False)
        client_ids = list(range(1, size))
        server = MPISparseGradientFedAVGServer(comm, [1, 2], model)
        api = MPIFedAVGAPI(
            comm,
            server,
            True,
            F.nll_loss,
            None,
            None,
            num_rounds,
            1,
            custom_action=evaluate_gloal_model(dataloader),
            device=device,
        )
    else:
        dataloader = prepare_dataloader(size - 1, myid, train=True)
        client = MPISparseGradientFedAVGClient(comm, model, user_id=myid)
        api = MPIFedAVGAPI(
            comm,
            client,
            False,
            F.nll_loss,
            optimizer,
            dataloader,
            num_rounds,
            1,
            device=device,
        )

    api.run()


if __name__ == "__main__":
    main()

Writing mpi_FedAVG_sparse.py


In [12]:
!mpiexec -np 3 --allow-run-as-root python /content/mpi_FedAVG_sparse.py

communication 0, epoch 0: client-2 0.02008056694070498
communication 0, epoch 0: client-3 0.019996537216504413
Round: 1, Test set: Average loss: 1.7728474597930908, Accuracy: 38.47
communication 1, epoch 0: client-2 0.016343721010287603
communication 1, epoch 0: client-3 0.016255500958363214
Round: 2, Test set: Average loss: 1.4043720769882202, Accuracy: 60.5
communication 2, epoch 0: client-3 0.014260987114906311
communication 2, epoch 0: client-2 0.014353630113601685
Round: 3, Test set: Average loss: 1.1684634439468384, Accuracy: 70.27
communication 3, epoch 0: client-2 0.013123111790418624
communication 3, epoch 0: client-3 0.013032549581925075
Round: 4, Test set: Average loss: 1.0258800836563111, Accuracy: 75.0
communication 4, epoch 0: client-3 0.012150899289051692
communication 4, epoch 0: client-2 0.012242827371756236
Round: 5, Test set: Average loss: 0.9197616576194764, Accuracy: 77.6
