Skip to content

Commit

Permalink
[2.4] Add custom order and early termination to cyclic controller (#2422
Browse files Browse the repository at this point in the history
)

* Add custom order and early termination to CyclicController and add tests

* Add more error handling
  • Loading branch information
YuanTingHsieh committed Mar 19, 2024
1 parent cd9237f commit fb8f7c5
Show file tree
Hide file tree
Showing 3 changed files with 214 additions and 35 deletions.
1 change: 1 addition & 0 deletions nvflare/apis/fl_constant.py
Expand Up @@ -39,6 +39,7 @@ class ReturnCode(object):
VALIDATE_TYPE_UNKNOWN = "VALIDATE_TYPE_UNKNOWN"
EMPTY_RESULT = "EMPTY_RESULT"
UNSAFE_JOB = "UNSAFE_JOB"
EARLY_TERMINATION = "EARLY_TERMINATION"
SERVER_NOT_READY = "SERVER_NOT_READY"
SERVICE_UNAVAILABLE = "SERVICE_UNAVAILABLE"

Expand Down
114 changes: 79 additions & 35 deletions nvflare/app_common/workflows/cyclic_ctl.py
Expand Up @@ -14,6 +14,7 @@

import gc
import random
from typing import List, Union

from nvflare.apis.client import Client
from nvflare.apis.fl_constant import ReturnCode
Expand Down Expand Up @@ -48,7 +49,8 @@ def __init__(
task_check_period: float = 0.5,
persist_every_n_rounds: int = 1,
snapshot_every_n_rounds: int = 1,
order: str = RelayOrder.FIXED,
order: Union[str, List[str]] = RelayOrder.FIXED,
allow_early_termination=False,
):
"""A sample implementation to demonstrate how to use relay method for Cyclic Federated Learning.
Expand All @@ -65,11 +67,13 @@ def __init__(
If n is 0 then no persist.
snapshot_every_n_rounds (int, optional): persist the server state every n rounds. Defaults to 1.
If n is 0 then no persist.
order (str, optional): the order of relay.
If FIXED means the same order for every round.
If RANDOM means random order for every round.
If RANDOM_WITHOUT_SAME_IN_A_ROW means every round the order gets shuffled but a client will never be
run twice in a row (in different round).
order (Union[str, List[str]], optional): The order of relay.
- If a string is provided:
- "FIXED": Same order for every round.
- "RANDOM": Random order for every round.
- "RANDOM_WITHOUT_SAME_IN_A_ROW": Shuffled order, no repetition in consecutive rounds.
- If a list of strings is provided, it represents a custom order for relay.
allow_early_termination: whether to allow early workflow termination from clients
Raises:
TypeError: when any of input arguments does not have correct type
Expand All @@ -88,13 +92,14 @@ def __init__(
if not isinstance(task_name, str):
raise TypeError("task_name must be a string but got {}".format(type(task_name)))

if order not in SUPPORTED_ORDERS:
raise ValueError(f"order must be in {SUPPORTED_ORDERS}")
if order not in SUPPORTED_ORDERS and not isinstance(order, list):
raise ValueError(f"order must be in {SUPPORTED_ORDERS} or a list")

self._num_rounds = num_rounds
self._start_round = 0
self._end_round = self._start_round + self._num_rounds
self._current_round = 0
self._is_done = False
self._last_learnable = None
self.persistor_id = persistor_id
self.shareable_generator_id = shareable_generator_id
Expand All @@ -107,6 +112,7 @@ def __init__(
self._participating_clients = None
self._last_client = None
self._order = order
self._allow_early_termination = allow_early_termination

def start_controller(self, fl_ctx: FLContext):
self.log_debug(fl_ctx, "starting controller")
Expand All @@ -127,46 +133,79 @@ def start_controller(self, fl_ctx: FLContext):
fl_ctx.set_prop(AppConstants.NUM_ROUNDS, self._num_rounds, private=True, sticky=True)
self.fire_event(AppEventType.INITIAL_MODEL_LOADED, fl_ctx)

self._participating_clients = self._engine.get_clients()
self._participating_clients: List[Client] = self._engine.get_clients()
if len(self._participating_clients) <= 1:
self.system_panic("Not enough client sites.", fl_ctx)
self._last_client = None

def _get_relay_orders(self, fl_ctx: FLContext):
targets = list(self._participating_clients)
if len(targets) <= 1:
self.system_panic("Not enough client sites.", fl_ctx)
if self._order == RelayOrder.RANDOM:
random.shuffle(targets)
elif self._order == RelayOrder.RANDOM_WITHOUT_SAME_IN_A_ROW:
random.shuffle(targets)
if self._last_client == targets[0]:
targets = targets.append(targets.pop(0))
def _get_relay_orders(self, fl_ctx: FLContext) -> Union[List[Client], None]:
if len(self._participating_clients) <= 1:
self.system_panic(f"Not enough client sites ({len(self._participating_clients)}).", fl_ctx)
return None

if isinstance(self._order, list):
targets = []
active_clients_map = {t.name: t for t in self._participating_clients}
for c_name in self._order:
if c_name not in active_clients_map:
self.system_panic(f"Required client site ({c_name}) is not in active clients.", fl_ctx)
return None
targets.append(active_clients_map[c_name])
else:
targets = list(self._participating_clients)
if self._order == RelayOrder.RANDOM or self._order == RelayOrder.RANDOM_WITHOUT_SAME_IN_A_ROW:
random.shuffle(targets)
if self._order == RelayOrder.RANDOM_WITHOUT_SAME_IN_A_ROW and self._last_client == targets[0]:
targets.append(targets.pop(0))
self._last_client = targets[-1]
return targets

def _process_result(self, client_task: ClientTask, fl_ctx: FLContext):
result = client_task.result
rc = result.get_return_code()
client_name = client_task.client.name

# Raise errors if ReturnCode is not OK.
if rc and rc != ReturnCode.OK:
self.system_panic(
f"Result from {client_name} is bad, error code: {rc}. "
f"{self.__class__.__name__} exiting at round {self._current_round}.",
fl_ctx=fl_ctx,
)
return False
def _stop_workflow(self, task: Task):
self.cancel_task(task)
self._is_done = True

def _process_result(self, client_task: ClientTask, fl_ctx: FLContext):
# submitted shareable is stored in client_task.result
# we need to update task.data with that shareable so the next target
# will get the updated shareable
task = client_task.task

# update the global learnable with the received result (shareable)
# e.g. the received result could be weight_diffs, the learnable could be full weights.
self._last_learnable = self.shareable_generator.shareable_to_learnable(client_task.result, fl_ctx)
result = client_task.result
if isinstance(result, Shareable):
# update the global learnable with the received result (shareable)
# e.g. the received result could be weight_diffs, the learnable could be full weights.
rc = result.get_return_code()
try:
self._last_learnable = self.shareable_generator.shareable_to_learnable(result, fl_ctx)
except Exception as ex:
if rc != ReturnCode.EARLY_TERMINATION:
self._stop_workflow(task)
self.log_error(fl_ctx, f"exception {secure_format_exception(ex)} from shareable_to_learnable")
return
else:
self.log_warning(
fl_ctx,
f"ignored {secure_format_exception(ex)} from shareable_to_learnable in early termination",
)

if rc == ReturnCode.EARLY_TERMINATION:
if self._allow_early_termination:
# the workflow is done
self._stop_workflow(task)
self.log_info(fl_ctx, f"Stopping workflow due to {rc} from client {client_task.client.name}")
return
else:
self.log_warning(
fl_ctx,
f"Ignored {rc} from client {client_task.client.name} because early termination is not allowed",
)
else:
self._stop_workflow(task)
self.log_error(
fl_ctx,
f"Stopping workflow due to result from client {client_task.client.name} is not a Shareable",
)
return

# prepare task shareable data for next client
task.data = self.shareable_generator.learnable_to_shareable(self._last_learnable, fl_ctx)
Expand All @@ -179,6 +218,9 @@ def control_flow(self, abort_signal: Signal, fl_ctx: FLContext):
self.log_debug(fl_ctx, "Cyclic starting.")

for self._current_round in range(self._start_round, self._end_round):
if self._is_done:
return

if abort_signal.triggered:
return

Expand All @@ -187,6 +229,8 @@ def control_flow(self, abort_signal: Signal, fl_ctx: FLContext):

# Task for one cyclic
targets = self._get_relay_orders(fl_ctx)
if targets is None:
return
targets_names = [t.name for t in targets]
self.log_debug(fl_ctx, f"Relay on {targets_names}")

Expand Down
134 changes: 134 additions & 0 deletions tests/unit_test/app_common/workflow/cyclic_ctl_test.py
@@ -0,0 +1,134 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import uuid
from unittest.mock import Mock, patch

import pytest

from nvflare.apis.client import Client
from nvflare.apis.controller_spec import ClientTask, Task
from nvflare.apis.fl_constant import ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable
from nvflare.apis.signal import Signal
from nvflare.app_common.abstract.learnable import Learnable
from nvflare.app_common.workflows.cyclic_ctl import CyclicController, RelayOrder

SITE_1_ID = uuid.uuid4()
SITE_2_ID = uuid.uuid4()
SITE_3_ID = uuid.uuid4()

ORDER_TEST_CASES = [
(
RelayOrder.FIXED,
[Client("site-1", SITE_1_ID), Client("site-2", SITE_2_ID)],
[Client("site-1", SITE_1_ID), Client("site-2", SITE_2_ID)],
),
(
["site-1", "site-2"],
[Client("site-1", SITE_1_ID), Client("site-2", SITE_2_ID)],
[Client("site-1", SITE_1_ID), Client("site-2", SITE_2_ID)],
),
(
["site-2", "site-1"],
[Client("site-1", SITE_1_ID), Client("site-2", SITE_2_ID)],
[Client("site-2", SITE_2_ID), Client("site-1", SITE_1_ID)],
),
(
["site-2", "site-1", "site-3"],
[Client("site-3", SITE_3_ID), Client("site-1", SITE_1_ID), Client("site-2", SITE_2_ID)],
[Client("site-2", SITE_2_ID), Client("site-1", SITE_1_ID), Client("site-3", SITE_3_ID)],
),
]


def gen_shareable(is_early_termination: bool = False, is_not_shareable: bool = False):
if is_not_shareable:
return [1, 2, 3]
return_result = Shareable()
if is_early_termination:
return_result.set_return_code(ReturnCode.EARLY_TERMINATION)
return return_result


PROCESS_RESULT_TEST_CASES = [gen_shareable(is_early_termination=True), gen_shareable(is_not_shareable=True)]


class TestCyclicController:
@pytest.mark.parametrize("order,active_clients,expected_result", ORDER_TEST_CASES)
def test_get_relay_orders(self, order, active_clients, expected_result):
ctl = CyclicController(order=order)
ctx = FLContext()
ctl._participating_clients = active_clients
targets = ctl._get_relay_orders(ctx)
for c, e_c in zip(targets, expected_result):
assert c.name == e_c.name
assert c.token == e_c.token

def test_control_flow_call_relay_and_wait(self):

with patch("nvflare.app_common.workflows.cyclic_ctl.CyclicController.relay_and_wait") as mock_method:
ctl = CyclicController(persist_every_n_rounds=0, snapshot_every_n_rounds=0, num_rounds=1)
ctl.shareable_generator = Mock()
ctl._participating_clients = [
Client("site-3", SITE_3_ID),
Client("site-1", SITE_1_ID),
Client("site-2", SITE_2_ID),
]

abort_signal = Signal()
fl_ctx = FLContext()

with patch.object(ctl.shareable_generator, "learnable_to_shareable") as mock_method1, patch.object(
ctl.shareable_generator, "shareable_to_learnable"
) as mock_method2:
mock_method1.return_value = Shareable()
mock_method2.return_value = Learnable()

ctl.control_flow(abort_signal, fl_ctx)

mock_method.assert_called_once()

@pytest.mark.parametrize("return_result", PROCESS_RESULT_TEST_CASES)
def test_process_result(self, return_result):
ctl = CyclicController(
persist_every_n_rounds=0, snapshot_every_n_rounds=0, num_rounds=1, allow_early_termination=True
)
ctl.shareable_generator = Mock()
ctl._participating_clients = [
Client("site-3", SITE_3_ID),
Client("site-1", SITE_1_ID),
Client("site-2", SITE_2_ID),
]

fl_ctx = FLContext()
with patch.object(ctl, "cancel_task") as mock_method, patch.object(
ctl.shareable_generator, "learnable_to_shareable"
) as mock_method1, patch.object(ctl.shareable_generator, "shareable_to_learnable") as mock_method2:
mock_method1.return_value = Shareable()
mock_method2.return_value = Learnable()

client_task = ClientTask(
client=Mock(),
task=Task(
name="__test_task",
data=Shareable(),
),
)
client_task.result = return_result
ctl._process_result(client_task, fl_ctx)
mock_method.assert_called_once()
assert ctl._is_done is True

0 comments on commit fb8f7c5

Please sign in to comment.