Skip to content

Commit

Permalink
Expose SQL to GCS Metadata (#24382)
Browse files Browse the repository at this point in the history
  • Loading branch information
patricker committed Jun 13, 2022
1 parent 8e0bdda commit 94257f4
Show file tree
Hide file tree
Showing 8 changed files with 185 additions and 40 deletions.
42 changes: 40 additions & 2 deletions airflow/providers/google/cloud/transfers/sql_to_gcs.py
Expand Up @@ -71,6 +71,7 @@ class BaseSQLToGCSOperator(BaseOperator):
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param upload_metadata: whether to upload the row count metadata as blob metadata
:param exclude_columns: set of columns to exclude from transmission
"""

Expand Down Expand Up @@ -104,6 +105,7 @@ def __init__(
gcp_conn_id: str = 'google_cloud_default',
delegate_to: Optional[str] = None,
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
upload_metadata: bool = False,
exclude_columns=None,
**kwargs,
) -> None:
Expand All @@ -125,6 +127,7 @@ def __init__(
self.gcp_conn_id = gcp_conn_id
self.delegate_to = delegate_to
self.impersonation_chain = impersonation_chain
self.upload_metadata = upload_metadata
self.exclude_columns = exclude_columns

def execute(self, context: 'Context'):
Expand All @@ -144,6 +147,9 @@ def execute(self, context: 'Context'):
schema_file['file_handle'].close()

counter = 0
files = []
total_row_count = 0
total_files = 0
self.log.info('Writing local data files')
for file_to_upload in self._write_local_data_files(cursor):
# Flush file before uploading
Expand All @@ -154,8 +160,29 @@ def execute(self, context: 'Context'):

self.log.info('Removing local file')
file_to_upload['file_handle'].close()

# Metadata to be outputted to Xcom
total_row_count += file_to_upload['file_row_count']
total_files += 1
files.append(
{
'file_name': file_to_upload['file_name'],
'file_mime_type': file_to_upload['file_mime_type'],
'file_row_count': file_to_upload['file_row_count'],
}
)

counter += 1

file_meta = {
'bucket': self.bucket,
'total_row_count': total_row_count,
'total_files': total_files,
'files': files,
}

return file_meta

def convert_types(self, schema, col_type_dict, row, stringify_dict=False) -> list:
"""Convert values from DBAPI to output-friendly formats."""
return [
Expand Down Expand Up @@ -188,6 +215,7 @@ def _write_local_data_files(self, cursor):
'file_name': self.filename.format(file_no),
'file_handle': tmp_file_handle,
'file_mime_type': file_mime_type,
'file_row_count': 0,
}

if self.export_format == 'csv':
Expand All @@ -197,6 +225,7 @@ def _write_local_data_files(self, cursor):
parquet_writer = self._configure_parquet_file(tmp_file_handle, parquet_schema)

for row in cursor:
file_to_upload['file_row_count'] += 1
if self.export_format == 'csv':
row = self.convert_types(schema, col_type_dict, row)
if self.null_marker is not None:
Expand Down Expand Up @@ -232,14 +261,17 @@ def _write_local_data_files(self, cursor):
'file_name': self.filename.format(file_no),
'file_handle': tmp_file_handle,
'file_mime_type': file_mime_type,
'file_row_count': 0,
}
if self.export_format == 'csv':
csv_writer = self._configure_csv_file(tmp_file_handle, schema)
if self.export_format == 'parquet':
parquet_writer = self._configure_parquet_file(tmp_file_handle, parquet_schema)
if self.export_format == 'parquet':
parquet_writer.close()
yield file_to_upload
# Last file may have 0 rows, don't yield if empty
if file_to_upload['file_row_count'] > 0:
yield file_to_upload

def _configure_csv_file(self, file_handle, schema):
"""Configure a csv writer with the file_handle and write schema
Expand Down Expand Up @@ -350,10 +382,16 @@ def _upload_to_gcs(self, file_to_upload):
delegate_to=self.delegate_to,
impersonation_chain=self.impersonation_chain,
)
is_data_file = file_to_upload.get('file_name') != self.schema_filename
metadata = None
if is_data_file and self.upload_metadata:
metadata = {'row_count': file_to_upload['file_row_count']}

hook.upload(
self.bucket,
file_to_upload.get('file_name'),
file_to_upload.get('file_handle').name,
mime_type=file_to_upload.get('file_mime_type'),
gzip=self.gzip if file_to_upload.get('file_name') != self.schema_filename else False,
gzip=self.gzip if is_data_file else False,
metadata=metadata,
)
6 changes: 3 additions & 3 deletions tests/providers/google/cloud/transfers/test_mssql_to_gcs.py
Expand Up @@ -97,7 +97,7 @@ def test_exec_success_json(self, gcs_hook_mock_class, mssql_hook_mock_class):

gcs_hook_mock = gcs_hook_mock_class.return_value

