11from __future__ import annotations
22
3+ import os
4+ from collections import defaultdict
35from typing import Any
4- from unittest .mock import AsyncMock , call
6+ from unittest .mock import AsyncMock
57
68import pytest
9+ from sqlalchemy import select
10+ from sqlalchemy .exc import SQLAlchemyError
11+ from sqlalchemy .ext .asyncio import async_sessionmaker , create_async_engine
12+ from sqlalchemy .pool import NullPool , StaticPool
713
14+ from app .db import base
15+ from app .db .models import Admin , Node , NodeUsage , NodeUserUsage , System , User
816from app .jobs import record_usages
17+ from config import SQLALCHEMY_DATABASE_URL
918
1019
1120class DummyNode :
@@ -17,140 +26,286 @@ async def get_extra(self) -> dict[str, Any]:
1726 return {"usage_coefficient" : self ._usage_coefficient }
1827
1928
29+ def _get_test_database_url () -> str :
30+ test_from = os .getenv ("TEST_FROM" , "local" ).lower ()
31+ if test_from == "local" :
32+ return "sqlite+aiosqlite:///:memory:"
33+ return SQLALCHEMY_DATABASE_URL
34+
35+
36+ @pytest .fixture
37+ async def session_factory (monkeypatch : pytest .MonkeyPatch ):
38+ database_url = _get_test_database_url ()
39+ is_sqlite = database_url .startswith ("sqlite" )
40+
41+ engine_kwargs = {}
42+ connect_args = {}
43+ if is_sqlite :
44+ connect_args ["check_same_thread" ] = False
45+ # Keep the in-memory database alive across connections
46+ engine_kwargs ["poolclass" ] = StaticPool
47+ else :
48+ engine_kwargs ["poolclass" ] = NullPool
49+
50+ engine = create_async_engine (database_url , connect_args = connect_args , ** engine_kwargs )
51+ async with engine .begin () as conn :
52+ await conn .run_sync (base .Base .metadata .drop_all )
53+ await conn .run_sync (base .Base .metadata .create_all )
54+
55+ session_factory = async_sessionmaker (bind = engine , expire_on_commit = False , autoflush = False )
56+
57+ class TestGetDB :
58+ def __init__ (self ):
59+ self .db = session_factory ()
60+
61+ async def __aenter__ (self ):
62+ return self .db
63+
64+ async def __aexit__ (self , exc_type , exc_value , traceback ):
65+ if isinstance (exc_value , SQLAlchemyError ):
66+ await self .db .rollback ()
67+ await self .db .close ()
68+
69+ monkeypatch .setattr (record_usages , "engine" , engine )
70+ monkeypatch .setattr (record_usages , "GetDB" , TestGetDB )
71+
72+ yield session_factory
73+
74+ async with engine .begin () as conn :
75+ await conn .run_sync (base .Base .metadata .drop_all )
76+ await engine .dispose ()
77+
78+
2079@pytest .mark .asyncio
21- async def test_record_user_usages_updates_users_and_admins (monkeypatch : pytest .MonkeyPatch ):
22- nodes = [(1 , DummyNode (1 , usage_coefficient = 2 )), (2 , DummyNode (2 , usage_coefficient = 1 ))]
80+ async def test_record_user_usages_updates_users_and_admins (monkeypatch : pytest .MonkeyPatch , session_factory ):
81+ async with session_factory () as session :
82+ admin = Admin (username = "admin" , hashed_password = "secret" )
83+ session .add (admin )
84+ await session .flush ()
85+ admin_id = admin .id
86+
87+ user_one = User (username = "user1" , admin_id = admin_id , proxy_settings = {})
88+ user_two = User (username = "user2" , admin_id = admin_id , proxy_settings = {})
89+ session .add_all ([user_one , user_two ])
90+ await session .flush ()
91+ user_one_id , user_two_id = user_one .id , user_two .id
92+
93+ node_one = Node (
94+ name = "node-1" ,
95+ address = "10.0.0.1" ,
96+ port = 1000 ,
97+ server_ca = "ca1" ,
98+ api_key = "key1" ,
99+ core_config_id = None ,
100+ )
101+ node_two = Node (
102+ name = "node-2" ,
103+ address = "10.0.0.2" ,
104+ port = 1001 ,
105+ server_ca = "ca2" ,
106+ api_key = "key2" ,
107+ core_config_id = None ,
108+ )
109+ session .add_all ([node_one , node_two ])
110+ await session .flush ()
111+ node_one_id , node_two_id = node_one .id , node_two .id
112+ await session .commit ()
113+
114+ nodes = [
115+ (node_one_id , DummyNode (node_one_id , usage_coefficient = 2 )),
116+ (node_two_id , DummyNode (node_two_id , usage_coefficient = 1 )),
117+ ]
23118 monkeypatch .setattr (record_usages .node_manager , "get_healthy_nodes" , AsyncMock (return_value = nodes ))
24119
25120 stats_map = {
26- 1 : [{"uid" : "1" , "value" : 100 }, {"uid" : "2" , "value" : 50 }],
27- 2 : [{"uid" : "1" , "value" : 75 }],
121+ node_one_id : [{"uid" : str ( user_one_id ) , "value" : 100 }, {"uid" : str ( user_two_id ) , "value" : 50 }],
122+ node_two_id : [{"uid" : str ( user_one_id ) , "value" : 75 }],
28123 }
29124
30125 async def fake_get_users_stats (node : DummyNode ):
31126 return stats_map [node .node_id ]
32127
33128 monkeypatch .setattr (record_usages , "get_users_stats" , fake_get_users_stats )
34-
35- safe_execute_mock = AsyncMock ()
36- monkeypatch .setattr (record_usages , "safe_execute" , safe_execute_mock )
37-
38- admin_usage = {99 : 555 }
39- calculate_admin_usage_mock = AsyncMock (return_value = admin_usage )
40- monkeypatch .setattr (record_usages , "calculate_admin_usage" , calculate_admin_usage_mock )
41-
42- record_user_stats_mock = AsyncMock ()
43- monkeypatch .setattr (record_usages , "record_user_stats" , record_user_stats_mock )
44129 monkeypatch .setattr (record_usages , "DISABLE_RECORDING_NODE_USAGE" , False )
45130
46131 await record_usages .record_user_usages ()
47132
48- expected_users_usage = [
49- {"uid" : 1 , "value" : 275 },
50- {"uid" : 2 , "value" : 100 },
51- ]
52- calculate_admin_usage_mock .assert_awaited_once ()
53- assert calculate_admin_usage_mock .await_args .args [0 ] == expected_users_usage
54-
55- assert safe_execute_mock .await_count == 2
56- user_call = safe_execute_mock .await_args_list [0 ]
57- assert user_call .args [1 ] == expected_users_usage
58-
59- admin_call = safe_execute_mock .await_args_list [1 ]
60- assert admin_call .args [1 ] == [{"admin_id" : 99 , "value" : 555 }]
61-
62- assert record_user_stats_mock .await_count == 2
63- expected_record_calls = [
64- call (params = stats_map [1 ], node_id = 1 , usage_coefficient = 2 ),
65- call (params = stats_map [2 ], node_id = 2 , usage_coefficient = 1 ),
66- ]
67- record_user_stats_mock .assert_has_awaits (expected_record_calls , any_order = False )
133+ async with session_factory () as session :
134+ users_result = await session .execute (
135+ select (User .id , User .used_traffic , User .online_at ).where (User .id .in_ ([user_one_id , user_two_id ]))
136+ )
137+ user_rows = users_result .all ()
138+ user_totals = {row .id : (row .used_traffic , row .online_at ) for row in user_rows }
139+
140+ assert user_totals [user_one_id ][0 ] > user_totals [user_two_id ][0 ]
141+ assert all (total > 0 for total , _ in user_totals .values ())
142+ assert all (online_at is not None for _ , online_at in user_totals .values ())
143+
144+ admin_total = await session .execute (select (Admin .used_traffic ).where (Admin .id == admin_id ))
145+ admin_used = admin_total .scalar_one ()
146+ assert admin_used == sum (total for total , _ in user_totals .values ())
147+
148+ node_usage_rows = await session .execute (
149+ select (NodeUserUsage .node_id , NodeUserUsage .user_id , NodeUserUsage .used_traffic )
150+ )
151+ node_usage_records = node_usage_rows .all ()
152+ usage_pairs = {(row .node_id , row .user_id ) for row in node_usage_records }
153+ assert usage_pairs == {
154+ (node_one_id , user_one_id ),
155+ (node_one_id , user_two_id ),
156+ (node_two_id , user_one_id ),
157+ }
158+
159+ aggregated_usage = defaultdict (int )
160+ for record in node_usage_records :
161+ assert record .used_traffic > 0
162+ aggregated_usage [record .user_id ] += record .used_traffic
163+
164+ for user_id , (total_usage , _ ) in user_totals .items ():
165+ assert aggregated_usage [user_id ] == total_usage
68166
69167
70168@pytest .mark .asyncio
71- async def test_record_user_usages_returns_when_no_usage (monkeypatch : pytest .MonkeyPatch ):
72- nodes = [(1 , DummyNode (1 ))]
169+ async def test_record_user_usages_returns_when_no_usage (monkeypatch : pytest .MonkeyPatch , session_factory ):
170+ async with session_factory () as session :
171+ admin = Admin (username = "admin" , hashed_password = "secret" )
172+ session .add (admin )
173+ await session .flush ()
174+ admin_id = admin .id
175+
176+ user = User (username = "user" , admin_id = admin_id , proxy_settings = {})
177+ node = Node (
178+ name = "node-1" ,
179+ address = "10.0.0.1" ,
180+ port = 1000 ,
181+ server_ca = "ca1" ,
182+ api_key = "key1" ,
183+ core_config_id = None ,
184+ )
185+ session .add_all ([user , node ])
186+ await session .flush ()
187+ user_id , node_id = user .id , node .id
188+ await session .commit ()
189+
190+ nodes = [(node_id , DummyNode (node_id ))]
73191 monkeypatch .setattr (record_usages .node_manager , "get_healthy_nodes" , AsyncMock (return_value = nodes ))
74192
75193 async def fake_get_users_stats (_ : DummyNode ):
76194 return []
77195
78196 monkeypatch .setattr (record_usages , "get_users_stats" , fake_get_users_stats )
197+ monkeypatch .setattr (record_usages , "DISABLE_RECORDING_NODE_USAGE" , False )
79198
80- safe_execute_mock = AsyncMock ()
81- monkeypatch .setattr (record_usages , "safe_execute" , safe_execute_mock )
82-
83- record_user_stats_mock = AsyncMock ()
84- monkeypatch .setattr (record_usages , "record_user_stats" , record_user_stats_mock )
199+ await record_usages .record_user_usages ()
85200
86- calculate_admin_usage_mock = AsyncMock ()
87- monkeypatch .setattr (record_usages , "calculate_admin_usage" , calculate_admin_usage_mock )
201+ async with session_factory () as session :
202+ user_total = await session .execute (select (User .used_traffic ).where (User .id == user_id ))
203+ assert user_total .scalar_one () == 0
88204
89- await record_usages .record_user_usages ()
205+ admin_total = await session .execute (select (Admin .used_traffic ).where (Admin .id == admin_id ))
206+ assert admin_total .scalar_one () == 0
90207
91- safe_execute_mock .assert_not_awaited ()
92- record_user_stats_mock .assert_not_awaited ()
93- calculate_admin_usage_mock .assert_not_awaited ()
208+ node_user_usage = await session .execute (select (NodeUserUsage .id ))
209+ assert node_user_usage .first () is None
94210
95211
96212@pytest .mark .asyncio
97- async def test_record_node_usages_updates_totals (monkeypatch : pytest .MonkeyPatch ):
98- nodes = [(1 , DummyNode (1 )), (2 , DummyNode (2 ))]
213+ async def test_record_node_usages_updates_totals (monkeypatch : pytest .MonkeyPatch , session_factory ):
214+ async with session_factory () as session :
215+ node_one = Node (
216+ name = "node-1" ,
217+ address = "10.0.0.1" ,
218+ port = 1000 ,
219+ server_ca = "ca1" ,
220+ api_key = "key1" ,
221+ core_config_id = None ,
222+ )
223+ node_two = Node (
224+ name = "node-2" ,
225+ address = "10.0.0.2" ,
226+ port = 1001 ,
227+ server_ca = "ca2" ,
228+ api_key = "key2" ,
229+ core_config_id = None ,
230+ )
231+ system = System (uplink = 0 , downlink = 0 )
232+ session .add_all ([node_one , node_two , system ])
233+ await session .flush ()
234+ node_one_id , node_two_id , system_id = node_one .id , node_two .id , system .id
235+ await session .commit ()
236+
237+ nodes = [(node_one_id , DummyNode (node_one_id )), (node_two_id , DummyNode (node_two_id ))]
99238 monkeypatch .setattr (record_usages .node_manager , "get_healthy_nodes" , AsyncMock (return_value = nodes ))
100239
101240 stats_map = {
102- 1 : [{"up" : 10 , "down" : 4 }, {"up" : 0 , "down" : 3 }],
103- 2 : [{"up" : 1 , "down" : 1 }],
241+ node_one_id : [{"up" : 10 , "down" : 4 }, {"up" : 0 , "down" : 3 }],
242+ node_two_id : [{"up" : 1 , "down" : 1 }],
104243 }
105244
106245 async def fake_get_outbounds_stats (node : DummyNode ):
107246 return stats_map [node .node_id ]
108247
109248 monkeypatch .setattr (record_usages , "get_outbounds_stats" , fake_get_outbounds_stats )
110-
111- safe_execute_mock = AsyncMock ()
112- monkeypatch .setattr (record_usages , "safe_execute" , safe_execute_mock )
113-
114- record_node_stats_mock = AsyncMock ()
115- monkeypatch .setattr (record_usages , "record_node_stats" , record_node_stats_mock )
116249 monkeypatch .setattr (record_usages , "DISABLE_RECORDING_NODE_USAGE" , False )
117250
118251 await record_usages .record_node_usages ()
119252
120- assert safe_execute_mock .await_count == 2
121- node_call = safe_execute_mock .await_args_list [0 ]
122- assert node_call .args [1 ] == [
123- {"node_id" : 1 , "up" : 10 , "down" : 7 },
124- {"node_id" : 2 , "up" : 1 , "down" : 1 },
125- ]
253+ async with session_factory () as session :
254+ nodes_result = await session .execute (select (Node .id , Node .uplink , Node .downlink ))
255+ node_totals = {row .id : (row .uplink , row .downlink ) for row in nodes_result .all ()}
256+ assert node_totals [node_one_id ][0 ] > node_totals [node_two_id ][0 ]
257+ assert node_totals [node_two_id ][1 ] > 0
126258
127- system_call = safe_execute_mock .await_args_list [1 ]
128- assert len (system_call .args ) == 1 # system totals baked into the statement
259+ node_usage_rows = await session .execute (
260+ select (NodeUsage .node_id , NodeUsage .uplink , NodeUsage .downlink )
261+ )
262+ node_usage_totals = {row .node_id : (row .uplink , row .downlink ) for row in node_usage_rows .all ()}
263+ assert set (node_usage_totals .keys ()) == {node_one_id , node_two_id }
264+ assert node_usage_totals == node_totals
129265
130- expected_node_calls = [
131- call (stats_map [1 ], 1 ),
132- call (stats_map [2 ], 2 ),
133- ]
134- assert record_node_stats_mock .await_args_list == expected_node_calls
266+ system_totals = await session .execute (select (System .uplink , System .downlink ).where (System .id == system_id ))
267+ system_row = system_totals .one ()
268+ assert system_row .uplink == sum (values [0 ] for values in node_totals .values ())
269+ assert system_row .downlink == sum (values [1 ] for values in node_totals .values ())
135270
136271
137272@pytest .mark .asyncio
138- async def test_record_node_usages_returns_when_totals_zero (monkeypatch : pytest .MonkeyPatch ):
139- nodes = [(1 , DummyNode (1 ))]
273+ async def test_record_node_usages_returns_when_totals_zero (monkeypatch : pytest .MonkeyPatch , session_factory ):
274+ async with session_factory () as session :
275+ node = Node (
276+ name = "node-1" ,
277+ address = "10.0.0.1" ,
278+ port = 1000 ,
279+ server_ca = "ca1" ,
280+ api_key = "key1" ,
281+ core_config_id = None ,
282+ )
283+ system = System (uplink = 0 , downlink = 0 )
284+ session .add_all ([node , system ])
285+ await session .flush ()
286+ node_id , system_id = node .id , system .id
287+ await session .commit ()
288+
289+ nodes = [(node_id , DummyNode (node_id ))]
140290 monkeypatch .setattr (record_usages .node_manager , "get_healthy_nodes" , AsyncMock (return_value = nodes ))
141291
142292 async def fake_get_outbounds_stats (_ : DummyNode ):
143293 return [{"up" : 0 , "down" : 0 }]
144294
145295 monkeypatch .setattr (record_usages , "get_outbounds_stats" , fake_get_outbounds_stats )
146296
147- safe_execute_mock = AsyncMock ()
148- monkeypatch .setattr (record_usages , "safe_execute" , safe_execute_mock )
297+ await record_usages .record_node_usages ()
149298
150- record_node_stats_mock = AsyncMock ()
151- monkeypatch .setattr (record_usages , "record_node_stats" , record_node_stats_mock )
299+ async with session_factory () as session :
300+ node_row = await session .execute (select (Node .uplink , Node .downlink ).where (Node .id == node_id ))
301+ node_totals = node_row .one ()
302+ assert node_totals .uplink == 0
303+ assert node_totals .downlink == 0
152304
153- await record_usages .record_node_usages ()
305+ system_row = await session .execute (select (System .uplink , System .downlink ).where (System .id == system_id ))
306+ system_totals = system_row .one ()
307+ assert system_totals .uplink == 0
308+ assert system_totals .downlink == 0
154309
155- safe_execute_mock . assert_not_awaited ( )
156- record_node_stats_mock . assert_not_awaited ()
310+ node_usage_rows = await session . execute ( select ( NodeUsage . id ) )
311+ assert node_usage_rows . first () is None
0 commit comments