# Tutorial 1A

The first lab tutorial presents the findings and uses part of the experimental methodology from the [original Federated Learning](https://arxiv.org/pdf/1602.05629.pdf) paper. In horizontal federated learning, all clients have access to the same complete model architecture, which they train on local data, sharing information about model updates but not their data.

Before starting, make sure to follow the overall setup for the labs.

<a href="https://blogs.nvidia.com/blog/what-is-federated-learning/" target="_blank">
    <img src="https://blogs.nvidia.com/wp-content/uploads/2019/10/federated_learning_animation_still_white.png" alt="FL Visualization" style="width:50%;">
</a>

---

Before anything else, we download, load, and preprocess the [MNIST dataset](https://archive.ics.uci.edu/dataset/683/mnist+database+of+handwritten+digits), which we will use for all experiments.

In [1]:
import typing

import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

data_path = "./data"
ETA = "\N{GREEK SMALL LETTER ETA}"

if torch.accelerator.is_available():
    device = torch.accelerator.current_accelerator()
    print(f"Using accelerator '{device}'")

    if device.type == "cuda":
        torch.backends.cudnn.deterministic = True
else:
    device = torch.device("cpu")
    print("WARN: No accelerator found, running on CPU")


transform = transforms.Compose(
    [
        transforms.ToTensor(),
        # normalize by mean and standard devia,
        transforms.Normalize((0.1307,), (0.3081,)),
    ]
)

train_dataset = datasets.MNIST(
    data_path,
    train=True,
    download=True,
    transform=transform,
)

test_loader = DataLoader(
    datasets.MNIST(data_path, train=False, download=False, transform=transform),
    shuffle=False,
    drop_last=False,
    batch_size=10000,
    generator=torch.Generator(),
)

Using accelerator 'mps'


We can then define a small convolutional neural network that will serve as our model.

In [2]:
import torch.nn as nn
import torch.nn.functional as F


class MnistCnn(nn.Module):
    def __init__(self):
        super(MnistCnn, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3)
        self.dropout1 = nn.Dropout(p=0.25)
        self.dropout2 = nn.Dropout(p=0.5)
        self.fc1 = nn.Linear(in_features=9216, out_features=128)
        self.fc2 = nn.Linear(in_features=128, out_features=10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)

        # Log softmax rather than softmax due to negative log likelihood loss.
        # log_softmax rather than two separate operations for numerical stability
        output = F.log_softmax(x, dim=1)

        return output

With that, we can define a helper method, which, given a model, a loader for iterating through a set of data, and an optimizer for updating the model trains one epoch (i.e., learns going through all the available data once).

In [3]:
from torch.optim import Optimizer


# negative log likelihood loss
loss_function = F.nll_loss


def train_epoch(
    model: torch.nn.Module, loader: DataLoader, optimizer: Optimizer
) -> None:
    model.train()

    for batch_features, batch_target in loader:
        batch_features = typing.cast(torch.Tensor, batch_features).to(device)
        batch_target = typing.cast(torch.Tensor, batch_target).to(device)

        optimizer.zero_grad()
        batch_output = model(batch_features)

        batch_loss = loss_function(batch_output, batch_target)
        batch_loss.backward()
        optimizer.step()

We also define another utility method that splits the dataset into several chunks.

We assign samples within chunks in an IID (independent and identically distributed) fashion or allow only two labels to exist in each.

In [4]:
from collections import Counter


def count_distinct_values(targets: typing.Iterable[typing.Any]):
    targets_count = Counter(targets)

    for target, count in targets_count.items():
        print(f"Target {target}: {count} records")

    print(f"Total: {targets_count.total()}")

In [5]:
count_distinct_values(train_dataset.targets.numpy())


Target 5: 5421 records
Target 0: 5923 records
Target 4: 5842 records
Target 1: 6742 records
Target 9: 5949 records
Target 2: 5958 records
Target 3: 6131 records
Target 6: 5918 records
Target 7: 6265 records
Target 8: 5851 records
Total: 60000


In [6]:
import typing

import numpy as np
import numpy.random as npr
from torch.utils.data import Dataset, Subset

type IndexVector = np.ndarray[tuple[int], np.dtype[np.long]]


def _get_rng(generator_or_seed: npr.Generator | int) -> npr.Generator:
    if type(generator_or_seed) == int:
        return npr.default_rng(generator_or_seed)
    else:
        return typing.cast(npr.Generator, generator_or_seed)


def index_uniformly(
    partitions_count: int, generator_or_seed: npr.Generator | int
) -> list[IndexVector]:
    generator = _get_rng(generator_or_seed)

    shuffled_indices: IndexVector = generator.permutation(len(train_dataset))
    return np.array_split(shuffled_indices, partitions_count)


def _combine_partitions(
    mini_partitions: list[IndexVector],
    *,
    mini_partitions_per_partition: int,
    generator: npr.Generator,
) -> list[IndexVector]:
    if len(mini_partitions) % mini_partitions_per_partition != 0:
        raise ValueError(
            f"expected to have exactly {mini_partitions_per_partition} mini-partitions per partition,"
            f"got {len(mini_partitions)} mini-partitions"
        )

    partitions_count = len(mini_partitions) // mini_partitions_per_partition
    shuffled_partition_indices = generator.permutation(len(mini_partitions))

    return [
        np.concatenate(
            [mini_partitions[partition_idx] for partition_idx in mini_partition_indices]
        )
        for mini_partition_indices in shuffled_partition_indices.reshape(
            partitions_count, mini_partitions_per_partition
        )
    ]


def index_by_approximate_binary_target_partitions(
    partitions_count: int, generator_or_seed: npr.Generator | int
) -> list[IndexVector]:
    generator = _get_rng(generator_or_seed)

    targets = train_dataset.targets.numpy().copy()
    generator.shuffle(targets)

    sorted_indices: IndexVector = np.argsort(train_dataset.targets)
    sorted_indices_partitions: list[IndexVector] = np.array_split(
        sorted_indices, 2 * partitions_count
    )

    return _combine_partitions(
        sorted_indices_partitions, mini_partitions_per_partition=2, generator=generator
    )


def index_by_binary_target_partitions(
    partitions_count: int, generator_or_seed: npr.Generator | int
) -> list[IndexVector]:
    generator = _get_rng(generator_or_seed)

    targets = train_dataset.targets.numpy()
    generator.shuffle(targets)

    client_indices = []
    unique_targets = np.unique(targets)

    unique_targets_count = unique_targets.shape[0]
    mini_partitions_per_label = partitions_count // unique_targets_count
    if partitions_count % unique_targets_count != 0:
        raise ValueError(
            "expected number of partitions to be a multiple of the number of unique Ïtargets, "
            f"got {partitions_count} partitions and {unique_targets_count} unique targets"
        )

    for target in unique_targets:
        label_indices = np.where(targets == target)[0]
        label_shards = np.array_split(label_indices, mini_partitions_per_label)
        client_indices.extend(label_shards)

    return _combine_partitions(
        client_indices, mini_partitions_per_partition=2, generator=generator
    )

In [7]:
def partition_dataset(
    dataset: Dataset, partitions: list[IndexVector]
) -> list[Subset[typing.Any]]:
    return [
        Subset(dataset, typing.cast(typing.Sequence[int], partition))
        for partition in partitions
    ]

In [8]:
sample_index = index_by_approximate_binary_target_partitions(10, 42)
subsets = partition_dataset(train_dataset, sample_index)
# 
for subset_idx, subset in enumerate(subsets):
    print(f"==== subset #{subset_idx} ====")
    subset_targets = list(int(it) for _, it in DataLoader(subset))
    count_distinct_values(subset_targets)

==== subset #0 ====
Target 2: 3000 records
Target 7: 3000 records
Total: 6000
==== subset #1 ====
Target 3: 3000 records
Target 6: 2935 records
Target 7: 65 records
Total: 6000
==== subset #2 ====
Target 1: 3665 records
Target 2: 2335 records
Total: 6000
==== subset #3 ====
Target 2: 623 records
Target 3: 2377 records
Target 5: 17 records
Target 6: 2983 records
Total: 6000
==== subset #4 ====
Target 3: 754 records
Target 4: 2246 records
Target 0: 2923 records
Target 1: 77 records
Total: 6000
==== subset #5 ====
Target 8: 51 records
Target 9: 2949 records
Target 0: 3000 records
Total: 6000
==== subset #6 ====
Target 4: 596 records
Target 5: 2404 records
Target 7: 200 records
Target 8: 2800 records
Total: 6000
==== subset #7 ====
Target 5: 3000 records
Target 4: 3000 records
Total: 6000
==== subset #8 ====
Target 7: 3000 records
Target 8: 3000 records
Total: 6000
==== subset #9 ====
Target 9: 3000 records
Target 1: 3000 records
Total: 6000


We define a short class for holding the results of training runs and the parameters used.

In [9]:
import dataclasses

from pandas import DataFrame


@dataclasses.dataclass
class RoundParameters:
    clients_count: int
    active_clients_fraction: float
    batch_size: int
    local_epochs_count: int
    learning_rate: float
    seed: int


@dataclasses.dataclass
class RunMetrics:
    wall_time: list[float] = dataclasses.field(default_factory=list)
    message_count: list[int] = dataclasses.field(default_factory=list)
    test_accuracy: list[float] = dataclasses.field(default_factory=list)


@dataclasses.dataclass
class RunResult:
    algorithm: str
    parameters: RoundParameters
    metrics: RunMetrics = dataclasses.field(default_factory=RunMetrics)

    def as_df(self) -> DataFrame:
        table_data = {
            "round": range(1, len(self.metrics.wall_time) + 1),
            "algorithm": self.algorithm,
            **dataclasses.asdict(self.parameters),
            **dataclasses.asdict(self.metrics),
        }

        if table_data["batch_size"] == -1:
            table_data["batch_size"] = "\N{INFINITY}"

        df = DataFrame(table_data)
        df = df.rename(
            columns={"learning_rate": ETA, "message_count": "message_count (sum)"}
        )

        return df

We create an abstract class as a template for all distributed learning clients, defining a method for outputting an update after training a given model on local data.

In [10]:
from abc import ABC


class Client(typing.Protocol):
    def update(self, weights: list[torch.Tensor], seed: int) -> list[torch.Tensor]:
        ...


class AbstractClient(ABC, Client):
    def __init__(self, client_data: Subset, batch_size: int) -> None:
        self.model = MnistCnn().to(device)
        self.generator = torch.Generator()
        self.loader_train = DataLoader(
            client_data, batch_size=batch_size, shuffle=True,
            drop_last=False, generator=self.generator)
        
    def build_local_model(self, weights: list[torch.Tensor]) -> torch.nn.Module:
        model = MnistCnn().to(device)
        with torch.no_grad():
            for client_parameter, server_parameter_values in zip(
                model.parameters(), weights
            ):
                client_parameter[:] = server_parameter_values
                client_parameter.grad = None

        return model


On the flip side, a server needs to be able to run the (distributed) training process for a given number of rounds and test the current model it possesses.

In [11]:
class Server(typing.Protocol):
    def run(self, rounds: int) -> RunResult: ...


class AbstractServer(ABC, Server):
    def __init__(self, parameters: RoundParameters) -> None:
        torch.manual_seed(parameters.seed)
        self.parameters = parameters
        self.model = MnistCnn().to(device)

    def evaluate_accuracy(self) -> float:
        self.model.eval()

        correct_predictions = 0
        total_predictions = 0
        with torch.no_grad():
            for batch_features, batch_targets in test_loader:
                batch_features = typing.cast(torch.Tensor, batch_features).to(device)
                batch_targets = typing.cast(torch.Tensor, batch_targets).to(device)

                batch_output: torch.Tensor = self.model(batch_features)

                # index of output neuron/logit corresponds to label
                batch_predictions = batch_output.argmax(dim=1, keepdim=True)

                correct_predictions += batch_predictions.eq(batch_targets.view_as(batch_predictions)).sum().item()
                total_predictions += batch_predictions.size(dim=0)

        print("correct: ", correct_predictions, "total: ", total_predictions)
        return correct_predictions / total_predictions

Over the previously defined server template, we can even formulate a centralized variant, which does not involve clients, as a precursor to distributed versions.

In [12]:
import time
from torch.optim import SGD
from tqdm import tqdm


class CentralizedServer(AbstractServer):
    def __init__(self, learning_rate: float, batch_size: int, seed: int) -> None:
        super().__init__(
            RoundParameters(
                clients_count=1,
                active_clients_fraction=float("nan"),
                batch_size=batch_size,
                local_epochs_count=1,
                learning_rate=learning_rate,
                seed=seed,
            )
        )
        self.optimizer = SGD(params=self.model.parameters(), lr=learning_rate)
        self.generator = torch.Generator()
        self.loader_train = DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True,
            drop_last=False,
            generator=self.generator,
        )

    def run(self, rounds: int) -> RunResult:
        metrics = RunMetrics()

        for epoch in tqdm(range(rounds), "epoch", leave=False):
            self.generator.manual_seed(self.parameters.seed + epoch + 1)

            wall_time_start = time.perf_counter()
            train_epoch(self.model, self.loader_train, self.optimizer)
            wall_time_end = time.perf_counter()

            accuracy = self.evaluate_accuracy()
            execution_time = wall_time_end - wall_time_start

            metrics.test_accuracy.append(accuracy)
            metrics.wall_time.append(execution_time)
            metrics.message_count.append(-1)

        return RunResult(
            "centralized",
            self.parameters,
            metrics
        )

In [13]:
centralized_server = CentralizedServer(learning_rate=0.5, batch_size=1024, seed=42)
result_centralized = centralized_server.run(rounds=5)

epoch:  20%|██        | 1/5 [00:07<00:30,  7.60s/it]

correct:  8131 total:  10000


epoch:  40%|████      | 2/5 [00:14<00:21,  7.16s/it]

correct:  9645 total:  10000


epoch:  60%|██████    | 3/5 [00:21<00:14,  7.04s/it]

correct:  9786 total:  10000


epoch:  80%|████████  | 4/5 [00:28<00:06,  6.96s/it]

correct:  9745 total:  10000


                                                    

correct:  9802 total:  10000




In [14]:
centralized_df = result_centralized.as_df()
centralized_df

Unnamed: 0,round,algorithm,clients_count,active_clients_fraction,batch_size,local_epochs_count,η,seed,wall_time,message_count (sum),test_accuracy
0,1,centralized,1,,1024,1,0.5,42,6.78824,-1,0.8131
1,2,centralized,1,,1024,1,0.5,42,6.122064,-1,0.9645
2,3,centralized,1,,1024,1,0.5,42,6.140187,-1,0.9786
3,4,centralized,1,,1024,1,0.5,42,6.120686,-1,0.9745
4,5,centralized,1,,1024,1,0.5,42,6.140917,-1,0.9802


We can extend the template with some setup steps common to all decentralized algorithms.

In [15]:
class DecentralizedServer(AbstractServer):
    def __init__(
        self,
        client_subsets: list[Subset],
        active_clients_fraction: float,
        learning_rate: float,
        batch_size: int,
        seed: int,
    ) -> None:
        super().__init__(
            RoundParameters(
                clients_count=len(client_subsets),
                active_clients_fraction=active_clients_fraction,
                batch_size=batch_size,
                local_epochs_count=1,
                learning_rate=learning_rate,
                seed=seed,
            )
        )

        self.generator = npr.default_rng(seed)
        self.clients_per_round = max(1, int(len(client_subsets) * active_clients_fraction))

The two federated learning algorithms from the paper follow, alongside an overview of metric plotting.

---

For the FedSGD algorithm, the baseline from the paper, we first need to define the client, and we choose to pass gradients from the client as the update result.

In [16]:
class GradientClient(AbstractClient):
    def __init__(self, client_data: Subset) -> None:
        client_data_size = len(client_data)

        super().__init__(client_data, client_data_size)
        self.client_data_size = client_data_size

    def update(self, weights: list[torch.Tensor], seed: int) -> list[torch.Tensor]:
        self.generator.manual_seed(seed)
        model = self.build_local_model(weights)
        
        model.train()
        for batch_features, batch_target in self.loader_train:
            batch_features = typing.cast(torch.Tensor, batch_features).to(device)
            batch_target = typing.cast(torch.Tensor, batch_target).to(device)

            batch_output = model(batch_features)

            batch_loss = loss_function(batch_output, batch_target)
            batch_loss.backward()

        parameter_gradients = [typing.cast(torch.Tensor, param.grad).detach().cpu() for param in model.parameters()]
        return parameter_gradients

We then define the corresponding server.

In [17]:
class FedSgdGradientServer(DecentralizedServer):
    def __init__(
        self,
        client_subsets: list[Subset],
        active_clients_fraction: float,
        learning_rate: float,
        seed: int,
    ) -> None:
        super().__init__(
            client_subsets=client_subsets,
            active_clients_fraction=active_clients_fraction,
            learning_rate=learning_rate,
            batch_size=-1,
            seed=seed,
        )
        self.optimizer = SGD(params=self.model.parameters(), lr=learning_rate)
        self.clients: list[GradientClient] = [
            GradientClient(subset) for subset in client_subsets
        ]
        self.client_datasets = client_subsets

    def select_clients(self) -> IndexVector:
        return self.generator.choice(len(self.clients), self.clients_per_round)

    def calculate_gradient_fraction_for_client(
        self,
        client: GradientClient,
        weights: list[torch.Tensor],
        seed: int,
        total_epoch_dataset_size: int,
    ) -> list[torch.Tensor]:
        client_dataset_size = client.client_data_size
        print(
            f"running training on client with training dataset of size {client_dataset_size}"
        )

        return [
            client_dataset_size / total_epoch_dataset_size * gradient_component
            for gradient_component in client.update(weights, seed)
        ]

    def run_epoch(self, weights: list[torch.Tensor], epoch: int):
        client_indices = self.select_clients()
        client_dataset_size = sum(
            len(self.clients[client_idx].loader_train) for client_idx in client_indices
        )

        # N x M; N clients with gradients for M parameters each
        gradients = [
            self.calculate_gradient_fraction_for_client(
                self.clients[client_idx],
                weights,
                seed=self.parameters.seed + epoch + 1,
                total_epoch_dataset_size=client_dataset_size,
            )
            for client_idx in client_indices
        ]

        aggregated_gradient: list[torch.Tensor] = [
            # sum gradients parameter-wise; 'parameter_gradients' is a tuple that contains one gradient per client
            torch.stack(parameter_gradients, dim=0).sum(dim=0)
            for parameter_gradients in zip(*gradients)
        ]

        with torch.no_grad():
            for parameter, parameter_gradient in zip(
                self.model.parameters(), aggregated_gradient
            ):
                parameter.grad = parameter_gradient.to(device)

        self.model.train()
        self.optimizer.step()

    def run(self, rounds: int) -> RunResult:
        metrics = RunMetrics()

        for epoch in tqdm(range(rounds), "epoch", leave=False):
            weights = [
                parameter.detach().clone() for parameter in self.model.parameters()
            ]

            wall_clock_start = time.perf_counter()
            weights = self.run_epoch(weights, epoch)
            wall_clock_end = time.perf_counter()

            accuracy = self.evaluate_accuracy()
            execution_time_s = wall_clock_end - wall_clock_start

            metrics.test_accuracy.append(accuracy)
            metrics.wall_time.append(execution_time_s)
            metrics.message_count.append(2 * self.clients_per_round * (epoch + 1))

        return RunResult("FedSgd", self.parameters, metrics)

In [18]:
client_datasets = partition_dataset(
    train_dataset,
    index_by_approximate_binary_target_partitions(
        partitions_count=20, generator_or_seed=42
    ),
)

fedsgd_gradient_server = FedSgdGradientServer(
    client_subsets=client_datasets,
    active_clients_fraction=0.2,
    learning_rate=0.02,
    seed=42,
)
result_fedsgd_gradient = fedsgd_gradient_server.run(5)
fedsgd_gradient_df = result_fedsgd_gradient.as_df()
fedsgd_gradient_df

epoch:   0%|          | 0/5 [00:00<?, ?it/s]

running training on client with training dataset of size 3000
running training on client with training dataset of size 3000
running training on client with training dataset of size 3000
running training on client with training dataset of size 3000


epoch:  20%|██        | 1/5 [00:02<00:11,  2.96s/it]

correct:  1435 total:  10000
running training on client with training dataset of size 3000
running training on client with training dataset of size 3000
running training on client with training dataset of size 3000
running training on client with training dataset of size 3000


epoch:  40%|████      | 2/5 [00:06<00:10,  3.36s/it]

correct:  1032 total:  10000
running training on client with training dataset of size 3000
running training on client with training dataset of size 3000
running training on client with training dataset of size 3000
running training on client with training dataset of size 3000


epoch:  60%|██████    | 3/5 [00:10<00:06,  3.49s/it]

correct:  980 total:  10000
running training on client with training dataset of size 3000
running training on client with training dataset of size 3000
running training on client with training dataset of size 3000
running training on client with training dataset of size 3000


epoch:  80%|████████  | 4/5 [00:13<00:03,  3.24s/it]

correct:  980 total:  10000
running training on client with training dataset of size 3000
running training on client with training dataset of size 3000
running training on client with training dataset of size 3000


                                                    

KeyboardInterrupt: 

The FedAvg algorithm is the paper's main contribution, requiring a client that passes around weights instead of gradients.

In [None]:
class WeightClient(AbstractClient):
    def __init__(
        self, client_data: Subset, lr: float, batch_size: int, nr_epochs: int
    ) -> None:
        super().__init__(client_data, batch_size)
        self.optimizer = SGD(params=self.model.parameters(), lr=lr)
        self.nr_epochs = nr_epochs

    def update(self, weights: list[torch.Tensor], seed: int) -> list[torch.Tensor]:
        # build new model and configure the server's weights

        return []

Following that, we define the actual server code for the method.

In [None]:
class FedAvgServer(DecentralizedServer):
    def __init__(
        self,
        lr: float,
        batch_size: int,
        client_subsets: list[Subset],
        client_fraction: float,
        nr_local_epochs: int,
        seed: int,
    ) -> None:
        super().__init__(client_subsets, client_fraction, lr, batch_size, seed)
        self.local_epochs_count = nr_local_epochs
        self.clients = [
            WeightClient(subset, lr, batch_size, nr_local_epochs)
            for subset in client_subsets
        ]

    def run(self, rounds: int) -> RunResult:
        metrics = RunMetrics()

    

        return RunResult(
            "FedAvg",
            self.parameters,
            RunMetrics()
        )

In [None]:
fedavg_server = FedAvgServer(0.02, 200, sample_split, 0.2, 2, 42)
result_fedavg = fedavg_server.run(5)
fedavg_df = result_fedavg.as_df()
fedavg_df

NameError: name 'sample_split' is not defined

Finally, we look at a quick example of plotting the accuracy per round of the two algorithms.

In [None]:
import pandas as pd
import seaborn as sns

df = pd.concat([fedavg_df, fedsgd_gradient_df], ignore_index=True)
ax = sns.lineplot(df, x="Round", y="Test accuracy", hue="Algorithm", seed=0)
_ = ax.set_xticks(df["Round"].unique())