Skip to content

Commit

Permalink
Add min_responses to send_model_and_wait() (#2675)
Browse files Browse the repository at this point in the history
* expose min_responses, hide wait_time_after_min_received

* change default to None

* add comment
  • Loading branch information
SYangster committed Jul 1, 2024
1 parent 0665625 commit eac704d
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 10 deletions.
12 changes: 8 additions & 4 deletions nvflare/app_common/workflows/base_model_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,9 @@ def broadcast_model(
task_name: str = AppConstants.TASK_TRAIN,
data: FLModel = None,
targets: Union[List[Client], List[str], None] = None,
min_responses: int = None,
timeout: int = 0,
wait_time_after_min_received: int = 10,
wait_time_after_min_received: int = 0,
blocking: bool = True,
callback: Callable[[FLModel], None] = None,
) -> List:
Expand All @@ -113,9 +114,11 @@ def broadcast_model(
task_name (str, optional): name of the task. Defaults to "train".
data (FLModel, optional): FLModel to be sent to clients. If no data is given, send empty FLModel.
targets (List[str], optional): the list of target client names or None (all clients). Defaults to None.
min_responses (int, optional): the minimum number of responses expected. If None, must receive responses from
all clients that the task has been sent to. Defaults to None.
timeout (int, optional): time to wait for clients to perform task. Defaults to 0, i.e., never time out.
wait_time_after_min_received (int, optional): time to wait after
minimum number of clients responses has been received. Defaults to 10.
minimum number of clients responses has been received. Defaults to 0.
blocking (bool, optional): whether to block to wait for task result. Defaults to True.
callback (Callable[[FLModel], None], optional): callback when a result is received, only called when blocking=False. Defaults to None.
Expand All @@ -127,6 +130,9 @@ def broadcast_model(
raise TypeError("task_name must be a string but got {}".format(type(task_name)))
if data and not isinstance(data, FLModel):
raise TypeError("data must be a FLModel or None but got {}".format(type(data)))
if min_responses is None:
min_responses = 0 # this is internally used by controller's broadcast to represent all targets
check_non_negative_int("min_responses", min_responses)
check_non_negative_int("timeout", timeout)
check_non_negative_int("wait_time_after_min_received", wait_time_after_min_received)
if not blocking and not isinstance(callback, Callable):
Expand All @@ -140,10 +146,8 @@ def broadcast_model(

if targets:
targets = [client.name if isinstance(client, Client) else client for client in targets]
min_responses = len(targets)
self.info(f"Sending task {task_name} to {targets}")
else:
min_responses = len(self.engine.get_clients())
self.info(f"Sending task {task_name} to all clients")

if blocking:
Expand Down
14 changes: 8 additions & 6 deletions nvflare/app_common/workflows/model_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,18 @@ def send_model_and_wait(
task_name: str = "train",
data: FLModel = None,
targets: Union[List[str], None] = None,
min_responses: int = None,
timeout: int = 0,
wait_time_after_min_received: int = 10,
) -> List[FLModel]:
"""Send a task with data to targets and wait for results.
Args:
task_name (str, optional): name of the task. Defaults to "train".
data (FLModel, optional): FLModel to be sent to clients. Defaults to None.
targets (List[str], optional): the list of target client names or None (all clients). Defaults to None.
min_responses (int, optional): the minimum number of responses expected. If None, must receive responses from
all clients that the task has been sent to. Defaults to None.
timeout (int, optional): time to wait for clients to perform task. Defaults to 0 (never time out).
wait_time_after_min_received (int, optional): time to wait after minimum number of client responses have been received. Defaults to 10.
Returns:
List[FLModel]
Expand All @@ -62,17 +63,17 @@ def send_model_and_wait(
task_name=task_name,
data=data,
targets=targets,
min_responses=min_responses,
timeout=timeout,
wait_time_after_min_received=wait_time_after_min_received,
)

def send_model(
self,
task_name: str = "train",
data: FLModel = None,
targets: Union[List[str], None] = None,
min_responses: int = None,
timeout: int = 0,
wait_time_after_min_received: int = 10,
callback: Callable[[FLModel], None] = None,
) -> None:
"""Send a task with data to targets (non-blocking). Callback is called when a result is received.
Expand All @@ -81,8 +82,9 @@ def send_model(
task_name (str, optional): name of the task. Defaults to "train".
data (FLModel, optional): FLModel to be sent to clients. Defaults to None.
targets (List[str], optional): the list of target client names or None (all clients). Defaults to None.
min_responses (int, optional): the minimum number of responses expected. If None, must receive responses from
all clients that the task has been sent to. Defaults to None.
timeout (int, optional): time to wait for clients to perform task. Defaults to 0 (never time out).
wait_time_after_min_received (int, optional): time to wait after minimum number of client responses have been received. Defaults to 10.
callback (Callable[[FLModel], None], optional): callback when a result is received. Defaults to None.
Returns:
Expand All @@ -92,8 +94,8 @@ def send_model(
task_name=task_name,
data=data,
targets=targets,
min_responses=min_responses,
timeout=timeout,
wait_time_after_min_received=wait_time_after_min_received,
blocking=False,
callback=callback,
)
Expand Down

0 comments on commit eac704d

Please sign in to comment.