Skip to content

Commit

Permalink
Remove Tun and Sock traits
Browse files Browse the repository at this point in the history
With due respect to 2d86264 where these traits were added, they
aren't used and the code is easier to read without them. If these
become needed, it should be straightforward to add them back.
  • Loading branch information
agrover committed Jul 21, 2021
1 parent e78a23a commit 32bde17
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 126 deletions.
17 changes: 7 additions & 10 deletions src/device/api.rs
Expand Up @@ -4,7 +4,7 @@
use super::dev_lock::LockReadGuard;
use super::drop_privileges::get_saved_ids;
use super::{make_array, AllowedIP, Device, Error, SocketAddr, X25519PublicKey, X25519SecretKey};
use crate::device::{Action, Sock, Tun};
use crate::device::Action;
use hex::encode as encode_hex;
use libc::*;
use std::fs::{create_dir, remove_file};
Expand Down Expand Up @@ -32,7 +32,7 @@ fn create_sock_dir() {
}
}

impl<T: Tun, S: Sock> Device<T, S> {
impl Device {
/// Register the api handler for this Device. The api handler receives stream connections on a Unix socket
/// with a known path: /var/run/wireguard/{tun_name}.sock.
pub fn register_api_handler(&mut self) -> Result<(), Error> {
Expand Down Expand Up @@ -118,7 +118,7 @@ impl<T: Tun, S: Sock> Device<T, S> {
}

#[allow(unused_must_use)]
fn api_get<T: Tun, S: Sock>(writer: &mut BufWriter<&UnixStream>, d: &Device<T, S>) -> i32 {
fn api_get(writer: &mut BufWriter<&UnixStream>, d: &Device) -> i32 {
// get command requires an empty line, but there is no reason to be religious about it
if let Some(ref k) = d.key_pair {
writeln!(writer, "private_key={}", encode_hex(k.0.as_bytes()));
Expand Down Expand Up @@ -164,10 +164,7 @@ fn api_get<T: Tun, S: Sock>(writer: &mut BufWriter<&UnixStream>, d: &Device<T, S
0
}

fn api_set<T: Tun, S: Sock>(
reader: &mut BufReader<&UnixStream>,
d: &mut LockReadGuard<Device<T, S>>,
) -> i32 {
fn api_set(reader: &mut BufReader<&UnixStream>, d: &mut LockReadGuard<Device>) -> i32 {
d.try_writeable(
|device| device.trigger_yield(),
|device| {
Expand Down Expand Up @@ -229,9 +226,9 @@ fn api_set<T: Tun, S: Sock>(
.unwrap_or(EIO)
}

fn api_set_peer<T: Tun, S: Sock>(
fn api_set_peer(
reader: &mut BufReader<&UnixStream>,
d: &mut Device<T, S>,
d: &mut Device,
pub_key: X25519PublicKey,
) -> i32 {
let mut cmd = String::new();
Expand Down Expand Up @@ -304,7 +301,7 @@ fn api_set_peer<T: Tun, S: Sock>(
preshared_key,
);
match val.parse::<X25519PublicKey>() {
Ok(key) => return api_set_peer::<T, S>(reader, d, key),
Ok(key) => return api_set_peer(reader, d, key),
Err(_) => return EINVAL,
}
}
Expand Down
107 changes: 41 additions & 66 deletions src/device/mod.rs
Expand Up @@ -86,45 +86,10 @@ enum Action {
}

// Event handler function
type Handler<T, S> =
Box<dyn Fn(&mut LockReadGuard<Device<T, S>>, &mut ThreadData<T>) -> Action + Send + Sync>;
type Handler = Box<dyn Fn(&mut LockReadGuard<Device>, &mut ThreadData) -> Action + Send + Sync>;

// The trait satisfied by tunnel device implementations.
pub trait Tun: 'static + AsRawFd + Sized + Send + Sync {
fn new(name: &str) -> Result<Self, Error>;
fn set_non_blocking(self) -> Result<Self, Error>;

fn name(&self) -> Result<String, Error>;
fn mtu(&self) -> Result<usize, Error>;

fn write4(&self, src: &[u8]) -> usize;
fn write6(&self, src: &[u8]) -> usize;
fn read<'a>(&self, dst: &'a mut [u8]) -> Result<&'a mut [u8], Error>;
}

// The trait satisfied by UDP socket implementations.
pub trait Sock: 'static + AsRawFd + Sized + Send + Sync {
fn new() -> Result<Self, Error>;
fn new6() -> Result<Self, Error>;

fn bind(self, port: u16) -> Result<Self, Error>;
fn connect(self, dst: &SocketAddr) -> Result<Self, Error>;

fn set_non_blocking(self) -> Result<Self, Error>;
fn set_reuse(self) -> Result<Self, Error>;
fn set_fwmark(&self, mark: u32) -> Result<(), Error>;

fn port(&self) -> Result<u16, Error>;
fn sendto(&self, buf: &[u8], dst: SocketAddr) -> usize;
fn recvfrom<'a>(&self, buf: &'a mut [u8]) -> Result<(SocketAddr, &'a mut [u8]), Error>;
fn write(&self, buf: &[u8]) -> usize;
fn read<'a>(&self, buf: &'a mut [u8]) -> Result<&'a mut [u8], Error>;

fn shutdown(&self);
}

pub struct DeviceHandle<T: Tun = TunSocket, S: Sock = UDPSocket> {
device: Arc<Lock<Device<T, S>>>, // The interface this handle owns
pub struct DeviceHandle {
device: Arc<Lock<Device>>, // The interface this handle owns
threads: Vec<JoinHandle<()>>,
}

Expand All @@ -148,23 +113,23 @@ impl Default for DeviceConfig {
}
}

pub struct Device<T: Tun, S: Sock> {
pub struct Device {
key_pair: Option<(Arc<X25519SecretKey>, Arc<X25519PublicKey>)>,
queue: Arc<EventPoll<Handler<T, S>>>,
queue: Arc<EventPoll<Handler>>,

listen_port: u16,
fwmark: Option<u32>,

iface: Arc<T>,
udp4: Option<Arc<S>>,
udp6: Option<Arc<S>>,
iface: Arc<TunSocket>,
udp4: Option<Arc<UDPSocket>>,
udp6: Option<Arc<UDPSocket>>,

yield_notice: Option<EventRef>,
exit_notice: Option<EventRef>,

peers: HashMap<Arc<X25519PublicKey>, Arc<Peer<S>>>,
peers_by_ip: AllowedIps<Arc<Peer<S>>>,
peers_by_idx: HashMap<u32, Arc<Peer<S>>>,
peers: HashMap<Arc<X25519PublicKey>, Arc<Peer>>,
peers_by_ip: AllowedIps<Arc<Peer>>,
peers_by_idx: HashMap<u32, Arc<Peer>>,
next_index: u32,

config: DeviceConfig,
Expand All @@ -176,16 +141,16 @@ pub struct Device<T: Tun, S: Sock> {
rate_limiter: Option<Arc<RateLimiter>>,
}

struct ThreadData<T: Tun> {
iface: Arc<T>,
struct ThreadData {
iface: Arc<TunSocket>,
src_buf: [u8; MAX_UDP_SIZE],
dst_buf: [u8; MAX_UDP_SIZE],
}

impl<T: Tun, S: Sock> DeviceHandle<T, S> {
pub fn new(name: &str, config: DeviceConfig) -> Result<DeviceHandle<T, S>, Error> {
impl DeviceHandle {
pub fn new(name: &str, config: DeviceConfig) -> Result<DeviceHandle, Error> {
let n_threads = config.n_threads;
let mut wg_interface = Device::<T, S>::new(name, config)?;
let mut wg_interface = Device::new(name, config)?;
wg_interface.open_listen_socket(0)?; // Start listening on a random port

let interface_lock = Arc::new(Lock::new(wg_interface));
Expand Down Expand Up @@ -218,7 +183,7 @@ impl<T: Tun, S: Sock> DeviceHandle<T, S> {
}
}

fn event_loop(_i: usize, device: &Lock<Device<T, S>>) {
fn event_loop(_i: usize, device: &Lock<Device>) {
#[cfg(target_os = "linux")]
let mut thread_local = ThreadData {
src_buf: [0u8; MAX_UDP_SIZE],
Expand All @@ -229,7 +194,7 @@ impl<T: Tun, S: Sock> DeviceHandle<T, S> {
} else {
// For for the rest create a new iface queue
let iface_local = Arc::new(
T::new(&device.read().iface.name().unwrap())
TunSocket::new(&device.read().iface.name().unwrap())
.unwrap()
.set_non_blocking()
.unwrap(),
Expand Down Expand Up @@ -279,14 +244,14 @@ impl<T: Tun, S: Sock> DeviceHandle<T, S> {
}
}

impl<T: Tun, S: Sock> Drop for DeviceHandle<T, S> {
impl Drop for DeviceHandle {
fn drop(&mut self) {
self.device.read().trigger_exit();
self.clean();
}
}

impl<T: Tun, S: Sock> Device<T, S> {
impl Device {
fn next_index(&mut self) -> u32 {
let next_index = self.next_index;
self.next_index += 1;
Expand All @@ -300,7 +265,7 @@ impl<T: Tun, S: Sock> Device<T, S> {
peer.shutdown_endpoint(); // close open udp socket and free the closure
self.peers_by_idx.remove(&peer.index()); // peers_by_idx
self.peers_by_ip
.remove(&|p: &Arc<Peer<S>>| Arc::ptr_eq(&peer, p)); // peers_by_ip
.remove(&|p: &Arc<Peer>| Arc::ptr_eq(&peer, p)); // peers_by_ip

info!(peer.tunnel.logger, "Peer removed");
}
Expand Down Expand Up @@ -366,11 +331,11 @@ impl<T: Tun, S: Sock> Device<T, S> {
info!(peer.tunnel.logger, "Peer added");
}

pub fn new(name: &str, config: DeviceConfig) -> Result<Device<T, S>, Error> {
let poll = EventPoll::<Handler<T, S>>::new()?;
pub fn new(name: &str, config: DeviceConfig) -> Result<Device, Error> {
let poll = EventPoll::<Handler>::new()?;

// Create a tunnel device
let iface = Arc::new(T::new(name)?.set_non_blocking()?);
let iface = Arc::new(TunSocket::new(name)?.set_non_blocking()?);
let mtu = iface.mtu()?;

let mut device = Device {
Expand Down Expand Up @@ -431,14 +396,24 @@ impl<T: Tun, S: Sock> Device<T, S> {
}

// Then open new sockets and bind to the port
let udp_sock4 = Arc::new(S::new()?.set_non_blocking()?.set_reuse()?.bind(port)?);
let udp_sock4 = Arc::new(
UDPSocket::new()?
.set_non_blocking()?
.set_reuse()?
.bind(port)?,
);

if port == 0 {
// Random port was assigned
port = udp_sock4.port()?;
}

let udp_sock6 = Arc::new(S::new6()?.set_non_blocking()?.set_reuse()?.bind(port)?);
let udp_sock6 = Arc::new(
UDPSocket::new6()?
.set_non_blocking()?
.set_reuse()?
.bind(port)?,
);

self.register_udp_handler(Arc::clone(&udp_sock4))?;
self.register_udp_handler(Arc::clone(&udp_sock6))?;
Expand All @@ -460,7 +435,7 @@ impl<T: Tun, S: Sock> Device<T, S> {

for peer in self.peers.values_mut() {
// Taking a pointer should be Ok as long as all other threads are stopped
let mut_ptr = Arc::into_raw(Arc::clone(peer)) as *mut Peer<S>;
let mut_ptr = Arc::into_raw(Arc::clone(peer)) as *mut Peer;

if unsafe {
mut_ptr.as_mut().unwrap().tunnel.set_static_private(
Expand Down Expand Up @@ -595,7 +570,7 @@ impl<T: Tun, S: Sock> Device<T, S> {
.stop_notification(self.yield_notice.as_ref().unwrap())
}

fn register_udp_handler(&self, udp: Arc<S>) -> Result<(), Error> {
fn register_udp_handler(&self, udp: Arc<UDPSocket>) -> Result<(), Error> {
self.queue.new_event(
udp.as_raw_fd(),
Box::new(move |d, t| {
Expand Down Expand Up @@ -693,8 +668,8 @@ impl<T: Tun, S: Sock> Device<T, S> {

fn register_conn_handler(
&self,
peer: Arc<Peer<S>>,
udp: Arc<S>,
peer: Arc<Peer>,
udp: Arc<UDPSocket>,
peer_addr: IpAddr,
) -> Result<(), Error> {
self.queue.new_event(
Expand Down Expand Up @@ -750,7 +725,7 @@ impl<T: Tun, S: Sock> Device<T, S> {
Ok(())
}

fn register_iface_handler(&self, iface: Arc<T>) -> Result<(), Error> {
fn register_iface_handler(&self, iface: Arc<TunSocket>) -> Result<(), Error> {
self.queue.new_event(
iface.as_raw_fd(),
Box::new(move |d, t| {
Expand Down
28 changes: 17 additions & 11 deletions src/device/peer.rs
Expand Up @@ -9,19 +9,21 @@ use std::net::SocketAddr;
use std::str::FromStr;
use std::sync::Arc;

use crate::device::{AllowedIps, Error, Sock};
use crate::device::{AllowedIps, Error};
use crate::noise::{Tunn, TunnResult};

use crate::device::udp::UDPSocket;

#[derive(Default, Debug)]
pub struct Endpoint<S: Sock> {
pub struct Endpoint {
pub addr: Option<SocketAddr>,
pub conn: Option<Arc<S>>,
pub conn: Option<Arc<UDPSocket>>,
}

pub struct Peer<S: Sock> {
pub struct Peer {
pub(crate) tunnel: Box<Tunn>, // The associated tunnel struct
index: u32, // The index the tunnel uses
endpoint: RwLock<Endpoint<S>>,
endpoint: RwLock<Endpoint>,
allowed_ips: AllowedIps<()>,
preshared_key: Option<[u8; 32]>,
}
Expand Down Expand Up @@ -50,14 +52,14 @@ impl FromStr for AllowedIP {
}
}

impl<S: Sock> Peer<S> {
impl Peer {
pub fn new(
tunnel: Box<Tunn>,
index: u32,
endpoint: Option<SocketAddr>,
allowed_ips: &[AllowedIP],
preshared_key: Option<[u8; 32]>,
) -> Peer<S> {
) -> Peer {
Peer {
tunnel,
index,
Expand All @@ -74,7 +76,7 @@ impl<S: Sock> Peer<S> {
self.tunnel.update_timers(dst)
}

pub fn endpoint(&self) -> parking_lot::RwLockReadGuard<'_, Endpoint<S>> {
pub fn endpoint(&self) -> parking_lot::RwLockReadGuard<'_, Endpoint> {
self.endpoint.read()
}

Expand All @@ -100,20 +102,24 @@ impl<S: Sock> Peer<S> {
};
}

pub fn connect_endpoint(&self, port: u16, fwmark: Option<u32>) -> Result<Arc<S>, Error> {
pub fn connect_endpoint(
&self,
port: u16,
fwmark: Option<u32>,
) -> Result<Arc<UDPSocket>, Error> {
let mut endpoint = self.endpoint.write();

if endpoint.conn.is_some() {
return Err(Error::Connect("Connected".to_owned()));
}

let udp_conn = Arc::new(match endpoint.addr {
Some(addr @ SocketAddr::V4(_)) => S::new()?
Some(addr @ SocketAddr::V4(_)) => UDPSocket::new()?
.set_non_blocking()?
.set_reuse()?
.bind(port)?
.connect(&addr)?,
Some(addr @ SocketAddr::V6(_)) => S::new6()?
Some(addr @ SocketAddr::V6(_)) => UDPSocket::new6()?
.set_non_blocking()?
.set_reuse()?
.bind(port)?
Expand Down

0 comments on commit 32bde17

Please sign in to comment.