diff --git a/examples/ws.py b/examples/ws.py index 2f5f436a..72b0681e 100644 --- a/examples/ws.py +++ b/examples/ws.py @@ -12,14 +12,20 @@ async def send_message(ws): async def receive_message(ws): - while True: - try: - message = await ws.recv() - print("Received: ", message) - if message.data == b"Message 20": - print("Closing connection...") - break - except asyncio.CancelledError: + # while True: + # try: + # message = await ws.recv() + # print("Received: ", message) + # if message.data == b"Message 20": + # print("Closing connection...") + # break + # except asyncio.CancelledError: + # break + # or + async for message in ws: + print("Received: ", message) + if message.data == b"Message 20": + print("Closing connection...") break diff --git a/src/client/response/ws.rs b/src/client/response/ws.rs index a925382d..833cbd01 100644 --- a/src/client/response/ws.rs +++ b/src/client/response/ws.rs @@ -1,20 +1,20 @@ use std::sync::Arc; use crate::{ - error::wrap_rquest_error, + error::{py_stop_async_iteration_error, wrap_rquest_error}, types::{HeaderMap, Json, SocketAddr, StatusCode, Version}, }; use futures_util::{ stream::{SplitSink, SplitStream}, SinkExt, StreamExt, TryStreamExt, }; -use pyo3::{prelude::*, IntoPyObjectExt}; +use pyo3::prelude::*; use pyo3_async_runtimes::tokio::future_into_py; use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods}; use tokio::sync::Mutex; -type Sender = Arc>>; -type Receiver = Arc>>; +type Sender = Arc>>>; +type Receiver = Arc>>>; /// A WebSocket response. #[gen_stub_pyclass] @@ -47,8 +47,8 @@ impl WebSocket { remote_addr, headers, protocol, - sender: Arc::new(Mutex::new(sender)), - receiver: Arc::new(Mutex::new(receiver)), + sender: Arc::new(Mutex::new(Some(sender))), + receiver: Arc::new(Mutex::new(Some(receiver))), }) } } @@ -132,9 +132,11 @@ impl WebSocket { pub fn recv<'rt>(&self, py: Python<'rt>) -> PyResult> { let websocket = self.receiver.clone(); future_into_py(py, async move { - let mut ws = websocket.lock().await; - if let Ok(Some(val)) = ws.try_next().await { - return Ok(Some(Message(val))); + let mut lock = websocket.lock().await; + if let Some(recv) = lock.as_mut() { + if let Ok(Some(val)) = recv.try_next().await { + return Ok(Some(Message(val))); + } } Ok(None) }) @@ -154,8 +156,11 @@ impl WebSocket { pub fn send<'rt>(&self, py: Python<'rt>, message: Message) -> PyResult> { let sender = self.sender.clone(); future_into_py(py, async move { - let mut ws = sender.lock().await; - ws.send(message.0).await.map_err(wrap_rquest_error) + let mut lock = sender.lock().await; + if let Some(send) = lock.as_mut() { + return send.send(message.0).await.map_err(wrap_rquest_error); + } + Ok(()) }) } @@ -178,15 +183,28 @@ impl WebSocket { reason: Option, ) -> PyResult> { let sender = self.sender.clone(); + let receiver = self.receiver.clone(); future_into_py(py, async move { - let mut sender = sender.lock().await; - sender - .send(rquest::Message::Close { - code: rquest::CloseCode::from(code.unwrap_or_default()), - reason, - }) - .await - .map_err(wrap_rquest_error)?; + let mut lock = receiver.lock().await; + drop(lock.take()); + drop(lock); + + let mut lock = sender.lock().await; + let send = lock.take(); + drop(lock); + + if let Some(mut send) = send { + if let Some(code) = code { + send.send(rquest::Message::Close { + code: rquest::CloseCode::from(code), + reason, + }) + .await + .map_err(wrap_rquest_error)?; + } + return send.close().await.map_err(wrap_rquest_error); + } + Ok(()) }) } @@ -194,19 +212,53 @@ impl WebSocket { #[pymethods] impl WebSocket { - fn __aenter__<'a>(slf: PyRef<'a, Self>, py: Python<'a>) -> PyResult> { - let slf = slf.into_py_any(py)?; - future_into_py(py, async move { Ok(slf) }) + /// Returns the WebSocket instance itself as an asynchronous iterator. + /// + /// This method is used to make the WebSocket instance iterable in an asynchronous context. + /// + /// # Arguments + /// + /// * `slf` - A reference to the WebSocket instance. + /// + /// # Returns + /// + /// Returns the WebSocket instance itself. + #[inline(always)] + fn __aiter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { + slf } - fn __aexit__<'a>( - &'a mut self, - py: Python<'a>, - _exc_type: &Bound<'a, PyAny>, - _exc_value: &Bound<'a, PyAny>, - _traceback: &Bound<'a, PyAny>, - ) -> PyResult> { - self.close(py, None, None) + /// Returns the next message from the WebSocket. + /// + /// This method is used to retrieve the next message from the WebSocket in an asynchronous iteration. + /// + /// # Arguments + /// + /// * `py` - The Python runtime. + /// + /// # Returns + /// + /// Returns a `PyResult` containing an `Option` with a `Bound` object representing the received message. + /// If no message is received, returns `None`. + fn __anext__<'rt>(&self, py: Python<'rt>) -> PyResult>> { + let recv = self.receiver.clone(); + future_into_py(py, async move { + // Here we lock the mutex to access the data inside + // and call try_next() method to get the next value. + let mut lock = recv.lock().await; + let recv = lock + .as_mut() + .ok_or_else(py_stop_async_iteration_error)? + .try_next() + .await; + + drop(lock); + + recv.map(|val| val.map(Message)) + .map(Some) + .map_err(wrap_rquest_error) + }) + .map(Some) } } diff --git a/src/lib.rs b/src/lib.rs index df8d493e..f551fe2b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -264,12 +264,14 @@ fn request( /// ```python /// import rnet /// import asyncio +/// from rnet import Message /// /// async def run(): -/// async with rnet.websocket("wss://echo.websocket.org") as ws: -/// await ws.send("Hello, World!") -/// message = await ws.recv() -/// print(message) +/// ws = await rnet.websocket("wss://echo.websocket.org") +/// await ws.send(Message.from_text("Hello, World!")) +/// message = await ws.recv() +/// print("Received:", message.data) +/// await ws.close() /// /// asyncio.run(run()) /// ``` diff --git a/tests/resp_test.py b/tests/resp_test.py index 09d6dced..12596f89 100644 --- a/tests/resp_test.py +++ b/tests/resp_test.py @@ -1,11 +1,21 @@ import pytest import rnet from pathlib import Path -from rnet import Version, Multipart, Part +from rnet import Version, Multipart, Part, WebSocket, Message client = rnet.Client(tls_info=True) +@pytest.mark.asyncio +async def test_websocket(): + ws: WebSocket = await client.websocket("wss://echo.websocket.org") + await ws.recv() + await ws.send(Message.from_text("Hello, World!")) + message: Message = await ws.recv() + assert message.data == b"Hello, World!" + await ws.close() + + @pytest.mark.asyncio async def test_multiple_requests(): resp = await client.post(