66from pathlib import Path
77import ssl
88import time
9- from typing import Any , Optional
9+ from typing import Any , Optional , Callable
1010from aiohttp import web
1111from aiohttp .web_urldispatcher import SystemRoute
1212
13- from core import config , units , utils
13+ from core import config , scheduler , units , utils
1414from .logger import logger
1515
1616from cryptography import x509
1717from cryptography .x509 .oid import NameOID
1818
1919
20+ @dataclass
21+ class CheckServer :
22+ port : int
23+ start_handle : Callable
24+ client : Optional [ssl .SSLContext ] = None
25+
2026@dataclass
2127class PrivateSSLServer :
2228 server : asyncio .Server
@@ -158,8 +164,6 @@ async def start_tcp_site():
158164 site = web .TCPSite (runner , '0.0.0.0' , port )
159165 await site .start ()
160166
161- print (site )
162-
163167async def init ():
164168 global runner , site
165169
@@ -172,6 +176,12 @@ async def init():
172176
173177 await start_public_server ()
174178
179+ scheduler .run_repeat_later (
180+ check_server ,
181+ 60 ,
182+ 10
183+ )
184+
175185async def forward_data (reader : asyncio .StreamReader , writer : asyncio .StreamWriter ):
176186 while not writer .is_closing ():
177187 data = await reader .read (IO_BUFFER )
@@ -354,5 +364,47 @@ def get_certificate_domains(
354364 ]
355365 return results
356366
367+ async def check_server ():
368+ servers : list [CheckServer ] = []
369+ if site is not None :
370+ servers .append (CheckServer (site ._port , start_private_server ))
371+ if public_server is not None :
372+ servers .append (CheckServer (public_server .sockets [0 ].getsockname ()[1 ], start_public_server ))
373+ if privates :
374+ for server in privates .values ():
375+ servers .append (CheckServer (server .server .sockets [0 ].getsockname ()[1 ], start_private_server , server .key ))
376+
377+ logger .tdebug ("web.debug.check_server" , servers = len (servers ))
378+ results = await asyncio .gather (* [asyncio .create_task (_check_server (server )) for server in servers ])
379+ for server , result in zip (servers , results ):
380+ if result :
381+ continue
382+ await server .start_handle ()
383+ logger .twarning ("web.warning.server_down" , port = server .port )
384+
385+
386+ async def _check_server (
387+ server : CheckServer
388+ ):
389+ try :
390+ r , w = await asyncio .wait_for (
391+ asyncio .open_connection (
392+ '127.0.0.1' ,
393+ server .port ,
394+ ssl = server .client
395+ ),
396+ timeout = 5
397+ )
398+ w .close ()
399+ try :
400+ await asyncio .wait_for (w .wait_closed (), timeout = 10 )
401+ except :
402+ ...
403+ return True
404+ except :
405+ logger .ttraceback ("web.traceback.check_server" , port = server .port )
406+ return False
407+
357408async def unload ():
358- ...
409+ await app .cleanup ()
410+ await app .shutdown ()
0 commit comments