Skip to content

Commit

Permalink
Add encoding parameter to GCSToLocalFilesystemOperator to fix #20901 (
Browse files Browse the repository at this point in the history
#20919)

* Fixes #20901

Adds encoding parameter to `GCSToLocalFilesystemOperator` that is used to decode `file_bytes` into a serializable string for XCom
  • Loading branch information
danneaves-ee committed Jan 19, 2022
1 parent b171e03 commit b8526ab
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 5 deletions.
8 changes: 7 additions & 1 deletion airflow/providers/google/cloud/transfers/gcs_to_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ class GCSToLocalFilesystemOperator(BaseOperator):
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:type impersonation_chain: Union[str, Sequence[str]]
:param file_encoding: Optional encoding used to decode file_bytes into a serializable
string that is suitable for storing to XCom. (templated).
:type file_encoding: str
"""

template_fields: Sequence[str] = (
Expand All @@ -80,6 +83,7 @@ class GCSToLocalFilesystemOperator(BaseOperator):
'filename',
'store_to_xcom_key',
'impersonation_chain',
'file_encoding',
)
ui_color = '#f0eee4'

Expand All @@ -94,6 +98,7 @@ def __init__(
google_cloud_storage_conn_id: Optional[str] = None,
delegate_to: Optional[str] = None,
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
file_encoding: Optional[str] = 'utf-8',
**kwargs,
) -> None:
# To preserve backward compatibility
Expand Down Expand Up @@ -126,6 +131,7 @@ def __init__(
self.gcp_conn_id = gcp_conn_id
self.delegate_to = delegate_to
self.impersonation_chain = impersonation_chain
self.file_encoding = file_encoding

def execute(self, context: 'Context'):
self.log.info('Executing download: %s, %s, %s', self.bucket, self.object_name, self.filename)
Expand All @@ -139,7 +145,7 @@ def execute(self, context: 'Context'):
file_size = hook.get_size(bucket_name=self.bucket, object_name=self.object_name)
if file_size < MAX_XCOM_SIZE:
file_bytes = hook.download(bucket_name=self.bucket, object_name=self.object_name)
context['ti'].xcom_push(key=self.store_to_xcom_key, value=str(file_bytes))
context['ti'].xcom_push(key=self.store_to_xcom_key, value=str(file_bytes, self.file_encoding))
else:
raise AirflowException('The size of the downloaded file is too large to push to XCom!')
else:
Expand Down
34 changes: 30 additions & 4 deletions tests/providers/google/cloud/transfers/test_gcs_to_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@
TEST_OBJECT = "dir1/test-object"
LOCAL_FILE_PATH = "/home/airflow/gcp/test-object"
XCOM_KEY = "some_xkom_key"
FILE_CONTENT = "some file content"
FILE_CONTENT_STR = "some file content"
FILE_CONTENT_BYTES_UTF8 = b"some file content"
FILE_CONTENT_BYTES_UTF16 = (
b'\xff\xfes\x00o\x00m\x00e\x00 \x00f\x00i\x00l\x00e\x00 \x00c\x00o\x00n\x00t\x00e\x00n\x00t\x00'
)


class TestGoogleCloudStorageDownloadOperator(unittest.TestCase):
Expand All @@ -61,7 +65,7 @@ def test_size_lt_max_xcom_size(self, mock_hook):
store_to_xcom_key=XCOM_KEY,
)
context = {"ti": MagicMock()}
mock_hook.return_value.download.return_value = FILE_CONTENT
mock_hook.return_value.download.return_value = FILE_CONTENT_BYTES_UTF8
mock_hook.return_value.get_size.return_value = MAX_XCOM_SIZE - 1

operator.execute(context=context)
Expand All @@ -71,7 +75,7 @@ def test_size_lt_max_xcom_size(self, mock_hook):
mock_hook.return_value.download.assert_called_once_with(
bucket_name=TEST_BUCKET, object_name=TEST_OBJECT
)
context["ti"].xcom_push.assert_called_once_with(key=XCOM_KEY, value=FILE_CONTENT)
context["ti"].xcom_push.assert_called_once_with(key=XCOM_KEY, value=FILE_CONTENT_STR)

@mock.patch("airflow.providers.google.cloud.transfers.gcs_to_local.GCSHook")
def test_size_gt_max_xcom_size(self, mock_hook):
Expand All @@ -82,8 +86,30 @@ def test_size_gt_max_xcom_size(self, mock_hook):
store_to_xcom_key=XCOM_KEY,
)
context = {"ti": MagicMock()}
mock_hook.return_value.download.return_value = FILE_CONTENT
mock_hook.return_value.download.return_value = FILE_CONTENT_BYTES_UTF8
mock_hook.return_value.get_size.return_value = MAX_XCOM_SIZE + 1

with pytest.raises(AirflowException, match="file is too large"):
operator.execute(context=context)

@mock.patch("airflow.providers.google.cloud.transfers.gcs_to_local.GCSHook")
def test_xcom_encoding(self, mock_hook):
operator = GCSToLocalFilesystemOperator(
task_id=TASK_ID,
bucket=TEST_BUCKET,
object_name=TEST_OBJECT,
store_to_xcom_key=XCOM_KEY,
file_encoding='utf-16',
)
context = {"ti": MagicMock()}
mock_hook.return_value.download.return_value = FILE_CONTENT_BYTES_UTF16
mock_hook.return_value.get_size.return_value = MAX_XCOM_SIZE - 1

operator.execute(context=context)
mock_hook.return_value.get_size.assert_called_once_with(
bucket_name=TEST_BUCKET, object_name=TEST_OBJECT
)
mock_hook.return_value.download.assert_called_once_with(
bucket_name=TEST_BUCKET, object_name=TEST_OBJECT
)
context["ti"].xcom_push.assert_called_once_with(key=XCOM_KEY, value=FILE_CONTENT_STR)

0 comments on commit b8526ab

Please sign in to comment.