Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions examples/ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,20 @@ async def send_message(ws):


async def receive_message(ws):
while True:
try:
message = await ws.recv()
print("Received: ", message)
if message.data == b"Message 20":
print("Closing connection...")
break
except asyncio.CancelledError:
# while True:
# try:
# message = await ws.recv()
# print("Received: ", message)
# if message.data == b"Message 20":
# print("Closing connection...")
# break
# except asyncio.CancelledError:
# break
# or
async for message in ws:
print("Received: ", message)
if message.data == b"Message 20":
print("Closing connection...")
break


Expand Down
112 changes: 82 additions & 30 deletions src/client/response/ws.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
use std::sync::Arc;

use crate::{
error::wrap_rquest_error,
error::{py_stop_async_iteration_error, wrap_rquest_error},
types::{HeaderMap, Json, SocketAddr, StatusCode, Version},
};
use futures_util::{
stream::{SplitSink, SplitStream},
SinkExt, StreamExt, TryStreamExt,
};
use pyo3::{prelude::*, IntoPyObjectExt};
use pyo3::prelude::*;
use pyo3_async_runtimes::tokio::future_into_py;
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
use tokio::sync::Mutex;

type Sender = Arc<Mutex<SplitSink<rquest::WebSocket, rquest::Message>>>;
type Receiver = Arc<Mutex<SplitStream<rquest::WebSocket>>>;
type Sender = Arc<Mutex<Option<SplitSink<rquest::WebSocket, rquest::Message>>>>;
type Receiver = Arc<Mutex<Option<SplitStream<rquest::WebSocket>>>>;

/// A WebSocket response.
#[gen_stub_pyclass]
Expand Down Expand Up @@ -47,8 +47,8 @@ impl WebSocket {
remote_addr,
headers,
protocol,
sender: Arc::new(Mutex::new(sender)),
receiver: Arc::new(Mutex::new(receiver)),
sender: Arc::new(Mutex::new(Some(sender))),
receiver: Arc::new(Mutex::new(Some(receiver))),
})
}
}
Expand Down Expand Up @@ -132,9 +132,11 @@ impl WebSocket {
pub fn recv<'rt>(&self, py: Python<'rt>) -> PyResult<Bound<'rt, PyAny>> {
let websocket = self.receiver.clone();
future_into_py(py, async move {
let mut ws = websocket.lock().await;
if let Ok(Some(val)) = ws.try_next().await {
return Ok(Some(Message(val)));
let mut lock = websocket.lock().await;
if let Some(recv) = lock.as_mut() {
if let Ok(Some(val)) = recv.try_next().await {
return Ok(Some(Message(val)));
}
}
Ok(None)
})
Expand All @@ -154,8 +156,11 @@ impl WebSocket {
pub fn send<'rt>(&self, py: Python<'rt>, message: Message) -> PyResult<Bound<'rt, PyAny>> {
let sender = self.sender.clone();
future_into_py(py, async move {
let mut ws = sender.lock().await;
ws.send(message.0).await.map_err(wrap_rquest_error)
let mut lock = sender.lock().await;
if let Some(send) = lock.as_mut() {
return send.send(message.0).await.map_err(wrap_rquest_error);
}
Ok(())
})
}

Expand All @@ -178,35 +183,82 @@ impl WebSocket {
reason: Option<String>,
) -> PyResult<Bound<'rt, PyAny>> {
let sender = self.sender.clone();
let receiver = self.receiver.clone();
future_into_py(py, async move {
let mut sender = sender.lock().await;
sender
.send(rquest::Message::Close {
code: rquest::CloseCode::from(code.unwrap_or_default()),
reason,
})
.await
.map_err(wrap_rquest_error)?;
let mut lock = receiver.lock().await;
drop(lock.take());
drop(lock);

let mut lock = sender.lock().await;
let send = lock.take();
drop(lock);

if let Some(mut send) = send {
if let Some(code) = code {
send.send(rquest::Message::Close {
code: rquest::CloseCode::from(code),
reason,
})
.await
.map_err(wrap_rquest_error)?;
}
return send.close().await.map_err(wrap_rquest_error);
}

Ok(())
})
}
}

#[pymethods]
impl WebSocket {
fn __aenter__<'a>(slf: PyRef<'a, Self>, py: Python<'a>) -> PyResult<Bound<'a, PyAny>> {
let slf = slf.into_py_any(py)?;
future_into_py(py, async move { Ok(slf) })
/// Returns the WebSocket instance itself as an asynchronous iterator.
///
/// This method is used to make the WebSocket instance iterable in an asynchronous context.
///
/// # Arguments
///
/// * `slf` - A reference to the WebSocket instance.
///
/// # Returns
///
/// Returns the WebSocket instance itself.
#[inline(always)]
fn __aiter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
slf
}

fn __aexit__<'a>(
&'a mut self,
py: Python<'a>,
_exc_type: &Bound<'a, PyAny>,
_exc_value: &Bound<'a, PyAny>,
_traceback: &Bound<'a, PyAny>,
) -> PyResult<Bound<'a, PyAny>> {
self.close(py, None, None)
/// Returns the next message from the WebSocket.
///
/// This method is used to retrieve the next message from the WebSocket in an asynchronous iteration.
///
/// # Arguments
///
/// * `py` - The Python runtime.
///
/// # Returns
///
/// Returns a `PyResult` containing an `Option` with a `Bound` object representing the received message.
/// If no message is received, returns `None`.
fn __anext__<'rt>(&self, py: Python<'rt>) -> PyResult<Option<Bound<'rt, PyAny>>> {
let recv = self.receiver.clone();
future_into_py(py, async move {
// Here we lock the mutex to access the data inside
// and call try_next() method to get the next value.
let mut lock = recv.lock().await;
let recv = lock
.as_mut()
.ok_or_else(py_stop_async_iteration_error)?
.try_next()
.await;

drop(lock);

recv.map(|val| val.map(Message))
.map(Some)
.map_err(wrap_rquest_error)
})
.map(Some)
}
}

Expand Down
10 changes: 6 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -264,12 +264,14 @@ fn request(
/// ```python
/// import rnet
/// import asyncio
/// from rnet import Message
///
/// async def run():
/// async with rnet.websocket("wss://echo.websocket.org") as ws:
/// await ws.send("Hello, World!")
/// message = await ws.recv()
/// print(message)
/// ws = await rnet.websocket("wss://echo.websocket.org")
/// await ws.send(Message.from_text("Hello, World!"))
/// message = await ws.recv()
/// print("Received:", message.data)
/// await ws.close()
///
/// asyncio.run(run())
/// ```
Expand Down
12 changes: 11 additions & 1 deletion tests/resp_test.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
import pytest
import rnet
from pathlib import Path
from rnet import Version, Multipart, Part
from rnet import Version, Multipart, Part, WebSocket, Message

client = rnet.Client(tls_info=True)


@pytest.mark.asyncio
async def test_websocket():
ws: WebSocket = await client.websocket("wss://echo.websocket.org")
await ws.recv()
await ws.send(Message.from_text("Hello, World!"))
message: Message = await ws.recv()
assert message.data == b"Hello, World!"
await ws.close()


@pytest.mark.asyncio
async def test_multiple_requests():
resp = await client.post(
Expand Down