diff --git a/Cargo.toml b/Cargo.toml index ac293223..d23b1b78 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,7 +26,8 @@ mime = "0.3.17" indexmap = { version = "2.7.0", features = ["serde"] } cookie = "0.18.0" arc-swap = "1.7.1" -rquest = { version = "2.1.0", features = ["full"] } +rquest = { version = "2.1.0", features = ["full", "websocket"] } +futures-util = { version = "0.3.0", default-features = false } [profile.release] lto = true diff --git a/README.md b/README.md index 107f6f19..1b3494af 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ Asynchronous Python HTTP Client with Black Magic, powered by FFI from [rquest](h - Redirect Policy - Cookie Store - HTTP Proxies +- WebSocket Upgrade - HTTPS via BoringSSL - Perfectly Chrome, Safari, and Firefox diff --git a/examples/ws.py b/examples/ws.py new file mode 100644 index 00000000..af7669a7 --- /dev/null +++ b/examples/ws.py @@ -0,0 +1,51 @@ +import asyncio +import signal +import rnet +from rnet import Message + + +async def send_message(ws): + for i in range(20): + print(f"Sending: Message {i + 1}") + await ws.send(Message.from_text(f"Message {i + 1}")) + await asyncio.sleep(1) + + +async def receive_message(ws): + while True: + try: + message = await ws.recv() + print("Received: ", message) + if message.data == b"Message 20": + print("Closing connection...") + break + except asyncio.CancelledError: + break + + +async def main(): + resp = await rnet.websocket("wss://echo.websocket.org") + print("Status Code: ", resp.status) + print("Version: ", resp.version) + print("Headers: ", resp.headers.to_dict()) + print("Remote Address: ", resp.remote_addr) + + ws = await resp.into_websocket() + + send_task = asyncio.create_task(send_message(ws)) + receive_task = asyncio.create_task(receive_message(ws)) + + async def close_ws(): + await ws.close() + send_task.cancel() + receive_task.cancel() + + loop = asyncio.get_running_loop() + for sig in (signal.SIGINT, signal.SIGTERM): + loop.add_signal_handler(sig, lambda: asyncio.create_task(close_ws())) + + await asyncio.gather(send_task, receive_task) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/rnet.pyi b/rnet.pyi index 4a5c7a78..96aec836 100644 --- a/rnet.pyi +++ b/rnet.pyi @@ -272,6 +272,12 @@ class Client: """ ... + def websocket(self, url:builtins.str, **kwds) -> typing.Any: + r""" + Sends a WebSocket request. + """ + ... + class ClientParams: r""" @@ -413,6 +419,109 @@ class ImpersonateOS: ... +class Message: + r""" + A WebSocket message. + """ + text: typing.Optional[builtins.str] + close: typing.Optional[tuple[builtins.int, typing.Optional[builtins.str]]] + def __str__(self) -> builtins.str: + r""" + Returns a string representation of the message. + + # Returns + + A string representing the message. + """ + ... + + def __repr__(self) -> builtins.str: + r""" + Returns a string representation of the message. + + # Returns + + A string representing the message. + """ + ... + + @staticmethod + def from_text(text:builtins.str) -> Message: + r""" + Creates a new text message. + + # Arguments + + * `text` - The text content of the message. + + # Returns + + A new `Message` instance containing the text message. + """ + ... + + @staticmethod + def from_binary(data:typing.Sequence[builtins.int]) -> Message: + r""" + Creates a new binary message. + + # Arguments + + * `data` - The binary data of the message. + + # Returns + + A new `Message` instance containing the binary message. + """ + ... + + @staticmethod + def from_ping(data:typing.Sequence[builtins.int]) -> Message: + r""" + Creates a new ping message. + + # Arguments + + * `data` - The ping data of the message. + + # Returns + + A new `Message` instance containing the ping message. + """ + ... + + @staticmethod + def from_pong(data:typing.Sequence[builtins.int]) -> Message: + r""" + Creates a new pong message. + + # Arguments + + * `data` - The pong data of the message. + + # Returns + + A new `Message` instance containing the pong message. + """ + ... + + @staticmethod + def from_close(code:builtins.int, reason:typing.Optional[builtins.str]=None) -> Message: + r""" + Creates a new close message. + + # Arguments + + * `code` - The close code. + * `reason` - An optional reason for closing. + + # Returns + + A new `Message` instance containing the close message. + """ + ... + + class Method: r""" A HTTP method. @@ -822,6 +931,80 @@ class Version: ... +class WebSocket: + r""" + A WebSocket connection. + """ + ... + +class WebSocketParams: + r""" + The parameters for a WebSocket request. + + # Examples + + ```python + import rnet + from rnet import Impersonate, Version + + params = rnet.WebSocketParams( + proxy="http://proxy.example.com", + local_address="192.168.1.1", + interface="eth0", + headers={"Content-Type": "application/json"}, + auth="Basic dXNlcjpwYXNzd29yZA==", + bearer_auth="Bearer token", + basic_auth=("user", "password"), + query=[("key1", "value1"), ("key2", "value2")] + ) + + async with rnet.websocket("wss://echo.websocket.org") as ws: + await ws.send("Hello, World!") + message = await ws.recv() + print(message) + + asyncio.run(run()) + ``` + """ + proxy: typing.Optional[builtins.str] + interface: typing.Optional[builtins.str] + auth: typing.Optional[builtins.str] + bearer_auth: typing.Optional[builtins.str] + basic_auth: typing.Optional[tuple[builtins.str, typing.Optional[builtins.str]]] + query: typing.Optional[builtins.list[tuple[builtins.str, builtins.str]]] + +class WebSocketResponse: + r""" + A WebSocket response. + """ + ok: builtins.bool + status: builtins.int + version: Version + headers: HeaderMap + remote_addr: typing.Optional[SocketAddr] + def peer_certificate(self) -> typing.Optional[builtins.list[builtins.int]]: + r""" + Returns the TLS peer certificate of the response. + + # Returns + + A Python object representing the TLS peer certificate of the response. + """ + ... + + def into_websocket(self) -> typing.Any: + r""" + Returns the WebSocket of the response. + """ + ... + + def close(self) -> None: + r""" + Closes the response connection. + """ + ... + + def delete(url:builtins.str, **kwds) -> typing.Any: r""" Shortcut method to quickly make a `DELETE` request. @@ -1018,3 +1201,26 @@ def trace(url:builtins.str, **kwds) -> typing.Any: """ ... +def websocket(url:builtins.str, **kwds) -> typing.Any: + r""" + Make a WebSocket connection with the given parameters. + + This function allows you to make a WebSocket connection with the specified parameters encapsulated in a `WebSocket` object. + + # Examples + + ```python + import rnet + import asyncio + + async def run(): + async with rnet.websocket("wss://echo.websocket.org") as ws: + await ws.send("Hello, World!") + message = await ws.recv() + print(message) + + asyncio.run(run()) + ``` + """ + ... + diff --git a/src/client.rs b/src/client.rs index 181f7742..ce31d1b6 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,7 +1,7 @@ use crate::{ error::{wrap_invali_header_name_error, wrap_rquest_error}, - param::{ClientParams, RequestParams}, - response::Response, + param::{ClientParams, RequestParams, WebSocketParams}, + response::{Response, WebSocketResponse}, types::Method, Result, }; @@ -581,6 +581,18 @@ impl Client { let client = self.0.clone(); pyo3_async_runtimes::tokio::future_into_py(py, execute_request(client, method, url, kwds)) } + + /// Sends a WebSocket request. + #[pyo3(signature = (url, **kwds))] + pub fn websocket<'rt>( + &self, + py: Python<'rt>, + url: String, + kwds: Option, + ) -> PyResult> { + let client = self.0.clone(); + pyo3_async_runtimes::tokio::future_into_py(py, execute_websocket_request(client, url, kwds)) + } } /// Executes an HTTP request. @@ -672,3 +684,55 @@ async fn execute_request( .map(Response::from) .map_err(wrap_rquest_error) } + +/// Executes a WebSocket request. +async fn execute_websocket_request( + client: rquest::Client, + url: String, + mut params: Option, +) -> Result { + let params = params.get_or_insert_default(); + let mut builder = client.websocket(url); + + // The protocols to use for the request. + apply_option!(apply_if_some, builder, params.protocols, protocols); + + // The origin to use for the request. + builder = builder.with_builder(|mut builder| { + // Network options. + apply_option!(apply_if_some, builder, params.proxy, proxy); + apply_option!(apply_if_some, builder, params.local_address, local_address); + rquest::cfg_bindable_device!( + apply_option!(apply_if_some, builder, params.interface, interface); + ); + + // Authentication options. + apply_option!(apply_if_some, builder, params.auth, auth); + + // Bearer authentication options. + apply_option!(apply_if_some, builder, params.bearer_auth, bearer_auth); + + // Basic authentication options. + if let Some(basic_auth) = params.basic_auth.take() { + builder = builder.basic_auth(basic_auth.0, basic_auth.1); + } + + // Headers options. + if let Some(headers) = params.headers.take() { + for (key, value) in headers { + builder = builder.header(key, value); + } + } + + // Query options. + apply_option!(apply_if_some_ref, builder, params.query, query); + + builder + }); + + builder + .send() + .await + .map(WebSocketResponse::from) + .map_err(wrap_rquest_error) +} diff --git a/src/lib.rs b/src/lib.rs index acf5f776..1262c969 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,10 +5,10 @@ mod response; mod types; use client::Client; -use param::{ClientParams, RequestParams}; +use param::{ClientParams, RequestParams, WebSocketParams}; use pyo3::prelude::*; use pyo3_stub_gen::{define_stub_info_gatherer, derive::*}; -use response::{Response, Streamer}; +use response::{Message, Response, Streamer, WebSocket, WebSocketResponse}; use types::{ HeaderMap, Impersonate, ImpersonateOS, Method, Proxy, SocketAddr, StatusCode, Version, }; @@ -243,6 +243,36 @@ fn request( Client::default().request(py, method, url, kwds) } +/// Make a WebSocket connection with the given parameters. +/// +/// This function allows you to make a WebSocket connection with the specified parameters encapsulated in a `WebSocket` object. +/// +/// # Examples +/// +/// ```python +/// import rnet +/// import asyncio +/// +/// async def run(): +/// async with rnet.websocket("wss://echo.websocket.org") as ws: +/// await ws.send("Hello, World!") +/// message = await ws.recv() +/// print(message) +/// +/// asyncio.run(run()) +/// ``` +#[gen_stub_pyfunction] +#[pyfunction] +#[pyo3(signature = (url, **kwds))] +#[inline(always)] +fn websocket( + py: Python<'_>, + url: String, + kwds: Option, +) -> PyResult> { + Client::default().websocket(py, url, kwds) +} + #[pymodule(gil_used = false)] fn rnet(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; @@ -254,11 +284,15 @@ fn rnet(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; - m.add_function(wrap_pyfunction!(request, m)?)?; + m.add_function(wrap_pyfunction!(get, m)?)?; m.add_function(wrap_pyfunction!(post, m)?)?; m.add_function(wrap_pyfunction!(put, m)?)?; @@ -267,6 +301,8 @@ fn rnet(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(head, m)?)?; m.add_function(wrap_pyfunction!(options, m)?)?; m.add_function(wrap_pyfunction!(trace, m)?)?; + m.add_function(wrap_pyfunction!(request, m)?)?; + m.add_function(wrap_pyfunction!(websocket, m)?)?; Ok(()) } diff --git a/src/param/mod.rs b/src/param/mod.rs index 934c5975..04a734ee 100644 --- a/src/param/mod.rs +++ b/src/param/mod.rs @@ -1,5 +1,7 @@ mod client; mod request; +mod websocket; pub use self::client::ClientParams; pub use self::request::RequestParams; +pub use self::websocket::WebSocketParams; diff --git a/src/param/websocket.rs b/src/param/websocket.rs new file mode 100644 index 00000000..335cb466 --- /dev/null +++ b/src/param/websocket.rs @@ -0,0 +1,94 @@ +use std::net::IpAddr; + +use indexmap::IndexMap; +use pyo3::prelude::*; +use pyo3_stub_gen::derive::gen_stub_pyclass; + +/// The parameters for a WebSocket request. +/// +/// # Examples +/// +/// ```python +/// import rnet +/// from rnet import Impersonate, Version +/// +/// params = rnet.WebSocketParams( +/// proxy="http://proxy.example.com", +/// local_address="192.168.1.1", +/// interface="eth0", +/// headers={"Content-Type": "application/json"}, +/// auth="Basic dXNlcjpwYXNzd29yZA==", +/// bearer_auth="Bearer token", +/// basic_auth=("user", "password"), +/// query=[("key1", "value1"), ("key2", "value2")] +/// ) +/// +/// async with rnet.websocket("wss://echo.websocket.org") as ws: +/// await ws.send("Hello, World!") +/// message = await ws.recv() +/// print(message) +/// +/// asyncio.run(run()) +/// ``` +#[gen_stub_pyclass] +#[pyclass] +#[derive(Default, Debug)] +pub struct WebSocketParams { + /// The proxy to use for the request. + #[pyo3(get)] + pub proxy: Option, + + /// Bind to a local IP Address. + pub local_address: Option, + + /// Bind to an interface by `SO_BINDTODEVICE`. + #[pyo3(get)] + pub interface: Option, + + /// The headers to use for the request. + pub headers: Option>, + + /// The protocols to use for the request. + pub protocols: Option>, + + /// The authentication to use for the request. + #[pyo3(get)] + pub auth: Option, + + /// The bearer authentication to use for the request. + #[pyo3(get)] + pub bearer_auth: Option, + + /// The basic authentication to use for the request. + #[pyo3(get)] + pub basic_auth: Option<(String, Option)>, + + /// The query parameters to use for the request. + #[pyo3(get)] + pub query: Option>, +} + +macro_rules! extract_option { + ($ob:expr, $params:expr, $field:ident) => { + if let Ok(value) = $ob.get_item(stringify!($field)) { + $params.$field = value.extract()?; + } + }; +} + +impl<'py> FromPyObject<'py> for WebSocketParams { + fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { + let mut params = Self::default(); + extract_option!(ob, params, proxy); + extract_option!(ob, params, local_address); + extract_option!(ob, params, interface); + + extract_option!(ob, params, headers); + extract_option!(ob, params, protocols); + extract_option!(ob, params, auth); + extract_option!(ob, params, bearer_auth); + extract_option!(ob, params, basic_auth); + extract_option!(ob, params, query); + Ok(params) + } +} diff --git a/src/response.rs b/src/response/http.rs similarity index 100% rename from src/response.rs rename to src/response/http.rs diff --git a/src/response/mod.rs b/src/response/mod.rs new file mode 100644 index 00000000..fc1051f2 --- /dev/null +++ b/src/response/mod.rs @@ -0,0 +1,5 @@ +mod http; +mod ws; + +pub use http::{Response, Streamer}; +pub use ws::{Message, WebSocket, WebSocketResponse}; diff --git a/src/response/ws.rs b/src/response/ws.rs new file mode 100644 index 00000000..83b65e6a --- /dev/null +++ b/src/response/ws.rs @@ -0,0 +1,503 @@ +use std::sync::Arc; + +use crate::{ + error::{memory_error, wrap_rquest_error}, + types::{HeaderMap, Json, SocketAddr, StatusCode, Version}, +}; +use arc_swap::ArcSwapOption; +use futures_util::{ + stream::{SplitSink, SplitStream}, + SinkExt, StreamExt, TryStreamExt, +}; +use pyo3::{prelude::*, IntoPyObjectExt}; +use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods}; +use rquest::TlsInfo; +use tokio::sync::Mutex; + +/// A WebSocket response. +#[gen_stub_pyclass] +#[pyclass] +pub struct WebSocketResponse { + version: Version, + status_code: StatusCode, + remote_addr: Option, + headers: HeaderMap, + response: ArcSwapOption, +} + +impl From for WebSocketResponse { + fn from(response: rquest::WebSocketResponse) -> Self { + WebSocketResponse { + version: Version::from(response.version()), + status_code: StatusCode::from(response.status()), + remote_addr: response.remote_addr().map(SocketAddr::from), + headers: HeaderMap::from(response.headers().clone()), + response: ArcSwapOption::from_pointee(response), + } + } +} + +#[gen_stub_pymethods] +#[pymethods] +impl WebSocketResponse { + /// Returns whether the response is successful. + /// + /// # Returns + /// + /// A boolean indicating whether the response is successful. + #[getter] + #[inline(always)] + pub fn ok(&self) -> bool { + self.status_code.as_int() == rquest::StatusCode::SWITCHING_PROTOCOLS + } + + /// Returns the status code as integer of the response. + /// + /// # Returns + /// + /// An integer representing the HTTP status code. + #[getter] + #[inline(always)] + pub fn status(&self) -> u16 { + self.status_code.as_int() + } + + /// Returns the HTTP version of the response. + /// + /// # Returns + /// + /// A `Version` object representing the HTTP version of the response. + #[getter] + #[inline(always)] + pub fn version(&self) -> Version { + self.version + } + + /// Returns the headers of the response. + /// + /// # Returns + /// + /// A `HeaderMap` object representing the headers of the response. + #[getter] + #[inline(always)] + pub fn headers(&self) -> HeaderMap { + self.headers.clone() + } + + /// Returns the remote address of the response. + /// + /// # Returns + /// + /// An `IpAddr` object representing the remote address of the response. + #[getter] + #[inline(always)] + pub fn remote_addr(&self) -> Option { + self.remote_addr + } + + /// Returns the TLS peer certificate of the response. + /// + /// # Returns + /// + /// A Python object representing the TLS peer certificate of the response. + pub fn peer_certificate(&self) -> PyResult>> { + let resp_ref = self.response.load(); + let resp = resp_ref.as_ref().ok_or_else(memory_error)?; + if let Some(val) = resp.extensions().get::() { + return Ok(val.peer_certificate().map(ToOwned::to_owned)); + } + + Ok(None) + } + + /// Returns the WebSocket of the response. + pub fn into_websocket<'rt>(&self, py: Python<'rt>) -> PyResult> { + let response = self.into_inner()?; + pyo3_async_runtimes::tokio::future_into_py(py, async move { + response + .into_websocket() + .await + .map(WebSocket::from) + .map_err(wrap_rquest_error) + }) + } + + /// Closes the response connection. + pub fn close(&self) { + let _ = self.into_inner().map(drop); + } +} + +impl WebSocketResponse { + /// Consumes the `WebSocketResponse` and returns the inner `rquest::RespWebSocketResponseonse`. + /// + /// # Returns + /// + /// A `PyResult` containing the inner `rquest::WebSocketResponse` if successful, or an error if the + /// response has already been taken or cannot be unwrapped. + /// + /// # Errors + /// + /// Returns a memory error if the response has already been taken or if the `Arc` cannot be unwrapped. + #[inline(always)] + #[allow(clippy::wrong_self_convention)] + fn into_inner(&self) -> PyResult { + self.response + .swap(None) + .and_then(Arc::into_inner) + .ok_or_else(memory_error) + } +} + +type Sender = Arc>>; +type Receiver = Arc>>; + +/// A WebSocket connection. +#[gen_stub_pyclass] +#[pyclass] +pub struct WebSocket { + protocol: Option, + sender: Sender, + receiver: Receiver, +} + +impl From for WebSocket { + fn from(ws: rquest::WebSocket) -> Self { + let protocol = ws.protocol().map(ToOwned::to_owned); + let (sender, receiver) = ws.split(); + WebSocket { + protocol, + sender: Arc::new(Mutex::new(sender)), + receiver: Arc::new(Mutex::new(receiver)), + } + } +} + +#[pymethods] +impl WebSocket { + /// Returns the WebSocket protocol. + /// + /// # Returns + /// + /// An optional string representing the WebSocket protocol. + pub fn protocol(&self) -> Option<&str> { + self.protocol.as_deref() + } + + /// Receives a message from the WebSocket. + /// + /// # Arguments + /// + /// * `py` - The Python runtime. + /// + /// # Returns + /// + /// A `PyResult` containing a `Bound` object with the received message, or `None` if no message is received. + pub fn recv<'rt>(&self, py: Python<'rt>) -> PyResult> { + let websocket = self.receiver.clone(); + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let mut ws = websocket.lock().await; + if let Ok(Some(val)) = ws.try_next().await { + return Ok(Some(Message(val))); + } + Ok(None) + }) + } + + /// Sends a message to the WebSocket. + /// + /// # Arguments + /// + /// * `py` - The Python runtime. + /// * `message` - The message to send. + /// + /// # Returns + /// + /// A `PyResult` containing a `Bound` object. + #[pyo3(signature = (message))] + pub fn send<'rt>(&self, py: Python<'rt>, message: Message) -> PyResult> { + let sender = self.sender.clone(); + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let mut ws = sender.lock().await; + ws.send(message.0).await.map_err(wrap_rquest_error) + }) + } + + /// Closes the WebSocket connection. + /// + /// # Arguments + /// + /// * `py` - The Python runtime. + /// * `code` - An optional close code. + /// * `reason` - An optional reason for closing. + /// + /// # Returns + /// + /// A `PyResult` containing a `Bound` object. + #[pyo3(signature = (code=None, reason=None))] + pub fn close<'rt>( + &self, + py: Python<'rt>, + code: Option, + reason: Option, + ) -> PyResult> { + let sender = self.sender.clone(); + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let mut sender = sender.lock().await; + sender + .send(rquest::Message::Close { + code: rquest::CloseCode::from(code.unwrap_or_default()), + reason, + }) + .await + .map_err(wrap_rquest_error)?; + Ok(()) + }) + } +} + +#[pymethods] +impl WebSocket { + fn __aenter__<'a>(slf: PyRef<'a, Self>, py: Python<'a>) -> PyResult> { + let slf = slf.into_py_any(py)?; + pyo3_async_runtimes::tokio::future_into_py(py, async move { Ok(slf) }) + } + + fn __aexit__<'a>( + &'a mut self, + py: Python<'a>, + _exc_type: &Bound<'a, PyAny>, + _exc_value: &Bound<'a, PyAny>, + _traceback: &Bound<'a, PyAny>, + ) -> PyResult> { + self.close(py, None, None) + } +} + +/// A WebSocket message. +#[gen_stub_pyclass] +#[pyclass] +#[derive(Clone)] +pub struct Message(rquest::Message); + +#[pymethods] +impl Message { + /// Returns the JSON representation of the message. + /// + /// # Returns + /// + /// A `PyResult` containing the JSON representation of the message. + pub fn json(&self) -> PyResult { + self.0.json::().map_err(wrap_rquest_error) + } + + /// Returns the data of the message as bytes. + /// + /// # Returns + /// + /// A byte slice representing the data of the message. + #[getter] + pub fn data(&self) -> &[u8] { + match &self.0 { + rquest::Message::Text(text) => text.as_bytes(), + rquest::Message::Binary(data) => data, + rquest::Message::Ping(data) => data, + rquest::Message::Pong(data) => data, + _ => &[], + } + } + + /// Returns the binary data of the message if it is a binary message. + /// + /// # Returns + /// + /// An optional byte slice representing the binary data of the message. + #[getter] + pub fn binary(&self) -> Option<&[u8]> { + match &self.0 { + rquest::Message::Binary(data) => Some(data), + _ => None, + } + } + + /// Returns the ping data of the message if it is a ping message. + /// + /// # Returns + /// + /// An optional byte slice representing the ping data of the message. + #[getter] + pub fn ping(&self) -> Option<&[u8]> { + match &self.0 { + rquest::Message::Ping(data) => Some(data), + _ => None, + } + } + + /// Returns the pong data of the message if it is a pong message. + /// + /// # Returns + /// + /// An optional byte slice representing the pong data of the message. + #[getter] + pub fn pong(&self) -> Option<&[u8]> { + match &self.0 { + rquest::Message::Pong(data) => Some(data), + _ => None, + } + } +} + +#[gen_stub_pymethods] +#[pymethods] +impl Message { + /// Returns a string representation of the message. + /// + /// # Returns + /// + /// A string representing the message. + fn __str__(&self) -> String { + format!("{:?}", self.0) + } + + /// Returns a string representation of the message. + /// + /// # Returns + /// + /// A string representing the message. + fn __repr__(&self) -> String { + self.__str__() + } + + /// Returns the text of the message if it is a text message. + /// + /// # Returns + /// + /// An optional string representing the text of the message. + #[getter] + pub fn text(&self) -> Option<&str> { + match &self.0 { + rquest::Message::Text(text) => Some(text), + _ => None, + } + } + + /// Returns the close code and reason of the message if it is a close message. + /// + /// # Returns + /// + /// An optional tuple containing the close code and reason. + #[getter] + pub fn close(&self) -> Option<(u16, Option<&str>)> { + match &self.0 { + rquest::Message::Close { code, reason } => Some(( + match *code { + rquest::CloseCode::Normal => 1000, + rquest::CloseCode::Away => 1001, + rquest::CloseCode::Protocol => 1002, + rquest::CloseCode::Unsupported => 1003, + rquest::CloseCode::Status => 1005, + rquest::CloseCode::Abnormal => 1006, + rquest::CloseCode::Invalid => 1007, + rquest::CloseCode::Policy => 1008, + rquest::CloseCode::Size => 1009, + rquest::CloseCode::Extension => 1010, + rquest::CloseCode::Error => 1011, + rquest::CloseCode::Restart => 1012, + rquest::CloseCode::Again => 1013, + rquest::CloseCode::Tls => 1015, + rquest::CloseCode::Reserved(v) + | rquest::CloseCode::Iana(v) + | rquest::CloseCode::Library(v) + | rquest::CloseCode::Bad(v) => v, + _ => return None, + }, + reason.as_deref(), + )), + _ => None, + } + } + + /// Creates a new text message. + /// + /// # Arguments + /// + /// * `text` - The text content of the message. + /// + /// # Returns + /// + /// A new `Message` instance containing the text message. + #[staticmethod] + #[pyo3(signature = (text))] + #[inline] + pub fn from_text(text: &str) -> Self { + Message(rquest::Message::Text(text.to_owned())) + } + + /// Creates a new binary message. + /// + /// # Arguments + /// + /// * `data` - The binary data of the message. + /// + /// # Returns + /// + /// A new `Message` instance containing the binary message. + #[staticmethod] + #[pyo3(signature = (data))] + #[inline] + pub fn from_binary(data: Vec) -> Self { + Message(rquest::Message::Binary(data)) + } + + /// Creates a new ping message. + /// + /// # Arguments + /// + /// * `data` - The ping data of the message. + /// + /// # Returns + /// + /// A new `Message` instance containing the ping message. + #[staticmethod] + #[pyo3(signature = (data))] + #[inline] + pub fn from_ping(data: Vec) -> Self { + Message(rquest::Message::Ping(data)) + } + + /// Creates a new pong message. + /// + /// # Arguments + /// + /// * `data` - The pong data of the message. + /// + /// # Returns + /// + /// A new `Message` instance containing the pong message. + #[staticmethod] + #[pyo3(signature = (data))] + #[inline] + pub fn from_pong(data: Vec) -> Self { + Message(rquest::Message::Pong(data)) + } + + /// Creates a new close message. + /// + /// # Arguments + /// + /// * `code` - The close code. + /// * `reason` - An optional reason for closing. + /// + /// # Returns + /// + /// A new `Message` instance containing the close message. + #[staticmethod] + #[pyo3(signature = (code, reason=None))] + #[inline] + pub fn from_close(code: u16, reason: Option) -> Self { + Message(rquest::Message::Close { + code: rquest::CloseCode::from(code), + reason, + }) + } +}