Skip to content

Commit

Permalink
Enabled Event and CapacityLimiter to be instantiated outside an event…
Browse files Browse the repository at this point in the history
… loop (#651)
  • Loading branch information
agronholm committed Dec 14, 2023
1 parent 44ca5ea commit 28516e2
Show file tree
Hide file tree
Showing 4 changed files with 232 additions and 3 deletions.
2 changes: 2 additions & 0 deletions docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.
- Add support for ``byte``-based paths in ``connect_unix``, ``create_unix_listeners``,
``create_unix_datagram_socket``, and ``create_connected_unix_datagram_socket``. (PR by
Lura Skye)
- Enabled the ``Event`` and ``CapacityLimiter`` classes to be instantiated outside an
event loop thread
- Fixed adjusting the total number of tokens in a ``CapacityLimiter`` on asyncio failing
to wake up tasks waiting to acquire the limiter in certain edge cases (fixed with help
from Egor Blagov)
Expand Down
4 changes: 3 additions & 1 deletion src/anyio/_backends/_trio.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,9 @@ def set(self) -> None:


class CapacityLimiter(BaseCapacityLimiter):
def __new__(cls, *args: object, **kwargs: object) -> CapacityLimiter:
def __new__(
cls, *args: Any, original: trio.CapacityLimiter | None = None
) -> CapacityLimiter:
return object.__new__(cls)

def __init__(
Expand Down
135 changes: 133 additions & 2 deletions src/anyio/_core/_synchronization.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from __future__ import annotations

import math
from collections import deque
from dataclasses import dataclass
from types import TracebackType

from sniffio import AsyncLibraryNotFoundError

from ..lowlevel import cancel_shielded_checkpoint, checkpoint, checkpoint_if_cancelled
from ._eventloop import get_async_backend
from ._exceptions import BusyResourceError, WouldBlock
Expand Down Expand Up @@ -76,7 +79,10 @@ class SemaphoreStatistics:

class Event:
def __new__(cls) -> Event:
return get_async_backend().create_event()
try:
return get_async_backend().create_event()
except AsyncLibraryNotFoundError:
return EventAdapter()

def set(self) -> None:
"""Set the flag, notifying all listeners."""
Expand All @@ -101,6 +107,35 @@ def statistics(self) -> EventStatistics:
raise NotImplementedError


class EventAdapter(Event):
_internal_event: Event | None = None

def __new__(cls) -> EventAdapter:
return object.__new__(cls)

@property
def _event(self) -> Event:
if self._internal_event is None:
self._internal_event = get_async_backend().create_event()

return self._internal_event

def set(self) -> None:
self._event.set()

def is_set(self) -> bool:
return self._internal_event is not None and self._internal_event.is_set()

async def wait(self) -> None:
await self._event.wait()

def statistics(self) -> EventStatistics:
if self._internal_event is None:
return EventStatistics(tasks_waiting=0)

return self._internal_event.statistics()


class Lock:
_owner_task: TaskInfo | None = None

Expand Down Expand Up @@ -373,7 +408,10 @@ def statistics(self) -> SemaphoreStatistics:

class CapacityLimiter:
def __new__(cls, total_tokens: float) -> CapacityLimiter:
return get_async_backend().create_capacity_limiter(total_tokens)
try:
return get_async_backend().create_capacity_limiter(total_tokens)
except AsyncLibraryNotFoundError:
return CapacityLimiterAdapter(total_tokens)

async def __aenter__(self) -> None:
raise NotImplementedError
Expand Down Expand Up @@ -482,6 +520,99 @@ def statistics(self) -> CapacityLimiterStatistics:
raise NotImplementedError


class CapacityLimiterAdapter(CapacityLimiter):
_internal_limiter: CapacityLimiter | None = None

def __new__(cls, total_tokens: float) -> CapacityLimiterAdapter:
return object.__new__(cls)

def __init__(self, total_tokens: float) -> None:
self.total_tokens = total_tokens

@property
def _limiter(self) -> CapacityLimiter:
if self._internal_limiter is None:
self._internal_limiter = get_async_backend().create_capacity_limiter(
self._total_tokens
)

return self._internal_limiter

async def __aenter__(self) -> None:
await self._limiter.__aenter__()

async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> bool | None:
return await self._limiter.__aexit__(exc_type, exc_val, exc_tb)

@property
def total_tokens(self) -> float:
if self._internal_limiter is None:
return self._total_tokens

return self._internal_limiter.total_tokens

@total_tokens.setter
def total_tokens(self, value: float) -> None:
if not isinstance(value, int) and value is not math.inf:
raise TypeError("total_tokens must be an int or math.inf")
elif value < 1:
raise ValueError("total_tokens must be >= 1")

if self._internal_limiter is None:
self._total_tokens = value
return

self._limiter.total_tokens = value

@property
def borrowed_tokens(self) -> int:
if self._internal_limiter is None:
return 0

return self._internal_limiter.borrowed_tokens

@property
def available_tokens(self) -> float:
if self._internal_limiter is None:
return self._total_tokens

return self._internal_limiter.available_tokens

def acquire_nowait(self) -> None:
self._limiter.acquire_nowait()

def acquire_on_behalf_of_nowait(self, borrower: object) -> None:
self._limiter.acquire_on_behalf_of_nowait(borrower)

async def acquire(self) -> None:
await self._limiter.acquire()

async def acquire_on_behalf_of(self, borrower: object) -> None:
await self._limiter.acquire_on_behalf_of(borrower)

def release(self) -> None:
self._limiter.release()

def release_on_behalf_of(self, borrower: object) -> None:
self._limiter.release_on_behalf_of(borrower)

def statistics(self) -> CapacityLimiterStatistics:
if self._internal_limiter is None:
return CapacityLimiterStatistics(
borrowed_tokens=0,
total_tokens=self.total_tokens,
borrowers=(),
tasks_waiting=0,
)

return self._internal_limiter.statistics()


class ResourceGuard:
"""
A context manager for ensuring that a resource is only used by a single task at a
Expand Down
94 changes: 94 additions & 0 deletions tests/test_synchronization.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
from typing import Any

import pytest

Expand All @@ -13,6 +14,7 @@
WouldBlock,
create_task_group,
fail_after,
run,
to_thread,
wait_all_tasks_blocked,
)
Expand Down Expand Up @@ -141,6 +143,21 @@ async def acquire() -> None:
task1.cancel()
await asyncio.wait_for(task2, 1)

def test_instantiate_outside_event_loop(
self, anyio_backend_name: str, anyio_backend_options: dict[str, Any]
) -> None:
async def use_lock() -> None:
async with lock:
pass

lock = Lock()
statistics = lock.statistics()
assert not statistics.locked
assert statistics.owner is None
assert statistics.tasks_waiting == 0

run(use_lock, backend=anyio_backend_name, backend_options=anyio_backend_options)


class TestEvent:
async def test_event(self) -> None:
Expand Down Expand Up @@ -208,6 +225,21 @@ async def waiter() -> None:

assert event.statistics().tasks_waiting == 0

def test_instantiate_outside_event_loop(
self, anyio_backend_name: str, anyio_backend_options: dict[str, Any]
) -> None:
async def use_event() -> None:
event.set()
await event.wait()

event = Event()
assert not event.is_set()
assert event.statistics().tasks_waiting == 0

run(
use_event, backend=anyio_backend_name, backend_options=anyio_backend_options
)


class TestCondition:
async def test_contextmanager(self) -> None:
Expand Down Expand Up @@ -304,6 +336,22 @@ async def waiter() -> None:
assert not condition.statistics().lock_statistics.locked
assert condition.statistics().tasks_waiting == 0

def test_instantiate_outside_event_loop(
self, anyio_backend_name: str, anyio_backend_options: dict[str, Any]
) -> None:
async def use_condition() -> None:
async with condition:
pass

condition = Condition()
assert condition.statistics().tasks_waiting == 0

run(
use_condition,
backend=anyio_backend_name,
backend_options=anyio_backend_options,
)


class TestSemaphore:
async def test_contextmanager(self) -> None:
Expand Down Expand Up @@ -426,6 +474,22 @@ async def acquire() -> None:
task1.cancel()
await asyncio.wait_for(task2, 1)

def test_instantiate_outside_event_loop(
self, anyio_backend_name: str, anyio_backend_options: dict[str, Any]
) -> None:
async def use_semaphore() -> None:
async with semaphore:
pass

semaphore = Semaphore(1)
assert semaphore.statistics().tasks_waiting == 0

run(
use_semaphore,
backend=anyio_backend_name,
backend_options=anyio_backend_options,
)


class TestCapacityLimiter:
async def test_bad_init_type(self) -> None:
Expand Down Expand Up @@ -595,3 +659,33 @@ async def worker(entered_event: Event) -> None:

# Allow all tasks to exit
continue_event.set()

def test_instantiate_outside_event_loop(
self, anyio_backend_name: str, anyio_backend_options: dict[str, Any]
) -> None:
async def use_limiter() -> None:
async with limiter:
pass

limiter = CapacityLimiter(1)
limiter.total_tokens = 2

with pytest.raises(TypeError):
limiter.total_tokens = "2" # type: ignore[assignment]

with pytest.raises(TypeError):
limiter.total_tokens = 3.0

assert limiter.total_tokens == 2
assert limiter.borrowed_tokens == 0
statistics = limiter.statistics()
assert statistics.total_tokens == 2
assert statistics.borrowed_tokens == 0
assert statistics.borrowers == ()
assert statistics.tasks_waiting == 0

run(
use_limiter,
backend=anyio_backend_name,
backend_options=anyio_backend_options,
)

0 comments on commit 28516e2

Please sign in to comment.