Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ def __init__(self, cache_manager=None):
self._background_caching_pipeline_results = {}
self._cached_source_signature = {}
self._tracked_user_pipelines = set()
# Tracks the computation completeness of PCollections. PCollections tracked
# here don't need to be re-computed when data introspection is needed.
self._computed_pcolls = set()
# Always watch __main__ module.
self.watch('__main__')
# Do a warning level logging if current python version is below 3.6.
Expand Down Expand Up @@ -278,3 +281,23 @@ def track_user_pipelines(self):
@property
def tracked_user_pipelines(self):
return self._tracked_user_pipelines

def mark_pcollection_computed(self, pcolls):
"""Marks computation completeness for the given pcolls.

Interactive Beam can use this information to determine if a computation is
needed to introspect the data of any given PCollection.
"""
self._computed_pcolls.update(pcoll for pcoll in pcolls)

def evict_computed_pcollections(self):
"""Evicts all computed PCollections.

Interactive Beam will treat none of the PCollections in any given pipeline
as completely computed.
"""
self._computed_pcolls = set()

@property
def computed_pcollections(self):
return self._computed_pcolls
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def __init__(self,
cache_dir=None,
cache_format='text',
render_option=None,
skip_display=False):
skip_display=False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know you didn't add it in this PR, but in general it is easier to understand "positive" parameters. enable_display with a default of True is easier to understand than skip_display with default of False (since skip_display with False is a double negative).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is added by an internal Googler and their code in Google3 need/set this value. So let's probably leave it be since we cannot make an atomic change that doesn't break things in Google3 at some point. I'll contact the author when we've settled down our changes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah gotcha, np then

force_compute=True):
"""Constructor of InteractiveRunner.

Args:
Expand All @@ -68,6 +69,12 @@ def __init__(self,
skip_display: (bool) whether to skip display operations when running the
pipeline. Useful if running large pipelines when display is not
needed.
force_compute: (bool) whether sequential pipeline runs can use cached data
of PCollections computed from the previous runs including show API
invocation from interactive_beam module. If True, always run the whole
pipeline and compute data for PCollections forcefully. If False, use
available data and run minimum pipeline fragment to only compute data
not available.
"""
self._underlying_runner = (underlying_runner
or direct_runner.DirectRunner())
Expand All @@ -79,6 +86,7 @@ def __init__(self,
self._render_option = render_option
self._in_session = False
self._skip_display = skip_display
self._force_compute = force_compute

def is_fnapi_compatible(self):
# TODO(BEAM-8436): return self._underlying_runner.is_fnapi_compatible()
Expand Down Expand Up @@ -127,6 +135,9 @@ def apply(self, transform, pvalueish, options):
return self._underlying_runner.apply(transform, pvalueish, options)

def run_pipeline(self, pipeline, options):
if self._force_compute:
ie.current_env().evict_computed_pcollections()

pipeline_instrument = inst.pin(pipeline, options)

# The user_pipeline analyzed might be None if the pipeline given has nothing
Expand Down Expand Up @@ -163,6 +174,11 @@ def run_pipeline(self, pipeline, options):
is_main_job=True)
main_job_result.wait_until_finish()

if main_job_result.state is beam.runners.runner.PipelineState.DONE:
# pylint: disable=dict-values-not-iterating
ie.current_env().mark_pcollection_computed(
pipeline_instrument.runner_pcoll_to_user_pcoll.values())

return main_job_result


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@
from apache_beam.runners.interactive import interactive_beam as ib
from apache_beam.runners.interactive import interactive_environment as ie
from apache_beam.runners.interactive import interactive_runner
from apache_beam.runners.interactive.testing.mock_ipython import mock_get_ipython

# TODO(BEAM-8288): clean up the work-around of nose tests using Python2 without
# unittest.mock module.
try:
from unittest.mock import patch
except ImportError:
from mock import patch


def print_with_message(msg):
Expand Down Expand Up @@ -123,6 +131,31 @@ def __exit__(self, exc_type, exc_val, exc_tb):
runner.end_session()
self.assertFalse(underlying_runner._in_session)

@unittest.skipIf(not ie.current_env().is_interactive_ready,
'[interactive] dependency is not installed.')
@patch('IPython.get_ipython', new_callable=mock_get_ipython)
def test_mark_pcollection_completed_after_successful_run(self, cell):
with cell: # Cell 1
p = beam.Pipeline(interactive_runner.InteractiveRunner())
ib.watch({'p': p})

