From bb57e249ce0e6162ddca8d30b0887ab01b595dac Mon Sep 17 00:00:00 2001 From: 0x676e67 Date: Sun, 23 Feb 2025 13:12:24 +0800 Subject: [PATCH 1/4] feat(ws): Ensure websocket connection is released when closed --- examples/ws.py | 22 +++++--- src/client/response/ws.rs | 103 +++++++++++++++++++++++++++++++------- 2 files changed, 99 insertions(+), 26 deletions(-) diff --git a/examples/ws.py b/examples/ws.py index 2f5f436a..72b0681e 100644 --- a/examples/ws.py +++ b/examples/ws.py @@ -12,14 +12,20 @@ async def send_message(ws): async def receive_message(ws): - while True: - try: - message = await ws.recv() - print("Received: ", message) - if message.data == b"Message 20": - print("Closing connection...") - break - except asyncio.CancelledError: + # while True: + # try: + # message = await ws.recv() + # print("Received: ", message) + # if message.data == b"Message 20": + # print("Closing connection...") + # break + # except asyncio.CancelledError: + # break + # or + async for message in ws: + print("Received: ", message) + if message.data == b"Message 20": + print("Closing connection...") break diff --git a/src/client/response/ws.rs b/src/client/response/ws.rs index a925382d..a48a8fdd 100644 --- a/src/client/response/ws.rs +++ b/src/client/response/ws.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use crate::{ - error::wrap_rquest_error, + error::{py_stop_async_iteration_error, wrap_rquest_error}, types::{HeaderMap, Json, SocketAddr, StatusCode, Version}, }; use futures_util::{ @@ -13,8 +13,8 @@ use pyo3_async_runtimes::tokio::future_into_py; use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods}; use tokio::sync::Mutex; -type Sender = Arc>>; -type Receiver = Arc>>; +type Sender = Arc>>>; +type Receiver = Arc>>>; /// A WebSocket response. #[gen_stub_pyclass] @@ -47,8 +47,8 @@ impl WebSocket { remote_addr, headers, protocol, - sender: Arc::new(Mutex::new(sender)), - receiver: Arc::new(Mutex::new(receiver)), + sender: Arc::new(Mutex::new(Some(sender))), + receiver: Arc::new(Mutex::new(Some(receiver))), }) } } @@ -132,9 +132,11 @@ impl WebSocket { pub fn recv<'rt>(&self, py: Python<'rt>) -> PyResult> { let websocket = self.receiver.clone(); future_into_py(py, async move { - let mut ws = websocket.lock().await; - if let Ok(Some(val)) = ws.try_next().await { - return Ok(Some(Message(val))); + let mut lock = websocket.lock().await; + if let Some(recv) = lock.as_mut() { + if let Ok(Some(val)) = recv.try_next().await { + return Ok(Some(Message(val))); + } } Ok(None) }) @@ -154,8 +156,11 @@ impl WebSocket { pub fn send<'rt>(&self, py: Python<'rt>, message: Message) -> PyResult> { let sender = self.sender.clone(); future_into_py(py, async move { - let mut ws = sender.lock().await; - ws.send(message.0).await.map_err(wrap_rquest_error) + let mut lock = sender.lock().await; + if let Some(send) = lock.as_mut() { + return send.send(message.0).await.map_err(wrap_rquest_error); + } + Ok(()) }) } @@ -178,15 +183,28 @@ impl WebSocket { reason: Option, ) -> PyResult> { let sender = self.sender.clone(); + let receiver = self.receiver.clone(); future_into_py(py, async move { - let mut sender = sender.lock().await; - sender - .send(rquest::Message::Close { - code: rquest::CloseCode::from(code.unwrap_or_default()), - reason, - }) - .await - .map_err(wrap_rquest_error)?; + let mut lock = receiver.lock().await; + drop(lock.take()); + drop(lock); + + let mut lock = sender.lock().await; + let send = lock.take(); + drop(lock); + + if let Some(mut send) = send { + if let Some(code) = code { + send.send(rquest::Message::Close { + code: rquest::CloseCode::from(code), + reason, + }) + .await + .map_err(wrap_rquest_error)?; + } + return send.close().await.map_err(wrap_rquest_error); + } + Ok(()) }) } @@ -194,6 +212,55 @@ impl WebSocket { #[pymethods] impl WebSocket { + /// Returns the WebSocket instance itself as an asynchronous iterator. + /// + /// This method is used to make the WebSocket instance iterable in an asynchronous context. + /// + /// # Arguments + /// + /// * `slf` - A reference to the WebSocket instance. + /// + /// # Returns + /// + /// Returns the WebSocket instance itself. + #[inline(always)] + fn __aiter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { + slf + } + + /// Returns the next message from the WebSocket. + /// + /// This method is used to retrieve the next message from the WebSocket in an asynchronous iteration. + /// + /// # Arguments + /// + /// * `py` - The Python runtime. + /// + /// # Returns + /// + /// Returns a `PyResult` containing an `Option` with a `Bound` object representing the received message. + /// If no message is received, returns `None`. + fn __anext__<'rt>(&self, py: Python<'rt>) -> PyResult>> { + let recv = self.receiver.clone(); + future_into_py(py, async move { + // Here we lock the mutex to access the data inside + // and call try_next() method to get the next value. + let mut lock = recv.lock().await; + let recv = lock + .as_mut() + .ok_or_else(py_stop_async_iteration_error)? + .try_next() + .await; + + drop(lock); + + recv.map(|val| val.map(Message)) + .map(Some) + .map_err(wrap_rquest_error) + }) + .map(Some) + } + fn __aenter__<'a>(slf: PyRef<'a, Self>, py: Python<'a>) -> PyResult> { let slf = slf.into_py_any(py)?; future_into_py(py, async move { Ok(slf) }) From 13d2d338fbd9de8a418a7d4a64fbdb506011bc0b Mon Sep 17 00:00:00 2001 From: 0x676e67 Date: Sun, 23 Feb 2025 13:24:45 +0800 Subject: [PATCH 2/4] Add test --- tests/resp_test.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/tests/resp_test.py b/tests/resp_test.py index 09d6dced..12596f89 100644 --- a/tests/resp_test.py +++ b/tests/resp_test.py @@ -1,11 +1,21 @@ import pytest import rnet from pathlib import Path -from rnet import Version, Multipart, Part +from rnet import Version, Multipart, Part, WebSocket, Message client = rnet.Client(tls_info=True) +@pytest.mark.asyncio +async def test_websocket(): + ws: WebSocket = await client.websocket("wss://echo.websocket.org") + await ws.recv() + await ws.send(Message.from_text("Hello, World!")) + message: Message = await ws.recv() + assert message.data == b"Hello, World!" + await ws.close() + + @pytest.mark.asyncio async def test_multiple_requests(): resp = await client.post( From 3612eddec49949ee36f272690c06e9f7e0fcef5d Mon Sep 17 00:00:00 2001 From: 0x676e67 Date: Sun, 23 Feb 2025 13:32:07 +0800 Subject: [PATCH 3/4] remove dead code --- src/client/response/ws.rs | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/src/client/response/ws.rs b/src/client/response/ws.rs index a48a8fdd..833cbd01 100644 --- a/src/client/response/ws.rs +++ b/src/client/response/ws.rs @@ -8,7 +8,7 @@ use futures_util::{ stream::{SplitSink, SplitStream}, SinkExt, StreamExt, TryStreamExt, }; -use pyo3::{prelude::*, IntoPyObjectExt}; +use pyo3::prelude::*; use pyo3_async_runtimes::tokio::future_into_py; use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods}; use tokio::sync::Mutex; @@ -260,21 +260,6 @@ impl WebSocket { }) .map(Some) } - - fn __aenter__<'a>(slf: PyRef<'a, Self>, py: Python<'a>) -> PyResult> { - let slf = slf.into_py_any(py)?; - future_into_py(py, async move { Ok(slf) }) - } - - fn __aexit__<'a>( - &'a mut self, - py: Python<'a>, - _exc_type: &Bound<'a, PyAny>, - _exc_value: &Bound<'a, PyAny>, - _traceback: &Bound<'a, PyAny>, - ) -> PyResult> { - self.close(py, None, None) - } } /// A WebSocket message. From ef14f4ad7e493a3c7b4203f2f4b58feeff82e213 Mon Sep 17 00:00:00 2001 From: 0x676e67 Date: Sun, 23 Feb 2025 13:35:31 +0800 Subject: [PATCH 4/4] update docs --- src/lib.rs | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index df8d493e..f551fe2b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -264,12 +264,14 @@ fn request( /// ```python /// import rnet /// import asyncio +/// from rnet import Message /// /// async def run(): -/// async with rnet.websocket("wss://echo.websocket.org") as ws: -/// await ws.send("Hello, World!") -/// message = await ws.recv() -/// print(message) +/// ws = await rnet.websocket("wss://echo.websocket.org") +/// await ws.send(Message.from_text("Hello, World!")) +/// message = await ws.recv() +/// print("Received:", message.data) +/// await ws.close() /// /// asyncio.run(run()) /// ```