Skip to content

Commit deaa44e

Browse files
fix: record_usages test
1 parent 1406c8f commit deaa44e

File tree

1 file changed

+235
-80
lines changed

1 file changed

+235
-80
lines changed

tests/test_record_usages.py

Lines changed: 235 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,20 @@
11
from __future__ import annotations
22

3+
import os
4+
from collections import defaultdict
35
from typing import Any
4-
from unittest.mock import AsyncMock, call
6+
from unittest.mock import AsyncMock
57

68
import pytest
9+
from sqlalchemy import select
10+
from sqlalchemy.exc import SQLAlchemyError
11+
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
12+
from sqlalchemy.pool import NullPool, StaticPool
713

14+
from app.db import base
15+
from app.db.models import Admin, Node, NodeUsage, NodeUserUsage, System, User
816
from app.jobs import record_usages
17+
from config import SQLALCHEMY_DATABASE_URL
918

1019

1120
class DummyNode:
@@ -17,140 +26,286 @@ async def get_extra(self) -> dict[str, Any]:
1726
return {"usage_coefficient": self._usage_coefficient}
1827

1928

29+
def _get_test_database_url() -> str:
30+
test_from = os.getenv("TEST_FROM", "local").lower()
31+
if test_from == "local":
32+
return "sqlite+aiosqlite:///:memory:"
33+
return SQLALCHEMY_DATABASE_URL
34+
35+
36+
@pytest.fixture
37+
async def session_factory(monkeypatch: pytest.MonkeyPatch):
38+
database_url = _get_test_database_url()
39+
is_sqlite = database_url.startswith("sqlite")
40+
41+
engine_kwargs = {}
42+
connect_args = {}
43+
if is_sqlite:
44+
connect_args["check_same_thread"] = False
45+
# Keep the in-memory database alive across connections
46+
engine_kwargs["poolclass"] = StaticPool
47+
else:
48+
engine_kwargs["poolclass"] = NullPool
49+
50+
engine = create_async_engine(database_url, connect_args=connect_args, **engine_kwargs)
51+
async with engine.begin() as conn:
52+
await conn.run_sync(base.Base.metadata.drop_all)
53+
await conn.run_sync(base.Base.metadata.create_all)
54+
55+
session_factory = async_sessionmaker(bind=engine, expire_on_commit=False, autoflush=False)
56+
57+
class TestGetDB:
58+
def __init__(self):
59+
self.db = session_factory()
60+
61+
async def __aenter__(self):
62+
return self.db
63+
64+
async def __aexit__(self, exc_type, exc_value, traceback):
65+
if isinstance(exc_value, SQLAlchemyError):
66+
await self.db.rollback()
67+
await self.db.close()
68+
69+
monkeypatch.setattr(record_usages, "engine", engine)
70+
monkeypatch.setattr(record_usages, "GetDB", TestGetDB)
71+
72+
yield session_factory
73+
74+
async with engine.begin() as conn:
75+
await conn.run_sync(base.Base.metadata.drop_all)
76+
await engine.dispose()
77+
78+
2079
@pytest.mark.asyncio
21-
async def test_record_user_usages_updates_users_and_admins(monkeypatch: pytest.MonkeyPatch):
22-
nodes = [(1, DummyNode(1, usage_coefficient=2)), (2, DummyNode(2, usage_coefficient=1))]
80+
async def test_record_user_usages_updates_users_and_admins(monkeypatch: pytest.MonkeyPatch, session_factory):
81+
async with session_factory() as session:
82+
admin = Admin(username="admin", hashed_password="secret")
83+
session.add(admin)
84+
await session.flush()
85+
admin_id = admin.id
86+
87+
user_one = User(username="user1", admin_id=admin_id, proxy_settings={})
88+
user_two = User(username="user2", admin_id=admin_id, proxy_settings={})
89+
session.add_all([user_one, user_two])
90+
await session.flush()
91+
user_one_id, user_two_id = user_one.id, user_two.id
92+
93+
node_one = Node(
94+
name="node-1",
95+
address="10.0.0.1",
96+
port=1000,
97+
server_ca="ca1",
98+
api_key="key1",
99+
core_config_id=None,
100+
)
101+
node_two = Node(
102+
name="node-2",
103+
address="10.0.0.2",
104+
port=1001,
105+
server_ca="ca2",
106+
api_key="key2",
107+
core_config_id=None,
108+
)
109+
session.add_all([node_one, node_two])
110+
await session.flush()
111+
node_one_id, node_two_id = node_one.id, node_two.id
112+
await session.commit()
113+
114+
nodes = [
115+
(node_one_id, DummyNode(node_one_id, usage_coefficient=2)),
116+
(node_two_id, DummyNode(node_two_id, usage_coefficient=1)),
117+
]
23118
monkeypatch.setattr(record_usages.node_manager, "get_healthy_nodes", AsyncMock(return_value=nodes))
24119

