Skip to content

Commit

Permalink
Add AsyncIterator and query_builder.
Browse files Browse the repository at this point in the history
  • Loading branch information
EvieePy committed Feb 11, 2024
1 parent 84cad0b commit 36b643f
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 38 deletions.
2 changes: 1 addition & 1 deletion twitchio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,4 @@

from . import authentication as authentication
from .exceptions import *
from .http import HTTPClient as HTTPClient
from .http import HTTPAsyncIterator as HTTPAsyncIterator, HTTPClient as HTTPClient
4 changes: 2 additions & 2 deletions twitchio/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from typing import TYPE_CHECKING


__all__ = ("TwitchioException", "TwitchioHTTPException")
__all__ = ("TwitchioException", "HTTPException")


if TYPE_CHECKING:
Expand All @@ -39,7 +39,7 @@ class TwitchioException(Exception):
# TODO: Document this class.


class TwitchioHTTPException(TwitchioException):
class HTTPException(TwitchioException):
"""Exception raised when an HTTP request fails."""

# TODO: Document this class.
Expand Down
187 changes: 155 additions & 32 deletions twitchio/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,26 +23,37 @@
"""
from __future__ import annotations

import copy
import logging
import sys
from typing import TYPE_CHECKING, Any, ClassVar
import urllib.parse
from collections import deque
from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeAlias, TypeVar

import aiohttp

from . import __version__
from .exceptions import TwitchioHTTPException
from .exceptions import HTTPException
from .utils import _from_json # type: ignore


if TYPE_CHECKING:
from typing_extensions import Unpack
from collections.abc import Generator

from .types_.requests import APIRequest, APIRequestKwargs, HTTPMethod
from typing_extensions import Self, Unpack

from .types_.requests import APIRequest, APIRequestKwargs, HTTPMethod, ParamMapping
from .types_.responses import RawResponse


logger: logging.Logger = logging.getLogger(__name__)


T = TypeVar("T")
PaginatedConverter: TypeAlias = Callable[[Any], Awaitable[T]] | None


async def json_or_text(resp: aiohttp.ClientResponse) -> dict[str, Any] | str:
text: str = await resp.text()

Expand All @@ -56,57 +67,163 @@ async def json_or_text(resp: aiohttp.ClientResponse) -> dict[str, Any] | str:


class Route:
# TODO: Document this class.
__slots__ = ("params", "data", "json", "headers", "use_id", "method", "path", "packed", "_base_url", "_url")

BASE: ClassVar[str] = "https://api.twitch.tv/helix/"
ID_BASE: ClassVar[str] = "https://id.twitch.tv/"

def __init__(
self, method: HTTPMethod, path: str, *, use_id: bool = False, **kwargs: Unpack[APIRequestKwargs]
) -> None:
params: dict[str, str] = kwargs.pop("params", {})
self._url = self.build_url(path, use_id=use_id, params=params)
self.params: ParamMapping = kwargs.pop("params", {})
self.data: dict[str, Any] = kwargs.get("data", {})
self.json: dict[str, Any] = kwargs.get("json", {})
self.headers: dict[str, str] = kwargs.get("headers", {})

self.use_id = use_id
self.method = method
self.path = path

self.params: dict[str, str] = params
self.data: dict[str, Any] = kwargs.get("data", {})
self.json: dict[str, Any] = kwargs.get("json", {})
self.headers: dict[str, str] = kwargs.get("headers", {})

self.packed: APIRequest = kwargs

self._base_url: str = ""
self._url: str = self.build_url()

def __str__(self) -> str:
return str(self._url)

def __repr__(self) -> str:
return f"{self.method}({self.path})"
return f"{self.method}[{self.base_url}]"

@classmethod
def build_url(cls, path: str, use_id: bool = False, params: dict[str, str] | None = None) -> str:
if params is None:
params = {}
path_: str = path.lstrip("/")
def build_url(self, *, remove_none: bool = True) -> str:
base = self.ID_BASE if self.use_id else self.BASE
self.path = self.path.lstrip("/").rstrip("/")

url: str = f"{cls.ID_BASE if use_id else cls.BASE}{path_}{cls.build_query(params)}"
return url
url: str = f"{base}{self.path}"
self._base_url = url

def update_query(self, params: dict[str, str]) -> str:
self.params.update(params)
self._url = self.build_url(self.path, use_id=self.use_id, params=self.params)
if not self.params:
return url

return self._url
url += "?"

# We expect a dict so keys should be unique...
for key, value in copy.copy(self.params).items():
if value is None:
if remove_none:
del self.params[key]
continue

if isinstance(value, (str, int)):
url += f'{key}={self.encode(str(value), safe="+", plus=True)}&'
else:
# At this point we should assume it's a list or tuple...
# If it's not that's ultimately on us...
joined: str = "+".join([self.encode(str(v), safe="+") for v in value])
url += f"{key}={joined}&"

return url.rstrip("&")

@classmethod
def encode(cls, value: str, /, safe: str = "", plus: bool = False) -> str:
method = urllib.parse.quote_plus if plus else urllib.parse.quote
unquote = urllib.parse.unquote_plus if plus else urllib.parse.unquote

if unquote(value) == value:
return method(value, safe=safe)

return value

@property
def url(self) -> str:
return self._url

@classmethod
def build_query(cls, params: dict[str, str]) -> str:
joined: str = "&".join(f"{key}={value}" for key, value in params.items())
return f"?{joined}" if joined else ""
@property
def base_url(self) -> str:
return self._base_url

def update_params(self, params: ParamMapping, *, remove_none: bool = True) -> str:
self.params.update(params)
self._url = self.build_url(remove_none=remove_none)

return self.url


class HTTPAsyncIterator(Generic[T]):
__slots__ = ("_http", "_route", "_cursor", "_first", "_max_results", "_converter", "_buffer")

def __init__(
self, http: HTTPClient, route: Route, max_results: int | None = None, converter: PaginatedConverter[T] = None
) -> None:
self._http = http
self._route = route

self._cursor: str | None | bool = None
self._first: int = int(route.params.get("first", 20)) # 20 is twitch default
self._max_results: int | None = max_results

if self._max_results is not None and self._max_results < self._first:
self._first = self._max_results

self._converter = converter or self._base_converter
self._buffer: deque[T] = deque()

async def _base_converter(self, data: Any) -> T:
return data

async def _call_next(self) -> None:
if self._cursor is False:
raise StopAsyncIteration

if self._max_results is not None and self._max_results < 0:
raise StopAsyncIteration

self._route.update_params({"after": self._cursor})
data: RawResponse = await self._http.request_json(self._route)
self._cursor = data.get("pagination", {}).get("cursor", False)

try:
inner: list[RawResponse] = data["data"]
except KeyError:
# TODO: Proper exception...
raise ValueError('Expected "data" key not found.')

for value in inner:
if self._max_results is None:
self._buffer.append(await self._do_conversion(value))
continue

self._max_results -= 1 # If this is causing issues, it's just pylance bugged/desynced...
if self._max_results < 0:
return

self._buffer.append(await self._do_conversion(value))

async def _do_conversion(self, data: RawResponse) -> T:
return await self._converter(data)

async def _flatten(self) -> list[T]:
if not self._buffer:
await self._call_next()

return list(self._buffer)

def __await__(self) -> Generator[Any, None, list[T]]:
return self._flatten().__await__()

def __aiter__(self) -> Self:
return self

async def __anext__(self) -> T:
if not self._buffer:
await self._call_next()

try:
data = self._buffer.popleft()
except IndexError:
raise StopAsyncIteration

return data


class HTTPClient:
Expand Down Expand Up @@ -148,18 +265,18 @@ async def close(self) -> None:
self.clear()
logger.debug("%s session closed successfully.", self.__class__.__qualname__)

async def request(self, route: Route) -> Any:
async def request(self, route: Route) -> RawResponse | str:
await self._init_session()
assert self.__session is not None

logger.debug("Attempting a request to %r with %s.", route, self.__class__.__qualname__)

async with self.__session.request(route.method, route.url, **route.packed) as resp:
data: dict[str, Any] | str = await json_or_text(resp)
data: RawResponse | str = await json_or_text(resp)

if resp.status >= 400:
logger.error("Request %r failed with status %s: %s", route, resp.status, data)
raise TwitchioHTTPException(
raise HTTPException(
f"Request {route} failed with status {resp.status}: {data}", route=route, status=resp.status
)

Expand All @@ -170,7 +287,13 @@ async def request_json(self, route: Route) -> Any:
data = await self.request(route)

if isinstance(data, str):
# TODO: Add a TwitchioHTTPException here.
# TODO: Add a HTTPException here.
raise TypeError("Expected JSON data, but received text data.")

return data

def request_paginated(
self, route: Route, max_results: int | None = None, *, converter: PaginatedConverter[T] | None = None
) -> HTTPAsyncIterator[T]:
iterator: HTTPAsyncIterator[T] = HTTPAsyncIterator(self, route, max_results, converter=converter)
return iterator
6 changes: 4 additions & 2 deletions twitchio/types_/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,21 @@
SOFTWARE.
"""

