Skip to content

Commit d35c7d8

Browse files
committed
refactor(nodes): use less sessions to avoid overflow
1 parent 5449ecd commit d35c7d8

File tree

10 files changed

+300
-135
lines changed

10 files changed

+300
-135
lines changed

app/db/crud/node.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from datetime import datetime, timezone
22
from typing import Optional, Union
33

4-
from sqlalchemy import and_, delete, func, select, update
4+
from sqlalchemy import and_, delete, func, select, update, bindparam
55
from sqlalchemy.ext.asyncio import AsyncSession
66

77
from app.db.models import Node, NodeStat, NodeStatus, NodeUsage, NodeUserUsage
@@ -278,6 +278,49 @@ def _table_model(table: UsageTable):
278278
raise ValueError("Invalid table enum")
279279

280280

281+
async def bulk_update_node_status(
282+
db: AsyncSession,
283+
updates: list[dict],
284+
) -> None:
285+
"""
286+
Update multiple node statuses in a single query using bindparam.
287+
288+
Args:
289+
db (AsyncSession): The database session.
290+
updates (list[dict]): List of updates with keys: node_id, status, message, xray_version, node_version.
291+
292+
Example:
293+
updates = [
294+
{"node_id": 1, "status": NodeStatus.connected, "message": "", "xray_version": "1.8.0", "node_version": "0.1.0"},
295+
{"node_id": 2, "status": NodeStatus.error, "message": "Connection failed", "xray_version": "", "node_version": ""},
296+
]
297+
"""
298+
if not updates:
299+
return
300+
301+
stmt = (
302+
update(Node)
303+
.where(Node.id == bindparam("node_id"))
304+
.values(
305+
status=bindparam("status"),
306+
message=bindparam("message"),
307+
xray_version=bindparam("xray_version"),
308+
node_version=bindparam("node_version"),
309+
last_status_change=bindparam("now"),
310+
)
311+
)
312+
313+
# Add timestamp to each update
314+
now = datetime.now(timezone.utc)
315+
for upd in updates:
316+
upd["now"] = now
317+
318+
# Execute using connection-level execute (bypasses ORM, allows bindparam with WHERE)
319+
conn = await db.connection()
320+
await conn.execute(stmt, updates)
321+
await db.commit()
322+
323+
281324
async def clear_usage_data(
282325
db: AsyncSession, table: UsageTable, start: datetime | None = None, end: datetime | None = None
283326
):

app/jobs/node_checker.py

