Skip to content

Commit 2948b55

Browse files
committed
fix(record usages): insert node usage only for existing users
1 parent 0307259 commit 2948b55

1 file changed

Lines changed: 51 additions & 107 deletions

File tree

app/jobs/record_usages.py

Lines changed: 51 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,16 @@
99

1010
from PasarGuardNodeBridge import NodeAPIError, PasarGuardNode
1111
from PasarGuardNodeBridge.common.service_pb2 import StatType
12-
from sqlalchemy import and_, bindparam, insert, select, update
12+
from sqlalchemy import and_, bindparam, insert, select, text, update
1313
from sqlalchemy.dialects.mysql import insert as mysql_insert
1414
from sqlalchemy.dialects.postgresql import insert as pg_insert
15-
from sqlalchemy.exc import DatabaseError, IntegrityError, OperationalError
15+
from sqlalchemy.exc import DatabaseError, OperationalError
1616
from sqlalchemy.sql.expression import Insert
1717

1818
from app import on_shutdown, scheduler
1919
from app.db import GetDB
2020
from app.db.base import engine
21-
from app.db.models import Admin, Node, NodeUsage, NodeUserUsage, System, User
21+
from app.db.models import Admin, Node, NodeUsage, System, User
2222
from app.node import node_manager
2323
from app.utils.logger import get_logger
2424
from config import job_settings, runtime_settings, usage_settings
@@ -102,69 +102,55 @@ def build_node_user_usage_upsert(dialect: str, upsert_params: list[dict]):
102102
upsert_params: List of parameter dicts with keys: uid, node_id, created_at, value
103103
104104
Returns:
105-
tuple: (statements_list, params_list) - For SQLite returns 2 statements, others return 1
106-
"""
107-
if dialect == "postgresql":
108-
stmt = pg_insert(NodeUserUsage).values(
109-
user_id=bindparam("uid"),
110-
node_id=bindparam("node_id"),
111-
created_at=bindparam("created_at"),
112-
used_traffic=bindparam("value"),
113-
)
114-
stmt = stmt.on_conflict_do_update(
115-
index_elements=["created_at", "user_id", "node_id"],
116-
set_={"used_traffic": NodeUserUsage.used_traffic + bindparam("value")},
105+
list: One SQL statement and its bound parameters.
106+
"""
107+
select_parts = []
108+
stmt_params = {}
109+
for index, param in enumerate(upsert_params):
110+
uid_key = f"uid_{index}"
111+
node_id_key = f"node_id_{index}"
112+
created_at_key = f"created_at_{index}"
113+
value_key = f"value_{index}"
114+
select_parts.append(
115+
f"SELECT :{uid_key} AS uid, :{node_id_key} AS node_id, "
116+
f":{created_at_key} AS created_at, :{value_key} AS value"
117117
)
118-
return [(stmt, upsert_params)]
119-
120-
elif dialect == "mysql":
121-
stmt = mysql_insert(NodeUserUsage).values(
122-
user_id=bindparam("uid"),
123-
node_id=bindparam("node_id"),
124-
created_at=bindparam("created_at"),
125-
used_traffic=bindparam("value"),
126-
)
127-
stmt = stmt.on_duplicate_key_update(used_traffic=NodeUserUsage.used_traffic + stmt.inserted.used_traffic)
128-
return [(stmt, upsert_params)]
129-
130-
else: # SQLite
131-
# Insert with OR IGNORE
132-
insert_stmt = (
133-
insert(NodeUserUsage)
134-
.values(
135-
user_id=bindparam("uid"),
136-
node_id=bindparam("node_id"),
137-
created_at=bindparam("created_at"),
138-
used_traffic=0,
139-
)
140-
.prefix_with("OR IGNORE")
141-
)
142-
143-
# Update with renamed bindparams to avoid conflicts
144-
update_stmt = (
145-
update(NodeUserUsage)
146-
.values(used_traffic=NodeUserUsage.used_traffic + bindparam("value"))
147-
.where(
148-
and_(
149-
NodeUserUsage.user_id == bindparam("b_uid"),
150-
NodeUserUsage.node_id == bindparam("b_node_id"),
151-
NodeUserUsage.created_at == bindparam("b_created_at"),
152-
)
153-
)
154-
)
155-
156-
# Remap params for update statement
157-
update_params = [
158-
{
159-
"value": p["value"],
160-
"b_uid": p["uid"],
161-
"b_node_id": p["node_id"],
162-
"b_created_at": p["created_at"],
163-
}
164-
for p in upsert_params
165-
]
118+
stmt_params[uid_key] = param["uid"]
119+
stmt_params[node_id_key] = param["node_id"]
120+
stmt_params[created_at_key] = param["created_at"]
121+
stmt_params[value_key] = param["value"]
122+
123+
values_source = "\nUNION ALL\n".join(select_parts)
124+
conflict_clause = {
125+
"postgresql": (
126+
"ON CONFLICT (created_at, user_id, node_id) DO UPDATE "
127+
"SET used_traffic = node_user_usages.used_traffic + EXCLUDED.used_traffic"
128+
),
129+
"mysql": (
130+
"ON DUPLICATE KEY UPDATE "
131+
"used_traffic = node_user_usages.used_traffic + VALUES(used_traffic)"
132+
),
133+
}.get(
134+
dialect,
135+
(
136+
"ON CONFLICT(created_at, user_id, node_id) DO UPDATE "
137+
"SET used_traffic = node_user_usages.used_traffic + excluded.used_traffic"
138+
),
139+
)
166140

