From a2b14003b210756176d53b3876f323e8405f8aa1 Mon Sep 17 00:00:00 2001 From: johnjcasey Date: Thu, 9 Dec 2021 11:09:32 -0500 Subject: [PATCH] [BEAM-13355] add Big Query parameter to enable users to specify load_job_project_id --- sdks/python/apache_beam/io/gcp/bigquery.py | 9 +++-- .../apache_beam/io/gcp/bigquery_file_loads.py | 36 +++++++++++++------ .../io/gcp/bigquery_file_loads_test.py | 33 +++++++++++++++++ .../apache_beam/io/gcp/bigquery_tools.py | 7 ++-- .../apache_beam/io/gcp/bigquery_tools_test.py | 12 +++++++ 5 files changed, 83 insertions(+), 14 deletions(-) diff --git a/sdks/python/apache_beam/io/gcp/bigquery.py b/sdks/python/apache_beam/io/gcp/bigquery.py index 0d98d7b6fc6a5..f38b8d0a1b147 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery.py +++ b/sdks/python/apache_beam/io/gcp/bigquery.py @@ -1926,7 +1926,8 @@ def __init__( ignore_insert_ids=False, # TODO(BEAM-11857): Switch the default when the feature is mature. with_auto_sharding=False, - ignore_unknown_columns=False): + ignore_unknown_columns=False, + load_job_project_id=None): """Initialize a WriteToBigQuery transform. Args: @@ -2058,6 +2059,8 @@ def __init__( which treats unknown values as errors. This option is only valid for method=STREAMING_INSERTS. See reference: https://cloud.google.com/bigquery/docs/reference/rest/v2/tabledata/insertAll + load_job_project_id: Specifies an alternate GCP project id to use for billing + Batch File Loads. By default, the project id of the table is used. """ self._table = table self._dataset = dataset @@ -2092,6 +2095,7 @@ def __init__( self.schema_side_inputs = schema_side_inputs or () self._ignore_insert_ids = ignore_insert_ids self._ignore_unknown_columns = ignore_unknown_columns + self.load_job_project_id = load_job_project_id # Dict/schema methods were moved to bigquery_tools, but keep references # here for backward compatibility. @@ -2185,7 +2189,8 @@ def expand(self, pcoll): schema_side_inputs=self.schema_side_inputs, additional_bq_parameters=self.additional_bq_parameters, validate=self._validate, - is_streaming_pipeline=is_streaming_pipeline) + is_streaming_pipeline=is_streaming_pipeline, + load_job_project_id=self.load_job_project_id) def display_data(self): res = {} diff --git a/sdks/python/apache_beam/io/gcp/bigquery_file_loads.py b/sdks/python/apache_beam/io/gcp/bigquery_file_loads.py index e211ab4ebf8d5..4261caacd40b1 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_file_loads.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_file_loads.py @@ -345,12 +345,14 @@ def __init__( test_client=None, additional_bq_parameters=None, step_name=None, - source_format=None): + source_format=None, + load_job_project_id=None): self._test_client = test_client self._write_disposition = write_disposition self._additional_bq_parameters = additional_bq_parameters or {} self._step_name = step_name self._source_format = source_format + self._load_job_project_id=load_job_project_id def setup(self): self._bq_wrapper = bigquery_tools.BigQueryWrapper(client=self._test_client) @@ -439,7 +441,8 @@ def process(self, element, schema_mod_job_name_prefix): create_disposition='CREATE_NEVER', additional_load_parameters=additional_parameters, job_labels=self._bq_io_metadata.add_additional_bq_job_labels(), - source_format=self._source_format) + source_format=self._source_format, + load_job_project_id=self._load_job_project_id) yield (destination, schema_update_job_reference) @@ -462,13 +465,15 @@ def __init__( create_disposition=None, write_disposition=None, test_client=None, - step_name=None): + step_name=None, + load_job_project_id=None): self.create_disposition = create_disposition self.write_disposition = write_disposition self.test_client = test_client self._observed_tables = set() self.bq_io_metadata = None self._step_name = step_name + self.load_job_project_id = load_job_project_id def display_data(self): return { @@ -527,8 +532,10 @@ def process(self, element, job_name_prefix=None, unused_schema_mod_jobs=None): if not self.bq_io_metadata: self.bq_io_metadata = create_bigquery_io_metadata(self._step_name) + + project_id = copy_to_reference.projectId if self.load_job_project_id is None else self.load_job_project_id job_reference = self.bq_wrapper._insert_copy_job( - copy_to_reference.projectId, + project_id, copy_job_name, copy_from_reference, copy_to_reference, @@ -559,7 +566,8 @@ def __init__( temporary_tables=False, additional_bq_parameters=None, source_format=None, - step_name=None): + step_name=None, + load_job_project_id=None): self.schema = schema self.test_client = test_client self.temporary_tables = temporary_tables @@ -567,6 +575,7 @@ def __init__( self.source_format = source_format self.bq_io_metadata = None self._step_name = step_name + self.load_job_project_id = load_job_project_id if self.temporary_tables: # If we are loading into temporary tables, we rely on the default create # and write dispositions, which mean that a new table will be created. @@ -663,7 +672,8 @@ def process(self, element, load_job_name_prefix, *schema_side_inputs): create_disposition=create_disposition, additional_load_parameters=additional_parameters, source_format=self.source_format, - job_labels=self.bq_io_metadata.add_additional_bq_job_labels()) + job_labels=self.bq_io_metadata.add_additional_bq_job_labels(), + load_job_project_id=self.load_job_project_id) yield (destination, job_reference) @@ -789,7 +799,8 @@ def __init__( schema_side_inputs=None, test_client=None, validate=True, - is_streaming_pipeline=False): + is_streaming_pipeline=False, + load_job_project_id=None): self.destination = destination self.create_disposition = create_disposition self.write_disposition = write_disposition @@ -823,6 +834,7 @@ def __init__( self.schema_side_inputs = schema_side_inputs or () self.is_streaming_pipeline = is_streaming_pipeline + self.load_job_project_id = load_job_project_id self._validate = validate if self._validate: self.verify() @@ -1005,7 +1017,8 @@ def _load_data( temporary_tables=True, additional_bq_parameters=self.additional_bq_parameters, source_format=self._temp_file_format, - step_name=step_name), + step_name=step_name, + load_job_project_id=self.load_job_project_id), load_job_name_pcv, *self.schema_side_inputs).with_outputs( TriggerLoadJobs.TEMP_TABLES, main='main')) @@ -1029,6 +1042,7 @@ def _load_data( additional_bq_parameters=self.additional_bq_parameters, step_name=step_name, source_format=self._temp_file_format, + load_job_project_id=self.load_job_project_id ), schema_mod_job_name_pcv)) @@ -1046,7 +1060,8 @@ def _load_data( create_disposition=self.create_disposition, write_disposition=self.write_disposition, test_client=self.test_client, - step_name=step_name), + step_name=step_name, + load_job_project_id=self.load_job_project_id), copy_job_name_pcv, pvalue.AsIter(finished_schema_mod_jobs_pc))) @@ -1084,7 +1099,8 @@ def _load_data( temporary_tables=False, additional_bq_parameters=self.additional_bq_parameters, source_format=self._temp_file_format, - step_name=step_name), + step_name=step_name, + load_job_project_id=self.load_job_project_id), load_job_name_pcv, *self.schema_side_inputs)) diff --git a/sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py b/sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py index 1e227f8b00fef..ddde3f4756ca0 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py @@ -459,6 +459,39 @@ def test_records_traverse_transform_with_mocks(self): assert_that(jobs, equal_to([job_reference]), label='CheckJobs') + def test_load_job_id_used(self): + job_reference = bigquery_api.JobReference() + job_reference.projectId = 'loadJobId' + job_reference.jobId = 'job_name1' + + result_job = bigquery_api.Job() + result_job.jobReference = job_reference + + mock_job = mock.Mock() + mock_job.status.state = 'DONE' + mock_job.status.errorResult = None + mock_job.jobReference = job_reference + + bq_client = mock.Mock() + bq_client.jobs.Get.return_value = mock_job + + bq_client.jobs.Insert.return_value = result_job + + transform = bqfl.BigQueryBatchFileLoads( + 'project1:dataset1.table1', + custom_gcs_temp_location=self._new_tempdir(), + test_client=bq_client, + validate=False, + load_job_project_id='loadJobId' + ) + + with TestPipeline('DirectRunner') as p: + outputs = p | beam.Create(_ELEMENTS) | transform + jobs = outputs[bqfl.BigQueryBatchFileLoads.DESTINATION_JOBID_PAIRS] \ + | "GetJobs" >> beam.Map(lambda x: x[1]) + + assert_that(jobs, equal_to([job_reference]), label='CheckJobProjectIds') + @mock.patch('time.sleep') def test_wait_for_job_completion(self, sleep_mock): job_references = [bigquery_api.JobReference(), bigquery_api.JobReference()] diff --git a/sdks/python/apache_beam/io/gcp/bigquery_tools.py b/sdks/python/apache_beam/io/gcp/bigquery_tools.py index 56425f7259e9b..256cd7b10e322 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_tools.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_tools.py @@ -988,7 +988,8 @@ def perform_load_job( create_disposition=None, additional_load_parameters=None, source_format=None, - job_labels=None): + job_labels=None, + load_job_project_id=None): """Starts a job to load data into BigQuery. Returns: @@ -1005,8 +1006,10 @@ def perform_load_job( 'Only one of source_uris and source_stream may be specified. ' 'Got both.') + project_id = destination.projectId if load_job_project_id is None else load_job_project_id + return self._insert_load_job( - destination.projectId, + project_id, job_id, destination, source_uris=source_uris, diff --git a/sdks/python/apache_beam/io/gcp/bigquery_tools_test.py b/sdks/python/apache_beam/io/gcp/bigquery_tools_test.py index f9174f5eac74b..38d3998d90724 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_tools_test.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_tools_test.py @@ -430,6 +430,18 @@ def test_perform_load_job_with_source_stream(self): upload = client.jobs.Insert.call_args[1]["upload"] self.assertEqual(b'some,data', upload.stream.read()) + def test_perform_load_job_with_load_job_id(self): + client = mock.Mock() + wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client) + + wrapper.perform_load_job( + destination=parse_table_reference('project:dataset.table'), + job_id='job_id', + source_uris=['gs://example.com/*'], + load_job_project_id='loadId') + call_args = client.jobs.Insert.call_args + self.assertEqual('loadId', call_args[0][0].projectId) + def verify_write_call_metric( self, project_id, dataset_id, table_id, status, count): """Check if an metric was recorded for the BQ IO write API call."""