In [None]:
import typing

import torch
from components.metrics import RunMetrics, RunResult
from components.tensor_types import IndexVector
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
from tqdm.notebook import tqdm

data_path = "../../datasets"

if torch.accelerator.is_available():
    device = torch.accelerator.current_accelerator()
    print(f"Using accelerator '{device}'")

    if device.type == "cuda":
        torch.backends.cudnn.deterministic = True
else:
    device = torch.device("cpu")
    print("WARN: No accelerator found, running on CPU")


transform = transforms.Compose(
    [
        transforms.ToTensor(),
        # normalize by mean and standard devia,
        transforms.Normalize((0.1307,), (0.3081,)),
    ]
)

train_dataset = datasets.MNIST(
    data_path,
    train=True,
    download=True,
    transform=transform,
)

test_loader = DataLoader(
    datasets.MNIST(data_path, train=False, download=False, transform=transform),
    shuffle=False,
    drop_last=False,
    batch_size=10000,
    generator=torch.Generator(),
)

In [None]:
class MnistCnn(torch.nn.Module):
    def __init__(self):
        super(MnistCnn, self).__init__()

        self.conv1 = torch.nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3)
        self.conv2 = torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3)
        self.dropout1 = torch.nn.Dropout(p=0.25)
        self.dropout2 = torch.nn.Dropout(p=0.5)
        self.fc1 = torch.nn.Linear(in_features=9216, out_features=128)
        self.fc2 = torch.nn.Linear(in_features=128, out_features=10)

    def forward(self, x):
        x = self.conv1(x)
        x = torch.nn.functional.relu(x)
        x = self.conv2(x)
        x = torch.nn.functional.relu(x)
        x = torch.nn.functional.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = torch.nn.functional.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)

        # Log softmax rather than softmax due to negative log likelihood loss.
        # log_softmax rather than two separate operations for numerical stability
        output = torch.nn.functional.log_softmax(x, dim=1)

        return output

In [None]:
from components.client import AbstractClient


class WeightClient(AbstractClient):
    def __init__(
        self,
        model: nn.Module,
        client_data: Subset,
        lr: float,
        batch_size: int,
        nr_epochs: int,
    ) -> None:
        super().__init__(model, client_data, batch_size)
        self.optimizer = SGD(params=self.model.parameters(), lr=lr)
        self.nr_epochs = nr_epochs

    def update(self, weights: list[torch.Tensor], seed: int) -> list[torch.Tensor]:
        # build new model and configure the server's weights

        return []

In [None]:
from components.server import DecentralizedServer


class FedAvgServer(DecentralizedServer):
    def __init__(
        self,
        model: torch.nn.Module,
        learning_rate: float,
        batch_size: int,
        client_subsets: list[Subset],
        client_fraction: float,
        local_epochs: int,
        seed: int,
    ) -> None:
        super().__init__(model, client_subsets, client_fraction, learning_rate, batch_size, seed, device)
        self.local_epochs_count = local_epochs
        self.clients = [
            WeightClient(model, subset, learning_rate, batch_size, local_epochs)
            for subset in client_subsets
        ]

    def run(self, rounds: int) -> RunResult:
        metrics = RunMetrics()

        

        return RunResult(
            "FedAvg",
            self.parameters,
            RunMetrics()
        )

In [None]:
fedavg_server = FedAvgServer(MnistCnn().to(device), 0.02, 200, sample_split, 0.2, 2, 42)
result_fedavg = fedavg_server.run(5)
fedavg_df = result_fedavg.as_df()
fedavg_df

In [None]:
import pandas as pd
import seaborn as sns

df = pd.concat([fedavg_df, fedsgd_gradient_df], ignore_index=True)
ax = sns.lineplot(df, x="Round", y="Test accuracy", hue="Algorithm", seed=0)
_ = ax.set_xticks(df["Round"].unique())