diff --git a/gateway/src/proxy/tls_passthough.rs b/gateway/src/proxy/tls_passthough.rs index 67bed7e2..fa107373 100644 --- a/gateway/src/proxy/tls_passthough.rs +++ b/gateway/src/proxy/tls_passthough.rs @@ -2,6 +2,7 @@ // // SPDX-License-Identifier: BUSL-1.1 +use std::collections::{BTreeMap, BTreeSet}; use std::fmt::Debug; use std::sync::atomic::Ordering; @@ -34,8 +35,23 @@ impl AppAddress { } } +fn select_app_address(items: &[Box<[u8]>], known_apps: &BTreeMap>) -> Result { + let mut fallback = None; + for data in items { + if let Ok(addr) = AppAddress::parse(data) { + if known_apps.contains_key(&addr.app_id) { + return Ok(addr); + } + if fallback.is_none() { + fallback = Some(addr); + } + } + } + fallback.context("no app address found in txt record") +} + /// resolve app address by sni -async fn resolve_app_address(prefix: &str, sni: &str, compat: bool) -> Result { +async fn resolve_app_address(prefix: &str, sni: &str, compat: bool, state: &Proxy) -> Result { let txt_domain = format!("{prefix}.{sni}"); let resolver = hickory_resolver::AsyncResolver::tokio_from_system_conf() .context("failed to create dns resolver")?; @@ -53,6 +69,7 @@ async fn resolve_app_address(prefix: &str, sni: &str, compat: bool) -> Result Result>>>>>> 8e3d6c27 (prefer locally-known app_id when TXT record contains multiple app addresses) anyhow::bail!("failed to resolve app address for {sni}"); } @@ -102,7 +155,7 @@ pub(crate) async fn proxy_with_sni( let ns_prefix = &state.config.proxy.app_address_ns_prefix; let compat = state.config.proxy.app_address_ns_compat; let dns_timeout = state.config.proxy.timeouts.dns_resolve; - let addr = timeout(dns_timeout, resolve_app_address(ns_prefix, sni, compat)) + let addr = timeout(dns_timeout, resolve_app_address(ns_prefix, sni, compat, &state)) .await .with_context(|| format!("DNS TXT resolve timeout for {sni}"))? .with_context(|| format!("failed to resolve app address for {sni}"))?; @@ -197,17 +250,65 @@ pub(crate) async fn proxy_to_app( #[cfg(test)] mod tests { use super::*; + use crate::{ + config::{load_config_figment, Config, MutualConfig, TlsConfig}, + main_service::ProxyOptions, + }; + use tempfile::TempDir; + + fn boxed(s: &[u8]) -> Box<[u8]> { + s.to_vec().into_boxed_slice() + } + + async fn create_test_proxy() -> (Proxy, TempDir) { + let figment = load_config_figment(None); + let mut config = figment.focus("core").extract::().unwrap(); + let temp_dir = TempDir::new().expect("failed to create temp dir"); + config.sync.data_dir = temp_dir.path().to_string_lossy().to_string(); + let proxy = Proxy::new(ProxyOptions { + config, + my_app_id: None, + tls_config: TlsConfig { + certs: "".to_string(), + key: "".to_string(), + mutual: MutualConfig { ca_certs: "".to_string() }, + }, + }) + .await + .expect("failed to create proxy"); + (proxy, temp_dir) + } #[tokio::test] async fn test_resolve_app_address() { + let (state, _dir) = create_test_proxy().await; let app_addr = resolve_app_address( "_dstack-app-address", "3327603e03f5bd1f830812ca4a789277fc31f577.app.dstack.org", false, + &state, ) .await .unwrap(); assert_eq!(app_addr.app_id, "3327603e03f5bd1f830812ca4a789277fc31f577"); assert_eq!(app_addr.port, 8090); } + + #[test] + fn test_select_app_address_prefers_local() { + let items = vec![boxed(b"aaaaaa:443"), boxed(b"bbbbbb:8080")]; + let mut apps = BTreeMap::new(); + apps.insert("bbbbbb".to_string(), BTreeSet::new()); + let addr = select_app_address(&items, &apps).unwrap(); + assert_eq!(addr.app_id, "bbbbbb"); + assert_eq!(addr.port, 8080); + } + + #[test] + fn test_select_app_address_fallback_when_none_local() { + let items = vec![boxed(b"aaaaaa:443"), boxed(b"bbbbbb:8080")]; + let addr = select_app_address(&items, &BTreeMap::new()).unwrap(); + assert_eq!(addr.app_id, "aaaaaa"); + assert_eq!(addr.port, 443); + } }