Skip to content

Commit

Permalink
Support custom connectors
Browse files Browse the repository at this point in the history
  • Loading branch information
Lonami committed Oct 29, 2023
1 parent d80c6b3 commit 6e88264
Show file tree
Hide file tree
Showing 8 changed files with 161 additions and 51 deletions.
8 changes: 8 additions & 0 deletions client/doc/modules/types.rst
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,11 @@ Private definitions
.. autoclass:: InFileLike

.. autoclass:: OutFileLike

.. currentmodule:: telethon._impl.mtsender.sender

.. autoclass:: AsyncReader

.. autoclass:: AsyncWriter

.. autoclass:: Connector
4 changes: 2 additions & 2 deletions client/src/telethon/_impl/client/client/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ async def complete_login(client: Client, auth: abcs.auth.Authorization) -> User:
try:
await client._storage.save(client._session)
except Exception:
client._logger.exception(
client._config.base_logger.exception(
"failed to save session upon login; you may need to login again in future runs"
)

Expand All @@ -59,7 +59,7 @@ async def complete_login(client: Client, auth: abcs.auth.Authorization) -> User:
async def handle_migrate(client: Client, dc_id: Optional[int]) -> None:
assert dc_id is not None
sender, client._session.dcs = await connect_sender(
client._config, client._session.dcs, DataCenter(id=dc_id), client._logger
client._config, client._session.dcs, DataCenter(id=dc_id)
)
async with client._sender_lock:
client._sender = sender
Expand Down
21 changes: 14 additions & 7 deletions client/src/telethon/_impl/client/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,11 @@
Union,
)

from telethon._impl.session.session import DataCenter