25120
stats_map = {
26-
1: [{"uid": "1", "value": 100}, {"uid": "2", "value": 50}],
27-
2: [{"uid": "1", "value": 75}],
121+
node_one_id: [{"uid": str(user_one_id), "value": 100}, {"uid": str(user_two_id), "value": 50}],
122+
node_two_id: [{"uid": str(user_one_id), "value": 75}],
28123
}
29124

30125
async def fake_get_users_stats(node: DummyNode):
31126
return stats_map[node.node_id]
32127

33128
monkeypatch.setattr(record_usages, "get_users_stats", fake_get_users_stats)
34-
35-
safe_execute_mock = AsyncMock()
36-
monkeypatch.setattr(record_usages, "safe_execute", safe_execute_mock)
37-
38-
admin_usage = {99: 555}
39-
calculate_admin_usage_mock = AsyncMock(return_value=admin_usage)
40-
monkeypatch.setattr(record_usages, "calculate_admin_usage", calculate_admin_usage_mock)
41-
42-
record_user_stats_mock = AsyncMock()
43-
monkeypatch.setattr(record_usages, "record_user_stats", record_user_stats_mock)
44129
monkeypatch.setattr(record_usages, "DISABLE_RECORDING_NODE_USAGE", False)
45130

46131
await record_usages.record_user_usages()
47132

48-
expected_users_usage = [
49-
{"uid": 1, "value": 275},
50-
{"uid": 2, "value": 100},
51-
]
52-
calculate_admin_usage_mock.assert_awaited_once()
53-
assert calculate_admin_usage_mock.await_args.args[0] == expected_users_usage
54-
55-
assert safe_execute_mock.await_count == 2
56-
user_call = safe_execute_mock.await_args_list[0]
57-
assert user_call.args[1] == expected_users_usage
58-
59-
admin_call = safe_execute_mock.await_args_list[1]
60-
assert admin_call.args[1] == [{"admin_id": 99, "value": 555}]
61-
62-
assert record_user_stats_mock.await_count == 2
63-
expected_record_calls = [
64-
call(params=stats_map[1], node_id=1, usage_coefficient=2),
65-
call(params=stats_map[2], node_id=2, usage_coefficient=1),
66-
]
67-
record_user_stats_mock.assert_has_awaits(expected_record_calls, any_order=False)
133+
async with session_factory() as session:
134+
users_result = await session.execute(
135+
select(User.id, User.used_traffic, User.online_at).where(User.id.in_([user_one_id, user_two_id]))
136+
)
137+
user_rows = users_result.all()
138+
user_totals = {row.id: (row.used_traffic, row.online_at) for row in user_rows}
139+
140+
assert user_totals[user_one_id][0] > user_totals[user_two_id][0]
141+
assert all(total > 0 for total, _ in user_totals.values())
142+
assert all(online_at is not None for _, online_at in user_totals.values())
143+
144+
admin_total = await session.execute(select(Admin.used_traffic).where(Admin.id == admin_id))
145+
admin_used = admin_total.scalar_one()
146+
assert admin_used == sum(total for total, _ in user_totals.values())
147+
148+
node_usage_rows = await session.execute(
149+
select(NodeUserUsage.node_id, NodeUserUsage.user_id, NodeUserUsage.used_traffic)
150+
)
151+
node_usage_records = node_usage_rows.all()
152+
usage_pairs = {(row.node_id, row.user_id) for row in node_usage_records}
153+
assert usage_pairs == {
154+
(node_one_id, user_one_id),
155+
(node_one_id, user_two_id),
156+
(node_two_id, user_one_id),
157+
}
158+
159+
aggregated_usage = defaultdict(int)
160+
for record in node_usage_records:
161+
assert record.used_traffic > 0
162+
aggregated_usage[record.user_id] += record.used_traffic
163+
164+
for user_id, (total_usage, _) in user_totals.items():
165+
assert aggregated_usage[user_id] == total_usage
68166

69167

