-
Notifications
You must be signed in to change notification settings - Fork 159
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
ADD CODE TO ADDRESS THE NEW DESIGN CHANGES.
There is a bug in the model update ( where the 2nd round missing keys)
- Loading branch information
1 parent
8684609
commit aceb96a
Showing
17 changed files
with
1,048 additions
and
35 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
# FedAvg: simplified | ||
|
||
This example illustrates How to use the new Workflow Communication API to contract a workflow: no need to write a controller. | ||
|
||
## FLARE Workflow Communicator API | ||
|
||
The Flare workflow Communicator API only has small set methods | ||
|
||
``` | ||
class WFCommAPISpec(ABC): | ||
@abstractmethod | ||
def broadcast_and_wait(self, msg_payload: Dict): | ||
pass | ||
@abstractmethod | ||
def send_and_wait(self, msg_payload: Dict): | ||
pass | ||
@abstractmethod | ||
def relay_and_wait(self, msg_payload: Dict): | ||
pass | ||
@abstractmethod | ||
def broadcast(self, msg_payload: Dict): | ||
pass | ||
@abstractmethod | ||
def send(self, msg_payload: Dict): | ||
pass | ||
@abstractmethod | ||
def relay(self, msg_payload: Dict): | ||
pass | ||
@abstractmethod | ||
def get_site_names(self) -> List[str]: | ||
pass | ||
@abstractmethod | ||
def wait_all(self, min_responses: int, resp_max_wait_time: Optional[float]) -> Dict[str, Dict[str, FLModel]]: | ||
pass | ||
@abstractmethod | ||
def wait_one(self, resp_max_wait_time: Optional[float] = None) -> Tuple[str, str, FLModel]: | ||
pass | ||
``` | ||
|
||
|
||
## Writing a new Workflow | ||
|
||
With this new API writing the new workflow is really simple: | ||
|
||
* Workflow (Server) | ||
|
||
``` | ||
from nvflare.app_common.workflows import wf_comm as flare | ||
class FedAvg: | ||
def __init__( | ||
self, | ||
min_clients: int, | ||
num_rounds: int, | ||
output_path: str, | ||
start_round: int = 1, | ||
stop_cond: str = None, | ||
model_selection_rule: str = None, | ||
): | ||
super(FedAvg, self).__init__() | ||
<skip init code> | ||
self.flare_comm = flare.get_wf_comm_api() | ||
def run(self): | ||
self.logger.info("start Fed Avg Workflow\n \n") | ||
start = self.start_round | ||
end = self.start_round + self.num_rounds | ||
model = self.init_model() | ||
for current_round in range(start, end): | ||
self.logger.info(f"Round {current_round}/{self.num_rounds} started. {start=}, {end=}") | ||
self.current_round = current_round | ||
sag_results = self.scatter_and_gather(model, current_round) | ||
aggr_result = self.aggr_fn(sag_results) | ||
self.logger.info(f"aggregate metrics = {aggr_result.metrics}") | ||
model = update_model(model, aggr_result) | ||
self.select_best_model(model) | ||
self.save_model(self.best_model, self.output_path) | ||
self.logger.info("end Fed Avg Workflow\n \n") | ||
``` | ||
Scatter and Gather (SAG): | ||
|
||
SAG is simply ask WFController to broadcast the model to all clients | ||
|
||
``` | ||
def scatter_and_gather(self, model: FLModel, current_round): | ||
msg_payload = {"min_responses": self.min_clients, | ||
"current_round": current_round, | ||
"num_round": self.num_rounds, | ||
"start_round": self.start_round, | ||
"data": model} | ||
# (2) broadcast and wait | ||
results = self.flare_comm.broadcast_and_wait(msg_payload) | ||
return results | ||
``` | ||
|
||
## Configurations | ||
|
||
### client-side configuration | ||
|
||
This is the same as FLARE Client API configuration | ||
|
||
### server-side configuration | ||
|
||
Server side controller is really simple, all we need is to use WFController with newly defined workflow class | ||
|
||
|
||
``` | ||
{ | ||
# version of the configuration | ||
format_version = 2 | ||
task_data_filters =[] | ||
task_result_filters = [] | ||
workflows = [ | ||
{ | ||
id = "fed_avg" | ||
path = "nvflare.app_opt.pt.wf_controller.PTWFController" | ||
args { | ||
comm_msg_pull_interval = 5 | ||
task_name = "train" | ||
wf_class_path = "fedavg_pt.PTFedAvg", | ||
wf_args { | ||
min_clients = 2 | ||
num_rounds = 10 | ||
output_path = "/tmp/nvflare/fedavg/mode.pth" | ||
stop_cond = "accuracy >= 55" | ||
model_selection_rule = "accuracy >=" | ||
} | ||
} | ||
} | ||
] | ||
components = [] | ||
} | ||
``` | ||
|
||
|
||
## Run the job | ||
|
||
assume current working directory is at ```hello-fedavg``` directory | ||
|
||
``` | ||
nvflare simulator -n 2 -t 2 jobs/fedavg -w /tmp/fedavg | ||
``` |
77 changes: 77 additions & 0 deletions
77
examples/hello-world/hello-fedavg/jobs/fedavg/app/config/config_fed_client.conf
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
{ | ||
format_version = 2 | ||
app_script = "cifar10_fl.py" | ||
app_config = "" | ||
executors = [ | ||
{ | ||
tasks = [ | ||
"train" | ||
] | ||
executor { | ||
path = "nvflare.app_opt.pt.client_api_launcher_executor.PTClientAPILauncherExecutor" | ||
args { | ||
launcher_id = "launcher" | ||
pipe_id = "pipe" | ||
heartbeat_timeout = 60 | ||
params_exchange_format = "pytorch" | ||
params_transfer_type = "DIFF" | ||
train_with_evaluation = true | ||
} | ||
} | ||
} | ||
] | ||
task_data_filters = [] | ||
task_result_filters = [] | ||
components = [ | ||
{ | ||
id = "launcher" | ||
path = "nvflare.app_common.launchers.subprocess_launcher.SubprocessLauncher" | ||
args { | ||
script = "python3 custom/{app_script} {app_config} " | ||
launch_once = true | ||
} | ||
} | ||
{ | ||
id = "pipe" | ||
path = "nvflare.fuel.utils.pipe.cell_pipe.CellPipe" | ||
args { | ||
mode = "PASSIVE" | ||
site_name = "{SITE_NAME}" | ||
token = "{JOB_ID}" | ||
root_url = "{ROOT_URL}" | ||
secure_mode = "{SECURE_MODE}" | ||
workspace_dir = "{WORKSPACE}" | ||
} | ||
} | ||
{ | ||
id = "metrics_pipe" | ||
path = "nvflare.fuel.utils.pipe.cell_pipe.CellPipe" | ||
args { | ||
mode = "PASSIVE" | ||
site_name = "{SITE_NAME}" | ||
token = "{JOB_ID}" | ||
root_url = "{ROOT_URL}" | ||
secure_mode = "{SECURE_MODE}" | ||
workspace_dir = "{WORKSPACE}" | ||
} | ||
} | ||
{ | ||
id = "metric_relay" | ||
path = "nvflare.app_common.widgets.metric_relay.MetricRelay" | ||
args { | ||
pipe_id = "metrics_pipe" | ||
event_type = "fed.analytix_log_stats" | ||
read_interval = 0.1 | ||
} | ||
} | ||
{ | ||
id = "config_preparer" | ||
path = "nvflare.app_common.widgets.external_configurator.ExternalConfigurator" | ||
args { | ||
component_ids = [ | ||
"metric_relay" | ||
] | ||
} | ||
} | ||
] | ||
} |
29 changes: 29 additions & 0 deletions
29
examples/hello-world/hello-fedavg/jobs/fedavg/app/config/config_fed_server.conf
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
{ | ||
# version of the configuration | ||
format_version = 2 | ||
task_data_filters =[] | ||
task_result_filters = [] | ||
|
||
workflows = [ | ||
{ | ||
id = "fed_avg" | ||
path = "fedavg_pt.PTFedAvg" | ||
args { | ||
min_clients = 2 | ||
num_rounds = 2 | ||
output_path = "/tmp/nvflare/fedavg/mode.pth" | ||
# stop_cond = "accuracy >= 55" | ||
} | ||
} | ||
] | ||
|
||
components = [ | ||
{ | ||
id = "decomposer_register" | ||
path = "nvflare.app_common.wf_comm.decomposer_register.DecomposerRegister" | ||
args { | ||
decomposers = [ "nvflare.app_opt.pt.decomposers.TensorDecomposer"] | ||
} | ||
} | ||
] | ||
} |
Oops, something went wrong.