Skip to content

Commit

Permalink
Rewrite interactive client with synchronous API.
Browse files Browse the repository at this point in the history
Fix #1312.
  • Loading branch information
aaugustin committed Apr 1, 2023
1 parent 25a5252 commit ce06dd6
Showing 1 changed file with 34 additions and 105 deletions.
139 changes: 34 additions & 105 deletions src/websockets/__main__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
from __future__ import annotations

import argparse
import asyncio
import os
import signal
import sys
import threading
from typing import Any, Set

from .exceptions import ConnectionClosed
from .frames import Close
from .legacy.client import connect

try:
import readline # noqa
except ImportError: # Windows has no `readline` normally
pass

from .sync.client import ClientConnection, connect
from .version import version as websockets_version


Expand Down Expand Up @@ -46,21 +48,6 @@ def win_enable_vt100() -> None:
raise RuntimeError("unable to set console mode")


def exit_from_event_loop_thread(
loop: asyncio.AbstractEventLoop,
stop: asyncio.Future[None],
) -> None:
loop.stop()
if not stop.done():
# When exiting the thread that runs the event loop, raise
# KeyboardInterrupt in the main thread to exit the program.
if sys.platform == "win32":
ctrl_c = signal.CTRL_C_EVENT
else:
ctrl_c = signal.SIGINT
os.kill(os.getpid(), ctrl_c)


def print_during_input(string: str) -> None:
sys.stdout.write(
# Save cursor position
Expand Down Expand Up @@ -93,63 +80,20 @@ def print_over_input(string: str) -> None:
sys.stdout.flush()


async def run_client(
uri: str,
loop: asyncio.AbstractEventLoop,
inputs: asyncio.Queue[str],
stop: asyncio.Future[None],
) -> None:
try:
websocket = await connect(uri)
except Exception as exc:
print_over_input(f"Failed to connect to {uri}: {exc}.")
exit_from_event_loop_thread(loop, stop)
return
else:
print_during_input(f"Connected to {uri}.")

try:
while True:
incoming: asyncio.Future[Any] = asyncio.create_task(websocket.recv())
outgoing: asyncio.Future[Any] = asyncio.create_task(inputs.get())
done: Set[asyncio.Future[Any]]
pending: Set[asyncio.Future[Any]]
done, pending = await asyncio.wait(
[incoming, outgoing, stop], return_when=asyncio.FIRST_COMPLETED
)

# Cancel pending tasks to avoid leaking them.
if incoming in pending:
incoming.cancel()
if outgoing in pending:
outgoing.cancel()

if incoming in done:
try:
message = incoming.result()
except ConnectionClosed:
break
else:
if isinstance(message, str):
print_during_input("< " + message)
else:
print_during_input("< (binary) " + message.hex())

if outgoing in done:
message = outgoing.result()
await websocket.send(message)

if stop in done:
break

finally:
await websocket.close()
assert websocket.close_code is not None and websocket.close_reason is not None
close_status = Close(websocket.close_code, websocket.close_reason)

print_over_input(f"Connection closed: {close_status}.")

exit_from_event_loop_thread(loop, stop)
def print_incoming_messages(websocket: ClientConnection, stop: threading.Event) -> None:
for message in websocket:
if isinstance(message, str):
print_during_input("< " + message)
else:
print_during_input("< (binary) " + message.hex())
if not stop.is_set():
# When the server closes the connection, raise KeyboardInterrupt
# in the main thread to exit the program.
if sys.platform == "win32":
ctrl_c = signal.CTRL_C_EVENT
else:
ctrl_c = signal.SIGINT
os.kill(os.getpid(), ctrl_c)


def main() -> None:
Expand Down Expand Up @@ -184,47 +128,32 @@ def main() -> None:
sys.stderr.flush()

try:
import readline # noqa
except ImportError: # Windows has no `readline` normally
pass

# Create an event loop that will run in a background thread.
loop = asyncio.new_event_loop()

# Due to zealous removal of the loop parameter in the Queue constructor,
# we need a factory coroutine to run in the freshly created event loop.
async def queue_factory() -> asyncio.Queue[str]:
return asyncio.Queue()

# Create a queue of user inputs. There's no need to limit its size.
inputs: asyncio.Queue[str] = loop.run_until_complete(queue_factory())

# Create a stop condition when receiving SIGINT or SIGTERM.
stop: asyncio.Future[None] = loop.create_future()
websocket = connect(args.uri)
except Exception as exc:
print(f"Failed to connect to {args.uri}: {exc}.")
sys.exit(1)
else:
print(f"Connected to {args.uri}.")

# Schedule the task that will manage the connection.
loop.create_task(run_client(args.uri, loop, inputs, stop))
stop = threading.Event()

# Start the event loop in a background thread.
thread = threading.Thread(target=loop.run_forever)
# Start the thread that reads messages from the connection.
thread = threading.Thread(target=print_incoming_messages, args=(websocket, stop))
thread.start()

# Read from stdin in the main thread in order to receive signals.
try:
while True:
# Since there's no size limit, put_nowait is identical to put.
message = input("> ")
loop.call_soon_threadsafe(inputs.put_nowait, message)
websocket.send(message)
except (KeyboardInterrupt, EOFError): # ^C, ^D
loop.call_soon_threadsafe(stop.set_result, None)
stop.set()
websocket.close()
print_over_input("Connection closed.")

# Wait for the event loop to terminate.
thread.join()

# For reasons unclear, even though the loop is closed in the thread,
# it still thinks it's running here.
loop.close()


if __name__ == "__main__":
main()

0 comments on commit ce06dd6

Please sign in to comment.