Skip to content

Commit d13f1c7

Browse files
committed
chore(jobs): cleanup record usage
1 parent c6620c9 commit d13f1c7

File tree

1 file changed

+189
-133
lines changed

1 file changed

+189
-133
lines changed

app/jobs/record_usages.py

Lines changed: 189 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,167 @@
2626
logger = get_logger("record-usages")
2727

2828

29+
async def get_dialect() -> str:
30+
"""Get the database dialect name without holding the session open."""
31+
async with GetDB() as db:
32+
return db.bind.dialect.name
33+
34+
35+
def build_node_user_usage_upsert(dialect: str, upsert_params: list[dict]):
36+
"""
37+
Build UPSERT statement for NodeUserUsage based on database dialect.
38+
39+
Args:
40+
dialect: Database dialect name ('postgresql', 'mysql', or 'sqlite')
41+
upsert_params: List of parameter dicts with keys: uid, node_id, created_at, value
42+
43+
Returns:
44+
tuple: (statements_list, params_list) - For SQLite returns 2 statements, others return 1
45+
"""
46+
if dialect == "postgresql":
47+
stmt = pg_insert(NodeUserUsage).values(
48+
user_id=bindparam("uid"),
49+
node_id=bindparam("node_id"),
50+
created_at=bindparam("created_at"),
51+
used_traffic=bindparam("value"),
52+
)
53+
stmt = stmt.on_conflict_do_update(
54+
index_elements=["created_at", "user_id", "node_id"],
55+
set_={"used_traffic": NodeUserUsage.used_traffic + bindparam("value")},
56+
)
57+
return [(stmt, upsert_params)]
58+
59+
elif dialect == "mysql":
60+
stmt = mysql_insert(NodeUserUsage).values(
61+
user_id=bindparam("uid"),
62+
node_id=bindparam("node_id"),
63+
created_at=bindparam("created_at"),
64+
used_traffic=bindparam("value"),
65+
)
66+
stmt = stmt.on_duplicate_key_update(
67+
used_traffic=NodeUserUsage.used_traffic + stmt.inserted.used_traffic
68+
)
69+
return [(stmt, upsert_params)]
70+
71+
else: # SQLite
72+
# Insert with OR IGNORE
73+
insert_stmt = (
74+
insert(NodeUserUsage)
75+
.values(
76+
user_id=bindparam("uid"),
77+
node_id=bindparam("node_id"),
78+
created_at=bindparam("created_at"),
79+
used_traffic=0,
80+
)
81+
.prefix_with("OR IGNORE")
82+
)
83+
84+
# Update with renamed bindparams to avoid conflicts
85+
update_stmt = (
86+
update(NodeUserUsage)
87+
.values(used_traffic=NodeUserUsage.used_traffic + bindparam("value"))
88+
.where(
89+
and_(
90+
NodeUserUsage.user_id == bindparam("b_uid"),
91+
NodeUserUsage.node_id == bindparam("b_node_id"),
92+
NodeUserUsage.created_at == bindparam("b_created_at"),
93+
)
94+
)
95+
)
96+
97+
# Remap params for update statement
98+
update_params = [
99+
{
100+
"value": p["value"],
101+
"b_uid": p["uid"],
102+
"b_node_id": p["node_id"],
103+
"b_created_at": p["created_at"],
104+
}
105+
for p in upsert_params
106+
]
107+
108+
return [(insert_stmt, upsert_params), (update_stmt, update_params)]
109+
110+
111+
def build_node_usage_upsert(dialect: str, upsert_param: dict):
112+
"""
113+
Build UPSERT statement for NodeUsage based on database dialect.
114+
115+
Args:
116+
dialect: Database dialect name ('postgresql', 'mysql', or 'sqlite')
117+
upsert_param: Parameter dict with keys: node_id, created_at, up, down
118+
119+
Returns:
120+
tuple: (statements_list, params_list) - For SQLite returns 2 statements, others return 1
121+
"""
122+
if dialect == "postgresql":
123+
stmt = pg_insert(NodeUsage).values(
124+
node_id=bindparam("node_id"),
125+
created_at=bindparam("created_at"),
126+
uplink=bindparam("up"),
127+
downlink=bindparam("down"),
128+
)
129+
stmt = stmt.on_conflict_do_update(
130+
index_elements=["created_at", "node_id"],
131+
set_={
132+
"uplink": NodeUsage.uplink + bindparam("up"),
133+
"downlink": NodeUsage.downlink + bindparam("down"),
134+
},
135+
)
136+
return [(stmt, [upsert_param])]
137+
138+
elif dialect == "mysql":
139+
stmt = mysql_insert(NodeUsage).values(
140+
node_id=bindparam("node_id"),
141+
created_at=bindparam("created_at"),
142+
uplink=bindparam("up"),
143+
downlink=bindparam("down"),
144+
)
145+
stmt = stmt.on_duplicate_key_update(
146+
uplink=NodeUsage.uplink + stmt.inserted.uplink,
147+
downlink=NodeUsage.downlink + stmt.inserted.downlink,
148+
)
149+
return [(stmt, [upsert_param])]
150+
151+
else: # SQLite
152+
# Insert with OR IGNORE
153+
insert_stmt = (
154+
insert(NodeUsage)
155+
.values(
156+
node_id=bindparam("node_id"),
157+
created_at=bindparam("created_at"),
158+
uplink=0,
159+
downlink=0,
160+
)
161+
.prefix_with("OR IGNORE")
162+
)
163+
164+
# Update with renamed bindparams to avoid conflicts
165+
update_stmt = (
166+
update(NodeUsage)
167+
.values(
168+
uplink=NodeUsage.uplink + bindparam("up"),
169+
downlink=NodeUsage.downlink + bindparam("down"),
170+
)
171+
.where(
172+
and_(
173+
NodeUsage.node_id == bindparam("b_node_id"),
174+
NodeUsage.created_at == bindparam("b_created_at"),
175+
)
176+
)
177+
)
178+
179+
# Remap params for update statement
180+
update_param = {
181+
"up": upsert_param["up"],
182+
"down": upsert_param["down"],
183+
"b_node_id": upsert_param["node_id"],
184+
"b_created_at": upsert_param["created_at"],
185+
}
186+
187+
return [(insert_stmt, [upsert_param]), (update_stmt, [update_param])]
188+
189+
29190
async def safe_execute(stmt, params=None, max_retries: int = 5):
30191
"""
31192
Safely execute database operations with deadlock and connection handling.
@@ -96,72 +257,24 @@ async def record_user_stats(params: list[dict], node_id: int, usage_coefficient:
96257

97258
created_at = dt.now(tz.utc).replace(minute=0, second=0, microsecond=0)
98259

99-
async with GetDB() as db:
100-
dialect = db.bind.dialect.name
260+
# Get dialect without holding session
261+
dialect = await get_dialect()
101262

102-
# Prepare parameters - ensure uid is converted to int
103-
upsert_params = [
104-
{
105-
"uid": int(p["uid"]),
106-
"value": int(p["value"] * usage_coefficient),
107-
"node_id": node_id,
108-
"created_at": created_at,
109-
}
110-
for p in params
111-
]
263+
# Prepare parameters - ensure uid is converted to int
264+
upsert_params = [
265+
{
266+
"uid": int(p["uid"]),
267+
"value": int(p["value"] * usage_coefficient),
268+
"node_id": node_id,
269+
"created_at": created_at,
270+
}
271+
for p in params
272+
]
112273

113-
if dialect == "postgresql":
114-
stmt = pg_insert(NodeUserUsage).values(
115-
user_id=bindparam("uid"),
116-
node_id=bindparam("node_id"),
117-
created_at=bindparam("created_at"),
118-
used_traffic=bindparam("value"),
119-
)
120-
stmt = stmt.on_conflict_do_update(
121-
index_elements=["created_at", "user_id", "node_id"],
122-
set_={"used_traffic": NodeUserUsage.used_traffic + bindparam("value")},
123-
)
124-
await safe_execute(stmt, upsert_params)
125-
126-
elif dialect == "mysql":
127-
# MySQL: INSERT ... ON DUPLICATE KEY UPDATE
128-
stmt = mysql_insert(NodeUserUsage).values(
129-
user_id=bindparam("uid"),
130-
node_id=bindparam("node_id"),
131-
created_at=bindparam("created_at"),
132-
used_traffic=bindparam("value"),
133-
)
134-
# Use stmt.inserted to reference the inserted value (VALUES() in SQL)
135-
stmt = stmt.on_duplicate_key_update(used_traffic=NodeUserUsage.used_traffic + stmt.inserted.used_traffic)
136-
await safe_execute(stmt, upsert_params)
137-
138-
else: # SQLite
139-
# SQLite: Use INSERT OR IGNORE + UPDATE pattern
140-
insert_stmt = (
141-
insert(NodeUserUsage)
142-
.values(
143-
user_id=bindparam("uid"),
144-
node_id=bindparam("node_id"),
145-
created_at=bindparam("created_at"),
146-
used_traffic=0,
147-
)
148-
.prefix_with("OR IGNORE")
149-
)
150-
await safe_execute(insert_stmt, upsert_params)
151-
152-
# Update all rows (existing + newly inserted)
153-
update_stmt = (
154-
update(NodeUserUsage)
155-
.values(used_traffic=NodeUserUsage.used_traffic + bindparam("value"))
156-
.where(
157-
and_(
158-
NodeUserUsage.user_id == bindparam("uid"),
159-
NodeUserUsage.node_id == bindparam("node_id"),
160-
NodeUserUsage.created_at == bindparam("created_at"),
161-
)
162-
)
163-
)
164-
await safe_execute(update_stmt, upsert_params)
274+
# Build and execute queries for the specific dialect
275+
queries = build_node_user_usage_upsert(dialect, upsert_params)
276+
for stmt, stmt_params in queries:
277+
await safe_execute(stmt, stmt_params)
165278

166279

167280
async def record_node_stats(params: list[dict], node_id: int):
@@ -181,77 +294,20 @@ async def record_node_stats(params: list[dict], node_id: int):
181294
total_up = sum(p.get("up", 0) for p in params)
182295
total_down = sum(p.get("down", 0) for p in params)
183296

184-
async with GetDB() as db:
185-
dialect = db.bind.dialect.name
186-
187-
upsert_param = {
188-
"node_id": node_id,
189-
"created_at": created_at,
190-
"up": total_up,
191-
"down": total_down,
192-
}
297+
# Get dialect without holding session
298+
dialect = await get_dialect()
193299

194-
if dialect == "postgresql":
195-
# PostgreSQL: INSERT ... ON CONFLICT DO UPDATE
196-
stmt = pg_insert(NodeUsage).values(
197-
node_id=bindparam("node_id"),
198-
created_at=bindparam("created_at"),
199-
uplink=bindparam("up"),
200-
downlink=bindparam("down"),
201-
)
202-
stmt = stmt.on_conflict_do_update(
203-
index_elements=["created_at", "node_id"],
204-
set_={
205-
"uplink": NodeUsage.uplink + bindparam("up"),
206-
"downlink": NodeUsage.downlink + bindparam("down"),
207-
},
208-
)
209-
await safe_execute(stmt, [upsert_param])
300+
upsert_param = {
301+
"node_id": node_id,
302+
"created_at": created_at,
303+
"up": total_up,
304+
"down": total_down,
305+
}
210306

211-
elif dialect == "mysql":
212-
# MySQL: INSERT ... ON DUPLICATE KEY UPDATE
213-
stmt = mysql_insert(NodeUsage).values(
214-
node_id=bindparam("node_id"),
215-
created_at=bindparam("created_at"),
216-
uplink=bindparam("up"),
217-
downlink=bindparam("down"),
218-
)
219-
# Use stmt.inserted to reference the inserted values (VALUES() in SQL)
220-
stmt = stmt.on_duplicate_key_update(
221-
uplink=NodeUsage.uplink + stmt.inserted.uplink,
222-
downlink=NodeUsage.downlink + stmt.inserted.downlink,
223-
)
224-
await safe_execute(stmt, [upsert_param])
225-
226-
else: # SQLite
227-
# SQLite: Use INSERT OR IGNORE + UPDATE pattern
228-
insert_stmt = (
229-
insert(NodeUsage)
230-
.values(
231-
node_id=bindparam("node_id"),
232-
created_at=bindparam("created_at"),
233-
uplink=0,
234-
downlink=0,
235-
)
236-
.prefix_with("OR IGNORE")
237-
)
238-
await safe_execute(insert_stmt, [upsert_param])
239-
240-
# Update the row (existing or newly inserted)
241-
update_stmt = (
242-
update(NodeUsage)
243-
.values(
244-
uplink=NodeUsage.uplink + bindparam("up"),
245-
downlink=NodeUsage.downlink + bindparam("down"),
246-
)
247-
.where(
248-
and_(
249-
NodeUsage.node_id == bindparam("node_id"),
250-
NodeUsage.created_at == bindparam("created_at"),
251-
)
252-
)
253-
)
254-
await safe_execute(update_stmt, [upsert_param])
307+
# Build and execute queries for the specific dialect
308+
queries = build_node_usage_upsert(dialect, upsert_param)
309+
for stmt, stmt_params in queries:
310+
await safe_execute(stmt, stmt_params)
255311

256312

257313
async def get_users_stats(node: PasarGuardNode):

0 commit comments

Comments
 (0)