diff --git a/rmp-serde/src/decode.rs b/rmp-serde/src/decode.rs index f7fd39e3..f3ca458f 100644 --- a/rmp-serde/src/decode.rs +++ b/rmp-serde/src/decode.rs @@ -30,6 +30,7 @@ pub enum Error { /// Uncategorized error. Uncategorized(String), Syntax(String), + DepthLimitExceeded, } impl ::std::error::Error for Error { @@ -44,6 +45,7 @@ impl ::std::error::Error for Error { LengthMismatch(_) => None, Uncategorized(_) => None, Syntax(_) => None, + DepthLimitExceeded => None, } } } @@ -171,8 +173,23 @@ pub struct Deserializer { rd: R, buf: Vec, decoding_option: bool, + depth: usize, } +macro_rules! depth_count( + ( $counter:expr, $expr:expr ) => { + { + $counter -= 1; + if $counter == 0 { + return Err(Error::DepthLimitExceeded) + } + let res = $expr; + $counter += 1; + res + } + } +); + impl Deserializer { // TODO: Docs. pub fn new(rd: R) -> Deserializer { @@ -180,9 +197,15 @@ impl Deserializer { rd: rd, buf: Vec::new(), decoding_option: false, + depth: 1000, } } + /// Changes the maximum nesting depth that is allowed + pub fn set_max_depth(&mut self, depth: usize) { + self.depth = depth; + } + /// Gets a reference to the underlying reader in this decoder. pub fn get_ref(&self) -> &R { &self.rd @@ -209,21 +232,21 @@ impl Deserializer { fn read_array(&mut self, len: u32, mut visitor: V) -> Result where V: serde::de::Visitor { - visitor.visit_seq(SeqVisitor { + depth_count!(self.depth, visitor.visit_seq(SeqVisitor { deserializer: self, len: len, actual: len, - }) + })) } fn read_map(&mut self, len: u32, mut visitor: V) -> Result where V: serde::de::Visitor { - visitor.visit_map(MapVisitor { + depth_count!(self.depth, visitor.visit_map(MapVisitor { deserializer: self, len: len, actual: len, - }) + })) } fn read_bin_data(&mut self, len: usize, mut visitor: V) -> Result @@ -339,7 +362,7 @@ impl serde::Deserializer for Deserializer { { // Primarily try to read optimisticly. self.decoding_option = true; - let res = match visitor.visit_some(self) { + let res = match depth_count!(self.depth, visitor.visit_some(self)) { Ok(val) => Ok(val), Err(Error::TypeMismatch(Marker::Null)) => visitor.visit_none(), Err(err) => Err(err) @@ -355,7 +378,7 @@ impl serde::Deserializer for Deserializer { let len = try!(read_array_size(&mut self.rd)); match len { - 2 => visitor.visit(VariantVisitor::new(self)), + 2 => depth_count!(self.depth, visitor.visit(VariantVisitor::new(self))), n => Err(Error::LengthMismatch(n as u32)), } } diff --git a/rmp-serde/src/encode.rs b/rmp-serde/src/encode.rs index 40387e9a..f19d5a8f 100644 --- a/rmp-serde/src/encode.rs +++ b/rmp-serde/src/encode.rs @@ -28,6 +28,9 @@ pub enum Error { /// Failed to serialize struct, sequence or map, because its length is unknown. UnknownLength, + + /// Depth limit exceeded + DepthLimitExceeded } impl ::std::error::Error for Error { @@ -36,6 +39,7 @@ impl ::std::error::Error for Error { Error::InvalidFixedValueWrite(..) => "invalid fixed value write", Error::InvalidValueWrite(..) => "invalid value write", Error::UnknownLength => "attempt to serialize struct, sequence or map with unknown length", + Error::DepthLimitExceeded => "depth limit exceeded", } } @@ -44,6 +48,7 @@ impl ::std::error::Error for Error { Error::InvalidFixedValueWrite(ref err) => Some(err), Error::InvalidValueWrite(ref err) => Some(err), Error::UnknownLength => None, + Error::DepthLimitExceeded => None, } } } @@ -108,14 +113,37 @@ impl VariantWriter for StructArrayWriter { pub struct Serializer<'a, W: VariantWriter> { wr: &'a mut Write, vw: W, + depth: usize, +} + +impl<'a, W: VariantWriter> Serializer<'a, W> { + /// Changes the maximum nesting depth that is allowed + pub fn set_max_depth(&mut self, depth: usize) { + self.depth = depth; + } } +macro_rules! depth_count( + ( $counter:expr, $expr:expr ) => { + { + $counter -= 1; + if $counter == 0 { + return Err(Error::DepthLimitExceeded) + } + let res = $expr; + $counter += 1; + res + } + } +); + impl<'a> Serializer<'a, StructArrayWriter> { /// Creates a new MessagePack encoder whose output will be written to the writer specified. pub fn new(wr: &'a mut Write) -> Serializer<'a, StructArrayWriter> { Serializer { wr: wr, vw: StructArrayWriter, + depth: 1000, } } } @@ -126,6 +154,7 @@ impl<'a, W: VariantWriter> Serializer<'a, W> { Serializer { wr: wr, vw: vw, + depth: 1000, } } } @@ -245,7 +274,7 @@ impl<'a, W: VariantWriter> serde::Serializer for Serializer<'a, W> { // ... and its arguments length. try!(write_array_len(&mut self.wr, len as u32)); - while let Some(()) = try!(visitor.visit(self)) { } + while let Some(()) = try!(depth_count!(self.depth, visitor.visit(self))) { } Ok(()) } @@ -267,7 +296,7 @@ impl<'a, W: VariantWriter> serde::Serializer for Serializer<'a, W> { fn visit_some(&mut self, v: T) -> Result<(), Error> where T: serde::Serialize, { - v.serialize(self) + depth_count!(self.depth, v.serialize(self)) } // TODO: Check len, overflow is possible. @@ -281,7 +310,7 @@ impl<'a, W: VariantWriter> serde::Serializer for Serializer<'a, W> { try!(write_array_len(&mut self.wr, len as u32)); - while let Some(()) = try!(visitor.visit(self)) { } + while let Some(()) = try!(depth_count!(self.depth, visitor.visit(self))) { } Ok(()) } @@ -302,7 +331,7 @@ impl<'a, W: VariantWriter> serde::Serializer for Serializer<'a, W> { try!(write_map_len(&mut self.wr, len as u32)); - while let Some(()) = try!(visitor.visit(self)) { } + while let Some(()) = try!(depth_count!(self.depth, visitor.visit(self))) { } Ok(()) } @@ -331,7 +360,7 @@ impl<'a, W: VariantWriter> serde::Serializer for Serializer<'a, W> { try!(self.vw.write_struct_len(&mut self.wr, len as u32)); - while let Some(()) = try!(visitor.visit(self)) { } + while let Some(()) = try!(depth_count!(self.depth, visitor.visit(self))) { } Ok(()) }