11import asyncio
2+ import random
23from collections import defaultdict
3- from datetime import datetime as dt , timezone as tz
4+ from datetime import datetime as dt , timezone as tz , timedelta as td
45from operator import attrgetter
56
67from PasarGuardNodeBridge import PasarGuardNode , NodeAPIError
1213from sqlalchemy .dialects .mysql import insert as mysql_insert
1314
1415from app import scheduler
15- from app .db import AsyncSession , GetDB
16+ from app .db import GetDB
1617from app .db .models import Admin , NodeUsage , NodeUserUsage , System , User
1718from app .node import node_manager as node_manager
1819from app .utils .logger import get_logger
2526logger = get_logger ("record-usages" )
2627
2728
28- async def safe_execute (db : AsyncSession , stmt , params = None , max_retries : int = 3 ):
29+ async def safe_execute (stmt , params = None , max_retries : int = 5 ):
2930 """
3031 Safely execute database operations with deadlock and connection handling.
32+ Creates a fresh DB session for each retry attempt to release locks.
3133
3234 Args:
33- db (AsyncSession): Async database session
3435 stmt: SQLAlchemy statement to execute
3536 params (list[dict], optional): Parameters for the statement
36- max_retries (int, optional): Maximum number of retry attempts
37+ max_retries (int, optional): Maximum number of retry attempts (default: 5)
3738 """
38- dialect = db .bind .dialect .name
39-
40- # MySQL-specific IGNORE prefix - but skip if using ON DUPLICATE KEY UPDATE
41- if dialect == "mysql" and isinstance (stmt , Insert ):
42- # Check if statement already has ON DUPLICATE KEY UPDATE
43- if not hasattr (stmt , '_post_values_clause' ) or stmt ._post_values_clause is None :
44- stmt = stmt .prefix_with ("IGNORE" )
45-
4639 for attempt in range (max_retries ):
4740 try :
48- await (await db .connection ()).execute (stmt , params )
49- await db .commit ()
50- return
41+ # Create fresh session for each attempt to release any locks from previous attempts
42+ async with GetDB () as db :
43+ dialect = db .bind .dialect .name
44+
45+ # MySQL-specific IGNORE prefix - but skip if using ON DUPLICATE KEY UPDATE
46+ if dialect == "mysql" and isinstance (stmt , Insert ):
47+ # Check if statement already has ON DUPLICATE KEY UPDATE
48+ if not hasattr (stmt , "_post_values_clause" ) or stmt ._post_values_clause is None :
49+ stmt = stmt .prefix_with ("IGNORE" )
50+
51+ # Use raw connection to avoid ORM bulk update requirements
52+ await (await db .connection ()).execute (stmt , params )
53+ await db .commit ()
54+ return # Success - exit function
55+
5156 except (OperationalError , DatabaseError ) as err :
52- # Rollback the session
53- await db .rollback ()
54-
55- # Specific error code handling with exponential backoff
56- if dialect == "mysql" :
57- # MySQL deadlock (Error 1213)
58- if err .orig .args [0 ] == 1213 and attempt < max_retries - 1 :
59- await asyncio .sleep (0.05 * (2 ** attempt )) # 50ms, 100ms, 200ms
60- continue
61- elif dialect == "postgresql" :
62- # PostgreSQL deadlock (Error 40P01)
63- if err .orig .code == "40P01" and attempt < max_retries - 1 :
64- await asyncio .sleep (0.05 * (2 ** attempt ))
57+ # Session auto-closed by context manager, locks released
58+
59+ # Determine error type for retry logic
60+ is_mysql_deadlock = (
61+ hasattr (err , "orig" )
62+ and hasattr (err .orig , "args" )
63+ and len (err .orig .args ) > 0
64+ and err .orig .args [0 ] == 1213
65+ )
66+ is_pg_deadlock = hasattr (err , "orig" ) and hasattr (err .orig , "code" ) and err .orig .code == "40P01"
67+ is_sqlite_locked = "database is locked" in str (err )
68+
69+ # Retry with exponential backoff if retriable error
70+ if attempt < max_retries - 1 :
71+ if is_mysql_deadlock or is_pg_deadlock :
72+ # Exponential backoff with jitter: 50-75ms, 100-150ms, 200-300ms, 400-600ms, 800-1200ms
73+ base_delay = 0.05 * (2 ** attempt )
74+ jitter = random .uniform (0 , base_delay * 0.5 )
75+ await asyncio .sleep (base_delay + jitter )
6576 continue
66- elif dialect == "sqlite" :
67- # SQLite database locked error
68- if "database is locked" in str (err ) and attempt < max_retries - 1 :
69- await asyncio .sleep (0.1 * (attempt + 1 )) # Exponential backoff
77+ elif is_sqlite_locked :
78+ await asyncio .sleep (0.1 * (attempt + 1 )) # Linear backoff
7079 continue
7180
7281 # If we've exhausted retries or it's not a retriable error, raise
@@ -112,7 +121,7 @@ async def record_user_stats(params: list[dict], node_id: int, usage_coefficient:
112121 index_elements = ["created_at" , "user_id" , "node_id" ],
113122 set_ = {"used_traffic" : NodeUserUsage .used_traffic + bindparam ("value" )},
114123 )
115- await safe_execute (db , stmt , upsert_params )
124+ await safe_execute (stmt , upsert_params )
116125
117126 elif dialect == "mysql" :
118127 # MySQL: INSERT ... ON DUPLICATE KEY UPDATE
@@ -123,20 +132,22 @@ async def record_user_stats(params: list[dict], node_id: int, usage_coefficient:
123132 used_traffic = bindparam ("value" ),
124133 )
125134 # Use stmt.inserted to reference the inserted value (VALUES() in SQL)
126- stmt = stmt .on_duplicate_key_update (
127- used_traffic = NodeUserUsage .used_traffic + stmt .inserted .used_traffic
128- )
129- await safe_execute (db , stmt , upsert_params )
135+ stmt = stmt .on_duplicate_key_update (used_traffic = NodeUserUsage .used_traffic + stmt .inserted .used_traffic )
136+ await safe_execute (stmt , upsert_params )
130137
131138 else : # SQLite
132139 # SQLite: Use INSERT OR IGNORE + UPDATE pattern
133- insert_stmt = insert (NodeUserUsage ).values (
134- user_id = bindparam ("uid" ),
135- node_id = bindparam ("node_id" ),
136- created_at = bindparam ("created_at" ),
137- used_traffic = 0 ,
138- ).prefix_with ("OR IGNORE" )
139- await safe_execute (db , insert_stmt , upsert_params )
140+ insert_stmt = (
141+ insert (NodeUserUsage )
142+ .values (
143+ user_id = bindparam ("uid" ),
144+ node_id = bindparam ("node_id" ),
145+ created_at = bindparam ("created_at" ),
146+ used_traffic = 0 ,
147+ )
148+ .prefix_with ("OR IGNORE" )
149+ )
150+ await safe_execute (insert_stmt , upsert_params )
140151
141152 # Update all rows (existing + newly inserted)
142153 update_stmt = (
@@ -150,7 +161,7 @@ async def record_user_stats(params: list[dict], node_id: int, usage_coefficient:
150161 )
151162 )
152163 )
153- await safe_execute (db , update_stmt , upsert_params )
164+ await safe_execute (update_stmt , upsert_params )
154165
155166
156167async def record_node_stats (params : list [dict ], node_id : int ):
@@ -195,7 +206,7 @@ async def record_node_stats(params: list[dict], node_id: int):
195206 "downlink" : NodeUsage .downlink + bindparam ("down" ),
196207 },
197208 )
198- await safe_execute (db , stmt , [upsert_param ])
209+ await safe_execute (stmt , [upsert_param ])
199210
200211 elif dialect == "mysql" :
201212 # MySQL: INSERT ... ON DUPLICATE KEY UPDATE
@@ -210,17 +221,21 @@ async def record_node_stats(params: list[dict], node_id: int):
210221 uplink = NodeUsage .uplink + stmt .inserted .uplink ,
211222 downlink = NodeUsage .downlink + stmt .inserted .downlink ,
212223 )
213- await safe_execute (db , stmt , [upsert_param ])
224+ await safe_execute (stmt , [upsert_param ])
214225
215226 else : # SQLite
216227 # SQLite: Use INSERT OR IGNORE + UPDATE pattern
217- insert_stmt = insert (NodeUsage ).values (
218- node_id = bindparam ("node_id" ),
219- created_at = bindparam ("created_at" ),
220- uplink = 0 ,
221- downlink = 0 ,
222- ).prefix_with ("OR IGNORE" )
223- await safe_execute (db , insert_stmt , [upsert_param ])
228+ insert_stmt = (
229+ insert (NodeUsage )
230+ .values (
231+ node_id = bindparam ("node_id" ),
232+ created_at = bindparam ("created_at" ),
233+ uplink = 0 ,
234+ downlink = 0 ,
235+ )
236+ .prefix_with ("OR IGNORE" )
237+ )
238+ await safe_execute (insert_stmt , [upsert_param ])
224239
225240 # Update the row (existing or newly inserted)
226241 update_stmt = (
@@ -236,7 +251,7 @@ async def record_node_stats(params: list[dict], node_id: int):
236251 )
237252 )
238253 )
239- await safe_execute (db , update_stmt , [upsert_param ])
254+ await safe_execute (update_stmt , [upsert_param ])
240255
241256
242257async def get_users_stats (node : PasarGuardNode ):
@@ -336,25 +351,24 @@ async def record_user_usages():
336351 if not users_usage :
337352 return
338353
339- async with GetDB () as db :
340- user_stmt = (
341- update (User )
342- .where (User .id == bindparam ("uid" ))
343- .values (used_traffic = User .used_traffic + bindparam ("value" ), online_at = dt .now (tz .utc ))
354+ user_stmt = (
355+ update (User )
356+ .where (User .id == bindparam ("uid" ))
357+ .values (used_traffic = User .used_traffic + bindparam ("value" ), online_at = dt .now (tz .utc ))
358+ .execution_options (synchronize_session = False )
359+ )
360+ await safe_execute (user_stmt , users_usage )
361+
362+ admin_usage = await calculate_admin_usage (users_usage )
363+ if admin_usage :
364+ admin_data = [{"admin_id" : aid , "value" : val } for aid , val in admin_usage .items ()]
365+ admin_stmt = (
366+ update (Admin )
367+ .where (Admin .id == bindparam ("admin_id" ))
368+ .values (used_traffic = Admin .used_traffic + bindparam ("value" ))
344369 .execution_options (synchronize_session = False )
345370 )
346- await safe_execute (db , user_stmt , users_usage )
347-
348- admin_usage = await calculate_admin_usage (users_usage )
349- if admin_usage :
350- admin_data = [{"admin_id" : aid , "value" : val } for aid , val in admin_usage .items ()]
351- admin_stmt = (
352- update (Admin )
353- .where (Admin .id == bindparam ("admin_id" ))
354- .values (used_traffic = Admin .used_traffic + bindparam ("value" ))
355- .execution_options (synchronize_session = False )
356- )
357- await safe_execute (db , admin_stmt , admin_data )
371+ await safe_execute (admin_stmt , admin_data )
358372
359373 if DISABLE_RECORDING_NODE_USAGE :
360374 return
@@ -385,11 +399,8 @@ async def record_node_usages():
385399 if not (total_up or total_down ):
386400 return
387401
388- async with GetDB () as db :
389- system_update_stmt = update (System ).values (
390- uplink = System .uplink + total_up , downlink = System .downlink + total_down
391- )
392- await safe_execute (db , system_update_stmt )
402+ system_update_stmt = update (System ).values (uplink = System .uplink + total_up , downlink = System .downlink + total_down )
403+ await safe_execute (system_update_stmt )
393404
394405 if DISABLE_RECORDING_NODE_USAGE :
395406 return
@@ -399,8 +410,19 @@ async def record_node_usages():
399410
400411
401412scheduler .add_job (
402- record_user_usages , "interval" , seconds = JOB_RECORD_USER_USAGES_INTERVAL , coalesce = True , max_instances = 1
413+ record_user_usages ,
414+ "interval" ,
415+ seconds = JOB_RECORD_USER_USAGES_INTERVAL ,
416+ coalesce = True ,
417+ start_date = dt .now (tz .utc ) + td (seconds = 30 ),
418+ max_instances = 1 ,
403419)
420+
404421scheduler .add_job (
405- record_node_usages , "interval" , seconds = JOB_RECORD_NODE_USAGES_INTERVAL , coalesce = True , max_instances = 1
422+ record_node_usages ,
423+ "interval" ,
424+ seconds = JOB_RECORD_NODE_USAGES_INTERVAL ,
425+ coalesce = True ,
426+ start_date = dt .now (tz .utc ) + td (seconds = 15 ),
427+ max_instances = 1 ,
406428)
0 commit comments