Skip to content

Commit

Permalink
Accept int status in process_request and reject.
Browse files Browse the repository at this point in the history
Fix #1309.
  • Loading branch information
aaugustin committed May 18, 2023
1 parent e3abb88 commit 1bf9d1d
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/websockets/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,8 @@ def __init__(
headers: datastructures.HeadersLike,
body: bytes = b"",
) -> None:
self.status = status
# If a user passes an int instead of a HTTPStatus, fix it automatically.
self.status = http.HTTPStatus(status)
self.headers = datastructures.Headers(headers)
self.body = body

Expand Down
2 changes: 2 additions & 0 deletions src/websockets/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,8 @@ def reject(
Response: WebSocket handshake response event to send to the client.
"""
# If a user passes an int instead of a HTTPStatus, fix it automatically.
status = http.HTTPStatus(status)
body = text.encode()
headers = Headers(
[
Expand Down
14 changes: 14 additions & 0 deletions tests/legacy/test_client_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,12 @@ async def process_request(self, path, request_headers):
return http.HTTPStatus.OK, [("X-Access", "OK")], b"status = green\n"


class ProcessRequestReturningIntProtocol(WebSocketServerProtocol):
async def process_request(self, path, request_headers):
if path == "/__health__/":
return 200, [], b"OK\n"


class SlowOpeningHandshakeProtocol(WebSocketServerProtocol):
async def process_request(self, path, request_headers):
await asyncio.sleep(10 * MS)
Expand Down Expand Up @@ -757,6 +763,14 @@ def test_http_request_custom_server_header(self):
with contextlib.closing(response):
self.assertEqual(response.headers["Server"], "websockets")

@with_server(create_protocol=ProcessRequestReturningIntProtocol)
def test_process_request_returns_int_status(self):
response = self.loop.run_until_complete(self.make_http_request("/__health__/"))

with contextlib.closing(response):
self.assertEqual(response.code, 200)
self.assertEqual(response.read(), b"OK\n")

def assert_client_raises_code(self, status_code):
with self.assertRaises(InvalidStatusCode) as raised:
self.start_client()
Expand Down
6 changes: 6 additions & 0 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,12 @@ def test_reject_response(self):
)
self.assertEqual(response.body, b"Sorry folks.\n")

def test_reject_response_supports_int_status(self):
server = ServerProtocol()
response = server.reject(404, "Sorry folks.\n")
self.assertEqual(response.status_code, 404)
self.assertEqual(response.reason_phrase, "Not Found")

def test_basic(self):
server = ServerProtocol()
request = self.make_request()
Expand Down

0 comments on commit 1bf9d1d

Please sign in to comment.