In [42]:
import typing

import torch
from components.metrics import RunMetrics, RunResult
from components.tensor_types import IndexVector
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
from tqdm.notebook import tqdm

data_path = "../../datasets"

if torch.accelerator.is_available():
    device = torch.accelerator.current_accelerator()
    print(f"Using accelerator '{device}'")

    if device.type == "cuda":
        torch.backends.cudnn.deterministic = True
else:
    device = torch.device("cpu")
    print("WARN: No accelerator found, running on CPU")


transform = transforms.Compose(
    [
        transforms.ToTensor(),
        # normalize by mean and standard devia,
        transforms.Normalize((0.1307,), (0.3081,)),
    ]
)

train_dataset = datasets.MNIST(
    data_path,
    train=True,
    download=True,
    transform=transform,
)

test_loader = DataLoader(
    datasets.MNIST(data_path, train=False, download=False, transform=transform),
    shuffle=False,
    drop_last=False,
    batch_size=10000,
    generator=torch.Generator(),
)

Using accelerator 'mps'


In [43]:
class MnistCnn(torch.nn.Module):
    def __init__(self):
        super(MnistCnn, self).__init__()

        self.conv1 = torch.nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3)
        self.conv2 = torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3)
        self.dropout1 = torch.nn.Dropout(p=0.25)
        self.dropout2 = torch.nn.Dropout(p=0.5)
        self.fc1 = torch.nn.Linear(in_features=9216, out_features=128)
        self.fc2 = torch.nn.Linear(in_features=128, out_features=10)

    def forward(self, x):
        x = self.conv1(x)
        x = torch.nn.functional.relu(x)
        x = self.conv2(x)
        x = torch.nn.functional.relu(x)
        x = torch.nn.functional.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = torch.nn.functional.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)

        # Log softmax rather than softmax due to negative log likelihood loss.
        # log_softmax rather than two separate operations for numerical stability
        output = torch.nn.functional.log_softmax(x, dim=1)

        return output

In [44]:
from components.client import AbstractClient


# negative log likelihood loss
loss_function = torch.nn.functional.nll_loss


class WeightClient(AbstractClient):
    def __init__(
        self,
        model: torch.nn.Module,
        client_data: Subset,
        learning_rate: float,
        batch_size: int,
        local_epochs: int,
    ) -> None:
        super().__init__(model, client_data, batch_size)
        self.optimizer = torch.optim.SGD(
            params=self.model.parameters(), lr=learning_rate
        )
        self.local_epochs = local_epochs
        self.client_data_size = len(client_data)

    def train_epoch(
        self,
    ) -> None:
        self.model.train()

        for batch_features, batch_target in self.loader_train:
            batch_features = typing.cast(torch.Tensor, batch_features).to(device)
            batch_target = typing.cast(torch.Tensor, batch_target).to(device)

            self.optimizer.zero_grad()
            batch_output = self.model(batch_features)

            batch_loss = loss_function(batch_output, batch_target)
            batch_loss.backward()
            self.optimizer.step()

    def update(self, weights: list[torch.Tensor], seed: int) -> list[torch.Tensor]:
        self.generator.manual_seed(seed)
        self.build_local_model(weights)

        for _ in range(self.local_epochs):
            self.train_epoch()

        parameter_weights = [
            parameter.detach().clone().cpu()
            for parameter in self.model.parameters()
        ]
        return parameter_weights

In [45]:
import time

from components.server import DecentralizedServer


