Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions bin/icehutd/src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,22 @@ pub struct IceHutOpts {
)]
pub port: Option<u16>,

#[arg(
long,
env = "CORS_ENABLED",
help = "Enable CORS",
default_value = "false"
)]
pub cors_enabled: Option<bool>,

#[arg(
long,
env = "CORS_ALLOW_ORIGIN",
required_if_eq("cors_enabled", "true"),
help = "CORS Allow Origin"
)]
pub cors_allow_origin: Option<String>,

#[arg(long, default_value = "true")]
use_fs: Option<bool>,
}
Expand Down
9 changes: 8 additions & 1 deletion bin/icehutd/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ async fn main() {
let slatedb_prefix = opts.slatedb_prefix.clone();
let host = opts.host.clone().unwrap();
let port = opts.port.unwrap();
let allow_origin = if opts.cors_enabled.unwrap_or(false) {
opts.cors_allow_origin.clone()
} else {
None
};
let object_store = opts.object_store_backend();

match object_store {
Expand All @@ -31,7 +36,9 @@ async fn main() {
Ok(object_store) => {
tracing::info!("Starting ❄️🏠 IceHut...");

if let Err(e) = nexus::run_icehut(object_store, slatedb_prefix, host, port).await {
if let Err(e) =
nexus::run_icehut(object_store, slatedb_prefix, host, port, allow_origin).await
{
tracing::error!("Failed to start IceHut: {:?}", e);
}
}
Expand Down
39 changes: 39 additions & 0 deletions crates/nexus/src/http/error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
use axum::response::IntoResponse;
use http::header::InvalidHeaderValue;
use snafu::prelude::*;

#[derive(Debug, Snafu)]
#[snafu(visibility(pub(crate)))]
pub enum NexusHttpError {
#[snafu(display("Error parsing Allow-Origin header: {}", source))]
AllowOriginHeaderParse { source: InvalidHeaderValue },

#[snafu(display("Session load error: {msg}"))]
SessionLoad { msg: String },
#[snafu(display("Unable to persist session"))]
SessionPersist {
source: tower_sessions::session::Error,
},
}

impl IntoResponse for NexusHttpError {
fn into_response(self) -> axum::response::Response {
match self {
Self::AllowOriginHeaderParse { .. } => (
axum::http::StatusCode::INTERNAL_SERVER_ERROR,
"Allow-Origin header parse error",
)
.into_response(),
Self::SessionLoad { .. } => (
axum::http::StatusCode::INTERNAL_SERVER_ERROR,
"Session load error",
)
.into_response(),
Self::SessionPersist { .. } => (
axum::http::StatusCode::INTERNAL_SERVER_ERROR,
"Session persist error",
)
.into_response(),
}
}
}
23 changes: 23 additions & 0 deletions crates/nexus/src/http/layers.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
#![allow(dead_code)]
use axum::http::HeaderMap;
use axum::{middleware::Next, response::Response};
use http::header::{AUTHORIZATION, CONTENT_TYPE};
use http::{HeaderValue, Method};
use snafu::ResultExt;
use std::str::FromStr;
use tower_http::cors::CorsLayer;
use uuid::Uuid;

use super::error;

#[derive(Clone)]
struct RequestMetadata {
request_id: Uuid,
Expand Down Expand Up @@ -41,3 +47,20 @@ pub async fn add_request_metadata(
.insert("x-request-id", request_id.to_string().parse().unwrap());
response
}

#[allow(clippy::needless_pass_by_value)]
pub fn make_cors_middleware(origin: String) -> Result<CorsLayer, error::NexusHttpError> {
let origin_value = origin
.parse::<HeaderValue>()
.context(error::AllowOriginHeaderParseSnafu)?;
Ok(CorsLayer::new()
.allow_origin(origin_value)
.allow_methods(vec![
Method::GET,
Method::POST,
Method::DELETE,
Method::HEAD,
])
.allow_headers(vec![AUTHORIZATION, CONTENT_TYPE])
.allow_credentials(true))
}
1 change: 1 addition & 0 deletions crates/nexus/src/http/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ pub mod router;
pub mod catalog;
pub mod control;
pub mod dbt;
pub mod error;
pub mod layers;
pub mod session;
pub mod ui;
Expand Down
8 changes: 1 addition & 7 deletions crates/nexus/src/http/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use axum::routing::{get, post};
use axum::{Json, Router};
use std::fs;
use tower_http::catch_panic::CatchPanicLayer;
use tower_http::cors::{Any, CorsLayer};
use utoipa::openapi::{self};
use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi;
Expand Down Expand Up @@ -69,12 +68,7 @@ pub fn create_app(state: AppState) -> Router {
.route("/telemetry/send", post(|| async { Json("OK") }))
.layer(TimeoutLayer::new(std::time::Duration::from_secs(1200)))
.layer(CatchPanicLayer::new())
.layer(
CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any),
)
//.layer(super::layers::make_cors_middleware(allow_origin.unwrap_or("*".to_string())))
.with_state(state)
}

Expand Down
37 changes: 10 additions & 27 deletions crates/nexus/src/http/session.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use axum::{extract::FromRequestParts, response::IntoResponse};
use axum::extract::FromRequestParts;
use control_plane::service::ControlService;
use http::request::Parts;
use snafu::prelude::*;
use snafu::ResultExt;
use std::{collections::HashMap, sync::Arc};
use time::OffsetDateTime;
Expand All @@ -11,6 +10,8 @@ use tower_sessions::{
session_store, ExpiredDeletion, Session, SessionStore,
};

use super::error::{self as nexus_http_error, NexusHttpError};

pub type RequestSessionMemory = Arc<Mutex<HashMap<Id, Record>>>;

#[derive(Clone)]
Expand Down Expand Up @@ -126,40 +127,19 @@ impl std::fmt::Debug for RequestSessionStore {
}
}

#[derive(Snafu, Debug)]
pub enum SessionError {
#[snafu(display("Session load error: {msg}"))]
SessionLoad { msg: String },
#[snafu(display("Unable to persist session"))]
SessionPersist {
source: tower_sessions::session::Error,
},
}

impl IntoResponse for SessionError {
fn into_response(self) -> axum::response::Response {
tracing::error!("Session error: {}", self);
(
axum::http::StatusCode::INTERNAL_SERVER_ERROR,
"Session error",
)
.into_response()
}
}

#[derive(Debug)]
pub struct DFSessionId(pub String);

impl<S> FromRequestParts<S> for DFSessionId
where
S: Send + Sync,
{
type Rejection = SessionError;
type Rejection = NexusHttpError;

async fn from_request_parts(req: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let session = Session::from_request_parts(req, state).await.map_err(|e| {
tracing::error!("Failed to get session: {}", e.1);
SessionError::SessionLoad {
NexusHttpError::SessionLoad {
msg: e.1.to_string(),
}
})?;
Expand All @@ -172,8 +152,11 @@ where
session
.insert("DF_SESSION_ID", id.clone())
.await
.context(SessionPersistSnafu)?;
session.save().await.context(SessionPersistSnafu)?;
.context(nexus_http_error::SessionPersistSnafu)?;
session
.save()
.await
.context(nexus_http_error::SessionPersistSnafu)?;
id
};
Ok(Self(session_id))
Expand Down
8 changes: 7 additions & 1 deletion crates/nexus/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use tokio::signal;
use tower_http::trace::TraceLayer;
use tower_sessions::{Expiry, SessionManagerLayer};

use http::layers::make_cors_middleware;
use http::session::{RequestSessionMemory, RequestSessionStore};
use utils::Db;

Expand All @@ -32,6 +33,7 @@ pub async fn run_icehut(
slatedb_prefix: String,
host: String,
port: u16,
allow_origin: Option<String>,
) -> Result<(), Box<dyn std::error::Error>> {
let db = {
let options = DbOptions::default();
Expand Down Expand Up @@ -72,11 +74,15 @@ pub async fn run_icehut(
// Create the application state
let app_state = state::AppState::new(control_svc.clone(), Arc::new(catalog_svc));

let app = http::router::create_app(app_state)
let mut app = http::router::create_app(app_state)
.layer(session_layer)
.layer(TraceLayer::new_for_http())
.layer(middleware::from_fn(print_request_response));

if let Some(allow_origin) = allow_origin {
app = app.layer(make_cors_middleware(allow_origin)?);
}

let listener = tokio::net::TcpListener::bind(format!("{host}:{port}")).await?;
tracing::info!("Listening on {}", listener.local_addr().unwrap());
axum::serve(listener, app)
Expand Down
Loading