Skip to content

Commit

Permalink
[AIRFLOW-5626] Add labels to MLEngine resources
Browse files Browse the repository at this point in the history
  • Loading branch information
Kamil Breguła committed Oct 13, 2019
1 parent c0d98a7 commit 9d3bed5
Show file tree
Hide file tree
Showing 2 changed files with 184 additions and 6 deletions.
14 changes: 14 additions & 0 deletions airflow/gcp/hooks/mlengine.py
Expand Up @@ -28,6 +28,9 @@

from airflow.gcp.hooks.base import GoogleCloudBaseHook
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.version import version as airflow_version

_AIRFLOW_VERSION = 'v' + airflow_version.replace('.', '-').replace('+', '-')


def _poll_with_exponential_delay(request, max_n, is_done_func, is_error_func):
Expand Down Expand Up @@ -111,6 +114,8 @@ def create_job(

hook = self.get_conn()

self._append_label(job)

request = hook.projects().jobs().create( # pylint: disable=no-member
parent='projects/{}'.format(project_id),
body=job)
Expand Down Expand Up @@ -193,6 +198,9 @@ def create_version(
"""
hook = self.get_conn()
parent_name = 'projects/{}/models/{}'.format(project_id, model_name)

self._append_label(version_spec)

create_request = hook.projects().models().versions().create( # pylint: disable=no-member
parent=parent_name, body=version_spec)
response = create_request.execute()
Expand Down Expand Up @@ -295,6 +303,8 @@ def create_model(
"could not be an empty string")
project = 'projects/{}'.format(project_id)

self._append_label(model)

request = hook.projects().models().create( # pylint: disable=no-member
parent=project, body=model)
return request.execute()
Expand Down Expand Up @@ -359,3 +369,7 @@ def _delete_all_versions(self, model_name, project_id):
for version in default_versions:
_, _, version_name = version['name'].rpartition('/')
self.delete_version(project_id=project_id, model_name=model_name, version_name=version_name)

def _append_label(self, model: Dict) -> None:
model['labels'] = model.get('labels', {})
model['labels']['airflow-version'] = _AIRFLOW_VERSION
176 changes: 170 additions & 6 deletions tests/gcp/hooks/test_mlengine.py
Expand Up @@ -16,6 +16,7 @@
# under the License.

import unittest
from copy import deepcopy
from unittest import mock

from googleapiclient.errors import HttpError
Expand Down Expand Up @@ -46,10 +47,67 @@ def test_mle_engine_client_creation(self, mock_build, mock_authorize):

@mock.patch("airflow.gcp.hooks.mlengine.MLEngineHook.get_conn")
def test_create_version(self, mock_get_conn):
project_id = 'test-project'
model_name = 'test-model'
version_name = 'test-version'
version = {
'name': version_name,
'labels': {'other-label': 'test-value'}
}
version_with_airflow_version = {
'name': 'test-version',
'labels': {
'other-label': 'test-value',
'airflow-version': hook._AIRFLOW_VERSION
}
}
operation_path = 'projects/{}/operations/test-operation'.format(project_id)
model_path = 'projects/{}/models/{}'.format(project_id, model_name)
operation_done = {'name': operation_path, 'done': True}

(
mock_get_conn.return_value.
projects.return_value.
models.return_value.
versions.return_value.
create.return_value.
execute.return_value
) = version
(
mock_get_conn.return_value.
projects.return_value.
operations.return_value.
get.return_value.
execute.return_value
) = {'name': operation_path, 'done': True}

create_version_response = self.hook.create_version(
project_id=project_id,
model_name=model_name,
version_spec=deepcopy(version)
)

self.assertEqual(create_version_response, operation_done)

mock_get_conn.assert_has_calls([
mock.call().projects().models().versions().create(
body=version_with_airflow_version,
parent=model_path
),
mock.call().projects().models().versions().create().execute(),
mock.call().projects().operations().get(name=version_name),
], any_order=True)

@mock.patch("airflow.gcp.hooks.mlengine.MLEngineHook.get_conn")
def test_create_version_with_labels(self, mock_get_conn):
project_id = 'test-project'
model_name = 'test-model'
version_name = 'test-version'
version = {'name': version_name}
version_with_airflow_version = {
'name': 'test-version',
'labels': {'airflow-version': hook._AIRFLOW_VERSION}
}
operation_path = 'projects/{}/operations/test-operation'.format(project_id)
model_path = 'projects/{}/models/{}'.format(project_id, model_name)
operation_done = {'name': operation_path, 'done': True}
Expand All @@ -73,12 +131,16 @@ def test_create_version(self, mock_get_conn):
create_version_response = self.hook.create_version(
project_id=project_id,
model_name=model_name,
version_spec=version
version_spec=deepcopy(version)
)

self.assertEqual(create_version_response, operation_done)

mock_get_conn.assert_has_calls([
mock.call().projects().models().versions().create(body=version, parent=model_path),
mock.call().projects().models().versions().create(
body=version_with_airflow_version,
parent=model_path
),
mock.call().projects().models().versions().create().execute(),
mock.call().projects().operations().get(name=version_name),
], any_order=True)
Expand Down Expand Up @@ -108,6 +170,7 @@ def test_set_default_version(self, mock_get_conn):
)

self.assertEqual(set_default_version_response, operation_done)

mock_get_conn.assert_has_calls([
mock.call().projects().models().versions().setDefault(body={}, name=version_path),
mock.call().projects().models().versions().setDefault().execute()
Expand Down Expand Up @@ -200,6 +263,10 @@ def test_create_model(self, mock_get_conn):
model = {
'name': model_name,
}
model_with_airflow_version = {
'name': model_name,
'labels': {'airflow-version': hook._AIRFLOW_VERSION}
}
project_path = 'projects/{}'.format(project_id)

(
Expand All @@ -211,12 +278,47 @@ def test_create_model(self, mock_get_conn):
) = model

create_model_response = self.hook.create_model(
project_id=project_id, model=model
project_id=project_id, model=deepcopy(model)
)

self.assertEqual(create_model_response, model)
mock_get_conn.assert_has_calls([
mock.call().projects().models().create(body=model, parent=project_path),
mock.call().projects().models().create(body=model_with_airflow_version, parent=project_path),
mock.call().projects().models().create().execute()
])

@mock.patch("airflow.gcp.hooks.mlengine.MLEngineHook.get_conn")
def test_create_model_with_labels(self, mock_get_conn):
project_id = 'test-project'
model_name = 'test-model'
model = {
'name': model_name,
'labels': {'other-label': 'test-value'}
}
model_with_airflow_version = {
'name': model_name,
'labels': {
'other-label': 'test-value',
'airflow-version': hook._AIRFLOW_VERSION
}
}
project_path = 'projects/{}'.format(project_id)

(
mock_get_conn.return_value.
projects.return_value.
models.return_value.
create.return_value.
execute.return_value
) = model

create_model_response = self.hook.create_model(
project_id=project_id, model=deepcopy(model)
)

self.assertEqual(create_model_response, model)
mock_get_conn.assert_has_calls([
mock.call().projects().models().create(body=model_with_airflow_version, parent=project_path),
mock.call().projects().models().create().execute()
])

Expand Down Expand Up @@ -360,6 +462,12 @@ def test_create_mlengine_job(self, mock_get_conn, mock_sleep):
'jobId': job_id,
'foo': 4815162342,
}
new_job_with_airflow_version = {
'jobId': job_id,
'foo': 4815162342,
'labels': {'airflow-version': hook._AIRFLOW_VERSION}
}

job_succeeded = {
'jobId': job_id,
'state': 'SUCCEEDED',
Expand All @@ -385,12 +493,68 @@ def test_create_mlengine_job(self, mock_get_conn, mock_sleep):
) = [job_queued, job_succeeded]

create_job_response = self.hook.create_job(
project_id=project_id, job=new_job
project_id=project_id, job=deepcopy(new_job)
)

self.assertEqual(create_job_response, job_succeeded)
mock_get_conn.assert_has_calls([
mock.call().projects().jobs().create(body=new_job, parent=project_path),
mock.call().projects().jobs().create(body=new_job_with_airflow_version, parent=project_path),
mock.call().projects().jobs().get(name=job_path),
mock.call().projects().jobs().get().execute()
], any_order=True)

@mock.patch("airflow.gcp.hooks.mlengine.time.sleep")
@mock.patch("airflow.gcp.hooks.mlengine.MLEngineHook.get_conn")
def test_create_mlengine_job_with_labels(self, mock_get_conn, mock_sleep):
project_id = 'test-project'
job_id = 'test-job-id'
project_path = 'projects/{}'.format(project_id)
job_path = 'projects/{}/jobs/{}'.format(project_id, job_id)
new_job = {
'jobId': job_id,
'foo': 4815162342,
'labels': {'other-label': 'test-value'}
}
new_job_with_airflow_version = {
'jobId': job_id,
'foo': 4815162342,
'labels': {
'other-label': 'test-value',
'airflow-version': hook._AIRFLOW_VERSION
}
}

job_succeeded = {
'jobId': job_id,
'state': 'SUCCEEDED',
}
job_queued = {
'jobId': job_id,
'state': 'QUEUED',
}

(
mock_get_conn.return_value.
projects.return_value.
jobs.return_value.
create.return_value.
execute.return_value
) = job_queued
(
mock_get_conn.return_value.
projects.return_value.
jobs.return_value.
get.return_value.
execute.side_effect
) = [job_queued, job_succeeded]

create_job_response = self.hook.create_job(
project_id=project_id, job=deepcopy(new_job)
)

self.assertEqual(create_job_response, job_succeeded)
mock_get_conn.assert_has_calls([
mock.call().projects().jobs().create(body=new_job_with_airflow_version, parent=project_path),
mock.call().projects().jobs().get(name=job_path),
mock.call().projects().jobs().get().execute()
], any_order=True)
Expand Down

0 comments on commit 9d3bed5

Please sign in to comment.