Skip to content

Commit

Permalink
Add hook and operator for Google Cloud Life Sciences (#8481)
Browse files Browse the repository at this point in the history
  • Loading branch information
ephraimbuddy committed Apr 25, 2020
1 parent 1ea9fa7 commit 14b22e6
Show file tree
Hide file tree
Showing 9 changed files with 824 additions and 0 deletions.
100 changes: 100 additions & 0 deletions airflow/providers/google/cloud/example_dags/example_life_sciences.py
@@ -0,0 +1,100 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import os

from airflow import models
from airflow.providers.google.cloud.operators.life_sciences import LifeSciencesRunPipelineOperator
from airflow.utils import dates

PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project-id")
BUCKET = os.environ.get("GCP_GCS_BUCKET", "example-bucket")
FILENAME = os.environ.get("GCP_GCS_LIFE_SCIENCES_FILENAME", 'input.in')
LOCATION = os.environ.get("GCP_LIFE_SCIENCES_LOCATION", 'us-central1')


# [START howto_configure_simple_action_pipeline]
SIMPLE_ACTION_PIEPELINE = {
"pipeline": {
"actions": [
{
"imageUri": "bash",
"commands": ["-c", "echo Hello, world"]
},
],
"resources": {
"regions": ["{}".format(LOCATION)],
"virtualMachine": {
"machineType": "n1-standard-1",
}
}
},
}
# [END howto_configure_simple_action_pipeline]

# [START howto_configure_multiple_action_pipeline]
MULTI_ACTION_PIPELINE = {
"pipeline": {
"actions": [
{
"imageUri": "google/cloud-sdk",
"commands": ["gsutil", "cp", "gs://{}/{}".format(BUCKET, FILENAME), "/tmp"]
},
{
"imageUri": "bash",
"commands": ["-c", "echo Hello, world"]
},
{
"imageUri": "google/cloud-sdk",
"commands": ["gsutil", "cp", "gs://{}/{}".format(BUCKET, FILENAME),
"gs://{}/output.in".format(BUCKET)]
},
],
"resources": {
"regions": ["{}".format(LOCATION)],
"virtualMachine": {
"machineType": "n1-standard-1",
}
}
}
}
# [END howto_configure_multiple_action_pipeline]

with models.DAG("example_gcp_life_sciences",
default_args=dict(start_date=dates.days_ago(1)),
schedule_interval=None,
tags=['example'],
) as dag:

# [START howto_run_pipeline]
simple_life_science_action_pipeline = LifeSciencesRunPipelineOperator(
task_id='simple-action-pipeline',
body=SIMPLE_ACTION_PIEPELINE,
project_id=PROJECT_ID,
location=LOCATION
)
# [END howto_run_pipeline]

multiple_life_science_action_pipeline = LifeSciencesRunPipelineOperator(
task_id='multi-action-pipeline',
body=MULTI_ACTION_PIPELINE,
project_id=PROJECT_ID,
location=LOCATION
)

simple_life_science_action_pipeline >> multiple_life_science_action_pipeline
150 changes: 150 additions & 0 deletions airflow/providers/google/cloud/hooks/life_sciences.py
@@ -0,0 +1,150 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Hook for Google Cloud Life Sciences service"""

import time
from typing import Any, Dict, Optional

import google.api_core.path_template
from googleapiclient.discovery import build

from airflow.exceptions import AirflowException
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook

# Time to sleep between active checks of the operation results
TIME_TO_SLEEP_IN_SECONDS = 5


# noinspection PyAbstractClass
class LifeSciencesHook(GoogleBaseHook):
"""
Hook for the Google Cloud Life Sciences APIs.
All the methods in the hook where project_id is used must be called with
keyword arguments rather than positional.
:param api_version: API version used (for example v1 or v1beta1).
:type api_version: str
:param gcp_conn_id: The connection ID to use when fetching connection info.
:type gcp_conn_id: str
:param delegate_to: The account to impersonate, if any.
For this to work, the service account making the request must have
domain-wide delegation enabled.
:type delegate_to: str
"""

_conn = None # type: Optional[Any]

def __init__(
self,
api_version: str = "v2beta",
gcp_conn_id: str = "google_cloud_default",
delegate_to: Optional[str] = None
) -> None:
super().__init__(gcp_conn_id, delegate_to)
self.api_version = api_version

def get_conn(self):
"""
Retrieves the connection to Cloud Life Sciences.
:return: Google Cloud Life Sciences service object.
"""
if not self._conn:
http_authorized = self._authorize()
self._conn = build("lifesciences", self.api_version,
http=http_authorized, cache_discovery=False)
return self._conn

@GoogleBaseHook.fallback_to_default_project_id
def run_pipeline(self, body: Dict, location: str, project_id: str):
"""
Runs a pipeline
:param body: The request body.
:type body: dict
:param location: The location of the project. For example: "us-east1".
:type location: str
:param project_id: Optional, Google Cloud Project project_id where the function belongs.
If set to None or missing, the default project_id from the GCP connection is used.
:type project_id: str
:rtype: dict
"""
parent = self._location_path(project_id=project_id, location=location)
service = self.get_conn()

request = (service.projects() # pylint: disable=no-member
.locations()
.pipelines()
.run(parent=parent, body=body)
)

response = request.execute(num_retries=self.num_retries)

# wait
operation_name = response['name']
self._wait_for_operation_to_complete(operation_name)

return response

@GoogleBaseHook.fallback_to_default_project_id
def _location_path(self, project_id: str, location: str):
"""
Return a location string.
:param project_id: Optional, Google Cloud Project project_id where the
function belongs. If set to None or missing, the default project_id
from the GCP connection is used.
:type project_id: str
:param location: The location of the project. For example: "us-east1".
:type location: str
"""
return google.api_core.path_template.expand(
'projects/{project}/locations/{location}',
project=project_id,
location=location,
)

def _wait_for_operation_to_complete(self, operation_name: str) -> None:
"""
Waits for the named operation to complete - checks status of the
asynchronous call.
:param operation_name: The name of the operation.
:type operation_name: str
:return: The response returned by the operation.
:rtype: dict
:exception: AirflowException in case error is returned.
"""
service = self.get_conn()
while True:
operation_response = (service.projects() # pylint: disable=no-member
.locations()
.operations()
.get(name=operation_name)
.execute(num_retries=self.num_retries))
self.log.info('Waiting for pipeline operation to complete')
if operation_response.get("done"):
response = operation_response.get("response")
error = operation_response.get("error")
# Note, according to documentation always either response or error is
# set when "done" == True
if error:
raise AirflowException(str(error))
return response
time.sleep(TIME_TO_SLEEP_IN_SECONDS)
74 changes: 74 additions & 0 deletions airflow/providers/google/cloud/operators/life_sciences.py
@@ -0,0 +1,74 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Operators that interact with Google Cloud Life Sciences service."""

from typing import Iterable, Optional

from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.google.cloud.hooks.life_sciences import LifeSciencesHook
from airflow.utils.decorators import apply_defaults


class LifeSciencesRunPipelineOperator(BaseOperator):
"""
Runs a Life Sciences Pipeline
:param body: The request body
:type body: dict
:param location: The location of the project
:type location: str
:param project_id: ID of the Google Cloud project if None then
default project_id is used.
:param project_id: str
:param gcp_conn_id: The connection ID to use to connect to Google Cloud Platform.
:type gcp_conn_id: str
:param api_version: API version used (for example v2beta).
:type api_version: str
"""

template_fields = ("body", "gcp_conn_id", "api_version") # type: Iterable[str]

@apply_defaults
def __init__(self,
body: dict,
location: str,
project_id: Optional[str] = None,
gcp_conn_id: str = "google_cloud_default",
api_version: str = "v2beta",
*args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.body = body
self.location = location
self.project_id = project_id
self.gcp_conn_id = gcp_conn_id
self.api_version = api_version
self._validate_inputs()

def _validate_inputs(self):
if not self.body:
raise AirflowException("The required parameter 'body' is missing")
if not self.location:
raise AirflowException("The required parameter 'location' is missing")

def execute(self, context):
hook = LifeSciencesHook(gcp_conn_id=self.gcp_conn_id, api_version=self.api_version)

return hook.run_pipeline(body=self.body,
location=self.location,
project_id=self.project_id)

0 comments on commit 14b22e6

Please sign in to comment.