Skip to content

Commit 19efe42

Browse files
committed
chore: 优化证书监听
1 parent d069fd2 commit 19efe42

File tree

7 files changed

+66
-74
lines changed

7 files changed

+66
-74
lines changed

.gitignore

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@ bmclapi
66
config
77
cache
88

9-
109
start.sh
1110

11+
**/__pycache__
1212

13-
**/__pycache__
13+
# pipy plugins
14+
tianxiu2b2t

core/cluster.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -623,8 +623,9 @@ async def _sync(self):
623623
await self.sync()
624624

625625
async def serve(self):
626-
for cluster in self.clusters:
627-
await cluster.start_serve()
626+
async with anyio.create_task_group() as task_group:
627+
for cluster in self.clusters:
628+
task_group.start_soon(cluster.start_serve)
628629

629630
async def stop(self):
630631
for cluster in self.clusters:
@@ -653,7 +654,7 @@ async def get_measure_file(self, size: int) -> ResponseFile:
653654
return await storage.get_file(f"measure/{size}")
654655

655656

656-
if cfg.concurreny_enable_cluster:
657-
sem = anyio.Semaphore(1)
657+
if cfg.concurrency_enable_cluster:
658+
sem = contextlib.nullcontext()
658659
else:
659-
sem = contextlib.nullcontext()
660+
sem = anyio.Semaphore(1)

core/config.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,8 @@ def storage_measure(self) -> bool:
7979
return self.get("advanced.storage_measure") or False
8080

8181
@property
82-
def concurreny_enable_cluster(self) -> bool:
83-
return self.get("advanced.concurreny_enable_cluster") or False
82+
def concurrency_enable_cluster(self) -> bool:
83+
return self.get("advanced.concurrency_enable_cluster") or False
8484

8585

8686
API_VERSION = "1.13.1"
@@ -94,6 +94,7 @@ def concurreny_enable_cluster(self) -> bool:
9494
"advanced.debug": False,
9595
"advanced.access_log": False,
9696
"advanced.host": "",
97+
"advanced.concurrency_enable_cluster": False,
9798
"web.port": 6543,
9899
"web.public_port": 6543,
99100
"web.proxy": False,

core/utils.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
import anyio
99
import anyio.abc
1010
from tqdm import tqdm
11+
from functools import lru_cache
12+
13+
from .logger import logger
1114

1215
from .abc import CertificateType
1316
from .config import cfg
@@ -304,16 +307,19 @@ def get_hash_obj(
304307
return hashlib.md5()
305308
return hashlib.sha1()
306309

310+
@lru_cache(maxsize=1024)
307311
def get_certificate_type() -> CertificateType:
312+
ret = CertificateType.CLUSTER
308313
if cfg.get("web.proxy"):
309-
return CertificateType.PROXY
314+
ret = CertificateType.PROXY
310315
else:
311316
key, cert = cfg.get("cert.key"), cfg.get("cert.cert")
312317
if key and cert:
313318
key_file, cert_file = Path(key), Path(cert)
314319
if key_file.exists() and cert_file.exists() and cert_file.stat().st_size > 0 and key_file.stat().st_size > 0:
315-
return CertificateType.BYOC
316-
return CertificateType.CLUSTER
320+
ret = CertificateType.BYOC
321+
logger.tinfo(f"web.byoc", type=ret)
322+
return ret
317323

318324
def get_range_size(range: str, size: Optional[int] = None):
319325
if range.startswith("bytes="):

core/web.py

Lines changed: 43 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import anyio.streams.tls
77
import fastapi
88
import uvicorn
9+
import tianxiu2b2t.anyio.streams as streams
910

1011
from . import utils, abc
1112
from .logger import logger
@@ -40,7 +41,7 @@ def __exit__(self, *args):
4041
)
4142
http_port = -1
4243
certificates: list[abc.Certificate] = []
43-
tls_ports: dict[str, int] = {}
44+
tls_listener: streams.AutoTLSListener | None = None
4445
forwards: dict[tuple[str, int], tuple[str, int]] = {}
4546
forwards_count: defaultdict[tuple[str, int], int] = defaultdict(int)
4647

@@ -49,8 +50,10 @@ async def get_free_port():
4950
port = listener.extra(anyio.abc.SocketAttribute.local_port)
5051
return port
5152

52-
async def pub_listener():
53-
global pub_port
53+
async def pub_listener(
54+
task_group: anyio.abc.TaskGroup
55+
):
56+
global pub_port, tls_listener
5457
pub_port = cfg.web_port
5558
if pub_port == -1:
5659
pub_port = cfg.web_public_port
@@ -59,82 +62,47 @@ async def pub_listener():
5962
listener = await anyio.create_tcp_listener(
6063
local_port=pub_port,
6164
)
65+
66+
tls_listener = streams.AutoTLSListener(
67+
listener,
68+
)
69+
task_group.start_soon(serve, tls_listener)
70+
71+
async def serve(
72+
listener: streams.AutoTLSListener,
73+
):
6274
async with listener:
6375
logger.tinfo("web.forward.pub_port", port=pub_port)
6476
await listener.serve(pub_handler)
6577

