Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions sdks/python/apache_beam/runners/interactive/pipeline_instrument.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,8 @@ def instrument(self):
self._pipeline
"""
cacheable_inputs = set()
all_inputs = set()
all_outputs = set()
unbounded_source_pcolls = set()

class InstrumentVisitor(PipelineVisitor):
Expand All @@ -418,10 +420,16 @@ def visit_transform(self, transform_node):
tuple(ie.current_env().options.capturable_sources)):
unbounded_source_pcolls.update(transform_node.outputs.values())
cacheable_inputs.update(self._pin._cacheable_inputs(transform_node))
ins, outs = self._pin._all_inputs_outputs(transform_node)
all_inputs.update(ins)
all_outputs.update(outs)

v = InstrumentVisitor(self)
self._pipeline.visit(v)

# Every output PCollection that is never used as an input PCollection is
# considered as a side effect of the pipeline run and should be included.
self._extended_targets.update(all_outputs.difference(all_inputs))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want to track, mark side effects differently? Does users want to specifically track these pcollections?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not necessary. The intended behavior is not ambiguous: When the user uses show, head, collect APIs, these PCollections are excluded completely as the user explicitly wishes. And when the user invokes p.run(), all transforms in the pipeline should be executed as expected.

This change is only to make sure that the prune logic doesn't affect the above intended behavior.

# Add the unbounded source pcollections to the cacheable inputs. This allows
# for the caching of unbounded sources without a variable reference.
cacheable_inputs.update(unbounded_source_pcolls)
Expand Down Expand Up @@ -720,6 +728,15 @@ def _cacheable_inputs(self, transform):
inputs.add(in_pcoll)
return inputs

def _all_inputs_outputs(self, transform):
inputs = set()
outputs = set()
for in_pcoll in transform.inputs:
inputs.add(in_pcoll)
for _, out_pcoll in transform.outputs.items():
outputs.add(out_pcoll)
return inputs, outputs

def _cacheable_key(self, pcoll):
"""Gets the key a cacheable PCollection is tracked within the instrument."""
return cacheable_key(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,18 @@ def test_pipeline_pruned_when_input_pcoll_is_cached(self):
assert_pipeline_proto_contain_top_level_transform(
self, full_proto, 'Init Source')

def test_side_effect_pcoll_is_included(self):
pipeline_with_side_effect = beam.Pipeline(
interactive_runner.InteractiveRunner())
# Deliberately not assign the result to a variable to make it a
# "side effect" transform. Note we never watch anything from
# the pipeline defined locally either.
# pylint: disable=range-builtin-not-iterating,expression-not-assigned
pipeline_with_side_effect | 'Init Create' >> beam.Create(range(10))
pipeline_instrument = instr.build_pipeline_instrument(
pipeline_with_side_effect)
self.assertTrue(pipeline_instrument._extended_targets)


if __name__ == '__main__':
unittest.main()