Skip to content

Commit

Permalink
Support optional usage of the StarletteAdapter.
Browse files Browse the repository at this point in the history
  • Loading branch information
EvieePy committed Mar 18, 2024
1 parent 7dd3246 commit 377fc5a
Show file tree
Hide file tree
Showing 4 changed files with 179 additions and 117 deletions.
6 changes: 3 additions & 3 deletions twitchio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
Team,
)
from .payloads import EventErrorPayload
from .web import AiohttpAdapter, WebAdapter
from .web import AiohttpAdapter


if TYPE_CHECKING:
Expand Down Expand Up @@ -83,8 +83,8 @@ def __init__(
session=session,
)

adapter: type[WebAdapter] = options.get("adapter", None) or AiohttpAdapter
self._adapter: WebAdapter = adapter(client=self)
adapter: Any = options.get("adapter", None) or AiohttpAdapter
self._adapter: Any = adapter(client=self)

# Event listeners...
# Cog listeners should be partials with injected self...
Expand Down
23 changes: 22 additions & 1 deletion twitchio/web/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,25 @@
SOFTWARE.
"""

from .adapters import *
from __future__ import annotations

import logging

from ..utils import ColorFormatter
from .aio_adapter import AiohttpAdapter as AiohttpAdapter


handler = logging.StreamHandler()
handler.setFormatter(ColorFormatter())
logger = logging.getLogger(__name__)
logger.addHandler(handler)


try:
from .starlette_adapter import StarletteAdapter as StarletteAdapter
except ImportError:
msg = "Please install the required packages: 'pip install twitchio[starlette]' to use the StarletteAdapter."
logger.warning("Starlette or uvicorn is not installed, StarletteAdapter will not be available. %s", msg)


logger.removeHandler(handler)
152 changes: 152 additions & 0 deletions twitchio/web/aio_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
"""
MIT License
Copyright (c) 2017 - Present PythonistaGuild
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""

from __future__ import annotations

import asyncio
import logging
from typing import TYPE_CHECKING
from urllib.parse import unquote_plus

from aiohttp import web

from ..authentication import Scopes


if TYPE_CHECKING:
from ..authentication import AuthorizationURLPayload, UserTokenPayload
from ..client import Client


__all__ = ("AiohttpAdapter",)


logger: logging.Logger = logging.getLogger(__name__)


class AiohttpAdapter(web.Application):
def __init__(self, client: Client, *, host: str | None = None, port: int | None = None) -> None:
super().__init__()
self._runner: web.AppRunner | None = None

self.client: Client = client

self._host: str = host or "localhost"
self._port: int = port or 4343

self._runner_task: asyncio.Task[None] | None = None
self._redirect_uri: str = client._http.redirect_uri or f"http://{self._host}:{self._port}/oauth/callback"

self.startup = self.event_startup
self.shutdown = self.event_shutdown

self.router.add_route("GET", "/oauth/callback", self.oauth_callback)
self.router.add_route("GET", "/oauth", self.oauth_redirect)

def __init_subclass__(cls: type[AiohttpAdapter]) -> None:
return

async def event_startup(self) -> None:
logger.info("Starting TwitchIO AiohttpAdapter on http://%s:%s.", self._host, self._port)

async def event_shutdown(self) -> None:
logger.info("Successfully shutdown TwitchIO <%s>.", self.__class__.__qualname__)

async def close(self) -> None:
if self._runner_task is not None:
try:
self._runner_task.cancel()
except Exception as e:
logger.debug(
"Ignoring exception raised while cancelling runner in <%s>: %s.",
self.__class__.__qualname__,
e,
)

if self._runner is not None:
await self._runner.cleanup()

self._runner = None
self._runner_task = None

async def run(self, host: str | None = None, port: int | None = None) -> None:
self._runner = web.AppRunner(self, access_log=None, handle_signals=True)
await self._runner.setup()

site: web.TCPSite = web.TCPSite(self._runner, host or self._host, port or self._port)
self._runner_task = asyncio.create_task(
site.start(), name=f"twitchio-web-adapter:{self.__class__.__qualname__}"
)

async def fetch_token(self, request: web.Request) -> web.Response:
if "code" not in request.query:
return web.Response(status=400)

try:
payload: UserTokenPayload = await self.client._http.user_access_token(
request.query["code"],
redirect_uri=self._redirect_uri,
)
except Exception as e:
logger.error("Exception raised while fetching Token in <%s>: %s", self.__class__.__qualname__, e)
return web.Response(status=500)

await self.client.add_token(payload["access_token"], payload["refresh_token"])
return web.Response(body="Success. You can leave this page.", status=200)

async def oauth_callback(self, request: web.Request) -> web.Response:
logger.debug("Received OAuth callback request in <%s>.", self.oauth_callback.__qualname__)

response: web.Response = await self.fetch_token(request)
return response

async def oauth_redirect(self, request: web.Request) -> web.Response:
scopes: str | None = request.query.get("scopes", None)
force_verify: bool = request.query.get("force_verify", "false").lower() == "true"

if not scopes:
scopes = str(self.client._http.scopes) if self.client._http.scopes else None

