Skip to content

Commit

Permalink
Merge 5746232 into 4a66756
Browse files Browse the repository at this point in the history
  • Loading branch information
tlonny committed Dec 14, 2023
2 parents 4a66756 + 5746232 commit 9adeac7
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 4 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Expand Up @@ -39,3 +39,5 @@ nosetests.xml

.idea
.idea/
.venv/
.envrc
36 changes: 32 additions & 4 deletions dominate/dom_tag.py
Expand Up @@ -23,6 +23,9 @@
from collections import defaultdict, namedtuple
from functools import wraps
import threading
from asyncio import get_event_loop
from uuid import uuid4
from contextvars import ContextVar

try:
# Python 3
Expand All @@ -37,19 +40,44 @@
basestring = str
unicode = str


try:
import greenlet
except ImportError:
greenlet = None

# We want dominate to work in async contexts - however, the problem is
# when we bind a tag using "with", we set what is essentially a global variable.
# If we are processing multiple documents at the same time, one context
# can "overwrite" the "bound tag" of another - this can cause documents to
# sort of bleed into one another...

# The solution is to use a ContextVar - which provides async context local storage.
# We use this to store a unique ID for each async context. We then use thie ID to
# form the key (in _get_thread_context) that is used to index the _with_context defaultdict.
# The presense of this key ensures that each async context has its own stack and doesn't conflict.
async_context_id = ContextVar('async_context_id', default = None)

def _get_async_context_id():
if async_context_id.get() is None:
async_context_id.set(uuid4().hex)
return async_context_id.get()

def _get_thread_context():
context = [threading.current_thread()]
# Tag extra content information with a name to make sure
# a greenlet.getcurrent() == 1 doesn't get confused with a
# a _get_thread_context() == 1.
if greenlet:
context.append(greenlet.getcurrent())
return hash(tuple(context))

context.append(("greenlet", greenlet.getcurrent()))

try:
if get_event_loop().is_running():
# Only add this extra information if we are actually in a running event loop
context.append(("async", _get_async_context_id()))
# A runtime error is raised if there is no async loop...
except RuntimeError:
pass
return tuple(context)

class dom_tag(object):
is_single = False # Tag does not require matching end tag (ex. <hr/>)
Expand Down
75 changes: 75 additions & 0 deletions tests/test_dom_tag_async.py
@@ -0,0 +1,75 @@
from asyncio import gather, run, Semaphore
from dominate.dom_tag import async_context_id
from textwrap import dedent

from dominate import tags

# To simulate sleep without making the tests take a hella long time to complete
# lets use a pair of semaphores to explicitly control when our coroutines run.
# The order of execution will be marked as comments below:
def test_async_bleed():
async def tag_routine_1(sem_1, sem_2):
root = tags.div(id = 1) # [1]
with root: # [2]
sem_2.release() # [3]
await sem_1.acquire() # [4]
tags.div(id = 2) # [11]
return str(root) # [12]

async def tag_routine_2(sem_1, sem_2):
await sem_2.acquire() # [5]
root = tags.div(id = 3) # [6]
with root: # [7]
tags.div(id = 4) # [8]
sem_1.release() # [9]
return str(root) # [10]

async def merge():
sem_1 = Semaphore(0)
sem_2 = Semaphore(0)
return await gather(
tag_routine_1(sem_1, sem_2),
tag_routine_2(sem_1, sem_2)
)

# Set this test up for failure - pre-set the context to a non-None value.
# As it is already set, _get_async_context_id will not set it to a new, unique value
# and thus we won't be able to differentiate between the two contexts. This essentially simulates
# the behavior before our async fix was implemented (the bleed):
async_context_id.set(0)
tag_1, tag_2 = run(merge())

# This looks wrong - but its what we would expect if we don't
# properly handle async...
assert tag_1 == dedent("""\
<div id="1">
<div id="3">
<div id="4"></div>
</div>
<div id="2"></div>
</div>
""").strip()

assert tag_2 == dedent("""\
<div id="3">
<div id="4"></div>
</div>
""").strip()

# Okay, now lets do it right - lets clear the context. Now when each async function
# calls _get_async_context_id, it will get a unique value and we can differentiate.
async_context_id.set(None)
tag_1, tag_2 = run(merge())

# Ah, much better...
assert tag_1 == dedent("""\
<div id="1">
<div id="2"></div>
</div>
""").strip()

assert tag_2 == dedent("""\
<div id="3">
<div id="4"></div>
</div>
""").strip()

0 comments on commit 9adeac7

Please sign in to comment.