Skip to content

Commit

Permalink
[BEAM-13355] add Big Query parameter to enable users to specify load_…
Browse files Browse the repository at this point in the history
…job_project_id
  • Loading branch information
johnjcasey committed Dec 9, 2021
1 parent 5b3f70b commit a2b1400
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 14 deletions.
9 changes: 7 additions & 2 deletions sdks/python/apache_beam/io/gcp/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -1926,7 +1926,8 @@ def __init__(
ignore_insert_ids=False,
# TODO(BEAM-11857): Switch the default when the feature is mature.
with_auto_sharding=False,
ignore_unknown_columns=False):
ignore_unknown_columns=False,
load_job_project_id=None):
"""Initialize a WriteToBigQuery transform.
Args:
Expand Down Expand Up @@ -2058,6 +2059,8 @@ def __init__(
which treats unknown values as errors. This option is only valid for
method=STREAMING_INSERTS. See reference:
https://cloud.google.com/bigquery/docs/reference/rest/v2/tabledata/insertAll
load_job_project_id: Specifies an alternate GCP project id to use for billing
Batch File Loads. By default, the project id of the table is used.
"""
self._table = table
self._dataset = dataset
Expand Down Expand Up @@ -2092,6 +2095,7 @@ def __init__(
self.schema_side_inputs = schema_side_inputs or ()
self._ignore_insert_ids = ignore_insert_ids
self._ignore_unknown_columns = ignore_unknown_columns
self.load_job_project_id = load_job_project_id

# Dict/schema methods were moved to bigquery_tools, but keep references
# here for backward compatibility.
Expand Down Expand Up @@ -2185,7 +2189,8 @@ def expand(self, pcoll):
schema_side_inputs=self.schema_side_inputs,
additional_bq_parameters=self.additional_bq_parameters,
validate=self._validate,
is_streaming_pipeline=is_streaming_pipeline)
is_streaming_pipeline=is_streaming_pipeline,
load_job_project_id=self.load_job_project_id)

def display_data(self):
res = {}
Expand Down
36 changes: 26 additions & 10 deletions sdks/python/apache_beam/io/gcp/bigquery_file_loads.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,12 +345,14 @@ def __init__(
test_client=None,
additional_bq_parameters=None,
step_name=None,
source_format=None):
source_format=None,
load_job_project_id=None):
self._test_client = test_client
self._write_disposition = write_disposition
self._additional_bq_parameters = additional_bq_parameters or {}
self._step_name = step_name
self._source_format = source_format
self._load_job_project_id=load_job_project_id

def setup(self):
self._bq_wrapper = bigquery_tools.BigQueryWrapper(client=self._test_client)
Expand Down Expand Up @@ -439,7 +441,8 @@ def process(self, element, schema_mod_job_name_prefix):
create_disposition='CREATE_NEVER',
additional_load_parameters=additional_parameters,
job_labels=self._bq_io_metadata.add_additional_bq_job_labels(),
source_format=self._source_format)
source_format=self._source_format,
load_job_project_id=self._load_job_project_id)
yield (destination, schema_update_job_reference)


Expand All @@ -462,13 +465,15 @@ def __init__(
create_disposition=None,
write_disposition=None,
test_client=None,
step_name=None):
step_name=None,
load_job_project_id=None):
self.create_disposition = create_disposition
self.write_disposition = write_disposition
self.test_client = test_client
self._observed_tables = set()
self.bq_io_metadata = None
self._step_name = step_name
self.load_job_project_id = load_job_project_id

def display_data(self):
return {
Expand Down Expand Up @@ -527,8 +532,10 @@ def process(self, element, job_name_prefix=None, unused_schema_mod_jobs=None):

if not self.bq_io_metadata:
self.bq_io_metadata = create_bigquery_io_metadata(self._step_name)

project_id = copy_to_reference.projectId if self.load_job_project_id is None else self.load_job_project_id
job_reference = self.bq_wrapper._insert_copy_job(
copy_to_reference.projectId,
project_id,
copy_job_name,
copy_from_reference,
copy_to_reference,
Expand Down Expand Up @@ -559,14 +566,16 @@ def __init__(
temporary_tables=False,
additional_bq_parameters=None,
source_format=None,
step_name=None):
step_name=None,
load_job_project_id=None):
self.schema = schema
self.test_client = test_client
self.temporary_tables = temporary_tables
self.additional_bq_parameters = additional_bq_parameters or {}
self.source_format = source_format
self.bq_io_metadata = None
self._step_name = step_name
self.load_job_project_id = load_job_project_id
if self.temporary_tables:
# If we are loading into temporary tables, we rely on the default create
# and write dispositions, which mean that a new table will be created.
Expand Down Expand Up @@ -663,7 +672,8 @@ def process(self, element, load_job_name_prefix, *schema_side_inputs):
create_disposition=create_disposition,
additional_load_parameters=additional_parameters,
source_format=self.source_format,
job_labels=self.bq_io_metadata.add_additional_bq_job_labels())
job_labels=self.bq_io_metadata.add_additional_bq_job_labels(),
load_job_project_id=self.load_job_project_id)
yield (destination, job_reference)


