Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions airflow/gcp/hooks/mlengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@

from airflow.gcp.hooks.base import GoogleCloudBaseHook
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.version import version as airflow_version

_AIRFLOW_VERSION = 'v' + airflow_version.replace('.', '-').replace('+', '-')


def _poll_with_exponential_delay(request, max_n, is_done_func, is_error_func):
Expand Down Expand Up @@ -111,6 +114,8 @@ def create_job(

hook = self.get_conn()

self._append_label(job)

request = hook.projects().jobs().create( # pylint: disable=no-member
parent='projects/{}'.format(project_id),
body=job)
Expand Down Expand Up @@ -193,6 +198,9 @@ def create_version(
"""
hook = self.get_conn()
parent_name = 'projects/{}/models/{}'.format(project_id, model_name)

self._append_label(version_spec)

create_request = hook.projects().models().versions().create( # pylint: disable=no-member
parent=parent_name, body=version_spec)
response = create_request.execute()
Expand Down Expand Up @@ -295,6 +303,8 @@ def create_model(
"could not be an empty string")
project = 'projects/{}'.format(project_id)

self._append_label(model)

request = hook.projects().models().create( # pylint: disable=no-member
parent=project, body=model)
return request.execute()
Expand Down Expand Up @@ -359,3 +369,7 @@ def _delete_all_versions(self, model_name, project_id):
for version in default_versions:
_, _, version_name = version['name'].rpartition('/')
self.delete_version(project_id=project_id, model_name=model_name, version_name=version_name)

def _append_label(self, model: Dict) -> None:
model['labels'] = model.get('labels', {})
model['labels']['airflow-version'] = _AIRFLOW_VERSION
176 changes: 170 additions & 6 deletions tests/gcp/hooks/test_mlengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.

import unittest
from copy import deepcopy
from unittest import mock

from googleapiclient.errors import HttpError
Expand Down Expand Up @@ -46,10 +47,67 @@ def test_mle_engine_client_creation(self, mock_build, mock_authorize):

@mock.patch("airflow.gcp.hooks.mlengine.MLEngineHook.get_conn")
def test_create_version(self, mock_get_conn):
project_id = 'test-project'
model_name = 'test-model'
version_name = 'test-version'
version = {
'name': version_name,
'labels': {'other-label': 'test-value'}
}
version_with_airflow_version = {
'name': 'test-version',
'labels': {
'other-label': 'test-value',
'airflow-version': hook._AIRFLOW_VERSION
}
}
operation_path = 'projects/{}/operations/test-operation'.format(project_id)
model_path = 'projects/{}/models/{}'.format(project_id, model_name)
operation_done = {'name': operation_path, 'done': True}

(
mock_get_conn.return_value.
projects.return_value.
models.return_value.
versions.return_value.
create.return_value.
execute.return_value
) = version
(
mock_get_conn.return_value.
projects.return_value.
operations.return_value.
get.return_value.
execute.return_value
) = {'name': operation_path, 'done': True}

create_version_response = self.hook.create_version(
project_id=project_id,
model_name=model_name,
version_spec=deepcopy(version)
)

self.assertEqual(create_version_response, operation_done)

mock_get_conn.assert_has_calls([
mock.call().projects().models().versions().create(
body=version_with_airflow_version,
parent=model_path
),
mock.call().projects().models().versions().create().execute(),
mock.call().projects().operations().get(name=version_name),
], any_order=True)

@mock.patch("airflow.gcp.hooks.mlengine.MLEngineHook.get_conn")
def test_create_version_with_labels(self, mock_get_conn):
project_id = 'test-project'
model_name = 'test-model'
version_name = 'test-version'
version = {'name': version_name}
version_with_airflow_version = {
'name': 'test-version',
'labels': {'airflow-version': hook._AIRFLOW_VERSION}
}
operation_path = 'projects/{}/operations/test-operation'.format(project_id)
model_path = 'projects/{}/models/{}'.format(project_id, model_name)
operation_done = {'name': operation_path, 'done': True}
Expand All @@ -73,12 +131,16 @@ def test_create_version(self, mock_get_conn):
create_version_response = self.hook.create_version(
project_id=project_id,
model_name=model_name,
version_spec=version
version_spec=deepcopy(version)
)

self.assertEqual(create_version_response, operation_done)

mock_get_conn.assert_has_calls([
mock.call().projects().models().versions().create(body=version, parent=model_path),
mock.call().projects().models().versions().create(
body=version_with_airflow_version,
parent=model_path
),
mock.call().projects().models().versions().create().execute(),
mock.call().projects().operations().get(name=version_name),
], any_order=True)
Expand Down Expand Up @@ -108,6 +170,7 @@ def test_set_default_version(self, mock_get_conn):
)

self.assertEqual(set_default_version_response, operation_done)

mock_get_conn.assert_has_calls([
mock.call().projects().models().versions().setDefault(body={}, name=version_path),
mock.call().projects().models().versions().setDefault().execute()
Expand Down Expand Up @@ -200,6 +263,10 @@ def test_create_model(self, mock_get_conn):
model = {
'name': model_name,
}
model_with_airflow_version = {
'name': model_name,
'labels': {'airflow-version': hook._AIRFLOW_VERSION}
}
project_path = 'projects/{}'.format(project_id)

(
Expand All @@ -211,12 +278,47 @@ def test_create_model(self, mock_get_conn):
) = model

create_model_response = self.hook.create_model(
project_id=project_id, model=model
project_id=project_id, model=deepcopy(model)
)

self.assertEqual(create_model_response, model)
mock_get_conn.assert_has_calls([
mock.call().projects().models().create(body=model, parent=project_path),
mock.call().projects().models().create(body=model_with_airflow_version, parent=project_path),
mock.call().projects().models().create().execute()
])

@mock.patch("airflow.gcp.hooks.mlengine.MLEngineHook.get_conn")
def test_create_model_with_labels(self, mock_get_conn):
project_id = 'test-project'
model_name = 'test-model'
model = {
'name': model_name,
'labels': {'other-label': 'test-value'}
}
model_with_airflow_version = {
'name': model_name,
'labels': {
'other-label': 'test-value',
'airflow-version': hook._AIRFLOW_VERSION
}
}
project_path = 'projects/{}'.format(project_id)

(
mock_get_conn.return_value.
projects.return_value.
models.return_value.
create.return_value.
execute.return_value
) = model

create_model_response = self.hook.create_model(
project_id=project_id, model=deepcopy(model)
)

self.assertEqual(create_model_response, model)
mock_get_conn.assert_has_calls([
mock.call().projects().models().create(body=model_with_airflow_version, parent=project_path),
mock.call().projects().models().create().execute()
])

Expand Down Expand Up @@ -360,6 +462,12 @@ def test_create_mlengine_job(self, mock_get_conn, mock_sleep):
'jobId': job_id,
'foo': 4815162342,
}
new_job_with_airflow_version = {
'jobId': job_id,
'foo': 4815162342,
'labels': {'airflow-version': hook._AIRFLOW_VERSION}
}

job_succeeded = {
'jobId': job_id,
'state': 'SUCCEEDED',
Expand All @@ -385,12 +493,68 @@ def test_create_mlengine_job(self, mock_get_conn, mock_sleep):
) = [job_queued, job_succeeded]

create_job_response = self.hook.create_job(
project_id=project_id, job=new_job
project_id=project_id, job=deepcopy(new_job)
)

self.assertEqual(create_job_response, job_succeeded)
mock_get_conn.assert_has_calls([
mock.call().projects().jobs().create(body=new_job, parent=project_path),
mock.call().projects().jobs().create(body=new_job_with_airflow_version, parent=project_path),
mock.call().projects().jobs().get(name=job_path),
mock.call().projects().jobs().get().execute()
], any_order=True)

@mock.patch("airflow.gcp.hooks.mlengine.time.sleep")
@mock.patch("airflow.gcp.hooks.mlengine.MLEngineHook.get_conn")
def test_create_mlengine_job_with_labels(self, mock_get_conn, mock_sleep):
project_id = 'test-project'
job_id = 'test-job-id'
project_path = 'projects/{}'.format(project_id)
job_path = 'projects/{}/jobs/{}'.format(project_id, job_id)
new_job = {
'jobId': job_id,
'foo': 4815162342,
'labels': {'other-label': 'test-value'}
}
new_job_with_airflow_version = {
'jobId': job_id,
'foo': 4815162342,
'labels': {
'other-label': 'test-value',
'airflow-version': hook._AIRFLOW_VERSION
}
}

job_succeeded = {
'jobId': job_id,
'state': 'SUCCEEDED',
}
job_queued = {
'jobId': job_id,
'state': 'QUEUED',
}

(
mock_get_conn.return_value.
projects.return_value.
jobs.return_value.
create.return_value.
execute.return_value
) = job_queued
(
mock_get_conn.return_value.
projects.return_value.
jobs.return_value.
get.return_value.
execute.side_effect
) = [job_queued, job_succeeded]

create_job_response = self.hook.create_job(
project_id=project_id, job=deepcopy(new_job)
)

self.assertEqual(create_job_response, job_succeeded)
mock_get_conn.assert_has_calls([
mock.call().projects().jobs().create(body=new_job_with_airflow_version, parent=project_path),
mock.call().projects().jobs().get(name=job_path),
mock.call().projects().jobs().get().execute()
], any_order=True)
Expand Down