Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,29 @@

from __future__ import absolute_import

import apache_beam as beam
from apache_beam.portability.api.beam_runner_api_pb2 import TestStreamPayload
from apache_beam.utils import timestamp
from apache_beam.utils.timestamp import Timestamp


class StreamingCacheReceiver(beam.transforms.PTransform):
"""Marks a PCollection to be read from cache.

This class is used in the PipelineInstrument to mark that an unbounded
Pcollection should be read from cache. This is because the TestStream needs
to know all the PCollections before being created.
"""
def expand(self, pbegin):
assert isinstance(pbegin, beam.pvalue.PBegin)
self.pipeline = pbegin.pipeline

return beam.pvalue.PCollection(self.pipeline, is_bounded=False)

def get_windowing(self, unused_inputs):
return beam.Windowing(beam.window.GlobalWindows())


class StreamingCache(object):
"""Abstraction that holds the logic for reading and writing to cache.
"""
Expand Down
58 changes: 52 additions & 6 deletions sdks/python/apache_beam/runners/interactive/pipeline_instrument.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.runners.interactive import cache_manager as cache
from apache_beam.runners.interactive import interactive_environment as ie
from apache_beam.runners.interactive.caching import streaming_cache
from apache_beam.testing import test_stream

READ_CACHE = "_ReadCache_"
WRITE_CACHE = "_WriteCache_"
Expand Down Expand Up @@ -400,12 +402,15 @@ def _read_cache(self, pipeline, pcoll):
# Can only read from cache when the cache with expected key exists.
if self._cache_manager.exists('full', key):
if key not in self._cached_pcoll_read:
if pcoll.is_bounded:
read_transform = cache.ReadCache(self._cache_manager, key)
else:
read_transform = streaming_cache.StreamingCacheReceiver()

# Mutates the pipeline with cache read transform attached
# to root of the pipeline.
pcoll_from_cache = (
pipeline
| '{}{}'.format(READ_CACHE, key) >> cache.ReadCache(
self._cache_manager, key))
pcoll_from_cache = (pipeline
| '{}{}'.format(READ_CACHE, key) >> read_transform)
self._cached_pcoll_read[key] = pcoll_from_cache
# else: NOOP when cache doesn't exist, just compute the original graph.

Expand All @@ -418,6 +423,43 @@ def _replace_with_cached_inputs(self, pipeline):
cache, noop.
"""

# Find all cached unbounded PCollections.
class CacheableUnboundedPCollectionVisitor(PipelineVisitor):
def __init__(self, pin):
self._pin = pin
self.unbounded_pcolls = set()

def enter_composite_transform(self, transform_node):
self.visit_transform(transform_node)

def visit_transform(self, transform_node):
if transform_node.inputs:
for input_pcoll in transform_node.inputs:
key = self._pin.cache_key(input_pcoll)
if (key in self._pin._cached_pcoll_read and
not input_pcoll.is_bounded):
self.unbounded_pcolls.add(key)

v = CacheableUnboundedPCollectionVisitor(self)
pipeline.visit(v)

# The set of keys from the cached unbounded PCollections will be used as the
# output tags for the TestStream. This is to remember what cache-key is
# associated with which PCollection.
output_tags = v.unbounded_pcolls

# Take the PCollections that will be read from the TestStream and insert
# them back into the dictionary of cached PCollections. The next step will
# replace the downstream consumer of the non-cached PCollections with these
# PCollections.
if output_tags:
output_pcolls = pipeline | test_stream.TestStream(output_tags=output_tags)
if len(output_tags) == 1:
self._cached_pcoll_read[None] = output_pcolls
else:
self._cached_pcoll_read.update(output_pcolls)


class ReadCacheWireVisitor(PipelineVisitor):
"""Visitor wires cache read as inputs to replace corresponding original
input PCollections in pipeline.
Expand All @@ -433,10 +475,14 @@ def enter_composite_transform(self, transform_node):
def visit_transform(self, transform_node):
if transform_node.inputs:
input_list = list(transform_node.inputs)
for i in range(len(input_list)):
key = self._pin.cache_key(input_list[i])
for i, input_pcoll in enumerate(input_list):
key = self._pin.cache_key(input_pcoll)

# Replace the input pcollection with the cached pcollection (if it
# has been cached).
if key in self._pin._cached_pcoll_read:
input_list[i] = self._pin._cached_pcoll_read[key]
# Update the transform with its new inputs.
transform_node.inputs = tuple(input_list)

