diff --git a/kms/kms.toml b/kms/kms.toml index bd5834cd..1dd28407 100644 --- a/kms/kms.toml +++ b/kms/kms.toml @@ -38,6 +38,12 @@ cache_dir = "/usr/share/dstack/images" download_url = "http://localhost:8000/{OS_IMAGE_HASH}.tar.gz" download_timeout = "2m" +[core.metrics] +# Expose unauthenticated Prometheus metrics on /metrics. +# Disable this if the KMS RPC listener is not protected by network policy +# or another access-control layer. +enabled = true + [core.auth_api] type = "webhook" diff --git a/kms/src/config.rs b/kms/src/config.rs index 2bf9bba2..7065b748 100644 --- a/kms/src/config.rs +++ b/kms/src/config.rs @@ -46,6 +46,13 @@ pub(crate) struct KmsConfig { /// agent socket. #[serde(default = "default_true")] pub enforce_self_authorization: bool, + pub metrics: MetricsConfig, +} + +#[derive(Debug, Clone, Deserialize)] +pub(crate) struct MetricsConfig { + /// Whether to expose the unauthenticated Prometheus `/metrics` endpoint. + pub enabled: bool, } fn default_true() -> bool { diff --git a/kms/src/main.rs b/kms/src/main.rs index eddfbdc9..1ab9b568 100644 --- a/kms/src/main.rs +++ b/kms/src/main.rs @@ -10,8 +10,8 @@ use ra_rpc::rocket_helper::QuoteVerifier; use rocket::{ fairing::AdHoc, figment::{providers::Serialized, Figment}, - response::content::RawHtml, - Shutdown, + response::content::{RawHtml, RawText}, + Shutdown, State, }; use tracing::{info, warn}; @@ -77,6 +77,34 @@ async fn run_onboard_service(kms_config: KmsConfig, figment: Figment) -> Result< Ok(()) } +#[rocket::get("/metrics")] +fn metrics(state: &State) -> RawText { + RawText(state.metrics().render_prometheus()) +} + +// Count only RPCs whose primary job is to verify caller/app attestation. +// Recording in a response fairing also catches failures that happen before +// RpcHandler is constructed, such as malformed RA-TLS attestation. +fn is_attestation_rpc_path(path: &str) -> bool { + let Some(method) = path.strip_prefix("/prpc/") else { + return false; + }; + let method = method.trim_start_matches("KMS."); + matches!(method, "GetAppKey" | "GetKmsKey" | "SignCert") +} + +fn record_attestation_metrics(req: &rocket::Request<'_>, res: &rocket::Response<'_>) { + if !is_attestation_rpc_path(req.uri().path().as_str()) { + return; + } + let Some(state) = req.rocket().state::() else { + return; + }; + state + .metrics() + .record_attestation_request(res.status().code >= 400); +} + #[rocket::main] async fn main() -> Result<()> { { @@ -109,6 +137,7 @@ async fn main() -> Result<()> { } let pccs_url = config.pccs_url.clone(); + let metrics_enabled = config.metrics.enabled; let state = main_service::KmsState::new(config).context("Failed to initialize KMS state")?; let figment = figment .clone() @@ -125,6 +154,16 @@ async fn main() -> Result<()> { ) .manage(state); + if metrics_enabled { + info!("Prometheus metrics endpoint enabled at /metrics"); + rocket = rocket + .attach(AdHoc::on_response( + "Record KMS attestation metrics", + |req, res| Box::pin(async move { record_attestation_metrics(req, res) }), + )) + .mount("/", rocket::routes![metrics]); + } + let verifier = QuoteVerifier::new(pccs_url); rocket = rocket.manage(verifier); diff --git a/kms/src/main_service.rs b/kms/src/main_service.rs index 7340f71e..00723566 100644 --- a/kms/src/main_service.rs +++ b/kms/src/main_service.rs @@ -4,7 +4,10 @@ use std::{ path::{Path, PathBuf}, - sync::Arc, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, + }, }; use anyhow::{bail, Context, Result}; @@ -57,6 +60,38 @@ pub struct KmsStateInner { temp_ca_key: String, verifier: CvmVerifier, self_boot_info: OnceCell, + metrics: KmsMetrics, +} + +#[derive(Default)] +pub(crate) struct KmsMetrics { + attestation_requests_total: AtomicU64, + attestation_failures_total: AtomicU64, +} + +impl KmsMetrics { + pub(crate) fn record_attestation_request(&self, failed: bool) { + self.attestation_requests_total + .fetch_add(1, Ordering::Relaxed); + if failed { + self.attestation_failures_total + .fetch_add(1, Ordering::Relaxed); + } + } + + pub(crate) fn render_prometheus(&self) -> String { + let attestation_requests_total = self.attestation_requests_total.load(Ordering::Relaxed); + let attestation_failures_total = self.attestation_failures_total.load(Ordering::Relaxed); + + format!( + "# HELP dstack_kms_attestation_requests_total Total number of KMS attestation requests.\n\ + # TYPE dstack_kms_attestation_requests_total counter\n\ + dstack_kms_attestation_requests_total {attestation_requests_total}\n\ + # HELP dstack_kms_attestation_failures_total Total number of failed KMS attestation requests.\n\ + # TYPE dstack_kms_attestation_failures_total counter\n\ + dstack_kms_attestation_failures_total {attestation_failures_total}\n" + ) + } } impl KmsState { @@ -77,7 +112,9 @@ impl KmsState { config.pccs_url.clone(), ); if !config.enforce_self_authorization { - warn!("self-authorization is disabled; trusted RPCs will not be gated by KMS self-attestation - do not use in production TEE deployments"); + warn!( + "self-authorization is disabled; trusted RPCs will not be gated by KMS self-attestation - do not use in production TEE deployments" + ); } Ok(Self { inner: Arc::new(KmsStateInner { @@ -88,9 +125,14 @@ impl KmsState { temp_ca_key, verifier, self_boot_info: OnceCell::new(), + metrics: KmsMetrics::default(), }), }) } + + pub(crate) fn metrics(&self) -> &KmsMetrics { + &self.inner.metrics + } } pub struct RpcHandler {