Skip to content

Commit

Permalink
Add OpenLineage support to GcsOperators - Delete, Transform and TimeS…
Browse files Browse the repository at this point in the history
…panTransform (#35838)
  • Loading branch information
kacpermuda committed Nov 27, 2023
1 parent 1e730f2 commit 99b68e2
Show file tree
Hide file tree
Showing 2 changed files with 245 additions and 10 deletions.
95 changes: 85 additions & 10 deletions airflow/providers/google/cloud/operators/gcs.py
Expand Up @@ -313,6 +313,7 @@ def __init__(
)
raise ValueError(err_message)

self._objects: list[str] = []
super().__init__(**kwargs)

def execute(self, context: Context) -> None:
Expand All @@ -322,13 +323,47 @@ def execute(self, context: Context) -> None:
)

if self.objects is not None:
objects = self.objects
self._objects = self.objects
else:
objects = hook.list(bucket_name=self.bucket_name, prefix=self.prefix)
self.log.info("Deleting %s objects from %s", len(objects), self.bucket_name)
for object_name in objects:
self._objects = hook.list(bucket_name=self.bucket_name, prefix=self.prefix)
self.log.info("Deleting %s objects from %s", len(self._objects), self.bucket_name)
for object_name in self._objects:
hook.delete(bucket_name=self.bucket_name, object_name=object_name)

def get_openlineage_facets_on_complete(self, task_instance):
"""Implementing on_complete as execute() resolves object names."""
from openlineage.client.facet import (
LifecycleStateChange,
LifecycleStateChangeDatasetFacet,
LifecycleStateChangeDatasetFacetPreviousIdentifier,
)
from openlineage.client.run import Dataset

from airflow.providers.openlineage.extractors import OperatorLineage

if not self._objects:
return OperatorLineage()

bucket_url = f"gs://{self.bucket_name}"
input_datasets = [
Dataset(
namespace=bucket_url,
name=object_name,
facets={
"lifecycleStateChange": LifecycleStateChangeDatasetFacet(
lifecycleStateChange=LifecycleStateChange.DROP.value,
previousIdentifier=LifecycleStateChangeDatasetFacetPreviousIdentifier(
namespace=bucket_url,
name=object_name,
),
)
},
)
for object_name in self._objects
]

return OperatorLineage(inputs=input_datasets)


class GCSBucketCreateAclEntryOperator(GoogleCloudBaseOperator):
"""
Expand Down Expand Up @@ -596,6 +631,22 @@ def execute(self, context: Context) -> None:
filename=destination_file.name,
)

def get_openlineage_facets_on_start(self):
from openlineage.client.run import Dataset

from airflow.providers.openlineage.extractors import OperatorLineage

input_dataset = Dataset(
namespace=f"gs://{self.source_bucket}",
name=self.source_object,
)
output_dataset = Dataset(
namespace=f"gs://{self.destination_bucket}",
name=self.destination_object,
)

return OperatorLineage(inputs=[input_dataset], outputs=[output_dataset])


class GCSTimeSpanFileTransformOperator(GoogleCloudBaseOperator):
"""
Expand Down Expand Up @@ -722,6 +773,9 @@ def __init__(
self.upload_continue_on_fail = upload_continue_on_fail
self.upload_num_attempts = upload_num_attempts

self._source_object_names: list[str] = []
self._destination_object_names: list[str] = []

def execute(self, context: Context) -> list[str]:
# Define intervals and prefixes.
try:
Expand Down Expand Up @@ -773,7 +827,7 @@ def execute(self, context: Context) -> list[str]:
)

