-
Notifications
You must be signed in to change notification settings - Fork 792
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding sklearn logreg example (#748)
- Loading branch information
Showing
7 changed files
with
244 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# Flower Example using scikit-learn | ||
|
||
This example of Flower uses `scikit-learn`'s `LogisticRegression` model to train a federated learning system. It will help you understand how to adapt Flower for use with `scikit-learn`. | ||
Running this example in itself is quite easy. | ||
|
||
## Project Setup | ||
|
||
Start by cloning the example project. We prepared a single-line command that you can copy into your shell which will checkout the example for you: | ||
|
||
```shell | ||
git clone --depth=1 https://github.com/adap/flower.git && mv flower/examples/sklearn-logreg-mnist . && rm -rf flower && cd sklearn-logreg-mnist | ||
``` | ||
|
||
This will create a new directory called `sklearn-logreg-mnist` containing the following files: | ||
|
||
```shell | ||
-- pyproject.toml | ||
-- client.py | ||
-- server.py | ||
-- utils.py | ||
-- README.md | ||
``` | ||
|
||
Project dependencies (such as `scikit-learn` and `flwr`) are defined in `pyproject.toml`. We recommend [Poetry](https://python-poetry.org/docs/) to install those dependencies and manage your virtual environment ([Poetry installation](https://python-poetry.org/docs/#installation)), but feel free to use a different way of installing dependencies and managing virtual environments if you have other preferences. | ||
|
||
```shell | ||
poetry install | ||
poetry shell | ||
``` | ||
|
||
Poetry will install all your dependencies in a newly created virtual environment. To verify that everything works correctly you can run the following command: | ||
|
||
```shell | ||
poetry run python3 -c "import flwr" | ||
``` | ||
|
||
If you don't see any errors you're good to go! | ||
|
||
# Run Federated Learning with scikit-learn and Flower | ||
|
||
Afterwards you are ready to start the Flower server as well as the clients. You can simply start the server in a terminal as follows: | ||
|
||
```shell | ||
poetry run python3 server.py | ||
``` | ||
|
||
Now you are ready to start the Flower clients which will participate in the learning. To do so simply open two more terminals and run the following command in each: | ||
|
||
```shell | ||
poetry run python3 client.py | ||
``` | ||
|
||
Alternatively you can run all of it in one shell as follows: | ||
|
||
```shell | ||
poetry run python3 server.py & | ||
poetry run python3 client.py & | ||
poetry run python3 client.py | ||
``` | ||
|
||
You will see that Flower is starting a federated training. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import warnings | ||
import flwr as fl | ||
import numpy as np | ||
|
||
from sklearn.linear_model import LogisticRegression | ||
from sklearn.metrics import log_loss | ||
|
||
import utils | ||
|
||
if __name__ == "__main__": | ||
# Load MNIST dataset from https://www.openml.org/d/554 | ||
(X_train, y_train), (X_test, y_test) = utils.load_mnist() | ||
|
||
# Split train set into 10 partitions and randomly use one for training. | ||
partition_id = np.random.choice(10) | ||
(X_train, y_train) = utils.partition(X_train, y_train, 10)[partition_id] | ||
|
||
# Create LogisticRegression Model | ||
model = LogisticRegression( | ||
penalty="l2", | ||
max_iter=1, # local epoch | ||
warm_start=True, # prevent refreshing weights when fitting | ||
) | ||
|
||
# Setting initial parameters, akin to model.compile for keras models | ||
utils.set_initial_params(model) | ||
|
||
# Define Flower client | ||
class MnistClient(fl.client.NumPyClient): | ||
def get_parameters(self): # type: ignore | ||
return utils.get_model_parameters(model) | ||
|
||
def fit(self, parameters, config): # type: ignore | ||
utils.set_model_params(model, parameters) | ||
# Ignore convergence failure due to low local epochs | ||
with warnings.catch_warnings(): | ||
warnings.simplefilter("ignore") | ||
model.fit(X_train, y_train) | ||
print(f"Training finished for round {config['rnd']}") | ||
return utils.get_model_parameters(model), len(X_train), {} | ||
|
||
def evaluate(self, parameters, config): # type: ignore | ||
utils.set_model_params(model, parameters) | ||
loss = log_loss(y_test, model.predict_proba(X_test)) | ||
accuracy = model.score(X_test, y_test) | ||
return loss, len(X_test), {"accuracy": accuracy} | ||
|
||
# Start Flower client | ||
fl.client.start_numpy_client("0.0.0.0:8080", client=MnistClient()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
[build-system] | ||
requires = [ | ||
"poetry==1.1.6", | ||
] | ||
build-backend = "poetry.core.masonry.api" | ||
|
||
[tool.poetry] | ||
name = "sklearn-mnist" | ||
version = "0.1.0" | ||
description = "Federated training with scikit-learn" | ||
authors = [ | ||
"The Flower Authors <enquiries@flower.dev>", | ||
"Kaushik Amar Das <kaushik.das@iiitg.ac.in>" | ||
] | ||
|
||
[tool.poetry.dependencies] | ||
python = "^3.8" | ||
flwr = "^0.16.0" | ||
scikit-learn = "^0.24.2" | ||
openml = "^0.12.2" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import flwr as fl | ||
import utils | ||
from sklearn.metrics import log_loss | ||
from sklearn.linear_model import LogisticRegression | ||
from typing import Dict | ||
|
||
|
||
def fit_round(rnd: int) -> Dict: | ||
"""Send round number to client""" | ||
return {"rnd": rnd} | ||
|
||
|
||
def get_eval_fn(model: LogisticRegression): | ||
"""Return an evaluation function for server-side evaluation.""" | ||
|
||
# Load test data here to avoid the overhead of doing it in `evaluate` itself | ||
_, (X_test, y_test) = utils.load_mnist() | ||
|
||
# The `evaluate` function will be called after every round | ||
def evaluate(parameters: fl.common.Weights): | ||
# Update model with the latest parameters | ||
utils.set_model_params(model, parameters) | ||
loss = log_loss(y_test, model.predict_proba(X_test)) | ||
accuracy = model.score(X_test, y_test) | ||
return loss, {"accuracy": accuracy} | ||
|
||
return evaluate | ||
|
||
|
||
# Start Flower server for five rounds of federated learning | ||
if __name__ == "__main__": | ||
model = LogisticRegression() | ||
utils.set_initial_params(model) | ||
strategy = fl.server.strategy.FedAvg( | ||
min_available_clients=2, | ||
eval_fn=get_eval_fn(model), | ||
on_fit_config_fn=fit_round, | ||
) | ||
fl.server.start_server("0.0.0.0:8080", strategy=strategy, config={"num_rounds": 5}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
from typing import Tuple, Union, List | ||
import numpy as np | ||
from sklearn.linear_model import LogisticRegression | ||
import openml | ||
|
||
XY = Tuple[np.ndarray, np.ndarray] | ||
Dataset = Tuple[XY, XY] | ||
LogRegParams = Union[XY, Tuple[np.ndarray]] | ||
XYList = List[XY] | ||
|
||
|
||
def get_model_parameters(model: LogisticRegression) -> LogRegParams: | ||
"""Returns the paramters of a sklearn LogisticRegression model""" | ||
if model.fit_intercept: | ||
params = (model.coef_, model.intercept_) | ||
else: | ||
params = (model.coef_,) | ||
return params | ||
|
||
|
||
def set_model_params( | ||
model: LogisticRegression, params: LogRegParams | ||
) -> LogisticRegression: | ||
"""Sets the parameters of a sklean LogisticRegression model""" | ||
model.coef_ = params[0] | ||
if model.fit_intercept: | ||
model.intercept_ = params[1] | ||
return model | ||
|
||
|
||
def set_initial_params(model: LogisticRegression): | ||
""" | ||
Sets initial parameters as zeros | ||
Required since model params are uninitialized until model.fit is called. | ||
But server asks for initial parameters from clients at launch. | ||
Refer to sklearn.linear_model.LogisticRegression documentation | ||
for more information. | ||
""" | ||
n_classes = 10 # MNIST has 10 classes | ||
n_features = 784 # Number of features in dataset | ||
model.classes_ = np.array([i for i in range(10)]) | ||
|
||
model.coef_ = np.zeros((n_classes, n_features)) | ||
if model.fit_intercept: | ||
model.intercept_ = np.zeros((n_classes,)) | ||
|
||
|
||
def load_mnist() -> Dataset: | ||
""" | ||
Loads the MNIST dataset using OpenML | ||
Dataset link: https://www.openml.org/d/554 | ||
""" | ||
mnist_openml = openml.datasets.get_dataset(554) | ||
Xy, _, _, _ = mnist_openml.get_data(dataset_format="array") | ||
X = Xy[:, :-1] # the last column contains labels | ||
y = Xy[:, -1] | ||
# First 60000 samples consist of the train set | ||
x_train, y_train = X[:60000], y[:60000] | ||
x_test, y_test = X[60000:], y[60000:] | ||
return (x_train, y_train), (x_test, y_test) | ||
|
||
|
||
def shuffle(X: np.ndarray, y: np.ndarray) -> XY: | ||
"""Shuffle X and y.""" | ||
rng = np.random.default_rng() | ||
idx = rng.permutation(len(X)) | ||
return X[idx], y[idx] | ||
|
||
|
||
def partition(X: np.ndarray, y: np.ndarray, num_partitions: int) -> XYList: | ||
"""Split X and y into a number of partitions.""" | ||
return list( | ||
zip(np.array_split(X, num_partitions), np.array_split(y, num_partitions)) | ||
) |