Skip to content
This repository has been archived by the owner on Nov 26, 2022. It is now read-only.

Commit

Permalink
feat: bulk update #158
Browse files Browse the repository at this point in the history
  • Loading branch information
Ehco1996 committed Apr 9, 2021
1 parent d153718 commit 2ec1723
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 103 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ keywords = ["shadowsocsk", "asyncio", "grpc"]
license = "GPLV3"
name = "aioshadowsocks"
readme = "README.md"
version = "0.1.7"
version = "0.1.8"

packages = [
{include = "shadowsocks"},
Expand Down
60 changes: 29 additions & 31 deletions shadowsocks/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,8 @@ async def logging_grpc_request(event: RecvRequest) -> None:

class App:
def __init__(self) -> None:
self._init_config()
self._init_logger()
self._prepared = False
self._prepare()

def _init_config(self):
self.config = {
Expand Down Expand Up @@ -71,10 +70,13 @@ def _init_logger(self):
"DEBUG": 10,
}
level = log_levels[self.log_level.upper()]
logging.basicConfig(
format="[%(levelname)s]%(asctime)s %(funcName)s line:%(lineno)d %(message)s",
level=level,
)
if level == 20:
format = "[%(levelname)s]%(asctime)s --- %(message)s"
else:
format = (
"[%(levelname)s]%(asctime)s %(funcName)s line:%(lineno)d --- %(message)s"
)
logging.basicConfig(format=format, level=level)

def _init_memory_db(self):

Expand All @@ -93,32 +95,35 @@ def _prepare(self):
if self._prepared:
return
self.loop = asyncio.get_event_loop()
self._init_config()
self._init_logger()
self._init_memory_db()
self._init_sentry()
self.loop.add_signal_handler(signal.SIGTERM, self._shutdown)
self.proxyman = ProxyMan(
self.use_json, self.sync_time, self.listen_host, self.api_endpoint
)

signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT)
for s in signals:
self.loop.add_signal_handler(
s, lambda s=s: asyncio.create_task(self._shutdown())
)

self._prepared = True

def _shutdown(self):
async def _shutdown(self):
logging.info("正在关闭所有ss server")
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
[task.cancel() for task in tasks]
self.proxyman.close_server()
if self.grpc_server:
self.grpc_server.close()
logging.info(f"grpc server closed!")
if self.metrics_server:
self.loop.create_task(self.metrics_server.stop())
await self.metrics_server.stop()
logging.info(f"metrics server closed!")
self.loop.stop()

def _run_loop(self):

try:
self.loop.run_forever()
except KeyboardInterrupt:
self._shutdown()

async def _start_grpc_server(self):

self.grpc_server = Server([AioShadowsocksServicer()], loop=self.loop)
Expand All @@ -137,26 +142,19 @@ async def _start_metrics_server(self):
f"Start Metrics Server At: http://0.0.0.0:{self.metrics_port}/metrics"
)

def run_ss_server(self):
self._prepare()
self.loop.create_task(self.proxyman.start_and_check_ss_server())
async def _start_ss_server(self):

if self.metrics_port:
self.loop.create_task(self._start_metrics_server())
await self._start_metrics_server()

if self.grpc_host and self.grpc_port:
self.loop.create_task(self._start_grpc_server())

self._run_loop()
await self._start_grpc_server()

def run_grpc_server(self):
self._prepare()
await self.proxyman.start_and_check_ss_server()

if self.grpc_host and self.grpc_port:
self.loop.create_task(self._start_grpc_server())
else:
raise Exception("grpc server not config")

self._run_loop()
def run_ss_server(self):
self.loop.create_task(self._start_ss_server())
self.loop.run_forever()