# Fetch list of files.
blobs_to_transform = source_hook.list_by_timespan(
self._source_object_names = source_hook.list_by_timespan(
bucket_name=self.source_bucket,
prefix=source_prefix_interp,
timespan_start=timespan_start,
Expand All @@ -785,7 +839,7 @@ def execute(self, context: Context) -> list[str]:
temp_output_dir_path = Path(temp_output_dir)

# TODO: download in parallel.
for blob_to_transform in blobs_to_transform:
for blob_to_transform in self._source_object_names:
destination_file = temp_input_dir_path / blob_to_transform
destination_file.parent.mkdir(parents=True, exist_ok=True)
try:
Expand Down Expand Up @@ -822,8 +876,6 @@ def execute(self, context: Context) -> list[str]:

self.log.info("Transformation succeeded. Output temporarily located at %s", temp_output_dir_path)

files_uploaded = []

# TODO: upload in parallel.
for upload_file in temp_output_dir_path.glob("**/*"):
if upload_file.is_dir():
Expand All @@ -844,12 +896,35 @@ def execute(self, context: Context) -> list[str]:
chunk_size=self.chunk_size,
num_max_attempts=self.upload_num_attempts,
)
files_uploaded.append(str(upload_file_name))
self._destination_object_names.append(str(upload_file_name))
except GoogleCloudError:
if not self.upload_continue_on_fail:
raise

return files_uploaded
return self._destination_object_names

def get_openlineage_facets_on_complete(self, task_instance):
"""Implementing on_complete as execute() resolves object names."""
from openlineage.client.run import Dataset

from airflow.providers.openlineage.extractors import OperatorLineage

input_datasets = [
Dataset(
namespace=f"gs://{self.source_bucket}",
name=object_name,
)
for object_name in self._source_object_names
]
output_datasets = [
Dataset(
namespace=f"gs://{self.destination_bucket}",
name=object_name,
)
for object_name in self._destination_object_names
]

return OperatorLineage(inputs=input_datasets, outputs=output_datasets)


class GCSDeleteBucketOperator(GoogleCloudBaseOperator):
Expand Down
160 changes: 160 additions & 0 deletions tests/providers/google/cloud/operators/test_gcs.py
Expand Up @@ -21,6 +21,13 @@
from pathlib import Path
from unittest import mock

from openlineage.client.facet import (
LifecycleStateChange,
LifecycleStateChangeDatasetFacet,
LifecycleStateChangeDatasetFacetPreviousIdentifier,
)
from openlineage.client.run import Dataset

from airflow.providers.google.cloud.operators.gcs import (
GCSBucketCreateAclEntryOperator,
GCSCreateBucketOperator,
Expand Down Expand Up @@ -164,6 +171,49 @@ def test_delete_prefix_as_empty_string(self, mock_hook):
any_order=True,
)

@mock.patch("airflow.providers.google.cloud.operators.gcs.GCSHook")
def test_get_openlineage_facets_on_complete(self, mock_hook):
bucket_url = f"gs://{TEST_BUCKET}"
expected_inputs = [
Dataset(
namespace=bucket_url,
name="folder/a.txt",
facets={
"lifecycleStateChange": LifecycleStateChangeDatasetFacet(
lifecycleStateChange=LifecycleStateChange.DROP.value,
previousIdentifier=LifecycleStateChangeDatasetFacetPreviousIdentifier(
namespace=bucket_url,
name="folder/a.txt",
),
)
},
),
Dataset(
namespace=bucket_url,
name="b.txt",
facets={
"lifecycleStateChange": LifecycleStateChangeDatasetFacet(
lifecycleStateChange=LifecycleStateChange.DROP.value,
previousIdentifier=LifecycleStateChangeDatasetFacetPreviousIdentifier(
namespace=bucket_url,
name="b.txt",
),
)
},
),
]

operator = GCSDeleteObjectsOperator(
task_id=TASK_ID, bucket_name=TEST_BUCKET, objects=["folder/a.txt", "b.txt"]
)

operator.execute(None)

lineage = operator.get_openlineage_facets_on_complete(None)
assert len(lineage.inputs) == 2
assert len(lineage.outputs) == 0
assert lineage.inputs == expected_inputs


class TestGoogleCloudStorageListOperator:
@mock.patch("airflow.providers.google.cloud.operators.gcs.GCSHook")
Expand Down Expand Up @@ -251,6 +301,31 @@ def test_execute(self, mock_hook, mock_subprocess, mock_tempfile):
filename=destination,
)