class FedAvgServer(DecentralizedServer):
    def __init__(
        self,
        model: torch.nn.Module,
        learning_rate: float,
        batch_size: int,
        client_subsets: list[Subset],
        client_fraction: float,
        local_epochs: int,
        seed: int,
    ) -> None:
        super().__init__(
            model,
            client_subsets,
            client_fraction,
            learning_rate,
            batch_size,
            seed,
            device,
        )
        self.local_epochs_count = local_epochs
        self.clients = [
            WeightClient(model, subset, learning_rate, batch_size, local_epochs)
            for subset in client_subsets
        ]

    def select_clients(self) -> IndexVector:
        return self.generator.choice(len(self.clients), self.clients_per_round)

    def calculate_weight_fraction_for_client(
        self,
        client: WeightClient,
        weights: list[torch.Tensor],
        seed: int,
        total_epoch_dataset_size: int,
    ) -> list[torch.Tensor]:
        client_dataset_size = client.client_data_size

        return [
            client_dataset_size / total_epoch_dataset_size * parameter_weight
            for parameter_weight in client.update(weights, seed)
        ]

    def run_epoch(self, weights: list[torch.Tensor], epoch: int) -> None:
        client_indices = [it.item() for it in self.select_clients()]
        client_dataset_size = sum(
            self.clients[client_idx].client_data_size for client_idx in client_indices
        )

        # N x M; N clients with weights for M parameters each
        client_weights: list[list[torch.Tensor]] = [
            self.calculate_weight_fraction_for_client(
                self.clients[client_idx],
                weights,
                seed=self.parameters.seed
                + client_idx
                + 1
                + epoch * self.clients_per_round,
                total_epoch_dataset_size=client_dataset_size,
            )
            for client_idx in tqdm(client_indices, "clients", leave=False)
        ]

        aggregated_client_weights: list[torch.Tensor] = [
            # sum weights parameter-wise; 'parameter_weights' is a tuple that contains one weight vector per client
            torch.stack(parameter_weights, dim=0).sum(dim=0)
            for parameter_weights in zip(*client_weights)
        ]

        with torch.no_grad():
            for parameter, parameter_weight in zip(
                self.model.parameters(), aggregated_client_weights
            ):
                parameter[:] = parameter_weight.to(device)

    def run(self, rounds: int) -> RunResult:
        metrics = RunMetrics()

        for epoch in tqdm(range(rounds), "epoch", leave=False):
            weights = [
                parameter.detach().clone() for parameter in self.model.parameters()
            ]

            wall_clock_start = time.perf_counter()
            weights = self.run_epoch(weights, epoch)
            wall_clock_end = time.perf_counter()

            accuracy = self.evaluate_accuracy(test_loader)
            execution_time_s = wall_clock_end - wall_clock_start

            metrics.test_accuracy.append(accuracy)
            metrics.wall_time.append(execution_time_s)
            metrics.message_count.append(2 * self.clients_per_round * (epoch + 1))

        return RunResult("FedAvg", self.parameters, metrics)

In [46]:
from components.data_splitting import (
    index_uniformly,
    partition_dataset,
)

client_datasets = partition_dataset(
    train_dataset,
    index_uniformly(
        train_dataset, partitions_count=100, generator_or_seed=42
    ),
)

fedavg_server = FedAvgServer(
    MnistCnn().to(device),
    learning_rate=0.02,
    batch_size=200,
    client_subsets=client_datasets,
    client_fraction=0.2,
    local_epochs=2,
    seed=42,
)

result_fedavg = fedavg_server.run(5)

epoch:   0%|          | 0/5 [00:00<?, ?it/s]

clients:   0%|          | 0/20 [00:00<?, ?it/s]

correct:  2213 total:  10000


clients:   0%|          | 0/20 [00:00<?, ?it/s]

correct:  5637 total:  10000


clients:   0%|          | 0/20 [00:00<?, ?it/s]

correct:  6777 total:  10000


clients:   0%|          | 0/20 [00:00<?, ?it/s]

correct:  7428 total:  10000


clients:   0%|          | 0/20 [00:00<?, ?it/s]

correct:  7684 total:  10000


In [47]:
fedavg_df = result_fedavg.as_df()
fedavg_df

Unnamed: 0,round,algorithm,clients_count,active_clients_fraction,batch_size,local_epochs_count,η,seed,wall_time,message_count (sum),test_accuracy
0,1,FedAvg,100,0.2,200,1,0.02,42,3.115875,40,0.2213
1,2,FedAvg,100,0.2,200,1,0.02,42,3.140193,80,0.5637
2,3,FedAvg,100,0.2,200,1,0.02,42,3.133477,120,0.6777
3,4,FedAvg,100,0.2,200,1,0.02,42,3.406714,160,0.7428
4,5,FedAvg,100,0.2,200,1,0.02,42,3.190974,200,0.7684