167-
return [(insert_stmt, upsert_params), (update_stmt, update_params)]
141+
stmt = text(
142+
f"""
143+
INSERT INTO node_user_usages (created_at, user_id, node_id, used_traffic)
144+
SELECT source.created_at, source.uid, source.node_id, SUM(source.value)
145+
FROM (
146+
{values_source}
147+
) AS source
148+
JOIN users ON users.id = source.uid
149+
GROUP BY source.created_at, source.uid, source.node_id
150+
{conflict_clause}
151+
"""
152+
)
153+
return [(stmt, stmt_params)]
168154

169155

170156
def build_node_usage_upsert(dialect: str, upsert_param: dict):
@@ -313,31 +299,6 @@ async def safe_execute(stmt, params=None, max_retries: int = 2):
313299
raise
314300

315301

316-
async def _filter_existing_user_usage_params(upsert_params: list[dict]) -> list[dict]:
317-
"""Drop chart usage rows for users that no longer exist."""
318-
if not upsert_params:
319-
return []
320-
321-
user_ids = {param["uid"] for param in upsert_params}
322-
323-
async with GetDB() as db:
324-
result = await db.execute(select(User.id).where(User.id.in_(user_ids)))
325-
existing_user_ids = set(result.scalars().all())
326-
327-
if len(existing_user_ids) == len(user_ids):
328-
return upsert_params
329-
330-
filtered_params = [param for param in upsert_params if param["uid"] in existing_user_ids]
331-
skipped = len(upsert_params) - len(filtered_params)
332-
missing_users = len(user_ids) - len(existing_user_ids)
333-
logger.warning(
334-
"Skipped %s node user usage records for %s missing users",
335-
skipped,
336-
missing_users,
337-
)
338-
return filtered_params
339-
340-
341302
def _get_time_bucket(now: dt = None) -> dt:
342303
"""
343304
Get 10-minute time bucket instead of hourly to reduce hot row contention.
@@ -390,28 +351,11 @@ async def record_user_stats_batched(all_node_params: dict, usage_coefficients: d
390351
if not upsert_params:
391352
return
392353

393-
upsert_params = await _filter_existing_user_usage_params(upsert_params)
394-
if not upsert_params:
395-
return
396-
397354
# Execute single batched UPSERT with concurrency control
398355
async with JOB_SEM:
399356
queries = build_node_user_usage_upsert(dialect, upsert_params)
400357
for stmt, stmt_params in queries:
401-
try:
402-
await safe_execute(stmt, stmt_params)
403-
except IntegrityError:
404-
# A user can be deleted after the pre-insert filter but before
405-
# the write reaches the database. Re-check once so one stale
406-
# node stat does not drop the whole batch.
407-
retry_params = await _filter_existing_user_usage_params(upsert_params)
408-
if len(retry_params) == len(upsert_params):
409-
raise
410-
if not retry_params:
411-
return
412-
retry_queries = build_node_user_usage_upsert(dialect, retry_params)
413-
for retry_stmt, retry_stmt_params in retry_queries:
414-
await safe_execute(retry_stmt, retry_stmt_params)
358+
await safe_execute(stmt, stmt_params)
415359

416360

417361
async def record_node_stats_batched(all_node_params: dict):

0 commit comments

Comments
 (0)