diff --git a/examples/quickstart-pytorch/README.md b/examples/quickstart-pytorch/README.md index 02c9b4b3849..978191cc0ec 100644 --- a/examples/quickstart-pytorch/README.md +++ b/examples/quickstart-pytorch/README.md @@ -47,6 +47,8 @@ Write the command below in your terminal to install the dependencies according t pip install -r requirements.txt ``` +______________________________________________________________________ + ## Run Federated Learning with PyTorch and Flower Afterwards you are ready to start the Flower server as well as the clients. You can simply start the server in a terminal as follows: @@ -72,3 +74,29 @@ python3 client.py --partition-id 1 ``` You will see that PyTorch is starting a federated training. Look at the [code](https://github.com/adap/flower/tree/main/examples/quickstart-pytorch) for a detailed explanation. + +______________________________________________________________________ + +## Run Federated Learning with PyTorch and `Flower Next` + +### 1. Start the long-running Flower server (SuperLink) + +```bash +flower-superlink --insecure +``` + +### 2. Start the long-running Flower clients (SuperNodes) + +Start 2 Flower `SuperNodes` in 2 separate terminal windows, using: + +```bash +flower-client-app client:app --insecure +``` + +### 3. Run the Flower App + +With both the long-running server (SuperLink) and two clients (SuperNode) up and running, we can now run the actual Flower App: + +```bash +flower-server-app server:app --insecure +``` diff --git a/examples/quickstart-pytorch/client.py b/examples/quickstart-pytorch/client.py index e640ce111df..e58dbf7ea0b 100644 --- a/examples/quickstart-pytorch/client.py +++ b/examples/quickstart-pytorch/client.py @@ -2,7 +2,7 @@ import warnings from collections import OrderedDict -import flwr as fl +from flwr.client import NumPyClient, ClientApp from flwr_datasets import FederatedDataset import torch import torch.nn as nn @@ -99,11 +99,11 @@ def apply_transforms(batch): parser.add_argument( "--partition-id", choices=[0, 1, 2], - required=True, + default=0, type=int, help="Partition of the dataset divided into 3 iid partitions created artificially.", ) -partition_id = parser.parse_args().partition_id +partition_id = parser.parse_known_args()[0].partition_id # Load model and data (simple CNN, CIFAR-10) net = Net().to(DEVICE) @@ -111,7 +111,7 @@ def apply_transforms(batch): # Define Flower client -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def get_parameters(self, config): return [val.cpu().numpy() for _, val in net.state_dict().items()] @@ -131,8 +131,22 @@ def evaluate(self, parameters, config): return loss, len(testloader.dataset), {"accuracy": accuracy} -# Start Flower client -fl.client.start_client( - server_address="127.0.0.1:8080", - client=FlowerClient().to_client(), +def client_fn(cid: str): + """Create and return an instance of Flower `Client`.""" + return FlowerClient().to_client() + + +# Flower ClientApp +app = ClientApp( + client_fn=client_fn, ) + + +# Legacy mode +if __name__ == "__main__": + from flwr.client import start_client + + start_client( + server_address="127.0.0.1:8080", + client=FlowerClient().to_client(), + ) diff --git a/examples/quickstart-pytorch/pyproject.toml b/examples/quickstart-pytorch/pyproject.toml index d8e1503dd8a..7255e627471 100644 --- a/examples/quickstart-pytorch/pyproject.toml +++ b/examples/quickstart-pytorch/pyproject.toml @@ -10,7 +10,7 @@ authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" -flwr = ">=1.0,<2.0" +flwr = ">=1.8.0,<2.0" flwr-datasets = { extras = ["vision"], version = ">=0.0.2,<1.0.0" } torch = "2.1.1" torchvision = "0.16.1" diff --git a/examples/quickstart-pytorch/requirements.txt b/examples/quickstart-pytorch/requirements.txt index 4e321e2cd0c..12627809551 100644 --- a/examples/quickstart-pytorch/requirements.txt +++ b/examples/quickstart-pytorch/requirements.txt @@ -1,4 +1,4 @@ -flwr>=1.0, <2.0 +flwr>=1.8.0, <2.0 flwr-datasets[vision]>=0.0.2, <1.0.0 torch==2.1.1 torchvision==0.16.1 diff --git a/examples/quickstart-pytorch/server.py b/examples/quickstart-pytorch/server.py index fe691a88aba..4034703ca69 100644 --- a/examples/quickstart-pytorch/server.py +++ b/examples/quickstart-pytorch/server.py @@ -1,6 +1,7 @@ from typing import List, Tuple -import flwr as fl +from flwr.server import ServerApp, ServerConfig +from flwr.server.strategy import FedAvg from flwr.common import Metrics @@ -15,11 +16,26 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: # Define strategy -strategy = fl.server.strategy.FedAvg(evaluate_metrics_aggregation_fn=weighted_average) +strategy = FedAvg(evaluate_metrics_aggregation_fn=weighted_average) -# Start Flower server -fl.server.start_server( - server_address="0.0.0.0:8080", - config=fl.server.ServerConfig(num_rounds=3), + +# Define config +config = ServerConfig(num_rounds=3) + + +# Flower ServerApp +app = ServerApp( + config=config, strategy=strategy, ) + + +# Legacy mode +if __name__ == "__main__": + from flwr.server import start_server + + start_server( + server_address="0.0.0.0:8080", + config=config, + strategy=strategy, + )