From 7c2277917195d65b5d5e83ac8e8d66c5b056aa21 Mon Sep 17 00:00:00 2001 From: Adam Reichold Date: Wed, 15 May 2024 00:35:19 +0200 Subject: [PATCH] Add support for general deserialization options and namespaces in particular. --- Cargo.toml | 2 +- src/lib.rs | 302 +++++++++++++++++++++++++++++++++++++++++++++-------- 2 files changed, 257 insertions(+), 47 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index bb4fd28..35e61b4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,7 @@ license = "MIT OR Apache-2.0" repository = "https://github.com/adamreichold/serde-roxmltree" documentation = "https://docs.rs/serde-roxmltree" readme = "README.md" -version = "0.7.0" +version = "0.7.1" edition = "2021" [dependencies] diff --git a/src/lib.rs b/src/lib.rs index 31fab3e..b76a68c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -119,6 +119,42 @@ //! # //! # Ok::<(), Box>(()) //! ``` +//! +//! Support for [namespaces][namespaces] can be enabled via the [`namespaces`][Options::namespaces] option: +//! +//! ``` +//! use serde::Deserialize; +//! use serde_roxmltree::{defaults, from_str, Options}; +//! +//! let text = r#" +//! 23 +//! 42 +//! "#; +//! +//! #[derive(Deserialize)] +//! struct SomeRecord { +//! qux: Vec, +//! } +//! +//! let record = from_str::(text)?; +//! assert_eq!(record.qux, [23, 42]); +//! +//! #[derive(Deserialize)] +//! struct AnotherRecord { +//! #[serde(rename = "{http://foo}qux")] +//! some_qux: i32, +//! #[serde(rename = "{http://bar}qux")] +//! another_qux: i32, +//! } +//! +//! let record = defaults().namespaces().from_str::(text)?; +//! assert_eq!(record.some_qux, 23); +//! assert_eq!(record.another_qux, 42); +//! # +//! # Ok::<(), Box>(()) +//! ``` +//! +//! [namespaces]: https://www.w3.org/TR/REC-xml-names/ #![forbid(unsafe_code)] #![deny( missing_docs, @@ -130,6 +166,7 @@ use std::collections::HashSet; use std::error::Error as StdError; use std::fmt; use std::iter::{once, Peekable}; +use std::marker::PhantomData; use std::num::{ParseFloatError, ParseIntError}; use std::str::{FromStr, ParseBoolError}; @@ -143,8 +180,7 @@ pub fn from_str(text: &str) -> Result where T: de::DeserializeOwned, { - let document = Document::parse(text).map_err(Error::ParseXml)?; - from_doc(&document) + defaults().from_str(text) } /// Deserialize an instance of type `T` from a [`roxmltree::Document`] @@ -152,7 +188,7 @@ pub fn from_doc<'de, 'input, T>(document: &'de Document<'input>) -> Result, { - from_node(document.root_element()) + defaults().from_doc(document) } /// Deserialize an instance of type `T` from a [`roxmltree::Node`] @@ -160,34 +196,146 @@ pub fn from_node<'de, 'input, T>(node: Node<'de, 'input>) -> Result where T: de::Deserialize<'de>, { - let deserializer = Deserializer { - source: Source::Node(node), - visited: &mut HashSet::new(), - }; - T::deserialize(deserializer) + defaults().from_node(node) +} + +/// Types that represent a set of options +/// +/// Provides methods to deserialize values using the given options +/// as well as to change individual options in the set. +pub trait Options: Sized { + /// Deserialize an instance of type `T` directly from XML text using the given options + #[allow(clippy::wrong_self_convention)] + fn from_str(self, text: &str) -> Result + where + T: de::DeserializeOwned, + { + let document = Document::parse(text).map_err(Error::ParseXml)?; + self.from_doc(&document) + } + + /// Deserialize an instance of type `T` from a [`roxmltree::Document`] using the given options + #[allow(clippy::wrong_self_convention)] + fn from_doc<'de, 'input, T>(self, document: &'de Document<'input>) -> Result + where + T: de::Deserialize<'de>, + { + let node = document.root_element(); + self.from_node(node) + } + + /// Deserialize an instance of type `T` from a [`roxmltree::Node`] using the given options + #[allow(clippy::wrong_self_convention)] + fn from_node<'de, 'input, T>(self, node: Node<'de, 'input>) -> Result + where + T: de::Deserialize<'de>, + { + let deserializer = Deserializer { + source: Source::Node(node), + temp: &mut Temp::default(), + options: PhantomData::, + }; + T::deserialize(deserializer) + } + + /// Include namespaces when building identifiers + /// + /// When tags or attributes are part of a namespace, + /// their identifiers will have the form `{namespace}name`. + fn namespaces(self) -> Namespaces { + Namespaces(PhantomData) + } + + #[doc(hidden)] + #[allow(private_interfaces)] + fn name<'de, 'input, 'temp>( + source: &Source<'de, 'input>, + _buffer: &'temp mut String, + ) -> &'temp str + where + 'input: 'temp, + { + match source { + Source::Node(node) => node.tag_name().name(), + Source::Attribute(attr) => attr.name(), + Source::Text(_) => "$text", + } + } } -struct Deserializer<'de, 'input, 'tmp> { +#[doc(hidden)] +#[derive(Clone, Copy, Default, Debug)] +pub struct Defaults; + +/// The default set of options +pub fn defaults() -> Defaults { + Defaults +} + +impl Options for Defaults {} + +#[doc(hidden)] +#[derive(Clone, Copy, Default, Debug)] +pub struct Namespaces(PhantomData); + +impl Options for Namespaces { + #[allow(private_interfaces)] + fn name<'de, 'input, 'temp>( + source: &Source<'de, 'input>, + buffer: &'temp mut String, + ) -> &'temp str + where + 'input: 'temp, + { + fn inner<'temp>( + namespace: Option<&str>, + name: &str, + buffer: &'temp mut String, + ) -> &'temp str { + buffer.clear(); + + if let Some(namespace) = namespace { + buffer.push('{'); + buffer.push_str(namespace); + buffer.push('}'); + } + + buffer.push_str(name); + + &*buffer + } + + match source { + Source::Node(node) => { + let tag_name = node.tag_name(); + inner(tag_name.namespace(), tag_name.name(), buffer) + } + Source::Attribute(attr) => inner(attr.namespace(), attr.name(), buffer), + Source::Text(_) => "$text", + } + } +} + +struct Deserializer<'de, 'input, 'temp, O> { source: Source<'de, 'input>, - visited: &'tmp mut HashSet, + temp: &'temp mut Temp, + options: PhantomData, } -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy)] enum Source<'de, 'input> { Node(Node<'de, 'input>), Attribute(Attribute<'de, 'input>), Text(&'de str), } -impl<'de, 'input, 'tmp> Deserializer<'de, 'input, 'tmp> { - fn name(&self) -> &'de str { - match &self.source { - Source::Node(node) => node.tag_name().name(), - Source::Attribute(attr) => attr.name(), - Source::Text(_) => "$text", - } - } +#[derive(Default)] +struct Temp { + visited: HashSet, + buffer: String, +} +impl<'de, 'input, 'temp, O> Deserializer<'de, 'input, 'temp, O> { fn node(&self) -> Result<&Node<'de, 'input>, Error> { match &self.source { Source::Node(node) => Ok(node), @@ -198,12 +346,16 @@ impl<'de, 'input, 'tmp> Deserializer<'de, 'input, 'tmp> { fn children_and_attributes(&self) -> Result>, Error> { let node = self.node()?; - Ok(node + let children = node .children() .filter(|node| node.is_element()) - .map(Source::Node) - .chain(node.attributes().map(Source::Attribute)) - .chain(once(node.text().unwrap_or_default()).map(Source::Text))) + .map(Source::Node); + + let attributes = node.attributes().map(Source::Attribute); + + let text = once(Source::Text(node.text().unwrap_or_default())); + + Ok(children.chain(attributes).chain(text)) } fn siblings(&self) -> Result>, Error> { @@ -232,7 +384,10 @@ impl<'de, 'input, 'tmp> Deserializer<'de, 'input, 'tmp> { } } -impl<'de, 'input, 'tmp> de::Deserializer<'de> for Deserializer<'de, 'input, 'tmp> { +impl<'de, 'input, 'temp, O> de::Deserializer<'de> for Deserializer<'de, 'input, 'temp, O> +where + O: Options, +{ type Error = Error; fn deserialize_bool(self, visitor: V) -> Result @@ -389,7 +544,8 @@ impl<'de, 'input, 'tmp> de::Deserializer<'de> for Deserializer<'de, 'input, 'tmp { visitor.visit_seq(SeqAccess { source: self.siblings()?, - visited: self.visited, + temp: self.temp, + options: PhantomData::, }) } @@ -418,7 +574,8 @@ impl<'de, 'input, 'tmp> de::Deserializer<'de> for Deserializer<'de, 'input, 'tmp { visitor.visit_map(MapAccess { source: self.children_and_attributes()?.peekable(), - visited: self.visited, + temp: self.temp, + options: PhantomData::, }) } @@ -445,7 +602,8 @@ impl<'de, 'input, 'tmp> de::Deserializer<'de> for Deserializer<'de, 'input, 'tmp { visitor.visit_enum(EnumAccess { source: self.children_and_attributes()?, - visited: self.visited, + temp: self.temp, + options: PhantomData::, }) } @@ -453,7 +611,7 @@ impl<'de, 'input, 'tmp> de::Deserializer<'de> for Deserializer<'de, 'input, 'tmp where V: de::Visitor<'de>, { - visitor.visit_borrowed_str(self.name()) + visitor.visit_str(O::name(&self.source, &mut self.temp.buffer)) } fn deserialize_any(self, _visitor: V) -> Result @@ -471,17 +629,19 @@ impl<'de, 'input, 'tmp> de::Deserializer<'de> for Deserializer<'de, 'input, 'tmp } } -struct SeqAccess<'de, 'tmp, I> +struct SeqAccess<'de, 'temp, I, O> where I: Iterator>, { source: I, - visited: &'tmp mut HashSet, + temp: &'temp mut Temp, + options: PhantomData, } -impl<'de, 'tmp, I> de::SeqAccess<'de> for SeqAccess<'de, 'tmp, I> +impl<'de, 'temp, I, O> de::SeqAccess<'de> for SeqAccess<'de, 'temp, I, O> where I: Iterator>, + O: Options, { type Error = Error; @@ -492,11 +652,12 @@ where match self.source.next() { None => Ok(None), Some(node) => { - self.visited.insert(node.id()); + self.temp.visited.insert(node.id()); let deserializer = Deserializer { source: Source::Node(node), - visited: &mut *self.visited, + temp: &mut *self.temp, + options: PhantomData::, }; seed.deserialize(deserializer).map(Some) } @@ -504,17 +665,19 @@ where } } -struct MapAccess<'de, 'input: 'de, 'tmp, I> +struct MapAccess<'de, 'input: 'de, 'temp, I, O> where I: Iterator>, { source: Peekable, - visited: &'tmp mut HashSet, + temp: &'temp mut Temp, + options: PhantomData, } -impl<'de, 'input, 'tmp, I> de::MapAccess<'de> for MapAccess<'de, 'input, 'tmp, I> +impl<'de, 'input, 'temp, I, O> de::MapAccess<'de> for MapAccess<'de, 'input, 'temp, I, O> where I: Iterator>, + O: Options, { type Error = Error; @@ -527,7 +690,7 @@ where None => return Ok(None), Some(source) => { if let Source::Node(node) = source { - if self.visited.contains(&node.id()) { + if self.temp.visited.contains(&node.id()) { self.source.next().unwrap(); continue; } @@ -535,7 +698,8 @@ where let deserailizer = Deserializer { source: *source, - visited: &mut *self.visited, + temp: &mut *self.temp, + options: PhantomData::, }; return seed.deserialize(deserailizer).map(Some); } @@ -551,26 +715,29 @@ where let deserializer = Deserializer { source, - visited: &mut *self.visited, + temp: &mut *self.temp, + options: PhantomData::, }; seed.deserialize(deserializer) } } -struct EnumAccess<'de, 'input: 'de, 'tmp, I> +struct EnumAccess<'de, 'input: 'de, 'temp, I, O> where I: Iterator>, { source: I, - visited: &'tmp mut HashSet, + temp: &'temp mut Temp, + options: PhantomData, } -impl<'de, 'input, 'tmp, I> de::EnumAccess<'de> for EnumAccess<'de, 'input, 'tmp, I> +impl<'de, 'input, 'temp, I, O> de::EnumAccess<'de> for EnumAccess<'de, 'input, 'temp, I, O> where I: Iterator>, + O: Options, { type Error = Error; - type Variant = Deserializer<'de, 'input, 'tmp>; + type Variant = Deserializer<'de, 'input, 'temp, O>; fn variant_seed(mut self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error> where @@ -580,19 +747,24 @@ where let deserializer = Deserializer { source, - visited: &mut *self.visited, + temp: &mut *self.temp, + options: PhantomData::, }; let value = seed.deserialize(deserializer)?; let deserializer = Deserializer { source, - visited: &mut *self.visited, + temp: &mut *self.temp, + options: PhantomData::, }; Ok((value, deserializer)) } } -impl<'de, 'input, 'tmp> de::VariantAccess<'de> for Deserializer<'de, 'input, 'tmp> { +impl<'de, 'input, 'temp, O> de::VariantAccess<'de> for Deserializer<'de, 'input, 'temp, O> +where + O: Options, +{ type Error = Error; fn unit_variant(self) -> Result<(), Self::Error> { @@ -895,4 +1067,42 @@ mod tests { from_str::(r#"foobar"#).unwrap(); } + + #[test] + fn children_with_namespaces() { + #[derive(Deserialize)] + struct Root { + #[serde(rename = "{http://name.space}child")] + child: u64, + } + + let val = Defaults + .namespaces() + .from_str::(r#"42"#) + .unwrap(); + assert_eq!(val.child, 42); + + let val = Defaults + .namespaces() + .from_str::(r#"42"#) + .unwrap(); + assert_eq!(val.child, 42); + } + + #[test] + fn attributes_with_namespaces() { + #[derive(Deserialize)] + struct Root { + #[serde(rename = "{http://name.space}attr")] + attr: i32, + } + + let val = Defaults + .namespaces() + .from_str::( + r#""#, + ) + .unwrap(); + assert_eq!(val.attr, 23); + } }