Skip to content
This repository has been archived by the owner on Feb 21, 2023. It is now read-only.

Commit

Permalink
Use ContextVar to track Task-local Lock token
Browse files Browse the repository at this point in the history
The Lock implementation in redis-py uses thread-local storage
so that multiple threads using the same Lock instance can
acquire the Lock from each other.

Thread-local storage ensures that each thread sees a different
token value.

Thread-local storage does not apply in the Task-based concurrency
that asyncio programs use. To achieve a similar effect, we need
to embed a ContextVar instance within each Lock and store the Lock
instance's token withint he ContextVar instance. This allows every
Task that uses the same Lock instance to see a different token.

Thus, if both Task A and Task B refer to Lock 1, Task A can "acquire"
Lock 1 and block Task B from acquiring the same Lock until Task A
"releases" the Lock.

NOTE: The Python documentation suggests only storing ContextVar
instances in the top-level module scope due to issues around
garbage collection. That won't work in the current design of
Lock. For lack of a better alternative, and to preserve the
original design of Lock taken from redis-py, we have created
instances of ContextVar within instances of Lock.

Fixes #1040.
  • Loading branch information
Andrew Brookins committed Jul 29, 2021
1 parent b2952d9 commit 028da4f
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 75 deletions.
46 changes: 23 additions & 23 deletions aioredis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,7 +959,7 @@ def lock(
sleep: float = 0.1,
blocking_timeout: float = None,
lock_class: Type[Lock] = None,
thread_local=True,
task_local: bool = True,
) -> Lock:
"""
Return a new Lock object using key ``name`` that mimics
Expand All @@ -979,31 +979,31 @@ def lock(
``lock_class`` forces the specified lock implementation.
``thread_local`` indicates whether the lock token is placed in
thread-local storage. By default, the token is placed in thread local
storage so that a thread only sees its token, not a token set by
another thread. Consider the following timeline:
``task_local`` indicates whether the lock token is placed in task-local
storage. By default, the token is placed in a contextvar so that a Task
only sees its token, not a token set by another Task. Consider the
following timeline:
time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds.
thread-1 sets the token to "abc"
time: 1, thread-2 blocks trying to acquire `my-lock` using the
time: 0, task-1 acquires `my-lock`, with a timeout of 5 seconds.
task-1 sets the token to "abc"
time: 1, task-2 blocks trying to acquire `my-lock` using the
Lock instance.
time: 5, thread-1 has not yet completed. redis expires the lock
time: 5, task-1 has not yet completed. redis expires the lock
key.
time: 5, thread-2 acquired `my-lock` now that it's available.
thread-2 sets the token to "xyz"
time: 6, thread-1 finishes its work and calls release(). if the
token is *not* stored in thread local storage, then
thread-1 would see the token value as "xyz" and would be
able to successfully release the thread-2's lock.
In some use cases it's necessary to disable thread local storage. For
example, if you have code where one thread acquires a lock and passes
that lock instance to a worker thread to release later. If thread
local storage isn't disabled in this case, the worker thread won't see
the token set by the thread that acquired the lock. Our assumption
time: 5, task-2 acquired `my-lock` now that it's available.
task-2 sets the token to "xyz"
time: 6, task-1 finishes its work and calls release(). if the
token is *not* stored in a contextvar, then task-1 would
see the token value as "xyz" and would be able to
successfully release the task-2's lock.
In some use cases it's necessary to disable task-local storage. For
example, if you have code where one Task acquires a lock and passes
that lock instance to a another Task to release later. If task-local
storage isn't disabled in this case, the other Task won't see
the token set by the Task that acquired the lock. Our assumption
is that these cases aren't common and as such default to using
thread local storage."""
task-local storage."""
if lock_class is None:
lock_class = Lock
return lock_class(
Expand All @@ -1012,7 +1012,7 @@ def lock(
timeout=timeout,
sleep=sleep,
blocking_timeout=blocking_timeout,
thread_local=thread_local,
task_local=task_local,
)

def pubsub(self, **kwargs) -> "PubSub":
Expand Down
108 changes: 62 additions & 46 deletions aioredis/lock.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
import threading
import contextvars
import uuid
from types import SimpleNamespace
from typing import TYPE_CHECKING, Awaitable, NoReturn, Union

from aioredis.exceptions import LockError, LockNotOwnedError
Expand All @@ -10,6 +9,24 @@
from aioredis import Redis


class SimpleToken:
"""
Simple storage for a Lock that isn't using task-local storage.
This class allows us to share the same interface as the ContextVar
that a Lock uses for task-local storage.
"""

def __init__(self):
self.token = ""

def get(self):
return self.token

def set(self, token: str):
self.token = token


class Lock:
"""
A shared, distributed Lock. Using Redis for locking allows the Lock
Expand Down Expand Up @@ -83,67 +100,63 @@ def __init__(
sleep: float = 0.1,
blocking: bool = True,
blocking_timeout: float = None,
thread_local: bool = True,
task_local: bool = True,
):
"""
Create a new Lock instance named ``name`` using the Redis client
supplied by ``redis``.
Return a new Lock object using key ``name`` that mimics
the behavior of threading.Lock.
``timeout`` indicates a maximum life for the lock.
If specified, ``timeout`` indicates a maximum life for the lock.
By default, it will remain locked until release() is called.
``timeout`` can be specified as a float or integer, both representing
the number of seconds to wait.
``sleep`` indicates the amount of time to sleep per loop iteration
when the lock is in blocking mode and another client is currently
holding the lock.
``blocking`` indicates whether calling ``acquire`` should block until
the lock has been acquired or to fail immediately, causing ``acquire``
to return False and the lock not being acquired. Defaults to True.
Note this value can be overridden by passing a ``blocking``
argument to ``acquire``.
``blocking_timeout`` indicates the maximum amount of time in seconds to
spend trying to acquire the lock. A value of ``None`` indicates
continue trying forever. ``blocking_timeout`` can be specified as a
float or integer, both representing the number of seconds to wait.
``thread_local`` indicates whether the lock token is placed in
thread-local storage. By default, the token is placed in thread local
storage so that a thread only sees its token, not a token set by
another thread. Consider the following timeline:
``lock_class`` forces the specified lock implementation.
``task_local`` indicates whether the lock token is placed in task-local
storage. By default, the token is placed in a contextvar so that a Task
only sees its token, not a token set by another Task. Consider the
following timeline:
time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds.
thread-1 sets the token to "abc"
time: 1, thread-2 blocks trying to acquire `my-lock` using the
time: 0, task-1 acquires `my-lock`, with a timeout of 5 seconds.
task-1 sets the token to "abc"
time: 1, task-2 blocks trying to acquire `my-lock` using the
Lock instance.
time: 5, thread-1 has not yet completed. redis expires the lock
time: 5, task-1 has not yet completed. redis expires the lock
key.
time: 5, thread-2 acquired `my-lock` now that it's available.
thread-2 sets the token to "xyz"
time: 6, thread-1 finishes its work and calls release(). if the
token is *not* stored in thread local storage, then
thread-1 would see the token value as "xyz" and would be
able to successfully release the thread-2's lock.
In some use cases it's necessary to disable thread local storage. For
example, if you have code where one thread acquires a lock and passes
that lock instance to a worker thread to release later. If thread
local storage isn't disabled in this case, the worker thread won't see
the token set by the thread that acquired the lock. Our assumption
time: 5, task-2 acquired `my-lock` now that it's available.
task-2 sets the token to "xyz"
time: 6, task-1 finishes its work and calls release(). if the
token is *not* stored in a contextvar, then task-1 would
see the token value as "xyz" and would be able to
successfully release the task-2's lock.
In some use cases it's necessary to disable task-local storage. For
example, if you have code where one Task acquires a lock and passes
that lock instance to a another Task to release later. If task-local
storage isn't disabled in this case, the other Task won't see
the token set by the Task that acquired the lock. Our assumption
is that these cases aren't common and as such default to using
thread local storage.
task-local storage.
"""
self.redis = redis
self.name = name
self.timeout = timeout
self.sleep = sleep
self.blocking = blocking
self.blocking_timeout = blocking_timeout
self.thread_local = bool(thread_local)
self.local = threading.local() if self.thread_local else SimpleNamespace()
self.local.token = None
self.task_local = bool(task_local)
self.local_token = (
contextvars.ContextVar("lock") if self.task_local else SimpleToken()
)
self.local_token.set(None)
self.register_scripts()

