Skip to content

Commit

Permalink
Add OpenLineage support to S3Operators - Copy, Delete and Create Obje…
Browse files Browse the repository at this point in the history
…ct (#35796)
  • Loading branch information
kacpermuda committed Nov 22, 2023
1 parent fcb91f4 commit 9e159fc
Show file tree
Hide file tree
Showing 2 changed files with 232 additions and 0 deletions.
86 changes: 86 additions & 0 deletions airflow/providers/amazon/aws/operators/s3.py
Expand Up @@ -321,6 +321,33 @@ def execute(self, context: Context):
self.acl_policy,
)

def get_openlineage_facets_on_start(self):
from openlineage.client.run import Dataset

from airflow.providers.openlineage.extractors import OperatorLineage

dest_bucket_name, dest_bucket_key = S3Hook.get_s3_bucket_key(
self.dest_bucket_name, self.dest_bucket_key, "dest_bucket_name", "dest_bucket_key"
)

source_bucket_name, source_bucket_key = S3Hook.get_s3_bucket_key(
self.source_bucket_name, self.source_bucket_key, "source_bucket_name", "source_bucket_key"
)

input_dataset = Dataset(
namespace=f"s3://{source_bucket_name}",
name=source_bucket_key,
)
output_dataset = Dataset(
namespace=f"s3://{dest_bucket_name}",
name=dest_bucket_key,
)

return OperatorLineage(
inputs=[input_dataset],
outputs=[output_dataset],
)


class S3CreateObjectOperator(BaseOperator):
"""
Expand Down Expand Up @@ -409,6 +436,22 @@ def execute(self, context: Context):
else:
s3_hook.load_bytes(self.data, s3_key, s3_bucket, self.replace, self.encrypt, self.acl_policy)

def get_openlineage_facets_on_start(self):
from openlineage.client.run import Dataset

from airflow.providers.openlineage.extractors import OperatorLineage

bucket, key = S3Hook.get_s3_bucket_key(self.s3_bucket, self.s3_key, "dest_bucket", "dest_key")

output_dataset = Dataset(
namespace=f"s3://{bucket}",
name=key,
)

return OperatorLineage(
outputs=[output_dataset],
)


class S3DeleteObjectsOperator(BaseOperator):
"""
Expand Down Expand Up @@ -462,6 +505,8 @@ def __init__(
self.aws_conn_id = aws_conn_id
self.verify = verify

self._keys: str | list[str] = ""

if not exactly_one(prefix is None, keys is None):
raise AirflowException("Either keys or prefix should be set.")

Expand All @@ -476,6 +521,47 @@ def execute(self, context: Context):
keys = self.keys or s3_hook.list_keys(bucket_name=self.bucket, prefix=self.prefix)
if keys:
s3_hook.delete_objects(bucket=self.bucket, keys=keys)
self._keys = keys

def get_openlineage_facets_on_complete(self, task_instance):
"""Implement _on_complete because object keys are resolved in execute()."""
from openlineage.client.facet import (
LifecycleStateChange,
LifecycleStateChangeDatasetFacet,
LifecycleStateChangeDatasetFacetPreviousIdentifier,
)
from openlineage.client.run import Dataset

from airflow.providers.openlineage.extractors import OperatorLineage

if not self._keys:
return OperatorLineage()

keys = self._keys
if isinstance(keys, str):
keys = [keys]

bucket_url = f"s3://{self.bucket}"
input_datasets = [
Dataset(
namespace=bucket_url,
name=key,
facets={
"lifecycleStateChange": LifecycleStateChangeDatasetFacet(
lifecycleStateChange=LifecycleStateChange.DROP.value,
previousIdentifier=LifecycleStateChangeDatasetFacetPreviousIdentifier(
namespace=bucket_url,
name=key,
),
)
},
)
for key in keys
]

return OperatorLineage(
inputs=input_datasets,
)


class S3FileTransformOperator(BaseOperator):
Expand Down
146 changes: 146 additions & 0 deletions tests/providers/amazon/aws/operators/test_s3.py
Expand Up @@ -28,6 +28,12 @@
import boto3
import pytest
from moto import mock_s3
from openlineage.client.facet import (
LifecycleStateChange,
LifecycleStateChangeDatasetFacet,
LifecycleStateChangeDatasetFacetPreviousIdentifier,
)
from openlineage.client.run import Dataset

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
Expand All @@ -44,6 +50,7 @@
S3ListPrefixesOperator,
S3PutBucketTaggingOperator,
)
from airflow.providers.openlineage.extractors import OperatorLineage

BUCKET_NAME = os.environ.get("BUCKET_NAME", "test-airflow-bucket")
S3_KEY = "test-airflow-key"
Expand Down Expand Up @@ -409,6 +416,55 @@ def test_s3_copy_object_arg_combination_2(self):
# the object found should be consistent with dest_key specified earlier
assert objects_in_dest_bucket["Contents"][0]["Key"] == self.dest_key

def test_get_openlineage_facets_on_start_combination_1(self):
expected_input = Dataset(
namespace=f"s3://{self.source_bucket}",
name=self.source_key,
)
expected_output = Dataset(
namespace=f"s3://{self.dest_bucket}",
name=self.dest_key,
)

op = S3CopyObjectOperator(
task_id="test",
source_bucket_name=self.source_bucket,
source_bucket_key=self.source_key,
dest_bucket_name=self.dest_bucket,
dest_bucket_key=self.dest_key,
)

