Skip to content

Commit

Permalink
Ensure websocket transport is closed when client does not close it (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco committed Mar 28, 2024
1 parent 8f23712 commit 6ec4747
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 8 deletions.
6 changes: 6 additions & 0 deletions CHANGES/8200.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Ensure websocket transport is closed when client does not close it
-- by :user:`bdraco`.

The transport could remain open if the client did not close it. This
change ensures the transport is closed when the client does not close
it.
21 changes: 16 additions & 5 deletions aiohttp/web_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,7 @@ async def close(
return True

if self._closing:
self._close_transport()
return True

reader = self._reader
Expand All @@ -440,9 +441,18 @@ async def close(
self._exception = asyncio.TimeoutError()
return True

def _set_closing(self, code: WSCloseCode) -> None:
"""Set the close code and mark the connection as closing."""
self._closing = True
self._close_code = code

def _set_code_close_transport(self, code: WSCloseCode) -> None:
"""Set the close code and close the transport."""
self._close_code = code
self._close_transport()

def _close_transport(self) -> None:
"""Close the transport."""
if self._req is not None and self._req.transport is not None:
self._req.transport.close()

Expand Down Expand Up @@ -487,14 +497,12 @@ async def receive(self, timeout: Optional[float] = None) -> WSMessage:
return WSMessage(WSMsgType.ERROR, exc, None)
except Exception as exc:
self._exception = exc
self._closing = True
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
self._set_closing(WSCloseCode.ABNORMAL_CLOSURE)
await self.close()
return WSMessage(WSMsgType.ERROR, exc, None)

if msg.type == WSMsgType.CLOSE:
self._closing = True
self._close_code = msg.data
self._set_closing(msg.data)
# Could be closed while awaiting reader.
if not self._closed and self._autoclose: # type: ignore[redundant-expr]
# The client is likely going to close the
Expand All @@ -503,7 +511,7 @@ async def receive(self, timeout: Optional[float] = None) -> WSMessage:
# likely result writing to a broken pipe.
await self.close(drain=False)
elif msg.type == WSMsgType.CLOSING:
self._closing = True
self._set_closing(WSCloseCode.OK)
elif msg.type == WSMsgType.PING and self._autoping:
await self.pong(msg.data)
continue
Expand Down Expand Up @@ -547,5 +555,8 @@ async def __anext__(self) -> WSMessage:
return msg

def _cancel(self, exc: BaseException) -> None:
# web_protocol calls this from connection_lost
# or when the server is shutting down.
self._closing = True
if self._reader is not None:
set_exception(self._reader, exc)
89 changes: 88 additions & 1 deletion tests/test_web_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from aiohttp.streams import EofStream
from aiohttp.test_utils import make_mocked_coro, make_mocked_request
from aiohttp.web import HTTPBadRequest, WebSocketResponse
from aiohttp.web_ws import WS_CLOSED_MESSAGE, WebSocketReady
from aiohttp.web_ws import WS_CLOSED_MESSAGE, WebSocketReady, WSMessage


@pytest.fixture
Expand Down Expand Up @@ -344,6 +344,93 @@ async def test_receive_eofstream_in_reader(make_request: Any, loop: Any) -> None
assert ws.closed


async def test_receive_exception_in_reader(make_request: Any, loop: Any) -> None:
req = make_request("GET", "/")
ws = WebSocketResponse()
await ws.prepare(req)

ws._reader = mock.Mock()
exc = Exception()
res = loop.create_future()
res.set_exception(exc)
ws._reader.read = make_mocked_coro(res)
ws._payload_writer.drain = mock.Mock()
ws._payload_writer.drain.return_value = loop.create_future()
ws._payload_writer.drain.return_value.set_result(True)

msg = await ws.receive()
assert msg.type == WSMsgType.ERROR
assert ws.closed
assert len(ws._req.transport.close.mock_calls) == 1


async def test_receive_close_but_left_open(make_request: Any, loop: Any) -> None:
req = make_request("GET", "/")
ws = WebSocketResponse()
await ws.prepare(req)
close_message = WSMessage(WSMsgType.CLOSE, 1000, "close")

ws._reader = mock.Mock()
ws._reader.read = mock.AsyncMock(return_value=close_message)
ws._payload_writer.drain = mock.Mock()
ws._payload_writer.drain.return_value = loop.create_future()
ws._payload_writer.drain.return_value.set_result(True)

msg = await ws.receive()
assert msg.type == WSMsgType.CLOSE
assert ws.closed
assert len(ws._req.transport.close.mock_calls) == 1


async def test_receive_closing(make_request: Any, loop: Any) -> None:
req = make_request("GET", "/")
ws = WebSocketResponse()
await ws.prepare(req)
closing_message = WSMessage(WSMsgType.CLOSING, 1000, "closing")

ws._reader = mock.Mock()
read_mock = mock.AsyncMock(return_value=closing_message)
ws._reader.read = read_mock
ws._payload_writer.drain = mock.Mock()
ws._payload_writer.drain.return_value = loop.create_future()
ws._payload_writer.drain.return_value.set_result(True)

msg = await ws.receive()
assert msg.type == WSMsgType.CLOSING
assert not ws.closed

msg = await ws.receive()
assert msg.type == WSMsgType.CLOSING
assert not ws.closed

ws._cancel(ConnectionResetError("Connection lost"))

msg = await ws.receive()
assert msg.type == WSMsgType.CLOSING


async def test_close_after_closing(make_request: Any, loop: Any) -> None:
req = make_request("GET", "/")
ws = WebSocketResponse()
await ws.prepare(req)
closing_message = WSMessage(WSMsgType.CLOSING, 1000, "closing")

ws._reader = mock.Mock()
ws._reader.read = mock.AsyncMock(return_value=closing_message)
ws._payload_writer.drain = mock.Mock()
ws._payload_writer.drain.return_value = loop.create_future()
ws._payload_writer.drain.return_value.set_result(True)

msg = await ws.receive()
assert msg.type == WSMsgType.CLOSING
assert not ws.closed
assert len(ws._req.transport.close.mock_calls) == 0

await ws.close()
assert ws.closed
assert len(ws._req.transport.close.mock_calls) == 1


async def test_receive_timeouterror(make_request: Any, loop: Any) -> None:
req = make_request("GET", "/")
ws = WebSocketResponse()
Expand Down
33 changes: 31 additions & 2 deletions tests/test_web_websocket_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# HTTP websocket server functional tests

import asyncio
from typing import Any
from typing import Any, Optional

import pytest

Expand Down Expand Up @@ -258,7 +258,7 @@ async def handler(request):
assert "reply" == (await ws.receive_str())

# The server closes here. Then the client sends bogus messages with an
# internval shorter than server-side close timeout, to make the server
# interval shorter than server-side close timeout, to make the server
# hanging indefinitely.
await asyncio.sleep(0.08)
msg = await ws._reader.read()
Expand Down Expand Up @@ -309,6 +309,35 @@ async def handler(request):
assert msg.type == WSMsgType.CLOSED


async def test_close_op_code_from_client(loop: Any, aiohttp_client: Any) -> None:
srv_ws: Optional[web.WebSocketResponse] = None

async def handler(request):
nonlocal srv_ws
ws = srv_ws = web.WebSocketResponse(protocols=("foo", "bar"))
await ws.prepare(request)

msg = await ws.receive()
assert msg.type == WSMsgType.CLOSE
await asyncio.sleep(0)
return ws

app = web.Application()
app.router.add_get("/", handler)
client = await aiohttp_client(app)

ws: web.WebSocketResponse = await client.ws_connect("/", protocols=("eggs", "bar"))

await ws._writer._send_frame(b"", WSMsgType.CLOSE)

msg = await ws.receive()
assert msg.type == WSMsgType.CLOSE

await asyncio.sleep(0)
msg = await ws.receive()
assert msg.type == WSMsgType.CLOSED


async def test_auto_pong_with_closing_by_peer(loop: Any, aiohttp_client: Any) -> None:
closed = loop.create_future()

Expand Down

0 comments on commit 6ec4747

Please sign in to comment.