Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion airflow/gcp/hooks/cloud_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(

def get_conn(self):
"""
Retrieves the connection to Cloud Functions.
Retrieves the connection to Cloud Build.

:return: Google Cloud Build services object.
"""
Expand Down
106 changes: 100 additions & 6 deletions airflow/gcp/hooks/mlengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ class MLEngineHook(GoogleCloudBaseHook):
"""
def get_conn(self):
"""
Returns a Google MLEngine service object.
Retrieves the connection to MLEngine.

:return: Google MLEngine services object.
"""
authed_http = self._authorize()
return build('ml', 'v1', http=authed_http, cache_discovery=False)
Expand Down Expand Up @@ -142,8 +144,13 @@ def create_job(

def _get_job(self, project_id: str, job_id: str) -> Dict:
"""
Gets a MLEngine job based on the job name.
Gets a MLEngine job based on the job id.

:param project_id: The project in which the Job is located.
If set to None or missing, the default project_id from the GCP connection is used. (templated)
:type project_id: str
:param job_id: A unique id for the Google MLEngine job. (templated)
:type job_id: str
:return: MLEngine job object if succeed.
:rtype: dict
:raises: googleapiclient.errors.HttpError
Expand All @@ -168,6 +175,14 @@ def _wait_for_job_done(self, project_id: str, job_id: str, interval: int = 30):

This method will periodically check the job state until the job reach
a terminal state.

:param project_id: The project in which the Job is located.
If set to None or missing, the default project_id from the GCP connection is used. (templated)
:type project_id: str
:param job_id: A unique id for the Google MLEngine job. (templated)
:type job_id: str
:param interval: Time expressed in seconds after which the job status is checked again. (templated)
:type interval: int
:raises: googleapiclient.errors.HttpError
"""
if interval <= 0:
Expand All @@ -188,8 +203,18 @@ def create_version(
"""
Creates the Version on Google Cloud ML Engine.

Returns the operation if the version was created successfully and
raises an error otherwise.
:param version_spec: A dictionary containing the information about the version. (templated)
:type version_spec: dict
:param model_name: The name of the Google Cloud ML Engine model that the version belongs to.
(templated)
:type model_name: str
:param project_id: The Google Cloud project name to which MLEngine model belongs.
If set to None or missing, the default project_id from the GCP connection is used.
(templated)
:type project_id: str
:return: If the version was created successfully, returns the operation.
Otherwise raises an error .
:rtype: dict
"""
hook = self.get_conn()
parent_name = 'projects/{}/models/{}'.format(project_id, model_name)
Expand All @@ -214,6 +239,19 @@ def set_default_version(
) -> Dict:
"""
Sets a version to be the default. Blocks until finished.

:param model_name: The name of the Google Cloud ML Engine model that the version belongs to.
(templated)
:type model_name: str
:param version_name: A name to use for the version being operated upon. (templated)
:type version_name: str
:param project_id: The Google Cloud project name to which MLEngine model belongs.
If set to None or missing, the default project_id from the GCP connection is used. (templated)
:type project_id: str
:return: If successful, return an instance of Version.
Otherwise raises an error.
:rtype: dict
:raises: googleapiclient.errors.HttpError
"""
hook = self.get_conn()
full_version_name = 'projects/{}/models/{}/versions/{}'.format(
Expand All @@ -237,6 +275,16 @@ def list_versions(
) -> List[Dict]:
"""
Lists all available versions of a model. Blocks until finished.

:param model_name: The name of the Google Cloud ML Engine model that the version
belongs to. (templated)
:type model_name: str
:param project_id: The Google Cloud project name to which MLEngine model belongs.
If set to None or missing, the default project_id from the GCP connection is used. (templated)
:type project_id: str
:return: return an list of instance of Version.
:rtype: List[Dict]
:raises: googleapiclient.errors.HttpError
"""
hook = self.get_conn()
result = [] # type: List[Dict]
Expand All @@ -261,10 +309,22 @@ def delete_version(
model_name: str,
version_name: str,
project_id: Optional[str] = None,
):
) -> Dict:
"""
Deletes the given version of a model. Blocks until finished.

:param model_name: The name of the Google Cloud ML Engine model that the version
belongs to. (templated)
:type model_name: str
:param project_id: The Google Cloud project name to which MLEngine
model belongs.
:type project_id: str
:return: If the version was deleted successfully, returns the operation.
Otherwise raises an error.
:rtype: Dict
"""
assert project_id is not None

hook = self.get_conn()
full_name = 'projects/{}/models/{}/versions/{}'.format(
project_id, model_name, version_name)
Expand All @@ -288,6 +348,16 @@ def create_model(
) -> Dict:
"""
Create a Model. Blocks until finished.

:param model: A dictionary containing the information about the model.
:type model: dict
:param project_id: The Google Cloud project name to which MLEngine model belongs.
If set to None or missing, the default project_id from the GCP connection is used. (templated)
:type project_id: str
:return: If the version was created successfully, returns the instance of Model.
Otherwise raises an error.
:rtype: Dict
:raises: googleapiclient.errors.HttpError
"""
hook = self.get_conn()
if not model['name']:
Expand All @@ -307,6 +377,16 @@ def get_model(
) -> Optional[Dict]:
"""
Gets a Model. Blocks until finished.

:param model_name: The name of the model.
:type model_name: str
:param project_id: The Google Cloud project name to which MLEngine model belongs.
If set to None or missing, the default project_id from the GCP connection is used. (templated)
:type project_id: str
:return: If the model exists, returns the instance of Model.
Otherwise return None.
:rtype: Dict
:raises: googleapiclient.errors.HttpError
"""
hook = self.get_conn()
if not model_name:
Expand All @@ -332,8 +412,22 @@ def delete_model(
) -> None:
"""
Delete a Model. Blocks until finished.

:param model_name: The name of the model.
:type model_name: str
:param delete_contents: Whether to force the deletion even if the models is not empty.
Will delete all version (if any) in the dataset if set to True.
The default value is False.
:type delete_contents: bool
:param project_id: The Google Cloud project name to which MLEngine model belongs.
If set to None or missing, the default project_id from the GCP connection is used. (templated)
:type project_id: str
:raises: googleapiclient.errors.HttpError
"""
assert project_id is not None

hook = self.get_conn()

if not model_name:
raise ValueError("Model name must be provided and it could not be an empty string")
model_path = 'projects/{}/models/{}'.format(project_id, model_name)
Expand All @@ -348,7 +442,7 @@ def delete_model(
return
raise

def _delete_all_versions(self, model_name, project_id):
def _delete_all_versions(self, model_name: str, project_id: str):
versions = self.list_versions(project_id=project_id, model_name=model_name)
# The default version can only be deleted when it is the last one in the model
non_default_versions = (version for version in versions if not version.get('isDefault', False))
Expand Down
Loading