In [1]:
import typing

import torch
from components.metrics import RunMetrics, RunResult
from components.tensor_types import IndexVector
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
from tqdm.notebook import tqdm

data_path = "../../datasets"

if torch.accelerator.is_available():
    device = torch.accelerator.current_accelerator()
    print(f"Using accelerator '{device}'")

    if device.type == "cuda":
        torch.backends.cudnn.deterministic = True
else:
    device = torch.device("cpu")
    print("WARN: No accelerator found, running on CPU")


transform = transforms.Compose(
    [
        transforms.ToTensor(),
        # normalize by mean and standard devia,
        transforms.Normalize((0.1307,), (0.3081,)),
    ]
)

train_dataset = datasets.MNIST(
    data_path,
    train=True,
    download=True,
    transform=transform,
)

test_loader = DataLoader(
    datasets.MNIST(data_path, train=False, download=False, transform=transform),
    shuffle=False,
    drop_last=False,
    batch_size=10000,
    generator=torch.Generator(),
)

Using accelerator 'mps'


In [2]:
class MnistCnn(torch.nn.Module):
    def __init__(self):
        super(MnistCnn, self).__init__()

        self.conv1 = torch.nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3)
        self.conv2 = torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3)
        self.dropout1 = torch.nn.Dropout(p=0.25)
        self.dropout2 = torch.nn.Dropout(p=0.5)
        self.fc1 = torch.nn.Linear(in_features=9216, out_features=128)
        self.fc2 = torch.nn.Linear(in_features=128, out_features=10)

    def forward(self, x):
        x = self.conv1(x)
        x = torch.nn.functional.relu(x)
        x = self.conv2(x)
        x = torch.nn.functional.relu(x)
        x = torch.nn.functional.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = torch.nn.functional.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)

        # Log softmax rather than softmax due to negative log likelihood loss.
        # log_softmax rather than two separate operations for numerical stability
        output = torch.nn.functional.log_softmax(x, dim=1)

        return output

In [3]:
from components.client import AbstractClient

# negative log likelihood loss
loss_function = torch.nn.functional.nll_loss


class GradientClient(AbstractClient):
    def __init__(self, model: torch.nn.Module, client_data: Subset) -> None:
        client_data_size = len(client_data)

        super().__init__(model, client_data, client_data_size)
        self.client_data_size = client_data_size

    def update(self, weights: list[torch.Tensor], seed: int) -> list[torch.Tensor]:
        self.generator.manual_seed(seed)
        client_model = self.build_local_model(weights)

        client_model.train()
        for batch_features, batch_targets in self.loader_train:
            batch_features = typing.cast(torch.Tensor, batch_features).to(device)
            batch_targets = typing.cast(torch.Tensor, batch_targets).to(device)

            batch_output = client_model(batch_features)

            batch_loss = loss_function(batch_output, batch_targets)
            batch_loss.backward()

        parameter_gradients = [
            typing.cast(torch.Tensor, param.grad).clone().detach().cpu()
            for param in client_model.parameters()
        ]
        return parameter_gradients

In [4]:
import time

from components.server import DecentralizedServer


