Skip to content
Permalink
Browse files
TextToSpeech assets & system tests migration (AIP-47) (#23247)
  • Loading branch information
bhirsz committed May 4, 2022
1 parent 3fb8e0b commit dfe0f759381c13a2c81212368d3c0c43f57da660
Showing 7 changed files with 74 additions and 63 deletions.
@@ -28,6 +28,7 @@
from airflow.models import BaseOperator
from airflow.providers.google.cloud.hooks.gcs import GCSHook
from airflow.providers.google.cloud.hooks.text_to_speech import CloudTextToSpeechHook
from airflow.providers.google.common.links.storage import FileDetailsLink

if TYPE_CHECKING:
from airflow.utils.context import Context
@@ -80,6 +81,7 @@ class CloudTextToSpeechSynthesizeOperator(BaseOperator):
"impersonation_chain",
)
# [END gcp_text_to_speech_synthesize_template_fields]
operator_extra_links = (FileDetailsLink(),)

def __init__(
self,
@@ -141,3 +143,9 @@ def execute(self, context: 'Context') -> None:
cloud_storage_hook.upload(
bucket_name=self.target_bucket_name, object_name=self.target_filename, filename=temp_file.name
)
FileDetailsLink.persist(
context=context,
task_instance=self,
uri=f"{self.target_bucket_name}/{self.target_filename}",
project_id=cloud_storage_hook.project_id,
)
@@ -15,12 +15,14 @@
# specific language governing permissions and limitations
# under the License.
"""This module contains a link for GCS Storage assets."""
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional

from airflow.models import BaseOperator
from airflow.providers.google.cloud.links.base import BaseGoogleLink

BASE_LINK = "https://console.cloud.google.com"
GCS_STORAGE_LINK = BASE_LINK + "/storage/browser/{uri};tab=objects?project={project_id}"
GCS_FILE_DETAILS_LINK = BASE_LINK + "/storage/browser/_details/{uri};tab=live_object?project={project_id}"

if TYPE_CHECKING:
from airflow.utils.context import Context
@@ -40,3 +42,19 @@ def persist(context: "Context", task_instance, uri: str):
key=StorageLink.key,
value={"uri": uri, "project_id": task_instance.project_id},
)


class FileDetailsLink(BaseGoogleLink):
"""Helper class for constructing GCS file details link"""

name = "GCS File Details"
key = "file_details"
format_str = GCS_FILE_DETAILS_LINK

@staticmethod
def persist(context: "Context", task_instance: BaseOperator, uri: str, project_id: Optional[str]):
task_instance.xcom_push(
context=context,
key=FileDetailsLink.key,
value={"uri": uri, "project_id": project_id},
)
@@ -912,6 +912,7 @@ extra-links:
- airflow.providers.google.cloud.links.bigtable.BigtableClusterLink
- airflow.providers.google.cloud.links.bigtable.BigtableTablesLink
- airflow.providers.google.common.links.storage.StorageLink
- airflow.providers.google.common.links.storage.FileDetailsLink

additional-extras:
apache.beam: apache-beam[gcp]
@@ -42,22 +42,22 @@ The ``input``, ``voice`` and ``audio_config`` arguments need to be dicts or obje

for more information, see: https://googleapis.github.io/google-cloud-python/latest/texttospeech/gapic/v1/api.html#google.cloud.texttospeech_v1.TextToSpeechClient.synthesize_speech

.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_text_to_speech.py
.. exampleinclude:: /../../tests/system/providers/google/text_to_speech/example_text_to_speech.py
:language: python
:start-after: [START howto_operator_text_to_speech_api_arguments]
:end-before: [END howto_operator_text_to_speech_api_arguments]

The ``filename`` argument is a simple string argument:

.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_text_to_speech.py
.. exampleinclude:: /../../tests/system/providers/google/text_to_speech/example_text_to_speech.py
:language: python
:start-after: [START howto_operator_text_to_speech_gcp_filename]
:end-before: [END howto_operator_text_to_speech_gcp_filename]

Using the operator
""""""""""""""""""

.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_text_to_speech.py
.. exampleinclude:: /../../tests/system/providers/google/text_to_speech/example_text_to_speech.py
:language: python
:dedent: 4
:start-after: [START howto_operator_text_to_speech_synthesize]
@@ -17,7 +17,7 @@
# under the License.

import unittest
from unittest.mock import ANY, Mock, PropertyMock, patch
from unittest.mock import ANY, MagicMock, Mock, PropertyMock, patch

