Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom backend XCom S3 #820

Merged
merged 4 commits into from
Dec 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Empty file.
104 changes: 104 additions & 0 deletions astronomer/providers/amazon/aws/xcom_backends/s3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import gzip
import json
import os
import pickle # nosec
import uuid
from datetime import date, datetime
from typing import Any

import pandas as pd
from airflow.configuration import conf
from airflow.models.xcom import BaseXCom
from airflow.providers.amazon.aws.hooks.s3 import S3Hook


class S3XComBackend(BaseXCom):
"""
The S3 custom xcom backend is an xcom custom backend wrapper that handles
serialization and deserialization of common data types.
This overrides the ``TaskInstance.XCom`` object with this wrapper.
"""

@staticmethod
def serialize_value(value: Any) -> Any: # type: ignore[override]
"""Custom XCOM for S3 to serialize the data"""
value = _S3XComBackend.write_and_upload_value(value)
return BaseXCom.serialize_value(value)

@staticmethod
def deserialize_value(result: Any) -> Any:
"""Custom XCOM for GCS to deserialize the data"""
result = BaseXCom.deserialize_value(result)
if isinstance(result, str) and result.startswith(_S3XComBackend.PREFIX):
result = _S3XComBackend.download_and_read_value(result)
return result

def orm_deserialize_value(self) -> str:
"""
Deserialize amethod which is used to reconstruct ORM XCom object.
This method should be overridden in custom XCom backends to avoid
unnecessary request or other resource consuming operations when
creating XCom ORM model.
"""
return f"XCOM is uploaded into S3 bucket: {_S3XComBackend.BUCKET_NAME}"


class _S3XComBackend:
"""
Custom XCom persistence class extends base to support various datatypes.
To use this XCom Backend, add the environment variable `AIRFLOW__CORE__XCOM_BACKEND`
to your environment and set it to
`astronomer.providers.amazon.aws.xcom_backends.s3.S3XComBackend`
"""

PREFIX = os.getenv("XCOM_BACKEND_PREFIX", "s3_xcom_")
AWS_CONN_ID = os.getenv("CONNECTION_NAME", "aws_default")
BUCKET_NAME = os.getenv("XCOM_BACKEND_BUCKET_NAME", "airflow_xcom_backend_default_bucket")
UPLOAD_CONTENT_AS_GZIP = os.getenv("XCOM_BACKEND_UPLOAD_CONTENT_AS_GZIP", False)
PANDAS_DATAFRAME = "dataframe"
DATETIME_OBJECT = "datetime"

@staticmethod
def write_and_upload_value(value: Any) -> str:
"""Convert to string and upload to S3"""
key_str = f"{_S3XComBackend.PREFIX}{uuid.uuid4()}"
hook = S3Hook(aws_conn_id=_S3XComBackend.AWS_CONN_ID)
if conf.getboolean("core", "enable_xcom_pickling"):
value = pickle.dumps(value)
elif isinstance(value, pd.DataFrame):
value = value.to_json()
key_str = f"{key_str}_{_S3XComBackend.PANDAS_DATAFRAME}"
elif isinstance(value, date):
key_str = f"{key_str}_{_S3XComBackend.DATETIME_OBJECT}"
value = value.isoformat()
else:
value = json.dumps(value)
if _S3XComBackend.UPLOAD_CONTENT_AS_GZIP:
key_str = f"{key_str}.gz"
hook.load_string(
bucket_name=_S3XComBackend.BUCKET_NAME, key=key_str, string_data=value, compression="gzip"
)
else:
hook.load_string(bucket_name=_S3XComBackend.BUCKET_NAME, key=key_str, string_data=value)
return key_str

@staticmethod
def download_and_read_value(filename: str) -> Any:
"""Download the file from S3"""
# Here we download the file from S3
hook = S3Hook(aws_conn_id=_S3XComBackend.AWS_CONN_ID)
file = hook.download_file(
bucket_name=_S3XComBackend.BUCKET_NAME, key=filename, preserve_file_name=True
)
with open(file, "rb") as f:
data = f.read()
if filename.endswith(".gz"):
data = gzip.decompress(data)
filename = filename.replace(".gz", "")
if conf.getboolean("core", "enable_xcom_pickling"):
return pickle.loads(data) # nosec
elif filename.endswith(_S3XComBackend.PANDAS_DATAFRAME):
return pd.read_json(data)
elif filename.endswith(_S3XComBackend.DATETIME_OBJECT):
return datetime.fromisoformat(str(data))
return json.loads(data)
Empty file.
264 changes: 264 additions & 0 deletions tests/amazon/aws/xcom_backends/test_s3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
from __future__ import annotations

