Skip to content

Commit

Permalink
fix: synchronized functions that return generators are fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
Guibod committed Aug 7, 2023
1 parent a98abb4 commit fe14af1
Show file tree
Hide file tree
Showing 6 changed files with 103 additions and 49 deletions.
11 changes: 7 additions & 4 deletions examples/scryfall.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from mightstone.app import Mightstone
from mightstone.ass import aiterator_to_list

scryfall = Mightstone().scryfall
found = aiterator_to_list(scryfall.search_async("boseiju"))
m = Mightstone()
found = list(m.scryfall.search("boseiju"))

print(f"Found {len(found)} instances of Boseiju")
print("boseiju matches:")
for card in found:
print(f" - {card}")

print(f"Found {len(found)} instances of Boseiju")

print(list([x.name for x in m.scryfall.search("thalia")]))
19 changes: 19 additions & 0 deletions examples/scryfall_sync.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""
This example demonstrate the behavior of Mightstone in a synchronous context
random() method is an alias for random_async() using asgiref’s async_to_sync feature
search() method is an alias for search_async() using asgiref’s async_to_sync feature
"""

from mightstone import Mightstone
from mightstone.services.scryfall import Card

mightstone = Mightstone()
card = mightstone.scryfall.random()

print(f"The random card is {card.name} ({card.id})")

brushwaggs: list[Card] = mightstone.scryfall.search("brushwagg")

for i, brushwagg in enumerate(brushwaggs):
print(f"Brushwagg {i} is {brushwagg.name} from {brushwagg.set_code}")
42 changes: 37 additions & 5 deletions src/mightstone/ass/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import logging
from functools import wraps
from typing import Any, AsyncGenerator, Callable, Coroutine, List, TypeVar, Union
from typing import Any, AsyncGenerator, Callable, Coroutine, Generator, List, TypeVar

import asyncstdlib
from asgiref.sync import async_to_sync
Expand All @@ -11,23 +12,25 @@


@async_to_sync
async def aiterator_to_list(ait: AsyncGenerator[T, Any], limit=100) -> List[T]:
async def aiterator_to_list(ait: AsyncGenerator[T], limit=100) -> List[T]:
"""
Transforms an async iterator into a sync list
:param ait: Asynchronous iterator
:param limit: Max item to return
:return: The list of items
"""
return [item async for item in asyncstdlib.islice(ait, limit)]

if limit:
ait = asyncstdlib.islice(ait, limit)
return [item async for item in ait]


R = TypeVar("R")


# TODO: probably bad type returned for AsyncGenerator input
def synchronize(
f: Callable[..., Union[Coroutine[Any, Any, R], AsyncGenerator[R, None]]],
f: Callable[..., Coroutine[Any, Any, R]],
docstring: str = None,
) -> Callable[..., R]:
qname = f"{f.__module__}.{f.__qualname__}"
Expand All @@ -46,3 +49,32 @@ def inner(*args, **kwargs) -> R:
)

return inner


def sync_generator(
f: Callable[..., AsyncGenerator[R, None]],
docstring: str = None,
) -> Callable[..., Generator[R, None, None]]:
qname = f"{f.__module__}.{f.__qualname__}"
loop = asyncio.get_event_loop()

@wraps(f)
def inner(*args, **kwargs) -> Generator[R, None, None]:
async_generator = f(*args, **kwargs)
try:
while True:
yield loop.run_until_complete(async_generator.__anext__())
except StopAsyncIteration as e:
return
except Exception as e:
raise e

if docstring:
inner.__doc__ = docstring
else:
inner.__doc__ = (
f"Sync version of :func:`~{qname}`, same behavior but "
"wrapped by :func:`~asgiref.sync.async_to_sync`."
)

return inner
22 changes: 11 additions & 11 deletions src/mightstone/services/edhrec/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from httpx import HTTPStatusError
from pydantic.error_wrappers import ValidationError

from mightstone.ass import synchronize
from mightstone.ass import sync_generator, synchronize
from mightstone.services import MightstoneHttpClient, ServiceError
from mightstone.services.edhrec.models import (
EdhRecCardItem,
Expand Down Expand Up @@ -203,7 +203,7 @@ async def tribes_async(
):
yield item

tribes = synchronize(tribes_async)
tribes = sync_generator(tribes_async)

async def themes_async(
self, identity: Union[EdhRecIdentity, str] = None, limit: int = None
Expand All @@ -223,7 +223,7 @@ async def themes_async(
):
yield item

themes = synchronize(themes_async)
themes = sync_generator(themes_async)

async def sets_async(
self, limit: int = None
Expand All @@ -233,7 +233,7 @@ async def sets_async(
):
yield item

sets = synchronize(sets_async)
sets = sync_generator(sets_async)

async def salt_async(
self, year: int = None, limit: int = None
Expand Down Expand Up @@ -270,7 +270,7 @@ async def top_cards_async(
async for item in self._page_item_generator(path, limit=limit):
yield item

top_cards = synchronize(top_cards_async)
top_cards = sync_generator(top_cards_async)

async def cards_async(
self,
Expand Down Expand Up @@ -327,7 +327,7 @@ async def cards_async(
async for item in self._page_item_generator(path, category, limit=limit):
yield item

cards = synchronize(cards_async)
cards = sync_generator(cards_async)

async def companions_async(
self, limit: int = None
Expand All @@ -337,7 +337,7 @@ async def companions_async(
):
yield item

companions = synchronize(companions_async)
companions = sync_generator(companions_async)

async def partners_async(
self, identity: Union[EdhRecIdentity, str] = None, limit: int = None
Expand All @@ -349,7 +349,7 @@ async def partners_async(
async for item in self._page_item_generator(path, limit=limit):
yield item

partners = synchronize(partners_async)
partners = sync_generator(partners_async)

async def commanders_async(
self, identity: Union[EdhRecIdentity, str] = None, limit: int = None
Expand All @@ -361,7 +361,7 @@ async def commanders_async(
async for item in self._page_item_generator(path, limit=limit):
yield item

commanders = synchronize(commanders_async)
commanders = sync_generator(commanders_async)

async def combos_async(
self, identity: Union[EdhRecIdentity, str], limit: int = None
Expand All @@ -372,7 +372,7 @@ async def combos_async(
):
yield item

combos = synchronize(combos_async)
combos = sync_generator(combos_async)

async def combo_async(
self, identity: str, identifier: Union[EdhRecIdentity, str], limit: int = None
Expand All @@ -383,7 +383,7 @@ async def combo_async(
):
yield item

combo = synchronize(combo_async)
combo = sync_generator(combo_async)

async def _page_item_generator(
self,
Expand Down
38 changes: 19 additions & 19 deletions src/mightstone/services/mtgjson/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from httpx import HTTPStatusError
from pydantic.error_wrappers import ValidationError

from mightstone.ass import compressor, synchronize
from mightstone.ass import compressor, sync_generator, synchronize
from mightstone.core import MightstoneModel
from mightstone.services import MightstoneHttpClient, ServiceError
from mightstone.services.mtgjson.models import (
Expand Down Expand Up @@ -158,7 +158,7 @@ async def all_printings_async(self) -> AsyncGenerator[Set, None]:
):
yield item

all_printings = synchronize(all_printings_async)
all_printings = sync_generator(all_printings_async)

async def all_identifiers_async(self) -> AsyncGenerator[Card, None]:
"""
Expand All @@ -169,7 +169,7 @@ async def all_identifiers_async(self) -> AsyncGenerator[Card, None]:
async for k, item in self._iterate_model(kind="AllIdentifiers", model=Card):
yield item

