Skip to content

Commit

Permalink
Stub out tunnel device with generics. (cloudflare#147)
Browse files Browse the repository at this point in the history
  • Loading branch information
Brendan McMillion committed Nov 3, 2020
1 parent 4380f35 commit 2d86264
Show file tree
Hide file tree
Showing 7 changed files with 305 additions and 270 deletions.
17 changes: 10 additions & 7 deletions src/device/api.rs
Expand Up @@ -4,7 +4,7 @@
use super::dev_lock::LockReadGuard;
use super::drop_privileges::*;
use super::{make_array, AllowedIP, Device, Error, SocketAddr, X25519PublicKey, X25519SecretKey};
use crate::device::Action;
use crate::device::{Action, Sock, Tun};
use hex::encode as encode_hex;
use libc::*;
use std::fs::{create_dir, remove_file};
Expand Down Expand Up @@ -32,7 +32,7 @@ fn create_sock_dir() {
}
}

impl Device {
impl<T: Tun, S: Sock> Device<T, S> {
/// Register the api handler for this Device. The api handler receives stream connections on a Unix socket
/// with a known path: /var/run/wireguard/{tun_name}.sock.
pub fn register_api_handler(&mut self) -> Result<(), Error> {
Expand Down Expand Up @@ -118,7 +118,7 @@ impl Device {
}

#[allow(unused_must_use)]
fn api_get(writer: &mut BufWriter<&UnixStream>, d: &Device) -> i32 {
fn api_get<T: Tun, S: Sock>(writer: &mut BufWriter<&UnixStream>, d: &Device<T, S>) -> i32 {
// get command requires an empty line, but there is no reason to be religious about it
if let Some(ref k) = d.key_pair {
writeln!(writer, "private_key={}", encode_hex(k.0.as_bytes()));
Expand Down Expand Up @@ -164,7 +164,10 @@ fn api_get(writer: &mut BufWriter<&UnixStream>, d: &Device) -> i32 {
0
}

fn api_set<'a>(reader: &mut BufReader<&UnixStream>, d: &mut LockReadGuard<Device>) -> i32 {
fn api_set<'a, T: Tun, S: Sock>(
reader: &mut BufReader<&UnixStream>,
d: &mut LockReadGuard<Device<T, S>>,
) -> i32 {
d.try_writeable(
|device| device.trigger_yield(),
|device| {
Expand Down Expand Up @@ -226,9 +229,9 @@ fn api_set<'a>(reader: &mut BufReader<&UnixStream>, d: &mut LockReadGuard<Device
.unwrap_or(EIO)
}

fn api_set_peer(
fn api_set_peer<T: Tun, S: Sock>(
reader: &mut BufReader<&UnixStream>,
d: &mut Device,
d: &mut Device<T, S>,
pub_key: X25519PublicKey,
) -> i32 {
let mut cmd = String::new();
Expand Down Expand Up @@ -301,7 +304,7 @@ fn api_set_peer(
preshared_key,
);
match val.parse::<X25519PublicKey>() {
Ok(key) => return api_set_peer(reader, d, key),
Ok(key) => return api_set_peer::<T, S>(reader, d, key),
Err(_) => return EINVAL,
}
}
Expand Down
107 changes: 66 additions & 41 deletions src/device/mod.rs
Expand Up @@ -86,10 +86,45 @@ enum Action {
}

// Event handler function
type Handler = Box<dyn Fn(&mut LockReadGuard<Device>, &mut ThreadData) -> Action + Send + Sync>;
type Handler<T, S> =
Box<dyn Fn(&mut LockReadGuard<Device<T, S>>, &mut ThreadData<T>) -> Action + Send + Sync>;

pub struct DeviceHandle {
device: Arc<Lock<Device>>, // The interface this handle owns
// The trait satisfied by tunnel device implementations.
pub trait Tun: 'static + AsRawFd + Sized + Send + Sync {
fn new(name: &str) -> Result<Self, Error>;
fn set_non_blocking(self) -> Result<Self, Error>;

fn name(&self) -> Result<String, Error>;
fn mtu(&self) -> Result<usize, Error>;

fn write4(&self, src: &[u8]) -> usize;
fn write6(&self, src: &[u8]) -> usize;
fn read<'a>(&self, dst: &'a mut [u8]) -> Result<&'a mut [u8], Error>;
}

// The trait satisfied by UDP socket implementations.
pub trait Sock: 'static + AsRawFd + Sized + Send + Sync {
fn new() -> Result<Self, Error>;
fn new6() -> Result<Self, Error>;

fn bind(self, port: u16) -> Result<Self, Error>;
fn connect(self, dst: &SocketAddr) -> Result<Self, Error>;

fn set_non_blocking(self) -> Result<Self, Error>;
fn set_reuse(self) -> Result<Self, Error>;
fn set_fwmark(&self, mark: u32) -> Result<(), Error>;

fn port(&self) -> Result<u16, Error>;
fn sendto(&self, buf: &[u8], dst: SocketAddr) -> usize;
fn recvfrom<'a>(&self, buf: &'a mut [u8]) -> Result<(SocketAddr, &'a mut [u8]), Error>;
fn write(&self, buf: &[u8]) -> usize;
fn read<'a>(&self, buf: &'a mut [u8]) -> Result<&'a mut [u8], Error>;

fn shutdown(&self);
}

pub struct DeviceHandle<T: Tun = TunSocket, S: Sock = UDPSocket> {
device: Arc<Lock<Device<T, S>>>, // The interface this handle owns
threads: Vec<JoinHandle<()>>,
}

Expand All @@ -113,23 +148,23 @@ impl Default for DeviceConfig {
}
}

pub struct Device {
pub struct Device<T: Tun, S: Sock> {
key_pair: Option<(Arc<X25519SecretKey>, Arc<X25519PublicKey>)>,
queue: Arc<EventPoll<Handler>>,
queue: Arc<EventPoll<Handler<T, S>>>,

listen_port: u16,
fwmark: Option<u32>,

iface: Arc<TunSocket>,
udp4: Option<Arc<UDPSocket>>,
udp6: Option<Arc<UDPSocket>>,
iface: Arc<T>,
udp4: Option<Arc<S>>,
udp6: Option<Arc<S>>,

yield_notice: Option<EventRef>,
exit_notice: Option<EventRef>,

peers: HashMap<Arc<X25519PublicKey>, Arc<Peer>>,
peers_by_ip: AllowedIps<Arc<Peer>>,
peers_by_idx: HashMap<u32, Arc<Peer>>,
peers: HashMap<Arc<X25519PublicKey>, Arc<Peer<S>>>,
peers_by_ip: AllowedIps<Arc<Peer<S>>>,
peers_by_idx: HashMap<u32, Arc<Peer<S>>>,
next_index: u32,

config: DeviceConfig,
Expand All @@ -141,16 +176,16 @@ pub struct Device {
rate_limiter: Option<Arc<RateLimiter>>,
}

struct ThreadData {
iface: Arc<TunSocket>,
struct ThreadData<T: Tun> {
iface: Arc<T>,
src_buf: [u8; MAX_UDP_SIZE],
dst_buf: [u8; MAX_UDP_SIZE],
}

impl DeviceHandle {
pub fn new(name: &str, config: DeviceConfig) -> Result<DeviceHandle, Error> {
impl<T: Tun, S: Sock> DeviceHandle<T, S> {
pub fn new(name: &str, config: DeviceConfig) -> Result<DeviceHandle<T, S>, Error> {
let n_threads = config.n_threads;
let mut wg_interface = Device::new(name, config)?;
let mut wg_interface = Device::<T, S>::new(name, config)?;
wg_interface.open_listen_socket(0)?; // Start listening on a random port

let interface_lock = Arc::new(Lock::new(wg_interface));
Expand Down Expand Up @@ -183,7 +218,7 @@ impl DeviceHandle {
}
}

fn event_loop(_i: usize, device: &Lock<Device>) {
fn event_loop(_i: usize, device: &Lock<Device<T, S>>) {
#[cfg(target_os = "linux")]
let mut thread_local = ThreadData {
src_buf: [0u8; MAX_UDP_SIZE],
Expand All @@ -194,7 +229,7 @@ impl DeviceHandle {
} else {
// For for the rest create a new iface queue
let iface_local = Arc::new(
TunSocket::new(&device.read().iface.name().unwrap())
T::new(&device.read().iface.name().unwrap())
.unwrap()
.set_non_blocking()
.unwrap(),
Expand Down Expand Up @@ -244,14 +279,14 @@ impl DeviceHandle {
}
}

impl Drop for DeviceHandle {
impl<T: Tun, S: Sock> Drop for DeviceHandle<T, S> {
fn drop(&mut self) {
self.device.read().trigger_exit();
self.clean();
}
}

impl Device {
impl<T: Tun, S: Sock> Device<T, S> {
fn next_index(&mut self) -> u32 {
let next_index = self.next_index;
self.next_index += 1;
Expand All @@ -265,7 +300,7 @@ impl Device {
peer.shutdown_endpoint(); // close open udp socket and free the closure
self.peers_by_idx.remove(&peer.index()); // peers_by_idx
self.peers_by_ip
.remove(&|p: &Arc<Peer>| Arc::ptr_eq(&peer, p)); // peers_by_ip
.remove(&|p: &Arc<Peer<S>>| Arc::ptr_eq(&peer, p)); // peers_by_ip

info!(peer.tunnel.logger, "Peer removed");
}
Expand Down Expand Up @@ -331,11 +366,11 @@ impl Device {
info!(peer.tunnel.logger, "Peer added");
}

pub fn new(name: &str, config: DeviceConfig) -> Result<Device, Error> {
let poll = EventPoll::<Handler>::new()?;
pub fn new(name: &str, config: DeviceConfig) -> Result<Device<T, S>, Error> {
let poll = EventPoll::<Handler<T, S>>::new()?;

// Create a tunnel device
let iface = Arc::new(TunSocket::new(name)?.set_non_blocking()?);
let iface = Arc::new(T::new(name)?.set_non_blocking()?);
let mtu = iface.mtu()?;

let mut device = Device {
Expand Down Expand Up @@ -396,24 +431,14 @@ impl Device {
}

// Then open new sockets and bind to the port
let udp_sock4 = Arc::new(
UDPSocket::new()?
.set_non_blocking()?
.set_reuse()?
.bind(port)?,
);
let udp_sock4 = Arc::new(S::new()?.set_non_blocking()?.set_reuse()?.bind(port)?);

if port == 0 {
// Random port was assigned
port = udp_sock4.port()?;
}

let udp_sock6 = Arc::new(
UDPSocket::new6()?
.set_non_blocking()?
.set_reuse()?
.bind(port)?,
);
let udp_sock6 = Arc::new(S::new6()?.set_non_blocking()?.set_reuse()?.bind(port)?);

self.register_udp_handler(Arc::clone(&udp_sock4))?;
self.register_udp_handler(Arc::clone(&udp_sock6))?;
Expand All @@ -435,7 +460,7 @@ impl Device {

for peer in self.peers.values_mut() {
// Taking a pointer should be Ok as long as all other threads are stopped
let mut_ptr = Arc::into_raw(Arc::clone(peer)) as *mut Peer;
let mut_ptr = Arc::into_raw(Arc::clone(peer)) as *mut Peer<S>;

if unsafe {
mut_ptr.as_mut().unwrap().tunnel.set_static_private(
Expand Down Expand Up @@ -570,7 +595,7 @@ impl Device {
.stop_notification(self.yield_notice.as_ref().unwrap())
}

fn register_udp_handler(&self, udp: Arc<UDPSocket>) -> Result<(), Error> {
fn register_udp_handler(&self, udp: Arc<S>) -> Result<(), Error> {
self.queue.new_event(
udp.as_raw_fd(),
Box::new(move |d, t| {
Expand Down Expand Up @@ -668,8 +693,8 @@ impl Device {

fn register_conn_handler(
&self,
peer: Arc<Peer>,
udp: Arc<UDPSocket>,
peer: Arc<Peer<S>>,
udp: Arc<S>,
peer_addr: IpAddr,
) -> Result<(), Error> {
self.queue.new_event(
Expand Down Expand Up @@ -725,7 +750,7 @@ impl Device {
Ok(())
}

fn register_iface_handler(&self, iface: Arc<TunSocket>) -> Result<(), Error> {
fn register_iface_handler(&self, iface: Arc<T>) -> Result<(), Error> {
self.queue.new_event(
iface.as_raw_fd(),
Box::new(move |d, t| {
Expand Down
25 changes: 10 additions & 15 deletions src/device/peer.rs
@@ -1,23 +1,22 @@
// Copyright (c) 2019 Cloudflare, Inc. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause

use crate::device::udp::UDPSocket;
use crate::device::*;
use parking_lot::RwLock;
use std::net::IpAddr;
use std::net::SocketAddr;
use std::str::FromStr;

#[derive(Default, Debug)]
pub struct Endpoint {
pub struct Endpoint<S: Sock> {
pub addr: Option<SocketAddr>,
pub conn: Option<Arc<UDPSocket>>,
pub conn: Option<Arc<S>>,
}

pub struct Peer {
pub struct Peer<S: Sock> {
pub(crate) tunnel: Box<Tunn>, // The associated tunnel struct
index: u32, // The index the tunnel uses
endpoint: RwLock<Endpoint>,
endpoint: RwLock<Endpoint<S>>,
allowed_ips: AllowedIps<()>,
preshared_key: Option<[u8; 32]>,
}
Expand Down Expand Up @@ -46,14 +45,14 @@ impl FromStr for AllowedIP {
}
}

impl Peer {
impl<S: Sock> Peer<S> {
pub fn new(
tunnel: Box<Tunn>,
index: u32,
endpoint: Option<SocketAddr>,
allowed_ips: &[AllowedIP],
preshared_key: Option<[u8; 32]>,
) -> Peer {
) -> Peer<S> {
Peer {
tunnel,
index,
Expand All @@ -70,7 +69,7 @@ impl Peer {
self.tunnel.update_timers(dst)
}

pub fn endpoint(&self) -> parking_lot::RwLockReadGuard<'_, Endpoint> {
pub fn endpoint(&self) -> parking_lot::RwLockReadGuard<'_, Endpoint<S>> {
self.endpoint.read()
}

Expand All @@ -96,24 +95,20 @@ impl Peer {
};
}

pub fn connect_endpoint(
&self,
port: u16,
fwmark: Option<u32>,
) -> Result<Arc<UDPSocket>, Error> {
pub fn connect_endpoint(&self, port: u16, fwmark: Option<u32>) -> Result<Arc<S>, Error> {
let mut endpoint = self.endpoint.write();

if endpoint.conn.is_some() {
return Err(Error::Connect("Connected".to_owned()));
}

let udp_conn = Arc::new(match endpoint.addr {
Some(addr @ SocketAddr::V4(_)) => UDPSocket::new()?
Some(addr @ SocketAddr::V4(_)) => S::new()?
.set_non_blocking()?
.set_reuse()?
.bind(port)?
.connect(&addr)?,
Some(addr @ SocketAddr::V6(_)) => UDPSocket::new6()?
Some(addr @ SocketAddr::V6(_)) => S::new6()?
.set_non_blocking()?
.set_reuse()?
.bind(port)?
Expand Down

0 comments on commit 2d86264

Please sign in to comment.