diff --git a/sdks/python/apache_beam/io/gcp/bigquery.py b/sdks/python/apache_beam/io/gcp/bigquery.py index 130b473300f3c..174edec2848ed 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery.py +++ b/sdks/python/apache_beam/io/gcp/bigquery.py @@ -271,6 +271,8 @@ def compute_table_name(row): from apache_beam.transforms import ParDo from apache_beam.transforms import PTransform from apache_beam.transforms.display import DisplayDataItem +from apache_beam.transforms.sideinputs import SIDE_INPUT_PREFIX +from apache_beam.transforms.sideinputs import get_sideinput_index from apache_beam.transforms.window import GlobalWindows from apache_beam.utils import retry from apache_beam.utils.annotations import deprecated @@ -1396,6 +1398,9 @@ def __init__( and https://cloud.google.com/bigquery/docs/loading-data-cloud-storage-json. """ + self._table = table + self._dataset = dataset + self._project = project self.table_reference = bigquery_tools.parse_table_reference( table, dataset, project) self.create_disposition = BigQueryDisposition.validate_create( @@ -1523,6 +1528,73 @@ def display_data(self): res['table'] = DisplayDataItem(tableSpec, label='Table') return res + def to_runner_api_parameter(self, context): + from apache_beam.internal import pickler + + # It'd be nice to name these according to their actual + # names/positions in the orignal argument list, but such a + # transformation is currently irreversible given how + # remove_objects_from_args and insert_values_in_args + # are currently implemented. + def serialize(side_inputs): + return {(SIDE_INPUT_PREFIX + '%s') % ix: + si.to_runner_api(context).SerializeToString() + for ix, + si in enumerate(side_inputs)} + + table_side_inputs = serialize(self.table_side_inputs) + schema_side_inputs = serialize(self.schema_side_inputs) + + config = { + 'table': self._table, + 'dataset': self._dataset, + 'project': self._project, + 'schema': self.schema, + 'create_disposition': self.create_disposition, + 'write_disposition': self.write_disposition, + 'kms_key': self.kms_key, + 'batch_size': self.batch_size, + 'max_file_size': self.max_file_size, + 'max_files_per_bundle': self.max_files_per_bundle, + 'custom_gcs_temp_location': self.custom_gcs_temp_location, + 'method': self.method, + 'insert_retry_strategy': self.insert_retry_strategy, + 'additional_bq_parameters': self.additional_bq_parameters, + 'table_side_inputs': table_side_inputs, + 'schema_side_inputs': schema_side_inputs, + 'triggering_frequency': self.triggering_frequency, + 'validate': self._validate, + 'temp_file_format': self._temp_file_format, + } + return 'beam:transform:write_to_big_query:v0', pickler.dumps(config) + + @PTransform.register_urn('beam:transform:write_to_big_query:v0', bytes) + def from_runner_api(unused_ptransform, payload, context): + from apache_beam.internal import pickler + from apache_beam.portability.api.beam_runner_api_pb2 import SideInput + + config = pickler.loads(payload) + + def deserialize(side_inputs): + deserialized_side_inputs = {} + for k, v in side_inputs.items(): + side_input = SideInput() + side_input.ParseFromString(v) + deserialized_side_inputs[k] = side_input + + # This is an ordered list stored as a dict (see the comments in + # to_runner_api_parameter above). + indexed_side_inputs = [( + get_sideinput_index(tag), + pvalue.AsSideInput.from_runner_api(si, context)) for tag, + si in deserialized_side_inputs.items()] + return [si for _, si in sorted(indexed_side_inputs)] + + config['table_side_inputs'] = deserialize(config['table_side_inputs']) + config['schema_side_inputs'] = deserialize(config['schema_side_inputs']) + + return WriteToBigQuery(**config) + class _PassThroughThenCleanup(PTransform): """A PTransform that invokes a DoFn after the input PCollection has been diff --git a/sdks/python/apache_beam/io/gcp/bigquery_test.py b/sdks/python/apache_beam/io/gcp/bigquery_test.py index 8c2bfe8f0d766..5c05978695170 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_test.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_test.py @@ -584,6 +584,73 @@ def test_schema_autodetect_not_allowed_with_avro_file_loads(self): schema=beam.io.gcp.bigquery.SCHEMA_AUTODETECT, temp_file_format=bigquery_tools.FileFormat.AVRO)) + def test_to_from_runner_api(self): + """Tests that serialization of WriteToBigQuery is correct. + + This is not intended to be a change-detector test. As such, this only tests + the more complicated serialization logic of parameters: ValueProviders, + callables, and side inputs. + """ + FULL_OUTPUT_TABLE = 'test_project:output_table' + + p = TestPipeline( + additional_pipeline_args=["--experiments=use_beam_bq_sink"]) + + # Used for testing side input parameters. + table_record_pcv = beam.pvalue.AsDict( + p | "MakeTable" >> beam.Create([('table', FULL_OUTPUT_TABLE)])) + + # Used for testing value provider parameters. + schema = value_provider.StaticValueProvider(str, '"a:str"') + + original = WriteToBigQuery( + table=lambda _, + side_input: side_input['table'], + table_side_inputs=(table_record_pcv, ), + schema=schema) + + # pylint: disable=expression-not-assigned + p | 'MyWriteToBigQuery' >> original + + # Run the pipeline through to generate a pipeline proto from an empty + # context. This ensures that the serialization code ran. + pipeline_proto, context = TestPipeline.from_runner_api( + p.to_runner_api(), p.runner, p.get_pipeline_options()).to_runner_api( + return_context=True) + + # Find the transform from the context. + write_to_bq_id = [ + k for k, + v in pipeline_proto.components.transforms.items() + if v.unique_name == 'MyWriteToBigQuery' + ][0] + deserialized_node = context.transforms.get_by_id(write_to_bq_id) + deserialized = deserialized_node.transform + self.assertIsInstance(deserialized, WriteToBigQuery) + + # Test that the serialization of a value provider is correct. + self.assertEqual(original.schema, deserialized.schema) + + # Test that the serialization of a callable is correct. + self.assertEqual( + deserialized._table(None, {'table': FULL_OUTPUT_TABLE}), + FULL_OUTPUT_TABLE) + + # Test that the serialization of a side input is correct. + self.assertEqual( + len(original.table_side_inputs), len(deserialized.table_side_inputs)) + original_side_input_data = original.table_side_inputs[0]._side_input_data() + deserialized_side_input_data = deserialized.table_side_inputs[ + 0]._side_input_data() + self.assertEqual( + original_side_input_data.access_pattern, + deserialized_side_input_data.access_pattern) + self.assertEqual( + original_side_input_data.window_mapping_fn, + deserialized_side_input_data.window_mapping_fn) + self.assertEqual( + original_side_input_data.view_fn, deserialized_side_input_data.view_fn) + @unittest.skipIf(HttpError is None, 'GCP dependencies are not installed') class BigQueryStreamingInsertTransformTests(unittest.TestCase):