from ....version import __version__ as default_version
from ...mtsender import Sender
from ...mtsender import Connector, Sender
from ...session import (
ChatHashCache,
DataCenter,
MemorySession,
MessageBox,
PackedChat,
Expand Down Expand Up @@ -193,6 +192,13 @@ class Client:
:param datacenter:
Override the datacenter to connect to.
Useful to connect to one of Telegram's test servers.
:param connector:
Asynchronous function called to connect to a remote address.
By default, this is :func:`asyncio.open_connection`.
In order to use proxies, you can set a custom connector.
See :class:`~telethon._impl.mtsender.sender.Connector` for more details.
"""

def __init__(
Expand All @@ -212,10 +218,9 @@ def __init__(
system_lang_code: Optional[str] = None,
lang_code: Optional[str] = None,
datacenter: Optional[DataCenter] = None,
connector: Optional[Connector] = None,
) -> None:
self._logger = logger or logging.getLogger(
__package__[: __package__.index(".")]
)
base_logger = logger or logging.getLogger(__package__[: __package__.index(".")])

self._sender: Optional[Sender] = None
self._sender_lock = asyncio.Lock()
Expand All @@ -240,11 +245,13 @@ def __init__(
if flood_sleep_threshold is None
else flood_sleep_threshold,
update_queue_limit=update_queue_limit,
base_logger=base_logger,
connector=connector or (lambda ip, port: asyncio.open_connection(ip, port)),
)

self._session = Session()

self._message_box = MessageBox(base_logger=self._logger)
self._message_box = MessageBox(base_logger=base_logger)
self._chat_hashes = ChatHashCache(None)
self._last_update_limit_warn: Optional[float] = None
self._updates: asyncio.Queue[
Expand Down
35 changes: 19 additions & 16 deletions client/src/telethon/_impl/client/client/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@

from ....version import __version__
from ...mtproto import BadStatus, Full, RpcError
from ...mtsender import Sender
from ...mtsender import connect as connect_without_auth
from ...mtsender import connect_with_auth
from ...mtsender import Connector, Sender
from ...mtsender import connect as do_connect_sender
from ...session import DataCenter
from ...session import User as SessionUser
from ...tl import LAYER, Request, abcs, functions, types
Expand Down Expand Up @@ -45,6 +44,8 @@ def default_system_version() -> str:
class Config:
api_id: int
api_hash: str
base_logger: logging.Logger
connector: Connector
device_model: str = field(default_factory=default_device_model)
system_version: str = field(default_factory=default_system_version)
app_version: str = __version__
Expand Down Expand Up @@ -76,7 +77,6 @@ async def connect_sender(
config: Config,
known_dcs: List[DataCenter],
dc: DataCenter,
base_logger: logging.Logger,
force_auth_gen: bool = False,
) -> Tuple[Sender, List[DataCenter]]:
# Only the ID of the input DC may be known.
Expand All @@ -93,11 +93,14 @@ async def connect_sender(
or (next((d.auth for d in known_dcs if d.id == dc.id and d.auth), None))
)

transport = Full()
if auth:
sender = await connect_with_auth(transport, dc.id, addr, auth, base_logger)
else:
sender = await connect_without_auth(transport, dc.id, addr, base_logger)
sender = await do_connect_sender(
Full(),
dc.id,
addr,
auth_key=auth,
base_logger=config.base_logger,
connector=config.connector,
)

try:
remote_config_data = await sender.invoke(
Expand All @@ -122,13 +125,11 @@ async def connect_sender(
dc = DataCenter(
id=dc.id, ipv4_addr=dc.ipv4_addr, ipv6_addr=dc.ipv6_addr, auth=None
)
base_logger.warning(
config.base_logger.warning(
"datacenter could not find stored auth; will retry generating a new one: %s",
dc,
)
return await connect_sender(
config, known_dcs, dc, base_logger, force_auth_gen=True
)
return await connect_sender(config, known_dcs, dc, force_auth_gen=True)
else:
raise

Expand Down Expand Up @@ -177,7 +178,7 @@ async def connect(self: Client) -> None:
id=self._session.user.dc if self._session.user else DEFAULT_DC
)
self._sender, self._session.dcs = await connect_sender(
self._config, self._session.dcs, datacenter, self._logger
self._config, self._session.dcs, datacenter
)

if self._message_box.is_empty() and self._session.user:
Expand Down Expand Up @@ -216,7 +217,7 @@ async def disconnect(self: Client) -> None:
except asyncio.CancelledError:
pass
except Exception:
self._logger.exception(
self._config.base_logger.exception(
"unhandled exception when cancelling dispatcher; this is a bug"
)
finally:
Expand All @@ -225,7 +226,9 @@ async def disconnect(self: Client) -> None:
try:
await sender.disconnect()
except Exception:
self._logger.exception("unhandled exception during disconnect; this is a bug")
self._config.base_logger.exception(
"unhandled exception during disconnect; this is a bug"
)

self._session.state = self._message_box.session_state()
await self._storage.save(self._session)
Expand Down
6 changes: 3 additions & 3 deletions client/src/telethon/_impl/client/client/updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def extend_update_queue(
now - client._last_update_limit_warn
> UPDATE_LIMIT_EXCEEDED_LOG_COOLDOWN
):
client._logger.warning(
client._config.base_logger.warning(
"updates are being dropped because limit=%d has been reached",
client._updates.maxsize,
)
Expand All @@ -134,13 +134,13 @@ async def dispatcher(client: Client) -> None:
except Exception as e:
if isinstance(e, RuntimeError) and loop.is_closed():
# User probably forgot to call disconnect.
client._logger.warning(
client._config.base_logger.warning(
"client was not closed cleanly, make sure to call client.disconnect()! %s",
e,
)
return
else:
client._logger.exception(
client._config.base_logger.exception(
"unhandled exception in event handler; this is probably a bug in your code, not telethon"
)
raise
Expand Down
8 changes: 6 additions & 2 deletions client/src/telethon/_impl/mtsender/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,20 @@
MAXIMUM_DATA,
NO_PING_DISCONNECT,
PING_DELAY,
AsyncReader,
AsyncWriter,
Connector,
Sender,
connect,
connect_with_auth,
)

__all__ = [
"MAXIMUM_DATA",
"NO_PING_DISCONNECT",
"PING_DELAY",
"AsyncReader",
"AsyncWriter",
"Connector",
"Sender",
"connect",
"connect_with_auth",
]
122 changes: 102 additions & 20 deletions client/src/telethon/_impl/mtsender/sender.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import struct
import time
from abc import ABC
from asyncio import FIRST_COMPLETED, Event, Future, StreamReader, StreamWriter
from asyncio import FIRST_COMPLETED, Event, Future
from dataclasses import dataclass
from typing import Generic, List, Optional, Self, TypeVar
from typing import Generic, List, Optional, Protocol, Self, Tuple, TypeVar

from ..crypto import AuthKey
from ..mtproto import (
Expand Down Expand Up @@ -43,6 +43,74 @@ def generate_random_id() -> int:
return _last_id


class AsyncReader(Protocol):
"""
A :class:`asyncio.StreamReader`-like class.
"""

async def read(self, n: int) -> bytes:
"""
Must behave like :meth:`asyncio.StreamReader.read`.
:param n:
Amount of bytes to read at most.
"""


class AsyncWriter(Protocol):
"""
A :class:`asyncio.StreamWriter`-like class.
"""

def write(self, data: bytes) -> None:
"""
Must behave like :meth:`asyncio.StreamWriter.write`.
:param data:
Data that must be entirely written or buffered until :meth:`drain` is called.
"""

async def drain(self) -> None:
"""
Must behave like :meth:`asyncio.StreamWriter.drain`.
"""

def close(self) -> None:
"""
Must behave like :meth:`asyncio.StreamWriter.close`.
"""

async def wait_closed(self) -> None:
"""
Must behave like :meth:`asyncio.StreamWriter.wait_closed`.
"""


class Connector(Protocol):
"""
A *Connector* is any function that takes in the following two positional parameters as input:
* The ``ip`` address as a :class:`str`. This might be either a IPv4 or IPv6.
* The ``port`` as a :class:`int`. This will be a number below 2¹⁶, often 443.
and returns a :class:`tuple`\ [:class:`AsyncReader`, :class:`AsyncWriter`].
You can use a custom connector to connect to Telegram through proxies.
The library will only ever open remote connections through this function.
The default connector is :func:`asyncio.open_connection`, defined as:
.. code-block:: python
default_connector = lambda ip, port: asyncio.open_connection(ip, port)
If your connector needs additional parameters, you can use either the :keyword:`lambda` syntax or :func:`functools.partial`.
"""

async def __call__(self, ip: str, port: int) -> Tuple[AsyncReader, AsyncWriter]:
pass


class RequestState(ABC):
pass

Expand Down Expand Up @@ -80,8 +148,8 @@ class Sender:
dc_id: int
addr: str
_logger: logging.Logger
_reader: StreamReader
_writer: StreamWriter
_reader: AsyncReader
_writer: AsyncWriter
_transport: Transport
_mtp: Mtp
_mtp_buffer: bytearray
Expand All @@ -98,9 +166,12 @@ async def connect(
mtp: Mtp,
dc_id: int,
addr: str,
*,
connector: Connector,
base_logger: logging.Logger,
) -> Self:
reader, writer = await asyncio.open_connection(*addr.split(":"))
ip, port = addr.split(":")
reader, writer = await connector(ip, int(port))

return cls(
dc_id=dc_id,
Expand Down Expand Up @@ -299,10 +370,33 @@ def auth_key(self) -> Optional[bytes]:


async def connect(
transport: Transport, dc_id: int, addr: str, base_logger: logging.Logger
transport: Transport,
dc_id: int,
addr: str,
*,
auth_key: Optional[bytes],
base_logger: logging.Logger,
connector: Connector,
) -> Sender:
sender = await Sender.connect(transport, Plain(), dc_id, addr, base_logger)
return await generate_auth_key(sender)
if auth_key is None:
sender = await Sender.connect(
transport,
Plain(),
dc_id,
addr,
connector=connector,
base_logger=base_logger,
)
return await generate_auth_key(sender)
else:
return await Sender.connect(
transport,
Encrypted(AuthKey.from_bytes(auth_key)),
dc_id,
addr,
connector=connector,
base_logger=base_logger,
)


async def generate_auth_key(sender: Sender) -> Sender:
Expand All @@ -320,15 +414,3 @@ async def generate_auth_key(sender: Sender) -> Sender:
sender._mtp = Encrypted(auth_key, time_offset=time_offset, first_salt=first_salt)
sender._next_ping = asyncio.get_running_loop().time() + PING_DELAY
return sender


async def connect_with_auth(
transport: Transport,
dc_id: int,
addr: str,
auth_key: bytes,
base_logger: logging.Logger,
) -> Sender:
return await Sender.connect(
transport, Encrypted(AuthKey.from_bytes(auth_key)), dc_id, addr, base_logger
)

0 comments on commit 6e88264

Please sign in to comment.