Skip to content

Commit 4a832cb

Browse files
committed
refactor(subscription): improve IP generation logic and update tests for dynamic allocation
1 parent 68c50d5 commit 4a832cb

File tree

3 files changed

+106
-65
lines changed

3 files changed

+106
-65
lines changed

app/subscription/base.py

Lines changed: 51 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import base64
22
import hashlib
3+
import ipaddress
34
import json
45
import re
56
from enum import Enum
@@ -193,39 +194,59 @@ def _get_wireguard_peer_ips(settings: dict, inbound: SubscriptionInboundData) ->
193194
return peer_ips
194195

195196
user_id = settings.get("_user_id")
196-
if not user_id:
197+
if user_id is None:
197198
return []
198199

199200
local_addresses = inbound.wireguard_local_address or []
200-
if not local_addresses:
201-
return []
202-
203-
import ipaddress
204-
205-
generated_ips = []
206-
for addr in local_addresses:
207-
try:
208-
network = ipaddress.ip_network(addr, strict=False)
209-
if network.version == 4:
210-
network_addr = int(network.network_address)
211-
usable_hosts = network.num_addresses - 2
212-
if usable_hosts <= 0:
213-
continue
214-
offset = (user_id % usable_hosts) + 1
215-
ip = ipaddress.IPv4Address(network_addr + offset)
216-
generated_ips.append(f"{ip}/32")
217-
elif network.version == 6:
218-
network_addr = int(network.network_address)
219-
usable_hosts = network.num_addresses - 2
220-
if usable_hosts <= 0:
221-
continue
222-
offset = (user_id % usable_hosts) + 1
223-
ip = ipaddress.IPv6Address(network_addr + offset)
224-
generated_ips.append(f"{ip}/128")
225-
except (ValueError, ipaddress.AddressValueError):
226-
continue
227-
228-
return generated_ips
201+
if local_addresses:
202+
generated_ips = []
203+
for addr in local_addresses:
204+
try:
205+
network = ipaddress.ip_network(addr, strict=False)
206+
if network.version == 4:
207+
network_addr = int(network.network_address)
208+
usable_hosts = network.num_addresses - 2
209+
if usable_hosts <= 0:
210+
continue
211+
offset = (user_id % usable_hosts) + 1
212+
ip = ipaddress.IPv4Address(network_addr + offset)
213+
generated_ips.append(f"{ip}/32")
214+
elif network.version == 6:
215+
network_addr = int(network.network_address)
216+
usable_hosts = network.num_addresses - 2
217+
if usable_hosts <= 0:
218+
continue
219+
offset = (user_id % usable_hosts) + 1
220+
ip = ipaddress.IPv6Address(network_addr + offset)
221+
generated_ips.append(f"{ip}/128")
222+
except (ValueError, ipaddress.AddressValueError):
223+
continue
224+
225+
if generated_ips:
226+
return generated_ips
227+
228+
# Fallback to global pool based on user_id (deterministic)
229+
# Use 10.0.0.0/8 pool, skipping reserved IPs
230+
global_pool = ipaddress.ip_network("10.0.0.0/8")
231+
reserved = {ipaddress.ip_address("10.0.0.0"), ipaddress.ip_address("10.0.0.1")}
232+
233+
# Calculate a deterministic IP for this user
234+
# Use user_id to select an IP from the pool
235+
start = int(global_pool.network_address)
236+
end = int(global_pool.broadcast_address)
237+
238+
# Skip reserved IPs and broadcast
239+
offset = user_id
240+
while True:
241+
candidate_int = start + 2 + (offset % (end - start - 2))
242+
if candidate_int > end:
243+
break
244+
candidate = ipaddress.ip_address(candidate_int)
245+
if candidate not in reserved and candidate != global_pool.broadcast_address:
246+
return [f"{candidate}/32"]
247+
offset += 1
248+
249+
return []
229250

