diff --git a/sdks/python/apache_beam/runners/interactive/interactive_environment.py b/sdks/python/apache_beam/runners/interactive/interactive_environment.py index bfb1a7f11905..b243d20ff857 100644 --- a/sdks/python/apache_beam/runners/interactive/interactive_environment.py +++ b/sdks/python/apache_beam/runners/interactive/interactive_environment.py @@ -365,11 +365,17 @@ def set_cache_manager(self, cache_manager, pipeline): if self.get_cache_manager(pipeline) is cache_manager: # NOOP if setting to the same cache_manager. return + # Check if the pipeline is already tracked as a user pipeline before cleanup. + is_user_pipeline = self._tracked_user_pipelines.get_user_pipeline( + pipeline) is pipeline if self.get_cache_manager(pipeline): # Invoke cleanup routine when a new cache_manager is forcefully set and # current cache_manager is not None. self.cleanup(pipeline) self._cache_managers[str(id(pipeline))] = cache_manager + if is_user_pipeline: + # Re-track the user pipeline because the self.cleanup() call above evicts it. + self.add_user_pipeline(pipeline) def get_cache_manager(self, pipeline, create_if_absent=False): """Gets the cache manager held by current Interactive Environment for the @@ -468,8 +474,8 @@ def evict_recording_manager(self, pipeline): def describe_all_recordings(self): """Returns a description of the recording for all watched pipelnes.""" return { - self.pipeline_id_to_pipeline(pid): rm.describe() - for pid, rm in self._recording_managers.items() + rm.user_pipeline: rm.describe() + for rm in self._recording_managers.values() } def set_pipeline_result(self, pipeline, result): diff --git a/sdks/python/apache_beam/runners/interactive/user_pipeline_tracker.py b/sdks/python/apache_beam/runners/interactive/user_pipeline_tracker.py index 53ee54ac8a35..4c7871c02bed 100644 --- a/sdks/python/apache_beam/runners/interactive/user_pipeline_tracker.py +++ b/sdks/python/apache_beam/runners/interactive/user_pipeline_tracker.py @@ -25,6 +25,7 @@ """ import shutil +import threading from typing import Iterator from typing import Optional @@ -39,13 +40,16 @@ class UserPipelineTracker: derived pipelines. """ def __init__(self): + self._lock = threading.RLock() self._user_pipelines: dict[beam.Pipeline, list[beam.Pipeline]] = {} - self._derived_pipelines: dict[beam.Pipeline] = {} - self._pid_to_pipelines: dict[beam.Pipeline] = {} + self._derived_pipelines: dict[beam.Pipeline, beam.Pipeline] = {} + self._pid_to_pipelines: dict[str, beam.Pipeline] = {} def __iter__(self) -> Iterator[beam.Pipeline]: """Iterates through all the user pipelines.""" - for p in self._user_pipelines: + with self._lock: + pipelines = list(self._user_pipelines.keys()) + for p in pipelines: yield p def _key(self, pipeline: beam.Pipeline) -> str: @@ -57,45 +61,57 @@ def evict(self, pipeline: beam.Pipeline) -> None: Removes the given pipeline and derived pipelines if a user pipeline. Otherwise, removes the given derived pipeline. """ - user_pipeline = self.get_user_pipeline(pipeline) - if user_pipeline: - for d in self._user_pipelines[user_pipeline]: - del self._derived_pipelines[d] - del self._user_pipelines[user_pipeline] - elif pipeline in self._derived_pipelines: - del self._derived_pipelines[pipeline] + with self._lock: + if pipeline in self._user_pipelines: + for d in self._user_pipelines[pipeline]: + self._derived_pipelines.pop(d, None) + self._pid_to_pipelines.pop(self._key(d), None) + self._user_pipelines.pop(pipeline, None) + elif pipeline in self._derived_pipelines: + user_pipeline = self._derived_pipelines.pop(pipeline, None) + if user_pipeline in self._user_pipelines: + try: + self._user_pipelines[user_pipeline].remove(pipeline) + except ValueError: + pass + self._pid_to_pipelines.pop(self._key(pipeline), None) def clear(self) -> None: """Clears the tracker of all user and derived pipelines.""" # Remove all local_tempdir of created pipelines. - for p in self._pid_to_pipelines.values(): - shutil.rmtree(p.local_tempdir, ignore_errors=True) + with self._lock: + pipelines = list(self._pid_to_pipelines.values()) + self._user_pipelines.clear() + self._derived_pipelines.clear() + self._pid_to_pipelines.clear() - self._user_pipelines.clear() - self._derived_pipelines.clear() - self._pid_to_pipelines.clear() + for p in pipelines: + shutil.rmtree(p.local_tempdir, ignore_errors=True) def get_pipeline(self, pid: str) -> Optional[beam.Pipeline]: """Returns the pipeline corresponding to the given pipeline id.""" - return self._pid_to_pipelines.get(pid, None) + with self._lock: + return self._pid_to_pipelines.get(pid, None) def add_user_pipeline(self, p: beam.Pipeline) -> beam.Pipeline: """Adds a user pipeline with an empty set of derived pipelines.""" - self._memoize_pipieline(p) + with self._lock: + self._memoize_pipeline(p) - # Create a new node for the user pipeline if it doesn't exist already. - user_pipeline = self.get_user_pipeline(p) - if not user_pipeline: - user_pipeline = p - self._user_pipelines[p] = [] + # Create a new node for the user pipeline if it doesn't exist already. + user_pipeline = self.get_user_pipeline(p) + if not user_pipeline: + user_pipeline = p + self._user_pipelines[p] = [] - return user_pipeline + return user_pipeline - def _memoize_pipieline(self, p: beam.Pipeline) -> None: + def _memoize_pipeline(self, p: beam.Pipeline) -> None: """Memoizes the pid of the pipeline to the pipeline object.""" pid = self._key(p) - if pid not in self._pid_to_pipelines: - self._pid_to_pipelines[pid] = p + with self._lock: + if pid not in self._pid_to_pipelines: + self._pid_to_pipelines[pid] = p def add_derived_pipeline( self, maybe_user_pipeline: beam.Pipeline, @@ -119,20 +135,21 @@ def add_derived_pipeline( # Returns p. ut.get_user_pipeline(derived2) """ - self._memoize_pipieline(maybe_user_pipeline) - self._memoize_pipieline(derived_pipeline) + with self._lock: + self._memoize_pipeline(maybe_user_pipeline) + self._memoize_pipeline(derived_pipeline) - # Cannot add a derived pipeline twice. - assert derived_pipeline not in self._derived_pipelines + # Cannot add a derived pipeline twice. + assert derived_pipeline not in self._derived_pipelines - # Get the "true" user pipeline. This allows for the user to derive a - # pipeline from another derived pipeline, use both as arguments, and this - # method will still get the correct user pipeline. - user = self.add_user_pipeline(maybe_user_pipeline) + # Get the "true" user pipeline. This allows for the user to derive a + # pipeline from another derived pipeline, use both as arguments, and this + # method will still get the correct user pipeline. + user = self.add_user_pipeline(maybe_user_pipeline) - # Map the derived pipeline to the user pipeline. - self._derived_pipelines[derived_pipeline] = user - self._user_pipelines[user].append(derived_pipeline) + # Map the derived pipeline to the user pipeline. + self._derived_pipelines[derived_pipeline] = user + self._user_pipelines[user].append(derived_pipeline) def get_user_pipeline(self, p: beam.Pipeline) -> Optional[beam.Pipeline]: """Returns the user pipeline of the given pipeline. @@ -142,14 +159,14 @@ def get_user_pipeline(self, p: beam.Pipeline) -> Optional[beam.Pipeline]: returns the same pipeline. If the given pipeline is a derived pipeline then this returns the user pipeline. """ + with self._lock: + # If `p` is a user pipeline then return it. + if p in self._user_pipelines: + return p - # If `p` is a user pipeline then return it. - if p in self._user_pipelines: - return p - - # If `p` exists then return its user pipeline. - if p in self._derived_pipelines: - return self._derived_pipelines[p] + # If `p` exists then return its user pipeline. + if p in self._derived_pipelines: + return self._derived_pipelines[p] - # Otherwise, `p` is not in this tracker. - return None + # Otherwise, `p` is not in this tracker. + return None diff --git a/sdks/python/apache_beam/runners/interactive/user_pipeline_tracker_test.py b/sdks/python/apache_beam/runners/interactive/user_pipeline_tracker_test.py index f7025b8b75bf..6fb8e4dbad99 100644 --- a/sdks/python/apache_beam/runners/interactive/user_pipeline_tracker_test.py +++ b/sdks/python/apache_beam/runners/interactive/user_pipeline_tracker_test.py @@ -15,7 +15,9 @@ # limitations under the License. # +import threading import unittest +from unittest.mock import patch import apache_beam as beam from apache_beam.runners.interactive.user_pipeline_tracker import UserPipelineTracker @@ -202,6 +204,52 @@ def test_can_evict_user_pipeline(self): self.assertIs(user2, ut.get_user_pipeline(derived21)) self.assertIs(user2, ut.get_user_pipeline(derived22)) + def test_clear_race_condition(self): + ut = UserPipelineTracker() + # Add a pipeline so clear() has at least one element to iterate over. + p1 = beam.Pipeline() + derived1 = beam.Pipeline() + ut.add_derived_pipeline(p1, derived1) + + # Set by the mock when clear() enters its loop. Signals the background + # worker to mutate. + in_loop_event = threading.Event() + # Set by the worker when mutation is complete. Signals mock that it can + # safely resume clear(). + mutate_done_event = threading.Event() + + def mock_rmtree(path, ignore_errors=False): + # Signal the worker that clear() is iterating. + in_loop_event.set() + # Pause here to give the worker thread time to perform the mutation. + mutate_done_event.wait(timeout=5) + + def worker(): + # Wait for clear() to start iterating. + if in_loop_event.wait(timeout=5): + # Concurrently mutate the tracker dictionary. + p2 = beam.Pipeline() + derived2 = beam.Pipeline() + try: + ut.add_derived_pipeline(p2, derived2) + finally: + # Resume the main thread. + mutate_done_event.set() + + thread = threading.Thread(target=worker) + thread.start() + + try: + # Intercept shutil.rmtree inside clear() to orchestrate the concurrent + # mutation. + with patch('shutil.rmtree', side_effect=mock_rmtree): + ut.clear() + finally: + # Avoid hanging tests if events are missed. + in_loop_event.set() + mutate_done_event.set() + thread.join() + if __name__ == '__main__': unittest.main()