Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: use new fields exclusively #283

Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ fcm = "0.9"
ed25519-dalek = "2.0.0-rc.2"

# JWT Authentication
relay_rpc = { git = "https://github.com/WalletConnect/WalletConnectRust.git", rev = "ced99e7"}
relay_rpc = { git = "https://github.com/WalletConnect/WalletConnectRust.git", rev = "4ee9007"}
jsonwebtoken = "8.1"
data-encoding = "2.3"

Expand Down
1 change: 1 addition & 0 deletions src/analytics/message_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pub struct MessageInfo {
pub client_id: Arc<str>,
pub topic: Arc<str>,
pub push_provider: Arc<str>,
pub always_encrypted: bool,
pub encrypted: bool,
pub flags: u32,
pub status: u16,
Expand Down
2 changes: 1 addition & 1 deletion src/blob.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ impl DecryptedPayloadBlob {
Ok(serde_json::from_str(&blob_string)?)
}

pub fn from_base64_encoded(blob_string: String) -> Result<DecryptedPayloadBlob> {
pub fn from_base64_encoded(blob_string: &str) -> Result<DecryptedPayloadBlob> {
let blob_decoded = base64::engine::general_purpose::STANDARD.decode(blob_string)?;
Ok(serde_json::from_slice(&blob_decoded)?)
}
Expand Down
5 changes: 1 addition & 4 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ pub enum Error {
InternalServerError,

#[error(transparent)]
JwtError(#[from] relay_rpc::auth::JwtVerificationError),
JwtError(#[from] relay_rpc::jwt::JwtError),

#[error("the provided authentication does not authenticate the request")]
InvalidAuthentication,
Expand Down Expand Up @@ -180,9 +180,6 @@ pub enum Error {

#[error("tenant suspended due to invalid configuration")]
TenantSuspended,

#[error("Bad payload provided: {0}")]
BadPayload(String),
}

impl IntoResponse for Error {
Expand Down
17 changes: 13 additions & 4 deletions src/handlers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ use {
Json,
},
hyper::StatusCode,
relay_rpc::{auth::Jwt, domain::ClientId},
relay_rpc::{
domain::ClientId,
jwt::{JwtBasicClaims, VerifyableClaims},
},
serde_json::{json, Value},
std::{collections::HashSet, string::ToString},
tracing::info,
Expand Down Expand Up @@ -53,12 +56,18 @@ where
{
return if let Some(auth_header) = headers.get(axum::http::header::AUTHORIZATION) {
let header_str = auth_header.to_str()?;
let client_id = Jwt(header_str.to_string())
.decode(&HashSet::from([aud.to_string()]))

let claims = JwtBasicClaims::try_from_str(header_str).map_err(|e| {
info!("Invalid claims: {:?}", e);
e
})?;
claims
.verify_basic(&HashSet::from([aud.to_string()]), None)
.map_err(|e| {
info!("Invalid claims: {:?}", e);
info!("Failed to verify_basic: {:?}", e);
e
})?;
let client_id: ClientId = claims.iss.into();
Ok(check(Some(client_id)))
} else {
// Note: Authentication is not required right now to ensure that this is a
Expand Down
130 changes: 67 additions & 63 deletions src/handlers/push_message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use axum_client_ip::SecureClientIp;
use {
crate::{
analytics::message_info::MessageInfo,
blob::ENCRYPTED_FLAG,
error::{
Error,
Error::{ClientNotFound, Store},
Expand All @@ -12,7 +11,7 @@ use {
increment_counter,
log::prelude::*,
middleware::validate_signature::RequireValidSignature,
providers::{Provider, PushProvider},
providers::{NewPushMessage, OldPushMessage, Provider, PushMessage, PushProvider},
request_id::get_req_id,
state::AppState,
stores::StoreError,
Expand All @@ -28,40 +27,14 @@ use {
tracing::instrument,
};

#[derive(Serialize, Deserialize, Debug, Clone, Eq, PartialEq)]
pub struct MessagePayload {
pub topic: String,
pub flags: u32,
pub blob: String,
}

impl MessagePayload {
pub fn is_encrypted(&self) -> bool {
(self.flags & ENCRYPTED_FLAG) == ENCRYPTED_FLAG
}
}

/// Encrypted notify message payload
#[derive(Serialize, Deserialize, Debug, Clone, Eq, PartialEq)]
pub struct RawMessagePayload {
pub topic: String,
pub tag: usize,
pub message: String,
}

#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
pub struct PushMessageBody {
/// Topic is used by the SDKs to decrypt
/// encrypted payloads on the client side
pub topic: Option<String>,
/// Filtering tag
pub tag: Option<usize>,
/// The payload message
pub message: Option<String>,
#[serde(flatten)]
pub new: Option<NewPushMessage>,

// Legacy (deprecating) fields
pub id: String,
pub payload: MessagePayload,
#[serde(flatten)]
pub old: Option<OldPushMessage>,
}

pub async fn handler(
Expand Down Expand Up @@ -150,19 +123,13 @@ pub async fn handler(
Ok(response)
}

#[instrument(name = "push_message_internal", skip_all, fields(tenant_id = tenant_id, client_id = client_id, id = body.id))]
#[instrument(name = "push_message_internal", skip_all, fields(tenant_id = tenant_id, client_id = client_id))]
pub async fn handler_internal(
Path((tenant_id, client_id)): Path<(String, String)>,
StateExtractor(state): StateExtractor<Arc<AppState>>,
headers: HeaderMap,
RequireValidSignature(Json(body)): RequireValidSignature<Json<PushMessageBody>>,
) -> Result<(axum::response::Response, Option<MessageInfo>), (Error, Option<MessageInfo>)> {
#[cfg(feature = "analytics")]
let topic: Arc<str> = body.payload.clone().topic.into();

#[cfg(feature = "analytics")]
let (flags, encrypted) = (body.payload.clone().flags, body.payload.is_encrypted());

let client = match state.client_store.get_client(&tenant_id, &client_id).await {
Ok(c) => Ok(c),
Err(StoreError::NotFound(_, _)) => Err(ClientNotFound),
Expand All @@ -173,16 +140,35 @@ pub async fn handler_internal(
e,
#[cfg(feature = "analytics")]
Some(MessageInfo {
msg_id: body.id.clone().into(),
msg_id: body
.new
.as_ref()
.map(|msg| relay_rpc::rpc::msg_id::get_message_id(&msg.message).into())
.unwrap_or(
body.old
.as_ref()
.map(|msg| msg.id.clone())
.unwrap_or("error: no message id".to_owned().into()),
),
region: None,
country: None,
continent: None,
project_id: tenant_id.clone().into(),
client_id: client_id.clone().into(),
topic: topic.clone(),
topic: body.new.as_ref().map(|m| m.topic.clone()).unwrap_or(
body.old
.as_ref()
.map(|m| m.payload.topic.clone())
.unwrap_or("error: no topic".to_owned().into()),
),
push_provider: "unknown".into(),
encrypted,
flags,
always_encrypted: body.new.is_some(),
encrypted: body
.old
.as_ref()
.map(|m| m.payload.is_encrypted())
.unwrap_or(false),
flags: body.old.as_ref().map(|m| m.payload.flags).unwrap_or(0),
status: 0,
response_message: None,
received_at: wc::analytics::time::now(),
Expand All @@ -192,31 +178,52 @@ pub async fn handler_internal(
)
})?;

// Check for required fields if the client has `always_raw = true`
if client.always_raw {
if body.topic.is_none() {
return Err((Error::EmptyField("topic".to_string()), None));
}
if body.message.is_none() {
return Err((Error::EmptyField("message".to_string()), None));
let cloned_body = body.clone();
let push_message = if client.always_raw {
if let Some(body) = body.new {
PushMessage::NewPushMessage(body)
} else {
return Err((
Error::EmptyField("missing topic, tag, or message field".to_string()),
None,
));
}
if body.tag.is_none() {
return Err((Error::EmptyField("tag".to_string()), None));
} else {
#[allow(clippy::collapsible_else_if)]
if let Some(body) = body.old {
PushMessage::OldPushMessage(body)
} else {
return Err((
Error::EmptyField("missing id or payload field".to_string()),
None,
));
}
}
};

let message_id = push_message.message_id();

#[cfg(feature = "analytics")]
let mut analytics = Some(MessageInfo {
msg_id: body.id.clone().into(),
msg_id: message_id.clone(),
region: None,
country: None,
continent: None,
project_id: tenant_id.clone().into(),
client_id: client_id.clone().into(),
topic,
topic: push_message.topic(),
push_provider: client.push_type.as_str().into(),
encrypted,
flags,
always_encrypted: match push_message {
PushMessage::NewPushMessage(_) => true,
PushMessage::OldPushMessage(_) => false,
},
encrypted: match push_message {
PushMessage::NewPushMessage(_) => false,
PushMessage::OldPushMessage(ref msg) => msg.payload.is_encrypted(),
},
flags: match push_message {
PushMessage::NewPushMessage(_) => 0,
PushMessage::OldPushMessage(ref msg) => msg.payload.flags,
},
status: 0,
response_message: None,
received_at: wc::analytics::time::now(),
Expand Down Expand Up @@ -296,7 +303,7 @@ pub async fn handler_internal(

if let Ok(notification) = state
.notification_store
.get_notification(&body.id, &client_id, &tenant_id)
.get_notification(&message_id, &client_id, &tenant_id)
.await
{
warn!(
Expand Down Expand Up @@ -324,7 +331,7 @@ pub async fn handler_internal(

let notification = state
.notification_store
.create_or_update_notification(&body.id, &tenant_id, &client_id, &body.payload)
.create_or_update_notification(&message_id, &tenant_id, &client_id, &cloned_body)
.await
.tap_err(|e| warn!("error create_or_update_notification: {e:?}"))
.map_err(|e| (Error::Store(e), analytics.clone()))?;
Expand Down Expand Up @@ -395,10 +402,7 @@ pub async fn handler_internal(
"fetched provider"
);

match provider
.send_notification(client.token, body, client.always_raw)
.await
{
match provider.send_notification(client.token, push_message).await {
Ok(()) => Ok(()),
Err(error) => {
warn!("error sending notification: {error:?}");
Expand Down
Loading
Loading