Skip to content

Commit

Permalink
quic: add QuicConnector to replace connect fn
Browse files Browse the repository at this point in the history
  • Loading branch information
yayanyang committed Mar 12, 2024
1 parent 78c0f18 commit 94e258b
Show file tree
Hide file tree
Showing 4 changed files with 167 additions and 88 deletions.
145 changes: 82 additions & 63 deletions crates/ext/src/net/quic/conn.rs
Expand Up @@ -146,7 +146,7 @@ impl QuicConnState {
/// Creates a new client-side connection.
///
/// `server_name` parameter is used to verify the peer's certificate.
pub fn connect(
pub fn new_client(
server_name: Option<&str>,
laddr: SocketAddr,
raddr: SocketAddr,
Expand Down Expand Up @@ -726,6 +726,12 @@ impl QuicConnState {
raw.quiche_conn.is_established()
}

/// Returns true if the connection is closed.
pub async fn is_closed(&self) -> bool {
let raw = self.raw.lock().await;
raw.quiche_conn.is_closed()
}

/// Get reference of inner [`quiche::Connection`] type.
pub async fn to_inner_conn(&self) -> impl ops::Deref<Target = quiche::Connection> + '_ {
self.raw.lock().await.deref_map(|state| &state.quiche_conn)
Expand Down Expand Up @@ -778,86 +784,62 @@ impl Drop for QuicConnFinalizer {
}
}

/// A Quic connection between a local and a remote socket.
///
/// A `QuicConn` can either be created by connecting to an endpoint, via the [`connect`](Self::connect) method,
/// or by [accepting] a connection from a [`listener`](super::QuicListener).
///
/// You can either open a stream via the [`open_stream`](Self::open_stream) function,
/// or accept a inbound stream via the [`stream_accept`](Self::stream_accept) function
pub struct QuicConn {
inner: Arc<QuicConnFinalizer>,
}

impl Display for QuicConn {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.inner.0)
}
/// A builder for client side [`QuicConn`].
pub struct QuicConnector {
udp_group: UdpGroup,
conn_state: QuicConnState,
max_send_udp_payload_size: usize,
}

impl QuicConn {
/// Returns the source connection ID.
impl QuicConnector {
/// Create new `QuicConnector` instance with global [`syscall`](rasi::syscall::Network),
/// to create a new Quic connection connected to the specified addresses.
///
/// When there are multiple IDs, and if there is an active path, the ID used
/// on that path is returned. Otherwise the oldest ID is returned.
///
/// Note that the value returned can change throughout the connection's
/// lifetime.
pub fn source_id(&self) -> &ConnectionId<'static> {
&self.inner.0.scid
}

/// Create `QuicConn` instance from [`state`](QuicConnState).
pub(super) fn new(state: QuicConnState) -> QuicConn {
Self {
inner: Arc::new(QuicConnFinalizer(state)),
}
}

// Using global [`syscall`](rasi_syscall::Network) interface to create a new Quic
/// connection connected to the specified addresses.
///
/// see [`connect_with`](Self::connect_with) for more informations.
pub async fn connect<L: ToSocketAddrs, R: ToSocketAddrs>(
/// see [`new_with`](Self::new_with) for more informations.
pub async fn new<L: ToSocketAddrs, R: ToSocketAddrs>(
server_name: Option<&str>,
laddrs: L,
raddrs: R,
config: &mut Config,
) -> io::Result<Self> {
Self::connect_with(server_name, laddrs, raddrs, config, global_network()).await
Self::new_with(server_name, laddrs, raddrs, config, global_network()).await
}

/// Using custom [`syscall`](rasi_syscall::Network) interface to create a new Quic
/// connection connected to the specified addresses.
///
/// This method will create a new Quic socket and attempt to connect it to the `raddrs`
/// provided. The [returned future] will be resolved once the connection has successfully
/// connected, or it will return an error if one occurs.
pub async fn connect_with<L: ToSocketAddrs, R: ToSocketAddrs>(
/// Create new `QuicConnector` instance with custom [`syscall`](rasi::syscall::Network),
/// to create a new Quic connection connected to the specified addresses.
pub async fn new_with<L: ToSocketAddrs, R: ToSocketAddrs>(
server_name: Option<&str>,
laddrs: L,
raddrs: R,
config: &mut Config,
syscall: &'static dyn rasi::syscall::Network,
) -> io::Result<Self> {
let socket = UdpGroup::bind_with(laddrs, syscall).await?;
let udp_group = UdpGroup::bind_with(laddrs, syscall).await?;

let raddr = raddrs.to_socket_addrs()?.choose(&mut thread_rng()).unwrap();

let laddr = socket
let laddr = udp_group
.local_addrs()
.filter(|addr| raddr.is_ipv4() == addr.is_ipv4())
.choose(&mut thread_rng())
.unwrap();

let mut conn_state = QuicConnState::connect(server_name, *laddr, raddr, config)?;
let conn_state = QuicConnState::new_client(server_name, *laddr, raddr, config)?;

let (sender, mut receiver) = socket.split();
Ok(Self {
conn_state,
udp_group,
max_send_udp_payload_size: config.max_send_udp_payload_size,
})
}

/// Performs a real connection process.
pub async fn connect(mut self) -> io::Result<QuicConn> {
let (sender, mut receiver) = self.udp_group.split();

loop {
let mut read_buf = ReadBuf::with_capacity(config.max_send_udp_payload_size);
let mut read_buf = ReadBuf::with_capacity(self.max_send_udp_payload_size);

let (read_size, send_info) = conn_state.send(read_buf.chunk_mut()).await?;
let (read_size, send_info) = self.conn_state.send(read_buf.chunk_mut()).await?;

let send_size = sender
.send_to_on_path(
Expand All @@ -872,7 +854,7 @@ impl QuicConn {
log::trace!("Quic connection, {:?}, send data {}", send_info, send_size);

let (mut buf, path_info) =
if let Some(timeout_at) = conn_state.to_inner_conn().await.timeout_instant() {
if let Some(timeout_at) = self.conn_state.to_inner_conn().await.timeout_instant() {
match receiver.try_next().timeout_at(timeout_at).await {
Some(Ok(r)) => r.ok_or(io::Error::new(
io::ErrorKind::BrokenPipe,
Expand All @@ -894,7 +876,7 @@ impl QuicConn {

log::trace!("Quic connection, {:?}, recv data {}", path_info, buf.len());

conn_state
self.conn_state
.recv(
&mut buf,
RecvInfo {
Expand All @@ -904,20 +886,57 @@ impl QuicConn {
)
.await?;

if conn_state.is_established().await {
conn_state.update_dcid().await;
if self.conn_state.is_established().await {
self.conn_state.update_dcid().await;
break;
}
}

spawn(Self::recv_loop(conn_state.clone(), receiver));
spawn(Self::send_loop(
conn_state.clone(),
spawn(QuicConn::recv_loop(self.conn_state.clone(), receiver));
spawn(QuicConn::send_loop(
self.conn_state.clone(),
sender,
config.max_send_udp_payload_size,
self.max_send_udp_payload_size,
));

Ok(Self::new(conn_state))
Ok(QuicConn::new(self.conn_state))
}
}

/// A Quic connection between a local and a remote socket.
///
/// A `QuicConn` can either be created by connecting to an endpoint, via the [`connect`](Self::connect) method,
/// or by [accepting] a connection from a [`listener`](super::QuicListener).
///
/// You can either open a stream via the [`open_stream`](Self::open_stream) function,
/// or accept a inbound stream via the [`stream_accept`](Self::stream_accept) function
pub struct QuicConn {
inner: Arc<QuicConnFinalizer>,
}

impl Display for QuicConn {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.inner.0)
}
}

impl QuicConn {
/// Returns the source connection ID.
///
/// When there are multiple IDs, and if there is an active path, the ID used
/// on that path is returned. Otherwise the oldest ID is returned.
///
/// Note that the value returned can change throughout the connection's
/// lifetime.
pub fn source_id(&self) -> &ConnectionId<'static> {
&self.inner.0.scid
}

/// Create `QuicConn` instance from [`state`](QuicConnState).
pub(super) fn new(state: QuicConnState) -> QuicConn {
Self {
inner: Arc::new(QuicConnFinalizer(state)),
}
}

/// Accepts a new incoming stream via this connection.
Expand Down
23 changes: 22 additions & 1 deletion crates/ext/src/net/quic/listener.rs
Expand Up @@ -107,6 +107,13 @@ impl RawQuicListenerState {
}
}

/// remove connection from pool.
fn remove_conn<'a>(&mut self, id: &ConnectionId<'a>) {
let id = id.clone().into_owned();
self.handshaking_conns.remove(&id);
self.established_conns.remove(&id);
}

/// Process Initial packet.
fn handshake<'a>(
&mut self,
Expand Down Expand Up @@ -375,7 +382,21 @@ impl QuicListenerState {
// release the lock before call [QuicConnState::recv] function.
drop(raw);

let recv_size = conn.recv(buf, recv_info).await?;
let recv_size = match conn.recv(buf, recv_info).await {
Ok(recv_size) => recv_size,
Err(err) => {
if conn.is_closed().await {
// relock the state.
raw = self.raw.lock().await;

raw.remove_conn(&header.dcid);

log::info!("{}, removed from server pool.", conn);
}

return Err(err);
}
};

if !is_established && conn.is_established().await {
// relock the state.
Expand Down
55 changes: 38 additions & 17 deletions crates/ext/src/net/quic/pool.rs
Expand Up @@ -12,7 +12,12 @@ use rasi::syscall::{global_network, Network};

use crate::utils::AsyncSpinMutex;

use super::{Config, QuicConn, QuicStream};
use super::{Config, QuicConn, QuicConnector, QuicStream};

enum OpenStream {
Stream(QuicStream),
Connector(QuicConnector),
}

struct RawQuicConnPool {
config: Config,
Expand All @@ -26,7 +31,7 @@ impl RawQuicConnPool {
max_conns: usize,
raddrs: &[SocketAddr],
syscall: &'static dyn Network,
) -> io::Result<QuicStream> {
) -> io::Result<OpenStream> {
let mut conns = self.conns.values().collect::<Vec<_>>();

conns.shuffle(&mut thread_rng());
Expand All @@ -36,7 +41,7 @@ impl RawQuicConnPool {
for conn in conns {
match conn.stream_open(true).await {
Ok(stream) => {
return Ok(stream);
return Ok(OpenStream::Stream(stream));
}
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {
continue;
Expand All @@ -63,7 +68,7 @@ impl RawQuicConnPool {
));
}

let conn = QuicConn::connect_with(
let connector = QuicConnector::new_with(
server_name,
["[::]:0".parse().unwrap(), "0.0.0.0:0".parse().unwrap()].as_slice(),
raddrs,
Expand All @@ -72,12 +77,7 @@ impl RawQuicConnPool {
)
.await?;

let stream = conn.stream_open(true).await?;

// insert newly connection into pool.
self.conns.insert(conn.source_id().clone(), conn);

Ok(stream)
Ok(OpenStream::Connector(connector))
}
}

Expand Down Expand Up @@ -146,16 +146,37 @@ impl QuicConnPool {
pub async fn stream_open(&self) -> io::Result<QuicStream> {
use crate::utils::AsyncLockable;

let connector = {
let mut inner = self.inner.lock().await;

match inner
.open_stream(
self.server_name.as_ref().map(String::as_str),
self.max_conns,
&self.raddrs,
self.syscall,
)
.await?
{
OpenStream::Stream(stream) => return Ok(stream),
OpenStream::Connector(connector) => connector,
}
};

// performs real connecting process.

let connection = connector.connect().await?;

let stream = connection.stream_open(true).await?;

// relock inner.
let mut inner = self.inner.lock().await;

inner
.open_stream(
self.server_name.as_ref().map(String::as_str),
self.max_conns,
&self.raddrs,
self.syscall,
)
.await
.conns
.insert(connection.source_id().clone(), connection);

Ok(stream)
}

/// Set `max_conns` parameter.
Expand Down

0 comments on commit 94e258b

Please sign in to comment.