# Tutorial 1A

---

Before anything else, we download, load, and preprocess the [MNIST dataset](https://archive.ics.uci.edu/dataset/683/mnist+database+of+handwritten+digits), which we will use for all experiments.

In [1]:
import typing

import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm.notebook import tqdm

data_path = "../../datasets"

if torch.accelerator.is_available():
    device = torch.accelerator.current_accelerator()
    print(f"Using accelerator '{device}'")

    if device.type == "cuda":
        torch.backends.cudnn.deterministic = True
else:
    device = torch.device("cpu")
    print("WARN: No accelerator found, running on CPU")


transform = transforms.Compose(
    [
        transforms.ToTensor(),
        # normalize by mean and standard devia,
        transforms.Normalize((0.1307,), (0.3081,)),
    ]
)

train_dataset = datasets.MNIST(
    data_path,
    train=True,
    download=True,
    transform=transform,
)

test_loader = DataLoader(
    datasets.MNIST(data_path, train=False, download=False, transform=transform),
    shuffle=False,
    drop_last=False,
    batch_size=10000,
    generator=torch.Generator(),
)

Using accelerator 'mps'


We can then define a small convolutional neural network that will serve as our model.

In [2]:
class MnistCnn(torch.nn.Module):
    def __init__(self):
        super(MnistCnn, self).__init__()

        self.conv1 = torch.nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3)
        self.conv2 = torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3)
        self.dropout1 = torch.nn.Dropout(p=0.25)
        self.dropout2 = torch.nn.Dropout(p=0.5)
        self.fc1 = torch.nn.Linear(in_features=9216, out_features=128)
        self.fc2 = torch.nn.Linear(in_features=128, out_features=10)

    def forward(self, x):
        x = self.conv1(x)
        x = torch.nn.functional.relu(x)
        x = self.conv2(x)
        x = torch.nn.functional.relu(x)
        x = torch.nn.functional.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = torch.nn.functional.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)

        # Log softmax rather than softmax due to negative log likelihood loss.
        # log_softmax rather than two separate operations for numerical stability
        output = torch.nn.functional.log_softmax(x, dim=1)

        return output

With that, we can define a helper method, which, given a model, a loader for iterating through a set of data, and an optimizer for updating the model trains one epoch (i.e., learns going through all the available data once).

In [3]:
# negative log likelihood loss
loss_function = torch.nn.functional.nll_loss


def train_epoch(
    model: torch.nn.Module, loader: DataLoader, optimizer: torch.optim.Optimizer
) -> None:
    model.train()

    for batch_features, batch_target in loader:
        batch_features = typing.cast(torch.Tensor, batch_features).to(device)
        batch_target = typing.cast(torch.Tensor, batch_target).to(device)

        optimizer.zero_grad()
        batch_output = model(batch_features)

        batch_loss = loss_function(batch_output, batch_target)
        batch_loss.backward()
        optimizer.step()

We also define another utility method that splits the dataset into several chunks.

We assign samples within chunks in an IID (independent and identically distributed) fashion or allow only two labels to exist in each.

We define a short class for holding the results of training runs and the parameters used.

In [4]:
import time

from components.metrics import RoundParameters, RunMetrics, RunResult
from components.server import AbstractServer


class CentralizedServer(AbstractServer):
    def __init__(
        self, model: torch.nn.Module, learning_rate: float, batch_size: int, seed: int
    ) -> None:
        super().__init__(
            model,
            RoundParameters(
                clients_count=1,
                active_clients_fraction=float("nan"),
                batch_size=batch_size,
                local_epochs_count=1,
                learning_rate=learning_rate,
                seed=seed,
            ),
            device,
        )
        self.optimizer = torch.optim.SGD(
            params=self.model.parameters(), lr=learning_rate
        )
        self.generator = torch.Generator()
        self.loader_train = DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True,
            drop_last=False,
            generator=self.generator,
        )

    def run(self, rounds: int) -> RunResult:
        metrics = RunMetrics()

        for epoch in tqdm(range(rounds), "epoch", leave=False):
            self.generator.manual_seed(self.parameters.seed + epoch + 1)

            wall_time_start = time.perf_counter()
            train_epoch(self.model, self.loader_train, self.optimizer)
            wall_time_end = time.perf_counter()

            accuracy = self.evaluate_accuracy(test_loader)
            execution_time = wall_time_end - wall_time_start

            metrics.test_accuracy.append(accuracy)
            metrics.wall_time.append(execution_time)
            metrics.message_count.append(-1)

        return RunResult("centralized", self.parameters, metrics)

In [5]:
centralized_server = CentralizedServer(
    model=MnistCnn().to(device), learning_rate=0.5, batch_size=1024, seed=42
)
result_centralized = centralized_server.run(rounds=5)

epoch:   0%|          | 0/5 [00:00<?, ?it/s]

correct:  8027 total:  10000
correct:  9653 total:  10000
correct:  9745 total:  10000
correct:  9800 total:  10000
correct:  9836 total:  10000


In [6]:
centralized_df = result_centralized.as_df()
centralized_df

Unnamed: 0,round,algorithm,clients_count,active_clients_fraction,batch_size,local_epochs_count,η,seed,wall_time,message_count (sum),test_accuracy
0,1,centralized,1,,1024,1,0.5,42,6.356011,-1,0.8027
1,2,centralized,1,,1024,1,0.5,42,6.104182,-1,0.9653
2,3,centralized,1,,1024,1,0.5,42,6.195448,-1,0.9745
3,4,centralized,1,,1024,1,0.5,42,6.093476,-1,0.98
4,5,centralized,1,,1024,1,0.5,42,6.204228,-1,0.9836