import contextlib
import gzip
import json
import os
import pickle # nosec
from datetime import datetime
from unittest import mock

import pandas as pd
import pytest
from airflow import settings
from airflow.configuration import conf
from airflow.models.xcom import BaseXCom
from pandas.util.testing import assert_frame_equal

from astronomer.providers.amazon.aws.xcom_backends.s3 import (
S3XComBackend,
_S3XComBackend,
)


@contextlib.contextmanager
def conf_vars(overrides):
original = {}
original_env_vars = {}
for (section, key), value in overrides.items():

env = conf._env_var_name(section, key)
if env in os.environ:
original_env_vars[env] = os.environ.pop(env)

if conf.has_option(section, key):
original[(section, key)] = conf.get(section, key)
else:
original[(section, key)] = None
if value is not None:
if not conf.has_section(section):
conf.add_section(section)
conf.set(section, key, value)
else:
conf.remove_option(section, key)
settings.configure_vars()
try:
yield
finally:
for (section, key), value in original.items():
if value is not None:
conf.set(section, key, value)
else:
conf.remove_option(section, key)
for env, value in original_env_vars.items():
os.environ[env] = value
settings.configure_vars()


@mock.patch("astronomer.providers.amazon.aws.xcom_backends.s3._S3XComBackend.write_and_upload_value")
def test_custom_xcom_s3_serialize(mock_write):
"""
Asserts that custom xcom is serialized or not
"""
real_job_id = "12345_hash"
mock_write.return_value = real_job_id
result = S3XComBackend.serialize_value(real_job_id)
assert result == json.dumps(real_job_id).encode("UTF-8")


@pytest.mark.parametrize(
"job_id",
["1234567890", {"a": "b"}, ["123"]],
)
@mock.patch("uuid.uuid4")
@mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.load_string")
def test_custom_xcom_s3_write_and_upload(mock_upload, mock_uuid, job_id):
"""
Asserts that custom xcom is uploaded and returns the key
"""
mock_uuid.return_value = "12345667890"
result = _S3XComBackend().write_and_upload_value(job_id)
assert result == "s3_xcom_" + "12345667890"


@conf_vars({("core", "enable_xcom_pickling"): "True"})
@pytest.mark.parametrize(
"job_id",
["1234567890", {"a": "b"}, ["123"]],
)
@mock.patch("uuid.uuid4")
@mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.load_string")
def test_custom_xcom_s3_write_and_upload_pickle(mock_upload, mock_uuid, job_id):
"""
Asserts that custom xcom pickle data is uploaded and returns the key
"""
mock_uuid.return_value = "12345667890"
result = _S3XComBackend().write_and_upload_value(job_id)
assert result == "s3_xcom_" + "12345667890"


@mock.patch("uuid.uuid4")
@mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.load_string")
def test_custom_xcom_s3_write_and_upload_pandas(mock_upload, mock_uuid):
"""
Asserts that custom xcom pandas data is uploaded and returns the key
"""
mock_uuid.return_value = "12345667890"
result = _S3XComBackend().write_and_upload_value(pd.DataFrame({"numbers": [1], "colors": ["red"]}))
assert result == "s3_xcom_" + "12345667890" + "_dataframe"


@mock.patch("uuid.uuid4")
@mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.load_string")
def test_custom_xcom_s3_write_and_upload_datetime(mock_upload, mock_uuid):
"""
Asserts that custom xcom datetime object is uploaded and returns the key
"""
mock_uuid.return_value = "12345667890"
result = _S3XComBackend().write_and_upload_value(datetime.now())
assert result == "s3_xcom_" + "12345667890" + "_datetime"


@pytest.mark.parametrize(
"job_id",
["1234567890", {"a": "b"}, ["123"]],
)
@mock.patch("uuid.uuid4")
@mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.load_string")
def test_custom_xcom_s3_write_and_upload_as_gzip(mock_upload, mock_uuid, job_id):
"""
Asserts that custom xcom as gzip is uploaded and returns the key
"""
mock_uuid.return_value = "12345667890"
_S3XComBackend.UPLOAD_CONTENT_AS_GZIP = mock.patch.dict(
os.environ, {"XCOM_BACKEND_UPLOAD_CONTENT_AS_GZIP": True}, clear=True
)
result = _S3XComBackend().write_and_upload_value(job_id)
assert result == "s3_xcom_" + "12345667890.gz"


