From 5c648aead60299f383e5dbd7397d84a7d87006ae Mon Sep 17 00:00:00 2001 From: 0x676e67 Date: Sun, 16 Mar 2025 11:09:39 +0800 Subject: [PATCH] perf(header): Use buffer protocol in `HeaderMap` to reduce copying --- src/buffer.rs | 58 +++++++++++++++++++++++++++++++++++++++++++ src/typing/headers.rs | 51 ++++++++++++++++++++++--------------- tests/client_test.py | 2 +- 3 files changed, 90 insertions(+), 21 deletions(-) diff --git a/src/buffer.rs b/src/buffer.rs index 6dbf8077..e2ef9568 100644 --- a/src/buffer.rs +++ b/src/buffer.rs @@ -19,6 +19,8 @@ use bytes::Bytes; use pyo3::IntoPyObjectExt; use pyo3::ffi; use pyo3::prelude::*; +use rquest::header::HeaderName; +use rquest::header::HeaderValue; use std::os::raw::c_int; /// A trait to define common buffer behavior @@ -98,6 +100,62 @@ impl BytesBuffer { } } +#[pyclass] +pub struct HeaderValueBuffer { + inner: HeaderValue, +} + +impl HeaderValueBuffer { + pub fn new(inner: HeaderValue) -> Self { + HeaderValueBuffer { inner } + } +} + +impl PyBufferProtocol<'_> for HeaderValueBuffer { + fn as_slice(&self) -> &[u8] { + self.inner.as_bytes() + } +} + +#[pymethods] +impl HeaderValueBuffer { + unsafe fn __getbuffer__( + slf: PyRefMut, + view: *mut ffi::Py_buffer, + flags: c_int, + ) -> PyResult<()> { + unsafe { fill_buffer_info(slf.as_slice(), slf.as_ptr(), view, flags, slf.py()) } + } +} + +#[pyclass] +pub struct HeaderNameBuffer { + inner: HeaderName, +} + +impl HeaderNameBuffer { + pub fn new(inner: HeaderName) -> Self { + HeaderNameBuffer { inner } + } +} + +impl PyBufferProtocol<'_> for HeaderNameBuffer { + fn as_slice(&self) -> &[u8] { + self.inner.as_ref() + } +} + +#[pymethods] +impl HeaderNameBuffer { + unsafe fn __getbuffer__( + slf: PyRefMut, + view: *mut ffi::Py_buffer, + flags: c_int, + ) -> PyResult<()> { + unsafe { fill_buffer_info(slf.as_slice(), slf.as_ptr(), view, flags, slf.py()) } + } +} + /// A helper function to fill buffer info unsafe fn fill_buffer_info( bytes: &[u8], diff --git a/src/typing/headers.rs b/src/typing/headers.rs index 99425435..0dc9ed8e 100644 --- a/src/typing/headers.rs +++ b/src/typing/headers.rs @@ -1,5 +1,5 @@ use crate::{ - buffer::{Buffer, PyBufferProtocol}, + buffer::{HeaderNameBuffer, HeaderValueBuffer, PyBufferProtocol}, error::{wrap_invali_header_name_error, wrap_invali_header_value_error}, }; use pyo3::{ @@ -22,28 +22,32 @@ impl HeaderMap { #[inline] fn __getitem__<'py>(&self, py: Python<'py>, key: PyBackedStr) -> Option> { let value = self.0.get(key.as_ref() as &str)?; - let buffer = Buffer::new(value.as_bytes().to_vec()); + let buffer = HeaderValueBuffer::new(value.clone()); buffer.into_bytes_ref(py).ok() } #[inline] - fn __setitem__(&mut self, key: PyBackedStr, value: PyBackedStr) { - if let (Ok(name), Ok(value)) = ( - HeaderName::from_bytes(key.as_bytes()), - HeaderValue::from_bytes(value.as_bytes()), - ) { - self.0.insert(name, value); - } + fn __setitem__(&mut self, py: Python, key: PyBackedStr, value: PyBackedStr) { + py.allow_threads(|| { + if let (Ok(name), Ok(value)) = ( + HeaderName::from_bytes(key.as_bytes()), + HeaderValue::from_bytes(value.as_bytes()), + ) { + self.0.insert(name, value); + } + }) } #[inline] - fn __delitem__(&mut self, key: PyBackedStr) { - self.0.remove(key.as_ref() as &str); + fn __delitem__(&mut self, py: Python, key: PyBackedStr) { + py.allow_threads(|| { + self.0.remove(key.as_ref() as &str); + }) } #[inline] - fn __contains__(&self, key: PyBackedStr) -> bool { - self.0.contains_key(key.as_ref() as &str) + fn __contains__(&self, py: Python, key: PyBackedStr) -> bool { + py.allow_threads(|| self.0.contains_key(key.as_ref() as &str)) } #[inline] @@ -87,13 +91,15 @@ pub struct HeaderMapKeysIter { #[pymethods] impl HeaderMapKeysIter { #[inline] - fn __iter__(slf: PyRefMut) -> PyRefMut { + fn __iter__<'py>(slf: PyRefMut<'py, Self>) -> PyRefMut<'py, Self> { slf } #[inline] - fn __next__(mut slf: PyRefMut) -> Option { - slf.inner.pop().map(|k| k.to_string()) + fn __next__(mut slf: PyRefMut) -> Option> { + slf.inner + .pop() + .and_then(|k| HeaderNameBuffer::new(k).into_bytes_ref(slf.py()).ok()) } } @@ -113,10 +119,15 @@ impl HeaderMapItemsIter { } #[inline] - fn __next__(mut slf: PyRefMut) -> Option<(String, Option)> { - slf.inner - .pop() - .map(|(k, v)| (k.to_string(), v.to_str().ok().map(String::from))) + fn __next__<'py>( + mut slf: PyRefMut<'py, Self>, + ) -> Option<(Bound<'py, PyAny>, Option>)> { + if let Some((k, v)) = slf.inner.pop() { + let key = HeaderNameBuffer::new(k).into_bytes_ref(slf.py()).ok()?; + let value = HeaderValueBuffer::new(v).into_bytes_ref(slf.py()).ok(); + return Some((key, value)); + } + None } } diff --git a/tests/client_test.py b/tests/client_test.py index afdafb86..24310750 100644 --- a/tests/client_test.py +++ b/tests/client_test.py @@ -9,7 +9,7 @@ async def test_update_headers(): client = rnet.Client() headers = {"user-agent": "rnet"} client.update(headers=headers) - assert client.headers == {"user-agent": b"rnet"} + assert client.headers["user-agent"] == b"rnet" @pytest.mark.asyncio