Skip to content

Commit

Permalink
fix: use new fields exclusively
Browse files Browse the repository at this point in the history
  • Loading branch information
chris13524 committed Nov 17, 2023
1 parent 28f331e commit 8f51a17
Show file tree
Hide file tree
Showing 15 changed files with 272 additions and 248 deletions.
3 changes: 2 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ fcm = "0.9"
ed25519-dalek = "2.0.0-rc.2"

# JWT Authentication
relay_rpc = { git = "https://github.com/WalletConnect/WalletConnectRust.git", rev = "ced99e7"}
relay_rpc = { git = "https://github.com/WalletConnect/WalletConnectRust.git", rev = "4ee9007"}
jsonwebtoken = "8.1"
data-encoding = "2.3"

Expand Down
1 change: 1 addition & 0 deletions src/analytics/message_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pub struct MessageInfo {
pub client_id: Arc<str>,
pub topic: Arc<str>,
pub push_provider: Arc<str>,
pub always_encrypted: bool,
pub encrypted: bool,
pub flags: u32,
pub status: u16,
Expand Down
2 changes: 1 addition & 1 deletion src/blob.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ impl DecryptedPayloadBlob {
Ok(serde_json::from_str(&blob_string)?)
}

pub fn from_base64_encoded(blob_string: String) -> Result<DecryptedPayloadBlob> {
pub fn from_base64_encoded(blob_string: &str) -> Result<DecryptedPayloadBlob> {
let blob_decoded = base64::engine::general_purpose::STANDARD.decode(blob_string)?;
Ok(serde_json::from_slice(&blob_decoded)?)
}
Expand Down
5 changes: 1 addition & 4 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ pub enum Error {
InternalServerError,

#[error(transparent)]
JwtError(#[from] relay_rpc::auth::JwtVerificationError),
JwtError(#[from] relay_rpc::jwt::JwtError),

#[error("the provided authentication does not authenticate the request")]
InvalidAuthentication,
Expand Down Expand Up @@ -180,9 +180,6 @@ pub enum Error {

#[error("tenant suspended due to invalid configuration")]
TenantSuspended,

#[error("Bad payload provided: {0}")]
BadPayload(String),
}

impl IntoResponse for Error {
Expand Down
17 changes: 13 additions & 4 deletions src/handlers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ use {
Json,
},
hyper::StatusCode,
relay_rpc::{auth::Jwt, domain::ClientId},
relay_rpc::{
domain::ClientId,
jwt::{JwtBasicClaims, VerifyableClaims},
},
serde_json::{json, Value},
std::{collections::HashSet, string::ToString},
tracing::info,
Expand Down Expand Up @@ -53,12 +56,18 @@ where
{
return if let Some(auth_header) = headers.get(axum::http::header::AUTHORIZATION) {
let header_str = auth_header.to_str()?;
let client_id = Jwt(header_str.to_string())
.decode(&HashSet::from([aud.to_string()]))

let claims = JwtBasicClaims::try_from_str(header_str).map_err(|e| {
info!("Invalid claims: {:?}", e);
e
})?;
claims
.verify_basic(&HashSet::from([aud.to_string()]), None)
.map_err(|e| {
info!("Invalid claims: {:?}", e);
info!("Failed to verify_basic: {:?}", e);
e
})?;
let client_id: ClientId = claims.iss.into();
Ok(check(Some(client_id)))
} else {
// Note: Authentication is not required right now to ensure that this is a
Expand Down
130 changes: 67 additions & 63 deletions src/handlers/push_message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use axum_client_ip::SecureClientIp;
use {
crate::{
analytics::message_info::MessageInfo,
blob::ENCRYPTED_FLAG,
error::{
Error,
Error::{ClientNotFound, Store},
Expand All @@ -12,7 +11,7 @@ use {
increment_counter,
log::prelude::*,
middleware::validate_signature::RequireValidSignature,
providers::{Provider, PushProvider},
providers::{NewPushMessage, OldPushMessage, Provider, PushMessage, PushProvider},
request_id::get_req_id,
state::AppState,
stores::StoreError,
Expand All @@ -28,40 +27,14 @@ use {
tracing::instrument,
};

#[derive(Serialize, Deserialize, Debug, Clone, Eq, PartialEq)]
pub struct MessagePayload {
pub topic: String,
pub flags: u32,
pub blob: String,
}

impl MessagePayload {
pub fn is_encrypted(&self) -> bool {
(self.flags & ENCRYPTED_FLAG) == ENCRYPTED_FLAG
}
}

/// Encrypted notify message payload
#[derive(Serialize, Deserialize, Debug, Clone, Eq, PartialEq)]
pub struct RawMessagePayload {
pub topic: String,
pub tag: usize,
pub message: String,
}

#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
pub struct PushMessageBody {
/// Topic is used by the SDKs to decrypt
/// encrypted payloads on the client side
pub topic: Option<String>,
/// Filtering tag
pub tag: Option<usize>,
/// The payload message
pub message: Option<String>,
#[serde(flatten)]
pub new: Option<NewPushMessage>,

// Legacy (deprecating) fields
pub id: String,
pub payload: MessagePayload,
#[serde(flatten)]
pub old: Option<OldPushMessage>,
}

pub async fn handler(
Expand Down Expand Up @@ -150,19 +123,13 @@ pub async fn handler(
Ok(response)
}

#[instrument(name = "push_message_internal", skip_all, fields(tenant_id = tenant_id, client_id = client_id, id = body.id))]
#[instrument(name = "push_message_internal", skip_all, fields(tenant_id = tenant_id, client_id = client_id))]
pub async fn handler_internal(
Path((tenant_id, client_id)): Path<(String, String)>,
StateExtractor(state): StateExtractor<Arc<AppState>>,
headers: HeaderMap,
RequireValidSignature(Json(body)): RequireValidSignature<Json<PushMessageBody>>,
) -> Result<(axum::response::Response, Option<MessageInfo>), (Error, Option<MessageInfo>)> {
#[cfg(feature = "analytics")]
let topic: Arc<str> = body.payload.clone().topic.into();

#[cfg(feature = "analytics")]
let (flags, encrypted) = (body.payload.clone().flags, body.payload.is_encrypted());

let client = match state.client_store.get_client(&tenant_id, &client_id).await {
Ok(c) => Ok(c),
Err(StoreError::NotFound(_, _)) => Err(ClientNotFound),
Expand All @@ -173,16 +140,35 @@ pub async fn handler_internal(
e,
#[cfg(feature = "analytics")]
Some(MessageInfo {
msg_id: body.id.clone().into(),
msg_id: body
.new
.as_ref()
.map(|msg| relay_rpc::rpc::msg_id::get_message_id(&msg.message).into())
.unwrap_or(
body.old
.as_ref()
.map(|msg| msg.id.clone())
.unwrap_or("error: no message id".to_owned().into()),
),
region: None,
country: None,
continent: None,
project_id: tenant_id.clone().into(),
client_id: client_id.clone().into(),
topic: topic.clone(),
topic: body.new.as_ref().map(|m| m.topic.clone()).unwrap_or(
body.old
.as_ref()
.map(|m| m.payload.topic.clone())
.unwrap_or("error: no topic".to_owned().into()),
),
push_provider: "unknown".into(),
encrypted,
flags,
always_encrypted: body.new.is_some(),
encrypted: body
.old
.as_ref()
.map(|m| m.payload.is_encrypted())
.unwrap_or(false),
flags: body.old.as_ref().map(|m| m.payload.flags).unwrap_or(0),
status: 0,
response_message: None,
received_at: wc::analytics::time::now(),
Expand All @@ -192,31 +178,52 @@ pub async fn handler_internal(
)
})?;

// Check for required fields if the client has `always_raw = true`
if client.always_raw {
if body.topic.is_none() {
return Err((Error::EmptyField("topic".to_string()), None));
}
if body.message.is_none() {
return Err((Error::EmptyField("message".to_string()), None));
let cloned_body = body.clone();
let push_message = if client.always_raw {
if let Some(body) = body.new {
PushMessage::NewPushMessage(body)
} else {
return Err((
Error::EmptyField("missing topic, tag, or message field".to_string()),
None,
));
}
if body.tag.is_none() {
return Err((Error::EmptyField("tag".to_string()), None));
} else {
#[allow(clippy::collapsible_else_if)]
if let Some(body) = body.old {
PushMessage::OldPushMessage(body)
} else {
return Err((
Error::EmptyField("missing id or payload field".to_string()),
None,
));
}
}
};

let message_id = push_message.message_id();

#[cfg(feature = "analytics")]
let mut analytics = Some(MessageInfo {
msg_id: body.id.clone().into(),
msg_id: message_id.clone(),
region: None,
country: None,
continent: None,
project_id: tenant_id.clone().into(),
client_id: client_id.clone().into(),
topic,
topic: push_message.topic(),
push_provider: client.push_type.as_str().into(),
encrypted,
flags,
always_encrypted: match push_message {
PushMessage::NewPushMessage(_) => true,
PushMessage::OldPushMessage(_) => false,
},
encrypted: match push_message {
PushMessage::NewPushMessage(_) => false,
PushMessage::OldPushMessage(ref msg) => msg.payload.is_encrypted(),
},
flags: match push_message {
PushMessage::NewPushMessage(_) => 0,
PushMessage::OldPushMessage(ref msg) => msg.payload.flags,
},
status: 0,
response_message: None,
received_at: wc::analytics::time::now(),
Expand Down Expand Up @@ -296,7 +303,7 @@ pub async fn handler_internal(

if let Ok(notification) = state
.notification_store
.get_notification(&body.id, &client_id, &tenant_id)
.get_notification(&message_id, &client_id, &tenant_id)
.await
{
warn!(
Expand Down Expand Up @@ -324,7 +331,7 @@ pub async fn handler_internal(

let notification = state
.notification_store
.create_or_update_notification(&body.id, &tenant_id, &client_id, &body.payload)
.create_or_update_notification(&message_id, &tenant_id, &client_id, &cloned_body)
.await
.tap_err(|e| warn!("error create_or_update_notification: {e:?}"))
.map_err(|e| (Error::Store(e), analytics.clone()))?;
Expand Down Expand Up @@ -395,10 +402,7 @@ pub async fn handler_internal(
"fetched provider"
);

match provider
.send_notification(client.token, body, client.always_raw)
.await
{
match provider.send_notification(client.token, push_message).await {
Ok(()) => Ok(()),
Err(error) => {
warn!("error sending notification: {error:?}");
Expand Down
Loading

0 comments on commit 8f51a17

Please sign in to comment.