Skip to content

Commit

Permalink
feat: validate address when reading the config
Browse files Browse the repository at this point in the history
  • Loading branch information
azzamsa committed Feb 22, 2024
1 parent 256c1a3 commit 02b741d
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 26 deletions.
24 changes: 20 additions & 4 deletions src/config.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{
ffi::OsStr,
fs,
fs, net,
path::{Path, PathBuf},
};

Expand All @@ -14,9 +14,8 @@ pub struct Config {

#[derive(Debug, Deserialize)]
pub struct Server {
// I can't use `SocketAddr` directly here if I wanted
// to make the port optional.
pub address: String,
#[serde(deserialize_with = "deserialize_address")]
pub address: net::SocketAddr,
pub name: String,
}

Expand Down Expand Up @@ -50,3 +49,20 @@ where
}
}
}

/// Parse address string into `SocketAddr`
fn deserialize_address<'de, D>(deserializer: D) -> Result<net::SocketAddr, D::Error>
where
D: serde::Deserializer<'de>,
{
let address: String = Deserialize::deserialize(deserializer)?;

if address.contains(':') {
address.parse().map_err(serde::de::Error::custom)
} else {
// Use default port
format!("{address}:53")
.parse()
.map_err(serde::de::Error::custom)
}
}
20 changes: 7 additions & 13 deletions src/dns.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::net;
use std::str::FromStr;

use hickory_client::client::{Client, SyncClient};
Expand All @@ -7,21 +8,14 @@ use hickory_client::udp::UdpClientConnection;

use crate::error::Error;

pub fn query(domain: &str, rtype: RecordType, address: &str) -> Result<DnsResponse, Error> {
let socket_address = parse_address(address)?;
let conn = UdpClientConnection::new(socket_address)?;
pub fn query(
domain: &str,
rtype: RecordType,
address: net::SocketAddr,
) -> Result<DnsResponse, Error> {
let conn = UdpClientConnection::new(address)?;
let client = SyncClient::new(conn);

let name = Name::from_str(&format!("{}.", domain))?;
Ok(client.query(&name, DNSClass::IN, rtype)?)
}

/// Parse address string into `SocketAddr`
fn parse_address(address: &str) -> Result<std::net::SocketAddr, Error> {
if address.contains(':') {
Ok(address.parse::<std::net::SocketAddr>()?)
} else {
// Use default port
Ok(format!("{address}:53").parse::<std::net::SocketAddr>()?)
}
}
6 changes: 0 additions & 6 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,6 @@ impl std::convert::From<std::io::Error> for Error {
}
}

impl std::convert::From<std::net::AddrParseError> for Error {
fn from(err: std::net::AddrParseError) -> Self {
Self::InvalidArgument(err.to_string())
}
}

impl std::convert::From<hickory_client::error::ClientError> for Error {
fn from(err: hickory_client::error::ClientError) -> Self {
Self::InvalidArgument(err.to_string())
Expand Down
2 changes: 1 addition & 1 deletion src/output.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ impl Printer {
}
pub fn print(&self) -> Result<(), crate::Error> {
for server in &self.config.servers {
let response = dns::query(&self.domain, self.record_type, &server.address);
let response = dns::query(&self.domain, self.record_type, server.address);
trace!("Response -> {:?}", response);

stdout(&server.name.to_string());
Expand Down
4 changes: 2 additions & 2 deletions tests/integration_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ name = "Google"

cmd.arg("example.net").arg("-c").arg(config.to_path_buf());
cmd.assert()
.success()
.stdout(predicate::str::contains("invalid socket address syntax"));
.failure()
.stderr(predicate::str::contains("invalid socket address syntax"));

temp_dir.close()?;
Ok(())
Expand Down

0 comments on commit 02b741d

Please sign in to comment.