Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding advanced_pytorch example #1007

Merged
merged 29 commits into from May 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
915812e
adding notebook for local simulation
cozek Jan 15, 2022
9750ef8
updating code
cozek Jan 16, 2022
6f5e0b1
adding server code
cozek Jan 16, 2022
6e15fcb
adding client server scripts
cozek Jan 16, 2022
4234c19
updating code
cozek Jan 16, 2022
d19e1a7
removing notebook
cozek Jan 16, 2022
6af58a3
black formatting
cozek Jan 16, 2022
426758d
Add advanced_pytorch example.
cozek Jan 17, 2022
d36170c
Merge branch 'main' into advanced-pytorch
cozek Jan 21, 2022
298074e
Merge branch 'main' into advanced-pytorch
cozek Feb 16, 2022
4e86244
Merge branch 'advanced-pytorch' of https://github.com/cozek/flower in…
cozek Feb 16, 2022
b04cda6
Add CIFAR download to run.sh; change device to CPU
cozek Feb 16, 2022
dd5430e
Merge branch 'main' into advanced-pytorch
pedropgusmao Feb 18, 2022
82fbc43
Merge branch 'main' into advanced-pytorch
cozek Mar 5, 2022
b434ace
download model in run.sh; move model to cpu after training; return ev…
cozek Mar 5, 2022
ce59a09
Merge branch 'main' into advanced-pytorch
pedropgusmao Mar 7, 2022
defcab1
fixes
cozek Mar 13, 2022
e16b652
Merge branch 'main' into advanced-pytorch
cozek Mar 13, 2022
3156030
Merge branch 'main' into advanced-pytorch
cozek Apr 13, 2022
f6eebab
Merge branch 'main' into advanced-pytorch
tanertopal Apr 13, 2022
9b8a9cc
Merge branch 'main' into advanced-pytorch
danieljanes Apr 13, 2022
7f6ec31
Merge branch 'main' into advanced-pytorch
cozek Apr 19, 2022
b65deb9
Merge branch 'main' into advanced-pytorch
cozek Apr 25, 2022
fa405d1
Merge branch 'main' into advanced-pytorch
danieljanes Apr 25, 2022
b6a2351
Merge branch 'main' into advanced-pytorch
pedropgusmao Apr 27, 2022
c9b885b
client can now specify which device to use for training; adding arg f…
cozek Apr 28, 2022
95f1745
minor changes
cozek Apr 28, 2022
d66cdf2
Merge branch 'main' into advanced-pytorch
cozek Apr 28, 2022
b9bd2b7
Merge branch 'main' into advanced-pytorch
danieljanes May 4, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Expand Up @@ -118,6 +118,7 @@ Other [examples](https://github.com/adap/flower/tree/main/examples):
* [MXNet: From Centralized to Federated](https://github.com/adap/flower/tree/main/examples/mxnet_from_centralized_to_federated)
* [JAX: From Centralized to Federated](https://github.com/adap/flower/tree/main/examples/jax_from_centralized_to_federated)
* [Advanced Flower with TensorFlow/Keras](https://github.com/adap/flower/tree/main/examples/advanced_tensorflow)
* [Advanced Flower with PyTorch](https://github.com/adap/flower/tree/main/examples/advanced_pytorch)
* [Single-Machine Simulation of Federated Learning Systems](https://github.com/adap/flower/tree/main/examples/simulation)

## Community
Expand Down
53 changes: 53 additions & 0 deletions examples/advanced_pytorch/README.md
@@ -0,0 +1,53 @@
# Advanced Flower Example (PyTorch)

This example demonstrates an advanced federated learning setup using Flower with PyTorch. It differs from the quickstart example in the following ways:

- 10 clients (instead of just 2)
- Each client holds a local dataset of 5000 training examples and 1000 test examples
- Server-side model evaluation after parameter aggregation
- Hyperparameter schedule using config functions
- Custom return values
- Server-side parameter initialization

## Project Setup

Start by cloning the example project. We prepared a single-line command that you can copy into your shell which will checkout the example for you:

```shell
git clone --depth=1 https://github.com/adap/flower.git && mv flower/examples/advanced_pytorch . && rm -rf flower && cd advanced_pytorch
```

This will create a new directory called `advanced_pytorch` containing the following files:

```shell
-- pyproject.toml
-- client.py
-- server.py
-- README.md
-- run.sh
```

Project dependencies (such as `pytorch` and `flwr`) are defined in `pyproject.toml`. We recommend [Poetry](https://python-poetry.org/docs/) to install those dependencies and manage your virtual environment ([Poetry installation](https://python-poetry.org/docs/#installation)), but feel free to use a different way of installing dependencies and managing virtual environments if you have other preferences.

```shell
poetry install
poetry shell
```

Poetry will install all your dependencies in a newly created virtual environment. To verify that everything works correctly you can run the following command:

```shell
poetry run python3 -c "import flwr"
```

If you don't see any errors you're good to go!

# Run Federated Learning with PyTorch and Flower

The included `run.sh` will start the Flower server (using `server.py`), sleep for 2 seconds to ensure the the server is up, and then start 10 Flower clients (using `client.py`). You can simply start everything in a terminal as follows:

```shell
poetry run ./run.sh
```

The `run.sh` script starts processes in the background so that you don't have to open eleven terminal windows. If you experiment with the code example and something goes wrong, simply using `CTRL + C` on Linux (or `CMD + C` on macOS) wouldn't normally kill all these processes, which is why the script ends with `trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM EXIT` and `wait`. This simply allows you to stop the experiment using `CTRL + C` (or `CMD + C`). If you change the script and anyhting goes wrong you can still use `killall python` (or `killall python3`) to kill all background processes (or a more specific command if you have other Python processes running that you don't want to kill).
157 changes: 157 additions & 0 deletions examples/advanced_pytorch/client.py
@@ -0,0 +1,157 @@
import utils
from torch.utils.data import DataLoader
import torchvision.datasets
import torch
import flwr as fl
import argparse
from collections import OrderedDict
import warnings

warnings.filterwarnings("ignore")


class CifarClient(fl.client.NumPyClient):
def __init__(
self,
trainset: torchvision.datasets,
testset: torchvision.datasets,
device: str,
validation_split: int = 0.1,
cozek marked this conversation as resolved.
Show resolved Hide resolved
):
self.device = device
self.trainset = trainset
self.testset = testset
cozek marked this conversation as resolved.
Show resolved Hide resolved
self.validation_split = validation_split

def get_parameters(self):
"""Get parameters of the local model."""
raise Exception("Not implemented (server-side parameter initialization)")

def set_parameters(self, parameters):
"""Loads a efficientnet model and replaces it parameters with
the ones given"""
model = utils.load_efficientnet(classes=10)
params_dict = zip(model.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
model.load_state_dict(state_dict, strict=True)
return model

def fit(self, parameters, config):
"""Train parameters on the locally held training set."""

# Update local model parameters
model = self.set_parameters(parameters)

# Get hyperparameters for this round
batch_size: int = config["batch_size"]
epochs: int = config["local_epochs"]

n_valset = int(len(self.trainset) * self.validation_split)

valset = torch.utils.data.Subset(self.trainset, range(0, n_valset))
trainset = torch.utils.data.Subset(
self.trainset, range(n_valset, len(self.trainset))
)

trainLoader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
valLoader = DataLoader(valset, batch_size=batch_size)

results = utils.train(model, trainLoader, valLoader, epochs, self.device)

parameters_prime = utils.get_model_params(model)
num_examples_train = len(trainset)

return parameters_prime, num_examples_train, results

def evaluate(self, parameters, config):
"""Evaluate parameters on the locally held test set."""
# Update local model parameters
model = self.set_parameters(parameters)

# Get config values
steps: int = config["val_steps"]

# Evaluate global model parameters on the local test data and return results
testloader = DataLoader(self.testset, batch_size=16)

loss, accuracy = utils.test(model, testloader, steps, self.device)
return float(loss), len(self.testset), {"accuracy": float(accuracy)}


def client_dry_run(device: str = "cpu"):
"""Weak tests to check whether all client methods are working as expected."""

model = utils.load_efficientnet(classes=10)
trainset, testset = utils.load_partition(0)
trainset = torch.utils.data.Subset(trainset, range(10))
testset = torch.utils.data.Subset(testset, range(10))
client = CifarClient(trainset, testset, device)
client.fit(
utils.get_model_params(model),
{"batch_size": 16, "local_epochs": 1},
)

client.evaluate(utils.get_model_params(model), {"val_steps": 32})

print("Dry Run Successful")


def main() -> None:
# Parse command line argument `partition`
parser = argparse.ArgumentParser(description="Flower")
parser.add_argument(
"--dry",
type=bool,
default=False,
required=False,
help="Do a dry-run to check the client",
)
parser.add_argument(
"--partition",
type=int,
default=0,
choices=range(0, 10),
required=False,
help="Specifies the artificial data partition of CIFAR10 to be used. \
Picks partition 0 by default",
)
parser.add_argument(
"--toy",
type=bool,
default=False,
required=False,
help="Set to true to quicky run the client using only 10 datasamples. \
Useful for testing purposes. Default: False",
)
parser.add_argument(
"--use_cuda",
type=bool,
default=False,
required=False,
help="Set to true to use GPU. Default: False",
)

cozek marked this conversation as resolved.
Show resolved Hide resolved
args = parser.parse_args()

device = torch.device(
"cuda:0" if torch.cuda.is_available() and args.use_cuda else "cpu"
)

if args.dry:
client_dry_run(device)
else:
# Load a subset of CIFAR-10 to simulate the local data partition
trainset, testset = utils.load_partition(args.partition)

if args.toy:
trainset = torch.utils.data.Subset(trainset, range(10))
testset = torch.utils.data.Subset(testset, range(10))

cozek marked this conversation as resolved.
Show resolved Hide resolved
# Start Flower client
client = CifarClient(trainset, testset, device)

fl.client.start_numpy_client("0.0.0.0:8080", client=client)


if __name__ == "__main__":
main()
21 changes: 21 additions & 0 deletions examples/advanced_pytorch/pyproject.toml
@@ -0,0 +1,21 @@
[build-system]
requires = [
"poetry==1.1.12",
]
build-backend = "poetry.masonry.api"

[tool.poetry]
name = "advanced_flwr_pytorch"
version = "0.1.0"
description = "Advanced Flower/PyTorch Example"
authors = [
"The Flower Authors <enquiries@flower.dev>",
"Kaushik Amar Das <kaushik.das@iiitg.ac.in>"
]

[tool.poetry.dependencies]
python = "^3.6.2"
flwr = "^0.17.0" # For development: { path = "../../", develop = true }
torch = "1.9.0"
torchvision = "0.10.0"
validators = "0.18.2"
23 changes: 23 additions & 0 deletions examples/advanced_pytorch/run.sh
@@ -0,0 +1,23 @@
#!/bin/bash

# download the CIFAR10 dataset and the efficientnet model
# subsequent runs do not redownload
python -c "from torchvision.datasets import CIFAR10; \
CIFAR10('./dataset', train=True, download=True)"

python -c "import torch; torch.hub.load( \
'NVIDIA/DeepLearningExamples:torchhub', \
'nvidia_efficientnet_b0', pretrained=True)"

python server.py &
sleep 2 # Sleep for 2s to give the server enough time to start

for i in `seq 0 9`; do
echo "Starting client $i"
python client.py --partition=${i} &
done

# This will allow you to use CTRL+C to stop all background processes
trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM
# Wait for all background processes to complete
wait
110 changes: 110 additions & 0 deletions examples/advanced_pytorch/server.py
@@ -0,0 +1,110 @@
from typing import Dict, Optional, Tuple
from collections import OrderedDict
import argparse
from torch.utils.data import DataLoader

import flwr as fl
import torch

import utils

import warnings

warnings.filterwarnings("ignore")


def fit_config(rnd: int):
"""Return training configuration dict for each round.
Keep batch size fixed at 32, perform two rounds of training with one
local epoch, increase to two local epochs afterwards.
"""
config = {
"batch_size": 16,
"local_epochs": 1 if rnd < 2 else 2,
}
return config


def evaluate_config(rnd: int):
"""Return evaluation configuration dict for each round.
Perform five local evaluation steps on each client (i.e., use five
batches) during rounds one to three, then increase to ten local
evaluation steps.
"""
val_steps = 5 if rnd < 4 else 10
return {"val_steps": val_steps}


def get_eval_fn(model: torch.nn.Module, toy: bool):
"""Return an evaluation function for server-side evaluation."""

# Load data and model here to avoid the overhead of doing it in `evaluate` itself
trainset, _, _ = utils.load_data()

n_train = len(trainset)
if toy:
# use only 10 samples as validation set
valset = torch.utils.data.Subset(trainset, range(n_train - 10, n_train))
else:
# Use the last 5k training examples as a validation set
valset = torch.utils.data.Subset(trainset, range(n_train - 5000, n_train))

valLoader = DataLoader(valset, batch_size=16)
# The `evaluate` function will be called after every round
def evaluate(
weights: fl.common.Weights,
) -> Optional[Tuple[float, Dict[str, fl.common.Scalar]]]:
# Update model with the latest parameters
params_dict = zip(model.state_dict().keys(), weights)
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
model.load_state_dict(state_dict, strict=True)

loss, accuracy = utils.test(model, valLoader)
return loss, {"accuracy": accuracy}

return evaluate


def main():
"""
# Load model for
# 1. server-side parameter initialization
# 2. server-side parameter evaluation
"""

# Parse command line argument `partition`
parser = argparse.ArgumentParser(description="Flower")
parser.add_argument(
"--toy",
type=bool,
default=False,
required=False,
help="Set to true to use only 10 datasamples for validation. \
Useful for testing purposes. Default: False",
)

args = parser.parse_args()

model = utils.load_efficientnet(classes=10)

model_weights = [val.cpu().numpy() for _, val in model.state_dict().items()]

# Create strategy
strategy = fl.server.strategy.FedAvg(
fraction_fit=0.2,
fraction_eval=0.2,
min_fit_clients=2,
min_eval_clients=2,
min_available_clients=10,
eval_fn=get_eval_fn(model, args.toy),
on_fit_config_fn=fit_config,
on_evaluate_config_fn=evaluate_config,
initial_parameters=fl.common.weights_to_parameters(model_weights),
)

# Start Flower server for four rounds of federated learning
fl.server.start_server("0.0.0.0:8080", config={"num_rounds": 4}, strategy=strategy)


if __name__ == "__main__":
main()