diff --git a/wp_api/src/login.rs b/wp_api/src/login.rs index 70a8fd662..ccacff1f3 100644 --- a/wp_api/src/login.rs +++ b/wp_api/src/login.rs @@ -5,7 +5,9 @@ use serde::{Deserialize, Serialize}; use std::{collections::HashMap, str, sync::Arc}; use wp_localization::{MessageBundle, WpMessages, WpSupportsLocalization}; use wp_localization_macro::WpDeriveLocalizable; -use wp_serde_helper::{deserialize_false_or_string, deserialize_offset}; +use wp_serde_helper::{ + deserialize_empty_array_or_hashmap, deserialize_false_or_string, deserialize_offset, +}; const KEY_APPLICATION_PASSWORDS: &str = "application-passwords"; @@ -225,22 +227,11 @@ impl WpSupportsLocalization for OAuthResponseUrlError { } } -#[derive(Debug, Clone, Serialize, PartialEq)] -pub struct WpApiDetailsAuthenticationMap(HashMap); - -// If the response is `[]`, default to an empty `HashMap` -impl<'de> Deserialize<'de> for WpApiDetailsAuthenticationMap { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - deserializer - .deserialize_any(wp_serde_helper::DeserializeEmptyVecOrT::< - HashMap, - >::new(Box::new(HashMap::new))) - .map(Self) - } -} +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct WpApiDetailsAuthenticationMap( + #[serde(deserialize_with = "deserialize_empty_array_or_hashmap")] + HashMap, +); impl WpApiDetailsAuthenticationMap { pub fn has_application_passwords_authentication_url(&self) -> bool { diff --git a/wp_serde_helper/src/lib.rs b/wp_serde_helper/src/lib.rs index 7f5269eee..f310cc6c1 100644 --- a/wp_serde_helper/src/lib.rs +++ b/wp_serde_helper/src/lib.rs @@ -253,11 +253,57 @@ where deserializer.deserialize_any(DeserializeFalseOrStringVisitor) } +struct DeserializeEmptyArrayOrHashMapVisitor(PhantomData<(K, V)>); + +impl<'de, K, V> de::Visitor<'de> for DeserializeEmptyArrayOrHashMapVisitor +where + K: DeserializeOwned + std::hash::Hash + Eq, + V: DeserializeOwned, +{ + type Value = std::collections::HashMap; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("empty array or a HashMap") + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: de::SeqAccess<'de>, + { + if seq.next_element::()?.is_none() { + // It's an empty array + Ok(std::collections::HashMap::new()) + } else { + // not an empty array + Err(serde::de::Error::invalid_type(Unexpected::Seq, &self)) + } + } + + fn visit_map(self, map: A) -> Result + where + A: de::MapAccess<'de>, + { + std::collections::HashMap::deserialize(de::value::MapAccessDeserializer::new(map)) + } +} + +pub fn deserialize_empty_array_or_hashmap<'de, D, K, V>( + deserializer: D, +) -> Result, D::Error> +where + D: Deserializer<'de>, + K: DeserializeOwned + std::hash::Hash + Eq, + V: DeserializeOwned, +{ + deserializer.deserialize_any(DeserializeEmptyArrayOrHashMapVisitor::(PhantomData)) +} + #[cfg(test)] mod tests { use super::*; use rstest::*; use serde::Deserialize; + use std::collections::HashMap; #[derive(Debug, Deserialize)] pub struct Foo { @@ -314,4 +360,32 @@ mod tests { serde_json::from_str(test_case).expect("Test case should be a valid JSON"); assert_eq!(expected_result, string_or_bool.value); } + + #[derive(Debug, Deserialize)] + pub struct HashMapWrapper { + #[serde(deserialize_with = "deserialize_empty_array_or_hashmap")] + pub map: HashMap, + } + + #[rstest] + #[case(r#"{"map": []}"#, HashMap::new())] + #[case(r#"{"map": {"key": "value"}}"#, { + let mut map = HashMap::new(); + map.insert("key".to_string(), "value".to_string()); + map + })] + #[case(r#"{"map": {"foo": "bar", "hello": "world"}}"#, { + let mut map = HashMap::new(); + map.insert("foo".to_string(), "bar".to_string()); + map.insert("hello".to_string(), "world".to_string()); + map + })] + fn test_deserialize_empty_array_or_hashmap( + #[case] test_case: &str, + #[case] expected_result: HashMap, + ) { + let wrapper: HashMapWrapper = + serde_json::from_str(test_case).expect("Test case should be a valid JSON"); + assert_eq!(expected_result, wrapper.map); + } }