From be2f92b688d0df597b939d7f9b0f4b308dc2b157 Mon Sep 17 00:00:00 2001 From: tlonny Date: Wed, 13 Dec 2023 20:45:03 +0000 Subject: [PATCH 1/2] Allow dominate to work in async contexts Using ContextVars to allow dominate to work within async contexts. Added unit tests to ensure code works as expected. --- dominate/dom_tag.py | 37 ++++++++++++++++-- tests/test_dom_tag_async.py | 75 +++++++++++++++++++++++++++++++++++++ 2 files changed, 109 insertions(+), 3 deletions(-) create mode 100644 tests/test_dom_tag_async.py diff --git a/dominate/dom_tag.py b/dominate/dom_tag.py index 532d916..774357a 100644 --- a/dominate/dom_tag.py +++ b/dominate/dom_tag.py @@ -23,6 +23,8 @@ from collections import defaultdict, namedtuple from functools import wraps import threading +from asyncio import get_event_loop +from contextvars import ContextVar try: # Python 3 @@ -37,20 +39,49 @@ basestring = str unicode = str - try: import greenlet except ImportError: greenlet = None +# We want dominate to work in async contexts - however, the problem is +# when we bind a tag using "with", we set what is essentially a global variable. +# If we are processing multiple documents at the same time, one context +# can "overwrite" the "bound tag" of another - this can cause documents to +# sort of bleed into one another... + +# The solution is to use a ContextVar - which provides async context local storage. +# We use this to store a unique ID for each async context. We then use thie ID to +# form the key (in _get_thread_context) that is used to index the _with_context defaultdict. +# The presense of this key ensures that each async context has its own stack and doesn't conflict. +async_context_id_counter = 5 +async_context_id = ContextVar('async_context_id', default = None) + +def _get_async_context_id(): + global async_context_id_counter + if async_context_id.get() is None: + async_context_id.set(async_context_id_counter) + async_context_id_counter += 1 + print(async_context_id.get()) + return async_context_id.get() def _get_thread_context(): context = [threading.current_thread()] + # Tag extra content information with a name to make sure + # a greenlet.getcurrent() == 1 doesn't get confused with a + # a _get_thread_context() == 1. if greenlet: - context.append(greenlet.getcurrent()) + context.append(("greenlet", greenlet.getcurrent())) + + try: + if get_event_loop().is_running(): + # Only add this extra information if we are actually in a running event loop + context.append(("async", _get_async_context_id())) + # A runtime error is raised if there is no async loop... + except RuntimeError: + pass return hash(tuple(context)) - class dom_tag(object): is_single = False # Tag does not require matching end tag (ex.
) is_pretty = True # Text inside the tag should be left as-is (ex.
)
diff --git a/tests/test_dom_tag_async.py b/tests/test_dom_tag_async.py
new file mode 100644
index 0000000..648f9c8
--- /dev/null
+++ b/tests/test_dom_tag_async.py
@@ -0,0 +1,75 @@
+from asyncio import gather, run, Semaphore
+from dominate.dom_tag import async_context_id
+from textwrap import dedent
+
+from dominate import tags
+
+# To simulate sleep without making the tests take a hella long time to complete
+# lets use a pair of semaphores to explicitly control when our coroutines run.
+# The order of execution will be marked as comments below:
+def test_async_bleed():
+    async def tag_routine_1(sem_1, sem_2):
+        root = tags.div(id = 1) # [1]
+        with root: # [2]
+            sem_2.release() # [3]
+            await sem_1.acquire() # [4]
+            tags.div(id = 2) # [11]
+        return str(root) # [12]
+
+    async def tag_routine_2(sem_1, sem_2):
+        await sem_2.acquire() # [5]
+        root = tags.div(id = 3) # [6]
+        with root: # [7]
+            tags.div(id = 4) # [8]
+        sem_1.release() # [9]
+        return str(root) # [10]
+
+    async def merge():
+        sem_1 = Semaphore(0)
+        sem_2 = Semaphore(0)
+        return await gather(
+            tag_routine_1(sem_1, sem_2), 
+            tag_routine_2(sem_1, sem_2)
+        )
+
+    # Set this test up for failure - pre-set the context to a non-None value.
+    # As it is already set, _get_async_context_id will not set it to a new, unique value
+    # and thus we won't be able to differentiate between the two contexts. This essentially simulates
+    # the behavior before our async fix was implemented (the bleed):
+    async_context_id.set(0)
+    tag_1, tag_2 = run(merge())
+
+    # This looks wrong - but its what we would expect if we don't
+    # properly handle async...
+    assert tag_1 == dedent("""\
+        
+
+
+
+
+
+ """).strip() + + assert tag_2 == dedent("""\ +
+
+
+ """).strip() + + # Okay, now lets do it right - lets clear the context. Now when each async function + # calls _get_async_context_id, it will get a unique value and we can differentiate. + async_context_id.set(None) + tag_1, tag_2 = run(merge()) + + # Ah, much better... + assert tag_1 == dedent("""\ +
+
+
+ """).strip() + + assert tag_2 == dedent("""\ +
+
+
+ """).strip() From 5746232b1f704e705b7f068899f16791d7fd3e38 Mon Sep 17 00:00:00 2001 From: tlonny Date: Thu, 14 Dec 2023 21:19:20 +0000 Subject: [PATCH 2/2] Small Fixes - Added .venv and .envrc to .gitignore (I use direnv and venv to keep my python environments isolated - I hope this is okay!) - Removed print statements I left in dom_tag during debugging - Replaced global incrementing int with UUID for contextvar ID generation - this zeroes the risk of race-hazards/collisions - _get_thread_context now returns a tuple vs. a hash of a tuple. Functionally not much changes - the underlying dictionary will still use the same hashing function but the only difference is that _if_ there is a collision, the dictionary will still be able to return the correct element --- .gitignore | 2 ++ dominate/dom_tag.py | 9 +++------ 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index 218ba63..eb93365 100644 --- a/.gitignore +++ b/.gitignore @@ -39,3 +39,5 @@ nosetests.xml .idea .idea/ +.venv/ +.envrc diff --git a/dominate/dom_tag.py b/dominate/dom_tag.py index 774357a..44ca8ba 100644 --- a/dominate/dom_tag.py +++ b/dominate/dom_tag.py @@ -24,6 +24,7 @@ from functools import wraps import threading from asyncio import get_event_loop +from uuid import uuid4 from contextvars import ContextVar try: @@ -54,15 +55,11 @@ # We use this to store a unique ID for each async context. We then use thie ID to # form the key (in _get_thread_context) that is used to index the _with_context defaultdict. # The presense of this key ensures that each async context has its own stack and doesn't conflict. -async_context_id_counter = 5 async_context_id = ContextVar('async_context_id', default = None) def _get_async_context_id(): - global async_context_id_counter if async_context_id.get() is None: - async_context_id.set(async_context_id_counter) - async_context_id_counter += 1 - print(async_context_id.get()) + async_context_id.set(uuid4().hex) return async_context_id.get() def _get_thread_context(): @@ -80,7 +77,7 @@ def _get_thread_context(): # A runtime error is raised if there is no async loop... except RuntimeError: pass - return hash(tuple(context)) + return tuple(context) class dom_tag(object): is_single = False # Tag does not require matching end tag (ex.
)