all_identifiers = synchronize(all_identifiers_async)
all_identifiers = sync_generator(all_identifiers_async)

async def all_prices_async(self) -> AsyncGenerator[CardPrices, None]:
"""
Expand All @@ -180,7 +180,7 @@ async def all_prices_async(self) -> AsyncGenerator[CardPrices, None]:
async for k, item in self._iterate_model(kind="AllPrices"):
yield CardPrices(uuid=k, **item)

all_prices = synchronize(all_prices_async)
all_prices = sync_generator(all_prices_async)

async def atomic_cards_async(self) -> AsyncGenerator[CardAtomic, None]:
"""
Expand All @@ -191,7 +191,7 @@ async def atomic_cards_async(self) -> AsyncGenerator[CardAtomic, None]:
async for item in self._atomic(kind="AtomicCards"):
yield item

atomic_cards = synchronize(atomic_cards_async)
atomic_cards = sync_generator(atomic_cards_async)

async def card_types_async(self) -> CardTypes:
"""
Expand Down Expand Up @@ -224,7 +224,7 @@ async def deck_list_async(self) -> AsyncGenerator[DeckList, None]:
):
yield item

deck_list = synchronize(deck_list_async)
deck_list = sync_generator(deck_list_async)

async def deck_async(self, file_name: str) -> Deck:
"""
Expand Down Expand Up @@ -267,7 +267,7 @@ async def legacy_async(self) -> AsyncGenerator[Set, None]:
async for k, item in self._iterate_model(kind="Legacy", model=Set):
yield item

legacy = synchronize(legacy_async)
legacy = sync_generator(legacy_async)

