Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 55 additions & 46 deletions multihost_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@
and
gcloud auth application-default login
"""

from absl import app, flags
import argparse
import sys
import subprocess
from datetime import datetime
Expand All @@ -55,32 +54,43 @@
import shutil

##### Define flags #####
FLAGS = flags.FLAGS
tpu_type_flag = flags.DEFINE_string("TPU_TYPE", "v4-8", "The type of the TPU")
tpu_runtime_version_flag = flags.DEFINE_string("VERSION", "tpu-vm-v4-base", "The runtime version of the TPU")
num_slices_flag = flags.DEFINE_integer("NUM_SLICES", 2, "The number of slices to run the job on")
script_dir_flag = flags.DEFINE_string("SCRIPT_DIR", os.getcwd(), "The local location of the directory to copy to"\
" the TPUs and run the main command from. Defaults to current working directory.")
command_flag = flags.DEFINE_string("COMMAND", None, "Main command to run on each TPU. This command is run from"\
" a copied version of SCRIPT_DIR on each TPU worker. You must include your dependency installations here, e.g."\
"--COMMAND='bash setup.sh && python3 train.py'")
bucket_name_flag = flags.DEFINE_string("BUCKET_NAME", None, "Name of GCS bucket, e.g. my-bucket")
bucket_dir_flag = flags.DEFINE_string("BUCKET_DIR", "", "Directory within the GCS bucket, can be None, e.g. my-dir")
project_flag = flags.DEFINE_string("PROJECT", None, "GCE project name, defaults to gcloud config project")
zone_flag = flags.DEFINE_string("ZONE", None, "GCE zone, e.g. us-central2-b, defaults to gcloud config compute/zone")
endpoint_flag = flags.DEFINE_string("ENDPOINT", "tpu.googleapis.com", "The endpoint for google API requests.")
run_name_flag = flags.DEFINE_string("RUN_NAME", None, "Run name used for temporary files, defaults to timestamp.")
network_flag = flags.DEFINE_string("NETWORK", "default", "Gcloud compute engine network.")
subnetwork_flag = flags.DEFINE_string("SUBNETWORK", "default", "Gcloud compute engine subnetwork.")
resource_pool_flag = flags.DEFINE_string("RESOURCE_POOL", "on-demand", "The resource pool to use, either"\
"'reserved', 'on-demand', or 'best-effort'.")
service_account_flag = flags.DEFINE_string("SERVICE_ACCOUNT", None, "Service account for the TPU VMs.")

flags.mark_flag_as_required('COMMAND')
flags.mark_flag_as_required('BUCKET_NAME')
flags.register_validator('RESOURCE_POOL',
lambda value: value in ["reserved", "on-demand", "best-effort"],
message="--RESOURCE_POOL must be 'reserved', 'on-demand' or 'best-effort'")
parser = argparse.ArgumentParser(description='TPU configuration options')
parser.add_argument('--TPU_TYPE', type=str, default='v4-8',
help='The type of the TPU')
parser.add_argument('--VERSION', type=str, default='tpu-vm-v4-base',
help='The runtime version of the TPU')
parser.add_argument('--NUM_SLICES', type=int, default=2,
help='The number of slices to run the job on')
parser.add_argument('--SCRIPT_DIR', type=str, default=os.getcwd(),
help='The local location of the directory to copy to the TPUs and run the main command from. \
Defaults to current working directory.')
parser.add_argument('--COMMAND', type=str, default=None, required=True,
help='Main command to run on each TPU. \
This command is run from a copied version of SCRIPT_DIR on each TPU worker. \
You must include your dependency installations here, \
e.g. --COMMAND=\'bash setup.sh && python3 train.py\'')
parser.add_argument('--BUCKET_NAME', type=str, default=None, required=True,
help='Name of GCS bucket, e.g. my-bucket')
parser.add_argument('--BUCKET_DIR', type=str, default="",
help='Directory within the GCS bucket, can be None, e.g. my-dir')
parser.add_argument('--PROJECT', type=str, default=None,
help='GCE project name, defaults to gcloud config project')
parser.add_argument('--ZONE', type=str, default=None,
help='GCE zone, e.g. us-central2-b, defaults to gcloud config compute/zone')
parser.add_argument('--ENDPOINT', type=str, default='tpu.googleapis.com',
help='The endpoint for google API requests.')
parser.add_argument('--RUN_NAME', type=str, default=None,
help='Run name used for temporary files, defaults to timestamp.')
parser.add_argument('--NETWORK', type=str, default='default',
help='Gcloud compute engine network.')
parser.add_argument('--SUBNETWORK', type=str, default='default',
help='Gcloud compute engine subnetwork.')
parser.add_argument('--RESOURCE_POOL', type=str,
default='on-demand', choices=["reserved", "on-demand", "best-effort"],
help='The resource pool to use, either \'reserved\', \'on-demand\', or \'best-effort\'.')
parser.add_argument('--SERVICE_ACCOUNT', type=str,
default=None, help='Service account for the TPU VMs.')
args = parser.parse_args()

def get_project():
completed_command = subprocess.run(["gcloud", "config", "get", "project"], check=True, capture_output=True)
Expand Down Expand Up @@ -303,26 +313,25 @@ def gcs_bucket_url(bucket_name, bucket_dir, project):
return f"https://console.cloud.google.com/storage/browser/{bucket_path}?project={project}"

################### Main ###################
def main(argv) -> None:
def main() -> None:
print("\nStarting multihost_job...\n", flush=True)

#### Parse flags ####
FLAGS(argv) # parses the python command inputs into FLAG objects
tpu_type = tpu_type_flag.value
tpu_runtime_version = tpu_runtime_version_flag.value
num_slices = num_slices_flag.value
script_dir = script_dir_flag.value
main_command = command_flag.value
bucket_name = bucket_name_flag.value
bucket_dir = bucket_dir_flag.value
endpoint = endpoint_flag.value
project = project_flag.value
zone = zone_flag.value
run_name = run_name_flag.value
network = network_flag.value
subnetwork = subnetwork_flag.value
resource_pool = resource_pool_flag.value
service_account = service_account_flag.value
tpu_type = args.TPU_TYPE
tpu_runtime_version = args.VERSION
num_slices = args.NUM_SLICES
script_dir = args.SCRIPT_DIR
main_command = args.COMMAND
bucket_name = args.BUCKET_NAME
bucket_dir = args.BUCKET_DIR
endpoint = args.ENDPOINT
project = args.PROJECT
zone = args.ZONE
run_name = args.RUN_NAME
network = args.NETWORK
subnetwork = args.SUBNETWORK
resource_pool = args.RESOURCE_POOL
service_account = args.SERVICE_ACCOUNT

if not project:
project = get_project()
Expand Down Expand Up @@ -398,4 +407,4 @@ def main(argv) -> None:
return 0

if __name__ == '__main__':
app.run(main)
main()