|
9 | 9 |
|
10 | 10 | from PasarGuardNodeBridge import NodeAPIError, PasarGuardNode |
11 | 11 | from PasarGuardNodeBridge.common.service_pb2 import StatType |
12 | | -from sqlalchemy import and_, bindparam, insert, select, update |
| 12 | +from sqlalchemy import and_, bindparam, insert, select, text, update |
13 | 13 | from sqlalchemy.dialects.mysql import insert as mysql_insert |
14 | 14 | from sqlalchemy.dialects.postgresql import insert as pg_insert |
15 | | -from sqlalchemy.exc import DatabaseError, IntegrityError, OperationalError |
| 15 | +from sqlalchemy.exc import DatabaseError, OperationalError |
16 | 16 | from sqlalchemy.sql.expression import Insert |
17 | 17 |
|
18 | 18 | from app import on_shutdown, scheduler |
19 | 19 | from app.db import GetDB |
20 | 20 | from app.db.base import engine |
21 | | -from app.db.models import Admin, Node, NodeUsage, NodeUserUsage, System, User |
| 21 | +from app.db.models import Admin, Node, NodeUsage, System, User |
22 | 22 | from app.node import node_manager |
23 | 23 | from app.utils.logger import get_logger |
24 | 24 | from config import job_settings, runtime_settings, usage_settings |
@@ -102,69 +102,55 @@ def build_node_user_usage_upsert(dialect: str, upsert_params: list[dict]): |
102 | 102 | upsert_params: List of parameter dicts with keys: uid, node_id, created_at, value |
103 | 103 |
|
104 | 104 | Returns: |
105 | | - tuple: (statements_list, params_list) - For SQLite returns 2 statements, others return 1 |
106 | | - """ |
107 | | - if dialect == "postgresql": |
108 | | - stmt = pg_insert(NodeUserUsage).values( |
109 | | - user_id=bindparam("uid"), |
110 | | - node_id=bindparam("node_id"), |
111 | | - created_at=bindparam("created_at"), |
112 | | - used_traffic=bindparam("value"), |
113 | | - ) |
114 | | - stmt = stmt.on_conflict_do_update( |
115 | | - index_elements=["created_at", "user_id", "node_id"], |
116 | | - set_={"used_traffic": NodeUserUsage.used_traffic + bindparam("value")}, |
| 105 | + list: One SQL statement and its bound parameters. |
| 106 | + """ |
| 107 | + select_parts = [] |
| 108 | + stmt_params = {} |
| 109 | + for index, param in enumerate(upsert_params): |
| 110 | + uid_key = f"uid_{index}" |
| 111 | + node_id_key = f"node_id_{index}" |
| 112 | + created_at_key = f"created_at_{index}" |
| 113 | + value_key = f"value_{index}" |
| 114 | + select_parts.append( |
| 115 | + f"SELECT :{uid_key} AS uid, :{node_id_key} AS node_id, " |
| 116 | + f":{created_at_key} AS created_at, :{value_key} AS value" |
117 | 117 | ) |
118 | | - return [(stmt, upsert_params)] |
119 | | - |
120 | | - elif dialect == "mysql": |
121 | | - stmt = mysql_insert(NodeUserUsage).values( |
122 | | - user_id=bindparam("uid"), |
123 | | - node_id=bindparam("node_id"), |
124 | | - created_at=bindparam("created_at"), |
125 | | - used_traffic=bindparam("value"), |
126 | | - ) |
127 | | - stmt = stmt.on_duplicate_key_update(used_traffic=NodeUserUsage.used_traffic + stmt.inserted.used_traffic) |
128 | | - return [(stmt, upsert_params)] |
129 | | - |
130 | | - else: # SQLite |
131 | | - # Insert with OR IGNORE |
132 | | - insert_stmt = ( |
133 | | - insert(NodeUserUsage) |
134 | | - .values( |
135 | | - user_id=bindparam("uid"), |
136 | | - node_id=bindparam("node_id"), |
137 | | - created_at=bindparam("created_at"), |
138 | | - used_traffic=0, |
139 | | - ) |
140 | | - .prefix_with("OR IGNORE") |
141 | | - ) |
142 | | - |
143 | | - # Update with renamed bindparams to avoid conflicts |
144 | | - update_stmt = ( |
145 | | - update(NodeUserUsage) |
146 | | - .values(used_traffic=NodeUserUsage.used_traffic + bindparam("value")) |
147 | | - .where( |
148 | | - and_( |
149 | | - NodeUserUsage.user_id == bindparam("b_uid"), |
150 | | - NodeUserUsage.node_id == bindparam("b_node_id"), |
151 | | - NodeUserUsage.created_at == bindparam("b_created_at"), |
152 | | - ) |
153 | | - ) |
154 | | - ) |
155 | | - |
156 | | - # Remap params for update statement |
157 | | - update_params = [ |
158 | | - { |
159 | | - "value": p["value"], |
160 | | - "b_uid": p["uid"], |
161 | | - "b_node_id": p["node_id"], |
162 | | - "b_created_at": p["created_at"], |
163 | | - } |
164 | | - for p in upsert_params |
165 | | - ] |
| 118 | + stmt_params[uid_key] = param["uid"] |
| 119 | + stmt_params[node_id_key] = param["node_id"] |
| 120 | + stmt_params[created_at_key] = param["created_at"] |
| 121 | + stmt_params[value_key] = param["value"] |
| 122 | + |
| 123 | + values_source = "\nUNION ALL\n".join(select_parts) |
| 124 | + conflict_clause = { |
| 125 | + "postgresql": ( |
| 126 | + "ON CONFLICT (created_at, user_id, node_id) DO UPDATE " |
| 127 | + "SET used_traffic = node_user_usages.used_traffic + EXCLUDED.used_traffic" |
| 128 | + ), |
| 129 | + "mysql": ( |
| 130 | + "ON DUPLICATE KEY UPDATE " |
| 131 | + "used_traffic = node_user_usages.used_traffic + VALUES(used_traffic)" |
| 132 | + ), |
| 133 | + }.get( |
| 134 | + dialect, |
| 135 | + ( |
| 136 | + "ON CONFLICT(created_at, user_id, node_id) DO UPDATE " |
| 137 | + "SET used_traffic = node_user_usages.used_traffic + excluded.used_traffic" |
| 138 | + ), |
| 139 | + ) |
166 | 140 |
|
167 | | - return [(insert_stmt, upsert_params), (update_stmt, update_params)] |
| 141 | + stmt = text( |
| 142 | + f""" |
| 143 | + INSERT INTO node_user_usages (created_at, user_id, node_id, used_traffic) |
| 144 | + SELECT source.created_at, source.uid, source.node_id, SUM(source.value) |
| 145 | + FROM ( |
| 146 | + {values_source} |
| 147 | + ) AS source |
| 148 | + JOIN users ON users.id = source.uid |
| 149 | + GROUP BY source.created_at, source.uid, source.node_id |
| 150 | + {conflict_clause} |
| 151 | + """ |
| 152 | + ) |
| 153 | + return [(stmt, stmt_params)] |
168 | 154 |
|
169 | 155 |
|
170 | 156 | def build_node_usage_upsert(dialect: str, upsert_param: dict): |
@@ -313,31 +299,6 @@ async def safe_execute(stmt, params=None, max_retries: int = 2): |
313 | 299 | raise |
314 | 300 |
|
315 | 301 |
|
316 | | -async def _filter_existing_user_usage_params(upsert_params: list[dict]) -> list[dict]: |
317 | | - """Drop chart usage rows for users that no longer exist.""" |
318 | | - if not upsert_params: |
319 | | - return [] |
320 | | - |
321 | | - user_ids = {param["uid"] for param in upsert_params} |
322 | | - |
323 | | - async with GetDB() as db: |
324 | | - result = await db.execute(select(User.id).where(User.id.in_(user_ids))) |
325 | | - existing_user_ids = set(result.scalars().all()) |
326 | | - |
327 | | - if len(existing_user_ids) == len(user_ids): |
328 | | - return upsert_params |
329 | | - |
330 | | - filtered_params = [param for param in upsert_params if param["uid"] in existing_user_ids] |
331 | | - skipped = len(upsert_params) - len(filtered_params) |
332 | | - missing_users = len(user_ids) - len(existing_user_ids) |
333 | | - logger.warning( |
334 | | - "Skipped %s node user usage records for %s missing users", |
335 | | - skipped, |
336 | | - missing_users, |
337 | | - ) |
338 | | - return filtered_params |
339 | | - |
340 | | - |
341 | 302 | def _get_time_bucket(now: dt = None) -> dt: |
342 | 303 | """ |
343 | 304 | Get 10-minute time bucket instead of hourly to reduce hot row contention. |
@@ -390,28 +351,11 @@ async def record_user_stats_batched(all_node_params: dict, usage_coefficients: d |
390 | 351 | if not upsert_params: |
391 | 352 | return |
392 | 353 |
|
393 | | - upsert_params = await _filter_existing_user_usage_params(upsert_params) |
394 | | - if not upsert_params: |
395 | | - return |
396 | | - |
397 | 354 | # Execute single batched UPSERT with concurrency control |
398 | 355 | async with JOB_SEM: |
399 | 356 | queries = build_node_user_usage_upsert(dialect, upsert_params) |
400 | 357 | for stmt, stmt_params in queries: |
401 | | - try: |
402 | | - await safe_execute(stmt, stmt_params) |
403 | | - except IntegrityError: |
404 | | - # A user can be deleted after the pre-insert filter but before |
405 | | - # the write reaches the database. Re-check once so one stale |
406 | | - # node stat does not drop the whole batch. |
407 | | - retry_params = await _filter_existing_user_usage_params(upsert_params) |
408 | | - if len(retry_params) == len(upsert_params): |
409 | | - raise |
410 | | - if not retry_params: |
411 | | - return |
412 | | - retry_queries = build_node_user_usage_upsert(dialect, retry_params) |
413 | | - for retry_stmt, retry_stmt_params in retry_queries: |
414 | | - await safe_execute(retry_stmt, retry_stmt_params) |
| 358 | + await safe_execute(stmt, stmt_params) |
415 | 359 |
|
416 | 360 |
|
417 | 361 | async def record_node_stats_batched(all_node_params: dict): |
|
0 commit comments