Skip to content

Commit

Permalink
Add base of ManagedHTTPClient (Needs review)
Browse files Browse the repository at this point in the history
  • Loading branch information
EvieePy committed Mar 1, 2024
1 parent 944321e commit ee1d65b
Show file tree
Hide file tree
Showing 7 changed files with 283 additions and 24 deletions.
1 change: 0 additions & 1 deletion twitchio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,3 @@

from . import authentication as authentication
from .exceptions import *
from .http import HTTPAsyncIterator as HTTPAsyncIterator, HTTPClient as HTTPClient
1 change: 1 addition & 0 deletions twitchio/authentication/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@
from .oauth import OAuth as OAuth
from .payloads import *
from .scopes import Scopes as Scopes
from .tokens import ManagedHTTPClient as ManagedHTTPClient
13 changes: 11 additions & 2 deletions twitchio/authentication/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""

from __future__ import annotations

import secrets
Expand All @@ -32,6 +33,8 @@


if TYPE_CHECKING:
import aiohttp

from ..types_.responses import (
AuthorizationURLResponse,
ClientCredentialsResponse,
Expand All @@ -46,9 +49,15 @@ class OAuth(HTTPClient):
CONTENT_TYPE_HEADER: ClassVar[dict[str, str]] = {"Content-Type": "application/x-www-form-urlencoded"}

def __init__(
self, *, client_id: str, client_secret: str, redirect_uri: str | None = None, scopes: Scopes | None = None
self,
*,
client_id: str,
client_secret: str,
redirect_uri: str | None = None,
scopes: Scopes | None = None,
session: aiohttp.ClientSession | None = None,
) -> None:
super().__init__()
super().__init__(session=session, client_id=client_id)

self.client_id = client_id
self.client_secret = client_secret
Expand Down
1 change: 1 addition & 0 deletions twitchio/authentication/payloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""

from __future__ import annotations

from collections.abc import Iterator, Mapping
Expand Down
1 change: 1 addition & 0 deletions twitchio/authentication/scopes.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""

from __future__ import annotations

import urllib.parse
Expand Down
221 changes: 221 additions & 0 deletions twitchio/authentication/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,224 @@
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""

import asyncio
import datetime
import logging
from typing import TYPE_CHECKING, TypeVar

import aiohttp

from twitchio.http import Route
from twitchio.types_.responses import RawResponse

from ..exceptions import HTTPException, InvalidTokenException
from ..http import HTTPAsyncIterator, PaginatedConverter
from ..types_.tokens import TokenMappingData
from .oauth import OAuth
from .payloads import ClientCredentialsPayload, ValidateTokenPayload
from .scopes import Scopes


if TYPE_CHECKING:
from ..types_.tokens import TokenMapping
from .payloads import RefreshTokenPayload


logger: logging.Logger = logging.getLogger(__name__)


T = TypeVar("T")


class ManagedHTTPClient(OAuth):
def __init__(
self,
*,
client_id: str,
client_secret: str,
app_token: str | None = None,
redirect_uri: str | None = None,
scopes: Scopes | None = None,
session: aiohttp.ClientSession | None = None,
) -> None:
super().__init__(
client_id=client_id,
client_secret=client_secret,
redirect_uri=redirect_uri,
scopes=scopes,
session=session,
)

self._tokens: TokenMapping = {}
self._app_token = app_token
self._validate_task: asyncio.Task[None] | None = None

async def _attempt_refresh_on_add(self, token: str, refresh: str) -> ValidateTokenPayload:
logger.debug("Token was invalid when attempting to add it to the token manager. Attempting to refresh.")

try:
resp: RefreshTokenPayload = await self.refresh_token(refresh)
except HTTPException as e:
msg: str = f'Token was invalid and cannot be refreshed. Please re-authenticate user with token: "{token}"'
raise InvalidTokenException(msg, token=token, refresh=refresh, type_="refresh", original=e)

try:
valid_resp: ValidateTokenPayload = await self.validate_token(resp["access_token"])
except HTTPException as e:
msg: str = f'Refreshed token was invalid. Please re-authenticate user with token: "{token}"'
raise InvalidTokenException(msg, token=token, refresh=refresh, type_="token", original=e)

self._tokens[valid_resp.user_id] = {
"user_id": valid_resp.user_id,
"token": resp.access_token,
"refresh": resp.refresh_token,
"last_validated": datetime.datetime.now(),
}

logger.info('Token successfully added to TokenManager after refresh: "%s"', valid_resp.user_id)
return valid_resp

async def add_token(self, token: str, refresh: str) -> ValidateTokenPayload:
try:
resp: ValidateTokenPayload = await self.validate_token(token)
except HTTPException as e:
if e.status != 401:
msg: str = "Token was invalid. Please check the token or re-authenticate user with a new token."
raise InvalidTokenException(msg, token=token, refresh=refresh, type_="token", original=e)

return await self._attempt_refresh_on_add(token, refresh)

self._tokens[resp.user_id] = {
"user_id": resp.user_id,
"token": token,
"refresh": refresh,
"last_validated": datetime.datetime.now(),
}

logger.info('Token successfully added to TokenManager: "%s"', resp.user_id)
return resp

def _find_token(self, route: Route) -> TokenMappingData | None | str:
token: str | None = route.headers.get("Authorization")
if token:
token = token.removeprefix("Bearer ").removeprefix("OAuth ")

if token == self._app_token:
return token

for data in self._tokens.values():
if data["token"] == token:
return data

async def request(self, route: Route) -> RawResponse | str:
if not self._session:
await self._init_session()

old: TokenMappingData | None | str = self._find_token(route)

try:
data: RawResponse | str = await super().request(route)
except HTTPException as e:
if not old or e.status != 401:
raise e

if e.extra.get("message", "").lower() != "invalid access token":
raise e

if isinstance(old, str):
payload: ClientCredentialsPayload = await self.client_credentials_token()
self._app_token = payload.access_token
route.update_headers({"Authorization": f"Bearer {payload.access_token}"})

return await self.request(route)

logger.debug('Token for "%s" was invalid or expired. Attempting to refresh token.', old["user_id"])
refresh: RefreshTokenPayload = await self.refresh_token(old["refresh"])
logger.debug('Token for "%s" was successfully refreshed.', old["user_id"])

self._tokens[old["user_id"]] = {
"user_id": old["user_id"],
"token": refresh.access_token,
"refresh": refresh.refresh_token,
"last_validated": datetime.datetime.now(),
}

route.update_headers({"Authorization": f"Bearer {refresh.access_token}"})
return await self.request(route)

return data

def request_paginated(
self,
route: Route,
max_results: int | None = None,
*,
converter: PaginatedConverter[T] | None = None,
) -> HTTPAsyncIterator[T]:
iterator: HTTPAsyncIterator[T] = HTTPAsyncIterator(self, route, max_results, converter=converter)
return iterator

async def __validate_loop(self) -> None:
logger.debug("Started the token validation loop on %s.", self.__class__.__qualname__)

if not self._session:
await self._init_session()

while self._session and not self._session.closed:
for data in self._tokens.copy().values():
if data["last_validated"] + datetime.timedelta(minutes=60) > datetime.datetime.now():
continue

try:
await self.validate_token(data["token"])
logger.debug('Token for "%s" was successfully re-validated.', data["user_id"])
except HTTPException as e:
if e.status >= 500:
logger.warning("Received invalid response from Twitch when re-validating token.")

# backoff for 60 seconds to and try again...
# There's really not much else we can do here...
await asyncio.sleep(60)
continue

logger.debug('Token for "%s" was invalid or expired. Attempting to refresh token.', data["user_id"])

try:
refresh: RefreshTokenPayload = await self.refresh_token(data["refresh"])
except HTTPException as e:
self._tokens.pop(data["user_id"], None)
logger.warning('Token for "%s" was invalid and could not be refreshed.', data["user_id"])
continue

logger.debug('Token for "%s" was successfully refreshed.', data["user_id"])

self._tokens[data["user_id"]] = {
"user_id": data["user_id"],
"token": refresh.access_token,
"refresh": refresh.refresh_token,
"last_validated": datetime.datetime.now(),
}

continue

await asyncio.sleep(60)

async def _init_session(self) -> None:
await super()._init_session()

if not self._validate_task:
self._validate_task = asyncio.create_task(self.__validate_loop())

async def close(self) -> None:
self._tokens.clear()

if self._validate_task:
try:
self._validate_task.cancel()
except Exception:
pass

self._validate_task = None

await super().close()

0 comments on commit ee1d65b

Please sign in to comment.