From dc9e6ff2bf03174ae7ea0121ff9a52adfb4149e1 Mon Sep 17 00:00:00 2001 From: tonyjohnchen Date: Mon, 1 May 2023 14:59:36 -0700 Subject: [PATCH] As b/279912173 requested, replaced `absl` with `argparse`, so `multihost_runner.py` will be fully `dependency free` --- multihost_job.py | 101 ++++++++++++++++++++++++++--------------------- 1 file changed, 55 insertions(+), 46 deletions(-) diff --git a/multihost_job.py b/multihost_job.py index a9ea04f55e..a1892c57f7 100644 --- a/multihost_job.py +++ b/multihost_job.py @@ -45,8 +45,7 @@ and gcloud auth application-default login """ - -from absl import app, flags +import argparse import sys import subprocess from datetime import datetime @@ -55,32 +54,43 @@ import shutil ##### Define flags ##### -FLAGS = flags.FLAGS -tpu_type_flag = flags.DEFINE_string("TPU_TYPE", "v4-8", "The type of the TPU") -tpu_runtime_version_flag = flags.DEFINE_string("VERSION", "tpu-vm-v4-base", "The runtime version of the TPU") -num_slices_flag = flags.DEFINE_integer("NUM_SLICES", 2, "The number of slices to run the job on") -script_dir_flag = flags.DEFINE_string("SCRIPT_DIR", os.getcwd(), "The local location of the directory to copy to"\ - " the TPUs and run the main command from. Defaults to current working directory.") -command_flag = flags.DEFINE_string("COMMAND", None, "Main command to run on each TPU. This command is run from"\ - " a copied version of SCRIPT_DIR on each TPU worker. You must include your dependency installations here, e.g."\ - "--COMMAND='bash setup.sh && python3 train.py'") -bucket_name_flag = flags.DEFINE_string("BUCKET_NAME", None, "Name of GCS bucket, e.g. my-bucket") -bucket_dir_flag = flags.DEFINE_string("BUCKET_DIR", "", "Directory within the GCS bucket, can be None, e.g. my-dir") -project_flag = flags.DEFINE_string("PROJECT", None, "GCE project name, defaults to gcloud config project") -zone_flag = flags.DEFINE_string("ZONE", None, "GCE zone, e.g. us-central2-b, defaults to gcloud config compute/zone") -endpoint_flag = flags.DEFINE_string("ENDPOINT", "tpu.googleapis.com", "The endpoint for google API requests.") -run_name_flag = flags.DEFINE_string("RUN_NAME", None, "Run name used for temporary files, defaults to timestamp.") -network_flag = flags.DEFINE_string("NETWORK", "default", "Gcloud compute engine network.") -subnetwork_flag = flags.DEFINE_string("SUBNETWORK", "default", "Gcloud compute engine subnetwork.") -resource_pool_flag = flags.DEFINE_string("RESOURCE_POOL", "on-demand", "The resource pool to use, either"\ - "'reserved', 'on-demand', or 'best-effort'.") -service_account_flag = flags.DEFINE_string("SERVICE_ACCOUNT", None, "Service account for the TPU VMs.") - -flags.mark_flag_as_required('COMMAND') -flags.mark_flag_as_required('BUCKET_NAME') -flags.register_validator('RESOURCE_POOL', - lambda value: value in ["reserved", "on-demand", "best-effort"], - message="--RESOURCE_POOL must be 'reserved', 'on-demand' or 'best-effort'") +parser = argparse.ArgumentParser(description='TPU configuration options') +parser.add_argument('--TPU_TYPE', type=str, default='v4-8', + help='The type of the TPU') +parser.add_argument('--VERSION', type=str, default='tpu-vm-v4-base', + help='The runtime version of the TPU') +parser.add_argument('--NUM_SLICES', type=int, default=2, + help='The number of slices to run the job on') +parser.add_argument('--SCRIPT_DIR', type=str, default=os.getcwd(), + help='The local location of the directory to copy to the TPUs and run the main command from. \ + Defaults to current working directory.') +parser.add_argument('--COMMAND', type=str, default=None, required=True, + help='Main command to run on each TPU. \ + This command is run from a copied version of SCRIPT_DIR on each TPU worker. \ + You must include your dependency installations here, \ + e.g. --COMMAND=\'bash setup.sh && python3 train.py\'') +parser.add_argument('--BUCKET_NAME', type=str, default=None, required=True, + help='Name of GCS bucket, e.g. my-bucket') +parser.add_argument('--BUCKET_DIR', type=str, default="", + help='Directory within the GCS bucket, can be None, e.g. my-dir') +parser.add_argument('--PROJECT', type=str, default=None, + help='GCE project name, defaults to gcloud config project') +parser.add_argument('--ZONE', type=str, default=None, + help='GCE zone, e.g. us-central2-b, defaults to gcloud config compute/zone') +parser.add_argument('--ENDPOINT', type=str, default='tpu.googleapis.com', + help='The endpoint for google API requests.') +parser.add_argument('--RUN_NAME', type=str, default=None, + help='Run name used for temporary files, defaults to timestamp.') +parser.add_argument('--NETWORK', type=str, default='default', + help='Gcloud compute engine network.') +parser.add_argument('--SUBNETWORK', type=str, default='default', + help='Gcloud compute engine subnetwork.') +parser.add_argument('--RESOURCE_POOL', type=str, + default='on-demand', choices=["reserved", "on-demand", "best-effort"], + help='The resource pool to use, either \'reserved\', \'on-demand\', or \'best-effort\'.') +parser.add_argument('--SERVICE_ACCOUNT', type=str, + default=None, help='Service account for the TPU VMs.') +args = parser.parse_args() def get_project(): completed_command = subprocess.run(["gcloud", "config", "get", "project"], check=True, capture_output=True) @@ -303,26 +313,25 @@ def gcs_bucket_url(bucket_name, bucket_dir, project): return f"https://console.cloud.google.com/storage/browser/{bucket_path}?project={project}" ################### Main ################### -def main(argv) -> None: +def main() -> None: print("\nStarting multihost_job...\n", flush=True) #### Parse flags #### - FLAGS(argv) # parses the python command inputs into FLAG objects - tpu_type = tpu_type_flag.value - tpu_runtime_version = tpu_runtime_version_flag.value - num_slices = num_slices_flag.value - script_dir = script_dir_flag.value - main_command = command_flag.value - bucket_name = bucket_name_flag.value - bucket_dir = bucket_dir_flag.value - endpoint = endpoint_flag.value - project = project_flag.value - zone = zone_flag.value - run_name = run_name_flag.value - network = network_flag.value - subnetwork = subnetwork_flag.value - resource_pool = resource_pool_flag.value - service_account = service_account_flag.value + tpu_type = args.TPU_TYPE + tpu_runtime_version = args.VERSION + num_slices = args.NUM_SLICES + script_dir = args.SCRIPT_DIR + main_command = args.COMMAND + bucket_name = args.BUCKET_NAME + bucket_dir = args.BUCKET_DIR + endpoint = args.ENDPOINT + project = args.PROJECT + zone = args.ZONE + run_name = args.RUN_NAME + network = args.NETWORK + subnetwork = args.SUBNETWORK + resource_pool = args.RESOURCE_POOL + service_account = args.SERVICE_ACCOUNT if not project: project = get_project() @@ -398,4 +407,4 @@ def main(argv) -> None: return 0 if __name__ == '__main__': - app.run(main) + main()