From 4add4d62196328f64956cb7f168bb7e1b2de6a77 Mon Sep 17 00:00:00 2001 From: Kevin Wang Date: Sat, 21 Dec 2024 07:49:14 +0000 Subject: [PATCH 1/3] prpc: Support for reading args from url query --- Cargo.lock | 20 ++++++++++++++++---- Cargo.toml | 4 ++-- kms/src/web_routes.rs | 4 +++- ra-rpc/src/lib.rs | 13 ++++++++++--- ra-rpc/src/rocket_helper.rs | 13 ++++++++++--- tappd/src/guest_api_routes.rs | 4 +++- tappd/src/http_routes.rs | 5 +++++ teepod/src/guest_api_routes.rs | 4 +++- teepod/src/host_api_routes.rs | 4 +++- teepod/src/main_routes.rs | 4 +++- tproxy/src/web_routes.rs | 4 +++- 11 files changed, 61 insertions(+), 18 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c3d5f2a65..c000ab68a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3420,9 +3420,9 @@ dependencies = [ [[package]] name = "prpc" -version = "0.3.0" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e83ca4bb539b92a92f3320a41482346d66bc591acd41c52b4ca22aa6960f3259" +checksum = "4bf27f5c46f289f99f68086d3a24b58c9f4f66c3143a1780d5c0589c79b17c81" dependencies = [ "anyhow", "async-trait", @@ -3434,13 +3434,14 @@ dependencies = [ "prpc-serde-bytes", "serde", "serde_json", + "serde_qs", ] [[package]] name = "prpc-build" -version = "0.3.6" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4155bf9e4977e408df247a4dafb118add3ab153957109ffc971e4fb597b72c2" +checksum = "afd0da11e53c652b1cd4828f7ca4f6bfd607bd666a48203ebad6b30bd0b2c9ec" dependencies = [ "either", "fs-err", @@ -4545,6 +4546,17 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_qs" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd34f36fe4c5ba9654417139a9b3a20d2e1de6012ee678ad14d240c22c78d8d6" +dependencies = [ + "percent-encoding", + "serde", + "thiserror 1.0.65", +] + [[package]] name = "serde_repr" version = "0.1.19" diff --git a/Cargo.toml b/Cargo.toml index d054e6b20..fba61351c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -131,8 +131,8 @@ rcgen = { version = "0.13.1", features = ["pem"] } x509-parser = "0.16.0" # RPC/Protocol -prpc = "0.3.0" -prpc-build = "0.3.6" +prpc = "0.4.0" +prpc-build = "0.4.0" # Development/Testing bindgen = "0.70.1" diff --git a/kms/src/web_routes.rs b/kms/src/web_routes.rs index d7adc52b2..3613c62a3 100644 --- a/kms/src/web_routes.rs +++ b/kms/src/web_routes.rs @@ -3,7 +3,7 @@ use ra_rpc::rocket_helper::{PrpcHandler, QuoteVerifier}; use rocket::{ data::{Data, Limits}, get, - http::ContentType, + http::{uri::Origin, ContentType}, mtls::Certificate, post, response::status::Custom, @@ -47,6 +47,7 @@ async fn prpc_get( quote_verifier: Option<&State>, cert: Option>, method: &str, + origin: &Origin<'_>, limits: &Limits, content_type: Option<&ContentType>, ) -> Custom> { @@ -57,6 +58,7 @@ async fn prpc_get( .method(method) .limits(limits) .maybe_content_type(content_type) + .maybe_query(origin.query()) .json(true) .build() .handle::() diff --git a/ra-rpc/src/lib.rs b/ra-rpc/src/lib.rs index a37174722..98e21e9c7 100644 --- a/ra-rpc/src/lib.rs +++ b/ra-rpc/src/lib.rs @@ -40,11 +40,17 @@ pub trait RpcCall { where Self: Sized; fn into_prpc_service(self) -> Self::PrpcService; - async fn call(self, method: String, payload: Vec, is_json: bool) -> (u16, Vec) + async fn call( + self, + method: String, + payload: Vec, + is_json: bool, + is_query: bool, + ) -> (u16, Vec) where Self: Sized, { - dispatch_prpc(method, payload, is_json, self.into_prpc_service()).await + dispatch_prpc(method, payload, is_json, is_query, self.into_prpc_service()).await } } @@ -52,12 +58,13 @@ async fn dispatch_prpc( path: String, data: Vec, json: bool, + query: bool, server: impl PrpcService + Send + 'static, ) -> (u16, Vec) { use prpc::server::Error; info!("dispatching request: {}", path); - let result = server.dispatch_request(&path, data, json).await; + let result = server.dispatch_request(&path, data, json, query).await; let (code, data) = match result { Ok(data) => (200, data), Err(err) => { diff --git a/ra-rpc/src/rocket_helper.rs b/ra-rpc/src/rocket_helper.rs index c3ba69648..34a5bb3b7 100644 --- a/ra-rpc/src/rocket_helper.rs +++ b/ra-rpc/src/rocket_helper.rs @@ -7,7 +7,7 @@ use ra_tls::{ }; use rocket::{ data::{ByteUnit, Limits, ToByteUnit}, - http::{ContentType, Status}, + http::{uri::Query, ContentType, Status}, listener::Endpoint, mtls::{oid::Oid, Certificate}, response::status::Custom, @@ -81,6 +81,7 @@ pub struct PrpcHandler<'a, 'b, 'c, 'd, 'e, 'f, 'g, S> { pub quote_verifier: Option<&'c QuoteVerifier>, pub method: &'d str, pub data: Option>, + pub query: Option>, pub limits: &'f Limits, pub content_type: Option<&'g ContentType>, pub json: bool, @@ -131,6 +132,7 @@ pub async fn handle_prpc_impl>( quote_verifier, method, data, + query, limits, content_type, json, @@ -164,8 +166,13 @@ pub async fn handle_prpc_impl>( remote_endpoint: remote_addr.map(RemoteEndpoint::from), }; let call = Call::construct(context).context("failed to construct call")?; - let data = data.to_vec(); - let (status_code, output) = call.call(method.to_string(), data, json).await; + let data = match query { + Some(query) => query.as_bytes().to_vec(), + None => data.to_vec(), + }; + let (status_code, output) = call + .call(method.to_string(), data, json, query.is_some()) + .await; Ok(Custom(Status::new(status_code), output)) } diff --git a/tappd/src/guest_api_routes.rs b/tappd/src/guest_api_routes.rs index b65455afe..003ffaed3 100644 --- a/tappd/src/guest_api_routes.rs +++ b/tappd/src/guest_api_routes.rs @@ -4,7 +4,7 @@ use ra_rpc::rocket_helper::PrpcHandler; use rocket::{ data::{Data, Limits}, get, - http::ContentType, + http::{uri::Origin, ContentType}, mtls::Certificate, post, response::status::Custom, @@ -39,6 +39,7 @@ async fn prpc_post( async fn prpc_get( state: &State, method: &str, + origin: &Origin<'_>, limits: &Limits, content_type: Option<&ContentType>, ) -> Custom> { @@ -48,6 +49,7 @@ async fn prpc_get( .limits(limits) .maybe_content_type(content_type) .json(true) + .maybe_query(origin.query()) .build() .handle::() .await diff --git a/tappd/src/http_routes.rs b/tappd/src/http_routes.rs index bc00b68df..57c57c8c6 100644 --- a/tappd/src/http_routes.rs +++ b/tappd/src/http_routes.rs @@ -8,6 +8,7 @@ use ra_rpc::rocket_helper::PrpcHandler; use ra_rpc::{CallContext, RpcCall}; use rinja::Template; use rocket::futures::StreamExt; +use rocket::http::uri::Origin; use rocket::response::stream::TextStream; use rocket::{ data::{Data, Limits}, @@ -46,6 +47,7 @@ async fn prpc_get( method: &str, limits: &Limits, content_type: Option<&ContentType>, + origin: &Origin<'_>, ) -> Custom> { PrpcHandler::builder() .state(&**state) @@ -53,6 +55,7 @@ async fn prpc_get( .limits(limits) .maybe_content_type(content_type) .json(true) + .maybe_query(origin.query()) .build() .handle::() .await @@ -128,6 +131,7 @@ async fn external_prpc_get( method: &str, limits: &Limits, content_type: Option<&ContentType>, + origin: &Origin<'_>, ) -> Custom> { PrpcHandler::builder() .state(&**state) @@ -135,6 +139,7 @@ async fn external_prpc_get( .limits(limits) .maybe_content_type(content_type) .json(true) + .maybe_query(origin.query()) .build() .handle::() .await diff --git a/teepod/src/guest_api_routes.rs b/teepod/src/guest_api_routes.rs index f6496f46a..6a0cd4b26 100644 --- a/teepod/src/guest_api_routes.rs +++ b/teepod/src/guest_api_routes.rs @@ -4,7 +4,7 @@ use ra_rpc::rocket_helper::PrpcHandler; use rocket::{ data::{Data, Limits}, get, - http::ContentType, + http::{uri::Origin, ContentType}, mtls::Certificate, post, response::status::Custom, @@ -41,6 +41,7 @@ async fn prpc_get( method: &str, limits: &Limits, content_type: Option<&ContentType>, + origin: &Origin<'_>, ) -> Custom> { PrpcHandler::builder() .state(&**state) @@ -48,6 +49,7 @@ async fn prpc_get( .limits(limits) .maybe_content_type(content_type) .json(true) + .maybe_query(origin.query()) .build() .handle::() .await diff --git a/teepod/src/host_api_routes.rs b/teepod/src/host_api_routes.rs index 9e9109d00..7313f55ce 100644 --- a/teepod/src/host_api_routes.rs +++ b/teepod/src/host_api_routes.rs @@ -4,7 +4,7 @@ use ra_rpc::rocket_helper::PrpcHandler; use rocket::{ data::{Data, Limits}, get, - http::ContentType, + http::{uri::Origin, ContentType}, listener::Endpoint, mtls::Certificate, post, @@ -45,6 +45,7 @@ async fn prpc_get( method: &str, limits: &Limits, content_type: Option<&ContentType>, + origin: &Origin<'_>, ) -> Custom> { PrpcHandler::builder() .state(&**state) @@ -53,6 +54,7 @@ async fn prpc_get( .limits(limits) .maybe_content_type(content_type) .json(true) + .maybe_query(origin.query()) .build() .handle::() .await diff --git a/teepod/src/main_routes.rs b/teepod/src/main_routes.rs index be2910415..8cd7b86d5 100644 --- a/teepod/src/main_routes.rs +++ b/teepod/src/main_routes.rs @@ -6,7 +6,7 @@ use ra_rpc::rocket_helper::PrpcHandler; use rocket::{ data::{Data, Limits}, get, - http::ContentType, + http::{uri::Origin, ContentType}, mtls::Certificate, post, response::{status::Custom, stream::TextStream}, @@ -76,6 +76,7 @@ async fn prpc_get( method: &str, limits: &Limits, content_type: Option<&ContentType>, + origin: &Origin<'_>, ) -> Custom> { PrpcHandler::builder() .state(&**state) @@ -83,6 +84,7 @@ async fn prpc_get( .limits(limits) .maybe_content_type(content_type) .json(true) + .maybe_query(origin.query()) .build() .handle::() .await diff --git a/tproxy/src/web_routes.rs b/tproxy/src/web_routes.rs index 0297ec805..57816fbd4 100644 --- a/tproxy/src/web_routes.rs +++ b/tproxy/src/web_routes.rs @@ -4,7 +4,7 @@ use ra_rpc::rocket_helper::{PrpcHandler, QuoteVerifier}; use rocket::{ data::{Data, Limits}, get, - http::ContentType, + http::{uri::Origin, ContentType}, mtls::Certificate, post, response::{content::RawHtml, status::Custom}, @@ -52,6 +52,7 @@ async fn prpc_get( method: &str, limits: &Limits, content_type: Option<&ContentType>, + origin: &Origin<'_>, ) -> Custom> { PrpcHandler::builder() .state(&**state) @@ -61,6 +62,7 @@ async fn prpc_get( .limits(limits) .maybe_content_type(content_type) .json(true) + .maybe_query(origin.query()) .build() .handle::() .await From 4ac154539a8af1c61fd8f068be6ba72cb571916f Mon Sep 17 00:00:00 2001 From: Kevin Wang Date: Sat, 21 Dec 2024 09:10:22 +0000 Subject: [PATCH 2/3] Make the ra_rpc more easier to use --- Cargo.toml | 2 +- kms/src/web_routes.rs | 61 +----------- ra-rpc/src/rocket_helper.rs | 169 +++++++++++++++++++++++++-------- tappd/src/guest_api_routes.rs | 56 +---------- tappd/src/http_routes.rs | 116 ++++------------------ teepod/src/guest_api_routes.rs | 56 +---------- teepod/src/host_api_routes.rs | 60 +----------- teepod/src/main_routes.rs | 52 +--------- tproxy/src/web_routes.rs | 61 +----------- 9 files changed, 160 insertions(+), 473 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index fba61351c..596342bf5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -132,7 +132,7 @@ x509-parser = "0.16.0" # RPC/Protocol prpc = "0.4.0" -prpc-build = "0.4.0" +prpc-build = "0.4.1" # Development/Testing bindgen = "0.70.1" diff --git a/kms/src/web_routes.rs b/kms/src/web_routes.rs index 3613c62a3..f90d0cddb 100644 --- a/kms/src/web_routes.rs +++ b/kms/src/web_routes.rs @@ -1,69 +1,12 @@ use crate::{main_service::KmsState, main_service::RpcHandler}; -use ra_rpc::rocket_helper::{PrpcHandler, QuoteVerifier}; -use rocket::{ - data::{Data, Limits}, - get, - http::{uri::Origin, ContentType}, - mtls::Certificate, - post, - response::status::Custom, - routes, Route, State, -}; +use rocket::{get, routes, Route}; #[get("/")] async fn index() -> String { "KMS Server is running!\n".to_string() } -#[post("/prpc/?", data = "")] -#[allow(clippy::too_many_arguments)] -async fn prpc_post( - state: &State, - quote_verifier: Option<&State>, - cert: Option>, - method: &str, - data: Data<'_>, - limits: &Limits, - content_type: Option<&ContentType>, - json: bool, -) -> Custom> { - PrpcHandler::builder() - .state(&**state) - .maybe_certificate(cert) - .maybe_quote_verifier(quote_verifier.map(|v| &**v)) - .method(method) - .data(data) - .limits(limits) - .maybe_content_type(content_type) - .json(json) - .build() - .handle::() - .await -} - -#[get("/prpc/")] -async fn prpc_get( - state: &State, - quote_verifier: Option<&State>, - cert: Option>, - method: &str, - origin: &Origin<'_>, - limits: &Limits, - content_type: Option<&ContentType>, -) -> Custom> { - PrpcHandler::builder() - .state(&**state) - .maybe_certificate(cert) - .maybe_quote_verifier(quote_verifier.map(|v| &**v)) - .method(method) - .limits(limits) - .maybe_content_type(content_type) - .maybe_query(origin.query()) - .json(true) - .build() - .handle::() - .await -} +ra_rpc::declare_prpc_routes!(prpc_post, prpc_get, KmsState, RpcHandler); pub fn routes() -> Vec { routes![index, prpc_post, prpc_get] diff --git a/ra-rpc/src/rocket_helper.rs b/ra-rpc/src/rocket_helper.rs index 34a5bb3b7..37d3bae3d 100644 --- a/ra-rpc/src/rocket_helper.rs +++ b/ra-rpc/src/rocket_helper.rs @@ -1,4 +1,7 @@ -use std::time::{Duration, SystemTime, UNIX_EPOCH}; +use std::{ + convert::Infallible, + time::{Duration, SystemTime, UNIX_EPOCH}, +}; use anyhow::{Context, Result}; use ra_tls::{ @@ -6,12 +9,13 @@ use ra_tls::{ qvl::{self, verify::VerifiedReport}, }; use rocket::{ - data::{ByteUnit, Limits, ToByteUnit}, - http::{uri::Query, ContentType, Status}, + data::{ByteUnit, Data, Limits, ToByteUnit}, + http::{uri::Origin, ContentType, Status}, listener::Endpoint, mtls::{oid::Oid, Certificate}, + request::{FromRequest, Outcome}, response::status::Custom, - Data, + Request, }; use rocket_vsock_listener::VsockEndpoint; use tracing::{info, warn}; @@ -24,6 +28,78 @@ pub struct QuoteVerifier { timeout: Duration, } +pub mod deps { + pub use super::{PrpcHandler, RpcRequest}; + pub use rocket::response::status::Custom; + pub use rocket::{Data, State}; +} + +#[macro_export] +macro_rules! declare_prpc_routes { + ($post:ident, $get:ident, $state:ty, $handler:ty) => { + $crate::declare_prpc_routes!(path: "/prpc/?", "/prpc/", $post, $get, $state, $handler); + }; + (bare: $post:ident, $get:ident, $state:ty, $handler:ty) => { + $crate::declare_prpc_routes!(path: "/?", "/", $post, $get, $state, $handler); + }; + (path: $post_path: literal, $get_path: literal, $post:ident, $get:ident, $state:ty, $handler:ty) => { + #[rocket::post($post_path, data = "")] + async fn $post<'a: 'd, 'd>( + state: &'a $crate::rocket_helper::deps::State<$state>, + method: &'a str, + data: $crate::rocket_helper::deps::Data<'d>, + json: bool, + rpc_request: $crate::rocket_helper::deps::RpcRequest<'a>, + ) -> $crate::rocket_helper::deps::Custom> { + $crate::rocket_helper::deps::PrpcHandler::builder() + .state(&**state) + .request(rpc_request) + .method(method) + .data(data) + .json(json) + .build() + .handle::<$handler>() + .await + } + + #[rocket::get("/")] + async fn $get<'a: 'd, 'd>( + state: &'a $crate::rocket_helper::deps::State<$state>, + rpc_request: $crate::rocket_helper::deps::RpcRequest<'a>, + method: &'a str, + ) -> $crate::rocket_helper::deps::Custom> { + $crate::rocket_helper::deps::PrpcHandler::builder() + .state(&**state) + .request(rpc_request) + .method(method) + .json(true) + .build() + .handle::<$handler>() + .await + } + }; +} + +macro_rules! from_request { + ($request:expr) => { + match FromRequest::from_request($request).await { + Outcome::Success(v) => v, + Outcome::Error(e) => return Outcome::Error(e), + Outcome::Forward(f) => return Outcome::Forward(f), + } + }; +} + +#[rocket::async_trait] +impl<'r> FromRequest<'r> for &'r QuoteVerifier { + type Error = (); + + async fn from_request(request: &'r Request<'_>) -> Outcome { + let state: &rocket::State = from_request!(request); + Outcome::Success(state) + } +} + impl QuoteVerifier { pub fn new(pccs_url: String) -> Self { Self { @@ -74,20 +150,40 @@ fn limit_for_method(method: &str, limits: &Limits) -> ByteUnit { } #[derive(bon::Builder)] -pub struct PrpcHandler<'a, 'b, 'c, 'd, 'e, 'f, 'g, S> { - pub state: &'a S, - pub remote_addr: Option, - pub certificate: Option>, - pub quote_verifier: Option<&'c QuoteVerifier>, - pub method: &'d str, - pub data: Option>, - pub query: Option>, - pub limits: &'f Limits, - pub content_type: Option<&'g ContentType>, - pub json: bool, +pub struct PrpcHandler<'s, 'r, S> { + state: &'s S, + request: RpcRequest<'r>, + method: &'r str, + json: bool, + data: Option>, +} + +pub struct RpcRequest<'r> { + remote_addr: Option<&'r Endpoint>, + certificate: Option>, + quote_verifier: Option<&'r QuoteVerifier>, + orgin: &'r Origin<'r>, + limits: &'r Limits, + content_type: Option<&'r ContentType>, } -impl<'a, 'b, 'c, 'd, 'e, 'f, 'g, S> PrpcHandler<'a, 'b, 'c, 'd, 'e, 'f, 'g, S> { +#[rocket::async_trait] +impl<'r> FromRequest<'r> for RpcRequest<'r> { + type Error = Infallible; + + async fn from_request(request: &'r Request<'_>) -> Outcome { + Outcome::Success(Self { + remote_addr: from_request!(request), + certificate: from_request!(request), + quote_verifier: from_request!(request), + orgin: from_request!(request), + limits: from_request!(request), + content_type: from_request!(request), + }) + } +} + +impl<'s, 'r, S> PrpcHandler<'s, 'r, S> { pub async fn handle>(self) -> Custom> { let json = self.json; let result = handle_prpc_impl::(self).await; @@ -124,23 +220,22 @@ impl From for RemoteEndpoint { } pub async fn handle_prpc_impl>( - args: PrpcHandler<'_, '_, '_, '_, '_, '_, '_, S>, + args: PrpcHandler<'_, '_, S>, ) -> Result>> { let PrpcHandler { state, - certificate, - quote_verifier, + request, method, - data, - query, - limits, - content_type, json, - remote_addr, + data, } = args; - let mut attestation = certificate.map(extract_attestation).transpose()?.flatten(); + let mut attestation = request + .certificate + .map(extract_attestation) + .transpose()? + .flatten(); let todo = "verified attestation needs to be a distinct type"; - if let (Some(quote_verifier), Some(attestation)) = (quote_verifier, &mut attestation) { + if let (Some(quote_verifier), Some(attestation)) = (request.quote_verifier, &mut attestation) { let verified_report = quote_verifier .verify_quote(attestation) .await @@ -149,30 +244,28 @@ pub async fn handle_prpc_impl>( } else if attestation.is_some() { info!("the ra quote is not verified"); } - let data = match data { + let is_get = data.is_none(); + let payload = match data { Some(data) => { - let limit = limit_for_method(method, limits); + let limit = limit_for_method(method, request.limits); let todo = "confirm this would not truncate the data"; read_data(data, limit) .await .context("failed to read data")? } - None => vec![], + None => request + .orgin + .query() + .map_or(vec![], |q| q.as_bytes().to_vec()), }; - let json = json || content_type.map(|t| t.is_json()).unwrap_or(false); + let json = json || request.content_type.map(|t| t.is_json()).unwrap_or(false); let context = CallContext { state, attestation, - remote_endpoint: remote_addr.map(RemoteEndpoint::from), + remote_endpoint: request.remote_addr.cloned().map(RemoteEndpoint::from), }; let call = Call::construct(context).context("failed to construct call")?; - let data = match query { - Some(query) => query.as_bytes().to_vec(), - None => data.to_vec(), - }; - let (status_code, output) = call - .call(method.to_string(), data, json, query.is_some()) - .await; + let (status_code, output) = call.call(method.to_string(), payload, json, is_get).await; Ok(Custom(Status::new(status_code), output)) } diff --git a/tappd/src/guest_api_routes.rs b/tappd/src/guest_api_routes.rs index 003ffaed3..9c4e2f84f 100644 --- a/tappd/src/guest_api_routes.rs +++ b/tappd/src/guest_api_routes.rs @@ -1,59 +1,7 @@ use crate::{guest_api_service::GuestApiHandler, AppState}; -use ra_rpc::rocket_helper::PrpcHandler; +use rocket::{routes, Route}; -use rocket::{ - data::{Data, Limits}, - get, - http::{uri::Origin, ContentType}, - mtls::Certificate, - post, - response::status::Custom, - routes, Route, State, -}; - -#[post("/?", data = "")] -#[allow(clippy::too_many_arguments)] -async fn prpc_post( - state: &State, - cert: Option>, - method: &str, - data: Data<'_>, - limits: &Limits, - content_type: Option<&ContentType>, - json: bool, -) -> Custom> { - PrpcHandler::builder() - .state(&**state) - .maybe_certificate(cert) - .method(method) - .data(data) - .limits(limits) - .maybe_content_type(content_type) - .json(json) - .build() - .handle::() - .await -} - -#[get("/")] -async fn prpc_get( - state: &State, - method: &str, - origin: &Origin<'_>, - limits: &Limits, - content_type: Option<&ContentType>, -) -> Custom> { - PrpcHandler::builder() - .state(&**state) - .method(method) - .limits(limits) - .maybe_content_type(content_type) - .json(true) - .maybe_query(origin.query()) - .build() - .handle::() - .await -} +ra_rpc::declare_prpc_routes!(prpc_post, prpc_get, AppState, GuestApiHandler); pub fn routes() -> Vec { routes![prpc_post, prpc_get] diff --git a/tappd/src/http_routes.rs b/tappd/src/http_routes.rs index 57c57c8c6..b94334248 100644 --- a/tappd/src/http_routes.rs +++ b/tappd/src/http_routes.rs @@ -4,67 +4,32 @@ use crate::rpc_service::{AppState, ExternalRpcHandler, InternalRpcHandler}; use anyhow::Result; use docker_logs::parse_duration; use guest_api::guest_api_server::GuestApiRpc; -use ra_rpc::rocket_helper::PrpcHandler; use ra_rpc::{CallContext, RpcCall}; use rinja::Template; use rocket::futures::StreamExt; -use rocket::http::uri::Origin; use rocket::response::stream::TextStream; -use rocket::{ - data::{Data, Limits}, - get, - http::ContentType, - post, - response::{content::RawHtml, status::Custom}, - routes, Route, State, -}; +use rocket::{get, response::content::RawHtml, routes, Route, State}; use tappd_rpc::{worker_server::WorkerRpc, WorkerInfo}; -#[post("/prpc/?", data = "")] -async fn prpc_post( - state: &State, - method: &str, - data: Data<'_>, - limits: &Limits, - content_type: Option<&ContentType>, - json: bool, -) -> Custom> { - PrpcHandler::builder() - .state(&**state) - .method(method) - .data(data) - .limits(limits) - .maybe_content_type(content_type) - .json(json) - .build() - .handle::() - .await -} - -#[get("/prpc/")] -async fn prpc_get( - state: &State, - method: &str, - limits: &Limits, - content_type: Option<&ContentType>, - origin: &Origin<'_>, -) -> Custom> { - PrpcHandler::builder() - .state(&**state) - .method(method) - .limits(limits) - .maybe_content_type(content_type) - .json(true) - .maybe_query(origin.query()) - .build() - .handle::() - .await -} - +ra_rpc::declare_prpc_routes!(prpc_post, prpc_get, AppState, InternalRpcHandler); pub fn internal_routes() -> Vec { routes![prpc_post, prpc_get] } +ra_rpc::declare_prpc_routes!( + external_prpc_post, + external_prpc_get, + AppState, + ExternalRpcHandler +); +pub fn external_routes(config: &Config) -> Vec { + let mut routes = routes![index, external_prpc_post, external_prpc_get]; + if config.public_logs { + routes.extend(routes![get_logs]); + } + routes +} + #[get("/")] async fn index(state: &State) -> Result, String> { let context = CallContext::builder().state(&**state).build(); @@ -104,47 +69,6 @@ async fn index(state: &State) -> Result, String> { } } -#[post("/prpc/?", data = "")] -async fn external_prpc_post( - state: &State, - method: &str, - data: Data<'_>, - limits: &Limits, - content_type: Option<&ContentType>, - json: bool, -) -> Custom> { - PrpcHandler::builder() - .state(&**state) - .method(method) - .data(data) - .limits(limits) - .maybe_content_type(content_type) - .json(json) - .build() - .handle::() - .await -} - -#[get("/prpc/")] -async fn external_prpc_get( - state: &State, - method: &str, - limits: &Limits, - content_type: Option<&ContentType>, - origin: &Origin<'_>, -) -> Custom> { - PrpcHandler::builder() - .state(&**state) - .method(method) - .limits(limits) - .maybe_content_type(content_type) - .json(true) - .maybe_query(origin.query()) - .build() - .handle::() - .await -} - #[get("/logs/?&&&&&&")] #[allow(clippy::too_many_arguments)] fn get_logs( @@ -195,14 +119,6 @@ fn get_logs( } } -pub fn external_routes(config: &Config) -> Vec { - let mut routes = routes![index, external_prpc_post, external_prpc_get]; - if config.public_logs { - routes.extend(routes![get_logs]); - } - routes -} - mod docker_logs { use std::time::{SystemTime, UNIX_EPOCH}; diff --git a/teepod/src/guest_api_routes.rs b/teepod/src/guest_api_routes.rs index 6a0cd4b26..213b68cba 100644 --- a/teepod/src/guest_api_routes.rs +++ b/teepod/src/guest_api_routes.rs @@ -1,59 +1,7 @@ use crate::{guest_api_service::GuestApiHandler, App}; -use ra_rpc::rocket_helper::PrpcHandler; +use rocket::{routes, Route}; -use rocket::{ - data::{Data, Limits}, - get, - http::{uri::Origin, ContentType}, - mtls::Certificate, - post, - response::status::Custom, - routes, Route, State, -}; - -#[post("/?", data = "")] -#[allow(clippy::too_many_arguments)] -async fn prpc_post( - state: &State, - cert: Option>, - method: &str, - data: Data<'_>, - limits: &Limits, - content_type: Option<&ContentType>, - json: bool, -) -> Custom> { - PrpcHandler::builder() - .state(&**state) - .maybe_certificate(cert) - .method(method) - .data(data) - .limits(limits) - .maybe_content_type(content_type) - .json(json) - .build() - .handle::() - .await -} - -#[get("/")] -async fn prpc_get( - state: &State, - method: &str, - limits: &Limits, - content_type: Option<&ContentType>, - origin: &Origin<'_>, -) -> Custom> { - PrpcHandler::builder() - .state(&**state) - .method(method) - .limits(limits) - .maybe_content_type(content_type) - .json(true) - .maybe_query(origin.query()) - .build() - .handle::() - .await -} +ra_rpc::declare_prpc_routes!(prpc_post, prpc_get, App, GuestApiHandler); pub fn routes() -> Vec { routes![prpc_post, prpc_get] diff --git a/teepod/src/host_api_routes.rs b/teepod/src/host_api_routes.rs index 7313f55ce..3fd3cb300 100644 --- a/teepod/src/host_api_routes.rs +++ b/teepod/src/host_api_routes.rs @@ -1,64 +1,8 @@ use crate::app::App; use crate::host_api_service::HostApiHandler; -use ra_rpc::rocket_helper::PrpcHandler; -use rocket::{ - data::{Data, Limits}, - get, - http::{uri::Origin, ContentType}, - listener::Endpoint, - mtls::Certificate, - post, - response::status::Custom, - routes, Route, State, -}; +use rocket::{routes, Route}; -#[post("/?", data = "")] -#[allow(clippy::too_many_arguments)] -async fn prpc_post( - endpoint: &Endpoint, - state: &State, - cert: Option>, - method: &str, - data: Data<'_>, - limits: &Limits, - content_type: Option<&ContentType>, - json: bool, -) -> Custom> { - PrpcHandler::builder() - .state(&**state) - .remote_addr(endpoint.clone()) - .maybe_certificate(cert) - .method(method) - .data(data) - .limits(limits) - .maybe_content_type(content_type) - .json(json) - .build() - .handle::() - .await -} - -#[get("/")] -async fn prpc_get( - endpoint: &Endpoint, - state: &State, - method: &str, - limits: &Limits, - content_type: Option<&ContentType>, - origin: &Origin<'_>, -) -> Custom> { - PrpcHandler::builder() - .state(&**state) - .remote_addr(endpoint.clone()) - .method(method) - .limits(limits) - .maybe_content_type(content_type) - .json(true) - .maybe_query(origin.query()) - .build() - .handle::() - .await -} +ra_rpc::declare_prpc_routes!(prpc_post, prpc_get, App, HostApiHandler); pub fn routes() -> Vec { routes![prpc_post, prpc_get] diff --git a/teepod/src/main_routes.rs b/teepod/src/main_routes.rs index 8cd7b86d5..711b02a8f 100644 --- a/teepod/src/main_routes.rs +++ b/teepod/src/main_routes.rs @@ -2,13 +2,9 @@ use crate::app::App; use crate::main_service::RpcHandler; use anyhow::Result; use fs_err as fs; -use ra_rpc::rocket_helper::PrpcHandler; use rocket::{ - data::{Data, Limits}, get, - http::{uri::Origin, ContentType}, - mtls::Certificate, - post, + http::ContentType, response::{status::Custom, stream::TextStream}, routes, Route, State, }; @@ -44,51 +40,7 @@ async fn res(path: &str) -> Result<(ContentType, String), Custom> { } } -#[post("/prpc/?", data = "")] -#[allow(clippy::too_many_arguments)] -async fn prpc_post( - _auth: Authorized, - state: &State, - cert: Option>, - method: &str, - data: Data<'_>, - limits: &Limits, - content_type: Option<&ContentType>, - json: bool, -) -> Custom> { - PrpcHandler::builder() - .state(&**state) - .maybe_certificate(cert) - .method(method) - .data(data) - .limits(limits) - .maybe_content_type(content_type) - .json(json) - .build() - .handle::() - .await -} - -#[get("/prpc/")] -async fn prpc_get( - _auth: Authorized, - state: &State, - method: &str, - limits: &Limits, - content_type: Option<&ContentType>, - origin: &Origin<'_>, -) -> Custom> { - PrpcHandler::builder() - .state(&**state) - .method(method) - .limits(limits) - .maybe_content_type(content_type) - .json(true) - .maybe_query(origin.query()) - .build() - .handle::() - .await -} +ra_rpc::declare_prpc_routes!(prpc_post, prpc_get, App, RpcHandler); static STREAM_CREATED_COUNTER: AtomicUsize = AtomicUsize::new(0); static STREAM_DROPPED_COUNTER: AtomicUsize = AtomicUsize::new(0); diff --git a/tproxy/src/web_routes.rs b/tproxy/src/web_routes.rs index 57816fbd4..e16bc71a1 100644 --- a/tproxy/src/web_routes.rs +++ b/tproxy/src/web_routes.rs @@ -1,15 +1,6 @@ use crate::main_service::{Proxy, RpcHandler}; use anyhow::Result; -use ra_rpc::rocket_helper::{PrpcHandler, QuoteVerifier}; -use rocket::{ - data::{Data, Limits}, - get, - http::{uri::Origin, ContentType}, - mtls::Certificate, - post, - response::{content::RawHtml, status::Custom}, - routes, Route, State, -}; +use rocket::{get, response::content::RawHtml, routes, Route, State}; mod route_index; @@ -18,55 +9,7 @@ async fn index(state: &State) -> Result, String> { route_index::index(state).await.map_err(|e| format!("{e}")) } -#[post("/prpc/?", data = "")] -#[allow(clippy::too_many_arguments)] -async fn prpc_post( - state: &State, - cert: Option>, - quote_verifier: Option<&State>, - method: &str, - data: Data<'_>, - limits: &Limits, - content_type: Option<&ContentType>, - json: bool, -) -> Custom> { - PrpcHandler::builder() - .state(&**state) - .maybe_certificate(cert) - .maybe_quote_verifier(quote_verifier.map(|v| &**v)) - .method(method) - .data(data) - .limits(limits) - .maybe_content_type(content_type) - .json(json) - .build() - .handle::() - .await -} - -#[get("/prpc/")] -async fn prpc_get( - state: &State, - cert: Option>, - quote_verifier: Option<&State>, - method: &str, - limits: &Limits, - content_type: Option<&ContentType>, - origin: &Origin<'_>, -) -> Custom> { - PrpcHandler::builder() - .state(&**state) - .maybe_certificate(cert) - .maybe_quote_verifier(quote_verifier.map(|v| &**v)) - .method(method) - .limits(limits) - .maybe_content_type(content_type) - .json(true) - .maybe_query(origin.query()) - .build() - .handle::() - .await -} +ra_rpc::declare_prpc_routes!(prpc_post, prpc_get, Proxy, RpcHandler); pub fn routes() -> Vec { routes![index, prpc_post, prpc_get] From 43de41b16c83710d77361b3626bc0cff75e69be1 Mon Sep 17 00:00:00 2001 From: Kevin Wang Date: Mon, 23 Dec 2024 07:22:55 +0000 Subject: [PATCH 3/3] Eliminates prpc routes file --- Cargo.lock | 4 +- kms/src/main.rs | 4 +- kms/src/main_service.rs | 9 +--- kms/src/web_routes.rs | 13 ----- ra-rpc/src/lib.rs | 24 +++++---- ra-rpc/src/rocket_helper.rs | 89 ++++++++++++++++++++++++--------- tappd/src/guest_api_routes.rs | 8 --- tappd/src/guest_api_service.rs | 9 +--- tappd/src/http_routes.rs | 21 ++++---- tappd/src/main.rs | 7 +-- tappd/src/rpc_service.rs | 18 +------ teepod/src/guest_api_routes.rs | 8 --- teepod/src/guest_api_service.rs | 9 +--- teepod/src/host_api_routes.rs | 9 ---- teepod/src/host_api_service.rs | 9 +--- teepod/src/main.rs | 11 ++-- teepod/src/main_routes.rs | 5 +- teepod/src/main_service.rs | 9 +--- tproxy/src/main.rs | 2 + tproxy/src/main_service.rs | 9 +--- tproxy/src/web_routes.rs | 6 +-- 21 files changed, 119 insertions(+), 164 deletions(-) delete mode 100644 kms/src/web_routes.rs delete mode 100644 tappd/src/guest_api_routes.rs delete mode 100644 teepod/src/guest_api_routes.rs delete mode 100644 teepod/src/host_api_routes.rs diff --git a/Cargo.lock b/Cargo.lock index c000ab68a..3be1e594e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3439,9 +3439,9 @@ dependencies = [ [[package]] name = "prpc-build" -version = "0.4.0" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "afd0da11e53c652b1cd4828f7ca4f6bfd607bd666a48203ebad6b30bd0b2c9ec" +checksum = "caf018aca01f6c5e7b0c312e484b29ebc63cf42404e4c35d105b1703e6350bfb" dependencies = [ "either", "fs-err", diff --git a/kms/src/main.rs b/kms/src/main.rs index 7e49496e5..641872ca0 100644 --- a/kms/src/main.rs +++ b/kms/src/main.rs @@ -1,6 +1,7 @@ use anyhow::{anyhow, Context, Result}; use clap::Parser; use config::KmsConfig; +use main_service::{KmsState, RpcHandler}; use ra_rpc::rocket_helper::QuoteVerifier; use rocket::fairing::AdHoc; use tracing::info; @@ -8,7 +9,6 @@ use tracing::info; mod config; mod ct_log; mod main_service; -mod web_routes; fn app_version() -> String { const CARGO_PKG_VERSION: &str = env!("CARGO_PKG_VERSION"); @@ -53,7 +53,7 @@ async fn main() -> Result<()> { res.set_raw_header("X-App-Version", app_version()); }) })) - .mount("/", web_routes::routes()) + .mount("/prpc", ra_rpc::prpc_routes!(KmsState, RpcHandler)) .manage(state); if !pccs_url.is_empty() { diff --git a/kms/src/main_service.rs b/kms/src/main_service.rs index 415304cf7..2262c082d 100644 --- a/kms/src/main_service.rs +++ b/kms/src/main_service.rs @@ -198,14 +198,7 @@ impl KmsRpc for RpcHandler { impl RpcCall for RpcHandler { type PrpcService = KmsServer; - fn into_prpc_service(self) -> Self::PrpcService { - KmsServer::new(self) - } - - fn construct(context: CallContext<'_, KmsState>) -> Result - where - Self: Sized, - { + fn construct(context: CallContext<'_, KmsState>) -> Result { Ok(RpcHandler { state: context.state.clone(), attestation: context.attestation, diff --git a/kms/src/web_routes.rs b/kms/src/web_routes.rs deleted file mode 100644 index f90d0cddb..000000000 --- a/kms/src/web_routes.rs +++ /dev/null @@ -1,13 +0,0 @@ -use crate::{main_service::KmsState, main_service::RpcHandler}; -use rocket::{get, routes, Route}; - -#[get("/")] -async fn index() -> String { - "KMS Server is running!\n".to_string() -} - -ra_rpc::declare_prpc_routes!(prpc_post, prpc_get, KmsState, RpcHandler); - -pub fn routes() -> Vec { - routes![index, prpc_post, prpc_get] -} diff --git a/ra-rpc/src/lib.rs b/ra-rpc/src/lib.rs index 98e21e9c7..b602dc40d 100644 --- a/ra-rpc/src/lib.rs +++ b/ra-rpc/src/lib.rs @@ -33,24 +33,26 @@ pub struct CallContext<'a, State> { pub remote_endpoint: Option, } -pub trait RpcCall { - type PrpcService: PrpcService + Send + 'static; +pub trait RpcCall: Sized { + type PrpcService: PrpcService + From + Send + 'static; + + fn construct(context: CallContext<'_, State>) -> Result; - fn construct(context: CallContext<'_, State>) -> Result - where - Self: Sized; - fn into_prpc_service(self) -> Self::PrpcService; async fn call( self, method: String, payload: Vec, is_json: bool, is_query: bool, - ) -> (u16, Vec) - where - Self: Sized, - { - dispatch_prpc(method, payload, is_json, is_query, self.into_prpc_service()).await + ) -> (u16, Vec) { + dispatch_prpc( + method, + payload, + is_json, + is_query, + >::from(self), + ) + .await } } diff --git a/ra-rpc/src/rocket_helper.rs b/ra-rpc/src/rocket_helper.rs index 37d3bae3d..223d6045c 100644 --- a/ra-rpc/src/rocket_helper.rs +++ b/ra-rpc/src/rocket_helper.rs @@ -10,7 +10,7 @@ use ra_tls::{ }; use rocket::{ data::{ByteUnit, Data, Limits, ToByteUnit}, - http::{uri::Origin, ContentType, Status}, + http::{uri::Origin, ContentType, Method, Status}, listener::Endpoint, mtls::{oid::Oid, Certificate}, request::{FromRequest, Outcome}, @@ -34,45 +34,65 @@ pub mod deps { pub use rocket::{Data, State}; } +fn query_field_get_raw<'r>(req: &'r Request<'_>, field_name: &str) -> Option<&'r str> { + for field in req.query_fields() { + let raw = (field.name.source().as_str(), field.value); + let key = field.name.key_lossy().as_str(); + if key == field_name { + return Some(field.value); + } + } + None +} + +fn query_field_get_bool(req: &Request<'_>, field_name: &str) -> bool { + matches!( + query_field_get_raw(req, field_name), + Some("true" | "1" | "") + ) +} + +#[macro_export] +macro_rules! prpc_routes { + ($state:ty, $handler:ty) => {{ + ra_rpc::declare_prpc_routes!(prpc_post, prpc_get, $state, $handler); + rocket::routes![prpc_post, prpc_get] + }}; +} + #[macro_export] macro_rules! declare_prpc_routes { ($post:ident, $get:ident, $state:ty, $handler:ty) => { - $crate::declare_prpc_routes!(path: "/prpc/?", "/prpc/", $post, $get, $state, $handler); - }; - (bare: $post:ident, $get:ident, $state:ty, $handler:ty) => { - $crate::declare_prpc_routes!(path: "/?", "/", $post, $get, $state, $handler); + $crate::declare_prpc_routes!(path: "/", $post, $get, $state, $handler); }; - (path: $post_path: literal, $get_path: literal, $post:ident, $get:ident, $state:ty, $handler:ty) => { - #[rocket::post($post_path, data = "")] + (path: $path: literal, $post:ident, $get:ident, $state:ty, $handler:ty) => { + #[rocket::post($path, data = "")] async fn $post<'a: 'd, 'd>( state: &'a $crate::rocket_helper::deps::State<$state>, method: &'a str, - data: $crate::rocket_helper::deps::Data<'d>, - json: bool, rpc_request: $crate::rocket_helper::deps::RpcRequest<'a>, + data: $crate::rocket_helper::deps::Data<'d>, ) -> $crate::rocket_helper::deps::Custom> { $crate::rocket_helper::deps::PrpcHandler::builder() .state(&**state) .request(rpc_request) .method(method) .data(data) - .json(json) .build() .handle::<$handler>() .await } - #[rocket::get("/")] - async fn $get<'a: 'd, 'd>( - state: &'a $crate::rocket_helper::deps::State<$state>, - rpc_request: $crate::rocket_helper::deps::RpcRequest<'a>, - method: &'a str, + #[rocket::get($path)] + async fn $get( + state: &$crate::rocket_helper::deps::State<$state>, + method: &str, + rpc_request: $crate::rocket_helper::deps::RpcRequest<'_>, ) -> $crate::rocket_helper::deps::Custom> { $crate::rocket_helper::deps::PrpcHandler::builder() .state(&**state) .request(rpc_request) .method(method) - .json(true) .build() .handle::<$handler>() .await @@ -80,6 +100,29 @@ macro_rules! declare_prpc_routes { }; } +#[macro_export] +macro_rules! prpc_alias { + (get: $name:ident, $alias:literal -> $prpc:ident($method:literal, $state:ty)) => { + #[rocket::get($alias)] + async fn $name( + state: &$crate::rocket_helper::deps::State<$state>, + rpc_request: $crate::rocket_helper::deps::RpcRequest<'_>, + ) -> $crate::rocket_helper::deps::Custom> { + $prpc(state, $method, rpc_request).await + } + }; + (post: $name:ident, $alias:literal -> $prpc:ident($method:literal, $state:ty)) => { + #[rocket::post($alias, data = "")] + async fn $name<'a: 'd, 'd>( + state: &'a $crate::rocket_helper::deps::State<$state>, + rpc_request: $crate::rocket_helper::deps::RpcRequest<'a>, + data: $crate::rocket_helper::deps::Data<'d>, + ) -> $crate::rocket_helper::deps::Custom> { + $prpc(state, $method, rpc_request, data).await + } + }; +} + macro_rules! from_request { ($request:expr) => { match FromRequest::from_request($request).await { @@ -154,7 +197,6 @@ pub struct PrpcHandler<'s, 'r, S> { state: &'s S, request: RpcRequest<'r>, method: &'r str, - json: bool, data: Option>, } @@ -162,9 +204,10 @@ pub struct RpcRequest<'r> { remote_addr: Option<&'r Endpoint>, certificate: Option>, quote_verifier: Option<&'r QuoteVerifier>, - orgin: &'r Origin<'r>, + origin: &'r Origin<'r>, limits: &'r Limits, content_type: Option<&'r ContentType>, + json: bool, } #[rocket::async_trait] @@ -176,16 +219,17 @@ impl<'r> FromRequest<'r> for RpcRequest<'r> { remote_addr: from_request!(request), certificate: from_request!(request), quote_verifier: from_request!(request), - orgin: from_request!(request), + origin: from_request!(request), limits: from_request!(request), content_type: from_request!(request), + json: request.method() == Method::Get || query_field_get_bool(request, "json"), }) } } impl<'s, 'r, S> PrpcHandler<'s, 'r, S> { pub async fn handle>(self) -> Custom> { - let json = self.json; + let json = self.request.json; let result = handle_prpc_impl::(self).await; match result { Ok(output) => output, @@ -226,7 +270,6 @@ pub async fn handle_prpc_impl>( state, request, method, - json, data, } = args; let mut attestation = request @@ -254,11 +297,11 @@ pub async fn handle_prpc_impl>( .context("failed to read data")? } None => request - .orgin + .origin .query() .map_or(vec![], |q| q.as_bytes().to_vec()), }; - let json = json || request.content_type.map(|t| t.is_json()).unwrap_or(false); + let json = request.json || request.content_type.map(|t| t.is_json()).unwrap_or(false); let context = CallContext { state, attestation, diff --git a/tappd/src/guest_api_routes.rs b/tappd/src/guest_api_routes.rs deleted file mode 100644 index 9c4e2f84f..000000000 --- a/tappd/src/guest_api_routes.rs +++ /dev/null @@ -1,8 +0,0 @@ -use crate::{guest_api_service::GuestApiHandler, AppState}; -use rocket::{routes, Route}; - -ra_rpc::declare_prpc_routes!(prpc_post, prpc_get, AppState, GuestApiHandler); - -pub fn routes() -> Vec { - routes![prpc_post, prpc_get] -} diff --git a/tappd/src/guest_api_service.rs b/tappd/src/guest_api_service.rs index 40ac16741..4f790e7f1 100644 --- a/tappd/src/guest_api_service.rs +++ b/tappd/src/guest_api_service.rs @@ -29,14 +29,7 @@ pub struct GuestApiHandler { impl RpcCall for GuestApiHandler { type PrpcService = GuestApiServer; - fn into_prpc_service(self) -> Self::PrpcService { - GuestApiServer::new(self) - } - - fn construct(context: CallContext<'_, AppState>) -> Result - where - Self: Sized, - { + fn construct(context: CallContext<'_, AppState>) -> Result { Ok(Self { state: context.state.clone(), }) diff --git a/tappd/src/http_routes.rs b/tappd/src/http_routes.rs index b94334248..15ccd891e 100644 --- a/tappd/src/http_routes.rs +++ b/tappd/src/http_routes.rs @@ -11,19 +11,22 @@ use rocket::response::stream::TextStream; use rocket::{get, response::content::RawHtml, routes, Route, State}; use tappd_rpc::{worker_server::WorkerRpc, WorkerInfo}; -ra_rpc::declare_prpc_routes!(prpc_post, prpc_get, AppState, InternalRpcHandler); -pub fn internal_routes() -> Vec { - routes![prpc_post, prpc_get] -} - ra_rpc::declare_prpc_routes!( - external_prpc_post, - external_prpc_get, + path: "/prpc/", + prpc_post, + prpc_get, AppState, - ExternalRpcHandler + InternalRpcHandler ); + +ra_rpc::prpc_alias!(get: quote_get, "/quote" -> prpc_get("Tappd.TdxQuote", AppState)); + +pub fn internal_routes() -> Vec { + routes![prpc_post, prpc_get, quote_get] +} + pub fn external_routes(config: &Config) -> Vec { - let mut routes = routes![index, external_prpc_post, external_prpc_get]; + let mut routes = routes![index]; if config.public_logs { routes.extend(routes![get_logs]); } diff --git a/tappd/src/main.rs b/tappd/src/main.rs index 673172e73..adda96870 100644 --- a/tappd/src/main.rs +++ b/tappd/src/main.rs @@ -2,19 +2,19 @@ use std::{fs::Permissions, future::pending, os::unix::fs::PermissionsExt}; use anyhow::{anyhow, Context, Result}; use clap::Parser; +use guest_api_service::GuestApiHandler; use rocket::{ fairing::AdHoc, figment::Figment, listener::{Bind, DefaultListener}, }; use rocket_vsock_listener::VsockListener; -use rpc_service::AppState; +use rpc_service::{AppState, ExternalRpcHandler}; use sd_notify::{notify as sd_notify, NotifyState}; use std::time::Duration; use tracing::{error, info}; mod config; -mod guest_api_routes; mod guest_api_service; mod http_routes; mod models; @@ -69,6 +69,7 @@ async fn run_internal(state: AppState, figment: Figment) -> Result<()> { async fn run_external(state: AppState, figment: Figment) -> Result<()> { let rocket = rocket::custom(figment) .mount("/", http_routes::external_routes(state.config())) + .mount("/prpc", ra_rpc::prpc_routes!(AppState, ExternalRpcHandler)) .attach(AdHoc::on_response("Add app version header", |_req, res| { Box::pin(async move { res.set_raw_header("X-App-Version", app_version()); @@ -84,7 +85,7 @@ async fn run_external(state: AppState, figment: Figment) -> Result<()> { async fn run_guest_api(state: AppState, figment: Figment) -> Result<()> { let rocket = rocket::custom(figment) - .mount("/api", guest_api_routes::routes()) + .mount("/api", ra_rpc::prpc_routes!(AppState, GuestApiHandler)) .manage(state); let ignite = rocket diff --git a/tappd/src/rpc_service.rs b/tappd/src/rpc_service.rs index 1654ac4ad..3b2e3fbbe 100644 --- a/tappd/src/rpc_service.rs +++ b/tappd/src/rpc_service.rs @@ -88,14 +88,7 @@ impl TappdRpc for InternalRpcHandler { impl RpcCall for InternalRpcHandler { type PrpcService = TappdServer; - fn into_prpc_service(self) -> Self::PrpcService { - TappdServer::new(self) - } - - fn construct(context: CallContext<'_, AppState>) -> Result - where - Self: Sized, - { + fn construct(context: CallContext<'_, AppState>) -> Result { Ok(InternalRpcHandler { state: context.state.clone(), }) @@ -172,14 +165,7 @@ impl WorkerRpc for ExternalRpcHandler { impl RpcCall for ExternalRpcHandler { type PrpcService = WorkerServer; - fn into_prpc_service(self) -> Self::PrpcService { - WorkerServer::new(self) - } - - fn construct(context: CallContext<'_, AppState>) -> Result - where - Self: Sized, - { + fn construct(context: CallContext<'_, AppState>) -> Result { Ok(ExternalRpcHandler { state: context.state.clone(), }) diff --git a/teepod/src/guest_api_routes.rs b/teepod/src/guest_api_routes.rs deleted file mode 100644 index 213b68cba..000000000 --- a/teepod/src/guest_api_routes.rs +++ /dev/null @@ -1,8 +0,0 @@ -use crate::{guest_api_service::GuestApiHandler, App}; -use rocket::{routes, Route}; - -ra_rpc::declare_prpc_routes!(prpc_post, prpc_get, App, GuestApiHandler); - -pub fn routes() -> Vec { - routes![prpc_post, prpc_get] -} diff --git a/teepod/src/guest_api_service.rs b/teepod/src/guest_api_service.rs index 74836039a..baa3cc9dc 100644 --- a/teepod/src/guest_api_service.rs +++ b/teepod/src/guest_api_service.rs @@ -22,14 +22,7 @@ impl Deref for GuestApiHandler { impl RpcCall for GuestApiHandler { type PrpcService = ProxiedGuestApiServer; - fn into_prpc_service(self) -> Self::PrpcService { - ProxiedGuestApiServer::new(self) - } - - fn construct(context: CallContext<'_, AppState>) -> Result - where - Self: Sized, - { + fn construct(context: CallContext<'_, AppState>) -> Result { Ok(Self { state: context.state.clone(), }) diff --git a/teepod/src/host_api_routes.rs b/teepod/src/host_api_routes.rs deleted file mode 100644 index 3fd3cb300..000000000 --- a/teepod/src/host_api_routes.rs +++ /dev/null @@ -1,9 +0,0 @@ -use crate::app::App; -use crate::host_api_service::HostApiHandler; -use rocket::{routes, Route}; - -ra_rpc::declare_prpc_routes!(prpc_post, prpc_get, App, HostApiHandler); - -pub fn routes() -> Vec { - routes![prpc_post, prpc_get] -} diff --git a/teepod/src/host_api_service.rs b/teepod/src/host_api_service.rs index dcfff1498..02359c694 100644 --- a/teepod/src/host_api_service.rs +++ b/teepod/src/host_api_service.rs @@ -16,14 +16,7 @@ pub struct HostApiHandler { impl RpcCall for HostApiHandler { type PrpcService = HostApiServer; - fn into_prpc_service(self) -> Self::PrpcService { - HostApiServer::new(self) - } - - fn construct(context: CallContext<'_, App>) -> Result - where - Self: Sized, - { + fn construct(context: CallContext<'_, App>) -> Result { let Some(RemoteEndpoint::Vsock { cid, port }) = context.remote_endpoint else { bail!("invalid remote endpoint: {:?}", context.remote_endpoint); }; diff --git a/teepod/src/main.rs b/teepod/src/main.rs index 953ed6e5b..7997fcd80 100644 --- a/teepod/src/main.rs +++ b/teepod/src/main.rs @@ -4,6 +4,9 @@ use anyhow::{anyhow, Context, Result}; use app::App; use clap::Parser; use config::Config; +use guest_api_service::GuestApiHandler; +use host_api_service::HostApiHandler; +use main_service::RpcHandler; use path_absolutize::Absolutize; use rocket::{ fairing::AdHoc, @@ -15,9 +18,7 @@ use supervisor_client::SupervisorClient; mod app; mod config; -mod guest_api_routes; mod guest_api_service; -mod host_api_routes; mod host_api_service; mod main_routes; mod main_service; @@ -44,7 +45,9 @@ struct Args { async fn run_external_api(app: App, figment: Figment, api_auth: ApiToken) -> Result<()> { let external_api = rocket::custom(figment) .mount("/", main_routes::routes()) - .mount("/guest", guest_api_routes::routes()) + .mount("/guest", ra_rpc::prpc_routes!(App, GuestApiHandler)) + .mount("/api", ra_rpc::prpc_routes!(App, HostApiHandler)) + .mount("/prpc", ra_rpc::prpc_routes!(App, RpcHandler)) .manage(app) .manage(api_auth) .attach(AdHoc::on_response("Add app rev header", |_req, res| { @@ -70,7 +73,7 @@ async fn run_host_api(app: App, figment: Figment) -> Result<()> { .clone() .merge(Serialized::defaults(figment.find_value("host_api")?)); let rocket = rocket::custom(figment) - .mount("/api", host_api_routes::routes()) + .mount("/api", ra_rpc::prpc_routes!(App, HostApiHandler)) .manage(app); let ignite = rocket .ignite() diff --git a/teepod/src/main_routes.rs b/teepod/src/main_routes.rs index 711b02a8f..dbfd4159a 100644 --- a/teepod/src/main_routes.rs +++ b/teepod/src/main_routes.rs @@ -1,5 +1,4 @@ use crate::app::App; -use crate::main_service::RpcHandler; use anyhow::Result; use fs_err as fs; use rocket::{ @@ -40,8 +39,6 @@ async fn res(path: &str) -> Result<(ContentType, String), Custom> { } } -ra_rpc::declare_prpc_routes!(prpc_post, prpc_get, App, RpcHandler); - static STREAM_CREATED_COUNTER: AtomicUsize = AtomicUsize::new(0); static STREAM_DROPPED_COUNTER: AtomicUsize = AtomicUsize::new(0); @@ -152,5 +149,5 @@ fn vm_logs( } pub fn routes() -> Vec { - routes![index, res, prpc_post, prpc_get, vm_logs] + routes![index, res, vm_logs] } diff --git a/teepod/src/main_service.rs b/teepod/src/main_service.rs index cccc9644a..e75e20d56 100644 --- a/teepod/src/main_service.rs +++ b/teepod/src/main_service.rs @@ -313,14 +313,7 @@ impl TeepodRpc for RpcHandler { impl RpcCall for RpcHandler { type PrpcService = TeepodServer; - fn into_prpc_service(self) -> Self::PrpcService { - TeepodServer::new(self) - } - - fn construct(context: CallContext<'_, App>) -> Result - where - Self: Sized, - { + fn construct(context: CallContext<'_, App>) -> Result { Ok(RpcHandler { app: context.state.clone(), }) diff --git a/tproxy/src/main.rs b/tproxy/src/main.rs index 4974b0f69..36f24d35d 100644 --- a/tproxy/src/main.rs +++ b/tproxy/src/main.rs @@ -1,6 +1,7 @@ use anyhow::{anyhow, Result}; use clap::Parser; use config::Config; +use main_service::{Proxy, RpcHandler}; use ra_rpc::rocket_helper::QuoteVerifier; use rocket::fairing::AdHoc; @@ -67,6 +68,7 @@ async fn main() -> Result<()> { let mut rocket = rocket::custom(figment) .mount("/", web_routes::routes()) + .mount("/prpc", ra_rpc::prpc_routes!(Proxy, RpcHandler)) .attach(AdHoc::on_response("Add app version header", |_req, res| { Box::pin(async move { res.set_raw_header("X-App-Version", app_version()); diff --git a/tproxy/src/main_service.rs b/tproxy/src/main_service.rs index 888741dab..9821748aa 100644 --- a/tproxy/src/main_service.rs +++ b/tproxy/src/main_service.rs @@ -451,14 +451,7 @@ impl TproxyRpc for RpcHandler { impl RpcCall for RpcHandler { type PrpcService = TproxyServer; - fn into_prpc_service(self) -> Self::PrpcService { - TproxyServer::new(self) - } - - fn construct(context: CallContext<'_, Proxy>) -> Result - where - Self: Sized, - { + fn construct(context: CallContext<'_, Proxy>) -> Result { Ok(RpcHandler { attestation: context.attestation, state: context.state.clone(), diff --git a/tproxy/src/web_routes.rs b/tproxy/src/web_routes.rs index e16bc71a1..026f3b64a 100644 --- a/tproxy/src/web_routes.rs +++ b/tproxy/src/web_routes.rs @@ -1,4 +1,4 @@ -use crate::main_service::{Proxy, RpcHandler}; +use crate::main_service::Proxy; use anyhow::Result; use rocket::{get, response::content::RawHtml, routes, Route, State}; @@ -9,8 +9,6 @@ async fn index(state: &State) -> Result, String> { route_index::index(state).await.map_err(|e| format!("{e}")) } -ra_rpc::declare_prpc_routes!(prpc_post, prpc_get, Proxy, RpcHandler); - pub fn routes() -> Vec { - routes![index, prpc_post, prpc_get] + routes![index] }