In [None]:
import copy
from logging import Formatter, StreamHandler, getLogger
import os
import random
from typing import Dict, List

import numpy as np
import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Subset
import torchvision
import torchvision.transforms as transforms

In [None]:
config = {
    "batch_size": 128,
    "device": "cuda:0",
    "epochs": 5,
    "learning_rate": 0.0001,
    "num_clients": 5,
    "rounds": 5,
    "seed": 42,
}

In [None]:
def get_logger(name="EXP"):
    log_fmt = Formatter(f"%(asctime)s [{name}][%(levelname)s] %(message)s ")
    logger = getLogger(__name__)
    handler = StreamHandler()
    handler.setLevel("INFO")
    handler.setFormatter(log_fmt)
    logger.setLevel("INFO")
    logger.addHandler(handler)
    logger.propagate = False
    return logger

def seed_everything(seed=1234):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

In [None]:
class FedAvgRetriever():
    def __init__(self, config: Dict) -> None:
        self.config = config

        # init CIFAR-10
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        self.train_set = torchvision.datasets.CIFAR10(
            root='./data', train=True, download=True, transform=transform)
        self.test_set = torchvision.datasets.CIFAR10(
            root='./data', train=False, download=True, transform=transform)

    def get(self, num_clients: int) -> List[Dict[str, DataLoader]]:
        train_size = len(self.train_set) // num_clients
        test_size = len(self.test_set) // num_clients

        dataloaders = []
        for i in range(num_clients):
            train_idx = range(i * train_size, (i + 1) * train_size)
            test_idx = range(i * test_size, (i + 1) * test_size)

            train_set = Subset(self.train_set, train_idx)
            test_set = Subset(self.test_set, test_idx)

            train_loader = DataLoader(
                train_set, batch_size=self.config["batch_size"], shuffle=True, num_workers=4)
            test_loader = DataLoader(
                test_set, batch_size=self.config["batch_size"], shuffle=False, num_workers=4)
            dataloaders.append({
                "train": train_loader,
                "test": test_loader
            })

        return dataloaders

In [None]:
class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(2)

        self.conv1 = nn.Conv2d(3, 16, 3)
        self.conv2 = nn.Conv2d(16, 128, 3)

        self.fc1 = nn.Linear(128, 64)
        self.fc2 = nn.Linear(64, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.pool(x)
        x = self.conv2(x)
        x = self.relu(x)

        x_size = x.size()
        x = x.reshape(x_size[0], -1, x_size[2] ** 2).mean(2)

        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

In [None]:
class FedAvgClient():
    def __init__(self, config: Dict, logger, dataloaders: Dict[str, DataLoader]) -> None:
        self.config = config
        self.train_loader = dataloaders["train"]
        self.test_loader = dataloaders["test"]

        self.model: CNNModel = None  # type: ignore

    def set_global_model(self, global_model: CNNModel):
        del self.model
        self.model = copy.deepcopy(global_model)

    def train_local_model(self) -> CNNModel:
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(self.model.parameters(),
                               lr=self.config["learning_rate"])

        def _train_loop():
            running_loss = 0
            for i, (X, y) in enumerate(self.train_loader):
                X = X.to(self.config["device"])
                y = y.to(self.config["device"])
                optimizer.zero_grad()
                outputs = self.model(X)
                loss = criterion(outputs, y)
                running_loss += loss.item()
                loss.backward()
                optimizer.step()
            return running_loss / len(self.train_loader)

        def _test_loop():
            labels = []
            preds = []
            running_loss = 0
            with torch.no_grad():
                for i, (X, y) in enumerate(self.test_loader):
                    X = X.to(self.config["device"])
                    y = y.to(self.config["device"])
                    output = self.model(X)
                    loss = criterion(output, y)
                    running_loss += loss.item()
                    labels.append(y.detach().cpu())
                    preds.append(output.detach().cpu())
            labels = torch.cat(labels, dim=0)
            preds = torch.cat(preds, dim=0).argmax(1)
            acc = torch.sum(preds == labels) / len(labels)
            return running_loss / len(self.test_loader), acc.item()

        self.model.to(self.config["device"])
        for epoch in range(self.config["epochs"]):
            self.model.train()
            train_loss = _train_loop()
            self.model.eval()
            test_loss, acc = _test_loop()

            logger.info(f"train_loss: {train_loss}, test_loss: {test_loss}, acc: {acc}")

        return self.model


class FedAvgServer():
    def __init__(self, config: Dict) -> None:
        self.config = config

    def init_global_model(self) -> CNNModel:
        global_model = CNNModel()
        return global_model

    def aggregate(self, local_models: List[CNNModel]) -> CNNModel:
        global_model = CNNModel().to(self.config["device"])

        # conv 1
        global_model.conv1.weight.data = torch.zeros_like(
            global_model.conv1.weight.data)
        for local_model in local_models:
            global_model.conv1.weight.data += local_model.conv1.weight.data
        global_model.conv1.weight.data /= len(local_models)

        # conv 2
        global_model.conv2.weight.data = torch.zeros_like(
            global_model.conv2.weight.data)
        for local_model in local_models:
            global_model.conv2.weight.data += local_model.conv2.weight.data
        global_model.conv2.weight.data /= len(local_models)

        # fc1
        global_model.fc1.weight.data = torch.zeros_like(
            global_model.fc1.weight.data)
        for local_model in local_models:
            global_model.fc1.weight.data += local_model.fc1.weight.data
        global_model.fc1.weight.data /= len(local_models)

        # fc2
        global_model.fc2.weight.data = torch.zeros_like(
            global_model.fc2.weight.data)
        for local_model in local_models:
            global_model.fc2.weight.data += local_model.fc2.weight.data
        global_model.fc2.weight.data /= len(local_models)

        return global_model

In [None]:
seed_everything(config["seed"])
logger = get_logger()

logger.info("FedAvg start")

# init dataset
retriever = FedAvgRetriever(config)
dataloaders = retriever.get(config["num_clients"])

# init server and clients
server = FedAvgServer(config)
clients = [FedAvgClient(config, logger, dataloaders[i])
            for i in range(config["num_clients"])]

# init FedAvg
global_model = server.init_global_model()
for client in clients:
    client.set_global_model(global_model)

# round loop
for round in range(config["rounds"]):

    logger.info(f"Round: {round} start")

    local_models = []
    for i, client in enumerate(clients):
        logger.info(f"train client: {i}")
        local_model = client.train_local_model()
        local_models.append(local_model)

    logger.info(f"Aggregate global model")
    global_model = server.aggregate(local_models)
    logger.info(f"Distribute global model")
    for client in clients:
        client.set_global_model(global_model)

logger.info("FedAvg finished")