Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 8 additions & 17 deletions wp_api/src/login.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ use serde::{Deserialize, Serialize};
use std::{collections::HashMap, str, sync::Arc};
use wp_localization::{MessageBundle, WpMessages, WpSupportsLocalization};
use wp_localization_macro::WpDeriveLocalizable;
use wp_serde_helper::{deserialize_false_or_string, deserialize_offset};
use wp_serde_helper::{
deserialize_empty_array_or_hashmap, deserialize_false_or_string, deserialize_offset,
};

const KEY_APPLICATION_PASSWORDS: &str = "application-passwords";

Expand Down Expand Up @@ -225,22 +227,11 @@ impl WpSupportsLocalization for OAuthResponseUrlError {
}
}

#[derive(Debug, Clone, Serialize, PartialEq)]
pub struct WpApiDetailsAuthenticationMap(HashMap<String, WpRestApiAuthenticationScheme>);

// If the response is `[]`, default to an empty `HashMap`
impl<'de> Deserialize<'de> for WpApiDetailsAuthenticationMap {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
deserializer
.deserialize_any(wp_serde_helper::DeserializeEmptyVecOrT::<
HashMap<String, WpRestApiAuthenticationScheme>,
>::new(Box::new(HashMap::new)))
.map(Self)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct WpApiDetailsAuthenticationMap(
#[serde(deserialize_with = "deserialize_empty_array_or_hashmap")]
HashMap<String, WpRestApiAuthenticationScheme>,
);

impl WpApiDetailsAuthenticationMap {
pub fn has_application_passwords_authentication_url(&self) -> bool {
Expand Down
74 changes: 74 additions & 0 deletions wp_serde_helper/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -253,11 +253,57 @@ where
deserializer.deserialize_any(DeserializeFalseOrStringVisitor)
}

struct DeserializeEmptyArrayOrHashMapVisitor<K, V>(PhantomData<(K, V)>);

impl<'de, K, V> de::Visitor<'de> for DeserializeEmptyArrayOrHashMapVisitor<K, V>
where
K: DeserializeOwned + std::hash::Hash + Eq,
V: DeserializeOwned,
{
type Value = std::collections::HashMap<K, V>;

fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("empty array or a HashMap")
}

fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: de::SeqAccess<'de>,
{
if seq.next_element::<Self::Value>()?.is_none() {
// It's an empty array
Ok(std::collections::HashMap::new())
} else {
// not an empty array
Err(serde::de::Error::invalid_type(Unexpected::Seq, &self))
}
}

fn visit_map<A>(self, map: A) -> Result<Self::Value, A::Error>
where
A: de::MapAccess<'de>,
{
std::collections::HashMap::deserialize(de::value::MapAccessDeserializer::new(map))
}
}

pub fn deserialize_empty_array_or_hashmap<'de, D, K, V>(
deserializer: D,
) -> Result<std::collections::HashMap<K, V>, D::Error>
where
D: Deserializer<'de>,
K: DeserializeOwned + std::hash::Hash + Eq,
V: DeserializeOwned,
{
deserializer.deserialize_any(DeserializeEmptyArrayOrHashMapVisitor::<K, V>(PhantomData))
}

#[cfg(test)]
mod tests {
use super::*;
use rstest::*;
use serde::Deserialize;
use std::collections::HashMap;

#[derive(Debug, Deserialize)]
pub struct Foo {
Expand Down Expand Up @@ -314,4 +360,32 @@ mod tests {
serde_json::from_str(test_case).expect("Test case should be a valid JSON");
assert_eq!(expected_result, string_or_bool.value);
}

#[derive(Debug, Deserialize)]
pub struct HashMapWrapper {
#[serde(deserialize_with = "deserialize_empty_array_or_hashmap")]
pub map: HashMap<String, String>,
}

#[rstest]
#[case(r#"{"map": []}"#, HashMap::new())]
#[case(r#"{"map": {"key": "value"}}"#, {
let mut map = HashMap::new();
map.insert("key".to_string(), "value".to_string());
map
})]
#[case(r#"{"map": {"foo": "bar", "hello": "world"}}"#, {
let mut map = HashMap::new();
map.insert("foo".to_string(), "bar".to_string());
map.insert("hello".to_string(), "world".to_string());
map
})]
fn test_deserialize_empty_array_or_hashmap(
#[case] test_case: &str,
#[case] expected_result: HashMap<String, String>,
) {
let wrapper: HashMapWrapper =
serde_json::from_str(test_case).expect("Test case should be a valid JSON");
assert_eq!(expected_result, wrapper.map);
}
}