diff --git a/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py b/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py index b9251100ed999..cef0d7157c9e6 100644 --- a/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +++ b/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py @@ -42,9 +42,9 @@ class GCSToBigQueryOperator(BaseOperator): :param bucket: The bucket to load from. (templated) :type bucket: str - :param source_objects: List of Google Cloud Storage URIs to load from. (templated) + :param source_objects: String or List of Google Cloud Storage URIs to load from. (templated) If source_format is 'DATASTORE_BACKUP', the list must only contain a single URI. - :type source_objects: list[str] + :type source_objects: str, list[str] :param destination_project_dataset_table: The dotted ``(.|:).`` BigQuery table to load data into. If ```` is not included, project will be the project defined in @@ -219,7 +219,7 @@ def __init__( if time_partitioning is None: time_partitioning = {} self.bucket = bucket - self.source_objects = source_objects + self.source_objects = source_objects if isinstance(source_objects, list) else [source_objects] self.schema_object = schema_object # BQ config diff --git a/tests/providers/google/cloud/transfers/test_gcs_to_bigquery.py b/tests/providers/google/cloud/transfers/test_gcs_to_bigquery.py index b0c2b3d964fed..9bc3b3b6aa4a6 100644 --- a/tests/providers/google/cloud/transfers/test_gcs_to_bigquery.py +++ b/tests/providers/google/cloud/transfers/test_gcs_to_bigquery.py @@ -26,6 +26,7 @@ TEST_BUCKET = 'test-bucket' MAX_ID_KEY = 'id' TEST_SOURCE_OBJECTS = ['test/objects/*'] +TEST_SOURCE_OBJECTS_AS_STRING = 'test/objects/*' LABELS = {'k1': 'v1'} DESCRIPTION = "Test Description" @@ -216,3 +217,75 @@ def test_description_external_table(self, bq_hook): description=DESCRIPTION, ) # fmt: on + + @mock.patch('airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook') + def test_source_objects_as_list(self, bq_hook): + operator = GCSToBigQueryOperator( + task_id=TASK_ID, + bucket=TEST_BUCKET, + source_objects=TEST_SOURCE_OBJECTS, + destination_project_dataset_table=TEST_EXPLICIT_DEST, + ) + + operator.execute(None) + + bq_hook.return_value.get_conn.return_value.cursor.return_value.run_load.assert_called_once_with( + destination_project_dataset_table=mock.ANY, + schema_fields=mock.ANY, + source_uris=[f'gs://{TEST_BUCKET}/{source_object}' for source_object in TEST_SOURCE_OBJECTS], + source_format=mock.ANY, + autodetect=mock.ANY, + create_disposition=mock.ANY, + skip_leading_rows=mock.ANY, + write_disposition=mock.ANY, + field_delimiter=mock.ANY, + max_bad_records=mock.ANY, + quote_character=mock.ANY, + ignore_unknown_values=mock.ANY, + allow_quoted_newlines=mock.ANY, + allow_jagged_rows=mock.ANY, + encoding=mock.ANY, + schema_update_options=mock.ANY, + src_fmt_configs=mock.ANY, + time_partitioning=mock.ANY, + cluster_fields=mock.ANY, + encryption_configuration=mock.ANY, + labels=mock.ANY, + description=mock.ANY, + ) + + @mock.patch('airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook') + def test_source_objects_as_string(self, bq_hook): + operator = GCSToBigQueryOperator( + task_id=TASK_ID, + bucket=TEST_BUCKET, + source_objects=TEST_SOURCE_OBJECTS_AS_STRING, + destination_project_dataset_table=TEST_EXPLICIT_DEST, + ) + + operator.execute(None) + + bq_hook.return_value.get_conn.return_value.cursor.return_value.run_load.assert_called_once_with( + destination_project_dataset_table=mock.ANY, + schema_fields=mock.ANY, + source_uris=[f'gs://{TEST_BUCKET}/{TEST_SOURCE_OBJECTS_AS_STRING}'], + source_format=mock.ANY, + autodetect=mock.ANY, + create_disposition=mock.ANY, + skip_leading_rows=mock.ANY, + write_disposition=mock.ANY, + field_delimiter=mock.ANY, + max_bad_records=mock.ANY, + quote_character=mock.ANY, + ignore_unknown_values=mock.ANY, + allow_quoted_newlines=mock.ANY, + allow_jagged_rows=mock.ANY, + encoding=mock.ANY, + schema_update_options=mock.ANY, + src_fmt_configs=mock.ANY, + time_partitioning=mock.ANY, + cluster_fields=mock.ANY, + encryption_configuration=mock.ANY, + labels=mock.ANY, + description=mock.ANY, + )