Skip to content

Commit 30bb4b1

Browse files
committed
fix(jobs): use on_conflict_do_update to avoid unnecessary select query
1 parent 76f14d6 commit 30bb4b1

File tree

7 files changed

+162
-84
lines changed

7 files changed

+162
-84
lines changed

app/db/base.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,6 @@
2828

2929
SessionLocal = async_sessionmaker(autocommit=False, autoflush=False, bind=engine)
3030

31-
# Determine dialect once at startup based on connection URL
32-
if SQLALCHEMY_DATABASE_URL.startswith("sqlite"):
33-
DATABASE_DIALECT = "sqlite"
34-
elif SQLALCHEMY_DATABASE_URL.startswith("postgresql"):
35-
DATABASE_DIALECT = "postgresql"
36-
elif SQLALCHEMY_DATABASE_URL.startswith("mysql"):
37-
DATABASE_DIALECT = "mysql"
38-
else:
39-
raise ValueError("Unsupported database URL")
40-
4131

4232
class Base(DeclarativeBase, MappedAsDataclass, AsyncAttrs):
4333
pass

app/db/crud/bulk.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from sqlalchemy.dialects.postgresql import JSONB
66
from sqlalchemy.ext.asyncio import AsyncSession
77

8-
from app.db.base import DATABASE_DIALECT
98
from app.db.models import (
109
Admin,
1110
Group,
@@ -266,7 +265,7 @@ async def update_users_expire(db: AsyncSession, bulk_model: BulkUser) -> tuple[l
266265
await db.execute(select(func.count(User.id)).where(and_(final_filter, User.expire.isnot(None))))
267266
).scalar_one_or_none() or 0
268267
# Get database-specific datetime addition expression
269-
new_expire = get_datetime_add_expression(User.expire, bulk_model.amount)
268+
new_expire = get_datetime_add_expression(db, User.expire, bulk_model.amount)
270269
current_time = dt.now(tz.utc)
271270

272271
# First, get the users that will have status changes BEFORE updating
@@ -377,8 +376,10 @@ async def update_users_proxy_settings(
377376
if not users_to_update:
378377
return [], count_effctive_users
379378

379+
dialect = db.bind.dialect.name
380+
380381
# Prepare the update statement
381-
if DATABASE_DIALECT == "postgresql":
382+
if dialect == "postgresql":
382383
proxy_settings_expr = cast(User.proxy_settings, JSONB)
383384
if bulk_model.flow is not None:
384385
proxy_settings_expr = func.jsonb_set(

app/db/crud/general.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from sqlalchemy import String, func, or_, select, text
22
from sqlalchemy.ext.asyncio import AsyncSession
33

4-
from app.db.base import DATABASE_DIALECT
54
from app.db.models import JWT, System
65
from app.models.stats import Period
76

@@ -19,39 +18,43 @@
1918
}
2019

2120

22-
def _build_trunc_expression(period: Period, column):
23-
"""Builds the appropriate truncation SQL expression based on DATABASE_DIALECT and period."""
24-
if DATABASE_DIALECT == "postgresql":
21+
def _build_trunc_expression(db: AsyncSession, period: Period, column):
22+
dialect = db.bind.dialect.name
23+
24+
"""Builds the appropriate truncation SQL expression based on dialect and period."""
25+
if dialect == "postgresql":
2526
return func.date_trunc(period.value, column)
26-
elif DATABASE_DIALECT == "mysql":
27+
elif dialect == "mysql":
2728
return func.date_format(column, MYSQL_FORMATS[period.value])
28-
elif DATABASE_DIALECT == "sqlite":
29+
elif dialect == "sqlite":
2930
return func.strftime(SQLITE_FORMATS[period.value], column)
3031

31-
raise ValueError(f"Unsupported dialect: {DATABASE_DIALECT}")
32+
raise ValueError(f"Unsupported dialect: {dialect}")
3233

3334

34-
def get_datetime_add_expression(datetime_column, seconds: int):
35+
def get_datetime_add_expression(db: AsyncSession, datetime_column, seconds: int):
3536
"""
3637
Get database-specific datetime addition expression
3738
"""
38-
if DATABASE_DIALECT == "mysql":
39+
dialect = db.bind.dialect.name
40+
if dialect == "mysql":
3941
return func.date_add(datetime_column, text("INTERVAL :seconds SECOND").bindparams(seconds=seconds))
40-
elif DATABASE_DIALECT == "postgresql":
42+
elif dialect == "postgresql":
4143
return datetime_column + func.make_interval(0, 0, 0, 0, 0, 0, seconds)
42-
elif DATABASE_DIALECT == "sqlite":
44+
elif dialect == "sqlite":
4345
return func.datetime(func.strftime("%s", datetime_column) + seconds, "unixepoch")
4446

45-
raise ValueError(f"Unsupported dialect: {DATABASE_DIALECT}")
47+
raise ValueError(f"Unsupported dialect: {dialect}")
4648

4749

48-
def json_extract(column, path: str):
50+
def json_extract(db: AsyncSession, column, path: str):
4951
"""
5052
Args:
5153
column: The JSON column in your model
5254
path: JSON path (e.g., '$.theme')
5355
"""
54-
match DATABASE_DIALECT:
56+
dialect = db.bind.dialect.name
57+
match dialect:
5558
case "postgresql":
5659
keys = path.replace("$.", "").split(".")
5760
expr = column
@@ -64,14 +67,14 @@ def json_extract(column, path: str):
6467
return func.json_extract(column, path).cast(String)
6568

6669

67-
def build_json_proxy_settings_search_condition(column, value: str):
70+
def build_json_proxy_settings_search_condition(db: AsyncSession, column, value: str):
6871
"""
6972
Builds a condition to search JSON column for UUIDs or passwords.
7073
Supports PostgresSQL, MySQL, SQLite.
7174
"""
7275
return or_(
7376
*[
74-
json_extract(column, field) == value
77+
json_extract(db, column, field) == value
7578
for field in ("$.vmess.id", "$.vless.id", "$.trojan.password", "$.shadowsocks.password")
7679
],
7780
)

app/db/crud/node.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ async def get_nodes_usage(
9999
Returns:
100100
NodeUsageStatsList: A NodeUsageStatsList contain list of NodeUsageResponse objects containing usage data.
101101
"""
102-
trunc_expr = _build_trunc_expression(period, NodeUsage.created_at)
102+
trunc_expr = _build_trunc_expression(db, period, NodeUsage.created_at)
103103

104104
conditions = [NodeUsage.created_at >= start, NodeUsage.created_at <= end]
105105

@@ -147,7 +147,7 @@ async def get_nodes_usage(
147147
async def get_node_stats(
148148
db: AsyncSession, node_id: int, start: datetime, end: datetime, period: Period
149149
) -> NodeStatsList:
150-
trunc_expr = _build_trunc_expression(period, NodeStat.created_at)
150+
trunc_expr = _build_trunc_expression(db, period, NodeStat.created_at)
151151
conditions = [NodeStat.created_at >= start, NodeStat.created_at <= end, NodeStat.node_id == node_id]
152152

153153
stmt = (

app/db/crud/user.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ async def get_users(
157157
if group_ids:
158158
filters.append(User.groups.any(Group.id.in_(group_ids)))
159159
if proxy_id:
160-
filters.append(build_json_proxy_settings_search_condition(User.proxy_settings, proxy_id))
160+
filters.append(build_json_proxy_settings_search_condition(db, User.proxy_settings, proxy_id))
161161

162162
if filters:
163163
stmt = stmt.where(and_(*filters))
@@ -337,7 +337,7 @@ async def get_user_usages(
337337
"""
338338

339339
# Build the appropriate truncation expression
340-
trunc_expr = _build_trunc_expression(period, NodeUserUsage.created_at)
340+
trunc_expr = _build_trunc_expression(db, period, NodeUserUsage.created_at)
341341

342342
conditions = [
343343
NodeUserUsage.created_at >= start,
@@ -862,7 +862,7 @@ async def get_all_users_usages(
862862
admin_users = {user.id for user in await get_users(db=db, admins=admin)}
863863

864864
# Build the appropriate truncation expression
865-
trunc_expr = _build_trunc_expression(period, NodeUserUsage.created_at)
865+
trunc_expr = _build_trunc_expression(db, period, NodeUserUsage.created_at)
866866

867867
conditions = [
868868
NodeUserUsage.created_at >= start,

app/jobs/cleanup_subscription_updates.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from app import scheduler
44
from app.db import GetDB
5-
from app.db.base import DATABASE_DIALECT
65
from app.db.models import UserSubscriptionUpdate
76
from app.utils.logger import get_logger
87
from config import USER_SUBSCRIPTION_CLIENTS_LIMIT, JOB_CLEANUP_SUBSCRIPTION_UPDATES_INTERVAL
@@ -26,8 +25,10 @@ async def cleanup_user_subscription_updates():
2625
logger.info("No users with excess subscription updates")
2726
return
2827

28+
dialect = db.bind.dialect.name
29+
2930
# Second query: Use different approaches based on database type
30-
if DATABASE_DIALECT == "mysql":
31+
if dialect == "mysql":
3132
# MySQL/MariaDB: Use correlated subquery without LIMIT
3233
total_deleted = 0
3334
for user_id in user_ids:

0 commit comments

Comments
 (0)