@pytest.mark.parametrize(
"job_id",
["s3_xcom__1234"],
)
@mock.patch("astronomer.providers.amazon.aws.xcom_backends.s3._S3XComBackend.download_and_read_value")
def test_custom_xcom_s3_deserialize(mock_download, job_id):
"""
Asserts that custom xcom is deserialized and check for data
"""

mock_download.return_value = job_id
real_job_id = BaseXCom(value=json.dumps(job_id).encode("UTF-8"))
result = S3XComBackend.deserialize_value(real_job_id)
assert result == job_id


@pytest.mark.parametrize(
"job_id",
["gcs_xcom_1234"],
)
@mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.download_file")
@mock.patch("builtins.open", create=True)
def test_custom_xcom_s3_download_and_read_value(mock_open, mock_download, job_id):
"""
Asserts that custom xcom is read the data and validate it.
"""
mock_open.side_effect = [mock.mock_open(read_data=json.dumps(job_id)).return_value]
mock_download.return_value = job_id
result = _S3XComBackend().download_and_read_value(job_id)
assert result == job_id


@pytest.mark.parametrize(
"job_id",
["gcs_xcom_1234_dataframe"],
)
@mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.download_file")
@mock.patch("builtins.open", create=True)
def test_custom_xcom_s3_download_and_read_value_pandas(mock_open, mock_download, job_id):
"""
Asserts that custom xcom is read the pandas data and validate it.
"""
mock_open.side_effect = [
mock.mock_open(read_data=pd.DataFrame({"numbers": [1], "colors": ["red"]}).to_json()).return_value
]
mock_download.return_value = job_id
result = _S3XComBackend().download_and_read_value(job_id)
assert_frame_equal(result, pd.DataFrame({"numbers": [1], "colors": ["red"]}))


@pytest.mark.parametrize(
"job_id",
["gcs_xcom_1234_datetime"],
)
@mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.download_file")
@mock.patch("builtins.open", create=True)
def test_custom_xcom_s3_download_and_read_value_datetime(mock_open, mock_download, job_id):
"""
Asserts that custom xcom is read the datetime object and validate it.
"""
time = datetime.now()
mock_open.side_effect = [mock.mock_open(read_data=time.isoformat()).return_value]
mock_download.return_value = job_id
result = _S3XComBackend().download_and_read_value(job_id)
assert result == time


@conf_vars({("core", "enable_xcom_pickling"): "True"})
@pytest.mark.parametrize(
"job_id",
["gcs_xcom_1234"],
)
@mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.download_file")
@mock.patch("builtins.open", create=True)
def test_custom_xcom_s3_download_and_read_value_pickle(mock_open, mock_download, job_id):
"""
Asserts that custom xcom is read the pickle data and validate it.
"""
mock_open.side_effect = [mock.mock_open(read_data=pickle.dumps(job_id)).return_value]
mock_download.return_value = job_id
result = _S3XComBackend().download_and_read_value(job_id)
assert result == job_id


@pytest.mark.parametrize(
"job_id",
["gcs_xcom_1234"],
)
@mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.download_file")
@mock.patch("builtins.open", create=True)
def test_custom_xcom_s3_download_and_read_value_bytes(mock_open, mock_download, job_id):
"""
Asserts that custom xcom is read the bytes data and validate it.
"""
mock_open.side_effect = [mock.mock_open(read_data=b'{ "Class": "Email addresses"}').return_value]
mock_download.return_value = job_id
result = _S3XComBackend().download_and_read_value(job_id)
assert result == {"Class": "Email addresses"}


@pytest.mark.parametrize(
"job_id",
["gcs_xcom_1234.gz"],
)
@mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.download_file")
@mock.patch("builtins.open", create=True)
def test_custom_xcom_s3_download_and_read_value_gzip(mock_open, mock_download, job_id):
"""
Asserts that custom xcom is gzip content and validate it.
"""
mock_open.side_effect = [
mock.mock_open(read_data=gzip.compress(b'{"Class": "Email addresses"}')).return_value
]
mock_download.return_value = job_id
result = _S3XComBackend().download_and_read_value(job_id)
assert result == {"Class": "Email addresses"}


def test_custom_xcom_s3_orm_deserialize_value():
"""
Asserts that custom xcom has called the orm deserialized
value method and check for data.
"""
result = S3XComBackend().orm_deserialize_value()
assert result == "XCOM is uploaded into S3 bucket: airflow_xcom_backend_default_bucket"