with cell: # Cell 2
# pylint: disable=range-builtin-not-iterating
init = p | 'Init' >> beam.Create(range(5))

with cell: # Cell 3
square = init | 'Square' >> beam.Map(lambda x: x * x)
cube = init | 'Cube' >> beam.Map(lambda x: x ** 3)

ib.watch(locals())
result = p.run()
self.assertTrue(init in ie.current_env().computed_pcollections)
self.assertEqual([0, 1, 2, 3, 4], result.get(init))
self.assertTrue(square in ie.current_env().computed_pcollections)
self.assertEqual([0, 1, 4, 9, 16], result.get(square))
self.assertTrue(cube in ie.current_env().computed_pcollections)
self.assertEqual([0, 1, 8, 27, 64], result.get(cube))


if __name__ == '__main__':
unittest.main()
225 changes: 225 additions & 0 deletions sdks/python/apache_beam/runners/interactive/pipeline_fragment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Module to build pipeline fragment that produces given PCollections.

For internal use only; no backwards-compatibility guarantees.
"""
from __future__ import absolute_import

import apache_beam as beam
from apache_beam.pipeline import PipelineVisitor


class PipelineFragment(object):
"""A fragment of a pipeline definition.

A pipeline fragment is built from the original pipeline definition to include
only PTransforms that are necessary to produce the given PCollections.
"""

def __init__(self, pcolls, options=None):
"""Constructor of PipelineFragment.

