In [None]:
# default_exp executor

In [None]:
#export
from datetime import datetime
from functools import lru_cache
import uuid
import os
from pathlib import Path
import tempfile
import yaml

from blocks.filesystem import GCSFileSystem as gcsfs

In [None]:
class GCPConfig:

    @staticmethod
    @lru_cache(1)
    def bucket():
        return os.getenv("BUCKET") or input("Please enter the bucket path: ").strip()

    @staticmethod
    @lru_cache(1)
    def project_id():
        return os.getenv("PROJECT_ID") or google.auth.default()[1] or input("Please enter the project id: ").strip()


In [None]:
#export
from googleapiclient import discovery
import warnings

class AIP:

    @property
    @lru_cache(1)
    def job_id(self):
        now = datetime.now()
        date_time = now.strftime("%Y%m%d_%H%M%S")
        gen_uunid = hex(uuid.getnode())
        return f"ai_run_{date_time}_{gen_uunid}"

    @property
    @lru_cache(1)
    def job_output(self):
        return GCPConfig.bucket()

    def run(self, image, machine_type: str = "n1-highmem-32", job_id=None, args=[], **overrides):
        """

        """
        # TODO: dealing with hyperparameters

        training_inputs = {
            "scaleTier": "CUSTOM",
            "masterType": machine_type,
            "args": args,
            "region": "us-central1",
            "masterConfig": {"imageUri": image}
        }
        training_inputs.update(overrides)

        job_spec = {"jobId": job_id, "trainingInput": training_inputs}
        project_id = "projects/{}".format(GCPConfig.project_id())

        _setup_logging
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            cloudml = discovery.build("ml", "v1", cache_discovery = False)
            request = cloudml.projects().jobs().create(body = job_spec, parent=project_id)
            return request.execute()

In [None]:
AIP().run("gs://testjobsubmit")


                gcloud ai-platform jobs submit training ai_run_20200428_083645_0x8c8590a5b94c                 --job-dir gs:/testjobsubmit/output                 --package-path gs:/testjobsubmit/src                 --module-name gs:/testjobsubmit/src/CHANGE_THIS                 --region us-central1                 --runtime-version=2.1                 --python-version=3.7                 --scale-tier CUSTOM                 --config /var/folders/2k/b58ly_192yjgtv76zjxqj6f8_9cn2g/T/tmpl8pefm4j
            


In [None]:
#export
def _must_exist(key, dict_):
    assert key in dict_, "%r should be in the dictionary" % key

def _validate_config(conf_dict):
    _must_exist("image", conf_dict)

def run_yoda_on_gcp(conf_dict):
    # validate config
    _validate_config(conf_dict)

    # upload config to gcp
    gcs_config_path = os.path.join(self.job_output, "config.yaml")
    with GCSFileSystem().open(gs_config_path, "w") as f:
        yaml.safe_dump(conf_dict, f)

    # run on GCP
    args="yoda run {}".format(gcs_config_path).split()
    image=conf_dict["image"]
    aip = AIP()
    aip.run(image, args=args, **conf_dict)

In [None]:
import google.cloud.logging
from google.cloud.logging.handlers.handlers import CloudLoggingHandler, EXCLUDED_LOGGER_DEFAULTS

def _setup_logging():
    job = os.environ.get("CLOUD_ML_JOB_ID", None)
    trial = os.environ.get("CLOUD_ML_TRIAL_ID", None)
    project = os.environ.get("GCP_PROJECT", None)
    if job and project:
        client = google.cloud.logging.Client(project = project)
        resource = Resource(type = "ml_job", labels = dict(job_id = job, project_id = project, task_name = "master-replica-o"))
        # grouping by trial in AIP logs
        labels = {"ml.googleapis.com/trial_id": trial} if trial is not None else None
        handler = CloudLoggingHandler(client, resource=resource,labels=labels)
        logger = logging.getLogger()
        logger.handlers = []
        logger.setLevel(logging.DEBUG)
        logger.addHandler(handler)
        for logger_name in EXCLUDED_LOGGER_DEFAULTS:
            logging.getLogger(logger_name).propagate = False
    else:
        logger = logging.getLogger()
        logger.setLevel(logging.DEBUG)