From a067193fe18c93d88324b0a9bd426af5ffe3d256 Mon Sep 17 00:00:00 2001 From: Bogdan Opanchuk Date: Tue, 10 Sep 2024 10:38:43 -0700 Subject: [PATCH 1/2] Return the reference to the buffer during deserialization. --- serdect/src/array.rs | 2 +- serdect/src/common.rs | 12 +++++------- serdect/src/slice.rs | 2 +- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/serdect/src/array.rs b/serdect/src/array.rs index 551f05611..97b963fe1 100644 --- a/serdect/src/array.rs +++ b/serdect/src/array.rs @@ -57,7 +57,7 @@ impl LengthCheck for ExactLength { /// Deserialize from hex when using human-readable formats or binary if the /// format is binary. Fails if the `buffer` isn't the exact same size as the /// resulting array. -pub fn deserialize_hex_or_bin<'de, D>(buffer: &mut [u8], deserializer: D) -> Result<(), D::Error> +pub fn deserialize_hex_or_bin<'de, D>(buffer: &mut [u8], deserializer: D) -> Result<&[u8], D::Error> where D: Deserializer<'de>, { diff --git a/serdect/src/common.rs b/serdect/src/common.rs index 52eb172d5..50a02f596 100644 --- a/serdect/src/common.rs +++ b/serdect/src/common.rs @@ -74,7 +74,7 @@ pub(crate) trait LengthCheck { pub(crate) struct StrIntoBufVisitor<'b, T: LengthCheck>(pub &'b mut [u8], pub PhantomData); impl<'de, 'b, T: LengthCheck> Visitor<'de> for StrIntoBufVisitor<'b, T> { - type Value = (); + type Value = &'b [u8]; fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { T::expecting(formatter, "a string", self.0.len() * 2) @@ -88,9 +88,7 @@ impl<'de, 'b, T: LengthCheck> Visitor<'de> for StrIntoBufVisitor<'b, T> { return Err(Error::invalid_length(v.len(), &self)); } // TODO: Map `base16ct::Error::InvalidLength` to `Error::invalid_length`. - base16ct::mixed::decode(v, self.0) - .map(|_| ()) - .map_err(E::custom) + base16ct::mixed::decode(v, self.0).map_err(E::custom) } } @@ -116,7 +114,7 @@ impl<'de> Visitor<'de> for StrIntoVecVisitor { pub(crate) struct SliceVisitor<'b, T: LengthCheck>(pub &'b mut [u8], pub PhantomData); impl<'de, 'b, T: LengthCheck> Visitor<'de> for SliceVisitor<'b, T> { - type Value = (); + type Value = &'b [u8]; fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { T::expecting(formatter, "an array", self.0.len()) @@ -131,7 +129,7 @@ impl<'de, 'b, T: LengthCheck> Visitor<'de> for SliceVisitor<'b, T> { if T::length_check(self.0.len(), v.len()) { let buffer = &mut self.0[..v.len()]; buffer.copy_from_slice(v); - return Ok(()); + return Ok(buffer); } Err(E::invalid_length(v.len(), &self)) @@ -147,7 +145,7 @@ impl<'de, 'b, T: LengthCheck> Visitor<'de> for SliceVisitor<'b, T> { if T::length_check(self.0.len(), v.len()) { let buffer = &mut self.0[..v.len()]; buffer.swap_with_slice(&mut v); - return Ok(()); + return Ok(buffer); } Err(E::invalid_length(v.len(), &self)) diff --git a/serdect/src/slice.rs b/serdect/src/slice.rs index 9104d177a..ce0798123 100644 --- a/serdect/src/slice.rs +++ b/serdect/src/slice.rs @@ -61,7 +61,7 @@ impl LengthCheck for UpperBound { /// Deserialize from hex when using human-readable formats or binary if the /// format is binary. Fails if the `buffer` is smaller then the resulting /// slice. -pub fn deserialize_hex_or_bin<'de, D>(buffer: &mut [u8], deserializer: D) -> Result<(), D::Error> +pub fn deserialize_hex_or_bin<'de, D>(buffer: &mut [u8], deserializer: D) -> Result<&[u8], D::Error> where D: Deserializer<'de>, { From a75ff1be980c72d8235f16e0f9360de124a608a4 Mon Sep 17 00:00:00 2001 From: Bogdan Opanchuk Date: Tue, 10 Sep 2024 10:52:40 -0700 Subject: [PATCH 2/2] Defer length check to base16ct --- serdect/src/common.rs | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/serdect/src/common.rs b/serdect/src/common.rs index 50a02f596..a5228b566 100644 --- a/serdect/src/common.rs +++ b/serdect/src/common.rs @@ -2,7 +2,7 @@ use core::fmt; use core::marker::PhantomData; use serde::{ - de::{Error, Visitor}, + de::{Error, Unexpected, Visitor}, Serializer, }; @@ -84,11 +84,15 @@ impl<'de, 'b, T: LengthCheck> Visitor<'de> for StrIntoBufVisitor<'b, T> { where E: Error, { - if !T::length_check(self.0.len() * 2, v.len()) { - return Err(Error::invalid_length(v.len(), &self)); - } - // TODO: Map `base16ct::Error::InvalidLength` to `Error::invalid_length`. - base16ct::mixed::decode(v, self.0).map_err(E::custom) + base16ct::mixed::decode(v, self.0).map_err(|err| match err { + base16ct::Error::InvalidLength => { + Error::invalid_length(v.len(), &"an even number of hex digits") + } + base16ct::Error::InvalidEncoding => Error::invalid_value( + Unexpected::Other(""), + &"a sequence of hex digits (0-9,a-f,A-F)", + ), + }) } }