Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -365,11 +365,17 @@ def set_cache_manager(self, cache_manager, pipeline):
if self.get_cache_manager(pipeline) is cache_manager:
# NOOP if setting to the same cache_manager.
return
# Check if the pipeline is already tracked as a user pipeline before cleanup.
is_user_pipeline = self._tracked_user_pipelines.get_user_pipeline(
pipeline) is pipeline
if self.get_cache_manager(pipeline):
# Invoke cleanup routine when a new cache_manager is forcefully set and
# current cache_manager is not None.
self.cleanup(pipeline)
self._cache_managers[str(id(pipeline))] = cache_manager
if is_user_pipeline:
# Re-track the user pipeline because the self.cleanup() call above evicts it.
self.add_user_pipeline(pipeline)

def get_cache_manager(self, pipeline, create_if_absent=False):
"""Gets the cache manager held by current Interactive Environment for the
Expand Down Expand Up @@ -468,8 +474,8 @@ def evict_recording_manager(self, pipeline):
def describe_all_recordings(self):
"""Returns a description of the recording for all watched pipelnes."""
return {
self.pipeline_id_to_pipeline(pid): rm.describe()
for pid, rm in self._recording_managers.items()
rm.user_pipeline: rm.describe()
for rm in self._recording_managers.values()
}

def set_pipeline_result(self, pipeline, result):
Expand Down
109 changes: 63 additions & 46 deletions sdks/python/apache_beam/runners/interactive/user_pipeline_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"""

import shutil
import threading
from typing import Iterator
from typing import Optional

Expand All @@ -39,13 +40,16 @@ class UserPipelineTracker:
derived pipelines.
"""
def __init__(self):
self._lock = threading.RLock()
self._user_pipelines: dict[beam.Pipeline, list[beam.Pipeline]] = {}
self._derived_pipelines: dict[beam.Pipeline] = {}
self._pid_to_pipelines: dict[beam.Pipeline] = {}
self._derived_pipelines: dict[beam.Pipeline, beam.Pipeline] = {}
self._pid_to_pipelines: dict[str, beam.Pipeline] = {}

def __iter__(self) -> Iterator[beam.Pipeline]:
"""Iterates through all the user pipelines."""
for p in self._user_pipelines:
with self._lock:
pipelines = list(self._user_pipelines.keys())
for p in pipelines:
yield p

def _key(self, pipeline: beam.Pipeline) -> str:
Expand All @@ -57,45 +61,57 @@ def evict(self, pipeline: beam.Pipeline) -> None:
Removes the given pipeline and derived pipelines if a user pipeline.
Otherwise, removes the given derived pipeline.
"""
user_pipeline = self.get_user_pipeline(pipeline)
if user_pipeline:
for d in self._user_pipelines[user_pipeline]:
del self._derived_pipelines[d]
del self._user_pipelines[user_pipeline]
elif pipeline in self._derived_pipelines:
del self._derived_pipelines[pipeline]
with self._lock:
if pipeline in self._user_pipelines:
for d in self._user_pipelines[pipeline]:
self._derived_pipelines.pop(d, None)
self._pid_to_pipelines.pop(self._key(d), None)
self._user_pipelines.pop(pipeline, None)
elif pipeline in self._derived_pipelines:
user_pipeline = self._derived_pipelines.pop(pipeline, None)
if user_pipeline in self._user_pipelines:
try:
self._user_pipelines[user_pipeline].remove(pipeline)
except ValueError:
pass
self._pid_to_pipelines.pop(self._key(pipeline), None)

def clear(self) -> None:
"""Clears the tracker of all user and derived pipelines."""
# Remove all local_tempdir of created pipelines.
for p in self._pid_to_pipelines.values():
shutil.rmtree(p.local_tempdir, ignore_errors=True)
with self._lock:
pipelines = list(self._pid_to_pipelines.values())
self._user_pipelines.clear()
self._derived_pipelines.clear()
self._pid_to_pipelines.clear()
Comment thread
shunping marked this conversation as resolved.

self._user_pipelines.clear()
self._derived_pipelines.clear()
self._pid_to_pipelines.clear()
for p in pipelines:
shutil.rmtree(p.local_tempdir, ignore_errors=True)

def get_pipeline(self, pid: str) -> Optional[beam.Pipeline]:
"""Returns the pipeline corresponding to the given pipeline id."""
return self._pid_to_pipelines.get(pid, None)
with self._lock:
return self._pid_to_pipelines.get(pid, None)

def add_user_pipeline(self, p: beam.Pipeline) -> beam.Pipeline:
"""Adds a user pipeline with an empty set of derived pipelines."""
self._memoize_pipieline(p)
with self._lock:
self._memoize_pipeline(p)

# Create a new node for the user pipeline if it doesn't exist already.
user_pipeline = self.get_user_pipeline(p)
if not user_pipeline:
user_pipeline = p
self._user_pipelines[p] = []
# Create a new node for the user pipeline if it doesn't exist already.
user_pipeline = self.get_user_pipeline(p)
if not user_pipeline:
user_pipeline = p
self._user_pipelines[p] = []

