Skip to content

Commit

Permalink
Handle exceptions when parsing opening handshake.
Browse files Browse the repository at this point in the history
  • Loading branch information
aaugustin committed Dec 18, 2022
1 parent 23a2d3f commit e4fcab1
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 11 deletions.
27 changes: 17 additions & 10 deletions src/websockets/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,11 +316,17 @@ def send_request(self, request: Request) -> None:

def parse(self) -> Generator[None, None, None]:
if self.state is CONNECTING:
response = yield from Response.parse(
self.reader.read_line,
self.reader.read_exact,
self.reader.read_to_eof,
)
try:
response = yield from Response.parse(
self.reader.read_line,
self.reader.read_exact,
self.reader.read_to_eof,
)
except Exception as exc:
self.handshake_exc = exc
self.parser = self.discard()
next(self.parser) # start coroutine
yield

if self.debug:
code, phrase = response.status_code, response.reason_phrase
Expand All @@ -334,14 +340,15 @@ def parse(self) -> Generator[None, None, None]:
self.process_response(response)
except InvalidHandshake as exc:
response._exception = exc
self.events.append(response)
self.handshake_exc = exc
self.parser = self.discard()
next(self.parser) # start coroutine
else:
assert self.state is CONNECTING
self.state = OPEN
finally:
self.events.append(response)
yield

assert self.state is CONNECTING
self.state = OPEN
self.events.append(response)

yield from super().parse()

Expand Down
11 changes: 10 additions & 1 deletion src/websockets/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,16 @@ def send_response(self, response: Response) -> None:

def parse(self) -> Generator[None, None, None]:
if self.state is CONNECTING:
request = yield from Request.parse(self.reader.read_line)
try:
request = yield from Request.parse(
self.reader.read_line,
)
except Exception as exc:
self.handshake_exc = exc
self.send_eof()
self.parser = self.discard()
next(self.parser) # start coroutine
yield

if self.debug:
self.logger.debug("< GET %s HTTP/1.1", request.path)
Expand Down
26 changes: 26 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,32 @@ def test_reject_response(self):
)
self.assertEqual(response.body, b"Sorry folks.\n")

def test_no_response(self):
with unittest.mock.patch("websockets.client.generate_key", return_value=KEY):
client = ClientProtocol(parse_uri("ws://example.com/test"))
client.connect()
client.receive_eof()
self.assertEqual(client.events_received(), [])

def test_partial_response(self):
with unittest.mock.patch("websockets.client.generate_key", return_value=KEY):
client = ClientProtocol(parse_uri("ws://example.com/test"))
client.connect()
client.receive_data(b"HTTP/1.1 101 Switching Protocols\r\n")
client.receive_eof()
self.assertEqual(client.events_received(), [])

def test_random_response(self):
with unittest.mock.patch("websockets.client.generate_key", return_value=KEY):
client = ClientProtocol(parse_uri("ws://example.com/test"))
client.connect()
client.receive_data(b"220 smtp.invalid\r\n")
client.receive_data(b"250 Hello relay.invalid\r\n")
client.receive_data(b"250 Ok\r\n")
client.receive_data(b"250 Ok\r\n")
client.receive_eof()
self.assertEqual(client.events_received(), [])

def make_accept_response(self, client):
request = client.connect()
return Response(
Expand Down
18 changes: 18 additions & 0 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,24 @@ def test_connect_request(self):
),
)

def test_no_request(self):
server = ServerProtocol()
server.receive_eof()
self.assertEqual(server.events_received(), [])

def test_partial_request(self):
server = ServerProtocol()
server.receive_data(b"GET /test HTTP/1.1\r\n")
server.receive_eof()
self.assertEqual(server.events_received(), [])

def test_random_request(self):
server = ServerProtocol()
server.receive_data(b"HELO relay.invalid\r\n")
server.receive_data(b"MAIL FROM: <alice@invalid>\r\n")
server.receive_data(b"RCPT TO: <bob@invalid>\r\n")
self.assertEqual(server.events_received(), [])


class AcceptRejectTests(unittest.TestCase):
def make_request(self):
Expand Down

0 comments on commit e4fcab1

Please sign in to comment.