forked from python-trio/trio-websocket
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path_impl.py
1572 lines (1359 loc) · 61 KB
/
_impl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
from __future__ import annotations
import sys
from collections import OrderedDict
from contextlib import asynccontextmanager
from functools import partial
from ipaddress import ip_address
import itertools
import logging
import random
import ssl
import struct
import urllib.parse
from typing import Iterable, List, Optional, Union
import outcome
import trio
import trio.abc
from wsproto import ConnectionType, WSConnection
from wsproto.connection import ConnectionState
import wsproto.frame_protocol as wsframeproto
from wsproto.events import (
AcceptConnection,
BytesMessage,
CloseConnection,
Ping,
Pong,
RejectConnection,
RejectData,
Request,
TextMessage,
)
import wsproto.utilities
if sys.version_info < (3, 11): # pragma: no cover
# pylint doesn't care about the version_info check, so need to ignore the warning
from exceptiongroup import BaseExceptionGroup # pylint: disable=redefined-builtin
_IS_TRIO_MULTI_ERROR = tuple(map(int, trio.__version__.split('.')[:2])) < (0, 22)
if _IS_TRIO_MULTI_ERROR:
_TRIO_EXC_GROUP_TYPE = trio.MultiError # type: ignore[attr-defined] # pylint: disable=no-member
else:
_TRIO_EXC_GROUP_TYPE = BaseExceptionGroup # pylint: disable=possibly-used-before-assignment
CONN_TIMEOUT = 60 # default connect & disconnect timeout, in seconds
MESSAGE_QUEUE_SIZE = 1
MAX_MESSAGE_SIZE = 2 ** 20 # 1 MiB
RECEIVE_BYTES = 4 * 2 ** 10 # 4 KiB
logger = logging.getLogger('trio-websocket')
class TrioWebsocketInternalError(Exception):
"""Raised as a fallback when open_websocket is unable to unwind an exceptiongroup
into a single preferred exception. This should never happen, if it does then
underlying assumptions about the internal code are incorrect.
"""
def _ignore_cancel(exc):
return None if isinstance(exc, trio.Cancelled) else exc
class _preserve_current_exception:
"""A context manager which should surround an ``__exit__`` or
``__aexit__`` handler or the contents of a ``finally:``
block. It ensures that any exception that was being handled
upon entry is not masked by a `trio.Cancelled` raised within
the body of the context manager.
https://github.com/python-trio/trio/issues/1559
https://gitter.im/python-trio/general?at=5faf2293d37a1a13d6a582cf
"""
__slots__ = ("_armed",)
def __init__(self):
self._armed = False
def __enter__(self):
self._armed = sys.exc_info()[1] is not None
def __exit__(self, ty, value, tb):
if value is None or not self._armed:
return False
if _IS_TRIO_MULTI_ERROR: # pragma: no cover
filtered_exception = trio.MultiError.filter(_ignore_cancel, value) # pylint: disable=no-member
elif isinstance(value, BaseExceptionGroup): # pylint: disable=possibly-used-before-assignment
filtered_exception = value.subgroup(lambda exc: not isinstance(exc, trio.Cancelled))
else:
filtered_exception = _ignore_cancel(value)
return filtered_exception is None
@asynccontextmanager
async def open_websocket(
host: str,
port: int,
resource: str,
*,
use_ssl: Union[bool, ssl.SSLContext],
subprotocols: Optional[Iterable[str]] = None,
extra_headers: Optional[list[tuple[bytes,bytes]]] = None,
message_queue_size: int = MESSAGE_QUEUE_SIZE,
max_message_size: int = MAX_MESSAGE_SIZE,
connect_timeout: float = CONN_TIMEOUT,
disconnect_timeout: float = CONN_TIMEOUT
):
'''
Open a WebSocket client connection to a host.
This async context manager connects when entering the context manager and
disconnects when exiting. It yields a
:class:`WebSocketConnection` instance.
:param str host: The host to connect to.
:param int port: The port to connect to.
:param str resource: The resource, i.e. URL path.
:param Union[bool, ssl.SSLContext] use_ssl: If this is an SSL context, then
use that context. If this is ``True`` then use default SSL context. If
this is ``False`` then disable SSL.
:param subprotocols: An iterable of strings representing preferred
subprotocols.
:param list[tuple[bytes,bytes]] extra_headers: A list of 2-tuples containing
HTTP header key/value pairs to send with the connection request. Note
that headers used by the WebSocket protocol (e.g.
``Sec-WebSocket-Accept``) will be overwritten.
:param int message_queue_size: The maximum number of messages that will be
buffered in the library's internal message queue.
:param int max_message_size: The maximum message size as measured by
``len()``. If a message is received that is larger than this size,
then the connection is closed with code 1009 (Message Too Big).
:param float connect_timeout: The number of seconds to wait for the
connection before timing out.
:param float disconnect_timeout: The number of seconds to wait when closing
the connection before timing out.
:raises HandshakeError: for any networking error,
client-side timeout (:exc:`ConnectionTimeout`, :exc:`DisconnectionTimeout`),
or server rejection (:exc:`ConnectionRejected`) during handshakes.
'''
# This context manager tries very very hard not to raise an exceptiongroup
# in order to be as transparent as possible for the end user.
# In the trivial case, this means that if user code inside the cm raises
# we make sure that it doesn't get wrapped.
# If opening the connection fails, then we will raise that exception. User
# code is never executed, so we will never have multiple exceptions.
# After opening the connection, we spawn _reader_task in the background and
# yield to user code. If only one of those raise a non-cancelled exception
# we will raise that non-cancelled exception.
# If we get multiple cancelled, we raise the user's cancelled.
# If both raise exceptions, we raise the user code's exception with the entire
# exception group as the __cause__.
# If we somehow get multiple exceptions, but no user exception, then we raise
# TrioWebsocketInternalError.
# If closing the connection fails, then that will be raised as the top
# exception in the last `finally`. If we encountered exceptions in user code
# or in reader task then they will be set as the `__cause__`.
async def _open_connection(nursery: trio.Nursery) -> WebSocketConnection:
try:
with trio.fail_after(connect_timeout):
return await connect_websocket(nursery, host, port,
resource, use_ssl=use_ssl, subprotocols=subprotocols,
extra_headers=extra_headers,
message_queue_size=message_queue_size,
max_message_size=max_message_size)
except trio.TooSlowError:
raise ConnectionTimeout from None
except OSError as e:
raise HandshakeError from e
async def _close_connection(connection: WebSocketConnection) -> None:
try:
with trio.fail_after(disconnect_timeout):
await connection.aclose()
except trio.TooSlowError:
raise DisconnectionTimeout from None
connection: WebSocketConnection|None=None
close_result: outcome.Maybe[None] | None = None
user_error = None
try:
async with trio.open_nursery() as new_nursery:
result = await outcome.acapture(_open_connection, new_nursery)
if isinstance(result, outcome.Value):
connection = result.unwrap()
try:
yield connection
except BaseException as e:
user_error = e
raise
finally:
close_result = await outcome.acapture(_close_connection, connection)
# This exception handler should only be entered if either:
# 1. The _reader_task started in connect_websocket raises
# 2. User code raises an exception
# I.e. open/close_connection are not included
except _TRIO_EXC_GROUP_TYPE as e:
# user_error, or exception bubbling up from _reader_task
if len(e.exceptions) == 1:
raise e.exceptions[0]
# contains at most 1 non-cancelled exceptions
exception_to_raise: BaseException|None = None
for sub_exc in e.exceptions:
if not isinstance(sub_exc, trio.Cancelled):
if exception_to_raise is not None:
# multiple non-cancelled
break
exception_to_raise = sub_exc
else:
if exception_to_raise is None:
# all exceptions are cancelled
# prefer raising the one from the user, for traceback reasons
if user_error is not None:
# no reason to raise from e, just to include a bunch of extra
# cancelleds.
raise user_error # pylint: disable=raise-missing-from
# multiple internal Cancelled is not possible afaik
raise e.exceptions[0] # pragma: no cover # pylint: disable=raise-missing-from
raise exception_to_raise
# if we have any KeyboardInterrupt in the group, make sure to raise it.
for sub_exc in e.exceptions:
if isinstance(sub_exc, KeyboardInterrupt):
raise sub_exc from e
# Both user code and internal code raised non-cancelled exceptions.
# We "hide" the internal exception(s) in the __cause__ and surface
# the user_error.
if user_error is not None:
raise user_error from e
raise TrioWebsocketInternalError(
"The trio-websocket API is not expected to raise multiple exceptions. "
"Please report this as a bug to "
"https://github.com/python-trio/trio-websocket"
) from e # pragma: no cover
finally:
if close_result is not None:
close_result.unwrap()
# error setting up, unwrap that exception
if connection is None:
result.unwrap()
async def connect_websocket(nursery, host, port, resource, *, use_ssl,
subprotocols=None, extra_headers=None,
message_queue_size=MESSAGE_QUEUE_SIZE, max_message_size=MAX_MESSAGE_SIZE
) -> WebSocketConnection:
'''
Return an open WebSocket client connection to a host.
This function is used to specify a custom nursery to run connection
background tasks in. The caller is responsible for closing the connection.
If you don't need a custom nursery, you should probably use
:func:`open_websocket` instead.
:param nursery: A Trio nursery to run background tasks in.
:param str host: The host to connect to.
:param int port: The port to connect to.
:param str resource: The resource, i.e. URL path.
:param Union[bool, ssl.SSLContext] use_ssl: If this is an SSL context, then
use that context. If this is ``True`` then use default SSL context. If
this is ``False`` then disable SSL.
:param subprotocols: An iterable of strings representing preferred
subprotocols.
:param list[tuple[bytes,bytes]] extra_headers: A list of 2-tuples containing
HTTP header key/value pairs to send with the connection request. Note
that headers used by the WebSocket protocol (e.g.
``Sec-WebSocket-Accept``) will be overwritten.
:param int message_queue_size: The maximum number of messages that will be
buffered in the library's internal message queue.
:param int max_message_size: The maximum message size as measured by
``len()``. If a message is received that is larger than this size,
then the connection is closed with code 1009 (Message Too Big).
:rtype: WebSocketConnection
'''
if use_ssl is True:
ssl_context = ssl.create_default_context()
elif use_ssl is False:
ssl_context = None
elif isinstance(use_ssl, ssl.SSLContext):
ssl_context = use_ssl
else:
raise TypeError('`use_ssl` argument must be bool or ssl.SSLContext')
logger.debug('Connecting to ws%s://%s:%d%s',
'' if ssl_context is None else 's', host, port, resource)
stream: trio.SSLStream[trio.SocketStream] | trio.SocketStream
if ssl_context is None:
stream = await trio.open_tcp_stream(host, port)
else:
stream = await trio.open_ssl_over_tcp_stream(host, port,
ssl_context=ssl_context, https_compatible=True)
if port in (80, 443):
host_header = host
else:
host_header = f'{host}:{port}'
connection = WebSocketConnection(stream,
WSConnection(ConnectionType.CLIENT),
host=host_header,
path=resource,
client_subprotocols=subprotocols, client_extra_headers=extra_headers,
message_queue_size=message_queue_size,
max_message_size=max_message_size)
nursery.start_soon(connection._reader_task)
await connection._open_handshake.wait()
return connection
def open_websocket_url(url, ssl_context=None, *, subprotocols=None,
extra_headers=None,
message_queue_size=MESSAGE_QUEUE_SIZE, max_message_size=MAX_MESSAGE_SIZE,
connect_timeout=CONN_TIMEOUT, disconnect_timeout=CONN_TIMEOUT):
'''
Open a WebSocket client connection to a URL.
This async context manager connects when entering the context manager and
disconnects when exiting. It yields a
:class:`WebSocketConnection` instance.
:param str url: A WebSocket URL, i.e. `ws:` or `wss:` URL scheme.
:param ssl_context: Optional SSL context used for ``wss:`` URLs. A default
SSL context is used for ``wss:`` if this argument is ``None``.
:type ssl_context: ssl.SSLContext or None
:param subprotocols: An iterable of strings representing preferred
subprotocols.
:param list[tuple[bytes,bytes]] extra_headers: A list of 2-tuples containing
HTTP header key/value pairs to send with the connection request. Note
that headers used by the WebSocket protocol (e.g.
``Sec-WebSocket-Accept``) will be overwritten.
:param int message_queue_size: The maximum number of messages that will be
buffered in the library's internal message queue.
:param int max_message_size: The maximum message size as measured by
``len()``. If a message is received that is larger than this size,
then the connection is closed with code 1009 (Message Too Big).
:param float connect_timeout: The number of seconds to wait for the
connection before timing out.
:param float disconnect_timeout: The number of seconds to wait when closing
the connection before timing out.
:raises HandshakeError: for any networking error,
client-side timeout (:exc:`ConnectionTimeout`, :exc:`DisconnectionTimeout`),
or server rejection (:exc:`ConnectionRejected`) during handshakes.
'''
host, port, resource, ssl_context = _url_to_host(url, ssl_context)
return open_websocket(host, port, resource, use_ssl=ssl_context,
subprotocols=subprotocols, extra_headers=extra_headers,
message_queue_size=message_queue_size,
max_message_size=max_message_size,
connect_timeout=connect_timeout, disconnect_timeout=disconnect_timeout)
async def connect_websocket_url(nursery, url, ssl_context=None, *,
subprotocols=None, extra_headers=None,
message_queue_size=MESSAGE_QUEUE_SIZE, max_message_size=MAX_MESSAGE_SIZE):
'''
Return an open WebSocket client connection to a URL.
This function is used to specify a custom nursery to run connection
background tasks in. The caller is responsible for closing the connection.
If you don't need a custom nursery, you should probably use
:func:`open_websocket_url` instead.
:param nursery: A nursery to run background tasks in.
:param str url: A WebSocket URL.
:param ssl_context: Optional SSL context used for ``wss:`` URLs.
:type ssl_context: ssl.SSLContext or None
:param subprotocols: An iterable of strings representing preferred
subprotocols.
:param list[tuple[bytes,bytes]] extra_headers: A list of 2-tuples containing
HTTP header key/value pairs to send with the connection request. Note
that headers used by the WebSocket protocol (e.g.
``Sec-WebSocket-Accept``) will be overwritten.
:param int message_queue_size: The maximum number of messages that will be
buffered in the library's internal message queue.
:param int max_message_size: The maximum message size as measured by
``len()``. If a message is received that is larger than this size,
then the connection is closed with code 1009 (Message Too Big).
:rtype: WebSocketConnection
'''
host, port, resource, ssl_context = _url_to_host(url, ssl_context)
return await connect_websocket(nursery, host, port, resource,
use_ssl=ssl_context, subprotocols=subprotocols,
extra_headers=extra_headers, message_queue_size=message_queue_size,
max_message_size=max_message_size)
def _url_to_host(url, ssl_context):
'''
Convert a WebSocket URL to a (host,port,resource) tuple.
The returned ``ssl_context`` is either the same object that was passed in,
or if ``ssl_context`` is None, then a bool indicating if a default SSL
context needs to be created.
:param str url: A WebSocket URL.
:type ssl_context: ssl.SSLContext or None
:returns: A tuple of ``(host, port, resource, ssl_context)``.
'''
url = str(url) # For backward compat with isinstance(url, yarl.URL).
parts = urllib.parse.urlsplit(url)
if parts.scheme not in ('ws', 'wss'):
raise ValueError('WebSocket URL scheme must be "ws:" or "wss:"')
if ssl_context is None:
ssl_context = parts.scheme == 'wss'
elif parts.scheme == 'ws':
raise ValueError('SSL context must be None for ws: URL scheme')
host = parts.hostname
if parts.port is not None:
port = parts.port
else:
port = 443 if ssl_context else 80
path_qs = parts.path
# RFC 7230, Section 5.3.1:
# If the target URI's path component is empty, the client MUST
# send "/" as the path within the origin-form of request-target.
if not path_qs:
path_qs = '/'
if '?' in url:
path_qs += '?' + parts.query
return host, port, path_qs, ssl_context
async def wrap_client_stream(nursery, stream, host, resource, *,
subprotocols=None, extra_headers=None,
message_queue_size=MESSAGE_QUEUE_SIZE, max_message_size=MAX_MESSAGE_SIZE):
'''
Wrap an arbitrary stream in a WebSocket connection.
This is a low-level function only needed in rare cases. In most cases, you
should use :func:`open_websocket` or :func:`open_websocket_url`.
:param nursery: A Trio nursery to run background tasks in.
:param stream: A Trio stream to be wrapped.
:type stream: trio.abc.Stream
:param str host: A host string that will be sent in the ``Host:`` header.
:param str resource: A resource string, i.e. the path component to be
accessed on the server.
:param subprotocols: An iterable of strings representing preferred
subprotocols.
:param list[tuple[bytes,bytes]] extra_headers: A list of 2-tuples containing
HTTP header key/value pairs to send with the connection request. Note
that headers used by the WebSocket protocol (e.g.
``Sec-WebSocket-Accept``) will be overwritten.
:param int message_queue_size: The maximum number of messages that will be
buffered in the library's internal message queue.
:param int max_message_size: The maximum message size as measured by
``len()``. If a message is received that is larger than this size,
then the connection is closed with code 1009 (Message Too Big).
:rtype: WebSocketConnection
'''
connection = WebSocketConnection(stream,
WSConnection(ConnectionType.CLIENT),
host=host, path=resource,
client_subprotocols=subprotocols, client_extra_headers=extra_headers,
message_queue_size=message_queue_size,
max_message_size=max_message_size)
nursery.start_soon(connection._reader_task)
await connection._open_handshake.wait()
return connection
async def wrap_server_stream(nursery, stream,
message_queue_size=MESSAGE_QUEUE_SIZE, max_message_size=MAX_MESSAGE_SIZE):
'''
Wrap an arbitrary stream in a server-side WebSocket.
This is a low-level function only needed in rare cases. In most cases, you
should use :func:`serve_websocket`.
:param nursery: A nursery to run background tasks in.
:param stream: A stream to be wrapped.
:param int message_queue_size: The maximum number of messages that will be
buffered in the library's internal message queue.
:param int max_message_size: The maximum message size as measured by
``len()``. If a message is received that is larger than this size,
then the connection is closed with code 1009 (Message Too Big).
:type stream: trio.abc.Stream
:rtype: WebSocketRequest
'''
connection = WebSocketConnection(stream,
WSConnection(ConnectionType.SERVER),
message_queue_size=message_queue_size,
max_message_size=max_message_size)
nursery.start_soon(connection._reader_task)
request = await connection._get_request()
return request
async def serve_websocket(handler, host, port, ssl_context, *,
handler_nursery=None, message_queue_size=MESSAGE_QUEUE_SIZE,
max_message_size=MAX_MESSAGE_SIZE, connect_timeout=CONN_TIMEOUT,
disconnect_timeout=CONN_TIMEOUT, task_status=trio.TASK_STATUS_IGNORED):
'''
Serve a WebSocket over TCP.
This function supports the Trio nursery start protocol: ``server = await
nursery.start(serve_websocket, …)``. It will block until the server
is accepting connections and then return a :class:`WebSocketServer` object.
Note that if ``host`` is ``None`` and ``port`` is zero, then you may get
multiple listeners that have *different port numbers!*
:param handler: An async function that is invoked with a request
for each new connection.
:param host: The host interface to bind. This can be an address of an
interface, a name that resolves to an interface address (e.g.
``localhost``), or a wildcard address like ``0.0.0.0`` for IPv4 or
``::`` for IPv6. If ``None``, then all local interfaces are bound.
:type host: str, bytes, or None
:param int port: The port to bind to.
:param ssl_context: The SSL context to use for encrypted connections, or
``None`` for unencrypted connection.
:type ssl_context: ssl.SSLContext or None
:param handler_nursery: An optional nursery to spawn handlers and background
tasks in. If not specified, a new nursery will be created internally.
:param int message_queue_size: The maximum number of messages that will be
buffered in the library's internal message queue.
:param int max_message_size: The maximum message size as measured by
``len()``. If a message is received that is larger than this size,
then the connection is closed with code 1009 (Message Too Big).
:param float connect_timeout: The number of seconds to wait for a client
to finish connection handshake before timing out.
:param float disconnect_timeout: The number of seconds to wait for a client
to finish the closing handshake before timing out.
:param task_status: Part of Trio nursery start protocol.
:returns: This function runs until cancelled.
'''
if ssl_context is None:
open_tcp_listeners = partial(trio.open_tcp_listeners, port, host=host)
else:
open_tcp_listeners = partial(trio.open_ssl_over_tcp_listeners, port,
ssl_context, host=host, https_compatible=True)
listeners = await open_tcp_listeners()
server = WebSocketServer(handler, listeners,
handler_nursery=handler_nursery, message_queue_size=message_queue_size,
max_message_size=max_message_size, connect_timeout=connect_timeout,
disconnect_timeout=disconnect_timeout)
await server.run(task_status=task_status)
class HandshakeError(Exception):
'''
There was an error during connection or disconnection with the websocket
server.
'''
class ConnectionTimeout(HandshakeError):
'''There was a timeout when connecting to the websocket server.'''
class DisconnectionTimeout(HandshakeError):
'''There was a timeout when disconnecting from the websocket server.'''
class ConnectionClosed(Exception):
'''
A WebSocket operation cannot be completed because the connection is closed
or in the process of closing.
'''
def __init__(self, reason):
'''
Constructor.
:param reason:
:type reason: CloseReason
'''
super().__init__()
self.reason = reason
def __repr__(self):
''' Return representation. '''
return f'{self.__class__.__name__}<{self.reason}>'
class ConnectionRejected(HandshakeError):
'''
A WebSocket connection could not be established because the server rejected
the connection attempt.
'''
def __init__(self, status_code, headers, body):
'''
Constructor.
:param reason:
:type reason: CloseReason
'''
super().__init__()
#: a 3 digit HTTP status code
self.status_code = status_code
#: a tuple of 2-tuples containing header key/value pairs
self.headers = headers
#: an optional ``bytes`` response body
self.body = body
def __repr__(self):
''' Return representation. '''
return f'{self.__class__.__name__}<status_code={self.status_code}>'
class CloseReason:
''' Contains information about why a WebSocket was closed. '''
def __init__(self, code, reason):
'''
Constructor.
:param int code:
:param Optional[str] reason:
'''
self._code = code
try:
self._name = wsframeproto.CloseReason(code).name
except ValueError:
if 1000 <= code <= 2999:
self._name = 'RFC_RESERVED'
elif 3000 <= code <= 3999:
self._name = 'IANA_RESERVED'
elif 4000 <= code <= 4999:
self._name = 'PRIVATE_RESERVED'
else:
self._name = 'INVALID_CODE'
self._reason = reason
@property
def code(self):
''' (Read-only) The numeric close code. '''
return self._code
@property
def name(self):
''' (Read-only) The human-readable close code. '''
return self._name
@property
def reason(self):
''' (Read-only) An arbitrary reason string. '''
return self._reason
def __repr__(self):
''' Show close code, name, and reason. '''
return f'{self.__class__.__name__}' \
f'<code={self.code}, name={self.name}, reason={self.reason}>'
class Future:
''' Represents a value that will be available in the future. '''
def __init__(self):
''' Constructor. '''
self._value = None
self._value_event = trio.Event()
def set_value(self, value):
'''
Set a value, which will notify any waiters.
:param value:
'''
self._value = value
self._value_event.set()
async def wait_value(self):
'''
Wait for this future to have a value, then return it.
:returns: The value set by ``set_value()``.
'''
await self._value_event.wait()
return self._value
class WebSocketRequest:
'''
Represents a handshake presented by a client to a server.
The server may modify the handshake or leave it as is. The server should
call ``accept()`` to finish the handshake and obtain a connection object.
'''
def __init__(self, connection, event):
'''
Constructor.
:param WebSocketConnection connection:
:type event: wsproto.events.Request
'''
self._connection = connection
self._event = event
@property
def headers(self):
'''
HTTP headers represented as a list of (name, value) pairs.
:rtype: list[tuple]
'''
return self._event.extra_headers
@property
def path(self):
'''
The requested URL path.
:rtype: str
'''
return self._event.target
@property
def proposed_subprotocols(self):
'''
A tuple of protocols proposed by the client.
:rtype: tuple[str]
'''
return tuple(self._event.subprotocols)
@property
def local(self):
'''
The connection's local endpoint.
:rtype: Endpoint or str
'''
return self._connection.local
@property
def remote(self):
'''
The connection's remote endpoint.
:rtype: Endpoint or str
'''
return self._connection.remote
async def accept(self, *, subprotocol=None, extra_headers=None):
'''
Accept the request and return a connection object.
:param subprotocol: The selected subprotocol for this connection.
:type subprotocol: str or None
:param extra_headers: A list of 2-tuples containing key/value pairs to
send as HTTP headers.
:type extra_headers: list[tuple[bytes,bytes]] or None
:rtype: WebSocketConnection
'''
if extra_headers is None:
extra_headers = []
await self._connection._accept(self._event, subprotocol, extra_headers)
return self._connection
async def reject(self, status_code, *, extra_headers=None, body=None):
'''
Reject the handshake.
:param int status_code: The 3 digit HTTP status code. In order to be
RFC-compliant, this should NOT be 101, and would ideally be an
appropriate code in the range 300-599.
:param list[tuple[bytes,bytes]] extra_headers: A list of 2-tuples
containing key/value pairs to send as HTTP headers.
:param body: If provided, this data will be sent in the response
body, otherwise no response body will be sent.
:type body: bytes or None
'''
extra_headers = extra_headers or []
body = body or b''
await self._connection._reject(status_code, extra_headers, body)
def _get_stream_endpoint(stream, *, local):
'''
Construct an endpoint from a stream.
:param trio.Stream stream:
:param bool local: If true, return local endpoint. Otherwise return remote.
:returns: An endpoint instance or ``repr()`` for streams that cannot be
represented as an endpoint.
:rtype: Endpoint or str
'''
socket, is_ssl = None, False
if isinstance(stream, trio.SocketStream):
socket = stream.socket
elif isinstance(stream, trio.SSLStream):
socket = stream.transport_stream.socket
is_ssl = True
if socket:
addr, port, *_ = socket.getsockname() if local else socket.getpeername()
endpoint = Endpoint(addr, port, is_ssl)
else:
endpoint = repr(stream)
return endpoint
class WebSocketConnection(trio.abc.AsyncResource):
''' A WebSocket connection. '''
CONNECTION_ID = itertools.count()
def __init__(
self,
stream: trio.SocketStream | trio.SSLStream[trio.SocketStream],
ws_connection: wsproto.WSConnection,
*,
host=None,
path=None,
client_subprotocols=None, client_extra_headers=None,
message_queue_size=MESSAGE_QUEUE_SIZE,
max_message_size=MAX_MESSAGE_SIZE
):
'''
Constructor.
Generally speaking, users are discouraged from directly instantiating a
``WebSocketConnection`` and should instead use one of the convenience
functions in this module, e.g. ``open_websocket()`` or
``serve_websocket()``. This class has some tricky internal logic and
timing that depends on whether it is an instance of a client connection
or a server connection. The convenience functions handle this complexity
for you.
:param SocketStream stream:
:param ws_connection wsproto.WSConnection:
:param str host: The hostname to send in the HTTP request headers. Only
used for client connections.
:param str path: The URL path for this connection.
:param list client_subprotocols: A list of desired subprotocols. Only
used for client connections.
:param list[tuple[bytes,bytes]] client_extra_headers: Extra headers to
send with the connection request. Only used for client connections.
:param int message_queue_size: The maximum number of messages that will be
buffered in the library's internal message queue.
:param int max_message_size: The maximum message size as measured by
``len()``. If a message is received that is larger than this size,
then the connection is closed with code 1009 (Message Too Big).
'''
# NOTE: The implementation uses _close_reason for more than an advisory
# purpose. It's critical internal state, indicating when the
# connection is closed or closing.
self._close_reason: Optional[CloseReason] = None
self._id = next(self.__class__.CONNECTION_ID)
self._stream = stream
self._stream_lock = trio.StrictFIFOLock()
self._wsproto = ws_connection
self._message_size = 0
self._message_parts: List[Union[bytes, str]] = []
self._max_message_size = max_message_size
self._reader_running = True
if ws_connection.client:
self._initial_request: Optional[Request] = Request(host=host, target=path,
subprotocols=client_subprotocols,
extra_headers=client_extra_headers or [])
else:
self._initial_request = None
self._path = path
self._subprotocol: Optional[str] = None
self._handshake_headers: tuple[tuple[str,str], ...] = tuple()
self._reject_status = 0
self._reject_headers: tuple[tuple[str,str], ...] = tuple()
self._reject_body = b''
self._send_channel, self._recv_channel = trio.open_memory_channel[
Union[bytes, str]
](message_queue_size)
self._pings: OrderedDict[bytes, trio.Event] = OrderedDict()
# Set when the server has received a connection request event. This
# future is never set on client connections.
self._connection_proposal = Future()
# Set once the WebSocket open handshake takes place, i.e.
# ConnectionRequested for server or ConnectedEstablished for client.
self._open_handshake = trio.Event()
# Set once a WebSocket closed handshake takes place, i.e after a close
# frame has been sent and a close frame has been received.
self._close_handshake = trio.Event()
# Set upon receiving CloseConnection from peer.
# Used to test close race conditions between client and server.
self._for_testing_peer_closed_connection = trio.Event()
@property
def closed(self):
'''
(Read-only) The reason why the connection was or is being closed,
else ``None``.
:rtype: Optional[CloseReason]
'''
return self._close_reason
@property
def is_client(self):
''' (Read-only) Is this a client instance? '''
return self._wsproto.client
@property
def is_server(self):
''' (Read-only) Is this a server instance? '''
return not self._wsproto.client
@property
def local(self):
'''
The local endpoint of the connection.
:rtype: Endpoint or str
'''
return _get_stream_endpoint(self._stream, local=True)
@property
def remote(self):
'''
The remote endpoint of the connection.
:rtype: Endpoint or str
'''
return _get_stream_endpoint(self._stream, local=False)
@property
def path(self):
'''
The requested URL path. For clients, this is set when the connection is
instantiated. For servers, it is set after the handshake completes.
:rtype: str
'''
return self._path
@property
def subprotocol(self):
'''
(Read-only) The negotiated subprotocol, or ``None`` if there is no
subprotocol.
This is only valid after the opening handshake is complete.
:rtype: str or None
'''
return self._subprotocol
@property
def handshake_headers(self):
'''
The HTTP headers that were sent by the remote during the handshake,
stored as 2-tuples containing key/value pairs. Header keys are always
lower case.
:rtype: tuple[tuple[str,str]]
'''
return self._handshake_headers
async def aclose(self, code=1000, reason=None): # pylint: disable=arguments-differ
'''
Close the WebSocket connection.
This sends a closing frame and suspends until the connection is closed.
After calling this method, any further I/O on this WebSocket (such as
``get_message()`` or ``send_message()``) will raise
``ConnectionClosed``.
This method is idempotent: it may be called multiple times on the same
connection without any errors.
:param int code: A 4-digit code number indicating the type of closure.
:param str reason: An optional string describing the closure.
'''
with _preserve_current_exception():
await self._aclose(code, reason)
async def _aclose(self, code, reason):
if self._close_reason:
# Per AsyncResource interface, calling aclose() on a closed resource
# should succeed.
return
try:
if self._wsproto.state == ConnectionState.OPEN:
# Our side is initiating the close, so send a close connection
# event to peer, while setting the local close reason to normal.
self._close_reason = CloseReason(1000, None)
await self._send(CloseConnection(code=code, reason=reason))
elif self._wsproto.state in (ConnectionState.CONNECTING,
ConnectionState.REJECTING):
self._close_handshake.set()
# TODO: shouldn't the receive channel be closed earlier, so that
# get_message() during send of the CloseConneciton event fails?
await self._recv_channel.aclose()
await self._close_handshake.wait()
except ConnectionClosed:
# If _send() raised ConnectionClosed, then we can bail out.
pass
finally:
# If cancelled during WebSocket close, make sure that the underlying
# stream is closed.
await self._close_stream()
async def get_message(self):