Lines changed: 31 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,21 @@ async def update_node_connection_status(node_id: int, node: PasarGuardNode):
5555
"""
5656
try:
5757
await node.get_backend_stats(timeout=8)
58-
await node_operator.update_node_status(
59-
node_id, NodeStatus.connected, await node.core_version(), await node.node_version()
60-
)
58+
async with GetDB() as db:
59+
await NodeOperation._update_single_node_status(
60+
db,
61+
node_id,
62+
NodeStatus.connected,
63+
xray_version=await node.core_version(),
64+
node_version=await node.node_version(),
65+
)
6166
except NodeAPIError as e:
6267
if e.code > -3:
63-
await node_operator.update_node_status(node_id, NodeStatus.error, err=e.detail)
68+
async with GetDB() as db:
69+
await NodeOperation._update_single_node_status(db, node_id, NodeStatus.error, message=e.detail)
6470
if e.code > 0:
65-
await node_operator.connect_node(node_id=node_id)
71+
async with GetDB() as db:
72+
await node_operator.connect_node_wrapper(db, node_id)
6673

6774

6875
async def process_node_health_check(db_node: Node, node: PasarGuardNode):
@@ -77,16 +84,23 @@ async def process_node_health_check(db_node: Node, node: PasarGuardNode):
7784
return
7885

7986
if node.requires_hard_reset():
80-
await node_operator.connect_node(db_node.id)
87+
async with GetDB() as db:
88+
await node_operator.connect_node_wrapper(db, db_node.id)
8189
return
8290

8391
try:
8492
health = await asyncio.wait_for(verify_node_backend_health(node, db_node.name), timeout=15)
8593
except asyncio.TimeoutError:
86-
await node_operator.update_node_status(db_node.id, NodeStatus.error, err="Health check timeout")
94+
async with GetDB() as db:
95+
await NodeOperation._update_single_node_status(
96+
db, db_node.id, NodeStatus.error, message="Health check timeout"
97+
)
8798
return
8899
except NodeAPIError:
89-
await node_operator.update_node_status(db_node.id, NodeStatus.error, err="Get health failed")
100+
async with GetDB() as db:
101+
await NodeOperation._update_single_node_status(
102+
db, db_node.id, NodeStatus.error, message="Get health failed"
103+
)
90104
return
91105

92106
# Skip nodes that are already healthy and connected
@@ -95,12 +109,14 @@ async def process_node_health_check(db_node: Node, node: PasarGuardNode):
95109

96110
# Update status for recovering nodes
97111
if db_node.status in (NodeStatus.connecting, NodeStatus.error) and health == Health.HEALTHY:
98-
await node_operator.update_node_status(
99-
db_node.id,
100-
NodeStatus.connected,
101-
core_version=await node.core_version(),
102-
node_version=await node.node_version(),
103-
)
112+
async with GetDB() as db:
113+
await NodeOperation._update_single_node_status(
114+
db,
115+
db_node.id,
116+
NodeStatus.connected,
117+
xray_version=await node.core_version(),
118+
node_version=await node.node_version(),
119+
)
104120
return
105121

106122
# For all other cases, update connection status
@@ -126,27 +142,10 @@ async def initialize_nodes():
126142
async with GetDB() as db:
127143
db_nodes = await get_nodes(db=db, enabled=True)
128144

129-
# Semaphore to limit concurrent node startups to 3 at a time
130-
semaphore = asyncio.Semaphore(3)
131-
132-
async def start_node(node: Node):
133-
try:
134-
await node_manager.update_node(node)
135-
except NodeAPIError as e:
136-
await node_operator.update_node_status(node.id, NodeStatus.error, err=e.detail)
137-
return
138-
139-
await node_operator.connect_node(node_id=node.id)
140-
141-
async def start_node_with_limit(node: Node):
142-
async with semaphore:
143-
await start_node(node)
144-
145145
if not db_nodes:
146146
logger.warning("Attention: You have no node, you need to have at least one node")
147147
else:
148-
start_tasks = [start_node_with_limit(node=db_node) for db_node in db_nodes]
149-
await asyncio.gather(*start_tasks)
148+
await node_operator.connect_nodes_bulk(db, db_nodes)
150149
logger.info("All nodes' cores have been started.")
151150

152151
scheduler.add_job(

app/models/node.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,3 +161,15 @@ class NodeResponse(Node):
161161
message: str | None
162162

163163
model_config = ConfigDict(from_attributes=True)
164+
165+
166+
class NodeNotification(BaseModel):
167+
"""Lightweight node model for sending notifications without database fetch."""
168+
169+
id: int
170+
name: str
171+
xray_version: str | None = None
172+
node_version: str | None = None
173+
message: str | None = None
174+
175+
model_config = ConfigDict(from_attributes=True)

app/notification/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from . import webhook as wh
66
from app.models.host import BaseHost
77
from app.models.user_template import UserTemplateResponse
8-
from app.models.node import NodeResponse
8+
from app.models.node import NodeNotification, NodeResponse
99
from app.models.group import GroupResponse
1010
from app.models.core import CoreResponse
1111
from app.models.admin import AdminDetails
@@ -63,12 +63,12 @@ async def remove_node(node: NodeResponse, by: str):
6363
await asyncio.gather(ds.remove_node(node, by), tg.remove_node(node, by))
6464

6565

66-
async def connect_node(node: NodeResponse):
66+
async def connect_node(node: NodeNotification):
6767
if (await notification_enable()).node.connect:
6868
await asyncio.gather(ds.connect_node(node), tg.connect_node(node))
6969

7070

71-
async def error_node(node: NodeResponse):
71+
async def error_node(node: NodeNotification):
7272
if (await notification_enable()).node.error:
7373
await asyncio.gather(ds.error_node(node), tg.error_node(node))
7474

app/notification/discord/node.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import copy
22

33
from app.notification.client import send_discord_webhook
4-
from app.models.node import NodeResponse
4+
from app.models.node import NodeNotification, NodeResponse
55
from app.models.settings import NotificationSettings
66
from app.settings import notification_settings
77
from app.utils.helpers import escape_ds_markdown_list, escape_ds_markdown
@@ -54,7 +54,7 @@ async def remove_node(node: NodeResponse, by: str):
5454
await send_discord_webhook(data, settings.discord_webhook_url)
5555

5656

57-
async def connect_node(node: NodeResponse):
57+
async def connect_node(node: NodeNotification):
5858
name = escape_ds_markdown(node.name)
5959
message = copy.deepcopy(messages.CONNECT_NODE)
6060
message["description"] = message["description"].format(
@@ -71,7 +71,7 @@ async def connect_node(node: NodeResponse):
7171
await send_discord_webhook(data, settings.discord_webhook_url)
7272

7373

74-
async def error_node(node: NodeResponse):
74+
async def error_node(node: NodeNotification):
7575
name, node_message = escape_ds_markdown_list((node.name, node.message))
7676
message = copy.deepcopy(messages.ERROR_NODE)
7777
message["description"] = message["description"].format(name=name, error=node_message)

app/notification/telegram/node.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from html import escape
22

33
from app.notification.client import send_telegram_message
4-
from app.models.node import NodeResponse
4+
from app.models.node import NodeNotification, NodeResponse
55
from app.models.settings import NotificationSettings
66
from app.settings import notification_settings
77
from app.utils.helpers import escape_tg_html
@@ -38,7 +38,7 @@ async def remove_node(node: NodeResponse, by: str):
3838
)
3939

4040

41-
async def connect_node(node: NodeResponse):
41+
async def connect_node(node: NodeNotification):
4242
data = messages.CONNECT_NODE.format(
4343
name=escape(node.name), node_version=node.node_version, core_version=node.xray_version, id=node.id
4444
)
@@ -49,7 +49,7 @@ async def connect_node(node: NodeResponse):
4949
)
5050

5151

52-
async def error_node(node: NodeResponse):
52+
async def error_node(node: NodeNotification):
5353
name, message = escape_tg_html((node.name, node.message))
5454
data = messages.ERROR_NODE.format(name=name, error=message, id=node.id)
5555
settings: NotificationSettings = await notification_settings()

0 commit comments

Comments
 (0)