Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BEAM-10603] Add describe and cancel to RecordingManager #12703

Merged
merged 3 commits into from Sep 8, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -109,7 +109,7 @@ def update(self, e):
self._count += 1

def is_triggered(self):
return self._count > self._max_count
return self._count >= self._max_count


class ProcessingTimeLimiter(ElementLimiter):
Expand Down
Expand Up @@ -28,7 +28,7 @@ class CaptureLimitersTest(unittest.TestCase):
def test_count_limiter(self):
limiter = CountLimiter(5)

for e in range(5):
for e in range(4):
limiter.update(e)

self.assertFalse(limiter.is_triggered())
Expand Down
Expand Up @@ -63,6 +63,13 @@ def __hash__(self):
self.pcoll,
self.producer_version))

def to_key(self):
return CacheKey(
self.var,
self.version,
self.producer_version,
str(id(self.pcoll.pipeline)))


# TODO: turn this into a dataclass object when we finally get off of Python2.
class CacheKey:
Expand Down
84 changes: 68 additions & 16 deletions sdks/python/apache_beam/runners/interactive/recording_manager.py
Expand Up @@ -63,6 +63,12 @@ def var(self):
"""Returns the variable named that defined this PCollection."""
return self._var

def cache_key(self):
rohdesamuel marked this conversation as resolved.
Show resolved Hide resolved
# type: () -> str

"""Returns the cache key for this stream."""
return self._cache_key

def display_id(self, suffix):
# type: (str) -> str

Expand Down Expand Up @@ -135,7 +141,8 @@ def __init__(
result, # type: beam.runner.PipelineResult
pipeline_instrument, # type: beam.runners.interactive.PipelineInstrument
max_n, # type: int
max_duration_secs # type: float
max_duration_secs, # type: float
start_time_for_test=None # type: int
):

self._user_pipeline = user_pipeline
Expand All @@ -154,7 +161,8 @@ def __init__(
max_duration_secs)
for pcoll in pcolls
}
self._start = time.time()

self._start = start_time_for_test if start_time_for_test else time.time()
self._duration_secs = max_duration_secs
self._set_computed = bcj.is_cache_complete(str(id(user_pipeline)))

Expand All @@ -180,7 +188,7 @@ def _mark_all_computed(self):
self._result.cancel()
self._result.wait_until_finish()

time.sleep(0.5)
time.sleep(0.1)

# Mark the PCollection as computed so that Interactive Beam wouldn't need to
# re-compute.
Expand Down Expand Up @@ -231,13 +239,25 @@ def wait_until_finish(self):
self._mark_computed.join()
return self._result.state

def describe(self):
# type: () -> dict[str, int]

"""Returns a dictionary describing the cache and recording."""
cache_manager = ie.current_env().get_cache_manager(self._user_pipeline)

size = sum(
cache_manager.size('full', s.cache_key())
for s in self._streams.values())
return {'size': size, 'start': self._start}


class RecordingManager:
"""Manages recordings of PCollections for a given pipeline."""
def __init__(self, user_pipeline):
# type: (beam.Pipeline, List[Limiter]) -> None
self.user_pipeline = user_pipeline
self._pipeline_instrument = pi.PipelineInstrument(self.user_pipeline)
# type: (beam.Pipeline) -> None

self.user_pipeline = user_pipeline # type: beam.Pipeline
self._recordings = set() # type: Set[Recording]

def _watch(self, pcolls):
# type: (List[beam.pvalue.PCollection]) -> None
Expand All @@ -258,16 +278,44 @@ def _watch(self, pcolls):
ie.current_env().watch(
{'anonymous_pcollection_{}'.format(id(pcoll)): pcoll})

def clear(self, pcolls):
def _clear(self, pipeline_instrument):
# type: (List[beam.pvalue.PCollection]) -> None

"""Clears the cache of the given PCollections."""
"""Clears the recording of all non-source PCollections."""

cache_manager = ie.current_env().get_cache_manager(self.user_pipeline)
for pc in pcolls:
cache_key = self._pipeline_instrument.cache_key(pc)

# Only clear the PCollections that aren't being populated from the
# BackgroundCachingJob.
all_cached = set(
str(c.to_key()) for c in pipeline_instrument.cacheables.values())
rohdesamuel marked this conversation as resolved.
Show resolved Hide resolved
source_pcolls = getattr(cache_manager, 'capture_keys', set())
to_clear = all_cached - source_pcolls

for cache_key in to_clear:
cache_manager.clear('full', cache_key)

def cancel(self):
# type: (None) -> None

"""Cancels the current background recording job."""

bcj.attempt_to_cancel_background_caching_job(self.user_pipeline)

for r in self._recordings:
r.wait_until_finish()
self._recordings = set()

def describe(self):
# type: () -> dict[str, int]

"""Returns a dictionary describing the cache and recording."""

descriptions = [r.describe() for r in self._recordings]
size = sum(d['size'] for d in descriptions)
start = min(d['start'] for d in descriptions)
return {'size': size, 'start': start}

def record(self, pcolls, max_n, max_duration_secs):
# type: (List[beam.pvalue.PCollection], int, int) -> Recording

Expand All @@ -292,6 +340,7 @@ def record(self, pcolls, max_n, max_duration_secs):
# watch it. No validation is needed here because the watch logic can handle
# arbitrary variables.
self._watch(pcolls)
pipeline_instrument = pi.PipelineInstrument(self.user_pipeline)

# Attempt to run background caching job to record any sources.
if ie.current_env().is_in_ipython:
Expand All @@ -313,23 +362,26 @@ def record(self, pcolls, max_n, max_duration_secs):
if uncomputed_pcolls:
# Clear the cache of the given uncomputed PCollections because they are
# incomplete.
self.clear(uncomputed_pcolls)
self._clear(pipeline_instrument)

warnings.filterwarnings(
'ignore',
'options is deprecated since First stable release. References to '
'<pipeline>.options will not be supported',
category=DeprecationWarning)
result = pf.PipelineFragment(
list(uncomputed_pcolls), self.user_pipeline.options).run()
ie.current_env().set_pipeline_result(self.user_pipeline, result)
pf.PipelineFragment(list(uncomputed_pcolls),
self.user_pipeline.options).run()
result = ie.current_env().pipeline_result(self.user_pipeline)
else:
result = None

return Recording(
recording = Recording(
self.user_pipeline,
pcolls,
result,
self._pipeline_instrument,
pipeline_instrument,
max_n,
max_duration_secs)
self._recordings.add(recording)

return recording