if not scopes:
logger.warning(
"No scopes provided in request to <%s>. Scopes are a required parameter that is missing.",
self.oauth_redirect.__qualname__,
)
return web.Response(status=400)

scopes_: Scopes = Scopes(unquote_plus(scopes).split())

try:
payload: AuthorizationURLPayload = self.client._http.get_authorization_url(
scopes=scopes_,
redirect_uri=self._redirect_uri,
force_verify=force_verify,
)
except Exception as e:
logger.error(
"Exception raised while fetching Authorization URL in <%s>: %s", self.__class__.__qualname__, e
)
return web.Response(status=500)

raise web.HTTPPermanentRedirect(payload["url"])
115 changes: 2 additions & 113 deletions twitchio/web/adapters.py → twitchio/web/starlette_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,10 @@

import asyncio
import logging
from typing import TYPE_CHECKING, TypeAlias
from typing import TYPE_CHECKING
from urllib.parse import unquote_plus

import uvicorn
from aiohttp import web
from starlette.applications import Starlette
from starlette.responses import RedirectResponse, Response
from starlette.routing import Route
Expand All @@ -45,7 +44,7 @@
from ..client import Client


__all__ = ("WebAdapter", "StarletteAdapter", "AiohttpAdapter")
__all__ = ("StarletteAdapter",)


logger: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -154,113 +153,3 @@ async def oauth_redirect(self, request: Request) -> Response:
return Response(status_code=500)

return RedirectResponse(url=payload["url"], status_code=307)


class AiohttpAdapter(web.Application):
def __init__(self, client: Client, *, host: str | None = None, port: int | None = None) -> None:
super().__init__()
self._runner: web.AppRunner | None = None

self.client: Client = client

self._host: str = host or "localhost"
self._port: int = port or 4343

self._runner_task: asyncio.Task[None] | None = None
self._redirect_uri: str = client._http.redirect_uri or f"http://{self._host}:{self._port}/oauth/callback"

self.startup = self.event_startup
self.shutdown = self.event_shutdown

self.router.add_route("GET", "/oauth/callback", self.oauth_callback)
self.router.add_route("GET", "/oauth", self.oauth_redirect)

def __init_subclass__(cls: type[AiohttpAdapter]) -> None:
return

async def event_startup(self) -> None:
logger.info("Starting TwitchIO AiohttpAdapter on http://%s:%s.", self._host, self._port)

async def event_shutdown(self) -> None:
logger.info("Successfully shutdown TwitchIO <%s>.", self.__class__.__qualname__)

async def close(self) -> None:
if self._runner_task is not None:
try:
self._runner_task.cancel()
except Exception as e:
logger.debug(
"Ignoring exception raised while cancelling runner in <%s>: %s.",
self.__class__.__qualname__,
e,
)

if self._runner is not None:
await self._runner.cleanup()

self._runner = None
self._runner_task = None

async def run(self, host: str | None = None, port: int | None = None) -> None:
self._runner = web.AppRunner(self, access_log=None, handle_signals=True)
await self._runner.setup()

site: web.TCPSite = web.TCPSite(self._runner, host or self._host, port or self._port)
self._runner_task = asyncio.create_task(
site.start(), name=f"twitchio-web-adapter:{self.__class__.__qualname__}"
)

async def fetch_token(self, request: web.Request) -> web.Response:
if "code" not in request.query:
return web.Response(status=400)

try:
payload: UserTokenPayload = await self.client._http.user_access_token(
request.query["code"],
redirect_uri=self._redirect_uri,
)
except Exception as e:
logger.error("Exception raised while fetching Token in <%s>: %s", self.__class__.__qualname__, e)
return web.Response(status=500)

await self.client.add_token(payload["access_token"], payload["refresh_token"])
return web.Response(body="Success. You can leave this page.", status=200)

async def oauth_callback(self, request: web.Request) -> web.Response:
logger.debug("Received OAuth callback request in <%s>.", self.oauth_callback.__qualname__)

response: web.Response = await self.fetch_token(request)
return response

async def oauth_redirect(self, request: web.Request) -> web.Response:
scopes: str | None = request.query.get("scopes", None)
force_verify: bool = request.query.get("force_verify", "false").lower() == "true"

if not scopes:
scopes = str(self.client._http.scopes) if self.client._http.scopes else None

if not scopes:
logger.warning(
"No scopes provided in request to <%s>. Scopes are a required parameter that is missing.",
self.oauth_redirect.__qualname__,
)
return web.Response(status=400)

scopes_: Scopes = Scopes(unquote_plus(scopes).split())

try:
payload: AuthorizationURLPayload = self.client._http.get_authorization_url(
scopes=scopes_,
redirect_uri=self._redirect_uri,
force_verify=force_verify,
)
except Exception as e:
logger.error(
"Exception raised while fetching Authorization URL in <%s>: %s", self.__class__.__qualname__, e
)
return web.Response(status=500)

raise web.HTTPPermanentRedirect(payload["url"])


WebAdapter: TypeAlias = StarletteAdapter | AiohttpAdapter

0 comments on commit 377fc5a

Please sign in to comment.