Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simulator end run for all clients #2514

Merged
merged 13 commits into from
Apr 19, 2024
Merged
7 changes: 7 additions & 0 deletions nvflare/private/fed/app/simulator/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ def define_simulator_parser(simulator_parser):
simulator_parser.add_argument("-t", "--threads", type=int, help="number of parallel running clients")
simulator_parser.add_argument("-gpu", "--gpu", type=str, help="list of GPU Device Ids, comma separated")
simulator_parser.add_argument("-m", "--max_clients", type=int, default=100, help="max number of clients")
simulator_parser.add_argument(
"--end_run_for_all",
default=False,
action="store_true",
help="flag to indicate if running END_RUN event for all clients",
)


def run_simulator(simulator_args):
Expand All @@ -41,6 +47,7 @@ def run_simulator(simulator_args):
threads=simulator_args.threads,
yhwen marked this conversation as resolved.
Show resolved Hide resolved
gpu=simulator_args.gpu,
max_clients=simulator_args.max_clients,
end_run_for_all=simulator_args.end_run_for_all,
)
run_status = simulator.run()

Expand Down
60 changes: 49 additions & 11 deletions nvflare/private/fed/app/simulator/simulator_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,15 @@

class SimulatorRunner(FLComponent):
def __init__(
self, job_folder: str, workspace: str, clients=None, n_clients=None, threads=None, gpu=None, max_clients=100
self,
job_folder: str,
workspace: str,
clients=None,
n_clients=None,
threads=None,
gpu=None,
max_clients=100,
end_run_for_all=False,
):
super().__init__()

Expand All @@ -80,6 +88,7 @@ def __init__(
self.threads = threads
self.gpu = gpu
self.max_clients = max_clients
self.end_run_for_all = end_run_for_all

self.ask_to_stop = False

Expand Down Expand Up @@ -142,6 +151,7 @@ def setup(self):
self.args.env = os.path.join("config", AppFolderConstants.CONFIG_ENV)
cwd = os.getcwd()
self.args.job_folder = os.path.join(cwd, self.args.job_folder)
self.args.end_run_for_all = self.end_run_for_all

if not os.path.exists(self.args.workspace):
os.makedirs(self.args.workspace)
Expand Down Expand Up @@ -523,7 +533,7 @@ def __init__(self, args, clients: [], client_config, deploy_args, build_ctx):
self.kv_list = parse_vars(args.set)
self.logging_config = os.path.join(self.args.workspace, "local", WorkspaceConstants.LOGGING_CONFIG)

self.end_run_clients = []
self.clients_finished_end_run = []

def run(self, gpu):
try:
Expand All @@ -533,17 +543,14 @@ def run(self, gpu):
lock = threading.Lock()
timeout = self.kv_list.get("simulator_worker_timeout", 60.0)
for i in range(self.args.threads):
executor.submit(lambda p: self.run_client_thread(*p), [self.args.threads, gpu, lock, i, timeout])
executor.submit(
lambda p: self.run_client_thread(*p),
[self.args.threads, gpu, lock, self.args.end_run_for_all, timeout],
)

# wait for the server and client running thread to finish.
executor.shutdown()

for client in self.federated_clients:
yhwen marked this conversation as resolved.
Show resolved Hide resolved
if client.client_name not in self.end_run_clients:
self.do_one_task(
client, self.args.threads, gpu, lock, timeout=timeout, task_name=RunnerTask.END_RUN
)

except Exception as e:
self.logger.error(f"SimulatorClientRunner run error: {secure_format_exception(e)}")
finally:
Expand All @@ -562,7 +569,7 @@ def _shutdown_client(self, client):
# Ignore the exception for the simulator client shutdown
self.logger.warn(f"Exception happened to client{client.name} during shutdown ")

def run_client_thread(self, num_of_threads, gpu, lock, rank, timeout=60):
def run_client_thread(self, num_of_threads, gpu, lock, end_run_for_all, timeout=60):
stop_run = False
interval = 1
client_to_run = None # indicates the next client to run
Expand All @@ -582,12 +589,43 @@ def run_client_thread(self, num_of_threads, gpu, lock, rank, timeout=60):
)
if end_run_client:
with lock:
self.end_run_clients.append(end_run_client)
self.clients_finished_end_run.append(end_run_client)

client.simulate_running = False

if end_run_for_all:
self._end_run_clients(gpu, lock, num_of_threads, timeout)
except Exception as e:
self.logger.error(f"run_client_thread error: {secure_format_exception(e)}")

def _end_run_clients(self, gpu, lock, num_of_threads, timeout):
"""After the WF reaches the END_RUN, each running thread will try to pick up one of the remaining client
which has not run the END_RUN yet, then execute the END_RUN handler, until all the clients have done so.
These client END_RUN event handler only execute when "end_run_for_all" has been set.

Multiple client running threads will try to pick up the client from the same clients pool.
yhwen marked this conversation as resolved.
Show resolved Hide resolved

"""
# Each thread only stop picking up the NOT-DONE client until all clients have run the END_RUN event.
while len(self.clients_finished_end_run) != len(self.federated_clients):
with lock:
end_run_client = self._pick_next_client()
if end_run_client:
self.do_one_task(
end_run_client, num_of_threads, gpu, lock, timeout=timeout, task_name=RunnerTask.END_RUN
)
with lock:
yhwen marked this conversation as resolved.
Show resolved Hide resolved
end_run_client.simulate_running = False

def _pick_next_client(self):
for client in self.federated_clients:
# Ensure the client has not run the END_RUN event
if client.client_name not in self.clients_finished_end_run and not client.simulate_running:
client.simulate_running = True
self.clients_finished_end_run.append(client.client_name)
return client
return None

def do_one_task(self, client, num_of_threads, gpu, lock, timeout=60.0, task_name=RunnerTask.TASK_EXEC):
open_port = get_open_ports(1)[0]
client_workspace = os.path.join(self.args.workspace, client.client_name)
Expand Down
Loading