Skip to content

Commit

Permalink
Merge pull request #11745 from [BEAM-9692] Add to/from_runner_api_par…
Browse files Browse the repository at this point in the history
…ameters to WriteToBigQuery

Add to/from_runner_api_parameters to WriteToBigQuery
  • Loading branch information
pabloem committed May 27, 2020
2 parents eb59a84 + 6cf105a commit 6658d62
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 0 deletions.
72 changes: 72 additions & 0 deletions sdks/python/apache_beam/io/gcp/bigquery.py
Expand Up @@ -271,6 +271,8 @@ def compute_table_name(row):
from apache_beam.transforms import ParDo
from apache_beam.transforms import PTransform
from apache_beam.transforms.display import DisplayDataItem
from apache_beam.transforms.sideinputs import SIDE_INPUT_PREFIX
from apache_beam.transforms.sideinputs import get_sideinput_index
from apache_beam.transforms.window import GlobalWindows
from apache_beam.utils import retry
from apache_beam.utils.annotations import deprecated
Expand Down Expand Up @@ -1396,6 +1398,9 @@ def __init__(
and
https://cloud.google.com/bigquery/docs/loading-data-cloud-storage-json.
"""
self._table = table
self._dataset = dataset
self._project = project
self.table_reference = bigquery_tools.parse_table_reference(
table, dataset, project)
self.create_disposition = BigQueryDisposition.validate_create(
Expand Down Expand Up @@ -1523,6 +1528,73 @@ def display_data(self):
res['table'] = DisplayDataItem(tableSpec, label='Table')
return res

def to_runner_api_parameter(self, context):
from apache_beam.internal import pickler

# It'd be nice to name these according to their actual
# names/positions in the orignal argument list, but such a
# transformation is currently irreversible given how
# remove_objects_from_args and insert_values_in_args
# are currently implemented.
def serialize(side_inputs):
return {(SIDE_INPUT_PREFIX + '%s') % ix:
si.to_runner_api(context).SerializeToString()
for ix,
si in enumerate(side_inputs)}

table_side_inputs = serialize(self.table_side_inputs)
schema_side_inputs = serialize(self.schema_side_inputs)

config = {
'table': self._table,
'dataset': self._dataset,
'project': self._project,
'schema': self.schema,
'create_disposition': self.create_disposition,
'write_disposition': self.write_disposition,
'kms_key': self.kms_key,
'batch_size': self.batch_size,
'max_file_size': self.max_file_size,
'max_files_per_bundle': self.max_files_per_bundle,
'custom_gcs_temp_location': self.custom_gcs_temp_location,
'method': self.method,
'insert_retry_strategy': self.insert_retry_strategy,
'additional_bq_parameters': self.additional_bq_parameters,
'table_side_inputs': table_side_inputs,
'schema_side_inputs': schema_side_inputs,
'triggering_frequency': self.triggering_frequency,
'validate': self._validate,
'temp_file_format': self._temp_file_format,
}
return 'beam:transform:write_to_big_query:v0', pickler.dumps(config)

@PTransform.register_urn('beam:transform:write_to_big_query:v0', bytes)
def from_runner_api(unused_ptransform, payload, context):
from apache_beam.internal import pickler
from apache_beam.portability.api.beam_runner_api_pb2 import SideInput

config = pickler.loads(payload)

def deserialize(side_inputs):
deserialized_side_inputs = {}
for k, v in side_inputs.items():
side_input = SideInput()
side_input.ParseFromString(v)
deserialized_side_inputs[k] = side_input

# This is an ordered list stored as a dict (see the comments in
# to_runner_api_parameter above).
indexed_side_inputs = [(
get_sideinput_index(tag),
pvalue.AsSideInput.from_runner_api(si, context)) for tag,
si in deserialized_side_inputs.items()]
return [si for _, si in sorted(indexed_side_inputs)]

config['table_side_inputs'] = deserialize(config['table_side_inputs'])
config['schema_side_inputs'] = deserialize(config['schema_side_inputs'])

return WriteToBigQuery(**config)


class _PassThroughThenCleanup(PTransform):
"""A PTransform that invokes a DoFn after the input PCollection has been
Expand Down
67 changes: 67 additions & 0 deletions sdks/python/apache_beam/io/gcp/bigquery_test.py
Expand Up @@ -584,6 +584,73 @@ def test_schema_autodetect_not_allowed_with_avro_file_loads(self):
schema=beam.io.gcp.bigquery.SCHEMA_AUTODETECT,
temp_file_format=bigquery_tools.FileFormat.AVRO))

def test_to_from_runner_api(self):
"""Tests that serialization of WriteToBigQuery is correct.
This is not intended to be a change-detector test. As such, this only tests
the more complicated serialization logic of parameters: ValueProviders,
callables, and side inputs.
"""
FULL_OUTPUT_TABLE = 'test_project:output_table'

p = TestPipeline(
additional_pipeline_args=["--experiments=use_beam_bq_sink"])

# Used for testing side input parameters.
table_record_pcv = beam.pvalue.AsDict(
p | "MakeTable" >> beam.Create([('table', FULL_OUTPUT_TABLE)]))

# Used for testing value provider parameters.
schema = value_provider.StaticValueProvider(str, '"a:str"')

original = WriteToBigQuery(
table=lambda _,
side_input: side_input['table'],
table_side_inputs=(table_record_pcv, ),
schema=schema)

# pylint: disable=expression-not-assigned
p | 'MyWriteToBigQuery' >> original

# Run the pipeline through to generate a pipeline proto from an empty
# context. This ensures that the serialization code ran.
pipeline_proto, context = TestPipeline.from_runner_api(
p.to_runner_api(), p.runner, p.get_pipeline_options()).to_runner_api(
return_context=True)

# Find the transform from the context.
write_to_bq_id = [
k for k,
v in pipeline_proto.components.transforms.items()
if v.unique_name == 'MyWriteToBigQuery'
][0]
deserialized_node = context.transforms.get_by_id(write_to_bq_id)
deserialized = deserialized_node.transform
self.assertIsInstance(deserialized, WriteToBigQuery)

# Test that the serialization of a value provider is correct.
self.assertEqual(original.schema, deserialized.schema)

# Test that the serialization of a callable is correct.
self.assertEqual(
deserialized._table(None, {'table': FULL_OUTPUT_TABLE}),
FULL_OUTPUT_TABLE)

# Test that the serialization of a side input is correct.
self.assertEqual(
len(original.table_side_inputs), len(deserialized.table_side_inputs))
original_side_input_data = original.table_side_inputs[0]._side_input_data()
deserialized_side_input_data = deserialized.table_side_inputs[
0]._side_input_data()
self.assertEqual(
original_side_input_data.access_pattern,
deserialized_side_input_data.access_pattern)
self.assertEqual(
original_side_input_data.window_mapping_fn,
deserialized_side_input_data.window_mapping_fn)
self.assertEqual(
original_side_input_data.view_fn, deserialized_side_input_data.view_fn)


@unittest.skipIf(HttpError is None, 'GCP dependencies are not installed')
class BigQueryStreamingInsertTransformTests(unittest.TestCase):
Expand Down

0 comments on commit 6658d62

Please sign in to comment.