async def legacy_atomic_async(self) -> AsyncGenerator[CardAtomic, None]:
"""
Expand All @@ -279,7 +279,7 @@ async def legacy_atomic_async(self) -> AsyncGenerator[CardAtomic, None]:
async for item in self._atomic(kind="LegacyAtomic"):
yield item

legacy_atomic = synchronize(legacy_atomic_async)
legacy_atomic = sync_generator(legacy_atomic_async)

async def meta_async(self) -> Meta:
"""
Expand All @@ -302,7 +302,7 @@ async def modern_async(self) -> AsyncGenerator[Set, None]:
async for k, item in self._iterate_model(kind="Modern", model=Set):
yield item

modern = synchronize(modern_async)
modern = sync_generator(modern_async)

async def modern_atomic_async(self) -> AsyncGenerator[CardAtomic, None]:
"""
Expand All @@ -313,7 +313,7 @@ async def modern_atomic_async(self) -> AsyncGenerator[CardAtomic, None]:
async for item in self._atomic(kind="ModernAtomic"):
yield item

modern_atomic = synchronize(modern_atomic_async)
modern_atomic = sync_generator(modern_atomic_async)

async def pauper_atomic_async(self) -> AsyncGenerator[CardAtomic, None]:
"""
Expand All @@ -324,7 +324,7 @@ async def pauper_atomic_async(self) -> AsyncGenerator[CardAtomic, None]:
async for item in self._atomic(kind="PauperAtomic"):
yield item

pauper_atomic = synchronize(pauper_atomic_async)
pauper_atomic = sync_generator(pauper_atomic_async)

async def pioneer_async(self) -> AsyncGenerator[Set, None]:
"""
Expand All @@ -336,7 +336,7 @@ async def pioneer_async(self) -> AsyncGenerator[Set, None]:
async for k, item in self._iterate_model(kind="Pioneer", model=Set):
yield item

pioneer = synchronize(pioneer_async)
pioneer = sync_generator(pioneer_async)

async def pioneer_atomic_async(self) -> AsyncGenerator[CardAtomic, None]:
"""
Expand All @@ -347,7 +347,7 @@ async def pioneer_atomic_async(self) -> AsyncGenerator[CardAtomic, None]:
async for item in self._atomic(kind="PioneerAtomic"):
yield item

pioneer_atomic = synchronize(pioneer_atomic_async)
pioneer_atomic = sync_generator(pioneer_atomic_async)

async def set_list_async(self) -> AsyncGenerator[SetList, None]:
"""
Expand All @@ -360,7 +360,7 @@ async def set_list_async(self) -> AsyncGenerator[SetList, None]:
):
yield item

set_list = synchronize(set_list_async)
set_list = sync_generator(set_list_async)

async def set_async(self, code: str) -> SetList:
"""
Expand All @@ -385,7 +385,7 @@ async def standard_async(self) -> AsyncGenerator[Set, None]:
async for k, item in self._iterate_model(kind="Standard", model=Set):
yield item

standard = synchronize(standard_async)
standard = sync_generator(standard_async)

async def standard_atomic_async(self) -> AsyncGenerator[CardAtomic, None]:
"""
Expand All @@ -396,7 +396,7 @@ async def standard_atomic_async(self) -> AsyncGenerator[CardAtomic, None]:
async for item in self._atomic(kind="StandardAtomic"):
yield item

standard_atomic = synchronize(standard_atomic_async)
standard_atomic = sync_generator(standard_atomic_async)

async def tcg_player_skus_async(self) -> AsyncGenerator[TcgPlayerSKUs, None]:
"""
Expand All @@ -418,7 +418,7 @@ async def tcg_player_skus_async(self) -> AsyncGenerator[TcgPlayerSKUs, None]:

yield group

tcg_player_skus = synchronize(tcg_player_skus_async)
tcg_player_skus = sync_generator(tcg_player_skus_async)

async def vintage_async(self) -> AsyncGenerator[Set, None]:
"""
Expand All @@ -430,7 +430,7 @@ async def vintage_async(self) -> AsyncGenerator[Set, None]:
async for k, item in self._iterate_model(kind="Vintage", model=Set):
yield item

vintage = synchronize(vintage_async)
vintage = sync_generator(vintage_async)

async def vintage_atomic_async(self) -> AsyncGenerator[CardAtomic, None]:
"""
Expand All @@ -441,7 +441,7 @@ async def vintage_atomic_async(self) -> AsyncGenerator[CardAtomic, None]:
async for item in self._atomic(kind="VintageAtomic"):
yield item

vintage_atomic = synchronize(vintage_atomic_async)
vintage_atomic = sync_generator(vintage_atomic_async)

async def _atomic(self, kind: str) -> AsyncGenerator[CardAtomic, None]:
card: Optional[CardAtomic] = None
Expand Down

0 comments on commit fe14af1

Please sign in to comment.