70168
@pytest.mark.asyncio
71-
async def test_record_user_usages_returns_when_no_usage(monkeypatch: pytest.MonkeyPatch):
72-
nodes = [(1, DummyNode(1))]
169+
async def test_record_user_usages_returns_when_no_usage(monkeypatch: pytest.MonkeyPatch, session_factory):
170+
async with session_factory() as session:
171+
admin = Admin(username="admin", hashed_password="secret")
172+
session.add(admin)
173+
await session.flush()
174+
admin_id = admin.id
175+
176+
user = User(username="user", admin_id=admin_id, proxy_settings={})
177+
node = Node(
178+
name="node-1",
179+
address="10.0.0.1",
180+
port=1000,
181+
server_ca="ca1",
182+
api_key="key1",
183+
core_config_id=None,
184+
)
185+
session.add_all([user, node])
186+
await session.flush()
187+
user_id, node_id = user.id, node.id
188+
await session.commit()
189+
190+
nodes = [(node_id, DummyNode(node_id))]
73191
monkeypatch.setattr(record_usages.node_manager, "get_healthy_nodes", AsyncMock(return_value=nodes))
74192

75193
async def fake_get_users_stats(_: DummyNode):
76194
return []
77195

78196
monkeypatch.setattr(record_usages, "get_users_stats", fake_get_users_stats)
197+
monkeypatch.setattr(record_usages, "DISABLE_RECORDING_NODE_USAGE", False)
79198

80-
safe_execute_mock = AsyncMock()
81-
monkeypatch.setattr(record_usages, "safe_execute", safe_execute_mock)
82-
83-
record_user_stats_mock = AsyncMock()
84-
monkeypatch.setattr(record_usages, "record_user_stats", record_user_stats_mock)
199+
await record_usages.record_user_usages()
85200

86-
calculate_admin_usage_mock = AsyncMock()
87-
monkeypatch.setattr(record_usages, "calculate_admin_usage", calculate_admin_usage_mock)
201+
async with session_factory() as session:
202+
user_total = await session.execute(select(User.used_traffic).where(User.id == user_id))
203+
assert user_total.scalar_one() == 0
88204

89-
await record_usages.record_user_usages()
205+
admin_total = await session.execute(select(Admin.used_traffic).where(Admin.id == admin_id))
206+
assert admin_total.scalar_one() == 0
90207

91-
safe_execute_mock.assert_not_awaited()
92-
record_user_stats_mock.assert_not_awaited()
93-
calculate_admin_usage_mock.assert_not_awaited()
208+
node_user_usage = await session.execute(select(NodeUserUsage.id))
209+
assert node_user_usage.first() is None
94210

95211

96212
@pytest.mark.asyncio
97-
async def test_record_node_usages_updates_totals(monkeypatch: pytest.MonkeyPatch):
98-
nodes = [(1, DummyNode(1)), (2, DummyNode(2))]
213+
async def test_record_node_usages_updates_totals(monkeypatch: pytest.MonkeyPatch, session_factory):
214+
async with session_factory() as session:
215+
node_one = Node(
216+
name="node-1",
217+
address="10.0.0.1",
218+
port=1000,
219+
server_ca="ca1",
220+
api_key="key1",
221+
core_config_id=None,
222+
)
223+
node_two = Node(
224+
name="node-2",
225+
address="10.0.0.2",
226+
port=1001,
227+
server_ca="ca2",
228+
api_key="key2",
229+
core_config_id=None,
230+
)
231+
system = System(uplink=0, downlink=0)
232+
session.add_all([node_one, node_two, system])
233+
await session.flush()
234+
node_one_id, node_two_id, system_id = node_one.id, node_two.id, system.id
235+
await session.commit()
236+
237+
nodes = [(node_one_id, DummyNode(node_one_id)), (node_two_id, DummyNode(node_two_id))]
99238
monkeypatch.setattr(record_usages.node_manager, "get_healthy_nodes", AsyncMock(return_value=nodes))
100239

101240
stats_map = {
102-
1: [{"up": 10, "down": 4}, {"up": 0, "down": 3}],
103-
2: [{"up": 1, "down": 1}],
241+
node_one_id: [{"up": 10, "down": 4}, {"up": 0, "down": 3}],
242+
node_two_id: [{"up": 1, "down": 1}],
104243
}
105244

106245
async def fake_get_outbounds_stats(node: DummyNode):
107246
return stats_map[node.node_id]
108247

