Skip to content

Commit

Permalink
[BEAM-9291] Upload graph option in dataflow's python sdk (#10829)
Browse files Browse the repository at this point in the history
[BEAM-9291] Upload graph option in dataflow's python sdk (#10829)
  • Loading branch information
stankiewicz committed Feb 18, 2020
1 parent 945b0bc commit b483ddb
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 0 deletions.
11 changes: 11 additions & 0 deletions sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,17 @@ def test_streaming_engine_flag_adds_windmill_experiments(self):
self.assertIn('enable_windmill_service', experiments_for_job)
self.assertIn('some_other_experiment', experiments_for_job)

def test_upload_graph_experiment(self):
remote_runner = DataflowRunner()
self.default_properties.append('--experiment=upload_graph')

with Pipeline(remote_runner, PipelineOptions(self.default_properties)) as p:
p | ptransform.Create([1]) # pylint: disable=expression-not-assigned

experiments_for_job = (
remote_runner.job.options.view_as(DebugOptions).experiments)
self.assertIn('upload_graph', experiments_for_job)

def test_dataflow_worker_jar_flag_non_fnapi_noop(self):
remote_runner = DataflowRunner()
self.default_properties.append('--experiment=some_other_experiment')
Expand Down
10 changes: 10 additions & 0 deletions sdks/python/apache_beam/runners/dataflow/internal/apiclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,16 @@ def create_job(self, job):
self.stage_file(
gcs_or_local_path, file_name, io.BytesIO(job.json().encode('utf-8')))

if job.options.view_as(DebugOptions).lookup_experiment('upload_graph'):
self.stage_file(
job.options.view_as(GoogleCloudOptions).staging_location,
"dataflow_graph.json",
io.BytesIO(job.json().encode('utf-8')))
del job.proto.steps[:]
job.proto.stepsLocation = FileSystems.join(
job.options.view_as(GoogleCloudOptions).staging_location,
"dataflow_graph.json")

if not template_location:
return self.submit_job_description(job)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,33 @@ def test_get_response_encoding(self):

assert encoding == version_to_encoding[sys.version_info[0]]

@unittest.skipIf(apiclient is None, 'GCP dependencies are not installed')
def test_graph_is_uploaded(self):
pipeline_options = PipelineOptions([
'--project',
'test_project',
'--job_name',
'test_job_name',
'--temp_location',
'gs://test-location/temp',
'--experiments',
'beam_fn_api',
'--experiments',
'upload_graph'
])
job = apiclient.Job(pipeline_options, FAKE_PIPELINE_URL)
client = apiclient.DataflowApplicationClient(pipeline_options)
with mock.patch.object(client, 'stage_file', side_effect=None):
with mock.patch.object(client, 'create_job_description',
side_effect=None):
with mock.patch.object(client,
'submit_job_description',
side_effect=None):
client.create_job(job)
client.stage_file.assert_called_once_with(
mock.ANY, "dataflow_graph.json", mock.ANY)
client.create_job_description.assert_called_once()


if __name__ == '__main__':
unittest.main()

0 comments on commit b483ddb

Please sign in to comment.