Skip to content

Commit

Permalink
Merge pull request #100 from osobiehl/feature/jb/improved-dtls-constr…
Browse files Browse the repository at this point in the history
…uction

improved dtls construction, fix resource leak on recv_loop, update version
  • Loading branch information
Covertness committed Apr 9, 2024
2 parents 8825bc3 + 51ed9dd commit c120146
Show file tree
Hide file tree
Showing 6 changed files with 279 additions and 77 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
@@ -1,6 +1,6 @@
[package]
name = "coap"
version = "0.16.0"
version = "0.17.0"
description = "A CoAP library"
readme = "README.md"
documentation = "https://docs.rs/coap/"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Expand Up @@ -24,7 +24,7 @@ First add this to your `Cargo.toml`:

```toml
[dependencies]
coap = "0.16"
coap = "0.17"
coap-lite = "0.11.3"
tokio = {version = "^1.32", features = ["full"]}
```
Expand Down
6 changes: 3 additions & 3 deletions examples/echo_with_dtls.rs
Expand Up @@ -2,7 +2,7 @@
/// a look at the test in dtls.rs
extern crate coap;
use coap::client::CoAPClient;
use coap::dtls::DtlsConfig;
use coap::dtls::UdpDtlsConfig;
use coap::request::RequestBuilder;
use coap::Server;
use coap_lite::{CoapRequest, RequestType as Method};
Expand Down Expand Up @@ -60,7 +60,7 @@ async fn main() {
.await
.unwrap();

let dtls_config = DtlsConfig {
let dtls_config = UdpDtlsConfig {
config,
dest_addr: ("127.0.0.1", server_port)
.to_socket_addrs()
Expand All @@ -69,7 +69,7 @@ async fn main() {
.unwrap(),
};

let client = CoAPClient::from_dtls_config(dtls_config)
let client = CoAPClient::from_udp_dtls_config(dtls_config)
.await
.expect("could not create client");
let domain = format!("127.0.0.1:{}", server_port);
Expand Down
90 changes: 54 additions & 36 deletions src/client.rs
@@ -1,5 +1,5 @@
#[cfg(feature = "dtls")]
use crate::dtls::{DtlsConfig, DtlsConnection};
use crate::dtls::{DtlsConnection, UdpDtlsConfig};
use crate::request::RequestBuilder;
use alloc::string::String;
use alloc::vec::Vec;
Expand All @@ -18,7 +18,7 @@ use regex::Regex;
use std::{
collections::BTreeMap,
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
sync::atomic::AtomicU16,
sync::{atomic::AtomicU16, Weak},
};
use std::{
io::{Error, ErrorKind, Result as IoResult},
Expand All @@ -44,24 +44,24 @@ pub enum ObserveMessage {
use async_trait::async_trait;

#[async_trait]
/// A basic interface for a transport on both the client and transport
/// A basic interface for a transport on the client
/// representing a one-to-one connection between a client and server
/// timeouts and retries do not need to be implemented by the transport
/// if confirmable messages are sent
pub trait Transport: Send + Sync {
async fn recv(&self, buf: &mut [u8]) -> std::io::Result<(usize, SocketAddr)>;
pub trait ClientTransport: Send + Sync {
async fn recv(&self, buf: &mut [u8]) -> std::io::Result<usize>;
async fn send(&self, buf: &[u8]) -> std::io::Result<usize>;
}

trait TransportExt {
async fn receive_packet(&self) -> IoResult<Option<(Packet, SocketAddr)>>;
async fn receive_packet(&self) -> IoResult<Option<Packet>>;
}

impl<T: Transport> TransportExt for T {
async fn receive_packet(&self) -> IoResult<Option<(Packet, SocketAddr)>> {
impl<T: ClientTransport> TransportExt for T {
async fn receive_packet(&self) -> IoResult<Option<Packet>> {
let mut buf = [0; 1500];
let (nread, src) = self.recv(&mut buf).await?;
let parse_opt = Packet::from_bytes(&buf[..nread]).ok().map(|p| (p, src));
let nread = self.recv(&mut buf).await?;
let parse_opt = Packet::from_bytes(&buf[..nread]).ok();
return Ok(parse_opt);
}
}
Expand Down Expand Up @@ -133,22 +133,34 @@ impl TransportSynchronizer {
}
}

async fn receive_loop<T: Transport + 'static>(
transport: Arc<T>,
async fn receive_loop<T: ClientTransport + 'static>(
transport: Weak<T>,
transport_sync: TransportSynchronizer,
) -> std::io::Result<()> {
let err = loop {
let recv_res = transport.receive_packet().await;
let Some(transport_instance) = transport.upgrade() else {
// nobody else is listening so we can drop our reference
return Ok(());
};
// we do a timeout here to ensure that we do not block forever
let Ok(recv_res) = timeout(
Duration::from_millis(300),
transport_instance.receive_packet(),
)
.await
else {
continue;
};
let option_packet = match recv_res {
Err(e) => break e,
Ok(o) => o,
};
let Some((packet, _src)) = option_packet else {
let Some(packet) = option_packet else {
trace!("unexpected malformed packet received");
continue;
};
if let Some(ack) = parse_for_ack(&packet) {
transport.send(&ack).await?;
transport_instance.send(&ack).await?;
}

let MessageClass::Response(_) = packet.header.code else {
Expand All @@ -168,7 +180,7 @@ async fn receive_loop<T: Transport + 'static>(
}
};
let Ok(_) = sender.send(Ok(packet)) else {
debug!("unexpected drop of oneshot sender");
debug!("unexpected drop of sender");
continue;
};
};
Expand All @@ -194,14 +206,14 @@ pub fn make_ack(packet: &Packet) -> Vec<u8> {
}

/// a wrapper for transports responsible for retries and timeouts
struct ClientTransport<T: Transport> {
struct CoapClientTransport<T: ClientTransport> {
pub(crate) transport: Arc<T>,
pub(crate) synchronizer: TransportSynchronizer,
pub(crate) retries: usize,
pub(crate) timeout: Duration,
}

impl<T: Transport> Clone for ClientTransport<T> {
impl<T: ClientTransport> Clone for CoapClientTransport<T> {
fn clone(&self) -> Self {
Self {
transport: self.transport.clone(),
Expand All @@ -212,7 +224,7 @@ impl<T: Transport> Clone for ClientTransport<T> {
}
}

impl<T: Transport> ClientTransport<T> {
impl<T: ClientTransport> CoapClientTransport<T> {
pub const DEFAULT_NUM_RETRIES: usize = 5;
async fn establish_receiver_for(&self, msg: &Packet) -> UnboundedReceiver<IoResult<Packet>> {
let (tx, rx) = unbounded_channel();
Expand Down Expand Up @@ -296,9 +308,12 @@ pub struct UdpTransport {
pub peer_addr: SocketAddr,
}
#[async_trait]
impl Transport for UdpTransport {
async fn recv(&self, buf: &mut [u8]) -> std::io::Result<(usize, SocketAddr)> {
self.socket.recv_from(buf).await
impl ClientTransport for UdpTransport {
async fn recv(&self, buf: &mut [u8]) -> std::io::Result<usize> {
self.socket
.recv_from(buf)
.await
.map(|(recv_size, _addr)| recv_size)
}
async fn send(&self, buf: &[u8]) -> std::io::Result<usize> {
self.socket.send_to(buf, self.peer_addr).await
Expand All @@ -308,13 +323,13 @@ impl Transport for UdpTransport {
/// A CoAP client over UDP. This client can send multicast and broadcasts
pub type UdpCoAPClient = CoAPClient<UdpTransport>;

pub struct CoAPClient<T: Transport> {
transport: ClientTransport<T>,
pub struct CoAPClient<T: ClientTransport> {
transport: CoapClientTransport<T>,
block1_size: usize,
message_id: Arc<AtomicU16>,
}

impl<T: Transport> Clone for CoAPClient<T> {
impl<T: ClientTransport> Clone for CoAPClient<T> {
fn clone(&self) -> Self {
Self {
transport: self.transport.clone(),
Expand Down Expand Up @@ -485,14 +500,14 @@ impl UdpCoAPClient {

#[cfg(feature = "dtls")]
impl CoAPClient<DtlsConnection> {
pub async fn from_dtls_config(config: DtlsConfig) -> IoResult<Self> {
pub async fn from_udp_dtls_config(config: UdpDtlsConfig) -> IoResult<Self> {
Ok(CoAPClient::from_transport(
DtlsConnection::try_new(config).await?,
))
}
}

impl<T: Transport + 'static> CoAPClient<T> {
impl<T: ClientTransport + 'static> CoAPClient<T> {
const MAX_PAYLOAD_BLOCK: usize = 1024;
/// Create a CoAP client with a chosen transport type

Expand All @@ -501,9 +516,12 @@ impl<T: Transport + 'static> CoAPClient<T> {
let transport_arc = Arc::new(transport);
let message_id: u16 = rand::random();
// spawn receive loop to handle responses
tokio::spawn(receive_loop(transport_arc.clone(), synchronizer.clone()));
tokio::spawn(receive_loop(
Arc::downgrade(&transport_arc),
synchronizer.clone(),
));
CoAPClient {
transport: ClientTransport::from_transport(transport_arc.clone(), synchronizer),
transport: CoapClientTransport::from_transport(transport_arc.clone(), synchronizer),
block1_size: Self::MAX_PAYLOAD_BLOCK,
message_id: Arc::new(AtomicU16::new(message_id)),
}
Expand Down Expand Up @@ -1195,8 +1213,8 @@ mod test {
}

#[async_trait]
impl Transport for FaultyUdp {
async fn recv(&self, buf: &mut [u8]) -> std::io::Result<(usize, SocketAddr)> {
impl ClientTransport for FaultyUdp {
async fn recv(&self, buf: &mut [u8]) -> std::io::Result<usize> {
self.udp.recv(buf).await
}

Expand Down Expand Up @@ -1247,7 +1265,7 @@ mod test {
let server_addr = format!("127.0.0.1:{}", server_port);
let mut client = get_faulty_client(
&server_addr,
ClientTransport::<FaultyUdp>::DEFAULT_NUM_RETRIES as u32 + 1,
CoapClientTransport::<FaultyUdp>::DEFAULT_NUM_RETRIES as u32 + 1,
)
.await;
let request_gen = || {
Expand All @@ -1260,7 +1278,7 @@ mod test {
//this request will work, we do this to reset the state of the faulty udp
client.send(request_gen()).await.unwrap();

client.set_transport_retries(ClientTransport::<UdpTransport>::DEFAULT_NUM_RETRIES + 2);
client.set_transport_retries(CoapClientTransport::<UdpTransport>::DEFAULT_NUM_RETRIES + 2);
let resp = client.send(request_gen()).await.unwrap();

assert_eq!(resp.message.payload, b"Rust".to_vec());
Expand All @@ -1287,7 +1305,7 @@ mod test {
assert!(req.is_err());
}

async fn do_wait_request<T: Transport + 'static>(
async fn do_wait_request<T: ClientTransport + 'static>(
client: Arc<CoAPClient<T>>,
path: &str,
token: Vec<u8>,
Expand Down Expand Up @@ -1362,8 +1380,8 @@ mod test {
pub should_fail: Mutex<oneshot::Receiver<std::io::Error>>,
}
#[async_trait]
impl Transport for FaultyReceiver {
async fn recv(&self, buf: &mut [u8]) -> std::io::Result<(usize, SocketAddr)> {
impl ClientTransport for FaultyReceiver {
async fn recv(&self, buf: &mut [u8]) -> std::io::Result<usize> {
let mut mutex = self.should_fail.lock().await;
tokio::select! {
e = mutex.deref_mut() => {
Expand Down

0 comments on commit c120146

Please sign in to comment.