Expand Down Expand Up @@ -789,7 +799,8 @@ def __init__(
schema_side_inputs=None,
test_client=None,
validate=True,
is_streaming_pipeline=False):
is_streaming_pipeline=False,
load_job_project_id=None):
self.destination = destination
self.create_disposition = create_disposition
self.write_disposition = write_disposition
Expand Down Expand Up @@ -823,6 +834,7 @@ def __init__(
self.schema_side_inputs = schema_side_inputs or ()

self.is_streaming_pipeline = is_streaming_pipeline
self.load_job_project_id = load_job_project_id
self._validate = validate
if self._validate:
self.verify()
Expand Down Expand Up @@ -1005,7 +1017,8 @@ def _load_data(
temporary_tables=True,
additional_bq_parameters=self.additional_bq_parameters,
source_format=self._temp_file_format,
step_name=step_name),
step_name=step_name,
load_job_project_id=self.load_job_project_id),
load_job_name_pcv,
*self.schema_side_inputs).with_outputs(
TriggerLoadJobs.TEMP_TABLES, main='main'))
Expand All @@ -1029,6 +1042,7 @@ def _load_data(
additional_bq_parameters=self.additional_bq_parameters,
step_name=step_name,
source_format=self._temp_file_format,
load_job_project_id=self.load_job_project_id
),
schema_mod_job_name_pcv))

Expand All @@ -1046,7 +1060,8 @@ def _load_data(
create_disposition=self.create_disposition,
write_disposition=self.write_disposition,
test_client=self.test_client,
step_name=step_name),
step_name=step_name,
load_job_project_id=self.load_job_project_id),
copy_job_name_pcv,
pvalue.AsIter(finished_schema_mod_jobs_pc)))

Expand Down Expand Up @@ -1084,7 +1099,8 @@ def _load_data(
temporary_tables=False,
additional_bq_parameters=self.additional_bq_parameters,
source_format=self._temp_file_format,
step_name=step_name),
step_name=step_name,
load_job_project_id=self.load_job_project_id),
load_job_name_pcv,
*self.schema_side_inputs))

Expand Down
33 changes: 33 additions & 0 deletions sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,39 @@ def test_records_traverse_transform_with_mocks(self):

assert_that(jobs, equal_to([job_reference]), label='CheckJobs')

def test_load_job_id_used(self):
job_reference = bigquery_api.JobReference()
job_reference.projectId = 'loadJobId'
job_reference.jobId = 'job_name1'

result_job = bigquery_api.Job()
result_job.jobReference = job_reference

mock_job = mock.Mock()
mock_job.status.state = 'DONE'
mock_job.status.errorResult = None
mock_job.jobReference = job_reference

bq_client = mock.Mock()
bq_client.jobs.Get.return_value = mock_job

bq_client.jobs.Insert.return_value = result_job

transform = bqfl.BigQueryBatchFileLoads(
'project1:dataset1.table1',
custom_gcs_temp_location=self._new_tempdir(),
test_client=bq_client,
validate=False,
load_job_project_id='loadJobId'
)

with TestPipeline('DirectRunner') as p:
outputs = p | beam.Create(_ELEMENTS) | transform
jobs = outputs[bqfl.BigQueryBatchFileLoads.DESTINATION_JOBID_PAIRS] \
| "GetJobs" >> beam.Map(lambda x: x[1])

assert_that(jobs, equal_to([job_reference]), label='CheckJobProjectIds')

@mock.patch('time.sleep')
def test_wait_for_job_completion(self, sleep_mock):
job_references = [bigquery_api.JobReference(), bigquery_api.JobReference()]
Expand Down
7 changes: 5 additions & 2 deletions sdks/python/apache_beam/io/gcp/bigquery_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -988,7 +988,8 @@ def perform_load_job(
create_disposition=None,
additional_load_parameters=None,
source_format=None,
job_labels=None):
job_labels=None,
load_job_project_id=None):
"""Starts a job to load data into BigQuery.
Returns:
Expand All @@ -1005,8 +1006,10 @@ def perform_load_job(
'Only one of source_uris and source_stream may be specified. '
'Got both.')

project_id = destination.projectId if load_job_project_id is None else load_job_project_id

return self._insert_load_job(
destination.projectId,
project_id,
job_id,
destination,
source_uris=source_uris,
Expand Down
12 changes: 12 additions & 0 deletions sdks/python/apache_beam/io/gcp/bigquery_tools_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,18 @@ def test_perform_load_job_with_source_stream(self):
upload = client.jobs.Insert.call_args[1]["upload"]
self.assertEqual(b'some,data', upload.stream.read())

def test_perform_load_job_with_load_job_id(self):
client = mock.Mock()
wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client)

wrapper.perform_load_job(
destination=parse_table_reference('project:dataset.table'),
job_id='job_id',
source_uris=['gs://example.com/*'],
load_job_project_id='loadId')
call_args = client.jobs.Insert.call_args
self.assertEqual('loadId', call_args[0][0].projectId)

def verify_write_call_metric(
self, project_id, dataset_id, table_id, status, count):
"""Check if an metric was recorded for the BQ IO write API call."""
Expand Down

0 comments on commit a2b1400

Please sign in to comment.