Skip to content

Commit

Permalink
Adding sklearn logreg example (#748)
Browse files Browse the repository at this point in the history
  • Loading branch information
cozek committed Jun 30, 2021
1 parent ada3e12 commit 570788c
Show file tree
Hide file tree
Showing 7 changed files with 244 additions and 1 deletion.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ Other [examples](https://github.com/adap/flower/tree/main/examples):
* [MXNet: From Centralized to Federated](https://github.com/adap/flower/tree/main/examples/mxnet_from_centralized_to_federated)
* [Advanced Flower with TensorFlow/Keras](https://github.com/adap/flower/tree/main/examples/advanced_tensorflow)
* [Single-Machine Simulation of Federated Learning Systems](https://github.com/adap/flower/tree/main/examples/simulation)
* [Federated learning example with scikit-learn](https://github.com/adap/flower/tree/main/examples/sklearn-logreg-mnist)

## Flower Baselines

Expand Down
1 change: 0 additions & 1 deletion doc/source/good-first-contributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ We wish we had more time to write usage examples because we believe they help
users to get started with building what they want to build. Here are a few
ideas where we'd be happy to accept a PR:

- First scikit-learn example (MNIST)
- First MXNet 1.6 example (MNIST)
- ImageNet (PyTorch/TensorFlow)
- LSTM (PyTorch/TensorFlow)
Expand Down
61 changes: 61 additions & 0 deletions examples/sklearn-logreg-mnist/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Flower Example using scikit-learn

This example of Flower uses `scikit-learn`'s `LogisticRegression` model to train a federated learning system. It will help you understand how to adapt Flower for use with `scikit-learn`.
Running this example in itself is quite easy.

## Project Setup

Start by cloning the example project. We prepared a single-line command that you can copy into your shell which will checkout the example for you:

```shell
git clone --depth=1 https://github.com/adap/flower.git && mv flower/examples/sklearn-logreg-mnist . && rm -rf flower && cd sklearn-logreg-mnist
```

This will create a new directory called `sklearn-logreg-mnist` containing the following files:

```shell
-- pyproject.toml
-- client.py
-- server.py
-- utils.py
-- README.md
```

Project dependencies (such as `scikit-learn` and `flwr`) are defined in `pyproject.toml`. We recommend [Poetry](https://python-poetry.org/docs/) to install those dependencies and manage your virtual environment ([Poetry installation](https://python-poetry.org/docs/#installation)), but feel free to use a different way of installing dependencies and managing virtual environments if you have other preferences.

```shell
poetry install
poetry shell
```

Poetry will install all your dependencies in a newly created virtual environment. To verify that everything works correctly you can run the following command:

```shell
poetry run python3 -c "import flwr"
```

If you don't see any errors you're good to go!

# Run Federated Learning with scikit-learn and Flower

Afterwards you are ready to start the Flower server as well as the clients. You can simply start the server in a terminal as follows:

```shell
poetry run python3 server.py
```

Now you are ready to start the Flower clients which will participate in the learning. To do so simply open two more terminals and run the following command in each:

```shell
poetry run python3 client.py
```

Alternatively you can run all of it in one shell as follows:

```shell
poetry run python3 server.py &
poetry run python3 client.py &
poetry run python3 client.py
```

You will see that Flower is starting a federated training.
49 changes: 49 additions & 0 deletions examples/sklearn-logreg-mnist/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import warnings
import flwr as fl
import numpy as np

from sklearn.linear_model import LogisticRegression
from sklearn.metrics import log_loss

import utils

if __name__ == "__main__":
# Load MNIST dataset from https://www.openml.org/d/554
(X_train, y_train), (X_test, y_test) = utils.load_mnist()

# Split train set into 10 partitions and randomly use one for training.
partition_id = np.random.choice(10)
(X_train, y_train) = utils.partition(X_train, y_train, 10)[partition_id]

# Create LogisticRegression Model
model = LogisticRegression(
penalty="l2",
max_iter=1, # local epoch
warm_start=True, # prevent refreshing weights when fitting
)

# Setting initial parameters, akin to model.compile for keras models
utils.set_initial_params(model)

# Define Flower client
class MnistClient(fl.client.NumPyClient):
def get_parameters(self): # type: ignore
return utils.get_model_parameters(model)

def fit(self, parameters, config): # type: ignore
utils.set_model_params(model, parameters)
# Ignore convergence failure due to low local epochs
with warnings.catch_warnings():
warnings.simplefilter("ignore")
model.fit(X_train, y_train)
print(f"Training finished for round {config['rnd']}")
return utils.get_model_parameters(model), len(X_train), {}

def evaluate(self, parameters, config): # type: ignore
utils.set_model_params(model, parameters)
loss = log_loss(y_test, model.predict_proba(X_test))
accuracy = model.score(X_test, y_test)
return loss, len(X_test), {"accuracy": accuracy}

# Start Flower client
fl.client.start_numpy_client("0.0.0.0:8080", client=MnistClient())
20 changes: 20 additions & 0 deletions examples/sklearn-logreg-mnist/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
[build-system]
requires = [
"poetry==1.1.6",
]
build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "sklearn-mnist"
version = "0.1.0"
description = "Federated training with scikit-learn"
authors = [
"The Flower Authors <enquiries@flower.dev>",
"Kaushik Amar Das <kaushik.das@iiitg.ac.in>"
]

[tool.poetry.dependencies]
python = "^3.8"
flwr = "^0.16.0"
scikit-learn = "^0.24.2"
openml = "^0.12.2"
39 changes: 39 additions & 0 deletions examples/sklearn-logreg-mnist/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import flwr as fl
import utils
from sklearn.metrics import log_loss
from sklearn.linear_model import LogisticRegression
from typing import Dict


def fit_round(rnd: int) -> Dict:
"""Send round number to client"""
return {"rnd": rnd}


def get_eval_fn(model: LogisticRegression):
"""Return an evaluation function for server-side evaluation."""

# Load test data here to avoid the overhead of doing it in `evaluate` itself
_, (X_test, y_test) = utils.load_mnist()

# The `evaluate` function will be called after every round
def evaluate(parameters: fl.common.Weights):
# Update model with the latest parameters
utils.set_model_params(model, parameters)
loss = log_loss(y_test, model.predict_proba(X_test))
accuracy = model.score(X_test, y_test)
return loss, {"accuracy": accuracy}

return evaluate


# Start Flower server for five rounds of federated learning
if __name__ == "__main__":
model = LogisticRegression()
utils.set_initial_params(model)
strategy = fl.server.strategy.FedAvg(
min_available_clients=2,
eval_fn=get_eval_fn(model),
on_fit_config_fn=fit_round,
)
fl.server.start_server("0.0.0.0:8080", strategy=strategy, config={"num_rounds": 5})
74 changes: 74 additions & 0 deletions examples/sklearn-logreg-mnist/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from typing import Tuple, Union, List
import numpy as np
from sklearn.linear_model import LogisticRegression
import openml

XY = Tuple[np.ndarray, np.ndarray]
Dataset = Tuple[XY, XY]
LogRegParams = Union[XY, Tuple[np.ndarray]]
XYList = List[XY]


def get_model_parameters(model: LogisticRegression) -> LogRegParams:
"""Returns the paramters of a sklearn LogisticRegression model"""
if model.fit_intercept:
params = (model.coef_, model.intercept_)
else:
params = (model.coef_,)
return params


def set_model_params(
model: LogisticRegression, params: LogRegParams
) -> LogisticRegression:
"""Sets the parameters of a sklean LogisticRegression model"""
model.coef_ = params[0]
if model.fit_intercept:
model.intercept_ = params[1]
return model


def set_initial_params(model: LogisticRegression):
"""
Sets initial parameters as zeros
Required since model params are uninitialized until model.fit is called.
But server asks for initial parameters from clients at launch.
Refer to sklearn.linear_model.LogisticRegression documentation
for more information.
"""
n_classes = 10 # MNIST has 10 classes
n_features = 784 # Number of features in dataset
model.classes_ = np.array([i for i in range(10)])

model.coef_ = np.zeros((n_classes, n_features))
if model.fit_intercept:
model.intercept_ = np.zeros((n_classes,))


def load_mnist() -> Dataset:
"""
Loads the MNIST dataset using OpenML
Dataset link: https://www.openml.org/d/554
"""
mnist_openml = openml.datasets.get_dataset(554)
Xy, _, _, _ = mnist_openml.get_data(dataset_format="array")
X = Xy[:, :-1] # the last column contains labels
y = Xy[:, -1]
# First 60000 samples consist of the train set
x_train, y_train = X[:60000], y[:60000]
x_test, y_test = X[60000:], y[60000:]
return (x_train, y_train), (x_test, y_test)


def shuffle(X: np.ndarray, y: np.ndarray) -> XY:
"""Shuffle X and y."""
rng = np.random.default_rng()
idx = rng.permutation(len(X))
return X[idx], y[idx]


def partition(X: np.ndarray, y: np.ndarray, num_partitions: int) -> XYList:
"""Split X and y into a number of partitions."""
return list(
zip(np.array_split(X, num_partitions), np.array_split(y, num_partitions))
)

0 comments on commit 570788c

Please sign in to comment.