Skip to content

Commit

Permalink
[CHIA-710] Add the concept of 'action scopes' (#18124)
Browse files Browse the repository at this point in the history
* Add the concept of 'action scopes'

* pylint and test coverage

* add try/finally

* Address comments by @altendky

* Address comments by @altendky

* 86 memos

* Only one callback

* pylint

* Address more comments by @altendky

* remove unused variable

* add comment
  • Loading branch information
Quexington committed Jun 17, 2024
1 parent a6fca99 commit a36c0b8
Show file tree
Hide file tree
Showing 2 changed files with 292 additions and 0 deletions.
133 changes: 133 additions & 0 deletions chia/_tests/util/test_action_scope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import AsyncIterator, final

import pytest

from chia.util.action_scope import ActionScope, StateInterface


@final
@dataclass
class TestSideEffects:
buf: bytes = b""

def __bytes__(self) -> bytes:
return self.buf

@classmethod
def from_bytes(cls, blob: bytes) -> TestSideEffects:
return cls(blob)


async def default_async_callback(interface: StateInterface[TestSideEffects]) -> None:
return None # pragma: no cover


# Test adding a callback
def test_set_callback() -> None:
state_interface = StateInterface(TestSideEffects(), True)
state_interface.set_callback(default_async_callback)
assert state_interface._callback == default_async_callback
state_interface_no_callbacks = StateInterface(TestSideEffects(), False)
with pytest.raises(RuntimeError, match="Callback cannot be edited from inside itself"):
state_interface_no_callbacks.set_callback(None)


@pytest.fixture(name="action_scope")
async def action_scope_fixture() -> AsyncIterator[ActionScope[TestSideEffects]]:
async with ActionScope.new_scope(TestSideEffects) as scope:
yield scope


@pytest.mark.anyio
async def test_new_action_scope(action_scope: ActionScope[TestSideEffects]) -> None:
"""
Assert we can immediately check out some initial state
"""
async with action_scope.use() as interface:
assert interface == StateInterface(TestSideEffects(), True)


@pytest.mark.anyio
async def test_scope_persistence(action_scope: ActionScope[TestSideEffects]) -> None:
async with action_scope.use() as interface:
interface.side_effects.buf = b"baz"

async with action_scope.use() as interface:
assert interface.side_effects.buf == b"baz"


@pytest.mark.anyio
async def test_transactionality(action_scope: ActionScope[TestSideEffects]) -> None:
async with action_scope.use() as interface:
interface.side_effects.buf = b"baz"

with pytest.raises(Exception, match="Going to be caught"):
async with action_scope.use() as interface:
interface.side_effects.buf = b"qat"
raise RuntimeError("Going to be caught")

async with action_scope.use() as interface:
assert interface.side_effects.buf == b"baz"


@pytest.mark.anyio
async def test_callbacks() -> None:
async with ActionScope.new_scope(TestSideEffects) as action_scope:
async with action_scope.use() as interface:

async def callback(interface: StateInterface[TestSideEffects]) -> None:
interface.side_effects.buf = b"bar"

interface.set_callback(callback)

assert action_scope.side_effects.buf == b"bar"


@pytest.mark.anyio
async def test_callback_in_callback_error() -> None:
with pytest.raises(RuntimeError, match="Callback"):
async with ActionScope.new_scope(TestSideEffects) as action_scope:
async with action_scope.use() as interface:

async def callback(interface: StateInterface[TestSideEffects]) -> None:
interface.set_callback(default_async_callback)

interface.set_callback(callback)


@pytest.mark.anyio
async def test_no_callbacks_if_error() -> None:
with pytest.raises(Exception, match="This should prevent the callbacks from being called"):
async with ActionScope.new_scope(TestSideEffects) as action_scope:
async with action_scope.use() as interface:

async def callback(interface: StateInterface[TestSideEffects]) -> None:
raise NotImplementedError("Should not get here") # pragma: no cover

interface.set_callback(callback)

async with action_scope.use() as interface:
raise RuntimeError("This should prevent the callbacks from being called")

with pytest.raises(Exception, match="This should prevent the callbacks from being called"):
async with ActionScope.new_scope(TestSideEffects) as action_scope:
async with action_scope.use() as interface:

async def callback2(interface: StateInterface[TestSideEffects]) -> None:
raise NotImplementedError("Should not get here") # pragma: no cover

interface.set_callback(callback2)

raise RuntimeError("This should prevent the callbacks from being called")


# TODO: add suport, change this test to test it and add a test for nested transactionality
@pytest.mark.anyio
async def test_nested_use_banned(action_scope: ActionScope[TestSideEffects]) -> None:
async with action_scope.use():
with pytest.raises(RuntimeError, match="cannot currently support nested transactions"):
async with action_scope.use():
pass
159 changes: 159 additions & 0 deletions chia/util/action_scope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
from __future__ import annotations

import contextlib
from dataclasses import dataclass, field
from typing import AsyncIterator, Awaitable, Callable, Generic, Optional, Protocol, Type, TypeVar, final

import aiosqlite

from chia.util.db_wrapper import DBWrapper2, execute_fetchone


class ResourceManager(Protocol):
@classmethod
@contextlib.asynccontextmanager
async def managed(cls, initial_resource: SideEffects) -> AsyncIterator[ResourceManager]: # pragma: no cover
# yield included to make this a generator as expected by @contextlib.asynccontextmanager
yield # type: ignore[misc]

@contextlib.asynccontextmanager
async def use(self) -> AsyncIterator[None]: # pragma: no cover
# yield included to make this a generator as expected by @contextlib.asynccontextmanager
yield

async def get_resource(self, resource_type: Type[_T_SideEffects]) -> _T_SideEffects: ...

async def save_resource(self, resource: SideEffects) -> None: ...


@dataclass
class SQLiteResourceManager:

_db: DBWrapper2
_active_writer: Optional[aiosqlite.Connection] = field(init=False, default=None)

def get_active_writer(self) -> aiosqlite.Connection:
if self._active_writer is None:
raise RuntimeError("Can only access resources while under `use()` context manager")

return self._active_writer

@classmethod
@contextlib.asynccontextmanager
async def managed(cls, initial_resource: SideEffects) -> AsyncIterator[ResourceManager]:
async with DBWrapper2.managed(":memory:", reader_count=0) as db:
self = cls(db)
async with self._db.writer() as conn:
await conn.execute("CREATE TABLE side_effects(total blob)")
await conn.execute(
"INSERT INTO side_effects VALUES(?)",
(bytes(initial_resource),),
)
yield self

@contextlib.asynccontextmanager
async def use(self) -> AsyncIterator[None]:
if self._active_writer is not None:
raise RuntimeError("SQLiteResourceManager cannot currently support nested transactions")
async with self._db.writer() as conn:
self._active_writer = conn
try:
yield
finally:
self._active_writer = None

async def get_resource(self, resource_type: Type[_T_SideEffects]) -> _T_SideEffects:
row = await execute_fetchone(self.get_active_writer(), "SELECT total FROM side_effects")
assert row is not None
side_effects = resource_type.from_bytes(row[0])
return side_effects

async def save_resource(self, resource: SideEffects) -> None:
# This sets all rows (there's only one) to the new serialization
await self.get_active_writer().execute(
"UPDATE side_effects SET total=?",
(bytes(resource),),
)


class SideEffects(Protocol):
def __bytes__(self) -> bytes: ...

@classmethod
def from_bytes(cls: Type[_T_SideEffects], blob: bytes) -> _T_SideEffects: ...


_T_SideEffects = TypeVar("_T_SideEffects", bound=SideEffects)


@final
@dataclass
class ActionScope(Generic[_T_SideEffects]):
"""
The idea of an "action" is to map a single client input to many potentially distributed functions and side
effects. The action holds on to a temporary state that the many callers modify at will but only one at a time.
When the action is closed, the state is still available and can be committed elsewhere or discarded.
Utilizes a "resource manager" to hold the state in order to take advantage of rollbacks and prevent concurrent tasks
from interferring with each other.
"""

_resource_manager: ResourceManager
_side_effects_format: Type[_T_SideEffects]
_callback: Optional[Callable[[StateInterface[_T_SideEffects]], Awaitable[None]]] = None
_final_side_effects: Optional[_T_SideEffects] = field(init=False, default=None)

@property
def side_effects(self) -> _T_SideEffects:
if self._final_side_effects is None:
raise RuntimeError(
"Can only request ActionScope.side_effects after exiting context manager. "
"While in context manager, use ActionScope.use()."
)

return self._final_side_effects

@classmethod
@contextlib.asynccontextmanager
async def new_scope(
cls,
side_effects_format: Type[_T_SideEffects],
resource_manager_backend: Type[ResourceManager] = SQLiteResourceManager,
) -> AsyncIterator[ActionScope[_T_SideEffects]]:
async with resource_manager_backend.managed(side_effects_format()) as resource_manager:
self = cls(_resource_manager=resource_manager, _side_effects_format=side_effects_format)

yield self

async with self.use(_callbacks_allowed=False) as interface:
if self._callback is not None:
await self._callback(interface)
self._final_side_effects = interface.side_effects

@contextlib.asynccontextmanager
async def use(self, _callbacks_allowed: bool = True) -> AsyncIterator[StateInterface[_T_SideEffects]]:
async with self._resource_manager.use():
side_effects = await self._resource_manager.get_resource(self._side_effects_format)
interface = StateInterface(side_effects, _callbacks_allowed)

yield interface

await self._resource_manager.save_resource(interface.side_effects)
self._callback = interface.callback


@dataclass
class StateInterface(Generic[_T_SideEffects]):
side_effects: _T_SideEffects
_callbacks_allowed: bool
_callback: Optional[Callable[[StateInterface[_T_SideEffects]], Awaitable[None]]] = None

@property
def callback(self) -> Optional[Callable[[StateInterface[_T_SideEffects]], Awaitable[None]]]:
return self._callback

def set_callback(self, new_callback: Optional[Callable[[StateInterface[_T_SideEffects]], Awaitable[None]]]) -> None:
if not self._callbacks_allowed:
raise RuntimeError("Callback cannot be edited from inside itself")

self._callback = new_callback

0 comments on commit a36c0b8

Please sign in to comment.