6678
async def pub_handler(
67-
sock: anyio.abc.SocketStream
79+
sock: streams.BufferedByteStream,
80+
extra: streams.TLSExtraData
6881
):
6982
try:
7083
async with sock:
71-
# first read 16384 bytes of tls
72-
buf = await sock.receive(16384)
73-
handshake = utils.parse_tls_handshake(buf)
74-
port = None
75-
if handshake is None:
76-
port = http_port
77-
else:
78-
if handshake.sni in tls_ports:
79-
port = tls_ports[handshake.sni]
80-
elif tls_ports:
81-
port = list(tls_ports.values())[0]
82-
if port is None:
83-
return
84-
# then forward to port
85-
await forward(sock, port, buf)
86-
except (
87-
anyio.EndOfStream,
88-
anyio.BrokenResourceError
89-
):
90-
...
91-
except Exception as e:
92-
logger.debug_traceback()
93-
94-
async def tls_listener(
95-
cert: abc.Certificate
96-
):
97-
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
98-
context.check_hostname = False
99-
context.hostname_checks_common_name = False
100-
context.load_cert_chain(cert.cert, cert.key)
101-
listener = await anyio.create_tcp_listener(
102-
local_host="127.0.0.1",
103-
)
104-
tls_listener = anyio.streams.tls.TLSListener(listener, context)
105-
async with tls_listener:
106-
logger.tdebug("web.forward.tls_port", port=listener.extra(anyio.abc.SocketAttribute.local_port))
107-
for domain in cert.domains:
108-
tls_ports[domain] = listener.extra(anyio.abc.SocketAttribute.local_port)
109-
await tls_listener.serve(tls_handler)
110-
111-
async def tls_handler(
112-
sock: anyio.streams.tls.TLSStream
113-
):
114-
try:
115-
async with sock:
116-
# first read 16384 bytes of tls
117-
# then forward to port
118-
await forward(sock, http_port)
84+
await forward(sock, http_port, b'')
11985
except (
12086
anyio.EndOfStream,
12187
anyio.BrokenResourceError,
122-
ssl.SSLError,
88+
ssl.SSLError
12389
):
12490
...
12591
except Exception as e:
12692
logger.debug_traceback()
12793

94+
12895
async def forward(
129-
sock: anyio.abc.SocketStream | anyio.streams.tls.TLSStream,
96+
sock: streams.BufferedByteStream,
13097
port: int,
13198
buffer: bytes = b''
13299
):
133100
try:
134-
async with await anyio.connect_tcp(
101+
async with streams.BufferedByteStream(
102+
await anyio.connect_tcp(
135103
"127.0.0.1",
136104
port
137-
) as conn:
105+
)) as conn:
138106
with ForwardAddress(
139107
get_sockname(conn),
140108
get_peername(sock)
@@ -148,12 +116,12 @@ async def forward(
148116
raise
149117

150118
def get_sockname(
151-
sock: anyio.abc.SocketStream | anyio.streams.tls.TLSStream
119+
sock: streams.BufferedByteStream
152120
) -> tuple[str, int]:
153121
return sock.extra(anyio.abc.SocketAttribute.local_address) # type: ignore
154122

155123
def get_peername(
156-
sock: anyio.abc.SocketStream | anyio.streams.tls.TLSStream
124+
sock: streams.BufferedByteStream
157125
) -> tuple[str, int]:
158126
return sock.extra(anyio.abc.SocketAttribute.remote_address) # type: ignore
159127

@@ -165,8 +133,8 @@ def get_origin_address(
165133
return name
166134

167135
async def forward_data(
168-
sock: anyio.abc.SocketStream | anyio.streams.tls.TLSStream,
169-
conn: anyio.abc.SocketStream | anyio.streams.tls.TLSStream
136+
sock: streams.BufferedByteStream,
137+
conn: streams.BufferedByteStream
170138
):
171139
try:
172140
while 1:
@@ -195,7 +163,7 @@ async def setup(
195163

196164
logger.tdebug("web.uvicorn.port", port=config.port)
197165

198-
task_group.start_soon(pub_listener)
166+
await pub_listener(task_group)
199167

200168
cert_type = utils.get_certificate_type()
201169

@@ -220,5 +188,18 @@ async def setup(
220188
if len(certificates) == 0:
221189
raise RuntimeError(t("error.web.certificates"))
222190

191+
if tls_listener is None:
192+
raise RuntimeError(t("error.web.tls_listener"))
193+
223194
for cert in certificates:
224-
task_group.start_soon(tls_listener, cert)
195+
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
196+
context.load_cert_chain(cert.cert, cert.key)
197+
context.check_hostname = False
198+
context.hostname_checks_common_name = False
199+
context.verify_mode = ssl.CERT_NONE
200+
201+
for domain in cert.domains:
202+
tls_listener.add_context(
203+
domain,
204+
context
205+
)

locale/zh_cn.json

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
"info.cluster.enable": "节点 [${id}] 已上线",
2020
"info.cluster.message": "节点 [${id}] 消息 [${msg}]",
2121
"info.cluster.retry": "节点 [${id}] 将在 [${time}s] 后尝试重新上线",
22+
"info.web.byoc": "证书类型 [${type}]",
2223
"success.cluster.keepalive": "节点 [${id}] 服务了 [${hits}] 个文件 总计 [${bytes}iB] 延迟 [${delay}ms]",
2324
"error.core.initialize.storages": "无法初始化存储",
2425
"error.core.initialize.missing": "当前加载节点数 [${clusters}] 存储数 [${storages}]",
@@ -29,6 +30,7 @@
2930
"error.cluster.enable": "节点 [${id}] 上线失败 [${err}]",
3031
"error.cluster.enable.timeout": "节点 [${id}] 注册超时",
3132
"error.web.certificates": "没有可用证书",
33+
"error.web.tls_listener": "TLS 监听器错误",
3234
"error.cluster.kicked": "节点 [${id}] 被主控踢出下线",
3335
"warning.cluster.warden": "节点 [${id}] 巡检:[${msg}]",
3436
"warning.cluster.keepalive": "节点 [${id}] 保活失败 (${failed}/3)",

requirements.txt

52 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)