Skip to content

Commit 983d396

Browse files
Implement non-blocking TCP connection (KurimuzonAkuma/kurigram#71)
1 parent c5e9aaa commit 983d396

File tree

11 files changed

+172
-96
lines changed

11 files changed

+172
-96
lines changed

pyrogram/client.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from io import StringIO, BytesIO
3333
from mimetypes import MimeTypes
3434
from pathlib import Path
35-
from typing import Union, List, Optional, Callable, AsyncGenerator
35+
from typing import Union, List, Optional, Callable, AsyncGenerator, Type
3636

3737
import pyrogram
3838
from pyrogram import __version__, __license__
@@ -52,6 +52,8 @@
5252
from pyrogram.storage import Storage, FileStorage, MemoryStorage
5353
from pyrogram.types import User, TermsOfService
5454
from pyrogram.utils import ainput
55+
from .connection import Connection
56+
from .connection.transport import TCP, TCPAbridged
5557
from .dispatcher import Dispatcher
5658
from .file_id import FileId, FileType, ThumbnailSource
5759
from .mime_types import mime_types
@@ -333,6 +335,9 @@ def __init__(
333335
else:
334336
self.storage = FileStorage(self.name, self.WORKDIR)
335337

338+
self.connection_factory = Connection
339+
self.protocol_factory = TCPAbridged
340+
336341
self.dispatcher = Dispatcher(self)
337342

338343
self.rnd_id = MsgId

pyrogram/connection/connection.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import asyncio
2020
import logging
21-
from typing import Optional
21+
from typing import Optional, Type
2222

2323
from .transport import TCP, TCPAbridged
2424
from ..session.internals import DataCenter
@@ -29,19 +29,28 @@
2929
class Connection:
3030
MAX_CONNECTION_ATTEMPTS = 3
3131

32-
def __init__(self, dc_id: int, test_mode: bool, ipv6: bool, proxy: dict, media: bool = False):
32+
def __init__(
33+
self,
34+
dc_id: int,
35+
test_mode: bool,
36+
ipv6: bool,
37+
proxy: dict,
38+
media: bool = False,
39+
protocol_factory: Type[TCP] = TCPAbridged
40+
) -> None:
3341
self.dc_id = dc_id
3442
self.test_mode = test_mode
3543
self.ipv6 = ipv6
3644
self.proxy = proxy
3745
self.media = media
46+
self.protocol_factory = protocol_factory
3847

3948
self.address = DataCenter(dc_id, test_mode, ipv6, media)
40-
self.protocol: TCP = None
49+
self.protocol: Optional[TCP] = None
4150

42-
async def connect(self):
51+
async def connect(self) -> None:
4352
for i in range(Connection.MAX_CONNECTION_ATTEMPTS):
44-
self.protocol = TCPAbridged(self.ipv6, self.proxy)
53+
self.protocol = self.protocol_factory(ipv6=self.ipv6, proxy=self.proxy)
4554

4655
try:
4756
log.info("Connecting...")
@@ -61,11 +70,11 @@ async def connect(self):
6170
log.warning("Connection failed! Trying again...")
6271
raise ConnectionError
6372

64-
async def close(self):
73+
async def close(self) -> None:
6574
await self.protocol.close()
6675
log.info("Disconnected")
6776

68-
async def send(self, data: bytes):
77+
async def send(self, data: bytes) -> None:
6978
await self.protocol.send(data)
7079

7180
async def recv(self) -> Optional[bytes]:

pyrogram/connection/transport/tcp/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# You should have received a copy of the GNU Lesser General Public License
1717
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
1818

19-
from .tcp import TCP
19+
from .tcp import TCP, Proxy
2020
from .tcp_abridged import TCPAbridged
2121
from .tcp_abridged_o import TCPAbridgedO
2222
from .tcp_full import TCPFull

pyrogram/connection/transport/tcp/tcp.py

Lines changed: 96 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -21,90 +21,135 @@
2121
import logging
2222
import socket
2323
from concurrent.futures import ThreadPoolExecutor
24+
from typing import Tuple, Dict, TypedDict, Optional
2425

2526
import socks
2627

2728
log = logging.getLogger(__name__)
2829

30+
proxy_type_by_scheme: Dict[str, int] = {
31+
"SOCKS4": socks.SOCKS4,
32+
"SOCKS5": socks.SOCKS5,
33+
"HTTP": socks.HTTP,
34+
}
35+
36+
37+
class Proxy(TypedDict):
38+
scheme: str
39+
hostname: str
40+
port: int
41+
username: Optional[str]
42+
password: Optional[str]
43+
2944

3045
class TCP:
3146
TIMEOUT = 10
3247

33-
def __init__(self, ipv6: bool, proxy: dict):
34-
self.socket = None
48+
def __init__(self, ipv6: bool, proxy: Proxy) -> None:
49+
self.ipv6 = ipv6
50+
self.proxy = proxy
3551

36-
self.reader = None
37-
self.writer = None
52+
self.reader: Optional[asyncio.StreamReader] = None
53+
self.writer: Optional[asyncio.StreamWriter] = None
3854

3955
self.lock = asyncio.Lock()
4056
self.loop = asyncio.get_event_loop()
4157

42-
self.proxy = proxy
58+
async def _connect_via_proxy(
59+
self,
60+
destination: Tuple[str, int]
61+
) -> None:
62+
scheme = self.proxy.get("scheme")
63+
if scheme is None:
64+
raise ValueError("No scheme specified")
4365

44-
if proxy:
45-
hostname = proxy.get("hostname")
46-
47-
try:
48-
ip_address = ipaddress.ip_address(hostname)
49-
except ValueError:
50-
self.socket = socks.socksocket(socket.AF_INET)
51-
else:
52-
if isinstance(ip_address, ipaddress.IPv6Address):
53-
self.socket = socks.socksocket(socket.AF_INET6)
54-
else:
55-
self.socket = socks.socksocket(socket.AF_INET)
66+
proxy_type = proxy_type_by_scheme.get(scheme.upper())
67+
if proxy_type is None:
68+
raise ValueError(f"Unknown proxy type {scheme}")
5669

57-
self.socket.set_proxy(
58-
proxy_type=getattr(socks, proxy.get("scheme").upper()),
59-
addr=hostname,
60-
port=proxy.get("port", None),
61-
username=proxy.get("username", None),
62-
password=proxy.get("password", None)
63-
)
70+
hostname = self.proxy.get("hostname")
71+
port = self.proxy.get("port")
72+
username = self.proxy.get("username")
73+
password = self.proxy.get("password")
6474

65-
self.socket.settimeout(TCP.TIMEOUT)
66-
67-
log.info("Using proxy %s", hostname)
75+
try:
76+
ip_address = ipaddress.ip_address(hostname)
77+
except ValueError:
78+
is_proxy_ipv6 = False
6879
else:
69-
self.socket = socket.socket(
70-
socket.AF_INET6 if ipv6
71-
else socket.AF_INET
72-
)
73-
74-
self.socket.setblocking(False)
75-
76-
async def connect(self, address: tuple):
80+
is_proxy_ipv6 = isinstance(ip_address, ipaddress.IPv6Address)
81+
82+
proxy_family = socket.AF_INET6 if is_proxy_ipv6 else socket.AF_INET
83+
sock = socks.socksocket(proxy_family)
84+
85+
sock.set_proxy(
86+
proxy_type=proxy_type,
87+
addr=hostname,
88+
port=port,
89+
username=username,
90+
password=password
91+
)
92+
sock.settimeout(TCP.TIMEOUT)
93+
94+
await self.loop.sock_connect(
95+
sock=sock,
96+
address=destination
97+
)
98+
99+
sock.setblocking(False)
100+
101+
self.reader, self.writer = await asyncio.open_connection(
102+
sock=sock
103+
)
104+
105+
async def _connect_via_direct(
106+
self,
107+
destination: Tuple[str, int]
108+
) -> None:
109+
host, port = destination
110+
family = socket.AF_INET6 if self.ipv6 else socket.AF_INET
111+
self.reader, self.writer = await asyncio.open_connection(
112+
host=host,
113+
port=port,
114+
family=family
115+
)
116+
117+
async def _connect(self, destination: Tuple[str, int]) -> None:
77118
if self.proxy:
78-
with ThreadPoolExecutor(1) as executor:
79-
await self.loop.run_in_executor(executor, self.socket.connect, address)
119+
await self._connect_via_proxy(destination)
80120
else:
81-
try:
82-
await asyncio.wait_for(asyncio.get_event_loop().sock_connect(self.socket, address), TCP.TIMEOUT)
83-
except asyncio.TimeoutError: # Re-raise as TimeoutError. asyncio.TimeoutError is deprecated in 3.11
84-
raise TimeoutError("Connection timed out")
121+
await self._connect_via_direct(destination)
122+
123+
async def connect(self, address: Tuple[str, int]) -> None:
124+
try:
125+
await asyncio.wait_for(self._connect(address), TCP.TIMEOUT)
126+
except asyncio.TimeoutError: # Re-raise as TimeoutError. asyncio.TimeoutError is deprecated in 3.11
127+
raise TimeoutError("Connection timed out")
85128

86-
self.reader, self.writer = await asyncio.open_connection(sock=self.socket)
129+
async def close(self) -> None:
130+
if self.writer is None:
131+
return None
87132

88-
async def close(self):
89133
try:
90-
if self.writer is not None:
91-
self.writer.close()
92-
await asyncio.wait_for(self.writer.wait_closed(), TCP.TIMEOUT)
134+
self.writer.close()
135+
await asyncio.wait_for(self.writer.wait_closed(), TCP.TIMEOUT)
93136
except Exception as e:
94137
log.info("Close exception: %s %s", type(e).__name__, e)
95138

96-
async def send(self, data: bytes):
139+
async def send(self, data: bytes) -> None:
140+
if self.writer is None:
141+
return None
142+
97143
async with self.lock:
98144
try:
99-
if self.writer is not None:
100-
self.writer.write(data)
101-
await self.writer.drain()
145+
self.writer.write(data)
146+
await self.writer.drain()
102147
except Exception as e:
103148
# error coming somewhere here
104149
log.exception("Send exception: %s %s", type(e).__name__, e)
105150
raise OSError(e)
106151

107-
async def recv(self, length: int = 0):
152+
async def recv(self, length: int = 0) -> Optional[bytes]:
108153
data = b""
109154

110155
while len(data) < length:

pyrogram/connection/transport/tcp/tcp_abridged.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,22 @@
1717
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
1818

1919
import logging
20-
from typing import Optional
20+
from typing import Optional, Tuple
2121

22-
from .tcp import TCP
22+
from .tcp import TCP, Proxy
2323

2424
log = logging.getLogger(__name__)
2525

2626

2727
class TCPAbridged(TCP):
28-
def __init__(self, ipv6: bool, proxy: dict):
28+
def __init__(self, ipv6: bool, proxy: Proxy) -> None:
2929
super().__init__(ipv6, proxy)
3030

31-
async def connect(self, address: tuple):
31+
async def connect(self, address: Tuple[str, int]) -> None:
3232
await super().connect(address)
3333
await super().send(b"\xef")
3434

35-
async def send(self, data: bytes, *args):
35+
async def send(self, data: bytes, *args) -> None:
3636
length = len(data) // 4
3737

3838
await super().send(

pyrogram/connection/transport/tcp/tcp_abridged_o.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,25 +18,25 @@
1818

1919
import logging
2020
import os
21-
from typing import Optional
21+
from typing import Optional, Tuple
2222

2323
import pyrogram
2424
from pyrogram.crypto import aes
25-
from .tcp import TCP
25+
from .tcp import TCP, Proxy
2626

2727
log = logging.getLogger(__name__)
2828

2929

3030
class TCPAbridgedO(TCP):
3131
RESERVED = (b"HEAD", b"POST", b"GET ", b"OPTI", b"\xee" * 4)
3232

33-
def __init__(self, ipv6: bool, proxy: dict):
33+
def __init__(self, ipv6: bool, proxy: Proxy) -> None:
3434
super().__init__(ipv6, proxy)
3535

3636
self.encrypt = None
3737
self.decrypt = None
3838

39-
async def connect(self, address: tuple):
39+
async def connect(self, address: Tuple[str, int]) -> None:
4040
await super().connect(address)
4141

4242
while True:
@@ -55,7 +55,7 @@ async def connect(self, address: tuple):
5555

5656
await super().send(nonce)
5757

58-
async def send(self, data: bytes, *args):
58+
async def send(self, data: bytes, *args) -> None:
5959
length = len(data) // 4
6060
data = (bytes([length]) if length <= 126 else b"\x7f" + length.to_bytes(3, "little")) + data
6161
payload = await self.loop.run_in_executor(pyrogram.crypto_executor, aes.ctr256_encrypt, data, *self.encrypt)

pyrogram/connection/transport/tcp/tcp_full.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,24 +19,24 @@
1919
import logging
2020
from binascii import crc32
2121
from struct import pack, unpack
22-
from typing import Optional
22+
from typing import Optional, Tuple
2323

24-
from .tcp import TCP
24+
from .tcp import TCP, Proxy
2525

2626
log = logging.getLogger(__name__)
2727

2828

2929
class TCPFull(TCP):
30-
def __init__(self, ipv6: bool, proxy: dict):
30+
def __init__(self, ipv6: bool, proxy: Proxy) -> None:
3131
super().__init__(ipv6, proxy)
3232

33-
self.seq_no = None
33+
self.seq_no: Optional[int] = None
3434

35-
async def connect(self, address: tuple):
35+
async def connect(self, address: Tuple[str, int]) -> None:
3636
await super().connect(address)
3737
self.seq_no = 0
3838

39-
async def send(self, data: bytes, *args):
39+
async def send(self, data: bytes, *args) -> None:
4040
data = pack("<II", len(data) + 12, self.seq_no) + data
4141
data += pack("<I", crc32(data))
4242
self.seq_no += 1

0 commit comments

Comments
 (0)