Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: added host header check (to protect against DNS rebinding attacks) #250

Merged
merged 5 commits into from
Nov 11, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 182 additions & 0 deletions aw-server/src/endpoints/hostcheck.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
//! Host header check needs to be performed to protect against DNS poisoning
//! attacks[1].
//!
//! Uses a Request Fairing to intercept the request before it's handled.
//! If the Host header is not valid, the request will be rerouted to a
//! BadRequest
//!
//! [1]: https://github.com/ActivityWatch/activitywatch/security/advisories/GHSA-v9fg-6g9j-h4x4
use rocket::fairing::Fairing;
use rocket::handler::Outcome;
use rocket::http::uri::Origin;
use rocket::http::{Method, Status};
use rocket::{Data, Request, Rocket, Route};

use crate::config::AWConfig;

static FAIRING_ROUTE_BASE: &str = "/checkheader_fairing";

pub struct HostCheck {
validate: bool,
}

impl HostCheck {
pub fn new(config: &AWConfig) -> HostCheck {
// We only validate requests if the server binds a local address
let validate = config.address == "127.0.0.1" || config.address == "localhost";
HostCheck { validate }
}
}

/// Route for HostCheck Fairing error
fn fairing_error_route<'r>(_request: &'r Request<'_>, _: Data) -> Outcome<'r> {
Outcome::Failure(Status::BadRequest)
}
johan-bjareholt marked this conversation as resolved.
Show resolved Hide resolved

/// Create a new `Route` for Fairing handling
fn fairing_route() -> Route {
Route::ranked(1, Method::Get, "/", fairing_error_route)
}

fn redirect_bad_request(request: &mut Request) {
let uri = FAIRING_ROUTE_BASE.to_string();
let origin = Origin::parse_owned(uri).unwrap();
request.set_method(Method::Get);
request.set_uri(origin);
}

impl Fairing for HostCheck {
fn info(&self) -> rocket::fairing::Info {
rocket::fairing::Info {
name: "HostCheck",
kind: rocket::fairing::Kind::Attach | rocket::fairing::Kind::Request,
}
}

fn on_attach(&self, rocket: Rocket) -> Result<Rocket, Rocket> {
match self.validate {
true => Ok(rocket.mount(FAIRING_ROUTE_BASE, vec![fairing_route()])),
false => {
warn!("Host header validation is turned off, this is a security risk");
Ok(rocket)
}
}
}

fn on_request(&self, request: &mut Request, _: &Data) {
if !self.validate {
// host header check is disabled
return;
}

// Fetch header
let hostheader_opt = request.headers().get_one("host");
if hostheader_opt.is_none() {
info!("Missing 'Host' header, denying request");
redirect_bad_request(request);
return;
}

// Parse hostname from host header
// hostname contains port, which we don't care about and filter out
let hostheader = hostheader_opt.unwrap();
let host_opt = hostheader.split(":").next();
if host_opt.is_none() {
info!("Host header '{}' not allowed, denying request", hostheader);
redirect_bad_request(request);
return;
}

// Deny requests from hosts that are not localhost
ErikBjare marked this conversation as resolved.
Show resolved Hide resolved
let valid_hosts: Vec<&str> = vec!["127.0.0.1", "localhost"];
let host = host_opt.unwrap();
if !valid_hosts.contains(&host) {
info!("Host header '{}' not allowed, denying request", hostheader);
redirect_bad_request(request);
}

// host is verified, proceed with request
ErikBjare marked this conversation as resolved.
Show resolved Hide resolved
}
}

#[cfg(test)]
mod tests {
use std::path::PathBuf;
use std::sync::Mutex;

use rocket::http::{ContentType, Header, Status};
use rocket::Rocket;

use crate::config::AWConfig;
use crate::endpoints;

fn setup_testserver(address: String) -> Rocket {
let state = endpoints::ServerState {
datastore: Mutex::new(aw_datastore::Datastore::new_in_memory(false)),
asset_path: PathBuf::from("aw-webui/dist"),
device_id: "test_id".to_string(),
};
let mut aw_config = AWConfig::default();
aw_config.address = address;
endpoints::build_rocket(state, aw_config)
}

#[test]
fn test_public_address() {
let server = setup_testserver("0.0.0.0".to_string());
let client = rocket::local::Client::new(server).expect("valid instance");

// When a public address is used, request should always pass, regardless
// if the Host header is missing
let res = client
.get("/api/0/info")
.header(ContentType::JSON)
.dispatch();
assert_eq!(res.status(), Status::Ok);
}

#[test]
fn test_localhost_address() {
let server = setup_testserver("127.0.0.1".to_string());
let client = rocket::local::Client::new(server).expect("valid instance");

// If Host header is missing we should get a BadRequest
let res = client
.get("/api/0/info")
.header(ContentType::JSON)
.dispatch();
assert_eq!(res.status(), Status::BadRequest);

// If Host header is not 127.0.0.1 or localhost we should get BadRequest
let res = client
.get("/api/0/info")
.header(ContentType::JSON)
.header(Header::new("Host", "192.168.0.1:1234"))
.dispatch();
assert_eq!(res.status(), Status::BadRequest);

// If Host header is 127.0.0.1:5600 we should get OK
let res = client
.get("/api/0/info")
.header(ContentType::JSON)
.header(Header::new("Host", "127.0.0.1:5600"))
.dispatch();
assert_eq!(res.status(), Status::Ok);

// If Host header is localhost:5600 we should get OK
let res = client
.get("/api/0/info")
.header(ContentType::JSON)
.header(Header::new("Host", "localhost:5600"))
.dispatch();
assert_eq!(res.status(), Status::Ok);

// If Host header is missing port, we should still get OK
let res = client
.get("/api/0/info")
.header(ContentType::JSON)
.header(Header::new("Host", "localhost"))
.dispatch();
assert_eq!(res.status(), Status::Ok);
}
}
4 changes: 2 additions & 2 deletions aw-server/src/endpoints/import.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ pub fn bucket_import_form(
.params()
.find(|&(k, _)| k == "boundary")
.ok_or_else(|| {
return HttpErrorJson::new(
HttpErrorJson::new(
Status::BadRequest,
"`Content-Type: multipart/form-data` boundary param not provided".to_string(),
);
)
})?;

let string = process_multipart_packets(boundary, data);
Expand Down
3 changes: 3 additions & 0 deletions aw-server/src/endpoints/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ mod util;
mod bucket;
mod cors;
mod export;
mod hostcheck;
mod import;
mod query;
mod settings;
Expand Down Expand Up @@ -78,8 +79,10 @@ pub fn build_rocket(server_state: ServerState, config: AWConfig) -> rocket::Rock
config.address, config.port
);
let cors = cors::cors(&config);
let hostcheck = hostcheck::HostCheck::new(&config);
rocket::custom(config.to_rocket_config())
.attach(cors.clone())
.attach(hostcheck)
.manage(cors)
.manage(server_state)
.manage(config)
Expand Down
15 changes: 12 additions & 3 deletions aw-server/src/logging.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,16 @@ pub fn setup_logger(testing: bool) -> Result<(), fern::InitError> {
Ok(())
}

#[test]
fn test_setup_logger() {
setup_logger(true).unwrap();
#[cfg(test)]
mod tests {
use super::setup_logger;

/* disable this test.
* This is due to it failing in GitHub actions, claiming that the logger
* has been initialized twice which is not allowed */
#[ignore]
#[test]
fn test_setup_logger() {
setup_logger(true).unwrap();
}
}
Loading