return user_pipeline
return user_pipeline

def _memoize_pipieline(self, p: beam.Pipeline) -> None:
def _memoize_pipeline(self, p: beam.Pipeline) -> None:
"""Memoizes the pid of the pipeline to the pipeline object."""
pid = self._key(p)
if pid not in self._pid_to_pipelines:
self._pid_to_pipelines[pid] = p
with self._lock:
if pid not in self._pid_to_pipelines:
self._pid_to_pipelines[pid] = p

def add_derived_pipeline(
self, maybe_user_pipeline: beam.Pipeline,
Expand All @@ -119,20 +135,21 @@ def add_derived_pipeline(
# Returns p.
ut.get_user_pipeline(derived2)
"""
self._memoize_pipieline(maybe_user_pipeline)
self._memoize_pipieline(derived_pipeline)
with self._lock:
self._memoize_pipeline(maybe_user_pipeline)
self._memoize_pipeline(derived_pipeline)

# Cannot add a derived pipeline twice.
assert derived_pipeline not in self._derived_pipelines
# Cannot add a derived pipeline twice.
assert derived_pipeline not in self._derived_pipelines

# Get the "true" user pipeline. This allows for the user to derive a
# pipeline from another derived pipeline, use both as arguments, and this
# method will still get the correct user pipeline.
user = self.add_user_pipeline(maybe_user_pipeline)
# Get the "true" user pipeline. This allows for the user to derive a
# pipeline from another derived pipeline, use both as arguments, and this
# method will still get the correct user pipeline.
user = self.add_user_pipeline(maybe_user_pipeline)

# Map the derived pipeline to the user pipeline.
self._derived_pipelines[derived_pipeline] = user
self._user_pipelines[user].append(derived_pipeline)
# Map the derived pipeline to the user pipeline.
self._derived_pipelines[derived_pipeline] = user
self._user_pipelines[user].append(derived_pipeline)

def get_user_pipeline(self, p: beam.Pipeline) -> Optional[beam.Pipeline]:
"""Returns the user pipeline of the given pipeline.
Expand All @@ -142,14 +159,14 @@ def get_user_pipeline(self, p: beam.Pipeline) -> Optional[beam.Pipeline]:
returns the same pipeline. If the given pipeline is a derived pipeline then
this returns the user pipeline.
"""
with self._lock:
# If `p` is a user pipeline then return it.
if p in self._user_pipelines:
return p

# If `p` is a user pipeline then return it.
if p in self._user_pipelines:
return p

# If `p` exists then return its user pipeline.
if p in self._derived_pipelines:
return self._derived_pipelines[p]
# If `p` exists then return its user pipeline.
if p in self._derived_pipelines:
return self._derived_pipelines[p]

# Otherwise, `p` is not in this tracker.
return None
# Otherwise, `p` is not in this tracker.
return None
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
# limitations under the License.
#

import threading
import unittest
from unittest.mock import patch

import apache_beam as beam
from apache_beam.runners.interactive.user_pipeline_tracker import UserPipelineTracker
Expand Down Expand Up @@ -202,6 +204,52 @@ def test_can_evict_user_pipeline(self):
self.assertIs(user2, ut.get_user_pipeline(derived21))
self.assertIs(user2, ut.get_user_pipeline(derived22))

def test_clear_race_condition(self):
ut = UserPipelineTracker()
# Add a pipeline so clear() has at least one element to iterate over.
p1 = beam.Pipeline()
derived1 = beam.Pipeline()
ut.add_derived_pipeline(p1, derived1)

# Set by the mock when clear() enters its loop. Signals the background
# worker to mutate.
in_loop_event = threading.Event()
# Set by the worker when mutation is complete. Signals mock that it can
# safely resume clear().
mutate_done_event = threading.Event()

def mock_rmtree(path, ignore_errors=False):
# Signal the worker that clear() is iterating.
in_loop_event.set()
# Pause here to give the worker thread time to perform the mutation.
mutate_done_event.wait(timeout=5)

def worker():
# Wait for clear() to start iterating.
if in_loop_event.wait(timeout=5):
# Concurrently mutate the tracker dictionary.
p2 = beam.Pipeline()
derived2 = beam.Pipeline()
try:
ut.add_derived_pipeline(p2, derived2)
finally:
# Resume the main thread.
mutate_done_event.set()

thread = threading.Thread(target=worker)
thread.start()

try:
# Intercept shutil.rmtree inside clear() to orchestrate the concurrent
# mutation.
with patch('shutil.rmtree', side_effect=mock_rmtree):
ut.clear()
finally:
# Avoid hanging tests if events are missed.
in_loop_event.set()
mutate_done_event.set()
thread.join()


if __name__ == '__main__':
unittest.main()
Loading