1010from typing import Any , Coroutine , Optional , Callable
1111from aiohttp import web
1212from aiohttp .web_urldispatcher import SystemRoute
13+ from tqdm import tqdm
1314
1415from core import config , scheduler , units , utils
1516from .logger import logger
@@ -26,6 +27,9 @@ class CheckServer:
2627 client : Optional [ssl .SSLContext ] = None
2728 args : tuple [Any , ...] = ()
2829
30+ def __hash__ (self ) -> int :
31+ return hash (self .object )
32+
2933@dataclass
3034class PrivateSSLServer :
3135 server : asyncio .Server
@@ -289,16 +293,20 @@ async def start_public_server(count: int = config.const.web_sockets):
289293 removes .append (server )
290294 for server in removes :
291295 public_servers .remove (server )
292- for _ in range (len (public_servers ), count ):
293- port = get_public_port ()
294- if port == 0 :
295- port = await get_free_port ()
296- server = await create_server (public_handle , '0.0.0.0' , port )
296+ with tqdm (total = count - len (public_servers )) as pbar :
297+ for _ in range (len (public_servers ), count ):
298+ port = get_public_port ()
299+ if port == 0 :
300+ port = await get_free_port ()
301+ server = await create_server (public_handle , '0.0.0.0' , port )
302+
303+ await server .start_serving ()
304+ public_servers .append (server )
305+ pbar .update (1 )
306+ pbar .set_postfix_str (f"Port [{ port } ]" )
297307
298- await server .start_serving ()
299- public_servers .append (server )
308+ logger .tsuccess ("web.success.public_port" , port = port , current = len (public_servers ), total = count )
300309
301- logger .tsuccess ("web.success.public_port" , port = server .sockets [0 ].getsockname ()[1 ], current = len (public_servers ), total = count )
302310
303311def get_public_port ():
304312 port = int (config .const .port )
@@ -380,22 +388,24 @@ async def check_server():
380388 servers : list [CheckServer ] = []
381389 if site is not None :
382390 servers .append (CheckServer (site , site ._port , start_tcp_site ))
383- for server in public_servers :
391+ if public_servers :
392+ server = public_servers [0 ]
384393 servers .append (CheckServer (server , get_server_port (server ), start_public_server ))
385394 if privates :
386395 for hash , server in privates .items ():
387396 servers .append (CheckServer (server .server , get_server_port (server .server ), start_private_server , server .key , (
388397 Path (hash [0 ]),
389398 Path (hash [1 ])
390399 )))
400+ servers = list (set (servers ))
391401
392402 #logger.tdebug("web.debug.check_server", servers=len(servers))
393403 results = await asyncio .gather (* [asyncio .create_task (_check_server (server )) for server in servers ])
394404 for server , result in zip (servers , results ):
395405 if result :
396406 continue
397- await server .start_handle ()
398407 logger .twarning ("web.warning.server_down" , port = server .port )
408+ await server .start_handle ()
399409
400410def get_server_port (server : Optional [asyncio .Server ]):
401411 if server is None :
0 commit comments