Skip to content
This repository has been archived by the owner on Feb 11, 2024. It is now read-only.

Commit

Permalink
fix: return 4xx instead of 5xx on invalid JWT (#67)
Browse files Browse the repository at this point in the history
  • Loading branch information
Xavier Basty committed Sep 12, 2023
1 parent 7cf38e5 commit ffd6bb3
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 27 deletions.
6 changes: 6 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,12 @@ impl IntoResponse for Error {
fn into_response(self) -> Response {
error!("responding with error ({:?})", self);
match self {
Error::JwtError(e) => crate::handlers::Response::new_failure(StatusCode::UNAUTHORIZED, vec![
ResponseError {
name: "jwt".to_string(),
message: e.to_string(),
}
], vec![]),
Error::Database(e) => crate::handlers::Response::new_failure(StatusCode::INTERNAL_SERVER_ERROR, vec![
ResponseError {
name: "mongodb".to_string(),
Expand Down
21 changes: 0 additions & 21 deletions src/handlers/register.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,6 @@ async fn overwrite_registration(
tags: HashSet<Arc<str>>,
relay_url: Arc<str>,
) -> error::Result<Response> {
info!(
"DBG:{}:overwrite_registration: upsert registration",
client_id
);
state
.registration_store
.upsert_registration(
Expand All @@ -83,10 +79,6 @@ async fn overwrite_registration(
)
.await?;

info!(
"DBG:{}:overwrite_registration: cache registration",
client_id
);
state
.registration_cache
.insert(client_id.into_value(), CachedRegistration {
Expand All @@ -108,27 +100,15 @@ async fn update_registration(
let append_tags = append_tags.unwrap_or_default();
let remove_tags = remove_tags.unwrap_or_default();

info!(
"DBG:{}:update_registration: process intersection of <append_tags> and <remove_tags>...",
client_id
);
if remove_tags.intersection(&append_tags).count() > 0 {
return Err(Error::InvalidUpdateRequest);
}

info!(
"DBG:{}:update_registration: get current registration",
client_id
);
let registration = state
.registration_store
.get_registration(client_id.as_ref())
.await?;

info!(
"DBG:{}:update_registration: get current registration",
client_id
);
let tags = registration
.tags
.into_iter()
Expand All @@ -140,6 +120,5 @@ async fn update_registration(
.cloned()
.collect();

info!("DBG:{}: overwrite registration", client_id);
overwrite_registration(state, client_id, tags, relay_url).await
}
30 changes: 25 additions & 5 deletions tests/integration.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
extern crate core;

use relay_rpc::{
auth::{
ed25519_dalek::Keypair,
rand::{rngs::StdRng, SeedableRng},
use {
chrono::{Duration, Utc},
relay_rpc::{
auth::{
ed25519_dalek::Keypair,
rand::{rngs::StdRng, SeedableRng},
},
domain::{ClientId, DecodedClientId},
},
domain::{ClientId, DecodedClientId},
};

mod context;
Expand Down Expand Up @@ -43,3 +46,20 @@ fn get_client_jwt() -> (String, ClientId) {

(jwt, client_id)
}

fn get_invalid_client_jwt() -> (String, ClientId) {
let mut rng = StdRng::from_entropy();
let keypair = Keypair::generate(&mut rng);

let random_client_id = DecodedClientId(*keypair.public_key().as_bytes());
let client_id = ClientId::from(random_client_id);

let jwt = relay_rpc::auth::AuthToken::new(client_id.to_string())
.aud(TEST_RELAY_URL.to_string())
.iat(Utc::now() + Duration::days(1))
.as_jwt(&keypair)
.unwrap()
.to_string();

(jwt, client_id)
}
66 changes: 65 additions & 1 deletion tests/registration/mod.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,40 @@
use {
crate::{context::ServerContext, get_client_jwt, TEST_RELAY_URL},
crate::{context::ServerContext, get_client_jwt, get_invalid_client_jwt, TEST_RELAY_URL},
axum::http,
gilgamesh::{handlers::register::RegisterPayload, store::registrations::Registration},
std::sync::Arc,
test_context::test_context,
};

#[test_context(ServerContext)]
#[tokio::test]
async fn test_register_invalid_jwt(ctx: &mut ServerContext) {
let (jwt, _) = get_invalid_client_jwt();

let payload = RegisterPayload {
tags: Some(vec![Arc::from("4000"), Arc::from("5***")]),
append_tags: None,
remove_tags: None,
relay_url: Arc::from(TEST_RELAY_URL),
};

let client = reqwest::Client::new();
let response = client
.post(format!("http://{}/register", ctx.server.public_addr))
.json(&payload)
.header(http::header::AUTHORIZATION, format!("Bearer {jwt}"))
.send()
.await
.expect("Call failed");

assert!(
response.status().is_client_error(),
"Response was not successful: {:?} - {:?}",
response.status(),
response.text().await
);
}

#[test_context(ServerContext)]
#[tokio::test]
async fn test_register_new(ctx: &mut ServerContext) {
Expand Down Expand Up @@ -316,3 +345,38 @@ async fn test_get_registration(ctx: &mut ServerContext) {
assert_eq!(payload.tags.unwrap(), tags);
assert_eq!(payload.relay_url.as_ref(), TEST_RELAY_URL);
}

#[test_context(ServerContext)]
#[tokio::test]
async fn test_get_registration_invalid_jwt(ctx: &mut ServerContext) {
let (jwt, client_id) = get_invalid_client_jwt();

let tags = vec![Arc::from("4000"), Arc::from("5***")];
let registration = Registration {
id: None,
client_id: client_id.clone().into_value(),
tags: tags.clone(),
relay_url: Arc::from(TEST_RELAY_URL),
};

ctx.server
.registration_store
.registrations
.insert(client_id.to_string(), registration)
.await;

let client = reqwest::Client::new();
let response = client
.get(format!("http://{}/register", ctx.server.public_addr))
.header(http::header::AUTHORIZATION, format!("Bearer {jwt}"))
.send()
.await
.expect("Call failed");

assert!(
response.status().is_client_error(),
"Response was not successful: {:?} - {:?}",
response.status(),
response.text().await
);
}

0 comments on commit ffd6bb3

Please sign in to comment.