Skip to content

Commit

Permalink
Add Tokio halves support
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-smoktal committed May 13, 2024
1 parent aace5b3 commit 4d15a43
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 4 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "async-send-fd"
version = "1.0.1"
version = "1.1.0"
edition = "2021"
authors = ["Alexander Smoktal [https://github.com/alexander-smoktal]"]
license = "MIT"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ Smol [UnixStream](https://docs.rs/smol/2.0.0/smol/net/unix/struct.UnixStream.htm
## Transfering socket pair ownership
Sending a descriptor doesn't close the local copy, which leads to having the socket being opened by the sender until it shuts down.
If you want socket pair receivers to detect peer shutdown, you have to close local sockets after sending them.
Use [close](https://docs.rs/nix/latest/nix/unistd/fn.close.html) Posix call for Tokio streams, or [UnixStream::shutdown()](https://docs.rs/smol/2.0.0/smol/net/unix/struct.UnixStream.html#method.shutdown) for Smol.
Use [UnixStream::poll_shutdown()](https://docs.rs/tokio/latest/tokio/net/struct.UnixStream.html#method.poll_shutdown) for Tokio streams, or [UnixStream::shutdown()](https://docs.rs/smol/2.0.0/smol/net/unix/struct.UnixStream.html#method.shutdown) for Smol.

## Features
- `tokio` - for Tokio support
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
//! opened by the sender until it shuts down.
//! If you want socket pair receivers to detect peer shutdown, you have to close local sockets after sending them.
//!
//! Use [close](https://docs.rs/nix/latest/nix/unistd/fn.close.html) Posix call for Tokio streams, or [UnixStream::shutdown()](smol::net::unix::UnixStream::shutdown) for Smol.
//! Use [UnixStream::poll_shutdown()](https://docs.rs/tokio/latest/tokio/net/struct.UnixStream.html#method.poll_shutdown) for Tokio streams, or [UnixStream::shutdown()](smol::net::unix::UnixStream::shutdown) for Smol.
//!
//! ## Features
//! - `tokio` - for Tokio support
Expand Down
56 changes: 55 additions & 1 deletion src/tokio_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@ use std::{
},
};

use tokio::{io::Interest, net::UnixStream};
use tokio::{
io::Interest,
net::{
unix::{OwnedReadHalf, OwnedWriteHalf, ReadHalf, WriteHalf},
UnixStream,
},
};

use passfd::FdPassingExt;

Expand Down Expand Up @@ -72,3 +78,51 @@ impl AsyncRecvTokioStream for UnixStream {
UnixStream::from_std(os_stream)
}
}

impl AsyncRecvFd for ReadHalf<'_> {
async fn recv_fd(&self) -> Result<RawFd, Error> {
self.as_ref().recv_fd().await
}
}

impl AsyncRecvTokioStream for ReadHalf<'_> {
async fn recv_stream(&self) -> Result<UnixStream, Error> {
self.as_ref().recv_stream().await
}
}

impl AsyncSendFd for WriteHalf<'_> {
async fn send_fd(&self, fd: RawFd) -> Result<(), Error> {
self.as_ref().send_fd(fd).await
}
}

impl AsyncSendTokioStream for WriteHalf<'_> {
async fn send_stream(&self, stream: UnixStream) -> Result<(), Error> {
self.as_ref().send_stream(stream).await
}
}

impl AsyncRecvFd for OwnedReadHalf {
async fn recv_fd(&self) -> Result<RawFd, Error> {
self.as_ref().recv_fd().await
}
}

impl AsyncRecvTokioStream for OwnedReadHalf {
async fn recv_stream(&self) -> Result<UnixStream, Error> {
self.as_ref().recv_stream().await
}
}

impl AsyncSendFd for OwnedWriteHalf {
async fn send_fd(&self, fd: RawFd) -> Result<(), Error> {
self.as_ref().send_fd(fd).await
}
}

impl AsyncSendTokioStream for OwnedWriteHalf {
async fn send_stream(&self, stream: UnixStream) -> Result<(), Error> {
self.as_ref().send_stream(stream).await
}
}
58 changes: 58 additions & 0 deletions tests/test_tokio_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,61 @@ async fn send_tokio_stream_test() {

let _ = std::fs::remove_dir(sock_path);
}

#[tokio::test]
async fn tokio_halves_test() {
let tmp_dir = TempDir::new("tokio-send-fd").unwrap();

let sock_path = tmp_dir.path().join(SOCKET_NAME);
let sock_path1 = sock_path.clone();
let sock_path2 = sock_path.clone();

println!("Start listening at: {:?}", sock_path1);
let listener = UnixListener::bind(sock_path1).unwrap();

let j1 = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();

println!("Incoming peer connection");
let (left, right) = OsUnixStream::pair().unwrap();

let (_, sender) = stream.into_split();
println!("Sending peer fd");
sender.send_fd(left.as_raw_fd()).await.unwrap();
println!("Succesfullt sent peer fd");

right.set_nonblocking(true).unwrap();
let mut peer_stream = UnixStream::from_std(right).unwrap();
let mut buffer = [0u8; 4];

println!("Reading data from the peer");
assert!(peer_stream.read(&mut buffer).await.unwrap() == 4);

println!("Message sent through a socket: {:?}", buffer);
});

let j2 = tokio::spawn(async move {
println!("Connection to the sender");
let stream = UnixStream::connect(sock_path2).await.unwrap();

let (receiver, _) = stream.into_split();
println!("Succesfully connected to the sender. Reading file descriptor");
let fd = receiver.recv_fd().await.unwrap();
println!("Succesfully read file descriptor");

let os_stream = unsafe { OsUnixStream::from_raw_fd(fd) };
// XXX: Don't forget to make this non-blocking. This gonna save you several days of debugging
os_stream.set_nonblocking(true).unwrap();

let mut peer_stream = UnixStream::from_std(os_stream).unwrap();

println!("Sending data to the peer");
let buffer: [u8; 4] = [0, 0, 0, 42];
peer_stream.write(&buffer).await.unwrap();
println!("Succesfully sent data to the peer");
});

tokio::try_join!(j1, j2).unwrap();

let _ = std::fs::remove_dir(sock_path);
}

0 comments on commit 4d15a43

Please sign in to comment.