import pytest
from google.api_core.gapic_v1.method import DEFAULT
@@ -42,6 +42,7 @@ class TestGcpTextToSpeech(unittest.TestCase):
def test_synthesize_text_green_path(self, mock_text_to_speech_hook, mock_gcp_hook):
mocked_response = Mock()
type(mocked_response).audio_content = PropertyMock(return_value=b"audio")
mocked_context = MagicMock()

mock_text_to_speech_hook.return_value.synthesize_speech.return_value = mocked_response
mock_gcp_hook.return_value.upload.return_value = True
@@ -56,7 +57,7 @@ def test_synthesize_text_green_path(self, mock_text_to_speech_hook, mock_gcp_hoo
target_filename=TARGET_FILENAME,
task_id="id",
impersonation_chain=IMPERSONATION_CHAIN,
).execute(context={"task_instance": Mock()})
).execute(context=mocked_context)

mock_text_to_speech_hook.assert_called_once_with(
gcp_conn_id="gcp-conn-id",
@@ -95,6 +96,8 @@ def test_missing_arguments(
mock_text_to_speech_hook,
mock_gcp_hook,
):
mocked_context = Mock()

with pytest.raises(AirflowException) as ctx:
CloudTextToSpeechSynthesizeOperator(
project_id="project-id",
@@ -104,7 +107,7 @@ def test_missing_arguments(
target_bucket_name=target_bucket_name,
target_filename=target_filename,
task_id="id",
).execute(context={"task_instance": Mock()})
).execute(context=mocked_context)

err = ctx.value
assert missing_arg in str(err)

This file was deleted.

@@ -20,10 +20,15 @@
from datetime import datetime

from airflow import models
from airflow.providers.google.cloud.operators.gcs import GCSCreateBucketOperator, GCSDeleteBucketOperator
from airflow.providers.google.cloud.operators.text_to_speech import CloudTextToSpeechSynthesizeOperator
from airflow.utils.trigger_rule import TriggerRule

GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project")
BUCKET_NAME = os.environ.get("GCP_TEXT_TO_SPEECH_BUCKET", "gcp-text-to-speech-test-bucket")
ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID")
PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT")
DAG_ID = "text_to_speech"

BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}"

# [START howto_operator_text_to_speech_gcp_filename]
FILENAME = "gcp-speech-test-file"
@@ -36,31 +41,48 @@
# [END howto_operator_text_to_speech_api_arguments]

with models.DAG(
"example_gcp_text_to_speech",
schedule_interval='@once', # Override to match your needs
DAG_ID,
schedule_interval="@once",
start_date=datetime(2021, 1, 1),
catchup=False,
tags=['example'],
tags=["example", "text_to_speech"],
) as dag:
create_bucket = GCSCreateBucketOperator(
task_id="create_bucket", bucket_name=BUCKET_NAME, project_id=PROJECT_ID
)

# [START howto_operator_text_to_speech_synthesize]
text_to_speech_synthesize_task = CloudTextToSpeechSynthesizeOperator(
project_id=GCP_PROJECT_ID,
input_data=INPUT,
voice=VOICE,
audio_config=AUDIO_CONFIG,
target_bucket_name=BUCKET_NAME,
target_filename=FILENAME,
task_id="text_to_speech_synthesize_task",
)
text_to_speech_synthesize_task2 = CloudTextToSpeechSynthesizeOperator(
input_data=INPUT,
voice=VOICE,
audio_config=AUDIO_CONFIG,
target_bucket_name=BUCKET_NAME,
target_filename=FILENAME,
task_id="text_to_speech_synthesize_task2",
)
# [END howto_operator_text_to_speech_synthesize]

text_to_speech_synthesize_task >> text_to_speech_synthesize_task2
delete_bucket = GCSDeleteBucketOperator(
task_id="delete_bucket", bucket_name=BUCKET_NAME, trigger_rule=TriggerRule.ALL_DONE
)

(
# TEST SETUP
create_bucket
# TEST BODY
>> text_to_speech_synthesize_task
# TEST TEARDOWN
>> delete_bucket
)

from tests.system.utils.watcher import watcher

# This test needs watcher in order to properly mark success/failure
# when "tearDown" task with trigger rule is part of the DAG
list(dag.tasks) >> watcher()


from tests.system.utils import get_test_run # noqa: E402

# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest)
test_run = get_test_run(dag)

0 comments on commit dfe0f75

Please sign in to comment.