109248
monkeypatch.setattr(record_usages, "get_outbounds_stats", fake_get_outbounds_stats)
110-
111-
safe_execute_mock = AsyncMock()
112-
monkeypatch.setattr(record_usages, "safe_execute", safe_execute_mock)
113-
114-
record_node_stats_mock = AsyncMock()
115-
monkeypatch.setattr(record_usages, "record_node_stats", record_node_stats_mock)
116249
monkeypatch.setattr(record_usages, "DISABLE_RECORDING_NODE_USAGE", False)
117250

118251
await record_usages.record_node_usages()
119252

120-
assert safe_execute_mock.await_count == 2
121-
node_call = safe_execute_mock.await_args_list[0]
122-
assert node_call.args[1] == [
123-
{"node_id": 1, "up": 10, "down": 7},
124-
{"node_id": 2, "up": 1, "down": 1},
125-
]
253+
async with session_factory() as session:
254+
nodes_result = await session.execute(select(Node.id, Node.uplink, Node.downlink))
255+
node_totals = {row.id: (row.uplink, row.downlink) for row in nodes_result.all()}
256+
assert node_totals[node_one_id][0] > node_totals[node_two_id][0]
257+
assert node_totals[node_two_id][1] > 0
126258

127-
system_call = safe_execute_mock.await_args_list[1]
128-
assert len(system_call.args) == 1 # system totals baked into the statement
259+
node_usage_rows = await session.execute(
260+
select(NodeUsage.node_id, NodeUsage.uplink, NodeUsage.downlink)
261+
)
262+
node_usage_totals = {row.node_id: (row.uplink, row.downlink) for row in node_usage_rows.all()}
263+
assert set(node_usage_totals.keys()) == {node_one_id, node_two_id}
264+
assert node_usage_totals == node_totals
129265

130-
expected_node_calls = [
131-
call(stats_map[1], 1),
132-
call(stats_map[2], 2),
133-
]
134-
assert record_node_stats_mock.await_args_list == expected_node_calls
266+
system_totals = await session.execute(select(System.uplink, System.downlink).where(System.id == system_id))
267+
system_row = system_totals.one()
268+
assert system_row.uplink == sum(values[0] for values in node_totals.values())
269+
assert system_row.downlink == sum(values[1] for values in node_totals.values())
135270

136271

137272
@pytest.mark.asyncio
138-
async def test_record_node_usages_returns_when_totals_zero(monkeypatch: pytest.MonkeyPatch):
139-
nodes = [(1, DummyNode(1))]
273+
async def test_record_node_usages_returns_when_totals_zero(monkeypatch: pytest.MonkeyPatch, session_factory):
274+
async with session_factory() as session:
275+
node = Node(
276+
name="node-1",
277+
address="10.0.0.1",
278+
port=1000,
279+
server_ca="ca1",
280+
api_key="key1",
281+
core_config_id=None,
282+
)
283+
system = System(uplink=0, downlink=0)
284+
session.add_all([node, system])
285+
await session.flush()
286+
node_id, system_id = node.id, system.id
287+
await session.commit()
288+
289+
nodes = [(node_id, DummyNode(node_id))]
140290
monkeypatch.setattr(record_usages.node_manager, "get_healthy_nodes", AsyncMock(return_value=nodes))
141291

142292
async def fake_get_outbounds_stats(_: DummyNode):
143293
return [{"up": 0, "down": 0}]
144294

145295
monkeypatch.setattr(record_usages, "get_outbounds_stats", fake_get_outbounds_stats)
146296

147-
safe_execute_mock = AsyncMock()
148-
monkeypatch.setattr(record_usages, "safe_execute", safe_execute_mock)
297+
await record_usages.record_node_usages()
149298

150-
record_node_stats_mock = AsyncMock()
151-
monkeypatch.setattr(record_usages, "record_node_stats", record_node_stats_mock)
299+
async with session_factory() as session:
300+
node_row = await session.execute(select(Node.uplink, Node.downlink).where(Node.id == node_id))
301+
node_totals = node_row.one()
302+
assert node_totals.uplink == 0
303+
assert node_totals.downlink == 0
152304

153-
await record_usages.record_node_usages()
305+
system_row = await session.execute(select(System.uplink, System.downlink).where(System.id == system_id))
306+
system_totals = system_row.one()
307+
assert system_totals.uplink == 0
308+
assert system_totals.downlink == 0
154309

155-
safe_execute_mock.assert_not_awaited()
156-
record_node_stats_mock.assert_not_awaited()
310+
node_usage_rows = await session.execute(select(NodeUsage.id))
311+
assert node_usage_rows.first() is None

0 commit comments

Comments
 (0)