def get_user(self, user_id):
c = SSClient(f"{self.grpc_host}:{self.grpc_port}")
Expand Down
2 changes: 1 addition & 1 deletion shadowsocks/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def handle_data_received(self, data):
data = self.cipher.decrypt(data)
except Exception as e:
logging.warning(
f"decrypt data error:{e} remote:{self._peername},type:{self._transport_protocol_human} closing..."
f"decrypt data error:{e} remote:{self._peername},type:{self._transport_protocol_human}"
)
self.close()
return
Expand Down
2 changes: 1 addition & 1 deletion shadowsocks/mdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def get_or_create(cls, **kw):
def update_from_dict(self, data, ignore_unknown=False, use_whitelist=False):
"""注意值是没有写入数据库的, 需要显式 save"""
cls = type(self)
clean_data = cls._filter_attrs(data)
clean_data = cls._filter_attrs(data, use_whitelist)
return shortcuts.update_model_from_dict(self, clean_data, ignore_unknown)

def to_dict(self, **kw):
Expand Down
111 changes: 71 additions & 40 deletions shadowsocks/mdb/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import logging
from typing import List

import peewee as pw
from cryptography.exceptions import InvalidTag
Expand All @@ -21,7 +22,6 @@ class User(BaseModel):
method = pw.CharField()
password = pw.CharField(unique=True)
enable = pw.BooleanField(default=True)
speed_limit = pw.IntegerField(default=0)
access_order = pw.BigIntegerField(
index=True, default=0
) # NOTE find_access_user order
Expand All @@ -32,6 +32,9 @@ class User(BaseModel):
upload_traffic = pw.BigIntegerField(default=0)
download_traffic = pw.BigIntegerField(default=0)

def __str__(self):
return f"[{self.user_id}-{self.access_order}]"

@classmethod
def _create_or_update_user_from_data(cls, data):
user_id = data.pop("user_id")
Expand All @@ -44,24 +47,12 @@ def _create_or_update_user_from_data(cls, data):

@classmethod
def list_by_port(cls, port):
fields = [
cls.user_id,
cls.method,
cls.password,
cls.enable,
cls.ip_list,
cls.access_order,
]
return (
cls.select(*fields)
.where(cls.port == port)
.order_by(cls.access_order.desc())
)
return cls.select().where(cls.port == port).order_by(cls.access_order.desc())

@classmethod
@db.atomic("EXCLUSIVE")
def create_or_update_by_user_data_list(cls, user_data_list):
if cls.select().count() == 0:
if not cls.select().first():
# bulk create
users = [
cls(
Expand All @@ -75,35 +66,52 @@ def create_or_update_by_user_data_list(cls, user_data_list):
]
cls.bulk_create(users, batch_size=len(users))
else:
user_ids = []
db_user_dict = {
u.user_id: u
for u in cls.select(
cls.user_id, cls.enable, cls.method, cls.password, cls.port
)
}
enable_user_ids = []
need_update_or_create_users = []
for user_data in user_data_list:
user_ids.append(user_data["user_id"])
user_id = user_data["user_id"]
enable_user_ids.append(user_id)
db_user = db_user_dict[user_id]
# 找到配置变化了的用户
if (
not db_user
or db_user.port != user_data["port"]
or db_user.enable != user_data["enable"]
or db_user.method != user_data["method"]
or db_user.password != user_data["password"]
):
need_update_or_create_users.append(user_data)
for user_data in need_update_or_create_users:
cls._create_or_update_user_from_data(user_data)
cnt = cls.delete().where(cls.user_id.not_in(user_ids)).execute()
cnt and logging.info(f"delete out of traffic user cnt: {cnt}")

@db.atomic("EXCLUSIVE")
def record_ip(self, peername):
if not peername:
return
self.ip_list.add(peername[0])
User.update(ip_list=self.ip_list, need_sync=True).where(
User.user_id == self.user_id
).execute()
sync_msg = "sync users: enable_user_cnt: {} updated_user_cnt: {},deleted_user_cnt:{}".format(
len(enable_user_ids),
len(need_update_or_create_users),
cls.delete().where(cls.user_id.not_in(enable_user_ids)).execute(),
)
logging.info(sync_msg)

@classmethod
@db.atomic("EXCLUSIVE")
def record_traffic(self, used_u, used_d):
def get_and_reset_need_sync_user_metrics(cls) -> List[User]:
fields = [
User.user_id,
User.ip_list,
User.tcp_conn_num,
User.upload_traffic,
User.download_traffic,
]
users = list(User.select(*fields).where(User.need_sync == True))
empyt_set = set()
User.update(
download_traffic=User.download_traffic + used_d,
upload_traffic=User.upload_traffic + used_u,
need_sync=True,
).where(User.user_id == self.user_id).execute()

@db.atomic("EXCLUSIVE")
def incr_tcp_conn_num(self, num):
User.update(tcp_conn_num=User.tcp_conn_num + num, need_sync=True).where(
User.user_id == self.user_id
).execute()
ip_list=empyt_set, upload_traffic=0, download_traffic=0, need_sync=False
).where(User.need_sync == True).execute()
return users