def test_get_openlineage_facets_on_start(self):
expected_input = Dataset(
namespace=f"gs://{TEST_BUCKET}",
name="folder/a.txt",
)
expected_output = Dataset(
namespace=f"gs://{TEST_BUCKET}2",
name="b.txt",
)

operator = GCSFileTransformOperator(
task_id=TASK_ID,
source_bucket=TEST_BUCKET,
source_object="folder/a.txt",
destination_bucket=f"{TEST_BUCKET}2",
destination_object="b.txt",
transform_script="/path/to_script",
)

lineage = operator.get_openlineage_facets_on_start()
assert len(lineage.inputs) == 1
assert len(lineage.outputs) == 1
assert lineage.inputs[0] == expected_input
assert lineage.outputs[0] == expected_output


class TestGCSTimeSpanFileTransformOperatorDateInterpolation:
def test_execute(self):
Expand Down Expand Up @@ -408,6 +483,91 @@ def test_execute(self, mock_hook, mock_subprocess, mock_tempdir):
]
)

@mock.patch("airflow.providers.google.cloud.operators.gcs.TemporaryDirectory")
@mock.patch("airflow.providers.google.cloud.operators.gcs.subprocess")
@mock.patch("airflow.providers.google.cloud.operators.gcs.GCSHook")
def test_get_openlineage_facets_on_complete(self, mock_hook, mock_subprocess, mock_tempdir):
source_bucket = TEST_BUCKET
source_prefix = "source_prefix"

destination_bucket = TEST_BUCKET + "_dest"
destination_prefix = "destination_prefix"
destination = "destination"

file1 = "file1"
file2 = "file2"

timespan_start = datetime(2015, 2, 1, 15, 16, 17, 345, tzinfo=timezone.utc)
mock_dag = mock.Mock()
mock_dag.following_schedule = lambda x: x + timedelta(hours=1)
context = dict(
execution_date=timespan_start,
dag=mock_dag,
ti=mock.Mock(),
)

mock_tempdir.return_value.__enter__.side_effect = ["source", destination]
mock_hook.return_value.list_by_timespan.return_value = [
f"{source_prefix}/{file1}",
f"{source_prefix}/{file2}",
]

mock_proc = mock.MagicMock()
mock_proc.returncode = 0
mock_proc.stdout.readline = lambda: b""
mock_proc.wait.return_value = None
mock_popen = mock.MagicMock()
mock_popen.return_value.__enter__.return_value = mock_proc

mock_subprocess.Popen = mock_popen
mock_subprocess.PIPE = "pipe"
mock_subprocess.STDOUT = "stdout"

op = GCSTimeSpanFileTransformOperator(
task_id=TASK_ID,
source_bucket=source_bucket,
source_prefix=source_prefix,
source_gcp_conn_id="",
destination_bucket=destination_bucket,
destination_prefix=destination_prefix,
destination_gcp_conn_id="",
transform_script="script.py",
)

with mock.patch.object(Path, "glob") as path_glob:
path_glob.return_value.__iter__.return_value = [
Path(f"{destination}/{file1}"),
Path(f"{destination}/{file2}"),
]
op.execute(context=context)

expected_inputs = [
Dataset(
namespace=f"gs://{source_bucket}",
name=f"{source_prefix}/{file1}",
),
Dataset(
namespace=f"gs://{source_bucket}",
name=f"{source_prefix}/{file2}",
),
]
expected_outputs = [
Dataset(
namespace=f"gs://{destination_bucket}",
name=f"{destination_prefix}/{file1}",
),
Dataset(
namespace=f"gs://{destination_bucket}",
name=f"{destination_prefix}/{file2}",
),
]

lineage = op.get_openlineage_facets_on_complete(None)
assert len(lineage.inputs) == 2
assert len(lineage.outputs) == 2
assert lineage.inputs == expected_inputs
assert lineage.outputs == expected_outputs


class TestGCSDeleteBucketOperator:
@mock.patch("airflow.providers.google.cloud.operators.gcs.GCSHook")
Expand Down

0 comments on commit 99b68e2

Please sign in to comment.