|
21 | 21 | import logging |
22 | 22 | import socket |
23 | 23 | from concurrent.futures import ThreadPoolExecutor |
| 24 | +from typing import Tuple, Dict, TypedDict, Optional |
24 | 25 |
|
25 | 26 | import socks |
26 | 27 |
|
27 | 28 | log = logging.getLogger(__name__) |
28 | 29 |
|
| 30 | +proxy_type_by_scheme: Dict[str, int] = { |
| 31 | + "SOCKS4": socks.SOCKS4, |
| 32 | + "SOCKS5": socks.SOCKS5, |
| 33 | + "HTTP": socks.HTTP, |
| 34 | +} |
| 35 | + |
| 36 | + |
| 37 | +class Proxy(TypedDict): |
| 38 | + scheme: str |
| 39 | + hostname: str |
| 40 | + port: int |
| 41 | + username: Optional[str] |
| 42 | + password: Optional[str] |
| 43 | + |
29 | 44 |
|
30 | 45 | class TCP: |
31 | 46 | TIMEOUT = 10 |
32 | 47 |
|
33 | | - def __init__(self, ipv6: bool, proxy: dict): |
34 | | - self.socket = None |
| 48 | + def __init__(self, ipv6: bool, proxy: Proxy) -> None: |
| 49 | + self.ipv6 = ipv6 |
| 50 | + self.proxy = proxy |
35 | 51 |
|
36 | | - self.reader = None |
37 | | - self.writer = None |
| 52 | + self.reader: Optional[asyncio.StreamReader] = None |
| 53 | + self.writer: Optional[asyncio.StreamWriter] = None |
38 | 54 |
|
39 | 55 | self.lock = asyncio.Lock() |
40 | 56 | self.loop = asyncio.get_event_loop() |
41 | 57 |
|
42 | | - self.proxy = proxy |
| 58 | + async def _connect_via_proxy( |
| 59 | + self, |
| 60 | + destination: Tuple[str, int] |
| 61 | + ) -> None: |
| 62 | + scheme = self.proxy.get("scheme") |
| 63 | + if scheme is None: |
| 64 | + raise ValueError("No scheme specified") |
43 | 65 |
|
44 | | - if proxy: |
45 | | - hostname = proxy.get("hostname") |
46 | | - |
47 | | - try: |
48 | | - ip_address = ipaddress.ip_address(hostname) |
49 | | - except ValueError: |
50 | | - self.socket = socks.socksocket(socket.AF_INET) |
51 | | - else: |
52 | | - if isinstance(ip_address, ipaddress.IPv6Address): |
53 | | - self.socket = socks.socksocket(socket.AF_INET6) |
54 | | - else: |
55 | | - self.socket = socks.socksocket(socket.AF_INET) |
| 66 | + proxy_type = proxy_type_by_scheme.get(scheme.upper()) |
| 67 | + if proxy_type is None: |
| 68 | + raise ValueError(f"Unknown proxy type {scheme}") |
56 | 69 |
|
57 | | - self.socket.set_proxy( |
58 | | - proxy_type=getattr(socks, proxy.get("scheme").upper()), |
59 | | - addr=hostname, |
60 | | - port=proxy.get("port", None), |
61 | | - username=proxy.get("username", None), |
62 | | - password=proxy.get("password", None) |
63 | | - ) |
| 70 | + hostname = self.proxy.get("hostname") |
| 71 | + port = self.proxy.get("port") |
| 72 | + username = self.proxy.get("username") |
| 73 | + password = self.proxy.get("password") |
64 | 74 |
|
65 | | - self.socket.settimeout(TCP.TIMEOUT) |
66 | | - |
67 | | - log.info("Using proxy %s", hostname) |
| 75 | + try: |
| 76 | + ip_address = ipaddress.ip_address(hostname) |
| 77 | + except ValueError: |
| 78 | + is_proxy_ipv6 = False |
68 | 79 | else: |
69 | | - self.socket = socket.socket( |
70 | | - socket.AF_INET6 if ipv6 |
71 | | - else socket.AF_INET |
72 | | - ) |
73 | | - |
74 | | - self.socket.setblocking(False) |
75 | | - |
76 | | - async def connect(self, address: tuple): |
| 80 | + is_proxy_ipv6 = isinstance(ip_address, ipaddress.IPv6Address) |
| 81 | + |
| 82 | + proxy_family = socket.AF_INET6 if is_proxy_ipv6 else socket.AF_INET |
| 83 | + sock = socks.socksocket(proxy_family) |
| 84 | + |
| 85 | + sock.set_proxy( |
| 86 | + proxy_type=proxy_type, |
| 87 | + addr=hostname, |
| 88 | + port=port, |
| 89 | + username=username, |
| 90 | + password=password |
| 91 | + ) |
| 92 | + sock.settimeout(TCP.TIMEOUT) |
| 93 | + |
| 94 | + await self.loop.sock_connect( |
| 95 | + sock=sock, |
| 96 | + address=destination |
| 97 | + ) |
| 98 | + |
| 99 | + sock.setblocking(False) |
| 100 | + |
| 101 | + self.reader, self.writer = await asyncio.open_connection( |
| 102 | + sock=sock |
| 103 | + ) |
| 104 | + |
| 105 | + async def _connect_via_direct( |
| 106 | + self, |
| 107 | + destination: Tuple[str, int] |
| 108 | + ) -> None: |
| 109 | + host, port = destination |
| 110 | + family = socket.AF_INET6 if self.ipv6 else socket.AF_INET |
| 111 | + self.reader, self.writer = await asyncio.open_connection( |
| 112 | + host=host, |
| 113 | + port=port, |
| 114 | + family=family |
| 115 | + ) |
| 116 | + |
| 117 | + async def _connect(self, destination: Tuple[str, int]) -> None: |
77 | 118 | if self.proxy: |
78 | | - with ThreadPoolExecutor(1) as executor: |
79 | | - await self.loop.run_in_executor(executor, self.socket.connect, address) |
| 119 | + await self._connect_via_proxy(destination) |
80 | 120 | else: |
81 | | - try: |
82 | | - await asyncio.wait_for(asyncio.get_event_loop().sock_connect(self.socket, address), TCP.TIMEOUT) |
83 | | - except asyncio.TimeoutError: # Re-raise as TimeoutError. asyncio.TimeoutError is deprecated in 3.11 |
84 | | - raise TimeoutError("Connection timed out") |
| 121 | + await self._connect_via_direct(destination) |
| 122 | + |
| 123 | + async def connect(self, address: Tuple[str, int]) -> None: |
| 124 | + try: |
| 125 | + await asyncio.wait_for(self._connect(address), TCP.TIMEOUT) |
| 126 | + except asyncio.TimeoutError: # Re-raise as TimeoutError. asyncio.TimeoutError is deprecated in 3.11 |
| 127 | + raise TimeoutError("Connection timed out") |
85 | 128 |
|
86 | | - self.reader, self.writer = await asyncio.open_connection(sock=self.socket) |
| 129 | + async def close(self) -> None: |
| 130 | + if self.writer is None: |
| 131 | + return None |
87 | 132 |
|
88 | | - async def close(self): |
89 | 133 | try: |
90 | | - if self.writer is not None: |
91 | | - self.writer.close() |
92 | | - await asyncio.wait_for(self.writer.wait_closed(), TCP.TIMEOUT) |
| 134 | + self.writer.close() |
| 135 | + await asyncio.wait_for(self.writer.wait_closed(), TCP.TIMEOUT) |
93 | 136 | except Exception as e: |
94 | 137 | log.info("Close exception: %s %s", type(e).__name__, e) |
95 | 138 |
|
96 | | - async def send(self, data: bytes): |
| 139 | + async def send(self, data: bytes) -> None: |
| 140 | + if self.writer is None: |
| 141 | + return None |
| 142 | + |
97 | 143 | async with self.lock: |
98 | 144 | try: |
99 | | - if self.writer is not None: |
100 | | - self.writer.write(data) |
101 | | - await self.writer.drain() |
| 145 | + self.writer.write(data) |
| 146 | + await self.writer.drain() |
102 | 147 | except Exception as e: |
103 | 148 | # error coming somewhere here |
104 | 149 | log.exception("Send exception: %s %s", type(e).__name__, e) |
105 | 150 | raise OSError(e) |
106 | 151 |
|
107 | | - async def recv(self, length: int = 0): |
| 152 | + async def recv(self, length: int = 0) -> Optional[bytes]: |
108 | 153 | data = b"" |
109 | 154 |
|
110 | 155 | while len(data) < length: |
|
0 commit comments