def register_scripts(self):
Expand Down Expand Up @@ -201,7 +214,7 @@ async def acquire(
stop_trying_at = loop.time() + blocking_timeout
while True:
if await self.do_acquire(token):
self.local.token = token
self.local_token.set(token)
return True
if not blocking:
return False
Expand Down Expand Up @@ -236,14 +249,15 @@ async def owned(self) -> bool:
if stored_token and not isinstance(stored_token, bytes):
encoder = self.redis.connection_pool.get_encoder()
stored_token = encoder.encode(stored_token)
return self.local.token is not None and stored_token == self.local.token
our_token = self.local_token.get()
return our_token is not None and stored_token == our_token

def release(self) -> Awaitable[NoReturn]:
"""Releases the already acquired lock"""
expected_token = self.local.token
expected_token = self.local_token.get()
if expected_token is None:
raise LockError("Cannot release an unlocked lock")
self.local.token = None
self.local_token.set(None)
return self.do_release(expected_token)

async def do_release(self, expected_token: bytes):
Expand All @@ -267,18 +281,19 @@ def extend(
the lock's existing ttl. If True, replace the lock's ttl with
`additional_time`.
"""
if self.local.token is None:
if self.local_token.get() is None:
raise LockError("Cannot extend an unlocked lock")
if self.timeout is None:
raise LockError("Cannot extend a lock with no timeout")
return self.do_extend(additional_time, replace_ttl)

async def do_extend(self, additional_time, replace_ttl) -> bool:
additional_time = int(additional_time * 1000)
token = self.local_token.get()
if not bool(
await self.lua_extend(
keys=[self.name],
args=[self.local.token, additional_time, replace_ttl and "1" or "0"],
args=[token, additional_time, replace_ttl and "1" or "0"],
client=self.redis,
)
):
Expand All @@ -289,17 +304,18 @@ def reacquire(self) -> Awaitable[bool]:
"""
Resets a TTL of an already acquired lock back to a timeout value.
"""
if self.local.token is None:
if self.local_token.get() is None:
raise LockError("Cannot reacquire an unlocked lock")
if self.timeout is None:
raise LockError("Cannot reacquire a lock with no timeout")
return self.do_reacquire()

async def do_reacquire(self) -> bool:
timeout = int(self.timeout * 1000)
token = self.local_token.get()
if not bool(
await self.lua_reacquire(
keys=[self.name], args=[self.local.token, timeout], client=self.redis
keys=[self.name], args=[token, timeout], client=self.redis
)
):
raise LockNotOwnedError("Cannot reacquire a lock that's" " no longer owned")
Expand Down
12 changes: 6 additions & 6 deletions tests/test_lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def get_lock(self, redis, *args, **kwargs):
async def test_lock(self, r):
lock = self.get_lock(r, "foo")
assert await lock.acquire(blocking=False)
assert await r.get("foo") == lock.local.token
assert await r.get("foo") == lock.local_token.get()
assert await r.ttl("foo") == -1
await lock.release()
assert await r.get("foo") is None
Expand All @@ -32,17 +32,17 @@ async def test_lock_token(self, r):
await self._test_lock_token(r, lock)

async def test_lock_token_thread_local_false(self, r):
lock = self.get_lock(r, "foo", thread_local=False)
lock = self.get_lock(r, "foo", task_local=False)
await self._test_lock_token(r, lock)

async def _test_lock_token(self, r, lock):
assert await lock.acquire(blocking=False, token="test")
assert await r.get("foo") == b"test"
assert lock.local.token == b"test"
assert lock.local_token.get() == b"test"
assert await r.ttl("foo") == -1
await lock.release()
assert await r.get("foo") is None
assert lock.local.token is None
assert lock.local_token.get() is None

async def test_locked(self, r):
lock = self.get_lock(r, "foo")
Expand Down Expand Up @@ -114,7 +114,7 @@ async def test_context_manager(self, r):
# blocking_timeout prevents a deadlock if the lock can't be acquired
# for some reason
async with self.get_lock(r, "foo", blocking_timeout=0.2) as lock:
assert await r.get("foo") == lock.local.token
assert await r.get("foo") == lock.local_token.get()
assert await r.get("foo") is None

async def test_context_manager_raises_when_locked_not_acquired(self, r):
Expand Down Expand Up @@ -149,7 +149,7 @@ async def test_releasing_lock_no_longer_owned_raises_error(self, r):
with pytest.raises(LockNotOwnedError):
await lock.release()
# even though we errored, the token is still cleared
assert lock.local.token is None
assert lock.local_token.get() is None

async def test_extend_lock(self, r):
lock = self.get_lock(r, "foo", timeout=10)
Expand Down

0 comments on commit 028da4f

Please sign in to comment.