from collections.abc import MutableMapping
from typing import Any, Literal, TypeAlias, TypedDict


__all__ = ("HTTPMethod", "APIRequestKwargs", "APIRequest")
__all__ = ("HTTPMethod", "APIRequestKwargs", "APIRequest", "ParamMapping")


HTTPMethod: TypeAlias = Literal["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS", "HEAD", "CONNECT", "TRACE"]
ParamMapping: TypeAlias = MutableMapping[str, Any]


class APIRequestKwargs(TypedDict, total=False):
headers: dict[str, str]
data: dict[str, Any]
params: dict[str, str]
params: ParamMapping
json: dict[str, Any]


Expand Down
4 changes: 3 additions & 1 deletion twitchio/types_/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
from typing import TypeAlias, TypedDict
from typing import Any, TypeAlias, TypedDict


__all__ = (
Expand All @@ -30,6 +30,7 @@
"ClientCredentialsResponse",
"OAuthResponses",
"UserTokenResponse",
"RawResponse",
)


Expand Down Expand Up @@ -60,3 +61,4 @@ class ClientCredentialsResponse(TypedDict):


OAuthResponses: TypeAlias = RefreshTokenResponse | ValidateTokenResponse | ClientCredentialsResponse | UserTokenResponse
RawResponse: TypeAlias = dict[str, Any]

0 comments on commit 36b643f

Please sign in to comment.