@classmethod
@FIND_ACCESS_USER_TIME.time()
Expand All @@ -124,5 +132,28 @@ def find_access_user(cls, port, method, ts_protocol, first_data) -> User:
if access_user:
# NOTE 记下成功访问的用户,下次优先找到他
access_user.access_order += 1
access_user.save()
access_user.save(only=[cls.access_order])
return access_user

@db.atomic("EXCLUSIVE")
def record_ip(self, peername):
if not peername:
return
self.ip_list.add(peername[0])
User.update(ip_list=self.ip_list, need_sync=True).where(
User.user_id == self.user_id
).execute()

@db.atomic("EXCLUSIVE")
def record_traffic(self, used_u, used_d):
User.update(
download_traffic=User.download_traffic + used_d,
upload_traffic=User.upload_traffic + used_u,
need_sync=True,
).where(User.user_id == self.user_id).execute()

@db.atomic("EXCLUSIVE")
def incr_tcp_conn_num(self, num):
User.update(tcp_conn_num=User.tcp_conn_num + num, need_sync=True).where(
User.user_id == self.user_id
).execute()
42 changes: 13 additions & 29 deletions shadowsocks/proxyman.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import httpx

from shadowsocks.core import LocalTCP, LocalUDP
from shadowsocks.mdb.models import User, db
from shadowsocks.mdb.models import User


class ProxyMan:
Expand Down Expand Up @@ -45,41 +45,25 @@ async def get_user_from_remote(url):

@staticmethod
async def flush_metrics_to_remote(url):
fields = [
User.user_id,
User.ip_list,
User.tcp_conn_num,
User.upload_traffic,
User.download_traffic,
data = [
{
"user_id": user.user_id,
"ip_list": list(user.ip_list),
"tcp_conn_num": user.tcp_conn_num,
"upload_traffic": user.upload_traffic,
"download_traffic": user.download_traffic,
}
for user in User.get_and_reset_need_sync_user_metrics()
]
with db.atomic("EXCLUSIVE"):
users = list(User.select(*fields).where(User.need_sync == True))
User.update(
ip_list=set(), upload_traffic=0, download_traffic=0, need_sync=False
).where(User.need_sync == True).execute()

data = [{
"user_id": user.user_id,
"ip_list": list(user.ip_list),
"tcp_conn_num": user.tcp_conn_num,
"upload_traffic": user.upload_traffic,
"download_traffic": user.download_traffic,
} for user in users]
async with httpx.AsyncClient() as client:
await client.post(url, json={"data": data})

async def sync_from_remote_cron(self):
try:
await self.flush_metrics_to_remote(self.api_endpoint)
await self.get_user_from_remote(self.api_endpoint)
except Exception as e:
logging.warning(f"sync user from remote error {e}")
await self.flush_metrics_to_remote(self.api_endpoint)
await self.get_user_from_remote(self.api_endpoint)

async def sync_from_json_cron(self):
try:
self.create_or_update_from_json("userconfigs.json")
except Exception as e:
logging.warning(f"sync user from json error {e}")
self.create_or_update_from_json("userconfigs.json")

def get_server_by_port(self, port):
return self.__running_servers__.get(port)
Expand Down

0 comments on commit 2ec1723

Please sign in to comment.