Skip to content

Commit

Permalink
examples: update uds example to use tokio UnixListenerStream
Browse files Browse the repository at this point in the history
tokio-stream packages a UnixListenerStream that implements
futures_core::Stream. Using this cuts down on consumer boilerplate
when using UnixStreams with a tonic server.

Refs: hyperium#856

Signed-off-by: Anthony Green <agreen@starry.com>
  • Loading branch information
agreen17 committed Dec 8, 2021
1 parent 5bd23d6 commit 0746070
Showing 1 changed file with 15 additions and 78 deletions.
93 changes: 15 additions & 78 deletions examples/src/uds/server.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
#![cfg_attr(not(unix), allow(unused_imports))]

use futures::TryFutureExt;
use std::path::Path;
#[cfg(unix)]
use tokio::net::UnixListener;
use tonic::{transport::Server, Request, Response, Status};
#[cfg(unix)]
use tokio_stream::wrappers::UnixListenerStream;
use tonic::{
transport::{server::UdsConnectInfo, Server},
Request, Response, Status,
};

pub mod hello_world {
tonic::include_proto!("helloworld");
Expand All @@ -26,8 +30,13 @@ impl Greeter for MyGreeter {
) -> Result<Response<HelloReply>, Status> {
#[cfg(unix)]
{
let conn_info = request.extensions().get::<unix::UdsConnectInfo>().unwrap();
let conn_info = request.extensions().get::<UdsConnectInfo>().unwrap();
println!("Got a request {:?} with info {:?}", request, conn_info);

// Client-side unix sockets are unnamed.
assert!(conn_info.peer_addr.as_ref().unwrap().is_unnamed());
// This should contain process credentials for the client socket.
assert!(conn_info.peer_cred.as_ref().is_some());
}

let reply = hello_world::HelloReply {
Expand All @@ -46,89 +55,17 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {

let greeter = MyGreeter::default();

let incoming = {
let uds = UnixListener::bind(path)?;

async_stream::stream! {
loop {
let item = uds.accept().map_ok(|(st, _)| unix::UnixStream(st)).await;

yield item;
}
}
};
let uds = UnixListener::bind(path)?;
let uds_stream = UnixListenerStream::new(uds);

Server::builder()
.add_service(GreeterServer::new(greeter))
.serve_with_incoming(incoming)
.serve_with_incoming(uds_stream)
.await?;

Ok(())
}

#[cfg(unix)]
mod unix {
use std::{
pin::Pin,
sync::Arc,
task::{Context, Poll},
};

use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tonic::transport::server::Connected;

#[derive(Debug)]
pub struct UnixStream(pub tokio::net::UnixStream);

impl Connected for UnixStream {
type ConnectInfo = UdsConnectInfo;

fn connect_info(&self) -> Self::ConnectInfo {
UdsConnectInfo {
peer_addr: self.0.peer_addr().ok().map(Arc::new),
peer_cred: self.0.peer_cred().ok(),
}
}
}

#[derive(Clone, Debug)]
pub struct UdsConnectInfo {
pub peer_addr: Option<Arc<tokio::net::unix::SocketAddr>>,
pub peer_cred: Option<tokio::net::unix::UCred>,
}

impl AsyncRead for UnixStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.0).poll_read(cx, buf)
}
}

impl AsyncWrite for UnixStream {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
Pin::new(&mut self.0).poll_write(cx, buf)
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.0).poll_flush(cx)
}

fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.0).poll_shutdown(cx)
}
}
}

#[cfg(not(unix))]
fn main() {
panic!("The `uds` example only works on unix systems!");
Expand Down

0 comments on commit 0746070

Please sign in to comment.