diff --git a/Cargo.toml b/Cargo.toml index 4d56637..70bd4c6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,10 +16,8 @@ crate-type = ["rlib", "cdylib"] pyo3 = { version = "0.27.1", features = ["generate-import-lib", "anyhow"] } python3-dll-a = "0.2.14" anyhow = "1.0.100" -libipld = { version = "0.16.0", features = ["dag-cbor"] } -multibase = "0.9.2" -byteorder = "1.5.0" -multihash = "0.18.1" +cid = "0.11.1" +cbor4ii = { version = "1.2.0", features = ["use_alloc"] } [workspace] members = [ "profiling" ] @@ -30,3 +28,6 @@ debug = false incremental = false lto = true opt-level = 3 + +[patch.crates-io] +cbor4ii = { git = "https://github.com/quininer/cbor4ii" } diff --git a/pytests/test_dag_cbor.py b/pytests/test_dag_cbor.py index c48c2c6..d2b5046 100644 --- a/pytests/test_dag_cbor.py +++ b/pytests/test_dag_cbor.py @@ -165,7 +165,7 @@ def test_dag_cbor_decode_invalid_utf8() -> None: libipld.decode_dag_cbor(bytes.fromhex('62c328')) - assert 'Invalid UTF-8 string' in str(exc_info.value) + assert 'utf-8' in str(exc_info.value) def test_dab_cbor_decode_map_int_key() -> None: diff --git a/src/lib.rs b/src/lib.rs index 4ee7406..dde1938 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,14 +1,77 @@ -use std::io::{BufReader, BufWriter, Cursor, Read, Seek, Write}; -use std::os::raw::c_char; +use std::io::{self, Write}; -use ::libipld::cbor::error::{LengthOutOfRange, NumberOutOfRange, UnknownTag}; -use ::libipld::cbor::{cbor, cbor::MajorKind, decode, encode}; -use ::libipld::cid::{Cid, Error as CidError, Result as CidResult, Version}; use anyhow::{anyhow, Result}; -use byteorder::{BigEndian, ByteOrder}; -use multihash::Multihash; -use pyo3::{ffi, prelude::*, types::*, BoundObject, Python}; +use cbor4ii::core::{ + dec::{self, Decode, Read}, + enc::{self, Encode}, + major, types, +}; +use cid::{multibase, Cid}; use pyo3::pybacked::PyBackedStr; +use pyo3::{ffi, prelude::*, types::*, BoundObject, Python}; + +// Copy from cbor4ii/src/core.rs. +mod marker { + pub const FALSE: u8 = 0xf4; // simple(20) + pub const TRUE: u8 = 0xf5; // simple(21) + pub const NULL: u8 = 0xf6; // simple(22) + pub const F32: u8 = 0xfa; + pub const F64: u8 = 0xfb; +} + +struct BufWriter(io::BufWriter); + +impl BufWriter { + pub fn new(inner: W) -> Self { + BufWriter(io::BufWriter::new(inner)) + } + + pub fn flush(&mut self) -> io::Result<()> { + self.0.flush() + } + + pub fn get_ref(&self) -> &W { + self.0.get_ref() + } +} + +impl enc::Write for BufWriter { + type Error = io::Error; + + #[inline] + fn push(&mut self, input: &[u8]) -> Result<(), Self::Error> { + self.0.write_all(input)?; + Ok(()) + } +} + +// Based on cbor4ii/src/utils.rs. +/// An in-memory reader. +struct SliceReader<'a> { + buf: &'a [u8], +} + +impl SliceReader<'_> { + fn new(buf: &[u8]) -> SliceReader<'_> { + SliceReader { buf } + } +} + +impl<'de> dec::Read<'de> for SliceReader<'de> { + type Error = core::convert::Infallible; + + #[inline] + fn fill<'b>(&'b mut self, want: usize) -> Result, Self::Error> { + let len = core::cmp::min(self.buf.len(), want); + Ok(dec::Reference::Long(&self.buf[..len])) + } + + #[inline] + fn advance(&mut self, n: usize) { + let len = core::cmp::min(self.buf.len(), n); + self.buf = &self.buf[len..]; + } +} fn cid_hash_to_pydict<'py>(py: Python<'py>, cid: &Cid) -> Bound<'py, PyDict> { let hash = cid.hash(); @@ -35,11 +98,7 @@ fn cid_to_pydict<'py>(py: Python<'py>, cid: &Cid) -> Bound<'py, PyDict> { dict_obj } -fn decode_len(len: u64) -> Result { - Ok(usize::try_from(len).map_err(|_| LengthOutOfRange::new::())?) -} - -fn map_key_cmp(a: &Vec, b: &Vec) -> std::cmp::Ordering { +fn map_key_cmp(a: &[u8], b: &[u8]) -> std::cmp::Ordering { /* The keys in every map must be sorted length-first by the byte representation of the string keys, where: - If two keys have different lengths, the shorter one sorts earlier; - If two keys have the same length, the one with the lower value in (byte-wise) lexical order sorts earlier. @@ -102,71 +161,84 @@ fn get_bytes_from_py_any<'py>(obj: &'py Bound<'py, PyAny>) -> PyResult<&'py [u8] } } -fn string_new_bound<'py>(py: Python<'py>, s: &[u8]) -> Result> { - std::str::from_utf8(s).map_err(|e| anyhow!("Invalid UTF-8 string: {}", e))?; - - let ptr = s.as_ptr() as *const c_char; - let len = s.len() as ffi::Py_ssize_t; - unsafe { - Ok(Bound::from_owned_ptr(py, ffi::PyUnicode_FromStringAndSize(ptr, len)).cast_into_unchecked()) - } +// Based on cbor4ii code. +fn peek_one<'de, R: dec::Read<'de>>(r: &mut R) -> Result +where + R::Error: Send + Sync, +{ + r.fill(1)? + .as_ref() + .first() + .copied() + .ok_or_else(|| anyhow!("end of data")) } -fn decode_dag_cbor_to_pyobject( +fn decode_dag_cbor_to_pyobject<'de, R: dec::Read<'de>>( py: Python, r: &mut R, depth: usize, -) -> Result> { +) -> Result> +where + R::Error: Send + Sync, +{ unsafe { if depth > ffi::Py_GetRecursionLimit() as usize { PyErr::new::( "RecursionError: maximum recursion depth exceeded in DAG-CBOR decoding", - ).restore(py); + ) + .restore(py); return Err(anyhow!("Maximum recursion depth exceeded")); } } - let major = decode::read_major(r)?; - Ok(match major.kind() { - MajorKind::UnsignedInt => decode::read_uint(r, major)?.into_pyobject(py)?.into(), - MajorKind::NegativeInt => (-1 - decode::read_uint(r, major)? as i128).into_pyobject(py)?.into(), - MajorKind::ByteString => { - let len = decode::read_uint(r, major)?; - PyBytes::new(py, &decode::read_bytes(r, len)?).into_pyobject(py)?.into() - } - MajorKind::TextString => { - let len = decode::read_uint(r, major)?; - string_new_bound(py, &decode::read_bytes(r, len)?)?.into_pyobject(py)?.into() + let byte = peek_one(r)?; + return Ok(match dec::if_major(byte) { + major::UNSIGNED => u64::decode(r)?.into_pyobject(py)?.into(), + major::NEGATIVE => i128::decode(r)?.into_pyobject(py)?.into(), + major::BYTES => PyBytes::new(py, >::decode(r)?.0) + .into_pyobject(py)? + .into(), + major::STRING => { + // The UTF-8 validation is done when it's converted into a Python string + PyString::from_bytes( + py, + >::decode(r) + .map_err(|_| anyhow!("Cannot decode as bytes"))? + .0, + )? + .into() } - MajorKind::Array => { - let len: ffi::Py_ssize_t = decode_len(decode::read_uint(r, major)?)?.try_into()?; + major::ARRAY => { + let len: ffi::Py_ssize_t = + types::Array::len(r)?.expect("contains length").try_into()?; unsafe { let ptr = ffi::PyList_New(len); for i in 0..len { - ffi::PyList_SET_ITEM(ptr, i, decode_dag_cbor_to_pyobject(py, r, depth + 1)?.into_ptr()); + ffi::PyList_SET_ITEM( + ptr, + i, + decode_dag_cbor_to_pyobject(py, r, depth + 1)?.into_ptr(), + ); } let list: Bound<'_, PyList> = Bound::from_owned_ptr(py, ptr).cast_into_unchecked(); list.into_pyobject(py)?.into() } } - MajorKind::Map => { - let len = decode_len(decode::read_uint(r, major)?)?; + major::MAP => { + let len = types::Map::len(r)?.expect("contains length"); let dict = PyDict::new(py); - let mut prev_key: Option> = None; + let mut prev_key: Option<&[u8]> = None; for _ in 0..len { - // DAG-CBOR keys are always strings - let key_major = decode::read_major(r)?; - if key_major.kind() != MajorKind::TextString { - return Err(anyhow!("Map keys must be strings")); - } - - let key_len = decode::read_uint(r, key_major)?; - let key = decode::read_bytes(r, key_len)?; + // DAG-CBOR keys are always strings. Python does the UTF-8 validation when creating + // the string. + let key = >::decode(r) + .map_err(|_| anyhow!("Map keys must be strings"))? + .0; if let Some(prev_key) = prev_key { // it cares about duplicated keys too thanks to Ordering::Equal @@ -175,7 +247,7 @@ fn decode_dag_cbor_to_pyobject( } } - let key_py = string_new_bound(py, key.as_slice())?.into_pyobject(py)?; + let key_py = PyString::from_bytes(py, key)?; prev_key = Some(key); let value_py = decode_dag_cbor_to_pyobject(py, r, depth + 1)?; @@ -184,45 +256,67 @@ fn decode_dag_cbor_to_pyobject( dict.into_pyobject(py)?.into() } - MajorKind::Tag => { - let value = decode::read_uint(r, major)?; + major::TAG => { + let value = types::Tag::tag(r)?; if value != 42 { return Err(anyhow!("Non-42 tags are not supported")); } - // FIXME(MarshalX): to_bytes allocates - let cid = decode::read_link(r)?.to_bytes(); + let cid = >::decode(r)?.0; + + // Parse the CID for validation. They have a zero byte at the front, strip it off. + if let Err(_) = Cid::try_from(&cid[1..]) { + return Err(anyhow!("Invalid CID")); + } + PyBytes::new(py, &cid).into_pyobject(py)?.into() } - MajorKind::Other => match major { + major::SIMPLE => match byte { // FIXME(MarshalX): should be more clear for bool? - cbor::FALSE => false.into_pyobject(py)?.into_any().unbind(), - cbor::TRUE => true.into_pyobject(py)?.into_any().unbind(), - cbor::NULL => py.None(), - cbor::F32 => { - let value = decode::read_f32(r)?; + marker::FALSE => { + r.advance(1); + false.into_pyobject(py)?.into_any().unbind() + } + marker::TRUE => { + r.advance(1); + true.into_pyobject(py)?.into_any().unbind() + } + marker::NULL => { + r.advance(1); + py.None() + } + marker::F32 => { + let value = f32::decode(r)?; if !value.is_finite() { - return Err(anyhow!("Number out of range for f32 (NaNs are forbidden)".to_string())); + return Err(anyhow!( + "Number out of range for f32 (NaNs are forbidden)".to_string() + )); } value.into_pyobject(py)?.into() - }, - cbor::F64 => { - let value = decode::read_f64(r)?; + } + marker::F64 => { + let value = f64::decode(r)?; if !value.is_finite() { - return Err(anyhow!("Number out of range for f64 (NaNs are forbidden)".to_string())); + return Err(anyhow!( + "Number out of range for f64 (NaNs are forbidden)".to_string() + )); } value.into_pyobject(py)?.into() - }, + } _ => return Err(anyhow!("Unsupported major type".to_string())), }, - }) + _ => return Err(anyhow!("Invalid major type".to_string())), + }); } -fn encode_dag_cbor_from_pyobject<'py, W: Write>( +fn encode_dag_cbor_from_pyobject<'py, W: enc::Write>( _py: Python<'py>, obj: &Bound<'py, PyAny>, w: &mut W, -) -> Result<()> { +) -> Result<()> +where + W::Error: Send + Sync, +{ /* Order is important for performance! Fast checks go first: @@ -238,16 +332,11 @@ fn encode_dag_cbor_from_pyobject<'py, W: Write>( */ if obj.is_none() { - encode::write_null(w)?; + types::Null.encode(w)?; Ok(()) } else if obj.is_instance_of::() { - let buf = if obj.is_truthy()? { - [cbor::TRUE.into()] - } else { - [cbor::FALSE.into()] - }; - w.write_all(&buf)?; + obj.is_truthy()?.encode(w)?; Ok(()) } else if obj.is_instance_of::() { @@ -258,20 +347,20 @@ fn encode_dag_cbor_from_pyobject<'py, W: Write>( return Err(anyhow!("Number out of range")); } - encode::write_u64(w, MajorKind::NegativeInt, -(i + 1) as u64)? + types::Negative(-(i + 1) as u64).encode(w)?; } else { if i > u64::MAX as i128 { return Err(anyhow!("Number out of range")); } - encode::write_u64(w, MajorKind::UnsignedInt, i as u64)? + (i as u64).encode(w)?; } Ok(()) } else if let Ok(l) = obj.cast::() { let len = l.len(); - encode::write_u64(w, MajorKind::Array, len as u64)?; + types::Array::bounded(len, w)?; for i in 0..len { encode_dag_cbor_from_pyobject(_py, &l.get_item(i)?, w)?; @@ -283,13 +372,12 @@ fn encode_dag_cbor_from_pyobject<'py, W: Write>( let keys = sort_map_keys(&map.keys(), len)?; let values = map.values(); - encode::write_u64(w, MajorKind::Map, len as u64)?; + types::Map::bounded(len, w)?; for (key, i) in keys { - let key_buf = key.as_bytes(); - encode::write_u64(w, MajorKind::TextString, key_buf.len() as u64)?; - w.write_all(key_buf)?; - + key.get(..) + .expect("whole range is a valid string") + .encode(w)?; encode_dag_cbor_from_pyobject(_py, &values.get_item(i)?, w)?; } @@ -297,48 +385,34 @@ fn encode_dag_cbor_from_pyobject<'py, W: Write>( } else if let Ok(f) = obj.cast::() { let v = f.value(); if !v.is_finite() { - return Err(NumberOutOfRange::new::().into()); + return Err(anyhow!("Number out of range")); } - let mut buf = [0xfb, 0, 0, 0, 0, 0, 0, 0, 0]; - BigEndian::write_f64(&mut buf[1..], v); - w.write_all(&buf)?; + v.encode(w)?; Ok(()) } else if let Ok(b) = obj.cast::() { // FIXME (MarshalX): it's not efficient to try to parse it as CID let cid = Cid::try_from(b.as_bytes()); if cid.is_ok() { - let buf = b.as_bytes(); - let len = buf.len(); - - encode::write_tag(w, 42)?; - encode::write_u64(w, MajorKind::ByteString, len as u64 + 1)?; - w.write_all(&[0])?; - w.write_all(&buf[..len])?; + types::Tag(42, b.as_bytes()).encode(w)?; } else { - let l: u64 = b.len()? as u64; - - encode::write_u64(w, MajorKind::ByteString, l)?; - w.write_all(b.as_bytes())?; + types::Bytes(b.as_bytes()).encode(w)?; } Ok(()) } else if let Ok(s) = obj.cast::() { - let buf = s.to_str()?.as_bytes(); - - encode::write_u64(w, MajorKind::TextString, buf.len() as u64)?; - w.write_all(buf)?; + s.to_str()?.encode(w)?; Ok(()) } else { - Err(UnknownTag(0).into()) + Err(anyhow!("Unknown tag")) } } #[pyfunction] fn decode_dag_cbor_multi<'py>(py: Python<'py>, data: &[u8]) -> PyResult> { - let mut reader = BufReader::new(Cursor::new(data)); + let mut reader = SliceReader::new(data); let decoded_parts = PyList::empty(py); loop { @@ -354,56 +428,32 @@ fn decode_dag_cbor_multi<'py>(py: Python<'py>, data: &[u8]) -> PyResult(r: &mut R) -> Result { - let mut result = 0; +fn read_u64_leb128<'de, R: dec::Read<'de>>(r: &mut R) -> Result +where + R::Error: Send + Sync, +{ + let mut result: u64 = 0; let mut shift = 0; loop { - let mut buf = [0]; - if r.read_exact(&mut buf).is_err() { - return Err(anyhow!("Unexpected EOF while reading ULEB128 number.")); - } + let byte = + peek_one(r).map_err(|_| anyhow!("Unexpected EOF while reading ULEB128 number."))?; + r.advance(1); - let byte = buf[0] as u64; if (byte & 0x80) == 0 { - result |= (byte) << shift; + result |= (byte as u64) << shift; return Ok(result); } else { - result |= (byte & 0x7F) << shift; + result |= (byte as u64 & 0x7F) << shift; } shift += 7; } } -fn read_cid_from_bytes(r: &mut R) -> CidResult { - let Ok(version) = read_u64_leb128(r) else { - return Err(CidError::VarIntDecodeError); - }; - let Ok(codec) = read_u64_leb128(r) else { - return Err(CidError::VarIntDecodeError); - }; - - if [version, codec] == [0x12, 0x20] { - let mut digest = [0u8; 32]; - r.read_exact(&mut digest)?; - let mh = Multihash::wrap(version, &digest).expect("Digest is always 32 bytes."); - return Cid::new_v0(mh); - } - - let version = Version::try_from(version)?; - match version { - Version::V0 => Err(CidError::InvalidCidVersion), - Version::V1 => { - let mh = Multihash::read(r)?; - Cid::new(version, codec, mh) - } - } -} - #[pyfunction] pub fn decode_car<'py>(py: Python<'py>, data: &[u8]) -> PyResult<(Py, Bound<'py, PyDict>)> { - let buf = &mut BufReader::new(Cursor::new(data)); + let buf = &mut SliceReader::new(data); if read_u64_leb128(buf).is_err() { return Err(get_err( @@ -456,7 +506,7 @@ pub fn decode_car<'py>(py: Python<'py>, data: &[u8]) -> PyResult<(Py, Bou break; } - let cid_result = read_cid_from_bytes(buf); + let cid_result = Cid::read_bytes(&mut buf.buf); let Ok(cid) = cid_result else { return Err(get_err( "Failed to read CID of block", @@ -489,18 +539,23 @@ pub fn decode_car<'py>(py: Python<'py>, data: &[u8]) -> PyResult<(Py, Bou #[pyfunction] pub fn decode_dag_cbor(py: Python, data: &[u8]) -> PyResult> { - let mut reader = BufReader::new(Cursor::new(data)); + let mut reader = SliceReader::new(data); let py_object = decode_dag_cbor_to_pyobject(py, &mut reader, 0); if let Ok(py_object) = py_object { // check for any remaining data in the reader - let mut buf = [0u8; 1]; - match reader.read(&mut buf) { - Ok(0) => Ok(py_object), // EOF - Err(_) => Ok(py_object), // EOF - Ok(_) => Err(get_err( - "Failed to decode DAG-CBOR", - "Invalid DAG-CBOR: contains multiple objects (CBOR sequence)".to_string() - )), + if reader + .fill(1) + .expect("SliceReader never fails") + .as_ref() + .len() + == 0 + { + Ok(py_object) + } else { + Err(get_err( + "Failed to decode DAG-CBOR", + "Invalid DAG-CBOR: contains multiple objects (CBOR sequence)".to_string(), + )) } } else { let err = get_err( @@ -536,12 +591,11 @@ pub fn encode_dag_cbor<'py>( } fn get_cid_from_py_any(data: &Bound) -> PyResult { - let cid: CidResult; - if let Ok(s) = data.cast::() { - cid = Cid::try_from(s.to_str()?); + let cid = if let Ok(s) = data.cast::() { + Cid::try_from(s.to_str()?) } else { - cid = Cid::try_from(get_bytes_from_py_any(data)?); - } + Cid::try_from(get_bytes_from_py_any(data)?) + }; if let Ok(cid) = cid { Ok(cid) @@ -560,7 +614,10 @@ fn decode_cid<'py>(py: Python<'py>, data: &Bound) -> PyResult(py: Python<'py>, data: &Bound) -> PyResult> { - Ok(PyString::new(py, get_cid_from_py_any(data)?.to_string().as_str())) + Ok(PyString::new( + py, + get_cid_from_py_any(data)?.to_string().as_str(), + )) } #[pyfunction]