Skip to content

Commit 138eae8

Browse files
committed
feat: introduce ProxyProtocol enum and update related classes to utilize it
1 parent ca4ac7e commit 138eae8

6 files changed

Lines changed: 112 additions & 33 deletions

File tree

app/core/abstract_core.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from abc import ABC, abstractmethod
22

3+
from app.models.protocol import ProxyProtocol
4+
35

46
class AbstractCore(ABC):
57
@abstractmethod
@@ -35,3 +37,8 @@ def from_json(cls, data: dict) -> "AbstractCore":
3537
@abstractmethod
3638
def inbounds(self) -> list[str]:
3739
raise NotImplementedError
40+
41+
@property
42+
@abstractmethod
43+
def protocols(self) -> frozenset[ProxyProtocol]:
44+
raise NotImplementedError

app/core/wireguard.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,11 @@
99
import commentjson
1010

1111
from app.models.core import CoreType
12+
from app.models.protocol import ProxyProtocol
1213
from app.utils.crypto import get_wireguard_public_key, validate_wireguard_key
1314

15+
_WIREGUARD_PROTOCOLS = frozenset((ProxyProtocol.wireguard,))
16+
1417

1518
class WireGuardConfig(dict):
1619
def __init__(
@@ -115,6 +118,10 @@ def inbounds_by_tag(self) -> dict:
115118
def inbounds(self) -> list[str]:
116119
return self._inbounds
117120

121+
@property
122+
def protocols(self) -> frozenset[ProxyProtocol]:
123+
return _WIREGUARD_PROTOCOLS
124+
118125
def to_json(self) -> dict:
119126
return {
120127
"type": self.type,

app/core/xray.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,18 @@
99
import commentjson
1010

1111
from app.models.core import CoreType
12+
from app.models.protocol import ProxyProtocol
1213
from app.utils.crypto import get_cert_SANs, get_x25519_public_key
1314

1415

16+
def _protocols_from_inbounds_by_tag(inbounds_by_tag: dict[str, dict]) -> frozenset[ProxyProtocol]:
17+
return frozenset(
18+
protocol
19+
for inbound in inbounds_by_tag.values()
20+
if (protocol := ProxyProtocol.from_value(inbound["protocol"])) is not None
21+
)
22+
23+
1524
class XRayConfig(dict):
1625
def __init__(
1726
self,
@@ -45,6 +54,7 @@ def __init__(
4554
self._inbounds = []
4655
self._inbounds_by_tag = {}
4756
self._fallbacks_inbound = []
57+
self._protocols: frozenset[ProxyProtocol] = frozenset()
4858

4959
# Registery pattern for network handlers, making it easy to add support for new network types in the future
5060
self.network_handlers = {
@@ -360,6 +370,7 @@ def _resolve_inbounds(self):
360370
"""Resolve all inbounds and their settings."""
361371
for inbound in self["inbounds"]:
362372
self._read_inbound(inbound)
373+
self._protocols = _protocols_from_inbounds_by_tag(self._inbounds_by_tag)
363374

364375
def _read_inbound(self, inbound: dict):
365376
"""Read an inbound and its settings."""
@@ -471,6 +482,10 @@ def inbounds(self) -> list[str]:
471482
"""Get inbounds by tag."""
472483
return self._inbounds
473484

485+
@property
486+
def protocols(self) -> frozenset[ProxyProtocol]:
487+
return self._protocols
488+
474489
@property
475490
def type(self) -> str:
476491
return self._type
@@ -503,6 +518,7 @@ def from_json(cls, data: dict) -> "XRayConfig":
503518
instance._inbounds = data["inbounds"]
504519
if "inbounds_by_tag" in data:
505520
instance._inbounds_by_tag = data["inbounds_by_tag"]
521+
instance._protocols = _protocols_from_inbounds_by_tag(instance._inbounds_by_tag)
506522
return instance
507523

508524
def copy(self):

app/models/protocol.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from enum import IntEnum
2+
3+
4+
class ProxyProtocol(IntEnum):
5+
vmess = 1
6+
vless = 2
7+
trojan = 3
8+
shadowsocks = 4
9+
wireguard = 5
10+
hysteria = 6
11+
12+
@classmethod
13+
def from_value(cls, value: str) -> "ProxyProtocol" | None:
14+
try:
15+
return _PROXY_PROTOCOL_BY_NAME[value]
16+
except KeyError:
17+
return None
18+
19+
20+
_PROXY_PROTOCOL_BY_NAME = {protocol.name: protocol for protocol in ProxyProtocol}

app/node/user.py

Lines changed: 57 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
1-
import inspect
2-
31
from PasarGuardNodeBridge import create_proxy, create_user
42
from PasarGuardNodeBridge.common.service_pb2 import User as ProtoUser
53
from sqlalchemy import and_, func, select
64

75
from app.db import AsyncSession
86
from app.db.models import Group, ProxyInbound, User, UserStatus, inbounds_groups_association, users_groups_association
7+
from app.models.protocol import ProxyProtocol
98

10-
_CREATE_PROXY_PARAMS = set(inspect.signature(create_proxy).parameters)
9+
_ALL_PROXY_PROTOCOLS = frozenset(ProxyProtocol)
1110

1211

1312
def _inbounds_from_loaded_groups(user: User) -> list[str] | None:
@@ -30,7 +29,7 @@ def _inbounds_from_loaded_groups(user: User) -> list[str] | None:
3029
return list(tags)
3130

3231

33-
async def serialize_user(user: User) -> ProtoUser:
32+
async def serialize_user(user: User, allowed_protocols: frozenset[ProxyProtocol] | None = None) -> ProtoUser:
3433
user_settings = user.proxy_settings
3534
inbounds = None
3635
status = user.__dict__.get("status")
@@ -42,38 +41,52 @@ async def serialize_user(user: User) -> ProtoUser:
4241
if inbounds is None:
4342
inbounds = await user.inbounds()
4443

45-
return _serialize_user_for_node(user.id, user.username, user_settings, inbounds)
46-
47-
48-
def _serialize_user_for_node(id: int, username: str, user_settings: dict, inbounds: list[str] = None) -> ProtoUser:
49-
vmess_settings = user_settings.get("vmess", {})
50-
vless_settings = user_settings.get("vless", {})
51-
if vless_settings.get("flow") == "xtls-rprx-vision-udp443":
52-
vless_settings["flow"] = "xtls-rprx-vision"
53-
trojan_settings = user_settings.get("trojan", {})
54-
shadowsocks_settings = user_settings.get("shadowsocks", {})
55-
wireguard_settings = user_settings.get("wireguard", {})
56-
hysteria_settings = user_settings.get("hysteria", {})
57-
proxy_kwargs = {
58-
"vmess_id": vmess_settings.get("id"),
59-
"vless_id": vless_settings.get("id"),
60-
"vless_flow": vless_settings.get("flow"),
61-
"trojan_password": trojan_settings.get("password"),
62-
"shadowsocks_password": shadowsocks_settings.get("password"),
63-
"shadowsocks_method": shadowsocks_settings.get("method"),
64-
"wireguard_public_key": wireguard_settings.get("public_key"),
65-
"wireguard_peer_ips": wireguard_settings.get("peer_ips") or [],
66-
"hysteria_auth": hysteria_settings.get("auth"),
67-
}
44+
return _serialize_user_for_node(user.id, user.username, user_settings, inbounds, allowed_protocols)
45+
46+
47+
def _serialize_user_for_node(
48+
id: int,
49+
username: str,
50+
user_settings: dict,
51+
inbounds: list[str] = None,
52+
allowed_protocols: frozenset[ProxyProtocol] | None = None,
53+
) -> ProtoUser:
54+
allowed_protocols = allowed_protocols or _ALL_PROXY_PROTOCOLS
55+
56+
proxy_kwargs = {}
57+
if ProxyProtocol.vmess in allowed_protocols:
58+
proxy_kwargs["vmess_id"] = user_settings.get("vmess", {}).get("id")
59+
if ProxyProtocol.vless in allowed_protocols:
60+
vless_settings = dict(user_settings.get("vless", {}))
61+
if vless_settings.get("flow") == "xtls-rprx-vision-udp443":
62+
vless_settings["flow"] = "xtls-rprx-vision"
63+
proxy_kwargs["vless_id"] = vless_settings.get("id")
64+
proxy_kwargs["vless_flow"] = vless_settings.get("flow")
65+
if ProxyProtocol.trojan in allowed_protocols:
66+
proxy_kwargs["trojan_password"] = user_settings.get("trojan", {}).get("password")
67+
if ProxyProtocol.shadowsocks in allowed_protocols:
68+
shadowsocks_settings = user_settings.get("shadowsocks", {})
69+
proxy_kwargs["shadowsocks_password"] = shadowsocks_settings.get("password")
70+
proxy_kwargs["shadowsocks_method"] = shadowsocks_settings.get("method")
71+
if ProxyProtocol.wireguard in allowed_protocols:
72+
wireguard_settings = user_settings.get("wireguard", {})
73+
proxy_kwargs["wireguard_public_key"] = wireguard_settings.get("public_key")
74+
proxy_kwargs["wireguard_peer_ips"] = wireguard_settings.get("peer_ips") or []
75+
if ProxyProtocol.hysteria in allowed_protocols:
76+
proxy_kwargs["hysteria_auth"] = user_settings.get("hysteria", {}).get("auth")
6877

6978
return create_user(
7079
f"{id}.{username}",
71-
create_proxy(**{key: value for key, value in proxy_kwargs.items() if key in _CREATE_PROXY_PARAMS}),
80+
create_proxy(**proxy_kwargs),
7281
inbounds,
7382
)
7483

7584

76-
async def core_users(db: AsyncSession, inbound_tags: list[str] | set[str] | None = None):
85+
async def core_users(
86+
db: AsyncSession,
87+
inbound_tags: list[str] | set[str] | None = None,
88+
allowed_protocols: frozenset[ProxyProtocol] | None = None,
89+
):
7790
dialect = db.bind.dialect.name
7891
inbound_tags = list(dict.fromkeys(inbound_tags or []))
7992

@@ -117,11 +130,21 @@ async def core_users(db: AsyncSession, inbound_tags: list[str] | set[str] | None
117130
for row in results:
118131
inbound_tags = row.inbound_tags.split(",") if row.inbound_tags else []
119132
if inbound_tags:
120-
bridge_users.append(_serialize_user_for_node(row.id, row.username, row.proxy_settings, inbound_tags))
133+
bridge_users.append(
134+
_serialize_user_for_node(
135+
row.id,
136+
row.username,
137+
row.proxy_settings,
138+
inbound_tags,
139+
allowed_protocols,
140+
)
141+
)
121142
return bridge_users
122143

123144

124-
async def serialize_users_for_node(users: list[User]) -> list[ProtoUser]:
145+
async def serialize_users_for_node(
146+
users: list[User], allowed_protocols: frozenset[ProxyProtocol] | None = None
147+
) -> list[ProtoUser]:
125148
bridge_users: list = []
126149

127150
for user in users:
@@ -133,6 +156,8 @@ async def serialize_users_for_node(users: list[User]) -> list[ProtoUser]:
133156
else:
134157
inbounds_list = loaded_inbounds
135158

136-
bridge_users.append(_serialize_user_for_node(user.id, user.username, user.proxy_settings, inbounds_list))
159+
bridge_users.append(
160+
_serialize_user_for_node(user.id, user.username, user.proxy_settings, inbounds_list, allowed_protocols)
161+
)
137162

138163
return bridge_users

app/operation/node.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,11 @@ async def _get_core_users_map(
209209
users_by_core[core_id] = []
210210
continue
211211

212-
users_by_core[core_id] = await core_users(db=db, inbound_tags=core.inbounds)
212+
users_by_core[core_id] = await core_users(
213+
db=db,
214+
inbound_tags=core.inbounds,
215+
allowed_protocols=core.protocols,
216+
)
213217

214218
return cores_by_id, users_by_core
215219

0 commit comments

Comments
 (0)