class FedSgdGradientServer(DecentralizedServer):
    def __init__(
        self,
        model_builder: typing.Callable[[], torch.nn.Module],
        client_subsets: list[Subset],
        active_clients_fraction: float,
        learning_rate: float,
        seed: int,
    ) -> None:
        super().__init__(
            model=model_builder(),
            client_subsets=client_subsets,
            active_clients_fraction=active_clients_fraction,
            learning_rate=learning_rate,
            batch_size=-1,
            seed=seed,
            device=device,
        )
        self.optimizer = torch.optim.SGD(
            params=self.model.parameters(), lr=learning_rate
        )
        self.clients: list[GradientClient] = [
            GradientClient(model_builder(), subset) for subset in client_subsets
        ]
        self.client_datasets = client_subsets

    def select_clients(self) -> IndexVector:
        return self.generator.choice(len(self.clients), self.clients_per_round)

    def calculate_gradient_fraction_for_client(
        self,
        client: GradientClient,
        weights: list[torch.Tensor],
        seed: int,
        total_epoch_dataset_size: int,
    ) -> list[torch.Tensor]:
        client_dataset_size = client.client_data_size

        return [
            client_dataset_size / total_epoch_dataset_size * parameter_gradient
            for parameter_gradient in client.update(weights, seed)
        ]

    def run_epoch(self, weights: list[torch.Tensor], epoch: int) -> None:
        client_indices = [it.item() for it in self.select_clients()]
        client_dataset_size = sum(
            self.clients[client_idx].client_data_size for client_idx in client_indices
        )

        self.optimizer.zero_grad()

        # N x M; N clients with gradients for M parameters each
        client_gradients = [
            self.calculate_gradient_fraction_for_client(
                self.clients[client_idx],
                weights,
                seed=self.parameters.seed
                + client_idx
                + 1
                + epoch * self.clients_per_round,
                total_epoch_dataset_size=client_dataset_size,
            )
            for client_idx in tqdm(client_indices, "clients")
        ]

        aggregated_client_gradients: list[torch.Tensor] = [
            # sum gradients parameter-wise; 'parameter_gradients' is a tuple that contains one gradient per client
            torch.stack(parameter_gradients, dim=0).sum(dim=0)
            for parameter_gradients in zip(*client_gradients)
        ]

        with torch.no_grad():
            for parameter, parameter_gradient in zip(
                self.model.parameters(), aggregated_client_gradients
            ):
                parameter.grad = parameter_gradient.to(device)

        self.model.train()
        self.optimizer.step()

    def run(self, rounds: int) -> RunResult:
        metrics = RunMetrics()

        for epoch in tqdm(range(rounds), "epoch", leave=False):
            weights = [
                parameter.detach().clone() for parameter in self.model.parameters()
            ]

            wall_clock_start = time.perf_counter()
            weights = self.run_epoch(weights, epoch)
            wall_clock_end = time.perf_counter()

            accuracy = self.evaluate_accuracy(test_loader)
            execution_time_s = wall_clock_end - wall_clock_start

            metrics.test_accuracy.append(accuracy)
            metrics.wall_time.append(execution_time_s)
            metrics.message_count.append(2 * self.clients_per_round * (epoch + 1))

        return RunResult("FedSgd", self.parameters, metrics)

In [5]:
from components.data_splitting import (
    index_uniformly,
    partition_dataset,
)

client_datasets = partition_dataset(
    train_dataset,
    index_uniformly(
        train_dataset, partitions_count=100, generator_or_seed=42
    ),
)

fedsgd_gradient_server = FedSgdGradientServer(
    model_builder=lambda: MnistCnn().to(device),
    client_subsets=client_datasets,
    active_clients_fraction=0.5,
    learning_rate=0.02,
    seed=42,
)
result_fedsgd_gradient = fedsgd_gradient_server.run(5)
fedsgd_gradient_df = result_fedsgd_gradient.as_df()
fedsgd_gradient_df

epoch:   0%|          | 0/5 [00:00<?, ?it/s]

clients:   0%|          | 0/50 [00:00<?, ?it/s]

correct:  1558 total:  10000


clients:   0%|          | 0/50 [00:00<?, ?it/s]

correct:  1868 total:  10000


clients:   0%|          | 0/50 [00:00<?, ?it/s]

correct:  2166 total:  10000


clients:   0%|          | 0/50 [00:00<?, ?it/s]

correct:  2441 total:  10000


clients:   0%|          | 0/50 [00:00<?, ?it/s]

correct:  2788 total:  10000


Unnamed: 0,round,algorithm,clients_count,active_clients_fraction,batch_size,local_epochs_count,η,seed,wall_time,message_count (sum),test_accuracy
0,1,FedSgd,100,0.5,∞,1,0.02,42,5.112729,100,0.1558
1,2,FedSgd,100,0.5,∞,1,0.02,42,5.113503,200,0.1868
2,3,FedSgd,100,0.5,∞,1,0.02,42,5.204438,300,0.2166
3,4,FedSgd,100,0.5,∞,1,0.02,42,5.297275,400,0.2441
4,5,FedSgd,100,0.5,∞,1,0.02,42,7.295458,500,0.2788
