Skip to content

Commit

Permalink
Simulator end run for all clients (#2514)
Browse files Browse the repository at this point in the history
* Provide an option to run END_RUN for all clients.

* Added end_run_all option for simulator to run END_RUN event for all clients.

* Fixed a add_argument type, added help message.

* Changed to use add_argument(() compatible with python 3.8.

* reformat.

* rewrite the _end_run_clients() and add docstring for easier understanding.

* reformat.

* adjusting the locking in the _end_run_clients.

* Fixed a potential None pointer error.

* renamed the clients_finished_end_run variable.

---------

Co-authored-by: Chester Chen <512707+chesterxgchen@users.noreply.github.com>
Co-authored-by: Sean Yang <seany314@gmail.com>
Co-authored-by: Yuan-Ting Hsieh (謝沅廷) <yuantingh@nvidia.com>
  • Loading branch information
4 people committed Apr 19, 2024
1 parent bf978b7 commit b93fe3b
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 11 deletions.
7 changes: 7 additions & 0 deletions nvflare/private/fed/app/simulator/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ def define_simulator_parser(simulator_parser):
simulator_parser.add_argument("-t", "--threads", type=int, help="number of parallel running clients")
simulator_parser.add_argument("-gpu", "--gpu", type=str, help="list of GPU Device Ids, comma separated")
simulator_parser.add_argument("-m", "--max_clients", type=int, default=100, help="max number of clients")
simulator_parser.add_argument(
"--end_run_for_all",
default=False,
action="store_true",
help="flag to indicate if running END_RUN event for all clients",
)


def run_simulator(simulator_args):
Expand All @@ -41,6 +47,7 @@ def run_simulator(simulator_args):
threads=simulator_args.threads,
gpu=simulator_args.gpu,
max_clients=simulator_args.max_clients,
end_run_for_all=simulator_args.end_run_for_all,
)
run_status = simulator.run()

Expand Down
60 changes: 49 additions & 11 deletions nvflare/private/fed/app/simulator/simulator_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,15 @@

class SimulatorRunner(FLComponent):
def __init__(
self, job_folder: str, workspace: str, clients=None, n_clients=None, threads=None, gpu=None, max_clients=100
self,
job_folder: str,
workspace: str,
clients=None,
n_clients=None,
threads=None,
gpu=None,
max_clients=100,
end_run_for_all=False,
):
super().__init__()

Expand All @@ -80,6 +88,7 @@ def __init__(
self.threads = threads
self.gpu = gpu
self.max_clients = max_clients
self.end_run_for_all = end_run_for_all

self.ask_to_stop = False

Expand Down Expand Up @@ -142,6 +151,7 @@ def setup(self):
self.args.env = os.path.join("config", AppFolderConstants.CONFIG_ENV)
cwd = os.getcwd()
self.args.job_folder = os.path.join(cwd, self.args.job_folder)
self.args.end_run_for_all = self.end_run_for_all

if not os.path.exists(self.args.workspace):
os.makedirs(self.args.workspace)
Expand Down Expand Up @@ -523,7 +533,7 @@ def __init__(self, args, clients: [], client_config, deploy_args, build_ctx):
self.kv_list = parse_vars(args.set)
self.logging_config = os.path.join(self.args.workspace, "local", WorkspaceConstants.LOGGING_CONFIG)

self.end_run_clients = []
self.clients_finished_end_run = []

def run(self, gpu):
try:
Expand All @@ -533,17 +543,14 @@ def run(self, gpu):
lock = threading.Lock()
timeout = self.kv_list.get("simulator_worker_timeout", 60.0)
for i in range(self.args.threads):
executor.submit(lambda p: self.run_client_thread(*p), [self.args.threads, gpu, lock, i, timeout])
executor.submit(
lambda p: self.run_client_thread(*p),
[self.args.threads, gpu, lock, self.args.end_run_for_all, timeout],
)

# wait for the server and client running thread to finish.
executor.shutdown()

for client in self.federated_clients:
if client.client_name not in self.end_run_clients:
self.do_one_task(
client, self.args.threads, gpu, lock, timeout=timeout, task_name=RunnerTask.END_RUN
)

except Exception as e:
self.logger.error(f"SimulatorClientRunner run error: {secure_format_exception(e)}")
finally:
Expand All @@ -562,7 +569,7 @@ def _shutdown_client(self, client):
# Ignore the exception for the simulator client shutdown
self.logger.warn(f"Exception happened to client{client.name} during shutdown ")

def run_client_thread(self, num_of_threads, gpu, lock, rank, timeout=60):
def run_client_thread(self, num_of_threads, gpu, lock, end_run_for_all, timeout=60):
stop_run = False
interval = 1
client_to_run = None # indicates the next client to run
Expand All @@ -582,12 +589,43 @@ def run_client_thread(self, num_of_threads, gpu, lock, rank, timeout=60):
)
if end_run_client:
with lock:
self.end_run_clients.append(end_run_client)
self.clients_finished_end_run.append(end_run_client)

client.simulate_running = False

if end_run_for_all:
self._end_run_clients(gpu, lock, num_of_threads, timeout)
except Exception as e:
self.logger.error(f"run_client_thread error: {secure_format_exception(e)}")

def _end_run_clients(self, gpu, lock, num_of_threads, timeout):
"""After the WF reaches the END_RUN, each running thread will try to pick up one of the remaining client
which has not run the END_RUN yet, then execute the END_RUN handler, until all the clients have done so.
These client END_RUN event handler only execute when "end_run_for_all" has been set.
Multiple client running threads will try to pick up the client from the same clients pool.
"""
# Each thread only stop picking up the NOT-DONE client until all clients have run the END_RUN event.
while len(self.clients_finished_end_run) != len(self.federated_clients):
with lock:
end_run_client = self._pick_next_client()
if end_run_client:
self.do_one_task(
end_run_client, num_of_threads, gpu, lock, timeout=timeout, task_name=RunnerTask.END_RUN
)
with lock:
end_run_client.simulate_running = False

def _pick_next_client(self):
for client in self.federated_clients:
# Ensure the client has not run the END_RUN event
if client.client_name not in self.clients_finished_end_run and not client.simulate_running:
client.simulate_running = True
self.clients_finished_end_run.append(client.client_name)
return client
return None

def do_one_task(self, client, num_of_threads, gpu, lock, timeout=60.0, task_name=RunnerTask.TASK_EXEC):
open_port = get_open_ports(1)[0]
client_workspace = os.path.join(self.args.workspace, client.client_name)
Expand Down

0 comments on commit b93fe3b

Please sign in to comment.