Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add hook and operator for Google Cloud Life Sciences (#8481)
- Loading branch information
1 parent
1ea9fa7
commit 14b22e6
Showing
9 changed files
with
824 additions
and
0 deletions.
There are no files selected for viewing
100 changes: 100 additions & 0 deletions
100
airflow/providers/google/cloud/example_dags/example_life_sciences.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
# | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
import os | ||
|
||
from airflow import models | ||
from airflow.providers.google.cloud.operators.life_sciences import LifeSciencesRunPipelineOperator | ||
from airflow.utils import dates | ||
|
||
PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project-id") | ||
BUCKET = os.environ.get("GCP_GCS_BUCKET", "example-bucket") | ||
FILENAME = os.environ.get("GCP_GCS_LIFE_SCIENCES_FILENAME", 'input.in') | ||
LOCATION = os.environ.get("GCP_LIFE_SCIENCES_LOCATION", 'us-central1') | ||
|
||
|
||
# [START howto_configure_simple_action_pipeline] | ||
SIMPLE_ACTION_PIEPELINE = { | ||
"pipeline": { | ||
"actions": [ | ||
{ | ||
"imageUri": "bash", | ||
"commands": ["-c", "echo Hello, world"] | ||
}, | ||
], | ||
"resources": { | ||
"regions": ["{}".format(LOCATION)], | ||
"virtualMachine": { | ||
"machineType": "n1-standard-1", | ||
} | ||
} | ||
}, | ||
} | ||
# [END howto_configure_simple_action_pipeline] | ||
|
||
# [START howto_configure_multiple_action_pipeline] | ||
MULTI_ACTION_PIPELINE = { | ||
"pipeline": { | ||
"actions": [ | ||
{ | ||
"imageUri": "google/cloud-sdk", | ||
"commands": ["gsutil", "cp", "gs://{}/{}".format(BUCKET, FILENAME), "/tmp"] | ||
}, | ||
{ | ||
"imageUri": "bash", | ||
"commands": ["-c", "echo Hello, world"] | ||
}, | ||
{ | ||
"imageUri": "google/cloud-sdk", | ||
"commands": ["gsutil", "cp", "gs://{}/{}".format(BUCKET, FILENAME), | ||
"gs://{}/output.in".format(BUCKET)] | ||
}, | ||
], | ||
"resources": { | ||
"regions": ["{}".format(LOCATION)], | ||
"virtualMachine": { | ||
"machineType": "n1-standard-1", | ||
} | ||
} | ||
} | ||
} | ||
# [END howto_configure_multiple_action_pipeline] | ||
|
||
with models.DAG("example_gcp_life_sciences", | ||
default_args=dict(start_date=dates.days_ago(1)), | ||
schedule_interval=None, | ||
tags=['example'], | ||
) as dag: | ||
|
||
# [START howto_run_pipeline] | ||
simple_life_science_action_pipeline = LifeSciencesRunPipelineOperator( | ||
task_id='simple-action-pipeline', | ||
body=SIMPLE_ACTION_PIEPELINE, | ||
project_id=PROJECT_ID, | ||
location=LOCATION | ||
) | ||
# [END howto_run_pipeline] | ||
|
||
multiple_life_science_action_pipeline = LifeSciencesRunPipelineOperator( | ||
task_id='multi-action-pipeline', | ||
body=MULTI_ACTION_PIPELINE, | ||
project_id=PROJECT_ID, | ||
location=LOCATION | ||
) | ||
|
||
simple_life_science_action_pipeline >> multiple_life_science_action_pipeline |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
# | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""Hook for Google Cloud Life Sciences service""" | ||
|
||
import time | ||
from typing import Any, Dict, Optional | ||
|
||
import google.api_core.path_template | ||
from googleapiclient.discovery import build | ||
|
||
from airflow.exceptions import AirflowException | ||
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook | ||
|
||
# Time to sleep between active checks of the operation results | ||
TIME_TO_SLEEP_IN_SECONDS = 5 | ||
|
||
|
||
# noinspection PyAbstractClass | ||
class LifeSciencesHook(GoogleBaseHook): | ||
""" | ||
Hook for the Google Cloud Life Sciences APIs. | ||
All the methods in the hook where project_id is used must be called with | ||
keyword arguments rather than positional. | ||
:param api_version: API version used (for example v1 or v1beta1). | ||
:type api_version: str | ||
:param gcp_conn_id: The connection ID to use when fetching connection info. | ||
:type gcp_conn_id: str | ||
:param delegate_to: The account to impersonate, if any. | ||
For this to work, the service account making the request must have | ||
domain-wide delegation enabled. | ||
:type delegate_to: str | ||
""" | ||
|
||
_conn = None # type: Optional[Any] | ||
|
||
def __init__( | ||
self, | ||
api_version: str = "v2beta", | ||
gcp_conn_id: str = "google_cloud_default", | ||
delegate_to: Optional[str] = None | ||
) -> None: | ||
super().__init__(gcp_conn_id, delegate_to) | ||
self.api_version = api_version | ||
|
||
def get_conn(self): | ||
""" | ||
Retrieves the connection to Cloud Life Sciences. | ||
:return: Google Cloud Life Sciences service object. | ||
""" | ||
if not self._conn: | ||
http_authorized = self._authorize() | ||
self._conn = build("lifesciences", self.api_version, | ||
http=http_authorized, cache_discovery=False) | ||
return self._conn | ||
|
||
@GoogleBaseHook.fallback_to_default_project_id | ||
def run_pipeline(self, body: Dict, location: str, project_id: str): | ||
""" | ||
Runs a pipeline | ||
:param body: The request body. | ||
:type body: dict | ||
:param location: The location of the project. For example: "us-east1". | ||
:type location: str | ||
:param project_id: Optional, Google Cloud Project project_id where the function belongs. | ||
If set to None or missing, the default project_id from the GCP connection is used. | ||
:type project_id: str | ||
:rtype: dict | ||
""" | ||
parent = self._location_path(project_id=project_id, location=location) | ||
service = self.get_conn() | ||
|
||
request = (service.projects() # pylint: disable=no-member | ||
.locations() | ||
.pipelines() | ||
.run(parent=parent, body=body) | ||
) | ||
|
||
response = request.execute(num_retries=self.num_retries) | ||
|
||
# wait | ||
operation_name = response['name'] | ||
self._wait_for_operation_to_complete(operation_name) | ||
|
||
return response | ||
|
||
@GoogleBaseHook.fallback_to_default_project_id | ||
def _location_path(self, project_id: str, location: str): | ||
""" | ||
Return a location string. | ||
:param project_id: Optional, Google Cloud Project project_id where the | ||
function belongs. If set to None or missing, the default project_id | ||
from the GCP connection is used. | ||
:type project_id: str | ||
:param location: The location of the project. For example: "us-east1". | ||
:type location: str | ||
""" | ||
return google.api_core.path_template.expand( | ||
'projects/{project}/locations/{location}', | ||
project=project_id, | ||
location=location, | ||
) | ||
|
||
def _wait_for_operation_to_complete(self, operation_name: str) -> None: | ||
""" | ||
Waits for the named operation to complete - checks status of the | ||
asynchronous call. | ||
:param operation_name: The name of the operation. | ||
:type operation_name: str | ||
:return: The response returned by the operation. | ||
:rtype: dict | ||
:exception: AirflowException in case error is returned. | ||
""" | ||
service = self.get_conn() | ||
while True: | ||
operation_response = (service.projects() # pylint: disable=no-member | ||
.locations() | ||
.operations() | ||
.get(name=operation_name) | ||
.execute(num_retries=self.num_retries)) | ||
self.log.info('Waiting for pipeline operation to complete') | ||
if operation_response.get("done"): | ||
response = operation_response.get("response") | ||
error = operation_response.get("error") | ||
# Note, according to documentation always either response or error is | ||
# set when "done" == True | ||
if error: | ||
raise AirflowException(str(error)) | ||
return response | ||
time.sleep(TIME_TO_SLEEP_IN_SECONDS) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
# | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""Operators that interact with Google Cloud Life Sciences service.""" | ||
|
||
from typing import Iterable, Optional | ||
|
||
from airflow.exceptions import AirflowException | ||
from airflow.models import BaseOperator | ||
from airflow.providers.google.cloud.hooks.life_sciences import LifeSciencesHook | ||
from airflow.utils.decorators import apply_defaults | ||
|
||
|
||
class LifeSciencesRunPipelineOperator(BaseOperator): | ||
""" | ||
Runs a Life Sciences Pipeline | ||
:param body: The request body | ||
:type body: dict | ||
:param location: The location of the project | ||
:type location: str | ||
:param project_id: ID of the Google Cloud project if None then | ||
default project_id is used. | ||
:param project_id: str | ||
:param gcp_conn_id: The connection ID to use to connect to Google Cloud Platform. | ||
:type gcp_conn_id: str | ||
:param api_version: API version used (for example v2beta). | ||
:type api_version: str | ||
""" | ||
|
||
template_fields = ("body", "gcp_conn_id", "api_version") # type: Iterable[str] | ||
|
||
@apply_defaults | ||
def __init__(self, | ||
body: dict, | ||
location: str, | ||
project_id: Optional[str] = None, | ||
gcp_conn_id: str = "google_cloud_default", | ||
api_version: str = "v2beta", | ||
*args, **kwargs) -> None: | ||
super().__init__(*args, **kwargs) | ||
self.body = body | ||
self.location = location | ||
self.project_id = project_id | ||
self.gcp_conn_id = gcp_conn_id | ||
self.api_version = api_version | ||
self._validate_inputs() | ||
|
||
def _validate_inputs(self): | ||
if not self.body: | ||
raise AirflowException("The required parameter 'body' is missing") | ||
if not self.location: | ||
raise AirflowException("The required parameter 'location' is missing") | ||
|
||
def execute(self, context): | ||
hook = LifeSciencesHook(gcp_conn_id=self.gcp_conn_id, api_version=self.api_version) | ||
|
||
return hook.run_pipeline(body=self.body, | ||
location=self.location, | ||
project_id=self.project_id) |
Oops, something went wrong.