Skip to content

Commit

Permalink
use open to read and add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rajaths010494 committed Dec 20, 2022
1 parent 7f87c8a commit 71e9249
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 4 deletions.
9 changes: 5 additions & 4 deletions astronomer/providers/amazon/aws/xcom_backends/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import pickle # nosec
import uuid
from datetime import date, datetime
from io import BytesIO
from typing import Any

import pandas as pd
Expand Down Expand Up @@ -88,9 +87,11 @@ def download_and_read_value(filename: str) -> Any:
"""Download the file from S3"""
# Here we download the file from S3
hook = S3Hook(aws_conn_id=_S3XComBackend.AWS_CONN_ID)
f = BytesIO()
hook.get_conn().download_fileobj(_S3XComBackend.BUCKET_NAME, filename, f)
data = f.getvalue()
file = hook.download_file(
bucket_name=_S3XComBackend.BUCKET_NAME, key=filename, preserve_file_name=True
)
with open(file, "rb") as f:
data = f.read()
if filename.endswith(".gz"):
data = gzip.decompress(data)
filename = filename.replace(".gz", "")
Expand Down
105 changes: 105 additions & 0 deletions tests/amazon/aws/xcom_backends/test_s3.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

import contextlib
import gzip
import json
import os
import pickle # nosec
from datetime import datetime
from unittest import mock

Expand All @@ -11,6 +13,7 @@
from airflow import settings
from airflow.configuration import conf
from airflow.models.xcom import BaseXCom
from pandas.util.testing import assert_frame_equal

from astronomer.providers.amazon.aws.xcom_backends.s3 import (
S3XComBackend,
Expand Down Expand Up @@ -150,6 +153,108 @@ def test_custom_xcom_s3_deserialize(mock_download, job_id):
assert result == job_id


@pytest.mark.parametrize(
"job_id",
["gcs_xcom_1234"],
)
@mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.download_file")
@mock.patch("builtins.open", create=True)
def test_custom_xcom_s3_download_and_read_value(mock_open, mock_download, job_id):
"""
Asserts that custom xcom is deserialized and check for data
"""
mock_open.side_effect = [mock.mock_open(read_data=json.dumps(job_id)).return_value]
mock_download.return_value = job_id
result = _S3XComBackend().download_and_read_value(job_id)
assert result == job_id


@pytest.mark.parametrize(
"job_id",
["gcs_xcom_1234_dataframe"],
)
@mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.download_file")
@mock.patch("builtins.open", create=True)
def test_custom_xcom_s3_download_and_read_value_pandas(mock_open, mock_download, job_id):
"""
Asserts that custom xcom is deserialized and check for data
"""
mock_open.side_effect = [
mock.mock_open(read_data=pd.DataFrame({"numbers": [1], "colors": ["red"]}).to_json()).return_value
]
mock_download.return_value = job_id
result = _S3XComBackend().download_and_read_value(job_id)
assert_frame_equal(result, pd.DataFrame({"numbers": [1], "colors": ["red"]}))


@pytest.mark.parametrize(
"job_id",
["gcs_xcom_1234_datetime"],
)
@mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.download_file")
@mock.patch("builtins.open", create=True)
def test_custom_xcom_s3_download_and_read_value_datetime(mock_open, mock_download, job_id):
"""
Asserts that custom xcom is deserialized and check for data
"""
time = datetime.now()
mock_open.side_effect = [mock.mock_open(read_data=time.isoformat()).return_value]
mock_download.return_value = job_id
result = _S3XComBackend().download_and_read_value(job_id)
assert result == time


@conf_vars({("core", "enable_xcom_pickling"): "True"})
@pytest.mark.parametrize(
"job_id",
["gcs_xcom_1234"],
)
@mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.download_file")
@mock.patch("builtins.open", create=True)
def test_custom_xcom_s3_download_and_read_value_pickle(mock_open, mock_download, job_id):
"""
Asserts that custom xcom is deserialized and check for data
"""
mock_open.side_effect = [mock.mock_open(read_data=pickle.dumps(job_id)).return_value]
mock_download.return_value = job_id
result = _S3XComBackend().download_and_read_value(job_id)
assert result == job_id


@pytest.mark.parametrize(
"job_id",
["gcs_xcom_1234"],
)
@mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.download_file")
@mock.patch("builtins.open", create=True)
def test_custom_xcom_s3_download_and_read_value_bytes(mock_open, mock_download, job_id):
"""
Asserts that custom xcom is deserialized and check for data
"""
mock_open.side_effect = [mock.mock_open(read_data=b'{ "Class": "Email addresses"}').return_value]
mock_download.return_value = job_id
result = _S3XComBackend().download_and_read_value(job_id)
assert result == {"Class": "Email addresses"}


@pytest.mark.parametrize(
"job_id",
["gcs_xcom_1234.gz"],
)
@mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.download_file")
@mock.patch("builtins.open", create=True)
def test_custom_xcom_s3_download_and_read_value_gzip(mock_open, mock_download, job_id):
"""
Asserts that custom xcom is deserialized and check for data
"""
mock_open.side_effect = [
mock.mock_open(read_data=gzip.compress(b'{"Class": "Email addresses"}')).return_value
]
mock_download.return_value = job_id
result = _S3XComBackend().download_and_read_value(job_id)
assert result == {"Class": "Email addresses"}


def test_custom_xcom_s3_orm_deserialize_value():
"""
Asserts that custom xcom has called the orm deserialized
Expand Down

0 comments on commit 71e9249

Please sign in to comment.