|
1 | 1 | import asyncio |
2 | 2 | import random |
3 | 3 | from collections import defaultdict |
4 | | -from datetime import datetime as dt, timezone as tz, timedelta as td |
| 4 | +from datetime import datetime as dt, timedelta as td, timezone as tz |
5 | 5 | from operator import attrgetter |
6 | 6 |
|
7 | | -from PasarGuardNodeBridge import PasarGuardNode, NodeAPIError |
| 7 | +from PasarGuardNodeBridge import NodeAPIError, PasarGuardNode |
8 | 8 | from PasarGuardNodeBridge.common.service_pb2 import StatType |
9 | 9 | from sqlalchemy import and_, bindparam, insert, select, update |
| 10 | +from sqlalchemy.dialects.mysql import insert as mysql_insert |
| 11 | +from sqlalchemy.dialects.postgresql import insert as pg_insert |
10 | 12 | from sqlalchemy.exc import DatabaseError, OperationalError |
11 | 13 | from sqlalchemy.sql.expression import Insert |
12 | | -from sqlalchemy.dialects.postgresql import insert as pg_insert |
13 | | -from sqlalchemy.dialects.mysql import insert as mysql_insert |
14 | 14 |
|
15 | 15 | from app import scheduler |
16 | 16 | from app.db import GetDB |
| 17 | +from app.db.base import engine |
17 | 18 | from app.db.models import Admin, Node, NodeUsage, NodeUserUsage, System, User |
18 | 19 | from app.node import node_manager as node_manager |
19 | 20 | from app.utils.logger import get_logger |
@@ -195,22 +196,20 @@ async def safe_execute(stmt, params=None, max_retries: int = 5): |
195 | 196 | params (list[dict], optional): Parameters for the statement |
196 | 197 | max_retries (int, optional): Maximum number of retry attempts (default: 5) |
197 | 198 | """ |
| 199 | + statement = stmt |
| 200 | + |
| 201 | + if await get_dialect() == "mysql" and isinstance(stmt, Insert): |
| 202 | + # MySQL-specific IGNORE prefix - but skip if using ON DUPLICATE KEY UPDATE |
| 203 | + if not hasattr(stmt, "_post_values_clause") or stmt._post_values_clause is None: |
| 204 | + statement = stmt.prefix_with("IGNORE") |
198 | 205 | for attempt in range(max_retries): |
199 | 206 | try: |
200 | | - # Create fresh session for each attempt to release any locks from previous attempts |
201 | | - async with GetDB() as db: |
202 | | - dialect = db.bind.dialect.name |
203 | | - |
204 | | - # MySQL-specific IGNORE prefix - but skip if using ON DUPLICATE KEY UPDATE |
205 | | - if dialect == "mysql" and isinstance(stmt, Insert): |
206 | | - # Check if statement already has ON DUPLICATE KEY UPDATE |
207 | | - if not hasattr(stmt, "_post_values_clause") or stmt._post_values_clause is None: |
208 | | - stmt = stmt.prefix_with("IGNORE") |
209 | | - |
210 | | - # Use raw connection to avoid ORM bulk update requirements |
211 | | - await (await db.connection()).execute(stmt, params) |
212 | | - await db.commit() |
213 | | - return # Success - exit function |
| 207 | + # engine.begin() ensures commit/rollback + connection return on exit |
| 208 | + async with engine.begin() as conn: |
| 209 | + if params is None: |
| 210 | + await conn.execute(statement) |
| 211 | + else: |
| 212 | + await conn.execute(statement, params) |
214 | 213 |
|
215 | 214 | except (OperationalError, DatabaseError) as err: |
216 | 215 | # Session auto-closed by context manager, locks released |
|
0 commit comments