From 476e432c69a0e3671ad1c050770c207406ea367f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kamil=20Bregu=C5=82a?= Date: Wed, 9 Oct 2019 12:59:55 +0200 Subject: [PATCH 1/2] [AIRFLOW-5625] Update MLEngine integration doc and typehint --- airflow/gcp/hooks/cloud_build.py | 2 +- airflow/gcp/hooks/mlengine.py | 106 ++++++++++++++++++++++-- airflow/gcp/operators/mlengine.py | 133 +++++++++++------------------- 3 files changed, 147 insertions(+), 94 deletions(-) diff --git a/airflow/gcp/hooks/cloud_build.py b/airflow/gcp/hooks/cloud_build.py index b7cdb7373a981..6e08ec90351aa 100644 --- a/airflow/gcp/hooks/cloud_build.py +++ b/airflow/gcp/hooks/cloud_build.py @@ -61,7 +61,7 @@ def __init__( def get_conn(self): """ - Retrieves the connection to Cloud Functions. + Retrieves the connection to Cloud Build. :return: Google Cloud Build services object. """ diff --git a/airflow/gcp/hooks/mlengine.py b/airflow/gcp/hooks/mlengine.py index f93825c84d2f4..6e5035b8bf779 100644 --- a/airflow/gcp/hooks/mlengine.py +++ b/airflow/gcp/hooks/mlengine.py @@ -64,7 +64,9 @@ class MLEngineHook(GoogleCloudBaseHook): """ def get_conn(self): """ - Returns a Google MLEngine service object. + Retrieves the connection to MLEngine. + + :return: Google MLEngine services object. """ authed_http = self._authorize() return build('ml', 'v1', http=authed_http, cache_discovery=False) @@ -142,8 +144,13 @@ def create_job( def _get_job(self, project_id: str, job_id: str) -> Dict: """ - Gets a MLEngine job based on the job name. + Gets a MLEngine job based on the job id. + :param project_id: The project in which the Job is located. + If set to None or missing, the default project_id from the GCP connection is used. (templated) + :type project_id: str + :param job_id: A unique id for the Google MLEngine job. (templated) + :type job_id: str :return: MLEngine job object if succeed. :rtype: dict :raises: googleapiclient.errors.HttpError @@ -168,6 +175,14 @@ def _wait_for_job_done(self, project_id: str, job_id: str, interval: int = 30): This method will periodically check the job state until the job reach a terminal state. + + :param project_id: The project in which the Job is located. + If set to None or missing, the default project_id from the GCP connection is used. (templated) + :type project_id: str + :param job_id: A unique id for the Google MLEngine job. (templated) + :type job_id: str + :param interval: Time expressed in seconds after which the job status is checked again. (templated) + :type interval: int :raises: googleapiclient.errors.HttpError """ if interval <= 0: @@ -188,8 +203,18 @@ def create_version( """ Creates the Version on Google Cloud ML Engine. - Returns the operation if the version was created successfully and - raises an error otherwise. + :param version_spec: A dictionary containing the information about the version. (templated) + :type version_spec: dict + :param model_name: The name of the Google Cloud ML Engine model that the version belongs to. + (templated) + :type model_name: str + :param project_id: The Google Cloud project name to which MLEngine model belongs. + If set to None or missing, the default project_id from the GCP connection is used. + (templated) + :type project_id: str + :return: If the version was created successfully, returns the operation. + Otherwise raises an error . + :rtype: dict """ hook = self.get_conn() parent_name = 'projects/{}/models/{}'.format(project_id, model_name) @@ -214,6 +239,19 @@ def set_default_version( ) -> Dict: """ Sets a version to be the default. Blocks until finished. + + :param model_name: The name of the Google Cloud ML Engine model that the version belongs to. + (templated) + :type model_name: str + :param version_name: A name to use for the version being operated upon. (templated) + :type version_name: str + :param project_id: The Google Cloud project name to which MLEngine model belongs. + If set to None or missing, the default project_id from the GCP connection is used. (templated) + :type project_id: str + :return: If successful, return an instance of Version. + Otherwise raises an error. + :rtype: dict + :raises: googleapiclient.errors.HttpError """ hook = self.get_conn() full_version_name = 'projects/{}/models/{}/versions/{}'.format( @@ -237,6 +275,16 @@ def list_versions( ) -> List[Dict]: """ Lists all available versions of a model. Blocks until finished. + + :param model_name: The name of the Google Cloud ML Engine model that the version + belongs to. (templated) + :type model_name: str + :param project_id: The Google Cloud project name to which MLEngine model belongs. + If set to None or missing, the default project_id from the GCP connection is used. (templated) + :type project_id: str + :return: return an list of instance of Version. + :rtype: List[Dict] + :raises: googleapiclient.errors.HttpError """ hook = self.get_conn() result = [] # type: List[Dict] @@ -261,10 +309,22 @@ def delete_version( model_name: str, version_name: str, project_id: Optional[str] = None, - ): + ) -> Dict: """ Deletes the given version of a model. Blocks until finished. + + :param model_name: The name of the Google Cloud ML Engine model that the version + belongs to. (templated) + :type model_name: str + :param project_id: The Google Cloud project name to which MLEngine + model belongs. + :type project_id: str + :return: If the version was deleted successfully, returns the operation. + Otherwise raises an error. + :rtype: Dict """ + assert project_id is None + hook = self.get_conn() full_name = 'projects/{}/models/{}/versions/{}'.format( project_id, model_name, version_name) @@ -288,6 +348,16 @@ def create_model( ) -> Dict: """ Create a Model. Blocks until finished. + + :param model: A dictionary containing the information about the model. + :type model: dict + :param project_id: The Google Cloud project name to which MLEngine model belongs. + If set to None or missing, the default project_id from the GCP connection is used. (templated) + :type project_id: str + :return: If the version was created successfully, returns the instance of Model. + Otherwise raises an error. + :rtype: Dict + :raises: googleapiclient.errors.HttpError """ hook = self.get_conn() if not model['name']: @@ -307,6 +377,16 @@ def get_model( ) -> Optional[Dict]: """ Gets a Model. Blocks until finished. + + :param model_name: The name of the model. + :type model_name: str + :param project_id: The Google Cloud project name to which MLEngine model belongs. + If set to None or missing, the default project_id from the GCP connection is used. (templated) + :type project_id: str + :return: If the model exists, returns the instance of Model. + Otherwise return None. + :rtype: Dict + :raises: googleapiclient.errors.HttpError """ hook = self.get_conn() if not model_name: @@ -332,8 +412,22 @@ def delete_model( ) -> None: """ Delete a Model. Blocks until finished. + + :param model_name: The name of the model. + :type model_name: str + :param delete_contents: Whether to force the deletion even if the models is not empty. + Will delete all version (if any) in the dataset if set to True. + The default value is False. + :type delete_contents: bool + :param project_id: The Google Cloud project name to which MLEngine model belongs. + If set to None or missing, the default project_id from the GCP connection is used. (templated) + :type project_id: str + :raises: googleapiclient.errors.HttpError """ + assert project_id is not None + hook = self.get_conn() + if not model_name: raise ValueError("Model name must be provided and it could not be an empty string") model_path = 'projects/{}/models/{}'.format(project_id, model_name) @@ -348,7 +442,7 @@ def delete_model( return raise - def _delete_all_versions(self, model_name, project_id): + def _delete_all_versions(self, model_name: str, project_id: str): versions = self.list_versions(project_id=project_id, model_name=model_name) # The default version can only be deleted when it is the last one in the model non_default_versions = (version for version in versions if not version.get('isDefault', False)) diff --git a/airflow/gcp/operators/mlengine.py b/airflow/gcp/operators/mlengine.py index f0db45c784380..cab4eee480f9e 100644 --- a/airflow/gcp/operators/mlengine.py +++ b/airflow/gcp/operators/mlengine.py @@ -38,11 +38,10 @@ def _normalize_mlengine_job_id(job_id: str) -> str: This also adds a leading 'z' in case job_id starts with an invalid character. - Args: - job_id: A job_id str that may have invalid characters. - - Returns: - A valid job_id representation. + :param job_id: A job_id str that may have invalid characters. + :type job_id: str: + :return: A valid job_id representation. + :rtype: str """ # Add a prefix when a job_id starts with a digit or a template @@ -96,71 +95,57 @@ class MLEngineBatchPredictionOperator(BaseOperator): See https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs for further documentation on the parameters. - :param project_id: The Google Cloud project name where the prediction job is submitted. - If set to None or missing, the default project_id from the GCP connection is used. (templated) - :type project_id: str - :param job_id: A unique id for the prediction job on Google Cloud ML Engine. (templated) :type job_id: str - :param data_format: The format of the input data. It will default to 'DATA_FORMAT_UNSPECIFIED' if is not provided or is not one of ["TEXT", "TF_RECORD", "TF_RECORD_GZIP"]. :type data_format: str - :param input_paths: A list of GCS paths of input data for batch prediction. Accepting wildcard operator ``*``, but only at the end. (templated) :type input_paths: list[str] - :param output_path: The GCS path where the prediction results are written to. (templated) :type output_path: str - :param region: The Google Compute Engine region to run the prediction job in. (templated) :type region: str - :param model_name: The Google Cloud ML Engine model to use for prediction. If version_name is not provided, the default version of this model will be used. Should not be None if version_name is provided. Should be None if uri is provided. (templated) :type model_name: str - :param version_name: The Google Cloud ML Engine model version to use for prediction. Should be None if uri is provided. (templated) :type version_name: str - :param uri: The GCS path of the saved model to use for prediction. Should be None if model_name is provided. It should be a GCS path pointing to a tensorflow SavedModel. (templated) :type uri: str - :param max_worker_count: The maximum number of workers to be used for parallel processing. Defaults to 10 if not specified. Should be a string representing the worker count ("10" instead of 10, "50" instead of 50, etc.) :type max_worker_count: string - :param runtime_version: The Google Cloud ML Engine runtime version to use for batch prediction. :type runtime_version: str - :param signature_name: The name of the signature defined in the SavedModel to use for this job. :type signature_name: str - + :param project_id: The Google Cloud project name where the prediction job is submitted. + If set to None or missing, the default project_id from the GCP connection is used. (templated) + :type project_id: str :param gcp_conn_id: The connection ID used for connection to Google Cloud Platform. :type gcp_conn_id: str - :param delegate_to: The account to impersonate, if any. For this to work, the service account making the request must have domain-wide delegation enabled. :type delegate_to: str - :raises: ``ValueError``: if a unique model/version origin cannot be determined. """ @@ -298,9 +283,6 @@ class MLEngineModelOperator(BaseOperator): This operator is deprecated. Consider using operators for specific operations: MLEngineCreateModelOperator, MLEngineGetModelOperator. - :param project_id: The Google Cloud project name to which MLEngine model belongs. - If set to None or missing, the default project_id from the GCP connection is used. (templated) - :type project_id: str :param model: A dictionary containing the information about the model. If the `operation` is `create`, then the `model` parameter should contain all the information about this model such as `name`. @@ -313,6 +295,9 @@ class MLEngineModelOperator(BaseOperator): * ``create``: Creates a new model as provided by the `model` parameter. * ``get``: Gets a particular model where the name is specified in `model`. :type operation: str + :param project_id: The Google Cloud project name to which MLEngine model belongs. + If set to None or missing, the default project_id from the GCP connection is used. (templated) + :type project_id: str :param gcp_conn_id: The connection ID to use when fetching connection info. :type gcp_conn_id: str :param delegate_to: The account to impersonate, if any. @@ -322,6 +307,7 @@ class MLEngineModelOperator(BaseOperator): """ template_fields = [ + '_project_id', '_model', ] @@ -362,13 +348,15 @@ def execute(self, context): class MLEngineCreateModelOperator(BaseOperator): """ - Operator for managing a Google Cloud ML Engine model. + Creates a new model. + + The model should be provided by the `model` parameter. + :param model: A dictionary containing the information about the model. + :type model: dict :param project_id: The Google Cloud project name to which MLEngine model belongs. If set to None or missing, the default project_id from the GCP connection is used. (templated) :type project_id: str - :param model: A dictionary containing the information about the model. - :type model: dict :param gcp_conn_id: The connection ID to use when fetching connection info. :type gcp_conn_id: str :param delegate_to: The account to impersonate, if any. @@ -378,6 +366,7 @@ class MLEngineCreateModelOperator(BaseOperator): """ template_fields = [ + '_project_id', '_model', ] @@ -403,13 +392,15 @@ def execute(self, context): class MLEngineGetModelOperator(BaseOperator): """ - Operator for managing a Google Cloud ML Engine model. + Gets a particular model + The name of model shold be specified in `model_name`. + + :param model_name: The name of the model. + :type model_name: str :param project_id: The Google Cloud project name to which MLEngine model belongs. If set to None or missing, the default project_id from the GCP connection is used. (templated) :type project_id: str - :param model_name: The name of the model. - :type model_name: str :param gcp_conn_id: The connection ID to use when fetching connection info. :type gcp_conn_id: str :param delegate_to: The account to impersonate, if any. @@ -419,6 +410,7 @@ class MLEngineGetModelOperator(BaseOperator): """ template_fields = [ + '_project_id', '_model_name', ] @@ -447,15 +439,15 @@ class MLEngineDeleteModelOperator(BaseOperator): The model should be provided by the `model_name` parameter. - :param project_id: The Google Cloud project name to which MLEngine model belongs. - If set to None or missing, the default project_id from the GCP connection is used. (templated) - :type project_id: str :param model_name: The name of the model. :type model_name: str :param delete_contents: (Optional) Whether to force the deletion even if the models is not empty. Will delete all version (if any) in the dataset if set to True. The default value is False. :type delete_contents: bool + :param project_id: The Google Cloud project name to which MLEngine model belongs. + If set to None or missing, the default project_id from the GCP connection is used. (templated) + :type project_id: str :param gcp_conn_id: The connection ID to use when fetching connection info. :type gcp_conn_id: str :param delegate_to: The account to impersonate, if any. @@ -465,6 +457,7 @@ class MLEngineDeleteModelOperator(BaseOperator): """ template_fields = [ + '_project_id', '_model_name', ] @@ -501,20 +494,14 @@ class MLEngineVersionOperator(BaseOperator): MLEngineCreateVersionOperator, MLEngineSetDefaultVersionOperator, MLEngineListVersionsOperator, MLEngineDeleteVersionOperator. - :param project_id: The Google Cloud project name to which MLEngine model belongs. - If set to None or missing, the default project_id from the GCP connection is used. (templated) - :type project_id: str - :param model_name: The name of the Google Cloud ML Engine model that the version belongs to. (templated) :type model_name: str - :param version_name: A name to use for the version being operated upon. If not None and the `version` argument is None or does not have a value for the `name` key, then this will be populated in the payload for the `name` key. (templated) :type version_name: str - :param version: A dictionary containing the information about the version. If the `operation` is `create`, `version` should contain all the information about this version such as name, and deploymentUrl. @@ -522,7 +509,6 @@ class MLEngineVersionOperator(BaseOperator): should contain the `name` of the version. If it is None, the only `operation` possible would be `list`. (templated) :type version: dict - :param operation: The operation to perform. Available operations are: * ``create``: Creates a new version in the model specified by `model_name`, @@ -542,10 +528,11 @@ class MLEngineVersionOperator(BaseOperator): The name of the version should be specified in the `version` parameter. :type operation: str - + :param project_id: The Google Cloud project name to which MLEngine model belongs. + If set to None or missing, the default project_id from the GCP connection is used. (templated) + :type project_id: str :param gcp_conn_id: The connection ID to use when fetching connection info. :type gcp_conn_id: str - :param delegate_to: The account to impersonate, if any. For this to work, the service account making the request must have domain-wide delegation enabled. @@ -630,19 +617,15 @@ class MLEngineCreateVersionOperator(BaseOperator): Model should be specified by `model_name`, in which case the `version` parameter should contain all the information to create that version - :param project_id: The Google Cloud project name to which MLEngine model belongs. - If set to None or missing, the default project_id from the GCP connection is used. (templated) - :type project_id: str - :param model_name: The name of the Google Cloud ML Engine model that the version belongs to. (templated) :type model_name: str - :param version: A dictionary containing the information about the version. (templated) :type version: dict - + :param project_id: The Google Cloud project name to which MLEngine model belongs. + If set to None or missing, the default project_id from the GCP connection is used. (templated) + :type project_id: str :param gcp_conn_id: The connection ID to use when fetching connection info. :type gcp_conn_id: str - :param delegate_to: The account to impersonate, if any. For this to work, the service account making the request must have domain-wide delegation enabled. @@ -695,21 +678,17 @@ class MLEngineSetDefaultVersionOperator(BaseOperator): Sets a version in the model. The model should be specified by `model_name` to be the default. The name of the version should be - specified in the `version` parameter. - - :param project_id: The Google Cloud project name to which MLEngine model belongs. - If set to None or missing, the default project_id from the GCP connection is used. (templated) - :type project_id: str + specified in the `version_name` parameter. :param model_name: The name of the Google Cloud ML Engine model that the version belongs to. (templated) :type model_name: str - :param version_name: A name to use for the version being operated upon. (templated) :type version_name: str - + :param project_id: The Google Cloud project name to which MLEngine model belongs. + If set to None or missing, the default project_id from the GCP connection is used. (templated) + :type project_id: str :param gcp_conn_id: The connection ID to use when fetching connection info. :type gcp_conn_id: str - :param delegate_to: The account to impersonate, if any. For this to work, the service account making the request must have domain-wide delegation enabled. @@ -763,17 +742,14 @@ class MLEngineListVersionsOperator(BaseOperator): The model should be specified by `model_name`. - :param project_id: The Google Cloud project name to which MLEngine model belongs. - If set to None or missing, the default project_id from the GCP connection is used. (templated) - :type project_id: str - :param model_name: The name of the Google Cloud ML Engine model that the version belongs to. (templated) :type model_name: str - :param gcp_conn_id: The connection ID to use when fetching connection info. :type gcp_conn_id: str - + :param project_id: The Google Cloud project name to which MLEngine model belongs. + If set to None or missing, the default project_id from the GCP connection is used. (templated) + :type project_id: str :param delegate_to: The account to impersonate, if any. For this to work, the service account making the request must have domain-wide delegation enabled. @@ -820,20 +796,16 @@ class MLEngineDeleteVersionOperator(BaseOperator): The name of the version should be specified in `version_name` parameter from the model specified by `model_name`. - :param project_id: The Google Cloud project name to which MLEngine - model belongs. - :type project_id: str - :param model_name: The name of the Google Cloud ML Engine model that the version belongs to. (templated) :type model_name: str - :param version_name: A name to use for the version being operated upon. (templated) :type version_name: str - + :param project_id: The Google Cloud project name to which MLEngine + model belongs. + :type project_id: str :param gcp_conn_id: The connection ID to use when fetching connection info. :type gcp_conn_id: str - :param delegate_to: The account to impersonate, if any. For this to work, the service account making the request must have domain-wide delegation enabled. @@ -884,57 +856,44 @@ class MLEngineTrainingOperator(BaseOperator): """ Operator for launching a MLEngine training job. - :param project_id: The Google Cloud project name within which MLEngine training job should run. - If set to None or missing, the default project_id from the GCP connection is used. (templated) - :type project_id: str - :param job_id: A unique templated id for the submitted Google MLEngine training job. (templated) :type job_id: str - :param package_uris: A list of package locations for MLEngine training job, which should include the main training program + any additional dependencies. (templated) :type package_uris: List[str] - :param training_python_module: The Python module name to run within MLEngine training job after installing 'package_uris' packages. (templated) :type training_python_module: str - :param training_args: A list of templated command line arguments to pass to the MLEngine training program. (templated) :type training_args: List[str] - :param region: The Google Compute Engine region to run the MLEngine training job in (templated). :type region: str - :param scale_tier: Resource tier for MLEngine training job. (templated) :type scale_tier: str - :param master_type: Cloud ML Engine machine name. Must be set when scale_tier is CUSTOM. (templated) :type master_type: str - :param runtime_version: The Google Cloud ML runtime version to use for training. (templated) :type runtime_version: str - :param python_version: The version of Python used in training. (templated) :type python_version: str - :param job_dir: A Google Cloud Storage path in which to store training outputs and other data needed for training. (templated) :type job_dir: str - + :param project_id: The Google Cloud project name within which MLEngine training job should run. + If set to None or missing, the default project_id from the GCP connection is used. (templated) + :type project_id: str :param gcp_conn_id: The connection ID to use when fetching connection info. :type gcp_conn_id: str - :param delegate_to: The account to impersonate, if any. For this to work, the service account making the request must have domain-wide delegation enabled. :type delegate_to: str - :param mode: Can be one of 'DRY_RUN'/'CLOUD'. In 'DRY_RUN' mode, no real training job will be launched, but the MLEngine training job request will be printed out. In 'CLOUD' mode, a real MLEngine training job From 15b52b5be22fdc8700dba74b7e86abb54fb251c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kamil=20Bregu=C5=82a?= Date: Wed, 9 Oct 2019 14:18:51 +0200 Subject: [PATCH 2/2] fixup! [AIRFLOW-5625] Update MLEngine integration doc and typehint --- airflow/gcp/hooks/mlengine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/gcp/hooks/mlengine.py b/airflow/gcp/hooks/mlengine.py index 6e5035b8bf779..5eddf30b08ee9 100644 --- a/airflow/gcp/hooks/mlengine.py +++ b/airflow/gcp/hooks/mlengine.py @@ -323,7 +323,7 @@ def delete_version( Otherwise raises an error. :rtype: Dict """ - assert project_id is None + assert project_id is not None hook = self.get_conn() full_name = 'projects/{}/models/{}/versions/{}'.format(