def _assert_upload(bucket, obj, tmp_filename, mime_type=None, gzip=False):
def _assert_upload(bucket, obj, tmp_filename, mime_type=None, gzip=False, metadata=None):
assert BUCKET == bucket
assert JSON_FILENAME.format(0) == obj
assert 'application/json' == mime_type
Expand Down Expand Up @@ -126,7 +126,7 @@ def test_file_splitting(self, gcs_hook_mock_class, mssql_hook_mock_class):
JSON_FILENAME.format(1): NDJSON_LINES[2],
}

def _assert_upload(bucket, obj, tmp_filename, mime_type=None, gzip=False):
def _assert_upload(bucket, obj, tmp_filename, mime_type=None, gzip=False, metadata=None):
assert BUCKET == bucket
assert 'application/json' == mime_type
assert GZIP == gzip
Expand Down Expand Up @@ -154,7 +154,7 @@ def test_schema_file(self, gcs_hook_mock_class, mssql_hook_mock_class):

gcs_hook_mock = gcs_hook_mock_class.return_value

def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip):
def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip, metadata=None):
if obj == SCHEMA_FILENAME:
with open(tmp_filename, 'rb') as file:
assert b''.join(SCHEMA_JSON) == file.read()
Expand Down
14 changes: 7 additions & 7 deletions tests/providers/google/cloud/transfers/test_mysql_to_gcs.py
Expand Up @@ -124,7 +124,7 @@ def test_exec_success_json(self, gcs_hook_mock_class, mysql_hook_mock_class):

gcs_hook_mock = gcs_hook_mock_class.return_value

def _assert_upload(bucket, obj, tmp_filename, mime_type=None, gzip=False):
def _assert_upload(bucket, obj, tmp_filename, mime_type=None, gzip=False, metadata=None):
assert BUCKET == bucket
assert JSON_FILENAME.format(0) == obj
assert 'application/json' == mime_type
Expand Down Expand Up @@ -158,7 +158,7 @@ def test_exec_success_csv(self, gcs_hook_mock_class, mysql_hook_mock_class):

gcs_hook_mock = gcs_hook_mock_class.return_value

def _assert_upload(bucket, obj, tmp_filename, mime_type=None, gzip=False):
def _assert_upload(bucket, obj, tmp_filename, mime_type=None, gzip=False, metadata=None):
assert BUCKET == bucket
assert CSV_FILENAME.format(0) == obj
assert 'text/csv' == mime_type
Expand Down Expand Up @@ -193,7 +193,7 @@ def test_exec_success_csv_ensure_utc(self, gcs_hook_mock_class, mysql_hook_mock_

gcs_hook_mock = gcs_hook_mock_class.return_value

def _assert_upload(bucket, obj, tmp_filename, mime_type=None, gzip=False):
def _assert_upload(bucket, obj, tmp_filename, mime_type=None, gzip=False, metadata=None):
assert BUCKET == bucket
assert CSV_FILENAME.format(0) == obj
assert 'text/csv' == mime_type
Expand Down Expand Up @@ -228,7 +228,7 @@ def test_exec_success_csv_with_delimiter(self, gcs_hook_mock_class, mysql_hook_m

gcs_hook_mock = gcs_hook_mock_class.return_value

def _assert_upload(bucket, obj, tmp_filename, mime_type=None, gzip=False):
def _assert_upload(bucket, obj, tmp_filename, mime_type=None, gzip=False, metadata=None):
assert BUCKET == bucket
assert CSV_FILENAME.format(0) == obj
assert 'text/csv' == mime_type
Expand Down Expand Up @@ -257,7 +257,7 @@ def test_file_splitting(self, gcs_hook_mock_class, mysql_hook_mock_class):
JSON_FILENAME.format(1): NDJSON_LINES[2],
}

def _assert_upload(bucket, obj, tmp_filename, mime_type=None, gzip=False):
def _assert_upload(bucket, obj, tmp_filename, mime_type=None, gzip=False, metadata=None):
assert BUCKET == bucket
assert 'application/json' == mime_type
assert not gzip
Expand Down Expand Up @@ -285,7 +285,7 @@ def test_schema_file(self, gcs_hook_mock_class, mysql_hook_mock_class):

gcs_hook_mock = gcs_hook_mock_class.return_value

def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip):
def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip, metadata=None):
if obj == SCHEMA_FILENAME:
assert not gzip
with open(tmp_filename, 'rb') as file:
Expand All @@ -311,7 +311,7 @@ def test_schema_file_with_custom_schema(self, gcs_hook_mock_class, mysql_hook_mo

gcs_hook_mock = gcs_hook_mock_class.return_value

def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip):
def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip, metadata=None):
if obj == SCHEMA_FILENAME:
assert not gzip
with open(tmp_filename, 'rb') as file:
Expand Down
6 changes: 3 additions & 3 deletions tests/providers/google/cloud/transfers/test_oracle_to_gcs.py
Expand Up @@ -70,7 +70,7 @@ def test_exec_success_json(self, gcs_hook_mock_class, oracle_hook_mock_class):

gcs_hook_mock = gcs_hook_mock_class.return_value

