Skip to content
This repository has been archived by the owner on Nov 26, 2022. It is now read-only.

Commit

Permalink
opt core (#78)
Browse files Browse the repository at this point in the history
* opt core

* remove timeout mixin
  • Loading branch information
Ehco1996 committed Sep 10, 2020
1 parent 3e4cca4 commit 5113b42
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 135 deletions.
1 change: 1 addition & 0 deletions shadowsocks/cipherman.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def find_access_user_by_data(self, data):
self.access_user.save()

def _record_user_traffic(self, ut_data_len: int, dt_data_len: int):
# TODO 写db的地方挪到队列里去做
self.access_user and self.access_user.record_traffic(ut_data_len, dt_data_len)
NETWORK_TRANSMIT_BYTES.inc(ut_data_len + dt_data_len)

Expand Down
201 changes: 67 additions & 134 deletions shadowsocks/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,32 +9,7 @@
from shadowsocks.utils import parse_header


class TimeoutMixin:
TIMEOUT = 10

def __init__(self):
self.loop = asyncio.get_running_loop()
self.timeout_handle = self.loop.call_later(self.TIMEOUT, self._timeout)

self._need_clean = False

def close(self):
raise NotImplementedError

def _timeout(self):
self._need_clean = True
self.close()

def keep_alive(self):
self.timeout_handle.cancel()
self.timeout_handle = self.loop.call_later(self.TIMEOUT, self._timeout)

@property
def need_clean(self):
return self._need_clean


class LocalHandler(TimeoutMixin):
class LocalHandler:
"""
事件循环一共处理五个状态
Expand Down Expand Up @@ -66,7 +41,7 @@ def __init__(self, port):
self._is_closing = False
self._connect_buffer = bytearray()

def _init_transport(self, transport, peername, protocol):
def _init_transport(self, transport: asyncio.Transport, peername, protocol):
self._stage = self.STAGE_INIT
self._transport = transport
self._peername = peername
Expand All @@ -80,28 +55,22 @@ def _init_cipher(self):
self.cipher = CipherMan.get_cipher_by_port(self.port, self._transport_protocol)

def close(self):
self._stage = self.STAGE_DESTROY
if self._is_closing:
return
self._stage = self.STAGE_DESTROY
self._is_closing = True

if self._transport_protocol == flag.TRANSPORT_TCP:
self._transport and self._transport.close()
self._remote and self._remote.close()
self.cipher and self.cipher.incr_user_tcp_num(-1)
elif self._transport_protocol == flag.TRANSPORT_UDP:
pass
self._remote and self._remote.close()
ACTIVE_CONNECTION_COUNT.inc(-1)

def write(self, data):
if self._transport_protocol == flag.TRANSPORT_TCP:
self._transport and not self._transport.is_closing() and self._transport.write(
data
)
self._transport.write(data)
else:
self._transport and not self._transport.is_closing() and self._transport.sendto(
data, self._peername
)
self._transport.sendto(data, self._peername)

def handle_connection_made(self, transport_protocol, transport, peername):
self._init_transport(transport, peername, transport_protocol)
Expand Down Expand Up @@ -129,8 +98,6 @@ def handle_data_received(self, data):
if not data:
return

self.keep_alive()

if self._stage == self.STAGE_INIT:
asyncio.create_task(self._handle_stage_init(data))
elif self._stage == self.STAGE_CONNECT:
Expand All @@ -145,63 +112,50 @@ def handle_data_received(self, data):
logging.warning(f"unknown stage:{self._stage}")

async def _handle_stage_init(self, data):
if self._transport_protocol == flag.TRANSPORT_TCP:
self._stage = self.STAGE_CONNECT
addr_type, dst_addr, dst_port, header_length = parse_header(data)
if not all([addr_type, dst_addr, dst_port, header_length]):
logging.warning(f"parse error addr_type: {addr_type} port: {self.port}")
self.close()
return
else:
payload = data[header_length:]

logging.debug(
f"HEADER: {addr_type} - {dst_addr} - {dst_port} - {self._transport_protocol}"
)

loop = asyncio.get_running_loop()
if self._transport_protocol == flag.TRANSPORT_TCP:
self._stage = self.STAGE_CONNECT
tcp_coro = self.loop.create_connection(
lambda: RemoteTCP(dst_addr, dst_port, payload, self), dst_addr, dst_port
)
self._handle_stage_connect(payload)
try:
_, remote_tcp = await tcp_coro
except (IOError, OSError) as e:
self.close()
self._stage = self.STAGE_DESTROY
logging.debug(f"connection failed , {type(e)} e: {e}")
_, remote_tcp = await loop.create_connection(
lambda: RemoteTCP(self), dst_addr, dst_port
)
except Exception as e:
self._stage = self.STAGE_ERROR
self.close()
logging.warning(f"connection failed, {type(e)} e: {e}")
else:
self._remote = remote_tcp
self._stage = self.STAGE_STREAM
self._remote.write(self._connect_buffer)
logging.debug(f"connection ok buffer lens:{len(self._connect_buffer)}")
self.cipher.record_user_ip(self._peername)

elif self._transport_protocol == flag.TRANSPORT_UDP:
udp_coro = self.loop.create_datagram_endpoint(
lambda: RemoteUDP(dst_addr, dst_port, payload, self),
remote_addr=(dst_addr, dst_port),
)
else:
try:
await udp_coro
except (IOError, OSError) as e:
self.close()
self._stage = self.STAGE_DESTROY
logging.debug(f"connection failed , {type(e)} e: {e}")
await self.create_datagram_endpoint(
lambda: RemoteUDP(dst_addr, dst_port, payload, self),
remote_addr=(dst_addr, dst_port),
)
except Exception as e:
self._stage = self.STAGE_ERROR
self.close()
logging.warning(f"connection failed, {type(e)} e: {e}")
else:
raise NotImplementedError

def _handle_stage_connect(self, data):
# 在握手之后,会耗费一定时间来来和remote建立连接,但是ss-client并不会等这个时间
self._connect_buffer.extend(data)
if not self._remote or self._remote.ready == False:
self._connect_buffer.extend(data)
else:
self._stage = self.STAGE_STREAM
self._handle_stage_stream(data)

def _handle_stage_stream(self, data):
self._remote.write(data)
Expand All @@ -226,16 +180,10 @@ def __call__(self):
return local

def pause_writing(self):
try:
self._handler._remote._transport.pause_reading()
except AttributeError:
pass
self._handler._remote._transport.pause_reading()

def resume_writing(self):
try:
self._handler._remote._transport.resume_reading()
except AttributeError:
pass
self._handler._remote._transport.resume_reading()

def connection_made(self, transport):
self._transport = transport
Expand All @@ -252,99 +200,85 @@ def connection_lost(self, exc):
self._handler.handle_connection_lost(exc)


class LocalUDP(asyncio.DatagramProtocol):
"""
Local Udp Factory
"""

def __init__(self, port):
self.port = port
self._protocols = {}
self._transport = None

def __call__(self):
local = LocalUDP(self.port)
return local

def connection_made(self, transport):
self._transport = transport

def datagram_received(self, data, peername):
if peername in self._protocols:
handler = self._protocols[peername]
else:
handler = LocalHandler(self.port)
self._protocols[peername] = handler
handler.handle_connection_made(
flag.TRANSPORT_UDP, self._transport, peername
)

handler.handle_data_received(data)
self._clear_closed_handlers()

def error_received(self, exc):
pass

def _clear_closed_handlers(self):
logging.debug(f"now udp handler {len(self._protocols)}")
need_clear_peers = []
for peername, handler in self._protocols.items():
if handler.need_clean:
need_clear_peers.append(peername)
for peer in need_clear_peers:
del self._protocols[peer]
logging.debug(f"after clear {len(self._protocols)}")


class RemoteTCP(asyncio.Protocol, TimeoutMixin):
def __init__(self, addr, port, data, local_handler):
class RemoteTCP(asyncio.Protocol):
def __init__(self, local_handler):
super().__init__()

self.data = data
self.local = local_handler
self.peername = None
self._transport = None
self.cipher = CipherMan(access_user=local_handler.cipher.access_user)
self.ready = False

self._is_closing = False

def write(self, data):
self._transport and not self._transport.is_closing() and self._transport.write(
data
)
self._transport.write(data)

def close(self):
if self._is_closing:
return
self._is_closing = True

self._transport and self._transport.close()
del self.local
self.local.close()

def connection_made(self, transport):
def connection_made(self, transport: asyncio.Transport):
self._transport = transport
self.peername = self._transport.get_extra_info("peername")
self.write(self.data)
transport.write(self.local._connect_buffer)
self.ready = True

def data_received(self, data):
self.keep_alive()
self.local.write(self.cipher.encrypt(data))

def pause_reading(self):
self._transport and self._transport.pause_reading()
self.local._transport.pause_reading()

def resume_reading(self):
self._transport and self._transport.resume_reading()
self.local._transport.resume_reading()

def eof_received(self):
self.local and self.local.handle_eof_received()
self.close()

def connection_lost(self, exc):
self.close()


class RemoteUDP(asyncio.DatagramProtocol, TimeoutMixin):
class LocalUDP(asyncio.DatagramProtocol):
"""
Local Udp Factory
"""

def __init__(self, port):
self.port = port
self._protocols = {}
self._transport = None

def __call__(self):
local = LocalUDP(self.port)
return local

def connection_made(self, transport):
self._transport = transport

def datagram_received(self, data, peername):
if peername in self._protocols:
handler = self._protocols[peername]
else:
handler = LocalHandler(self.port)
self._protocols[peername] = handler
handler.handle_connection_made(
flag.TRANSPORT_UDP, self._transport, peername
)
handler.handle_data_received(data)

def error_received(self, exc):
# TODO clean udp conn
pass


class RemoteUDP(asyncio.DatagramProtocol):
def __init__(self, addr, port, data, local_hander):
super().__init__()
self.data = data
Expand Down Expand Up @@ -375,7 +309,6 @@ def connection_made(self, transport):
self.write(self.data)

def datagram_received(self, data, peername, *arg):
self.keep_alive()

assert self.peername == peername
# 源地址和端口
Expand Down
2 changes: 1 addition & 1 deletion userconfigs.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
{
"user_id": 1,
"port": 2345,
"method": "aes-256-cfb",
"method": "none",
"password": "hellotheworld1",
"transfer": 104857600,
"speed_limit": 0
Expand Down

0 comments on commit 5113b42

Please sign in to comment.