In [1]:
!pip install -U pip
!pip install "pybind11[global]"
!pip install git+https://github.com/Koukyosyumei/AIJack@sparse_matrix

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pip
  Downloading pip-22.3.1-py3-none-any.whl (2.1 MB)
[K     |████████████████████████████████| 2.1 MB 7.7 MB/s 
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 21.1.3
    Uninstalling pip-21.1.3:
      Successfully uninstalled pip-21.1.3
Successfully installed pip-22.3.1
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pybind11[global]
  Downloading pybind11-2.10.2-py3-none-any.whl (222 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m222.1/222.1 kB[0m [31m7.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting pybind11-global==2.10.2
  Downloading pybind11_global-2.10.2-py3-none-any.whl (400 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m400.9/400.9 kB[0m [31m28.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected p

# FedMD: Federated Learning with Model Distillation

## Single Process

In [2]:
import random

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from mpi4py import MPI
from torchvision import datasets, transforms

from aijack.collaborative import FedMDClient, FedMDServer
from aijack.collaborative.fedmd import FedMDAPI
from aijack.utils import NumpyDataset

In [3]:
def fix_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True


def prepare_dataloader(num_clients, myid, train=True, path=""):
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
    )
    if train:
        dataset = datasets.MNIST(path, train=True, download=True, transform=transform)
        idxs = list(range(len(dataset.data)))
        random.shuffle(idxs)
        idx = np.array_split(idxs, num_clients, 0)[myid - 1]
        dataset.data = dataset.data[idx]
        dataset.targets = dataset.targets[idx]
        train_loader = torch.utils.data.DataLoader(
            NumpyDataset(
                x=dataset.data.numpy(),
                y=dataset.targets.numpy(),
                transform=transform,
                return_idx=True,
            ),
            batch_size=training_batch_size,
        )
        return train_loader
    else:
        dataset = datasets.MNIST(path, train=False, download=True, transform=transform)
        test_loader = torch.utils.data.DataLoader(
            NumpyDataset(
                x=dataset.data.numpy(),
                y=dataset.targets.numpy(),
                transform=transform,
                return_idx=True,
            ),
            batch_size=test_batch_size,
        )
        return test_loader


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.ln = nn.Linear(28 * 28, 10)

    def forward(self, x):
        x = self.ln(x.reshape(-1, 28 * 28))
        output = F.log_softmax(x, dim=1)
        return output

In [4]:
training_batch_size = 64
test_batch_size = 64
num_rounds = 5
lr = 0.001
seed = 0
client_size = 2
criterion = F.nll_loss

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
fix_seed(seed)

In [5]:
dataloaders = [prepare_dataloader(client_size + 1, c) for c in range(client_size + 1)]
public_dataloader = dataloaders[0]
local_dataloaders = dataloaders[1:]
test_dataloader = prepare_dataloader(client_size, -1, train=False)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting MNIST/raw/train-images-idx3-ubyte.gz to MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting MNIST/raw/train-labels-idx1-ubyte.gz to MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting MNIST/raw/t10k-images-idx3-ubyte.gz to MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting MNIST/raw/t10k-labels-idx1-ubyte.gz to MNIST/raw



In [6]:
clients = [
    FedMDClient(Net().to(device), public_dataloader, output_dim=10, user_id=c)
    for c in range(client_size)
]
local_optimizers = [optim.SGD(client.parameters(), lr=lr) for client in clients]

server = FedMDServer(clients, Net().to(device))

In [7]:
api = FedMDAPI(
    server,
    clients,
    public_dataloader,
    local_dataloaders,
    F.nll_loss,
    local_optimizers,
    test_dataloader,
    num_communication=2,
)
log = api.run()

epoch 1 (public - pretrain): [1.473225961858853, 1.5095995597945997]
acc on validation dataset:  {'clients_score': [0.7988, 0.7907]}
epoch 1 (local - pretrain): [0.831909927316367, 0.8403522956866426]
acc on validation dataset:  {'clients_score': [0.8431, 0.8406]}
epoch 1, client 0: 248.21629628539085
epoch 1, client 1: 269.46991488337517
epoch=1 acc on local datasets:  {'clients_score': [0.84605, 0.85175]}
epoch=1 acc on public dataset:  {'clients_score': [0.84925, 0.8516]}
epoch=1 acc on validation dataset:  {'clients_score': [0.8568, 0.8594]}
epoch 2, client 0: 348.2690239548683
epoch 2, client 1: 364.190059453249
epoch=2 acc on local datasets:  {'clients_score': [0.85075, 0.85555]}
epoch=2 acc on public dataset:  {'clients_score': [0.85395, 0.8567]}
epoch=2 acc on validation dataset:  {'clients_score': [0.8601, 0.8641]}


## MPI

In [8]:
%%writefile mpi_fedmd.py
import random
from logging import getLogger

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from mpi4py import MPI
from torchvision import datasets, transforms

from aijack.collaborative import FedMDClient, FedMDServer
from aijack.collaborative.fedmd import FedMDAPI
from aijack.collaborative.fedmd.api import MPIFedMDAPI
from aijack.collaborative.fedmd.client import MPIFedMDClient
from aijack.collaborative.fedmd.server import MPIFedMDServer
from aijack.utils import NumpyDataset, accuracy_torch_dataloader

logger = getLogger(__name__)

training_batch_size = 64
test_batch_size = 64
num_rounds = 2
lr = 0.001
seed = 0


def fix_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True


def prepare_dataloader(num_clients, myid, train=True, path=""):
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
    )
    if train:
        dataset = datasets.MNIST(path, train=True, download=True, transform=transform)
        idxs = list(range(len(dataset.data)))
        random.shuffle(idxs)
        idx = np.array_split(idxs, num_clients, 0)[myid - 1]
        dataset.data = dataset.data[idx]
        dataset.targets = dataset.targets[idx]
        train_loader = torch.utils.data.DataLoader(
            NumpyDataset(x=dataset.data.numpy(), y=dataset.targets.numpy(), transform=transform, return_idx=True),
             batch_size=training_batch_size
        )
        return train_loader
    else:
        dataset = datasets.MNIST(path, train=False, download=True, transform=transform)
        test_loader = torch.utils.data.DataLoader(NumpyDataset(x=dataset.data.numpy(), y=dataset.targets.numpy(), transform=transform, return_idx=True),
                                                  batch_size=test_batch_size)
        return test_loader


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.ln = nn.Linear(28 * 28, 10)

    def forward(self, x):
        x = self.ln(x.reshape(-1, 28 * 28))
        output = F.log_softmax(x, dim=1)
        return output

def main():
    fix_seed(seed)

    comm = MPI.COMM_WORLD
    myid = comm.Get_rank()
    size = comm.Get_size()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = Net()
    model = model.to(device)
    optimizer = optim.SGD(model.parameters(), lr=lr)

    public_dataloader = prepare_dataloader(size - 1, 0, train=True)

    if myid == 0:
        dataloader = prepare_dataloader(size + 1, myid+1, train=False)
        client_ids = list(range(1, size))
        server = MPIFedMDServer(comm, FedMDServer([1, 2], model))
        api = MPIFedMDAPI(
            comm,
            server,
            True,
            F.nll_loss,
            None,
            None,
            num_communication=num_rounds,
            device=device
        )
    else:
        dataloader = prepare_dataloader(size + 1, myid + 1, train=True)
        client = MPIFedMDClient(comm, FedMDClient(model, public_dataloader, output_dim=10, user_id=myid))
        api = MPIFedMDAPI(
            comm,
            client,
            False,
            F.nll_loss,
            optimizer,
            dataloader,
            public_dataloader,
            num_communication=num_rounds,
            device=device
        )

    api.run()

    if myid != 0:
      print(f"client_id={myid}: Accuracy on local dataset is ", accuracy_torch_dataloader(client, dataloader))


if __name__ == "__main__":
    main()

Writing mpi_fedmd.py


In [9]:
!mpiexec -np 3 --allow-run-as-root python /content/mpi_fedmd.py

client_id=1: Accuracy on local dataset is  0.8734666666666666
client_id=2: Accuracy on local dataset is  0.8708