def _assert_upload(bucket, obj, tmp_filename, mime_type=None, gzip=False):
def _assert_upload(bucket, obj, tmp_filename, mime_type=None, gzip=False, metadata=None):
assert BUCKET == bucket
assert JSON_FILENAME.format(0) == obj
assert 'application/json' == mime_type
Expand Down Expand Up @@ -99,7 +99,7 @@ def test_file_splitting(self, gcs_hook_mock_class, oracle_hook_mock_class):
JSON_FILENAME.format(1): NDJSON_LINES[2],
}

def _assert_upload(bucket, obj, tmp_filename, mime_type=None, gzip=False):
def _assert_upload(bucket, obj, tmp_filename, mime_type=None, gzip=False, metadata=None):
assert BUCKET == bucket
assert 'application/json' == mime_type
assert GZIP == gzip
Expand Down Expand Up @@ -127,7 +127,7 @@ def test_schema_file(self, gcs_hook_mock_class, oracle_hook_mock_class):

gcs_hook_mock = gcs_hook_mock_class.return_value

def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip):
def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip, metadata=None):
if obj == SCHEMA_FILENAME:
with open(tmp_filename, 'rb') as file:
assert b''.join(SCHEMA_JSON) == file.read()
Expand Down
Expand Up @@ -92,7 +92,7 @@ def test_init(self):
assert op.bucket == BUCKET
assert op.filename == FILENAME

def _assert_uploaded_file_content(self, bucket, obj, tmp_filename, mime_type, gzip):
def _assert_uploaded_file_content(self, bucket, obj, tmp_filename, mime_type, gzip, metadata=None):
assert BUCKET == bucket
assert FILENAME.format(0) == obj
assert 'application/json' == mime_type
Expand Down Expand Up @@ -159,7 +159,7 @@ def test_file_splitting(self, gcs_hook_mock_class):
FILENAME.format(1): NDJSON_LINES[2],
}

def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip):
def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip, metadata=None):
assert BUCKET == bucket
assert 'application/json' == mime_type
assert not gzip
Expand All @@ -183,7 +183,7 @@ def test_schema_file(self, gcs_hook_mock_class):

gcs_hook_mock = gcs_hook_mock_class.return_value

def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip):
def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip, metadata=None):
if obj == SCHEMA_FILENAME:
with open(tmp_filename, 'rb') as file:
assert SCHEMA_JSON == file.read()
Expand Down
12 changes: 6 additions & 6 deletions tests/providers/google/cloud/transfers/test_presto_to_gcs.py
Expand Up @@ -65,7 +65,7 @@ def test_init(self):
@patch("airflow.providers.google.cloud.transfers.presto_to_gcs.PrestoHook")
@patch("airflow.providers.google.cloud.transfers.sql_to_gcs.GCSHook")
def test_save_as_json(self, mock_gcs_hook, mock_presto_hook):
def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip):
def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip, metadata=None):
assert BUCKET == bucket
assert FILENAME.format(0) == obj
assert "application/json" == mime_type
Expand Down Expand Up @@ -120,7 +120,7 @@ def test_save_as_json_with_file_splitting(self, mock_gcs_hook, mock_presto_hook)
FILENAME.format(1): NDJSON_LINES[2],
}

def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip):
def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip, metadata=None):
assert BUCKET == bucket
assert "application/json" == mime_type
assert not gzip
Expand Down Expand Up @@ -160,7 +160,7 @@ def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip):
def test_save_as_json_with_schema_file(self, mock_gcs_hook, mock_presto_hook):
"""Test writing schema files."""

def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip):
def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip, metadata=None):
if obj == SCHEMA_FILENAME:
with open(tmp_filename, "rb") as file:
assert SCHEMA_JSON == file.read()
Expand Down Expand Up @@ -199,7 +199,7 @@ def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip):
@patch("airflow.providers.google.cloud.transfers.sql_to_gcs.GCSHook")
@patch("airflow.providers.google.cloud.transfers.presto_to_gcs.PrestoHook")
def test_save_as_csv(self, mock_presto_hook, mock_gcs_hook):
def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip):
def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip, metadata=None):
assert BUCKET == bucket
assert FILENAME.format(0) == obj
assert "text/csv" == mime_type
Expand Down Expand Up @@ -255,7 +255,7 @@ def test_save_as_csv_with_file_splitting(self, mock_gcs_hook, mock_presto_hook):
FILENAME.format(1): b"".join([CSV_LINES[0], CSV_LINES[3]]),
}

def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip):
def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip, metadata=None):
assert BUCKET == bucket
assert "text/csv" == mime_type
assert not gzip
Expand Down Expand Up @@ -296,7 +296,7 @@ def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip):
def test_save_as_csv_with_schema_file(self, mock_gcs_hook, mock_presto_hook):
"""Test writing schema files."""

def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip):
def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip, metadata=None):
if obj == SCHEMA_FILENAME:
with open(tmp_filename, "rb") as file:
assert SCHEMA_JSON == file.read()
Expand Down

0 comments on commit 94257f4

Please sign in to comment.