Skip to content

Commit

Permalink
Some more messy conduit stuff.
Browse files Browse the repository at this point in the history
  • Loading branch information
EvieePy committed Apr 17, 2024
1 parent b9a12d5 commit 3aaa333
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 13 deletions.
18 changes: 15 additions & 3 deletions twitchio/conduits/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

from ..utils import parse_timestamp
from .enums import ShardStatus, TransportMethod
from .websockets import Websocket


if TYPE_CHECKING:
Expand All @@ -42,7 +43,6 @@
from ..client import Client
from ..ext.commands import Bot
from ..types_.conduits import ConduitData, ShardData, ShardTransport
from .websockets import Websocket


logger: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -150,7 +150,7 @@ async def create_conduit(self, shard_count: int, buffer: bool = False) -> list[C

return conduits

async def fetch_conduits(self) -> dict[str, Conduit]:
async def fetch_conduits(self) -> MappingProxyType[str, Conduit]:
data = await self._client._http.get_conduits()
mapping: dict[str, Conduit] = {}

Expand All @@ -159,4 +159,16 @@ async def fetch_conduits(self) -> dict[str, Conduit]:
mapping[conduit.id] = conduit

self._conduits = mapping
return mapping
return MappingProxyType(mapping)

async def test(self) -> None:
await self.fetch_conduits()

for id_, conduit in self._conduits.items():
shards: list[Shard] = await (await self._client._http.get_conduit_shards(id_))
start: int = len(shards)

for n in range(start, conduit._shard_count):
websocket: Websocket = Websocket(id=n)
await websocket.connect()
break
46 changes: 36 additions & 10 deletions twitchio/conduits/websockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,12 @@

import asyncio
import logging
from typing import Any, cast

import aiohttp

from ..utils import _from_json # type: ignore


logger: logging.Logger = logging.getLogger(__name__)

Expand All @@ -35,17 +38,16 @@


class Websocket:
def __init__(self, *, keep_alive_timeout: float = 60, session: aiohttp.ClientSession | None = None) -> None:
def __init__(
self, *, keep_alive_timeout: float = 60, session: aiohttp.ClientSession | None = None, id: int
) -> None:
self._keep_alive_timeout: int = max(10, min(int(keep_alive_timeout), 600))
self._session: aiohttp.ClientSession | None = session
self._reconnect: bool = True
self._id: int = id

self._socket: aiohttp.ClientWebSocketResponse | None = None

self._listen_task: asyncio.Task[None] | None = None

self._id: str | None = None

@property
def keep_alive_timeout(self) -> int:
return self._keep_alive_timeout
Expand All @@ -54,6 +56,10 @@ def keep_alive_timeout(self) -> int:
def connected(self) -> bool:
return bool(self._socket and not self._socket.closed)

@property
def id(self) -> int:
return self._id

async def connect(self) -> None:
url: str = f"{WSS}?keepalive_timeout_seconds={self._keep_alive_timeout}"

Expand All @@ -64,9 +70,7 @@ async def connect(self) -> None:
if not self._session:
self._session = aiohttp.ClientSession()

async with self._session as session:
# TODO: Error handling...
self._socket = await session.ws_connect(url)
self._socket = await self._session.ws_connect(url)

logger.debug("Successfully connected to conduit websocket... Preparing to assign to shard.")
self._listen_task = asyncio.create_task(self._listen())
Expand All @@ -76,12 +80,28 @@ async def _listen(self) -> None:

while True:
try:
message = await self._socket.receive()
message: aiohttp.WSMessage = await self._socket.receive()
except Exception:
# TODO: Proper error handling...
return await self.close()

print(message)
type_: aiohttp.WSMsgType = message.type
if type_ in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE):
logger.info("Received close message on conduit websocket: %s", self._id)
return await self.close()

if type_ is not aiohttp.WSMsgType.TEXT:
logger.info("Received unknown message from conduit websocket: %s", self._id)
continue

try:
data: dict[str, Any] = cast(dict[str, Any], _from_json(message.data))
except Exception:
logger.warning("Unable to parse JSON in conduit websocket: %s", self._id)
continue

# TODO: Remove print...
print(data)

async def close(self) -> None:
if self._socket:
Expand All @@ -95,3 +115,9 @@ async def close(self) -> None:
await self._session.close()
except Exception:
...

if self._listen_task:
try:
self._listen_task.cancel()
except Exception:
...

0 comments on commit 3aaa333

Please sign in to comment.