@@ -214,24 +214,49 @@ async def ssl_handle(reader: asyncio.StreamReader, writer: asyncio.StreamWriter)
214214 finally :
215215 writer .close ()
216216
217- async def check_server ():
218- global public_server
219- if public_server is None :
217+ async def _check_server (ip : str , port : int , ssl : Optional [ssl .SSLContext ] = None ):
218+ try :
219+ r , w = await asyncio .wait_for (asyncio .open_connection (ip , port , ssl = ssl ), 5 )
220+ w .close ()
221+ await w .wait_closed ()
222+ return True
223+ except :
220224 return False
221- async def _check ():
222- try :
223- r , w = await asyncio .wait_for (asyncio .open_connection ('127.0.0.1' , public_server .sockets [0 ].getsockname ()[1 ]), 5 ) # type: ignore
224- w .write (CHECK_PORT_SECRET )
225- await w .drain ()
226- data = await r .read (REQUEST_BUFFER )
227- return data == CHECK_PORT_SECRET
228- except :
229- return False
230- if await _check ():
231- return
232- await start_public_server ()
233- logger .twarning ("web.warning.public_port" , port = config .const .public_port )
234225
226+ async def check_server ():
227+ global site , public_server , private_ssl_server
228+ servers = [
229+ ]
230+ if site is not None :
231+ servers .append (
232+ ("127.0.0.1" , site ._port , start_tcp_site )
233+ )
234+ if public_server is not None :
235+ servers .append (
236+ ("127.0.0.1" , config .const .public_port , start_public_server )
237+ )
238+ if private_ssl_server is not None :
239+ servers .append (
240+ ("127.0.0.1" , private_ssl_server .sockets [0 ].getsockname ()[1 ], _start_ssl_server )
241+ )
242+ result = await asyncio .gather (* (
243+ _check_server (* server ) for server in servers
244+ ))
245+ for i , r in enumerate (result ):
246+ if not r :
247+ await servers [i ][2 ]()
248+ logger .twarning ("web.warning.server_down" , server = servers [i ][0 ], port = servers [i ][1 ])
249+
250+ async def start_tcp_site ():
251+ global site , runner
252+ if runner is None :
253+ return
254+ port = await get_free_port ()
255+ if site is not None :
256+ await site .stop ()
257+ site = web .TCPSite (runner , '127.0.0.1' , port )
258+ await site .start ()
259+ logger .tdebug ("web.debug.local_port" , port = site ._port )
235260
236261async def init ():
237262 global runner , site , public_server , routes , app
@@ -241,12 +266,7 @@ async def init():
241266 runner = web .AppRunner (app )
242267 await runner .setup ()
243268
244- port = await get_free_port ()
245-
246- site = web .TCPSite (runner , '127.0.0.1' , port )
247- await site .start ()
248-
249- logger .tdebug ("web.debug.local_port" , port = site ._port )
269+ await start_tcp_site ()
250270
251271 await start_public_server ()
252272
@@ -282,6 +302,12 @@ async def start_ssl_server(cert: Path, key: Path):
282302 client
283303 )
284304
305+ await _start_ssl_server ()
306+
307+ async def _start_ssl_server ():
308+ global private_ssl_server , private_ssl
309+ if not private_ssl :
310+ return
285311 if private_ssl_server is not None and private_ssl_server .is_serving ():
286312 private_ssl_server .close ()
287313 await private_ssl_server .wait_closed ()
@@ -290,12 +316,11 @@ async def start_ssl_server(cert: Path, key: Path):
290316 ssl_handle ,
291317 '127.0.0.1' ,
292318 port ,
293- ssl = context
319+ ssl = private_ssl [ 0 ]
294320 )
295321 logger .tdebug ("web.debug.ssl_port" , port = private_ssl_server .sockets [0 ].getsockname ()[1 ])
296322
297323
298-
299324async def unload ():
300325 global app
301326 await app .cleanup ()
0 commit comments