Skip to content

Commit

Permalink
use socket.AF_UNIX by default for unix_socket
Browse files Browse the repository at this point in the history
  • Loading branch information
cocolato committed Apr 16, 2024
1 parent 00b6997 commit f432bb4
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 43 deletions.
48 changes: 32 additions & 16 deletions thriftpy2/contrib/aio/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ async def make_client(service, host='localhost', port=9090, unix_socket=None,
cafile=None, ssl_context=None,
certfile=None, keyfile=None,
validate=True, url='',
socket_timeout=None, socket_family=socket.AF_INET):
socket_timeout=None, socket_family=None):
if socket_timeout is not None:
warnings.warn(
"The 'socket_timeout' argument is deprecated. "
Expand All @@ -31,24 +31,34 @@ async def make_client(service, host='localhost', port=9090, unix_socket=None,
host = parsed_url.hostname or host
port = parsed_url.port or port
if unix_socket:
socket = TAsyncSocket(unix_socket=unix_socket,
connect_timeout=connect_timeout,
socket_timeout=timeout,
socket_family=socket_family)
socket_family = socket.AF_UNIX if not socket_family else socket_family
client_socket = TAsyncSocket(
unix_socket=unix_socket,
connect_timeout=connect_timeout,
socket_timeout=timeout,
socket_family=socket_family,
)
if certfile:
warnings.warn("SSL only works with host:port, not unix_socket.")
elif host and port:
socket = TAsyncSocket(
host, port,
socket_timeout=timeout, connect_timeout=connect_timeout,
cafile=cafile, ssl_context=ssl_context,
certfile=certfile, keyfile=keyfile, validate=validate,
socket_family=socket_family)
socket_family = socket.AF_INET if not socket_family else socket_family
client_socket = TAsyncSocket(
host,
port,
socket_timeout=timeout,
connect_timeout=connect_timeout,
cafile=cafile,
ssl_context=ssl_context,
certfile=certfile,
keyfile=keyfile,
validate=validate,
socket_family=socket_family,
)
else:
raise ValueError("Either host/port or unix_socket"
" or url must be provided.")

transport = trans_factory.get_transport(socket)
transport = trans_factory.get_transport(client_socket)
protocol = proto_factory.get_protocol(transport)
await transport.open()
return TAsyncClient(service, protocol)
Expand All @@ -60,21 +70,27 @@ def make_server(service, handler,
trans_factory=TAsyncBufferedTransportFactory(),
client_timeout=3000, certfile=None,
keyfile=None, ssl_context=None, loop=None,
socket_family=socket.AF_INET):
socket_family=None):
processor = TAsyncProcessor(service, handler)

if unix_socket:
socket_family = socket.AF_UNIX if not socket_family else socket_family
server_socket = TAsyncServerSocket(
unix_socket=unix_socket,
socket_family=socket_family)
if certfile:
warnings.warn("SSL only works with host:port, not unix_socket.")
elif host and port:
socket_family = socket.AF_INET if not socket_family else socket_family
server_socket = TAsyncServerSocket(
host=host, port=port,
host=host,
port=port,
client_timeout=client_timeout,
certfile=certfile, keyfile=keyfile, ssl_context=ssl_context,
socket_family=socket_family)
certfile=certfile,
keyfile=keyfile,
ssl_context=ssl_context,
socket_family=socket_family,
)
else:
raise ValueError("Either host/port or unix_socket must be provided.")

