Skip to content

Commit

Permalink
ADD CODE TO ADDRESS THE NEW DESIGN CHANGES.
Browse files Browse the repository at this point in the history
There is a bug in the model update ( where the 2nd round missing keys)
  • Loading branch information
chesterxgchen committed Jan 30, 2024
1 parent 8684609 commit aceb96a
Show file tree
Hide file tree
Showing 17 changed files with 1,048 additions and 35 deletions.
172 changes: 172 additions & 0 deletions examples/hello-world/hello-fedavg/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
# FedAvg: simplified

This example illustrates How to use the new Workflow Communication API to contract a workflow: no need to write a controller.

## FLARE Workflow Communicator API

The Flare workflow Communicator API only has small set methods

```
class WFCommAPISpec(ABC):
@abstractmethod
def broadcast_and_wait(self, msg_payload: Dict):
pass
@abstractmethod
def send_and_wait(self, msg_payload: Dict):
pass
@abstractmethod
def relay_and_wait(self, msg_payload: Dict):
pass
@abstractmethod
def broadcast(self, msg_payload: Dict):
pass
@abstractmethod
def send(self, msg_payload: Dict):
pass
@abstractmethod
def relay(self, msg_payload: Dict):
pass
@abstractmethod
def get_site_names(self) -> List[str]:
pass
@abstractmethod
def wait_all(self, min_responses: int, resp_max_wait_time: Optional[float]) -> Dict[str, Dict[str, FLModel]]:
pass
@abstractmethod
def wait_one(self, resp_max_wait_time: Optional[float] = None) -> Tuple[str, str, FLModel]:
pass
```


## Writing a new Workflow

With this new API writing the new workflow is really simple:

* Workflow (Server)

```
from nvflare.app_common.workflows import wf_comm as flare
class FedAvg:
def __init__(
self,
min_clients: int,
num_rounds: int,
output_path: str,
start_round: int = 1,
stop_cond: str = None,
model_selection_rule: str = None,
):
super(FedAvg, self).__init__()
<skip init code>
self.flare_comm = flare.get_wf_comm_api()
def run(self):
self.logger.info("start Fed Avg Workflow\n \n")
start = self.start_round
end = self.start_round + self.num_rounds
model = self.init_model()
for current_round in range(start, end):
self.logger.info(f"Round {current_round}/{self.num_rounds} started. {start=}, {end=}")
self.current_round = current_round
sag_results = self.scatter_and_gather(model, current_round)
aggr_result = self.aggr_fn(sag_results)
self.logger.info(f"aggregate metrics = {aggr_result.metrics}")
model = update_model(model, aggr_result)
self.select_best_model(model)
self.save_model(self.best_model, self.output_path)
self.logger.info("end Fed Avg Workflow\n \n")
```
Scatter and Gather (SAG):

SAG is simply ask WFController to broadcast the model to all clients

```
def scatter_and_gather(self, model: FLModel, current_round):
msg_payload = {"min_responses": self.min_clients,
"current_round": current_round,
"num_round": self.num_rounds,
"start_round": self.start_round,
"data": model}
# (2) broadcast and wait
results = self.flare_comm.broadcast_and_wait(msg_payload)
return results
```

## Configurations

### client-side configuration

This is the same as FLARE Client API configuration

### server-side configuration

Server side controller is really simple, all we need is to use WFController with newly defined workflow class


```
{
# version of the configuration
format_version = 2
task_data_filters =[]
task_result_filters = []
workflows = [
{
id = "fed_avg"
path = "nvflare.app_opt.pt.wf_controller.PTWFController"
args {
comm_msg_pull_interval = 5
task_name = "train"
wf_class_path = "fedavg_pt.PTFedAvg",
wf_args {
min_clients = 2
num_rounds = 10
output_path = "/tmp/nvflare/fedavg/mode.pth"
stop_cond = "accuracy >= 55"
model_selection_rule = "accuracy >="
}
}
}
]
components = []
}
```


## Run the job

assume current working directory is at ```hello-fedavg``` directory

```
nvflare simulator -n 2 -t 2 jobs/fedavg -w /tmp/fedavg
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
{
format_version = 2
app_script = "cifar10_fl.py"
app_config = ""
executors = [
{
tasks = [
"train"
]
executor {
path = "nvflare.app_opt.pt.client_api_launcher_executor.PTClientAPILauncherExecutor"
args {
launcher_id = "launcher"
pipe_id = "pipe"
heartbeat_timeout = 60
params_exchange_format = "pytorch"
params_transfer_type = "DIFF"
train_with_evaluation = true
}
}
}
]
task_data_filters = []
task_result_filters = []
components = [
{
id = "launcher"
path = "nvflare.app_common.launchers.subprocess_launcher.SubprocessLauncher"
args {
script = "python3 custom/{app_script} {app_config} "
launch_once = true
}
}
{
id = "pipe"
path = "nvflare.fuel.utils.pipe.cell_pipe.CellPipe"
args {
mode = "PASSIVE"
site_name = "{SITE_NAME}"
token = "{JOB_ID}"
root_url = "{ROOT_URL}"
secure_mode = "{SECURE_MODE}"
workspace_dir = "{WORKSPACE}"
}
}
{
id = "metrics_pipe"
path = "nvflare.fuel.utils.pipe.cell_pipe.CellPipe"
args {
mode = "PASSIVE"
site_name = "{SITE_NAME}"
token = "{JOB_ID}"
root_url = "{ROOT_URL}"
secure_mode = "{SECURE_MODE}"
workspace_dir = "{WORKSPACE}"
}
}
{
id = "metric_relay"
path = "nvflare.app_common.widgets.metric_relay.MetricRelay"
args {
pipe_id = "metrics_pipe"
event_type = "fed.analytix_log_stats"
read_interval = 0.1
}
}
{
id = "config_preparer"
path = "nvflare.app_common.widgets.external_configurator.ExternalConfigurator"
args {
component_ids = [
"metric_relay"
]
}
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
{
# version of the configuration
format_version = 2
task_data_filters =[]
task_result_filters = []

workflows = [
{
id = "fed_avg"
path = "fedavg_pt.PTFedAvg"
args {
min_clients = 2
num_rounds = 2
output_path = "/tmp/nvflare/fedavg/mode.pth"
# stop_cond = "accuracy >= 55"
}
}
]

components = [
{
id = "decomposer_register"
path = "nvflare.app_common.wf_comm.decomposer_register.DecomposerRegister"
args {
decomposers = [ "nvflare.app_opt.pt.decomposers.TensorDecomposer"]
}
}
]
}
Loading

0 comments on commit aceb96a

Please sign in to comment.