v = ReadCacheWireVisitor(self)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@
from apache_beam.runners.interactive import interactive_environment as ie
from apache_beam.runners.interactive import pipeline_instrument as instr
from apache_beam.runners.interactive import interactive_runner
from apache_beam.runners.interactive.caching import streaming_cache
from apache_beam.runners.interactive.testing.pipeline_assertion import assert_pipeline_equal
from apache_beam.runners.interactive.testing.pipeline_assertion import assert_pipeline_proto_equal
from apache_beam.testing.test_stream import TestStream

# Work around nose tests using Python2 without unittest.mock module.
try:
Expand Down Expand Up @@ -186,10 +188,16 @@ def test_background_caching_pipeline_proto(self):

assert_pipeline_proto_equal(self, expected_pipeline, actual_pipeline)

def _example_pipeline(self, watch=True):
def _example_pipeline(self, watch=True, bounded=True):
p = beam.Pipeline(interactive_runner.InteractiveRunner())
# pylint: disable=range-builtin-not-iterating
init_pcoll = p | 'Init Create' >> beam.Create(range(10))
if bounded:
source = beam.Create(range(10))
else:
source = beam.io.ReadFromPubSub(
subscription='projects/fake-project/subscriptions/fake_sub')

init_pcoll = p | 'Init Source' >> source
second_pcoll = init_pcoll | 'Second' >> beam.Map(lambda x: x * x)
if watch:
ib.watch(locals())
Expand Down Expand Up @@ -292,6 +300,47 @@ def test_find_out_correct_user_pipeline(self):
pipeline_instrument = instr.pin(runner_pipeline)
self.assertIs(pipeline_instrument.user_pipeline, user_pipeline)

def test_instrument_example_unbounded_pipeline_to_read_cache(self):
p_origin, init_pcoll, second_pcoll = self._example_pipeline(watch=True,
bounded=False)
p_copy, _, _ = self._example_pipeline(watch=False, bounded=False)

# Mock as if cacheable PCollections are cached.
init_pcoll_cache_key = 'init_pcoll_' + str(
id(init_pcoll)) + '_' + str(id(init_pcoll.producer))
self._mock_write_cache(init_pcoll, init_pcoll_cache_key)
second_pcoll_cache_key = 'second_pcoll_' + str(
id(second_pcoll)) + '_' + str(id(second_pcoll.producer))
self._mock_write_cache(second_pcoll, second_pcoll_cache_key)
ie.current_env().cache_manager().exists = MagicMock(return_value=True)
instr.pin(p_copy)

# Add the caching transforms.
key = '_ReadCache_' + init_pcoll_cache_key
cached_init_pcoll = (p_origin
| key >> streaming_cache.StreamingCacheReceiver())
cached_init_pcoll = p_origin | TestStream(output_tags=[key])

# second_pcoll is never used as input and there is no need to read cache.

class TestReadCacheWireVisitor(PipelineVisitor):
"""Replace init_pcoll with cached_init_pcoll for all occuring inputs."""

def enter_composite_transform(self, transform_node):
self.visit_transform(transform_node)

def visit_transform(self, transform_node):
if transform_node.inputs:
input_list = list(transform_node.inputs)
for i in range(len(input_list)):
if input_list[i] == init_pcoll:
input_list[i] = cached_init_pcoll
transform_node.inputs = tuple(input_list)

v = TestReadCacheWireVisitor()
p_origin.visit(v)
assert_pipeline_equal(self, p_origin, p_copy)


if __name__ == '__main__':
unittest.main()
7 changes: 4 additions & 3 deletions sdks/python/apache_beam/testing/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,13 +172,14 @@ class TestStream(PTransform):
output.
"""

def __init__(self, coder=coders.FastPrimitivesCoder(), events=None):
def __init__(self, coder=coders.FastPrimitivesCoder(), events=None,
output_tags=None):
super(TestStream, self).__init__()
assert coder is not None
self.coder = coder
self.watermarks = {None: timestamp.MIN_TIMESTAMP}
self._events = [] if events is None else list(events)
self.output_tags = set()
self._events = list(events) if events else list()
self.output_tags = set(output_tags) if output_tags else set()

def get_windowing(self, unused_inputs):
return core.Windowing(window.GlobalWindows())
Expand Down