Expand Down
76 changes: 49 additions & 27 deletions thriftpy2/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,29 +20,38 @@ def make_client(service, host="localhost", port=9090, unix_socket=None,
proto_factory=TBinaryProtocolFactory(),
trans_factory=TBufferedTransportFactory(),
timeout=3000, cafile=None, ssl_context=None, certfile=None,
keyfile=None, url="", socket_family=socket.AF_INET):
keyfile=None, url="", socket_family=None):
if url:
parsed_url = urllib.parse.urlparse(url)
host = parsed_url.hostname or host
port = parsed_url.port or port
if unix_socket:
socket = TSocket(unix_socket=unix_socket, socket_timeout=timeout)
socket_family = socket.AF_UNIX if not socket_family else socket_family
client_socket = TSocket(unix_socket=unix_socket, socket_timeout=timeout)
if certfile:
warnings.warn("SSL only works with host:port, not unix_socket.")
elif host and port:
socket_family = socket.AF_INET if not socket_family else socket_family
if cafile or ssl_context:
socket = TSSLSocket(host, port, socket_timeout=timeout,
socket_family=socket_family, cafile=cafile,
certfile=certfile, keyfile=keyfile,
ssl_context=ssl_context)
client_socket = TSSLSocket(
host,
port,
socket_timeout=timeout,
socket_family=socket_family,
cafile=cafile,
certfile=certfile,
keyfile=keyfile,
ssl_context=ssl_context,
)
else:
socket = TSocket(host, port, socket_family=socket_family,
socket_timeout=timeout)
client_socket = TSocket(
host, port, socket_family=socket_family, socket_timeout=timeout
)
else:
raise ValueError("Either host/port or unix_socket"
" or url must be provided.")

transport = trans_factory.get_transport(socket)
transport = trans_factory.get_transport(client_socket)
protocol = proto_factory.get_protocol(transport)
transport.open()
return TClient(service, protocol)
Expand All @@ -53,16 +62,18 @@ def make_server(service, handler,
proto_factory=TBinaryProtocolFactory(),
trans_factory=TBufferedTransportFactory(),
client_timeout=3000, certfile=None,
socket_family=socket.AF_INET):
socket_family=None):
processor = TProcessor(service, handler)

if unix_socket:
socket_family = socket.AF_UNIX if not socket_family else socket_family
server_socket = TServerSocket(
unix_socket=unix_socket,
socket_family=socket_family)
if certfile:
warnings.warn("SSL only works with host:port, not unix_socket.")
elif host and port:
socket_family = socket.AF_INET if not socket_family else socket_family
if certfile:
server_socket = TSSLServerSocket(
host=host, port=port, client_timeout=client_timeout,
Expand All @@ -86,7 +97,7 @@ def client_context(service, host="localhost", port=9090, unix_socket=None,
trans_factory=TBufferedTransportFactory(),
timeout=None, socket_timeout=3000, connect_timeout=3000,
cafile=None, ssl_context=None, certfile=None, keyfile=None,
url="", socket_family=socket.AF_INET):
url="", socket_family=None):
if url:
parsed_url = urllib.parse.urlparse(url)
host = parsed_url.hostname or host
Expand All @@ -98,32 +109,43 @@ def client_context(service, host="localhost", port=9090, unix_socket=None,
socket_timeout = connect_timeout = timeout

if unix_socket:
socket = TSocket(unix_socket=unix_socket,
connect_timeout=connect_timeout,
socket_timeout=socket_timeout,
socket_family=socket_family)
socket_family = socket.AF_UNIX if not socket_family else socket_family
client_socket = TSocket(
unix_socket=unix_socket,
connect_timeout=connect_timeout,
socket_timeout=socket_timeout,
socket_family=socket_family,
)
if certfile:
warnings.warn("SSL only works with host:port, not unix_socket.")
elif host and port:
socket_family = socket.AF_INET if not socket_family else socket_family
if cafile or ssl_context:
socket = TSSLSocket(host, port,
connect_timeout=connect_timeout,
socket_timeout=socket_timeout,
cafile=cafile,
certfile=certfile, keyfile=keyfile,
ssl_context=ssl_context,
socket_family=socket_family)
client_socket = TSSLSocket(
host,
port,
connect_timeout=connect_timeout,
socket_timeout=socket_timeout,
cafile=cafile,
certfile=certfile,
keyfile=keyfile,
ssl_context=ssl_context,
socket_family=socket_family,
)
else:
socket = TSocket(host, port,
connect_timeout=connect_timeout,
socket_timeout=socket_timeout,
socket_family=socket_family)
client_socket = TSocket(
host,
port,
connect_timeout=connect_timeout,
socket_timeout=socket_timeout,
socket_family=socket_family,
)
else:
raise ValueError("Either host/port or unix_socket"
" or url must be provided.")

try:
transport = trans_factory.get_transport(socket)
transport = trans_factory.get_transport(client_socket)
protocol = proto_factory.get_protocol(transport)
transport.open()
yield TClient(service, protocol)
Expand Down

0 comments on commit f432bb4

Please sign in to comment.