Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 23 additions & 12 deletions sdks/python/apache_beam/utils/subprocess_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class _SharedCache:
def __init__(self, constructor, destructor):
self._constructor = constructor
self._destructor = destructor
self._live_owners = set()
self._live_owners = {}
self._cache = {}
self._lock = threading.Lock()
self._counter = 0
Expand All @@ -82,10 +82,10 @@ def _next_id(self):
self._counter += 1
return self._counter

def register(self):
def register(self, is_context=False):
with self._lock:
owner = self._next_id()
self._live_owners.add(owner)
self._live_owners[owner] = is_context
return owner

def purge(self, owner):
Expand All @@ -97,7 +97,7 @@ def purge(self, owner):
"shutdown, the subprocess was already cleaned up earlier.",
owner)
return
self._live_owners.remove(owner)
del self._live_owners[owner]
for key, entry in list(self._cache.items()):
if owner in entry.owners:
entry.owners.remove(owner)
Expand All @@ -108,14 +108,22 @@ def purge(self, owner):
for value in to_delete:
self._destructor(value)

def get(self, *key):
def get(self, *key, owner=None):
if not self._live_owners:
raise RuntimeError("At least one owner must be registered.")
with self._lock:
if key not in self._cache:
self._cache[key] = _SharedCacheEntry(self._constructor(*key), set())
for owner in self._live_owners:
if owner is not None:
if owner not in self._live_owners:
raise RuntimeError("The requesting owner must be registered.")
self._cache[key].owners.add(owner)
for live_owner, is_context in self._live_owners.items():
if is_context:
self._cache[key].owners.add(live_owner)
else:
for live_owner in self._live_owners:
self._cache[key].owners.add(live_owner)
return self._cache[key].obj

def force_remove(self, *key):
Expand Down Expand Up @@ -180,7 +188,7 @@ def cache_subprocesses(cls):
These subprocesses may be shared with other contexts as well.
"""
try:
unique_id = cls._cache.register()
unique_id = cls._cache.register(is_context=True)
yield
finally:
cls._cache.purge(unique_id)
Expand Down Expand Up @@ -214,7 +222,7 @@ def start(self):
channel_ready = grpc.channel_ready_future(self._grpc_channel)
while True:
if process is not None and process.poll() is not None:
_LOGGER.error("Started job service with %s", process.args)
_LOGGER.error("Failed to start job service with %s", process.args)
raise RuntimeError(
'Service failed to start up with error %s' % process.poll())
try:
Expand All @@ -235,15 +243,16 @@ def start(self):
def start_process(self):
if self._owner_id is not None:
self._cache.purge(self._owner_id)
self._owner_id = self._cache.register()
return self._cache.get(tuple(self._cmd), self._port, self._logger)
self._owner_id = self._cache.register(is_context=False)
return self._cache.get(
tuple(self._cmd), self._port, self._logger, owner=self._owner_id)

def _really_start_process(cmd, port, logger):
if not port:
port, = pick_port(None)
cmd = [arg.replace('{{PORT}}', str(port)) for arg in cmd] # pylint: disable=not-an-iterable
endpoint = 'localhost:%s' % port
_LOGGER.info("Starting service with %s", str(cmd).replace("',", "'"))
_LOGGER.warning("Really starting service at %s with cmd: %s", endpoint, cmd)
process = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)

Expand Down Expand Up @@ -295,9 +304,11 @@ def stop_force(self):
self._grpc_channel = None

def _really_stop_process(process_and_endpoint):
process, _ = process_and_endpoint # pylint: disable=unpacking-non-sequence
process, endpoint = process_and_endpoint # pylint: disable=unpacking-non-sequence
if not process:
return
_LOGGER.warning(
"Really destroying service at %s with cmd: %s", endpoint, process.args)
for _ in range(5):
if process.poll() is not None:
break
Expand Down
57 changes: 52 additions & 5 deletions sdks/python/apache_beam/utils/subprocess_server_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,16 +402,16 @@ def mock_unregister(cb):
self.assertEqual(len(registered_callbacks), 1)

def test_concurrent_purge_race_condition(self):
# Concurrent threads attempting to check memebership and call purge for the same owner.
# Here we explicitly define a synchronized set to mimic the behavior of _live_owners.
# This set will block two threads on __contains__, allowing us to test the race condition.
# Concurrent threads attempting to check membership and call purge for the same owner.
# Here we explicitly define a synchronized dict to mimic the behavior of _live_owners.
# This dict will block two threads on __contains__, allowing us to test the race condition.
cache = subprocess_server._SharedCache(lambda x: "obj", lambda x: None)
owner = cache.register()

barrier = threading.Barrier(2)
exceptions = []

class SynchronizedSet(set):
class SynchronizedDict(dict):
def __contains__(self, item):
res = super().__contains__(item)
try:
Expand All @@ -421,7 +421,7 @@ def __contains__(self, item):
pass
return res

cache._live_owners = SynchronizedSet(cache._live_owners)
cache._live_owners = SynchronizedDict(cache._live_owners)

def purge_worker():
try:
Expand Down Expand Up @@ -551,6 +551,53 @@ def __init__(self):
# Clean up the other owner
cache.purge(other_owner)

def test_non_context_owners_do_not_share_keys(self):
cache = subprocess_server._SharedCache(self.with_prefix, lambda x: None)
# owner1 is a non-context owner (e.g., prism)
owner1 = cache.register(is_context=False)
a = cache.get('a', owner=owner1)

# owner2 is another non-context owner (e.g., short-lived expansion service)
owner2 = cache.register(is_context=False)
b = cache.get('b', owner=owner2)

# Verify that owner1 does not own 'b'
self.assertNotIn(owner1, cache._cache[('b', )].owners)

# Verify that owner2 does not own 'a'
self.assertNotIn(owner2, cache._cache[('a', )].owners)

# Purging owner2 should immediately destroy/remove 'b'
cache.purge(owner2)
self.assertNotIn(('b', ), cache._cache)

# 'a' is still alive because owner1 is still registered
self.assertIn(('a', ), cache._cache)

# Purging owner1 should destroy/remove 'a'
cache.purge(owner1)
self.assertNotIn(('a', ), cache._cache)

def test_context_owner_owns_all_keys(self):
cache = subprocess_server._SharedCache(self.with_prefix, lambda x: None)
# owner1 is a non-context owner (e.g., prism)
owner1 = cache.register(is_context=False)

# owner2 is a context owner (e.g., cache_subprocesses)
owner2 = cache.register(is_context=True)

# owner3 is another non-context owner (e.g., short-lived service)
owner3 = cache.register(is_context=False)

# owner3 requests 'b'
b = cache.get('b', owner=owner3)

# owner2 (context) should own 'b'
self.assertIn(owner2, cache._cache[('b', )].owners)

# owner1 (non-context) should NOT own 'b'
self.assertNotIn(owner1, cache._cache[('b', )].owners)


if __name__ == '__main__':
unittest.main()
Loading