Args:
pcolls: (List[PCollection]) a list of PCollections to build pipeline
fragment for.
options: (PipelineOptions) the pipeline options for the implicit
pipeline run.
"""
assert len(pcolls) > 0, (
'Need at least 1 PCollection as the target data to build a pipeline '
'fragment that produces it.')
for pcoll in pcolls:
assert isinstance(pcoll, beam.pvalue.PCollection), (
'{} is not an apache_beam.pvalue.PCollection.'.format(pcoll))
# No modification to self._user_pipeline is allowed.
self._user_pipeline = pcolls[0].pipeline
# These are user PCollections. Do not use them to deduce anything that
# will be executed by any runner. Instead, use
# `self._runner_pcolls_to_user_pcolls.keys()` to get copied PCollections.
self._pcolls = set(pcolls)
for pcoll in self._pcolls:
assert pcoll.pipeline is self._user_pipeline, (
'{} belongs to a different user pipeline than other PCollections '
'given and cannot be used to build a pipeline fragment that produces '
'the given PCollections.'.format(pcoll))
self._options = options

# A copied pipeline instance for modification without changing the user
# pipeline instance held by the end user. This instance can be processed
# into a pipeline fragment that later run by the underlying runner.
self._runner_pipeline = self._build_runner_pipeline()
_, self._context = self._runner_pipeline.to_runner_api(
return_context=True, use_fake_coders=True)
from apache_beam.runners.interactive import pipeline_instrument as instr
self._runner_pcoll_to_id = instr.pcolls_to_pcoll_id(
self._runner_pipeline, self._context)
# Correlate components in the runner pipeline to components in the user
# pipeline. The target pcolls are the pcolls given and defined in the user
# pipeline.
self._id_to_target_pcoll = self._calculate_target_pcoll_ids()
self._label_to_user_transform = self._calculate_user_transform_labels()
# Below will give us the 1:1 correlation between
# PCollections/AppliedPTransforms from the copied runner pipeline and
# PCollections/AppliedPTransforms from the user pipeline.
# (Dict[PCollection, PCollection])
(self._runner_pcolls_to_user_pcolls,
# (Dict[AppliedPTransform, AppliedPTransform])
self._runner_transforms_to_user_transforms
) = self._build_correlation_between_pipelines(
self._runner_pcoll_to_id,
self._id_to_target_pcoll,
self._label_to_user_transform)

# Below are operated on the runner pipeline.
(self._necessary_transforms,
self._necessary_pcollections) = self._mark_necessary_transforms_and_pcolls(
self._runner_pcolls_to_user_pcolls)
self._runner_pipeline = self._prune_runner_pipeline_to_fragment(
self._runner_pipeline,
self._necessary_transforms)

def deduce_fragment(self):
"""Deduce the pipeline fragment as an apache_beam.Pipeline instance."""
return beam.pipeline.Pipeline.from_runner_api(
self._runner_pipeline.to_runner_api(use_fake_coders=True),
self._runner_pipeline.runner,
self._options)

def run(self, display_pipeline_graph=False, use_cache=True):
"""Shorthand to run the pipeline fragment."""
try:
skip_pipeline_graph = self._runner_pipeline.runner._skip_display
force_compute = self._runner_pipeline.runner._force_compute
self._runner_pipeline.runner._skip_display = not display_pipeline_graph
self._runner_pipeline.runner._force_compute = not use_cache
return self.deduce_fragment().run()
finally:
self._runner_pipeline.runner._skip_display = skip_pipeline_graph
self._runner_pipeline.runner._force_compute = force_compute

def _build_runner_pipeline(self):
return beam.pipeline.Pipeline.from_runner_api(
self._user_pipeline.to_runner_api(use_fake_coders=True),
self._user_pipeline.runner,
self._options)

def _calculate_target_pcoll_ids(self):
pcoll_id_to_target_pcoll = {}
for pcoll in self._pcolls:
pcoll_id_to_target_pcoll[self._runner_pcoll_to_id.get(str(pcoll),
'')] = pcoll
return pcoll_id_to_target_pcoll

def _calculate_user_transform_labels(self):
label_to_user_transform = {}

class UserTransformVisitor(PipelineVisitor):

def enter_composite_transform(self, transform_node):
self.visit_transform(transform_node)

def visit_transform(self, transform_node):
if transform_node is not None:
label_to_user_transform[transform_node.full_label] = transform_node

v = UserTransformVisitor()
self._runner_pipeline.visit(v)
return label_to_user_transform

def _build_correlation_between_pipelines(self,
runner_pcoll_to_id,
id_to_target_pcoll,
label_to_user_transform):
runner_pcolls_to_user_pcolls = {}
runner_transforms_to_user_transforms = {}

class CorrelationVisitor(PipelineVisitor):

def enter_composite_transform(self, transform_node):
self.visit_transform(transform_node)

def visit_transform(self, transform_node):
self._process_transform(transform_node)
for in_pcoll in transform_node.inputs:
self._process_pcoll(in_pcoll)
for out_pcoll in transform_node.outputs.values():
self._process_pcoll(out_pcoll)

def _process_pcoll(self, pcoll):
pcoll_id = runner_pcoll_to_id.get(str(pcoll), '')
if pcoll_id in id_to_target_pcoll:
runner_pcolls_to_user_pcolls[pcoll] = (
id_to_target_pcoll[pcoll_id])

def _process_transform(self, transform_node):
if transform_node.full_label in label_to_user_transform:
runner_transforms_to_user_transforms[transform_node] = (
label_to_user_transform[transform_node.full_label])

v = CorrelationVisitor()
self._runner_pipeline.visit(v)
return runner_pcolls_to_user_pcolls, runner_transforms_to_user_transforms

def _mark_necessary_transforms_and_pcolls(self,
runner_pcolls_to_user_pcolls):
necessary_transforms = set()
all_inputs = set()
updated_all_inputs = set(runner_pcolls_to_user_pcolls.keys())
# Do this until no more new PCollection is recorded.
while len(updated_all_inputs) != len(all_inputs):
all_inputs = set(updated_all_inputs)
for pcoll in all_inputs:
producer = pcoll.producer
while producer:
if producer in necessary_transforms:
break
# Mark the AppliedPTransform as necessary.
necessary_transforms.add(producer)
# Record all necessary input and side input PCollections.
updated_all_inputs.update(producer.inputs)
# pylint: disable=map-builtin-not-iterating
side_input_pvalues = set(
map(lambda side_input: side_input.pvalue,
producer.side_inputs))
updated_all_inputs.update(side_input_pvalues)
# Go to its parent AppliedPTransform.
producer = producer.parent
return necessary_transforms, all_inputs

def _prune_runner_pipeline_to_fragment(self,
runner_pipeline,
necessary_transforms):

class PruneVisitor(PipelineVisitor):

def enter_composite_transform(self, transform_node):
pruned_parts = list(transform_node.parts)
for part in transform_node.parts:
if part not in necessary_transforms:
pruned_parts.remove(part)
transform_node.parts = tuple(pruned_parts)
self.visit_transform(transform_node)

def visit_transform(self, transform_node):
if transform_node not in necessary_transforms:
transform_node.parent = None

v = PruneVisitor()
runner_pipeline.visit(v)
return runner_pipeline
Loading