66import anyio .streams .tls
77import fastapi
88import uvicorn
9+ import tianxiu2b2t .anyio .streams as streams
910
1011from . import utils , abc
1112from .logger import logger
@@ -40,7 +41,7 @@ def __exit__(self, *args):
4041)
4142http_port = - 1
4243certificates : list [abc .Certificate ] = []
43- tls_ports : dict [ str , int ] = {}
44+ tls_listener : streams . AutoTLSListener | None = None
4445forwards : dict [tuple [str , int ], tuple [str , int ]] = {}
4546forwards_count : defaultdict [tuple [str , int ], int ] = defaultdict (int )
4647
@@ -49,8 +50,10 @@ async def get_free_port():
4950 port = listener .extra (anyio .abc .SocketAttribute .local_port )
5051 return port
5152
52- async def pub_listener ():
53- global pub_port
53+ async def pub_listener (
54+ task_group : anyio .abc .TaskGroup
55+ ):
56+ global pub_port , tls_listener
5457 pub_port = cfg .web_port
5558 if pub_port == - 1 :
5659 pub_port = cfg .web_public_port
@@ -59,82 +62,47 @@ async def pub_listener():
5962 listener = await anyio .create_tcp_listener (
6063 local_port = pub_port ,
6164 )
65+
66+ tls_listener = streams .AutoTLSListener (
67+ listener ,
68+ )
69+ task_group .start_soon (serve , tls_listener )
70+
71+ async def serve (
72+ listener : streams .AutoTLSListener ,
73+ ):
6274 async with listener :
6375 logger .tinfo ("web.forward.pub_port" , port = pub_port )
6476 await listener .serve (pub_handler )
6577
6678async def pub_handler (
67- sock : anyio .abc .SocketStream
79+ sock : streams .BufferedByteStream ,
80+ extra : streams .TLSExtraData
6881):
6982 try :
7083 async with sock :
71- # first read 16384 bytes of tls
72- buf = await sock .receive (16384 )
73- handshake = utils .parse_tls_handshake (buf )
74- port = None
75- if handshake is None :
76- port = http_port
77- else :
78- if handshake .sni in tls_ports :
79- port = tls_ports [handshake .sni ]
80- elif tls_ports :
81- port = list (tls_ports .values ())[0 ]
82- if port is None :
83- return
84- # then forward to port
85- await forward (sock , port , buf )
86- except (
87- anyio .EndOfStream ,
88- anyio .BrokenResourceError
89- ):
90- ...
91- except Exception as e :
92- logger .debug_traceback ()
93-
94- async def tls_listener (
95- cert : abc .Certificate
96- ):
97- context = ssl .create_default_context (ssl .Purpose .CLIENT_AUTH )
98- context .check_hostname = False
99- context .hostname_checks_common_name = False
100- context .load_cert_chain (cert .cert , cert .key )
101- listener = await anyio .create_tcp_listener (
102- local_host = "127.0.0.1" ,
103- )
104- tls_listener = anyio .streams .tls .TLSListener (listener , context )
105- async with tls_listener :
106- logger .tdebug ("web.forward.tls_port" , port = listener .extra (anyio .abc .SocketAttribute .local_port ))
107- for domain in cert .domains :
108- tls_ports [domain ] = listener .extra (anyio .abc .SocketAttribute .local_port )
109- await tls_listener .serve (tls_handler )
110-
111- async def tls_handler (
112- sock : anyio .streams .tls .TLSStream
113- ):
114- try :
115- async with sock :
116- # first read 16384 bytes of tls
117- # then forward to port
118- await forward (sock , http_port )
84+ await forward (sock , http_port , b'' )
11985 except (
12086 anyio .EndOfStream ,
12187 anyio .BrokenResourceError ,
122- ssl .SSLError ,
88+ ssl .SSLError
12389 ):
12490 ...
12591 except Exception as e :
12692 logger .debug_traceback ()
12793
94+
12895async def forward (
129- sock : anyio . abc . SocketStream | anyio . streams .tls . TLSStream ,
96+ sock : streams .BufferedByteStream ,
13097 port : int ,
13198 buffer : bytes = b''
13299):
133100 try :
134- async with await anyio .connect_tcp (
101+ async with streams .BufferedByteStream (
102+ await anyio .connect_tcp (
135103 "127.0.0.1" ,
136104 port
137- ) as conn :
105+ )) as conn :
138106 with ForwardAddress (
139107 get_sockname (conn ),
140108 get_peername (sock )
@@ -148,12 +116,12 @@ async def forward(
148116 raise
149117
150118def get_sockname (
151- sock : anyio . abc . SocketStream | anyio . streams .tls . TLSStream
119+ sock : streams .BufferedByteStream
152120) -> tuple [str , int ]:
153121 return sock .extra (anyio .abc .SocketAttribute .local_address ) # type: ignore
154122
155123def get_peername (
156- sock : anyio . abc . SocketStream | anyio . streams .tls . TLSStream
124+ sock : streams .BufferedByteStream
157125) -> tuple [str , int ]:
158126 return sock .extra (anyio .abc .SocketAttribute .remote_address ) # type: ignore
159127
@@ -165,8 +133,8 @@ def get_origin_address(
165133 return name
166134
167135async def forward_data (
168- sock : anyio . abc . SocketStream | anyio . streams .tls . TLSStream ,
169- conn : anyio . abc . SocketStream | anyio . streams .tls . TLSStream
136+ sock : streams .BufferedByteStream ,
137+ conn : streams .BufferedByteStream
170138):
171139 try :
172140 while 1 :
@@ -195,7 +163,7 @@ async def setup(
195163
196164 logger .tdebug ("web.uvicorn.port" , port = config .port )
197165
198- task_group . start_soon ( pub_listener )
166+ await pub_listener ( task_group )
199167
200168 cert_type = utils .get_certificate_type ()
201169
@@ -220,5 +188,18 @@ async def setup(
220188 if len (certificates ) == 0 :
221189 raise RuntimeError (t ("error.web.certificates" ))
222190
191+ if tls_listener is None :
192+ raise RuntimeError (t ("error.web.tls_listener" ))
193+
223194 for cert in certificates :
224- task_group .start_soon (tls_listener , cert )
195+ context = ssl .create_default_context (ssl .Purpose .CLIENT_AUTH )
196+ context .load_cert_chain (cert .cert , cert .key )
197+ context .check_hostname = False
198+ context .hostname_checks_common_name = False
199+ context .verify_mode = ssl .CERT_NONE
200+
201+ for domain in cert .domains :
202+ tls_listener .add_context (
203+ domain ,
204+ context
205+ )
0 commit comments