230251
def _build_wireguard_components(
231252
self, remark: str, address: str, inbound: SubscriptionInboundData, settings: dict

app/utils/wireguard.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from app.db.crud.user import get_users_with_proxy_settings
1111
from app.models.proxy import ProxyTable
1212
from app.utils.crypto import generate_wireguard_keypair, get_wireguard_public_key
13-
from app.utils.ip_pool import allocate_from_global_pool, validate_peer_ips_globally
13+
from app.utils.ip_pool import validate_peer_ips_globally
1414

1515

1616
async def get_wireguard_tags(tags: Iterable[str]) -> list[str]:
@@ -72,8 +72,6 @@ async def prepare_wireguard_proxy_settings(
7272

7373
if peer_ips:
7474
await validate_peer_ips_globally(db, peer_ips, exclude_user_id=exclude_user_id)
75-
else:
76-
peer_ips = []
7775

7876
proxy_settings.wireguard.peer_ips = peer_ips
7977
return proxy_settings

tests/api/test_user.py

Lines changed: 54 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -667,15 +667,11 @@ def test_user_can_be_assigned_to_multiple_wireguard_interfaces(access_token):
667667
# Get the auto-allocated peer IPs
668668
peer_ips = user["proxy_settings"]["wireguard"]["peer_ips"]
669669

670-
# Verify that peer_ips is a non-empty list (auto-allocated from global pool)
670+
# peer_ips should be empty in user settings (dynamically generated during subscription)
671671
assert isinstance(peer_ips, list)
672-
assert len(peer_ips) > 0
672+
assert len(peer_ips) == 0
673673

674-
# All peer IPs should be /32 (single hosts)
675-
for peer_ip in peer_ips:
676-
assert peer_ip.endswith("/32")
677-
678-
# Verify that the same peer_ips are used for all WireGuard interfaces
674+
# Verify that peer IPs are dynamically generated during subscription
679675
links_response = client.get(f"{user['subscription_url']}/links")
680676
assert links_response.status_code == status.HTTP_200_OK
681677

@@ -686,37 +682,48 @@ def test_user_can_be_assigned_to_multiple_wireguard_interfaces(access_token):
686682
parsed = urlsplit(line.strip())
687683
links_by_endpoint[f"{parsed.hostname}:{parsed.port}"] = parse_qs(parsed.query)
688684

689-
# Both endpoints should have the same peer IPs
690-
expected_address = ",".join(peer_ips)
691-
assert links_by_endpoint[f"{first_endpoint}:51820"]["address"] == [expected_address]
692-
assert links_by_endpoint[f"{second_endpoint}:51821"]["address"] == [expected_address]
685+
# Both endpoints should have peer IPs generated from interface addresses
686+
# User ID 17 with first_interface address "10.30.10.1/24" should get "10.30.10.18/32"
687+
# User ID 17 with second_interface address "10.40.10.1/24" should get "10.40.10.18/32"
688+
first_address = links_by_endpoint[f"{first_endpoint}:51820"]["address"][0]
689+
second_address = links_by_endpoint[f"{second_endpoint}:51821"]["address"][0]
690+
691+
# Verify IPs are from correct ranges
692+
assert first_address.startswith("10.30.10.")
693+
assert second_address.startswith("10.40.10.")
694+
assert first_address.endswith("/32")
695+
assert second_address.endswith("/32")
693696

694697
# Verify WireGuard subscription contains the peer IPs
695698
wireguard_response = client.get(f"{user['subscription_url']}/wireguard")
696699
assert wireguard_response.status_code == status.HTTP_200_OK
697700
config_bodies = extract_wireguard_config_bodies(wireguard_response)
698701
assert len(config_bodies) == 2
699702

700-
expected_address = f"Address = {', '.join(peer_ips)}"
703+
# Verify each config has correct Address from respective interface
704+
for body in config_bodies:
705+
# Should have Address from one of the interfaces
706+
assert "Address = 10.30.10." in body or "Address = 10.40.10." in body
707+
assert "/32" in body
708+
701709
expected_endpoints = {f"Endpoint = {first_endpoint}:51820", f"Endpoint = {second_endpoint}:51821"}
702710
actual_endpoints = set()
703711

704712
for body in config_bodies:
705-
assert expected_address in body
706713
for endpoint in expected_endpoints:
707714
if endpoint in body:
708715
actual_endpoints.add(endpoint)
709716

710717
assert actual_endpoints == expected_endpoints
711718

712-
# Test no-op update preserves peer_ips
719+
# Test no-op update preserves empty peer_ips
713720
update_response = client.put(
714721
f"/api/user/{user['username']}",
715722
headers=auth_headers(access_token),
716723
json={"note": "keep existing wireguard allocations"},
717724
)
718725
assert update_response.status_code == status.HTTP_200_OK
719-
assert update_response.json()["proxy_settings"]["wireguard"]["peer_ips"] == peer_ips
726+
assert update_response.json()["proxy_settings"]["wireguard"]["peer_ips"] == []
720727
finally:
721728
delete_user(access_token, user["username"])
722729
delete_group(access_token, group["id"])
@@ -1579,16 +1586,29 @@ def test_wireguard_peer_ip_global_pool_and_validation(access_token):
15791586
assert response.status_code == status.HTTP_400_BAD_REQUEST
15801587
assert "reserved for the server" in response.json()["detail"]
15811588

1582-
# Test 2: Create user without specifying peer IPs - should get IP from global pool
1589+
# Test 2: Create user without specifying peer IPs - should get IP dynamically during subscription
15831590
user1 = create_user(
15841591
access_token,
15851592
group_ids=[group["id"]],
15861593
payload={"username": unique_name("wg_auto_ip_user1")},
15871594
)
1588-
assert user1["proxy_settings"]["wireguard"]["peer_ips"]
1589-
peer_ip1 = user1["proxy_settings"]["wireguard"]["peer_ips"][0]
1595+
# peer_ips should be empty in user settings
1596+
assert user1["proxy_settings"]["wireguard"]["peer_ips"] == []
1597+
1598+
# But subscription should work with dynamically allocated IP from global pool
1599+
links_response = client.get(f"{user1['subscription_url']}/links")
1600+
assert links_response.status_code == status.HTTP_200_OK
1601+
1602+
# Should have a wireguard link with an IP from global pool (10.0.0.0/8)
1603+
link = links_response.text.strip()
1604+
assert link.startswith("wireguard://")
1605+
parsed = urlsplit(link)
1606+
query = parse_qs(parsed.query)
1607+
peer_ip1 = query.get("address", [""])[0]
1608+
# Should be an IP from 10.0.0.0/8 pool
15901609
assert peer_ip1.startswith("10.")
1591-
assert peer_ip1 != "10.0.0.1/32"
1610+
assert peer_ip1.endswith("/32")
1611+
assert peer_ip1 != "10.0.0.1/32" # Should not be the reserved server IP
15921612

15931613
# Test 3: Try to create another user with the same IP - should fail
15941614
response = client.post(
@@ -1604,29 +1624,31 @@ def test_wireguard_peer_ip_global_pool_and_validation(access_token):
16041624
"group_ids": [group["id"]],
16051625
},
16061626
)
1627+
# Since peer_ips are dynamically generated, manually specifying a duplicate should be rejected
16071628
assert response.status_code == status.HTTP_400_BAD_REQUEST
16081629
assert "already in use" in response.json()["detail"]
16091630

1610-
# Test 4: Create another user without specifying peer IPs - should get different IP
1631+
# Test 4: Create another user without specifying peer IPs - should get different IP dynamically
16111632
user2 = create_user(
16121633
access_token,
16131634
group_ids=[group["id"]],
16141635
payload={"username": unique_name("wg_auto_ip_user2")},
16151636
)
1616-
assert user2["proxy_settings"]["wireguard"]["peer_ips"]
1617-
peer_ip2 = user2["proxy_settings"]["wireguard"]["peer_ips"][0]
1637+
# peer_ips should be empty in user settings
1638+
assert user2["proxy_settings"]["wireguard"]["peer_ips"] == []
1639+
1640+
# Get dynamically allocated IP from subscription
1641+
links_response2 = client.get(f"{user2['subscription_url']}/links")
1642+
assert links_response2.status_code == status.HTTP_200_OK
1643+
link2 = links_response2.text.strip()
1644+
assert link2.startswith("wireguard://")
1645+
parsed2 = urlsplit(link2)
1646+
query2 = parse_qs(parsed2.query)
1647+
peer_ip2 = query2.get("address", [""])[0]
16181648
assert peer_ip2.startswith("10.")
1649+
assert peer_ip2.endswith("/32")
1650+
# Different users should get different IPs
16191651
assert peer_ip2 != peer_ip1
1620-
assert peer_ip2 != "10.0.0.1/32"
1621-
1622-
# Test 5: Verify subscription links work with auto-allocated IPs
1623-
links_response = client.get(f"{user1['subscription_url']}/links")
1624-
assert links_response.status_code == status.HTTP_200_OK
1625-
link = links_response.text.strip()
1626-
assert link.startswith("wireguard://")
1627-
parsed = urlsplit(link)
1628-
query = parse_qs(parsed.query)
1629-
assert query["address"] == [peer_ip1]
16301652

16311653
finally:
16321654
if user1:

0 commit comments

Comments
 (0)