Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 13 additions & 20 deletions client/lib/src/device/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,6 @@ pub async fn tun_read_handle(peers: &Arc<RwLock<Peers>>, udp4: &UdpSocket, udp6:
} else {
tracing::error!("No endpoint");
}
//TODO: get tcp socket from peers and send
}
_ => panic!("Unexpected result from encapsulate"),
};
Expand Down Expand Up @@ -342,6 +341,7 @@ pub async fn tcp_peers_timer(
) {
let mut interval = time::interval(Duration::from_millis(250));
let mut dst_buf: Vec<u8>= vec![0; MAX_UDP_SIZE];
let error_recovery_duration = Duration::from_secs(10);

loop {
interval.tick().await;
Expand All @@ -353,7 +353,10 @@ pub async fn tcp_peers_timer(
None => continue,
};
match &mut p.endpoint.tcp_conn {
TcpConnection::Nothing | TcpConnection::ConnectedFailure(_) => {
TcpConnection::ConnectedFailure(_, time) if time.elapsed().map(|x| x < error_recovery_duration).unwrap_or(false) => {
continue;
}
TcpConnection::Nothing | TcpConnection::ConnectedFailure(..) => {
if node_type == NodeType::NodeClient || ip < &p.ip {
p.endpoint.tcp_conn = TcpConnection::Connecting(SystemTime::now());
match TcpStream::connect(&endpoint_addr).await {
Expand All @@ -364,7 +367,7 @@ pub async fn tcp_peers_timer(
},
Err(error) => {
tracing::debug!("connect {endpoint_addr:?} failure, error: {error:?}");
p.endpoint.tcp_conn = TcpConnection::ConnectedFailure(error)
p.endpoint.tcp_conn = TcpConnection::ConnectedFailure(error, SystemTime::now());
}
};
}
Expand All @@ -384,10 +387,7 @@ pub async fn tcp_peers_timer(
}
TunnResult::Err(e) => tracing::error!(message = "Timer error", error = ?e),
TunnResult::WriteToNetwork(packet) => {
if let TcpConnection::Connected(connection) = &mut p.endpoint.tcp_conn {
let _ = connection.write_all(packet).await;
}

p.endpoint.tcp_write(packet).await;
}
_ => tracing::warn!("Unexpected result from update_timers"),
};
Expand Down Expand Up @@ -571,11 +571,7 @@ pub fn tcp_handler(
},
WriterState::PeerWriter(peer)=> {
let mut p = peer.lock().await;
if let TcpConnection::Connected(w) = &mut p.endpoint.tcp_conn {
let _ = w.write_all(cookie).await;
}else {
tracing::warn!("should not come here");
}
p.endpoint.tcp_write(cookie).await;
}
}
continue;
Expand All @@ -601,7 +597,7 @@ pub fn tcp_handler(
};

let mut p = peer.lock().await;
if let TcpConnection::Nothing | TcpConnection::ConnectedFailure(_) = p.endpoint.tcp_conn {
if let TcpConnection::Nothing | TcpConnection::ConnectedFailure(..) = p.endpoint.tcp_conn {
if let WriterState::PureWriter(_) = &mut writer {
let pure_writer = mem::replace(&mut writer,WriterState::PeerWriter(peer.clone()));
if let WriterState::PureWriter(_writer) = pure_writer {
Expand All @@ -619,10 +615,7 @@ pub fn tcp_handler(
TunnResult::Err(_) => continue,
TunnResult::WriteToNetwork(packet) => {
flush = true;

if let TcpConnection::Connected(conn) = &mut p.endpoint.tcp_conn {
let _ = conn.write_all(packet).await;
}
p.endpoint.tcp_write(packet).await;
}
TunnResult::WriteToTunnelV4(packet, addr) => {
// tracing::debug!("{addr:?}");
Expand Down Expand Up @@ -680,15 +673,15 @@ pub fn tcp_handler(
while let TunnResult::WriteToNetwork(packet) =
p.tunnel.decapsulate(None, &[], &mut dst_buf[..])
{
if let TcpConnection::Connected(conn) = &mut p.endpoint.tcp_conn {
let _ = conn.write_all(packet).await;
}
p.endpoint.tcp_write(packet).await;
}
}
}
}
tracing::info!("tcp: {addr:?} close");
});


}


Expand Down
22 changes: 19 additions & 3 deletions client/lib/src/device/peer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use std::sync::Arc;
use std::time::SystemTime;

use boringtun::noise::{Tunn, TunnResult};
use tokio::io::AsyncWriteExt;
use tokio::net::{UdpSocket};
use tokio::net::tcp::OwnedWriteHalf;
use crate::device::allowed_ips::AllowedIps;
Expand All @@ -20,7 +21,7 @@ pub enum TcpConnection {
Nothing,
Connecting(SystemTime),
Connected(OwnedWriteHalf),
ConnectedFailure(std::io::Error)
ConnectedFailure(std::io::Error, SystemTime),
}
#[derive(Debug)]
pub struct Endpoint {
Expand All @@ -29,6 +30,22 @@ pub struct Endpoint {
pub tcp_conn: TcpConnection,
}

impl Endpoint {
pub async fn tcp_write(&mut self, bytes:&[u8]) {
if let TcpConnection::Connected(conn) = &mut self.tcp_conn {
match conn.write_all(bytes).await {
Ok(_) => {
// do nothing
},
Err(e) => {
tracing::error!("tcp conn of {:?} fail, error: {}", conn.peer_addr(), e);
self.tcp_conn = TcpConnection::ConnectedFailure(e, SystemTime::now());
}
};
}
}
}

pub struct Peer {
/// The associated tunnel struct
pub(crate) tunnel: Tunn,
Expand Down Expand Up @@ -104,8 +121,7 @@ impl Peer {
pub fn shutdown_endpoint(&mut self) {
if let Some(_) = &mut self.endpoint.udp_conn.take() {
tracing::info!("disconnecting from endpoint");
}
if let TcpConnection::Connected(_) = &mut self.endpoint.tcp_conn {
} else if let TcpConnection::Connected(_) = &mut self.endpoint.tcp_conn {
tracing::info!("disconnecting tcp connection");
}
self.endpoint.tcp_conn = TcpConnection::Nothing;
Expand Down