From 94257f48f4a3f123918b0d55c34753c7c413eb74 Mon Sep 17 00:00:00 2001 From: Peter Wicks Date: Mon, 13 Jun 2022 00:55:12 -0600 Subject: [PATCH] Expose SQL to GCS Metadata (#24382) --- .../google/cloud/transfers/sql_to_gcs.py | 42 +++++- .../cloud/transfers/test_mssql_to_gcs.py | 6 +- .../cloud/transfers/test_mysql_to_gcs.py | 14 +- .../cloud/transfers/test_oracle_to_gcs.py | 6 +- .../cloud/transfers/test_postgres_to_gcs.py | 6 +- .../cloud/transfers/test_presto_to_gcs.py | 12 +- .../google/cloud/transfers/test_sql_to_gcs.py | 127 ++++++++++++++++-- .../cloud/transfers/test_trino_to_gcs.py | 12 +- 8 files changed, 185 insertions(+), 40 deletions(-) diff --git a/airflow/providers/google/cloud/transfers/sql_to_gcs.py b/airflow/providers/google/cloud/transfers/sql_to_gcs.py index 46e1ad505d784..c2044790242b0 100644 --- a/airflow/providers/google/cloud/transfers/sql_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/sql_to_gcs.py @@ -71,6 +71,7 @@ class BaseSQLToGCSOperator(BaseOperator): If set as a sequence, the identities from the list must grant Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). + :param upload_metadata: whether to upload the row count metadata as blob metadata :param exclude_columns: set of columns to exclude from transmission """ @@ -104,6 +105,7 @@ def __init__( gcp_conn_id: str = 'google_cloud_default', delegate_to: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + upload_metadata: bool = False, exclude_columns=None, **kwargs, ) -> None: @@ -125,6 +127,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain + self.upload_metadata = upload_metadata self.exclude_columns = exclude_columns def execute(self, context: 'Context'): @@ -144,6 +147,9 @@ def execute(self, context: 'Context'): schema_file['file_handle'].close() counter = 0 + files = [] + total_row_count = 0 + total_files = 0 self.log.info('Writing local data files') for file_to_upload in self._write_local_data_files(cursor): # Flush file before uploading @@ -154,8 +160,29 @@ def execute(self, context: 'Context'): self.log.info('Removing local file') file_to_upload['file_handle'].close() + + # Metadata to be outputted to Xcom + total_row_count += file_to_upload['file_row_count'] + total_files += 1 + files.append( + { + 'file_name': file_to_upload['file_name'], + 'file_mime_type': file_to_upload['file_mime_type'], + 'file_row_count': file_to_upload['file_row_count'], + } + ) + counter += 1 + file_meta = { + 'bucket': self.bucket, + 'total_row_count': total_row_count, + 'total_files': total_files, + 'files': files, + } + + return file_meta + def convert_types(self, schema, col_type_dict, row, stringify_dict=False) -> list: """Convert values from DBAPI to output-friendly formats.""" return [ @@ -188,6 +215,7 @@ def _write_local_data_files(self, cursor): 'file_name': self.filename.format(file_no), 'file_handle': tmp_file_handle, 'file_mime_type': file_mime_type, + 'file_row_count': 0, } if self.export_format == 'csv': @@ -197,6 +225,7 @@ def _write_local_data_files(self, cursor): parquet_writer = self._configure_parquet_file(tmp_file_handle, parquet_schema) for row in cursor: + file_to_upload['file_row_count'] += 1 if self.export_format == 'csv': row = self.convert_types(schema, col_type_dict, row) if self.null_marker is not None: @@ -232,6 +261,7 @@ def _write_local_data_files(self, cursor): 'file_name': self.filename.format(file_no), 'file_handle': tmp_file_handle, 'file_mime_type': file_mime_type, + 'file_row_count': 0, } if self.export_format == 'csv': csv_writer = self._configure_csv_file(tmp_file_handle, schema) @@ -239,7 +269,9 @@ def _write_local_data_files(self, cursor): parquet_writer = self._configure_parquet_file(tmp_file_handle, parquet_schema) if self.export_format == 'parquet': parquet_writer.close() - yield file_to_upload + # Last file may have 0 rows, don't yield if empty + if file_to_upload['file_row_count'] > 0: + yield file_to_upload def _configure_csv_file(self, file_handle, schema): """Configure a csv writer with the file_handle and write schema @@ -350,10 +382,16 @@ def _upload_to_gcs(self, file_to_upload): delegate_to=self.delegate_to, impersonation_chain=self.impersonation_chain, ) + is_data_file = file_to_upload.get('file_name') != self.schema_filename + metadata = None + if is_data_file and self.upload_metadata: + metadata = {'row_count': file_to_upload['file_row_count']} + hook.upload( self.bucket, file_to_upload.get('file_name'), file_to_upload.get('file_handle').name, mime_type=file_to_upload.get('file_mime_type'), - gzip=self.gzip if file_to_upload.get('file_name') != self.schema_filename else False, + gzip=self.gzip if is_data_file else False, + metadata=metadata, ) diff --git a/tests/providers/google/cloud/transfers/test_mssql_to_gcs.py b/tests/providers/google/cloud/transfers/test_mssql_to_gcs.py index b388f4548c200..8b9d820221166 100644 --- a/tests/providers/google/cloud/transfers/test_mssql_to_gcs.py +++ b/tests/providers/google/cloud/transfers/test_mssql_to_gcs.py @@ -97,7 +97,7 @@ def test_exec_success_json(self, gcs_hook_mock_class, mssql_hook_mock_class): gcs_hook_mock = gcs_hook_mock_class.return_value - def _assert_upload(bucket, obj, tmp_filename, mime_type=None, gzip=False): + def _assert_upload(bucket, obj, tmp_filename, mime_type=None, gzip=False, metadata=None): assert BUCKET == bucket assert JSON_FILENAME.format(0) == obj assert 'application/json' == mime_type @@ -126,7 +126,7 @@ def test_file_splitting(self, gcs_hook_mock_class, mssql_hook_mock_class): JSON_FILENAME.format(1): NDJSON_LINES[2], } - def _assert_upload(bucket, obj, tmp_filename, mime_type=None, gzip=False): + def _assert_upload(bucket, obj, tmp_filename, mime_type=None, gzip=False, metadata=None): assert BUCKET == bucket assert 'application/json' == mime_type assert GZIP == gzip @@ -154,7 +154,7 @@ def test_schema_file(self, gcs_hook_mock_class, mssql_hook_mock_class): gcs_hook_mock = gcs_hook_mock_class.return_value - def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip): + def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip, metadata=None): if obj == SCHEMA_FILENAME: with open(tmp_filename, 'rb') as file: assert b''.join(SCHEMA_JSON) == file.read() diff --git a/tests/providers/google/cloud/transfers/test_mysql_to_gcs.py b/tests/providers/google/cloud/transfers/test_mysql_to_gcs.py index c006c230d388f..8d87ea9867225 100644 --- a/tests/providers/google/cloud/transfers/test_mysql_to_gcs.py +++ b/tests/providers/google/cloud/transfers/test_mysql_to_gcs.py @@ -124,7 +124,7 @@ def test_exec_success_json(self, gcs_hook_mock_class, mysql_hook_mock_class): gcs_hook_mock = gcs_hook_mock_class.return_value - def _assert_upload(bucket, obj, tmp_filename, mime_type=None, gzip=False): + def _assert_upload(bucket, obj, tmp_filename, mime_type=None, gzip=False, metadata=None): assert BUCKET == bucket assert JSON_FILENAME.format(0) == obj assert 'application/json' == mime_type @@ -158,7 +158,7 @@ def test_exec_success_csv(self, gcs_hook_mock_class, mysql_hook_mock_class): gcs_hook_mock = gcs_hook_mock_class.return_value - def _assert_upload(bucket, obj, tmp_filename, mime_type=None, gzip=False): + def _assert_upload(bucket, obj, tmp_filename, mime_type=None, gzip=False, metadata=None): assert BUCKET == bucket assert CSV_FILENAME.format(0) == obj assert 'text/csv' == mime_type @@ -193,7 +193,7 @@ def test_exec_success_csv_ensure_utc(self, gcs_hook_mock_class, mysql_hook_mock_ gcs_hook_mock = gcs_hook_mock_class.return_value - def _assert_upload(bucket, obj, tmp_filename, mime_type=None, gzip=False): + def _assert_upload(bucket, obj, tmp_filename, mime_type=None, gzip=False, metadata=None): assert BUCKET == bucket assert CSV_FILENAME.format(0) == obj assert 'text/csv' == mime_type @@ -228,7 +228,7 @@ def test_exec_success_csv_with_delimiter(self, gcs_hook_mock_class, mysql_hook_m gcs_hook_mock = gcs_hook_mock_class.return_value - def _assert_upload(bucket, obj, tmp_filename, mime_type=None, gzip=False): + def _assert_upload(bucket, obj, tmp_filename, mime_type=None, gzip=False, metadata=None): assert BUCKET == bucket assert CSV_FILENAME.format(0) == obj assert 'text/csv' == mime_type @@ -257,7 +257,7 @@ def test_file_splitting(self, gcs_hook_mock_class, mysql_hook_mock_class): JSON_FILENAME.format(1): NDJSON_LINES[2], } - def _assert_upload(bucket, obj, tmp_filename, mime_type=None, gzip=False): + def _assert_upload(bucket, obj, tmp_filename, mime_type=None, gzip=False, metadata=None): assert BUCKET == bucket assert 'application/json' == mime_type assert not gzip @@ -285,7 +285,7 @@ def test_schema_file(self, gcs_hook_mock_class, mysql_hook_mock_class): gcs_hook_mock = gcs_hook_mock_class.return_value - def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip): + def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip, metadata=None): if obj == SCHEMA_FILENAME: assert not gzip with open(tmp_filename, 'rb') as file: @@ -311,7 +311,7 @@ def test_schema_file_with_custom_schema(self, gcs_hook_mock_class, mysql_hook_mo gcs_hook_mock = gcs_hook_mock_class.return_value - def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip): + def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip, metadata=None): if obj == SCHEMA_FILENAME: assert not gzip with open(tmp_filename, 'rb') as file: diff --git a/tests/providers/google/cloud/transfers/test_oracle_to_gcs.py b/tests/providers/google/cloud/transfers/test_oracle_to_gcs.py index a49c224c7aab3..b90510cbae19c 100644 --- a/tests/providers/google/cloud/transfers/test_oracle_to_gcs.py +++ b/tests/providers/google/cloud/transfers/test_oracle_to_gcs.py @@ -70,7 +70,7 @@ def test_exec_success_json(self, gcs_hook_mock_class, oracle_hook_mock_class): gcs_hook_mock = gcs_hook_mock_class.return_value - def _assert_upload(bucket, obj, tmp_filename, mime_type=None, gzip=False): + def _assert_upload(bucket, obj, tmp_filename, mime_type=None, gzip=False, metadata=None): assert BUCKET == bucket assert JSON_FILENAME.format(0) == obj assert 'application/json' == mime_type @@ -99,7 +99,7 @@ def test_file_splitting(self, gcs_hook_mock_class, oracle_hook_mock_class): JSON_FILENAME.format(1): NDJSON_LINES[2], } - def _assert_upload(bucket, obj, tmp_filename, mime_type=None, gzip=False): + def _assert_upload(bucket, obj, tmp_filename, mime_type=None, gzip=False, metadata=None): assert BUCKET == bucket assert 'application/json' == mime_type assert GZIP == gzip @@ -127,7 +127,7 @@ def test_schema_file(self, gcs_hook_mock_class, oracle_hook_mock_class): gcs_hook_mock = gcs_hook_mock_class.return_value - def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip): + def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip, metadata=None): if obj == SCHEMA_FILENAME: with open(tmp_filename, 'rb') as file: assert b''.join(SCHEMA_JSON) == file.read() diff --git a/tests/providers/google/cloud/transfers/test_postgres_to_gcs.py b/tests/providers/google/cloud/transfers/test_postgres_to_gcs.py index ff653292c4ea8..e8007fc427d00 100644 --- a/tests/providers/google/cloud/transfers/test_postgres_to_gcs.py +++ b/tests/providers/google/cloud/transfers/test_postgres_to_gcs.py @@ -92,7 +92,7 @@ def test_init(self): assert op.bucket == BUCKET assert op.filename == FILENAME - def _assert_uploaded_file_content(self, bucket, obj, tmp_filename, mime_type, gzip): + def _assert_uploaded_file_content(self, bucket, obj, tmp_filename, mime_type, gzip, metadata=None): assert BUCKET == bucket assert FILENAME.format(0) == obj assert 'application/json' == mime_type @@ -159,7 +159,7 @@ def test_file_splitting(self, gcs_hook_mock_class): FILENAME.format(1): NDJSON_LINES[2], } - def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip): + def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip, metadata=None): assert BUCKET == bucket assert 'application/json' == mime_type assert not gzip @@ -183,7 +183,7 @@ def test_schema_file(self, gcs_hook_mock_class): gcs_hook_mock = gcs_hook_mock_class.return_value - def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip): + def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip, metadata=None): if obj == SCHEMA_FILENAME: with open(tmp_filename, 'rb') as file: assert SCHEMA_JSON == file.read() diff --git a/tests/providers/google/cloud/transfers/test_presto_to_gcs.py b/tests/providers/google/cloud/transfers/test_presto_to_gcs.py index 80a5a50386965..46b76621f213d 100644 --- a/tests/providers/google/cloud/transfers/test_presto_to_gcs.py +++ b/tests/providers/google/cloud/transfers/test_presto_to_gcs.py @@ -65,7 +65,7 @@ def test_init(self): @patch("airflow.providers.google.cloud.transfers.presto_to_gcs.PrestoHook") @patch("airflow.providers.google.cloud.transfers.sql_to_gcs.GCSHook") def test_save_as_json(self, mock_gcs_hook, mock_presto_hook): - def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip): + def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip, metadata=None): assert BUCKET == bucket assert FILENAME.format(0) == obj assert "application/json" == mime_type @@ -120,7 +120,7 @@ def test_save_as_json_with_file_splitting(self, mock_gcs_hook, mock_presto_hook) FILENAME.format(1): NDJSON_LINES[2], } - def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip): + def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip, metadata=None): assert BUCKET == bucket assert "application/json" == mime_type assert not gzip @@ -160,7 +160,7 @@ def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip): def test_save_as_json_with_schema_file(self, mock_gcs_hook, mock_presto_hook): """Test writing schema files.""" - def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip): + def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip, metadata=None): if obj == SCHEMA_FILENAME: with open(tmp_filename, "rb") as file: assert SCHEMA_JSON == file.read() @@ -199,7 +199,7 @@ def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip): @patch("airflow.providers.google.cloud.transfers.sql_to_gcs.GCSHook") @patch("airflow.providers.google.cloud.transfers.presto_to_gcs.PrestoHook") def test_save_as_csv(self, mock_presto_hook, mock_gcs_hook): - def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip): + def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip, metadata=None): assert BUCKET == bucket assert FILENAME.format(0) == obj assert "text/csv" == mime_type @@ -255,7 +255,7 @@ def test_save_as_csv_with_file_splitting(self, mock_gcs_hook, mock_presto_hook): FILENAME.format(1): b"".join([CSV_LINES[0], CSV_LINES[3]]), } - def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip): + def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip, metadata=None): assert BUCKET == bucket assert "text/csv" == mime_type assert not gzip @@ -296,7 +296,7 @@ def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip): def test_save_as_csv_with_schema_file(self, mock_gcs_hook, mock_presto_hook): """Test writing schema files.""" - def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip): + def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip, metadata=None): if obj == SCHEMA_FILENAME: with open(tmp_filename, "rb") as file: assert SCHEMA_JSON == file.read() diff --git a/tests/providers/google/cloud/transfers/test_sql_to_gcs.py b/tests/providers/google/cloud/transfers/test_sql_to_gcs.py index 824ab8ff317f3..918450e0e5331 100644 --- a/tests/providers/google/cloud/transfers/test_sql_to_gcs.py +++ b/tests/providers/google/cloud/transfers/test_sql_to_gcs.py @@ -127,8 +127,20 @@ def test_exec(self, mock_convert_type, mock_query, mock_upload, mock_writerow, m gzip=True, schema=SCHEMA, gcp_conn_id='google_cloud_default', + upload_metadata=True, ) - operator.execute(context=dict()) + result = operator.execute(context=dict()) + + assert result == { + 'bucket': 'TEST-BUCKET-1', + 'total_row_count': 3, + 'total_files': 3, + 'files': [ + {'file_name': 'test_results_0.csv', 'file_mime_type': 'text/csv', 'file_row_count': 1}, + {'file_name': 'test_results_1.csv', 'file_mime_type': 'text/csv', 'file_row_count': 1}, + {'file_name': 'test_results_2.csv', 'file_mime_type': 'text/csv', 'file_row_count': 1}, + ], + } mock_query.assert_called_once() mock_writerow.assert_has_calls( @@ -142,16 +154,25 @@ def test_exec(self, mock_convert_type, mock_query, mock_upload, mock_writerow, m mock.call(COLUMNS), ] ) - mock_flush.assert_has_calls([mock.call(), mock.call(), mock.call(), mock.call(), mock.call()]) + mock_flush.assert_has_calls([mock.call(), mock.call(), mock.call(), mock.call()]) csv_calls = [] for i in range(0, 3): csv_calls.append( - mock.call(BUCKET, FILENAME.format(i), TMP_FILE_NAME, mime_type='text/csv', gzip=True) + mock.call( + BUCKET, + FILENAME.format(i), + TMP_FILE_NAME, + mime_type='text/csv', + gzip=True, + metadata={'row_count': 1}, + ) ) - json_call = mock.call(BUCKET, SCHEMA_FILE, TMP_FILE_NAME, mime_type=APP_JSON, gzip=False) + json_call = mock.call( + BUCKET, SCHEMA_FILE, TMP_FILE_NAME, mime_type=APP_JSON, gzip=False, metadata=None + ) upload_calls = [json_call, csv_calls[0], csv_calls[1], csv_calls[2]] mock_upload.assert_has_calls(upload_calls) - mock_close.assert_has_calls([mock.call(), mock.call(), mock.call(), mock.call(), mock.call()]) + mock_close.assert_has_calls([mock.call(), mock.call(), mock.call(), mock.call()]) mock_query.reset_mock() mock_flush.reset_mock() @@ -165,7 +186,16 @@ def test_exec(self, mock_convert_type, mock_query, mock_upload, mock_writerow, m operator = DummySQLToGCSOperator( sql=SQL, bucket=BUCKET, filename=FILENAME, task_id=TASK_ID, export_format="json", schema=SCHEMA ) - operator.execute(context=dict()) + result = operator.execute(context=dict()) + + assert result == { + 'bucket': 'TEST-BUCKET-1', + 'total_row_count': 3, + 'total_files': 1, + 'files': [ + {'file_name': 'test_results_0.csv', 'file_mime_type': 'application/json', 'file_row_count': 3} + ], + } mock_query.assert_called_once() mock_write.assert_has_calls( @@ -180,7 +210,59 @@ def test_exec(self, mock_convert_type, mock_query, mock_upload, mock_writerow, m ) mock_flush.assert_called_once() mock_upload.assert_called_once_with( - BUCKET, FILENAME.format(0), TMP_FILE_NAME, mime_type=APP_JSON, gzip=False + BUCKET, FILENAME.format(0), TMP_FILE_NAME, mime_type=APP_JSON, gzip=False, metadata=None + ) + mock_close.assert_called_once() + + mock_query.reset_mock() + mock_flush.reset_mock() + mock_upload.reset_mock() + mock_close.reset_mock() + cursor_mock.reset_mock() + + cursor_mock.__iter__ = Mock(return_value=iter(INPUT_DATA)) + + # Test Metadata Upload + operator = DummySQLToGCSOperator( + sql=SQL, + bucket=BUCKET, + filename=FILENAME, + task_id=TASK_ID, + export_format="json", + schema=SCHEMA, + upload_metadata=True, + ) + result = operator.execute(context=dict()) + + assert result == { + 'bucket': 'TEST-BUCKET-1', + 'total_row_count': 3, + 'total_files': 1, + 'files': [ + {'file_name': 'test_results_0.csv', 'file_mime_type': 'application/json', 'file_row_count': 3} + ], + } + + mock_query.assert_called_once() + mock_write.assert_has_calls( + [ + mock.call(OUTPUT_DATA), + mock.call(b"\n"), + mock.call(OUTPUT_DATA), + mock.call(b"\n"), + mock.call(OUTPUT_DATA), + mock.call(b"\n"), + ] + ) + + mock_flush.assert_called_once() + mock_upload.assert_called_once_with( + BUCKET, + FILENAME.format(0), + TMP_FILE_NAME, + mime_type=APP_JSON, + gzip=False, + metadata={'row_count': 3}, ) mock_close.assert_called_once() @@ -196,12 +278,30 @@ def test_exec(self, mock_convert_type, mock_query, mock_upload, mock_writerow, m operator = DummySQLToGCSOperator( sql=SQL, bucket=BUCKET, filename=FILENAME, task_id=TASK_ID, export_format="parquet", schema=SCHEMA ) - operator.execute(context=dict()) + result = operator.execute(context=dict()) + + assert result == { + 'bucket': 'TEST-BUCKET-1', + 'total_row_count': 3, + 'total_files': 1, + 'files': [ + { + 'file_name': 'test_results_0.csv', + 'file_mime_type': 'application/octet-stream', + 'file_row_count': 3, + } + ], + } mock_query.assert_called_once() mock_flush.assert_called_once() mock_upload.assert_called_once_with( - BUCKET, FILENAME.format(0), TMP_FILE_NAME, mime_type='application/octet-stream', gzip=False + BUCKET, + FILENAME.format(0), + TMP_FILE_NAME, + mime_type='application/octet-stream', + gzip=False, + metadata=None, ) mock_close.assert_called_once() @@ -217,7 +317,14 @@ def test_exec(self, mock_convert_type, mock_query, mock_upload, mock_writerow, m export_format="csv", null_marker="NULL", ) - operator.execute(context=dict()) + result = operator.execute(context=dict()) + + assert result == { + 'bucket': 'TEST-BUCKET-1', + 'total_row_count': 3, + 'total_files': 1, + 'files': [{'file_name': 'test_results_0.csv', 'file_mime_type': 'text/csv', 'file_row_count': 3}], + } mock_writerow.assert_has_calls( [ diff --git a/tests/providers/google/cloud/transfers/test_trino_to_gcs.py b/tests/providers/google/cloud/transfers/test_trino_to_gcs.py index 1e5443f6795b7..50828a36ea544 100644 --- a/tests/providers/google/cloud/transfers/test_trino_to_gcs.py +++ b/tests/providers/google/cloud/transfers/test_trino_to_gcs.py @@ -65,7 +65,7 @@ def test_init(self): @patch("airflow.providers.google.cloud.transfers.trino_to_gcs.TrinoHook") @patch("airflow.providers.google.cloud.transfers.sql_to_gcs.GCSHook") def test_save_as_json(self, mock_gcs_hook, mock_trino_hook): - def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip): + def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip, metadata=None): assert BUCKET == bucket assert FILENAME.format(0) == obj assert "application/json" == mime_type @@ -120,7 +120,7 @@ def test_save_as_json_with_file_splitting(self, mock_gcs_hook, mock_trino_hook): FILENAME.format(1): NDJSON_LINES[2], } - def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip): + def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip, metadata=None): assert BUCKET == bucket assert "application/json" == mime_type assert not gzip @@ -160,7 +160,7 @@ def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip): def test_save_as_json_with_schema_file(self, mock_gcs_hook, mock_trino_hook): """Test writing schema files.""" - def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip): + def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip, metadata=None): if obj == SCHEMA_FILENAME: with open(tmp_filename, "rb") as file: assert SCHEMA_JSON == file.read() @@ -199,7 +199,7 @@ def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip): @patch("airflow.providers.google.cloud.transfers.sql_to_gcs.GCSHook") @patch("airflow.providers.google.cloud.transfers.trino_to_gcs.TrinoHook") def test_save_as_csv(self, mock_trino_hook, mock_gcs_hook): - def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip): + def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip, metadata=None): assert BUCKET == bucket assert FILENAME.format(0) == obj assert "text/csv" == mime_type @@ -255,7 +255,7 @@ def test_save_as_csv_with_file_splitting(self, mock_gcs_hook, mock_trino_hook): FILENAME.format(1): b"".join([CSV_LINES[0], CSV_LINES[3]]), } - def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip): + def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip, metadata=None): assert BUCKET == bucket assert "text/csv" == mime_type assert not gzip @@ -296,7 +296,7 @@ def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip): def test_save_as_csv_with_schema_file(self, mock_gcs_hook, mock_trino_hook): """Test writing schema files.""" - def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip): + def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip, metadata=None): if obj == SCHEMA_FILENAME: with open(tmp_filename, "rb") as file: assert SCHEMA_JSON == file.read()