-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy pathlib.rs
More file actions
128 lines (115 loc) · 4.01 KB
/
lib.rs
File metadata and controls
128 lines (115 loc) · 4.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
pub mod args;
mod functions;
mod state;
mod types;
mod user;
use std::net::SocketAddr;
use axum_server::tls_rustls::RustlsConfig;
use eyre::OptionExt;
pub use state::{AppState, SharedState};
use thiserror::Error;
use tower_http::trace::TraceLayer;
pub use types::*;
use args::Args;
use axum::{
http::StatusCode,
response::{IntoResponse, Response},
routing::post,
Json, Router,
};
/// Create the axum Router for the server.
/// Maps specific endpoints to handler functions.
// TODO: use methods of a single object instead of separate functions?
pub fn router(shared_state: SharedState) -> Router {
// Shared state that is passed to each handler by axum
Router::new()
.route("/challenge", post(functions::challenge))
.route("/login", post(functions::login))
.route("/logout", post(functions::logout))
.route("/create_new_session", post(functions::create_new_session))
.route("/list_sessions", post(functions::list_sessions))
.route("/get_session_info", post(functions::get_session_info))
.route("/send", post(functions::send))
.route("/receive", post(functions::receive))
.route("/close_session", post(functions::close_session))
.layer(TraceLayer::new_for_http())
.with_state(shared_state)
}
/// Run the server with the specified arguments.
pub async fn run(args: &Args) -> Result<(), Box<dyn std::error::Error>> {
let shared_state = AppState::new().await?;
let app = router(shared_state.clone());
let addr: SocketAddr = format!("{}:{}", args.ip(), args.port).parse()?;
if args.no_tls_very_insecure {
tracing::warn!(
"starting an INSECURE HTTP server at {}. This should be done only \
for testing or if you are providing TLS/HTTPS with a separate \
mechanism (e.g. reverse proxy such as nginx)",
addr,
);
let listener = tokio::net::TcpListener::bind(addr).await?;
Ok(axum::serve(listener, app).await?)
} else {
rustls::crypto::ring::default_provider()
.install_default()
.expect("Failed to install rustls crypto provider");
let config = RustlsConfig::from_pem_file(
args.tls_cert
.clone()
.ok_or_eyre("tls-cert argument is required")?,
args.tls_key
.clone()
.ok_or_eyre("tls-key argument is required")?,
)
.await?;
tracing::info!("starting HTTPS server at {}", addr);
Ok(axum_server::bind_rustls(addr, config)
.serve(app.into_make_service())
.await?)
}
}
/// An error. Wraps a StatusCode which is returned by the server when the
/// error happens during a API call, and a generic eyre::Report.
#[derive(Debug, Error)]
pub(crate) enum AppError {
#[error("invalid or missing argument: {0}")]
InvalidArgument(String),
#[error("client did not provide proper authorization credentials")]
Unauthorized,
#[error("session was not found")]
SessionNotFound,
#[error("user is not the coordinator")]
NotCoordinator,
}
// These make it easier to clients to tell which error happened.
pub const INVALID_ARGUMENT: usize = 1;
pub const UNAUTHORIZED: usize = 2;
pub const SESSION_NOT_FOUND: usize = 3;
pub const NOT_COORDINATOR: usize = 4;
impl AppError {
pub fn error_code(&self) -> usize {
match &self {
AppError::InvalidArgument(_) => INVALID_ARGUMENT,
AppError::Unauthorized => UNAUTHORIZED,
AppError::SessionNotFound => SESSION_NOT_FOUND,
AppError::NotCoordinator => NOT_COORDINATOR,
}
}
}
impl From<AppError> for types::Error {
fn from(err: AppError) -> Self {
types::Error {
code: err.error_code(),
msg: err.to_string(),
}
}
}
impl IntoResponse for AppError {
fn into_response(self) -> Response {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(Into::<types::Error>::into(self)),
)
.into_response()
}
}