# Federated Learning using PyTorch and Flower

This tutorial will show how to leverage PyTorch to federate the training of a convolutional neural network over multiple clients using [Flower](https://flower.dev/). More specifically, we will use a ResNet model for the CIFAR10 dataset. The end goal is to classify 10 different objects in small images. 

In [2]:
!pip install -q flwr["simulation"] torch torchvision

You should consider upgrading via the '/home/w.lindskog/.pyenv/versions/3.9.5/bin/python3.9 -m pip install --upgrade pip' command.[0m


We can now import the relevant modules.

In [3]:
from collections import OrderedDict
from typing import Callable, Dict, List, Optional, Tuple, Union
import os
import random
import warnings

import flwr as fl
import torch
import torchvision
from torchvision import transforms

from torch.utils.data import DataLoader

Next we will set some global variables and disable some of the logging to clear out our output.

In [22]:
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=DeprecationWarning)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
warnings.simplefilter('ignore')

DEVICE = torch.device("cuda:0")
NUM_CLIENTS = 2
NUM_ROUNDS = 3

Next, we'll create a function that fetches the data from torchvision. 

In [5]:
def load_data() -> Union[DataLoader, DataLoader]:
    """ Load data CIFAR10 from torchvision.datasets """
    # Set transform
    transform = transforms.Compose(
        [
            # cifar10 optimal transform
            transforms.Resize((224, 224)),
            transforms.RandomCrop(224, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.49139968, 0.48215841, 0.44653091],
                std=[0.24703223, 0.24348513, 0.26158784],
            ),
        ]
    )

    train_dataset = torchvision.datasets.CIFAR10(
        root="~/torch_datasets", train=True, transform=transform, download=True
    )
    test_dataset = torchvision.datasets.CIFAR10(
        root="~/torch_datasets", train=False, transform=transform, download=True
    )

    # Set to data loader
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=32, shuffle=True, num_workers=4, pin_memory=True
    )
    test_loader = torch.utils.data.DataLoader(
        test_dataset, batch_size=32, shuffle=False, num_workers=4, pin_memory=True
    )

    return train_loader, test_loader

Next, we get a pretrained resNet model. 

In [6]:
# Get resnet18 model
def get_model() -> torch.nn.Module:
    """ Get ResNet18 model """
    model = torchvision.models.resnet18(pretrained=True)
    model.fc = torch.nn.Linear(512, 10)
    return model

In [23]:
def train(net, trainloader, epochs):
    optimizer = torch.optim.Adam(net.parameters(), lr=5e-4)
    net.train()
    for _ in range(epochs):
        for batch in trainloader:
            batch = {k: v.to(DEVICE) for k, v in batch.items()}
            outputs = net(**batch)
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()


def test(net, testloader):
    loss = 0
    net.eval()
    total, correct = 0, 0
    for batch in testloader:
        batch = {k: v.to(DEVICE) for k, v in batch.items()}
        with torch.no_grad():
            outputs = net(**batch)
        logits = outputs.logits
        loss += outputs.loss.item()
        predictions = torch.argmax(logits, dim=-1)
        total += batch["labels"].size(0)
        correct += (predictions == batch["labels"]).sum().item()
    accuracy = correct / total
    loss /= len(testloader.dataset)
    return loss, accuracy

In [24]:
model = get_model().to(DEVICE)

RuntimeError: No CUDA GPUs are available

In [9]:
class CIFAR10Client(fl.client.NumPyClient):
    def __init__(self, net, trainloader, testloader):
        self.net = net
        self.trainloader = trainloader
        self.testloader = testloader

    def get_parameters(self, config):
        return [val.cpu().numpy() for _, val in self.net.state_dict().items()]

    def set_parameters(self, parameters):
        params_dict = zip(self.net.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
        self.net.load_state_dict(state_dict, strict=True)

    def fit(self, parameters, config):
        self.set_parameters(parameters)
        print("Training Started...")
        train(self.net, self.trainloader, epochs=1)
        print("Training Finished.")
        return self.get_parameters(config={}), len(self.trainloader), {}

    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        loss, accuracy = test(self.net, self.testloader)
        return float(loss), len(self.testloader), {"accuracy": float(accuracy), "loss": float(loss)}

In [10]:
trainloader, testloader = load_data()
def client_fn(cid):
  return CIFAR10Client(net, trainloader, testloader)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /home/w.lindskog/torch_datasets/cifar-10-python.tar.gz


100.0%


Extracting /home/w.lindskog/torch_datasets/cifar-10-python.tar.gz to /home/w.lindskog/torch_datasets
Files already downloaded and verified


In [25]:
def weighted_average(metrics):
  accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
  losses = [num_examples * m["loss"] for num_examples, m in metrics]
  examples = [num_examples for num_examples, _ in metrics]
  return {"accuracy": sum(accuracies) / sum(examples), "loss": sum(losses) / sum(examples)}

strategy = fl.server.strategy.FedAvg(
    fraction_fit=1.0,
    fraction_evaluate=1.0,
    evaluate_metrics_aggregation_fn=weighted_average,
)

fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=NUM_CLIENTS,
    config=fl.server.ServerConfig(num_rounds=NUM_ROUNDS),
    strategy=strategy,
    client_resources={"num_cpus": 1, "num_gpus": 1},
    ray_init_args={"log_to_driver": False, "num_cpus": 1, "num_gpus": 1}
)

INFO flwr 2023-11-02 16:40:37,211 | app.py:175 | Starting Flower simulation, config: ServerConfig(num_rounds=3, round_timeout=None)
2023-11-02 16:40:40,222	INFO worker.py:1612 -- Started a local Ray instance. View the dashboard at [1m[32m127.0.0.1:8265 [39m[22m
INFO flwr 2023-11-02 16:40:40,477 | app.py:210 | Flower VCE: Ray initialized with resources: {'CPU': 1.0, 'memory': 17872746087.0, 'GPU': 1.0, 'object_store_memory': 8936373043.0, 'node:172.20.106.81': 1.0, 'node:__internal_head__': 1.0}
INFO flwr 2023-11-02 16:40:40,477 | app.py:224 | Flower VCE: Resources for each Virtual Client: {'num_cpus': 1, 'num_gpus': 1}
INFO flwr 2023-11-02 16:40:40,481 | app.py:270 | Flower VCE: Creating VirtualClientEngineActorPool with 1 actors
INFO flwr 2023-11-02 16:40:40,482 | server.py:89 | Initializing global parameters
INFO flwr 2023-11-02 16:40:40,482 | server.py:276 | Requesting initial parameters from one random client
ERROR flwr 2023-11-02 16:40:41,703 | ray_client_proxy.py:147 | Traceb

