Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions src/typing/body.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,25 @@ impl TryFrom<FromPyBody> for rquest::Body {
impl FromPyObject<'_> for FromPyBody {
fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult<Self> {
if let Ok(text) = ob.extract::<String>() {
Ok(Self::Text(Bytes::from(text)))
} else if let Ok(bytes) = ob.downcast::<PyBytes>() {
Ok(Self::Bytes(Bytes::from(bytes.as_bytes().to_vec())))
} else if let Ok(iter) = ob.extract::<PyObject>() {
Ok(Self::Iterator(Arc::new(ArcSwapOption::from_pointee(
SyncStream::new(iter),
))))
} else {
return Ok(Self::Text(Bytes::from(text)));
}

if let Ok(bytes) = ob.downcast::<PyBytes>() {
return Ok(Self::Bytes(Bytes::from(bytes.as_bytes().to_vec())));
}

if ob.hasattr("asend")? {
pyo3_async_runtimes::tokio::into_stream_v2(ob.to_owned())
.map(AsyncStream::new)
.map(ArcSwapOption::from_pointee)
.map(Arc::new)
.map(Self::Stream)
} else {
ob.extract::<PyObject>()
.map(SyncStream::new)
.map(ArcSwapOption::from_pointee)
.map(Arc::new)
.map(Self::Iterator)
}
}
}
28 changes: 18 additions & 10 deletions src/typing/multipart/part.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,21 +103,29 @@ impl Part {
impl FromPyObject<'_> for PartData {
fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult<Self> {
if let Ok(text) = ob.extract::<String>() {
Ok(Self::Text(Bytes::from(text)))
} else if let Ok(bytes) = ob.downcast::<PyBytes>() {
Ok(Self::Bytes(Bytes::from(bytes.as_bytes().to_vec())))
} else if let Ok(path) = ob.extract::<PathBuf>() {
Ok(Self::File(path))
} else if let Ok(iter) = ob.extract::<PyObject>() {
Ok(Self::Iterator(Arc::new(ArcSwapOption::from_pointee(
SyncStream::new(iter),
))))
} else {
return Ok(Self::Text(Bytes::from(text)));
}

if let Ok(bytes) = ob.downcast::<PyBytes>() {
return Ok(Self::Bytes(Bytes::from(bytes.as_bytes().to_vec())));
}

if let Ok(path) = ob.extract::<PathBuf>() {
return Ok(Self::File(path));
}

if ob.hasattr("asend")? {
pyo3_async_runtimes::tokio::into_stream_v2(ob.to_owned())
.map(AsyncStream::new)
.map(ArcSwapOption::from_pointee)
.map(Arc::new)
.map(Self::Stream)
} else {
ob.extract::<PyObject>()
.map(SyncStream::new)
.map(ArcSwapOption::from_pointee)
.map(Arc::new)
.map(Self::Iterator)
}
}
}
16 changes: 15 additions & 1 deletion tests/request_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ async def test_send_bytes():

@pytest.mark.asyncio
@pytest.mark.flaky(reruns=3, reruns_delay=2)
async def test_send_bytes_stream():
async def test_send_async_bytes_stream():
async def file_bytes_stream():
with open("README.md", "rb") as f:
while True:
Expand All @@ -85,3 +85,17 @@ async def file_bytes_stream():
response = await client.post(url, body=file_bytes_stream())
json = await response.json()
assert json["data"] in open("README.md").read()


@pytest.mark.asyncio
@pytest.mark.flaky(reruns=3, reruns_delay=2)
async def test_send_sync_bytes_stream():
def file_to_bytes_stream(file_path):
with open(file_path, "rb") as f:
while chunk := f.read(1024):
yield chunk

url = "https://httpbin.org/post"
response = await client.post(url, body=file_to_bytes_stream("README.md"))
json = await response.json()
assert json["data"] in open("README.md").read()
8 changes: 4 additions & 4 deletions tests/response_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@ async def file_to_bytes_stream(file_path):
Part(name="abc", value=b"000", filename="abc.txt", mime="text/plain"),
Part(
name="LICENSE",
value=Path("Cargo.toml"),
value=Path("./LICENSE"),
filename="LICENSE",
mime="text/plain",
),
Part(
name="README",
value=file_to_bytes_stream("./README.md"),
filename="README.md",
name="Cargo.toml",
value=file_to_bytes_stream("./Cargo.toml"),
filename="Cargo.toml",
mime="text/plain",
),
),
Expand Down