2626logger = get_logger ("record-usages" )
2727
2828
29+ async def get_dialect () -> str :
30+ """Get the database dialect name without holding the session open."""
31+ async with GetDB () as db :
32+ return db .bind .dialect .name
33+
34+
35+ def build_node_user_usage_upsert (dialect : str , upsert_params : list [dict ]):
36+ """
37+ Build UPSERT statement for NodeUserUsage based on database dialect.
38+
39+ Args:
40+ dialect: Database dialect name ('postgresql', 'mysql', or 'sqlite')
41+ upsert_params: List of parameter dicts with keys: uid, node_id, created_at, value
42+
43+ Returns:
44+ tuple: (statements_list, params_list) - For SQLite returns 2 statements, others return 1
45+ """
46+ if dialect == "postgresql" :
47+ stmt = pg_insert (NodeUserUsage ).values (
48+ user_id = bindparam ("uid" ),
49+ node_id = bindparam ("node_id" ),
50+ created_at = bindparam ("created_at" ),
51+ used_traffic = bindparam ("value" ),
52+ )
53+ stmt = stmt .on_conflict_do_update (
54+ index_elements = ["created_at" , "user_id" , "node_id" ],
55+ set_ = {"used_traffic" : NodeUserUsage .used_traffic + bindparam ("value" )},
56+ )
57+ return [(stmt , upsert_params )]
58+
59+ elif dialect == "mysql" :
60+ stmt = mysql_insert (NodeUserUsage ).values (
61+ user_id = bindparam ("uid" ),
62+ node_id = bindparam ("node_id" ),
63+ created_at = bindparam ("created_at" ),
64+ used_traffic = bindparam ("value" ),
65+ )
66+ stmt = stmt .on_duplicate_key_update (
67+ used_traffic = NodeUserUsage .used_traffic + stmt .inserted .used_traffic
68+ )
69+ return [(stmt , upsert_params )]
70+
71+ else : # SQLite
72+ # Insert with OR IGNORE
73+ insert_stmt = (
74+ insert (NodeUserUsage )
75+ .values (
76+ user_id = bindparam ("uid" ),
77+ node_id = bindparam ("node_id" ),
78+ created_at = bindparam ("created_at" ),
79+ used_traffic = 0 ,
80+ )
81+ .prefix_with ("OR IGNORE" )
82+ )
83+
84+ # Update with renamed bindparams to avoid conflicts
85+ update_stmt = (
86+ update (NodeUserUsage )
87+ .values (used_traffic = NodeUserUsage .used_traffic + bindparam ("value" ))
88+ .where (
89+ and_ (
90+ NodeUserUsage .user_id == bindparam ("b_uid" ),
91+ NodeUserUsage .node_id == bindparam ("b_node_id" ),
92+ NodeUserUsage .created_at == bindparam ("b_created_at" ),
93+ )
94+ )
95+ )
96+
97+ # Remap params for update statement
98+ update_params = [
99+ {
100+ "value" : p ["value" ],
101+ "b_uid" : p ["uid" ],
102+ "b_node_id" : p ["node_id" ],
103+ "b_created_at" : p ["created_at" ],
104+ }
105+ for p in upsert_params
106+ ]
107+
108+ return [(insert_stmt , upsert_params ), (update_stmt , update_params )]
109+
110+
111+ def build_node_usage_upsert (dialect : str , upsert_param : dict ):
112+ """
113+ Build UPSERT statement for NodeUsage based on database dialect.
114+
115+ Args:
116+ dialect: Database dialect name ('postgresql', 'mysql', or 'sqlite')
117+ upsert_param: Parameter dict with keys: node_id, created_at, up, down
118+
119+ Returns:
120+ tuple: (statements_list, params_list) - For SQLite returns 2 statements, others return 1
121+ """
122+ if dialect == "postgresql" :
123+ stmt = pg_insert (NodeUsage ).values (
124+ node_id = bindparam ("node_id" ),
125+ created_at = bindparam ("created_at" ),
126+ uplink = bindparam ("up" ),
127+ downlink = bindparam ("down" ),
128+ )
129+ stmt = stmt .on_conflict_do_update (
130+ index_elements = ["created_at" , "node_id" ],
131+ set_ = {
132+ "uplink" : NodeUsage .uplink + bindparam ("up" ),
133+ "downlink" : NodeUsage .downlink + bindparam ("down" ),
134+ },
135+ )
136+ return [(stmt , [upsert_param ])]
137+
138+ elif dialect == "mysql" :
139+ stmt = mysql_insert (NodeUsage ).values (
140+ node_id = bindparam ("node_id" ),
141+ created_at = bindparam ("created_at" ),
142+ uplink = bindparam ("up" ),
143+ downlink = bindparam ("down" ),
144+ )
145+ stmt = stmt .on_duplicate_key_update (
146+ uplink = NodeUsage .uplink + stmt .inserted .uplink ,
147+ downlink = NodeUsage .downlink + stmt .inserted .downlink ,
148+ )
149+ return [(stmt , [upsert_param ])]
150+
151+ else : # SQLite
152+ # Insert with OR IGNORE
153+ insert_stmt = (
154+ insert (NodeUsage )
155+ .values (
156+ node_id = bindparam ("node_id" ),
157+ created_at = bindparam ("created_at" ),
158+ uplink = 0 ,
159+ downlink = 0 ,
160+ )
161+ .prefix_with ("OR IGNORE" )
162+ )
163+
164+ # Update with renamed bindparams to avoid conflicts
165+ update_stmt = (
166+ update (NodeUsage )
167+ .values (
168+ uplink = NodeUsage .uplink + bindparam ("up" ),
169+ downlink = NodeUsage .downlink + bindparam ("down" ),
170+ )
171+ .where (
172+ and_ (
173+ NodeUsage .node_id == bindparam ("b_node_id" ),
174+ NodeUsage .created_at == bindparam ("b_created_at" ),
175+ )
176+ )
177+ )
178+
179+ # Remap params for update statement
180+ update_param = {
181+ "up" : upsert_param ["up" ],
182+ "down" : upsert_param ["down" ],
183+ "b_node_id" : upsert_param ["node_id" ],
184+ "b_created_at" : upsert_param ["created_at" ],
185+ }
186+
187+ return [(insert_stmt , [upsert_param ]), (update_stmt , [update_param ])]
188+
189+
29190async def safe_execute (stmt , params = None , max_retries : int = 5 ):
30191 """
31192 Safely execute database operations with deadlock and connection handling.
@@ -96,72 +257,24 @@ async def record_user_stats(params: list[dict], node_id: int, usage_coefficient:
96257
97258 created_at = dt .now (tz .utc ).replace (minute = 0 , second = 0 , microsecond = 0 )
98259
99- async with GetDB () as db :
100- dialect = db . bind . dialect . name
260+ # Get dialect without holding session
261+ dialect = await get_dialect ()
101262
102- # Prepare parameters - ensure uid is converted to int
103- upsert_params = [
104- {
105- "uid" : int (p ["uid" ]),
106- "value" : int (p ["value" ] * usage_coefficient ),
107- "node_id" : node_id ,
108- "created_at" : created_at ,
109- }
110- for p in params
111- ]
263+ # Prepare parameters - ensure uid is converted to int
264+ upsert_params = [
265+ {
266+ "uid" : int (p ["uid" ]),
267+ "value" : int (p ["value" ] * usage_coefficient ),
268+ "node_id" : node_id ,
269+ "created_at" : created_at ,
270+ }
271+ for p in params
272+ ]
112273
113- if dialect == "postgresql" :
114- stmt = pg_insert (NodeUserUsage ).values (
115- user_id = bindparam ("uid" ),
116- node_id = bindparam ("node_id" ),
117- created_at = bindparam ("created_at" ),
118- used_traffic = bindparam ("value" ),
119- )
120- stmt = stmt .on_conflict_do_update (
121- index_elements = ["created_at" , "user_id" , "node_id" ],
122- set_ = {"used_traffic" : NodeUserUsage .used_traffic + bindparam ("value" )},
123- )
124- await safe_execute (stmt , upsert_params )
125-
126- elif dialect == "mysql" :
127- # MySQL: INSERT ... ON DUPLICATE KEY UPDATE
128- stmt = mysql_insert (NodeUserUsage ).values (
129- user_id = bindparam ("uid" ),
130- node_id = bindparam ("node_id" ),
131- created_at = bindparam ("created_at" ),
132- used_traffic = bindparam ("value" ),
133- )
134- # Use stmt.inserted to reference the inserted value (VALUES() in SQL)
135- stmt = stmt .on_duplicate_key_update (used_traffic = NodeUserUsage .used_traffic + stmt .inserted .used_traffic )
136- await safe_execute (stmt , upsert_params )
137-
138- else : # SQLite
139- # SQLite: Use INSERT OR IGNORE + UPDATE pattern
140- insert_stmt = (
141- insert (NodeUserUsage )
142- .values (
143- user_id = bindparam ("uid" ),
144- node_id = bindparam ("node_id" ),
145- created_at = bindparam ("created_at" ),
146- used_traffic = 0 ,
147- )
148- .prefix_with ("OR IGNORE" )
149- )
150- await safe_execute (insert_stmt , upsert_params )
151-
152- # Update all rows (existing + newly inserted)
153- update_stmt = (
154- update (NodeUserUsage )
155- .values (used_traffic = NodeUserUsage .used_traffic + bindparam ("value" ))
156- .where (
157- and_ (
158- NodeUserUsage .user_id == bindparam ("uid" ),
159- NodeUserUsage .node_id == bindparam ("node_id" ),
160- NodeUserUsage .created_at == bindparam ("created_at" ),
161- )
162- )
163- )
164- await safe_execute (update_stmt , upsert_params )
274+ # Build and execute queries for the specific dialect
275+ queries = build_node_user_usage_upsert (dialect , upsert_params )
276+ for stmt , stmt_params in queries :
277+ await safe_execute (stmt , stmt_params )
165278
166279
167280async def record_node_stats (params : list [dict ], node_id : int ):
@@ -181,77 +294,20 @@ async def record_node_stats(params: list[dict], node_id: int):
181294 total_up = sum (p .get ("up" , 0 ) for p in params )
182295 total_down = sum (p .get ("down" , 0 ) for p in params )
183296
184- async with GetDB () as db :
185- dialect = db .bind .dialect .name
186-
187- upsert_param = {
188- "node_id" : node_id ,
189- "created_at" : created_at ,
190- "up" : total_up ,
191- "down" : total_down ,
192- }
297+ # Get dialect without holding session
298+ dialect = await get_dialect ()
193299
194- if dialect == "postgresql" :
195- # PostgreSQL: INSERT ... ON CONFLICT DO UPDATE
196- stmt = pg_insert (NodeUsage ).values (
197- node_id = bindparam ("node_id" ),
198- created_at = bindparam ("created_at" ),
199- uplink = bindparam ("up" ),
200- downlink = bindparam ("down" ),
201- )
202- stmt = stmt .on_conflict_do_update (
203- index_elements = ["created_at" , "node_id" ],
204- set_ = {
205- "uplink" : NodeUsage .uplink + bindparam ("up" ),
206- "downlink" : NodeUsage .downlink + bindparam ("down" ),
207- },
208- )
209- await safe_execute (stmt , [upsert_param ])
300+ upsert_param = {
301+ "node_id" : node_id ,
302+ "created_at" : created_at ,
303+ "up" : total_up ,
304+ "down" : total_down ,
305+ }
210306
211- elif dialect == "mysql" :
212- # MySQL: INSERT ... ON DUPLICATE KEY UPDATE
213- stmt = mysql_insert (NodeUsage ).values (
214- node_id = bindparam ("node_id" ),
215- created_at = bindparam ("created_at" ),
216- uplink = bindparam ("up" ),
217- downlink = bindparam ("down" ),
218- )
219- # Use stmt.inserted to reference the inserted values (VALUES() in SQL)
220- stmt = stmt .on_duplicate_key_update (
221- uplink = NodeUsage .uplink + stmt .inserted .uplink ,
222- downlink = NodeUsage .downlink + stmt .inserted .downlink ,
223- )
224- await safe_execute (stmt , [upsert_param ])
225-
226- else : # SQLite
227- # SQLite: Use INSERT OR IGNORE + UPDATE pattern
228- insert_stmt = (
229- insert (NodeUsage )
230- .values (
231- node_id = bindparam ("node_id" ),
232- created_at = bindparam ("created_at" ),
233- uplink = 0 ,
234- downlink = 0 ,
235- )
236- .prefix_with ("OR IGNORE" )
237- )
238- await safe_execute (insert_stmt , [upsert_param ])
239-
240- # Update the row (existing or newly inserted)
241- update_stmt = (
242- update (NodeUsage )
243- .values (
244- uplink = NodeUsage .uplink + bindparam ("up" ),
245- downlink = NodeUsage .downlink + bindparam ("down" ),
246- )
247- .where (
248- and_ (
249- NodeUsage .node_id == bindparam ("node_id" ),
250- NodeUsage .created_at == bindparam ("created_at" ),
251- )
252- )
253- )
254- await safe_execute (update_stmt , [upsert_param ])
307+ # Build and execute queries for the specific dialect
308+ queries = build_node_usage_upsert (dialect , upsert_param )
309+ for stmt , stmt_params in queries :
310+ await safe_execute (stmt , stmt_params )
255311
256312
257313async def get_users_stats (node : PasarGuardNode ):
0 commit comments