lineage = op.get_openlineage_facets_on_start()
assert len(lineage.inputs) == 1
assert len(lineage.outputs) == 1
assert lineage.inputs[0] == expected_input
assert lineage.outputs[0] == expected_output

def test_get_openlineage_facets_on_start_combination_2(self):
expected_input = Dataset(
namespace=f"s3://{self.source_bucket}",
name=self.source_key,
)
expected_output = Dataset(
namespace=f"s3://{self.dest_bucket}",
name=self.dest_key,
)

source_key_s3_url = f"s3://{self.source_bucket}/{self.source_key}"
dest_key_s3_url = f"s3://{self.dest_bucket}/{self.dest_key}"

op = S3CopyObjectOperator(
task_id="test",
source_bucket_key=source_key_s3_url,
dest_bucket_key=dest_key_s3_url,
)

lineage = op.get_openlineage_facets_on_start()
assert len(lineage.inputs) == 1
assert len(lineage.outputs) == 1
assert lineage.inputs[0] == expected_input
assert lineage.outputs[0] == expected_output


@mock_s3
class TestS3DeleteObjectsOperator:
Expand Down Expand Up @@ -575,6 +631,82 @@ def test_validate_keys_and_prefix_in_execute(self, keys, prefix):
# the object found should be consistent with dest_key specified earlier
assert objects_in_dest_bucket["Contents"][0]["Key"] == key_of_test

@pytest.mark.parametrize("keys", ("path/data.txt", ["path/data.txt"]))
@mock.patch("airflow.providers.amazon.aws.operators.s3.S3Hook")
def test_get_openlineage_facets_on_complete_single_object(self, mock_hook, keys):
bucket = "testbucket"
expected_input = Dataset(
namespace=f"s3://{bucket}",
name="path/data.txt",
facets={
"lifecycleStateChange": LifecycleStateChangeDatasetFacet(
lifecycleStateChange=LifecycleStateChange.DROP.value,
previousIdentifier=LifecycleStateChangeDatasetFacetPreviousIdentifier(
namespace=f"s3://{bucket}",
name="path/data.txt",
),
)
},
)

op = S3DeleteObjectsOperator(task_id="test_task_s3_delete_single_object", bucket=bucket, keys=keys)
op.execute(None)

lineage = op.get_openlineage_facets_on_complete(None)
assert len(lineage.inputs) == 1
assert lineage.inputs[0] == expected_input

@mock.patch("airflow.providers.amazon.aws.operators.s3.S3Hook")
def test_get_openlineage_facets_on_complete_multiple_objects(self, mock_hook):
bucket = "testbucket"
keys = ["path/data1.txt", "path/data2.txt"]
expected_inputs = [
Dataset(
namespace=f"s3://{bucket}",
name="path/data1.txt",
facets={
"lifecycleStateChange": LifecycleStateChangeDatasetFacet(
lifecycleStateChange=LifecycleStateChange.DROP.value,
previousIdentifier=LifecycleStateChangeDatasetFacetPreviousIdentifier(
namespace=f"s3://{bucket}",
name="path/data1.txt",
),
)
},
),
Dataset(
namespace=f"s3://{bucket}",
name="path/data2.txt",
facets={
"lifecycleStateChange": LifecycleStateChangeDatasetFacet(
lifecycleStateChange=LifecycleStateChange.DROP.value,
previousIdentifier=LifecycleStateChangeDatasetFacetPreviousIdentifier(
namespace=f"s3://{bucket}",
name="path/data2.txt",
),
)
},
),
]

op = S3DeleteObjectsOperator(task_id="test_task_s3_delete_single_object", bucket=bucket, keys=keys)
op.execute(None)

lineage = op.get_openlineage_facets_on_complete(None)
assert len(lineage.inputs) == 2
assert lineage.inputs == expected_inputs

@pytest.mark.parametrize("keys", ("", []))
@mock.patch("airflow.providers.amazon.aws.operators.s3.S3Hook")
def test_get_openlineage_facets_on_complete_no_objects(self, mock_hook, keys):
op = S3DeleteObjectsOperator(
task_id="test_task_s3_delete_single_object", bucket="testbucket", keys=keys
)
op.execute(None)

lineage = op.get_openlineage_facets_on_complete(None)
assert lineage == OperatorLineage()


class TestS3CreateObjectOperator:
@mock.patch.object(S3Hook, "load_string")
Expand Down Expand Up @@ -614,3 +746,17 @@ def test_execute_if_s3_bucket_not_provided(self, mock_load_string):
operator.execute(None)

mock_load_string.assert_called_once_with(data, S3_KEY, BUCKET_NAME, False, False, None, None, None)

@pytest.mark.parametrize(("bucket", "key"), (("bucket", "file.txt"), (None, "s3://bucket/file.txt")))
def test_get_openlineage_facets_on_start(self, bucket, key):
expected_output = Dataset(
namespace="s3://bucket",
name="file.txt",
)

op = S3CreateObjectOperator(task_id="test", s3_bucket=bucket, s3_key=key, data="test")

lineage = op.get_openlineage_facets_on_start()
assert len(lineage.inputs) == 0
assert len(lineage.outputs) == 1
assert lineage.outputs[0] == expected_output

0 comments on commit 9e159fc

Please sign in to comment.