From 4394d80bf315caccbd838e938d0722cbc45487e5 Mon Sep 17 00:00:00 2001 From: 0x676e67 Date: Thu, 6 Mar 2025 12:37:11 +0800 Subject: [PATCH] fix(stream): fix asynchronous stream sending --- src/typing/body.rs | 22 ++++++++++++++-------- src/typing/multipart/part.rs | 28 ++++++++++++++++++---------- tests/request_test.py | 16 +++++++++++++++- tests/response_test.py | 8 ++++---- 4 files changed, 51 insertions(+), 23 deletions(-) diff --git a/src/typing/body.rs b/src/typing/body.rs index 9c6c2503..03bd70df 100644 --- a/src/typing/body.rs +++ b/src/typing/body.rs @@ -38,19 +38,25 @@ impl TryFrom for rquest::Body { impl FromPyObject<'_> for FromPyBody { fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult { if let Ok(text) = ob.extract::() { - Ok(Self::Text(Bytes::from(text))) - } else if let Ok(bytes) = ob.downcast::() { - Ok(Self::Bytes(Bytes::from(bytes.as_bytes().to_vec()))) - } else if let Ok(iter) = ob.extract::() { - Ok(Self::Iterator(Arc::new(ArcSwapOption::from_pointee( - SyncStream::new(iter), - )))) - } else { + return Ok(Self::Text(Bytes::from(text))); + } + + if let Ok(bytes) = ob.downcast::() { + return Ok(Self::Bytes(Bytes::from(bytes.as_bytes().to_vec()))); + } + + if ob.hasattr("asend")? { pyo3_async_runtimes::tokio::into_stream_v2(ob.to_owned()) .map(AsyncStream::new) .map(ArcSwapOption::from_pointee) .map(Arc::new) .map(Self::Stream) + } else { + ob.extract::() + .map(SyncStream::new) + .map(ArcSwapOption::from_pointee) + .map(Arc::new) + .map(Self::Iterator) } } } diff --git a/src/typing/multipart/part.rs b/src/typing/multipart/part.rs index 3a7cb779..6bfc07cc 100644 --- a/src/typing/multipart/part.rs +++ b/src/typing/multipart/part.rs @@ -103,21 +103,29 @@ impl Part { impl FromPyObject<'_> for PartData { fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult { if let Ok(text) = ob.extract::() { - Ok(Self::Text(Bytes::from(text))) - } else if let Ok(bytes) = ob.downcast::() { - Ok(Self::Bytes(Bytes::from(bytes.as_bytes().to_vec()))) - } else if let Ok(path) = ob.extract::() { - Ok(Self::File(path)) - } else if let Ok(iter) = ob.extract::() { - Ok(Self::Iterator(Arc::new(ArcSwapOption::from_pointee( - SyncStream::new(iter), - )))) - } else { + return Ok(Self::Text(Bytes::from(text))); + } + + if let Ok(bytes) = ob.downcast::() { + return Ok(Self::Bytes(Bytes::from(bytes.as_bytes().to_vec()))); + } + + if let Ok(path) = ob.extract::() { + return Ok(Self::File(path)); + } + + if ob.hasattr("asend")? { pyo3_async_runtimes::tokio::into_stream_v2(ob.to_owned()) .map(AsyncStream::new) .map(ArcSwapOption::from_pointee) .map(Arc::new) .map(Self::Stream) + } else { + ob.extract::() + .map(SyncStream::new) + .map(ArcSwapOption::from_pointee) + .map(Arc::new) + .map(Self::Iterator) } } } diff --git a/tests/request_test.py b/tests/request_test.py index 3195d705..0d7a74fd 100644 --- a/tests/request_test.py +++ b/tests/request_test.py @@ -72,7 +72,7 @@ async def test_send_bytes(): @pytest.mark.asyncio @pytest.mark.flaky(reruns=3, reruns_delay=2) -async def test_send_bytes_stream(): +async def test_send_async_bytes_stream(): async def file_bytes_stream(): with open("README.md", "rb") as f: while True: @@ -85,3 +85,17 @@ async def file_bytes_stream(): response = await client.post(url, body=file_bytes_stream()) json = await response.json() assert json["data"] in open("README.md").read() + + +@pytest.mark.asyncio +@pytest.mark.flaky(reruns=3, reruns_delay=2) +async def test_send_sync_bytes_stream(): + def file_to_bytes_stream(file_path): + with open(file_path, "rb") as f: + while chunk := f.read(1024): + yield chunk + + url = "https://httpbin.org/post" + response = await client.post(url, body=file_to_bytes_stream("README.md")) + json = await response.json() + assert json["data"] in open("README.md").read() diff --git a/tests/response_test.py b/tests/response_test.py index 63faf711..a09d81d7 100644 --- a/tests/response_test.py +++ b/tests/response_test.py @@ -32,14 +32,14 @@ async def file_to_bytes_stream(file_path): Part(name="abc", value=b"000", filename="abc.txt", mime="text/plain"), Part( name="LICENSE", - value=Path("Cargo.toml"), + value=Path("./LICENSE"), filename="LICENSE", mime="text/plain", ), Part( - name="README", - value=file_to_bytes_stream("./README.md"), - filename="README.md", + name="Cargo.toml", + value=file_to_bytes_stream("./Cargo.toml"), + filename="Cargo.toml", mime="text/plain", ), ),