Skip to content

Commit

Permalink
impl parse_header_payload func
Browse files Browse the repository at this point in the history
  • Loading branch information
laruh committed May 16, 2024
1 parent 7e87c8f commit e8f296d
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 27 deletions.
5 changes: 3 additions & 2 deletions src/net/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,14 +112,15 @@ pub(crate) async fn http_handler(

let (req, payload) = match generate_payload_from_req(req, &proxy_route.proxy_type).await {
Ok(t) => t,
Err(_) => {
Err(e) => {
log::warn!(
"{}",
log_format!(
remote_addr.ip(),
String::from("-"),
req_uri,
"Received invalid http payload, returning 401."
"Received invalid http payload: {}, returning 401.",
e
)
);
return response_by_status(StatusCode::UNAUTHORIZED);
Expand Down
39 changes: 24 additions & 15 deletions src/proxy/mod.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
use crate::ctx::{AppConfig, GenericResult, ProxyRoute};
use crate::sign::SignedMessage;
use hyper::header::HeaderValue;
use hyper::{Body, Method, Request, Response, StatusCode, Uri};
use hyper::{Body, Request, Response, StatusCode, Uri};
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use std::net::SocketAddr;
mod moralis;
use moralis::{proxy_moralis, validation_middleware_moralis, MoralisPayload};
mod quicknode;
pub(crate) use quicknode::{proxy_quicknode, validation_middleware_quicknode, QuicknodePayload};

const X_AUTH_PAYLOAD: &str = "X-Auth-Payload";

/// Enumerates different proxy types supported by the application, focusing on separating feature logic.
/// This allows for differentiated handling based on what the proxy should do with the request,
/// directing each to the appropriate service or API based on its designated proxy type.
Expand Down Expand Up @@ -46,11 +49,11 @@ pub(crate) async fn generate_payload_from_req(
) -> GenericResult<(Request<Body>, PayloadData)> {
match proxy_type {
ProxyType::Quicknode => {
let (req, payload) = parse_payload::<QuicknodePayload>(req, false).await?;
let (req, payload) = parse_body_payload::<QuicknodePayload>(req).await?;
Ok((req, PayloadData::Quicknode(payload)))
}
ProxyType::Moralis => {
let (req, payload) = parse_payload::<MoralisPayload>(req, true).await?;
let (req, payload) = parse_header_payload::<MoralisPayload>(req).await?;
Ok((req, PayloadData::Moralis(payload)))
}
}
Expand Down Expand Up @@ -94,25 +97,31 @@ pub(crate) async fn validation_middleware(
/// Asynchronously parses an HTTP request's body into a specified type `T`. If the request method is `GET`,
/// the function modifies the request to have an empty body. For other methods, it retains the original body.
/// The function ensures that the body is not empty before attempting deserialization into the non-optional type `T`.
async fn parse_payload<T>(req: Request<Body>, get_req: bool) -> GenericResult<(Request<Body>, T)>
async fn parse_body_payload<T>(req: Request<Body>) -> GenericResult<(Request<Body>, T)>
where
T: serde::de::DeserializeOwned,
T: DeserializeOwned,
{
let (mut parts, body) = req.into_parts();
let (parts, body) = req.into_parts();
let body_bytes = hyper::body::to_bytes(body).await?;

if body_bytes.is_empty() {
return Err("Empty body cannot be deserialized into non-optional type T".into());
}

let payload: T = serde_json::from_slice(&body_bytes)?;
let new_req = Request::from_parts(parts, Body::from(body_bytes));
Ok((new_req, payload))
}

let new_req = if get_req {
parts.method = Method::GET;
Request::from_parts(parts, Body::empty())
} else {
Request::from_parts(parts, Body::from(body_bytes))
};

async fn parse_header_payload<T>(req: Request<Body>) -> GenericResult<(Request<Body>, T)>
where
T: DeserializeOwned,
{
let (parts, body) = req.into_parts();
let header_value = parts
.headers
.get(X_AUTH_PAYLOAD)
.ok_or("Missing X-Auth-Payload header")?
.to_str()?;
let payload: T = serde_json::from_str(header_value)?;
let new_req = Request::from_parts(parts, body);
Ok((new_req, payload))
}
21 changes: 13 additions & 8 deletions src/proxy/moralis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,8 @@ pub(crate) async fn validation_middleware_moralis(

#[tokio::test]
async fn test_parse_moralis_payload() {
use super::parse_payload;
use super::{parse_header_payload, X_AUTH_PAYLOAD};
use hyper::Method;

let serialized_payload = serde_json::json!({
"uri": "https://example.com/test-path",
Expand All @@ -258,21 +259,25 @@ async fn test_parse_moralis_payload() {
})
.to_string();

let mut req = Request::new(Body::from(serialized_payload));
req.headers_mut().insert(
HeaderName::from_static("accept"),
APPLICATION_JSON.parse().unwrap(),
);
let req = Request::builder()
.method(Method::GET)
.header(header::ACCEPT, HeaderValue::from_static(APPLICATION_JSON))
.header(
X_AUTH_PAYLOAD,
HeaderValue::from_str(&serialized_payload).unwrap(),
)
.body(Body::empty())
.unwrap();

let (mut req, payload) = parse_payload::<MoralisPayload>(req, true).await.unwrap();
let (mut req, payload) = parse_header_payload::<MoralisPayload>(req).await.unwrap();

let body_bytes = hyper::body::to_bytes(req.body_mut()).await.unwrap();
assert!(
body_bytes.is_empty(),
"Body should be empty for GET methods"
);

let header_value = req.headers().get("accept").unwrap();
let header_value = req.headers().get(header::ACCEPT).unwrap();

let expected_payload = MoralisPayload {
uri: Url::from_str("https://example.com/test-path").unwrap(),
Expand Down
4 changes: 2 additions & 2 deletions src/proxy/quicknode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ fn test_quicknode_payload_serialzation_and_deserialization() {

#[tokio::test]
async fn test_parse_quicknode_payload() {
use super::parse_payload;
use super::parse_body_payload;
use hyper::Method;

let serialized_payload = json!({
Expand All @@ -365,7 +365,7 @@ async fn test_parse_quicknode_payload() {
);

let (mut req, payload): (Request<Body>, QuicknodePayload) =
parse_payload::<QuicknodePayload>(req, false).await.unwrap();
parse_body_payload::<QuicknodePayload>(req).await.unwrap();

let body_bytes = hyper::body::to_bytes(req.body_mut()).await.unwrap();
assert!(
Expand Down

0 comments on commit e8f296d

Please sign in to comment.