Skip to content

Commit

Permalink
Adds support for decoding floating-point typed arrays from RFC8746
Browse files Browse the repository at this point in the history
This adds support for decoding arrays of floating point numbers of IEEE
754 formats binary16, binary32, and binary64 in both the big- and
little-endian form.
  • Loading branch information
tgockel committed May 26, 2021
1 parent 6f8311d commit cad2aaf
Show file tree
Hide file tree
Showing 3 changed files with 258 additions and 17 deletions.
33 changes: 33 additions & 0 deletions cbor2/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,33 @@ def decode_uuid(self):
from uuid import UUID
return self.set_shareable(UUID(bytes=self._decode()))

def _decode_typed_array_impl(self, name, tag, element_size, format):
"""Helper function for decoding typed arrays described by RFC 8746"""
buf = self.decode()
if not isinstance(buf, bytes):
raise CBORDecodeValueError("invalid %s typed array %r" % (name, buf))
elif len(buf) % element_size != 0:
raise CBORDecodeValueError(
"invalid length for %s typed array -- must be multiple of %i, but is %i"
% (name, element_size, len(buf)))

out = struct.unpack(format % (len(buf) // element_size), buf)

if self._immutable:
return self.set_shareable(out)
else:
return self.set_shareable(list(out))

def _decode_typed_array_func(*args):
return lambda self: self._decode_typed_array_impl(*args)

decode_array_float16_be = _decode_typed_array_func('float16', 80, 2, '>%ie')
decode_array_float32_be = _decode_typed_array_func('float32', 81, 4, '>%if')
decode_array_float64_be = _decode_typed_array_func('float64', 82, 8, '>%id')
decode_array_float16_le = _decode_typed_array_func('float16', 84, 2, '<%ie')
decode_array_float32_le = _decode_typed_array_func('float32', 85, 4, '<%if')
decode_array_float64_le = _decode_typed_array_func('float64', 86, 8, '<%id')

def decode_set(self):
# Semantic tag 258
if self._immutable:
Expand Down Expand Up @@ -535,6 +562,12 @@ def decode_float64(self):
35: CBORDecoder.decode_regexp,
36: CBORDecoder.decode_mime,
37: CBORDecoder.decode_uuid,
80: CBORDecoder.decode_array_float16_be,
81: CBORDecoder.decode_array_float32_be,
82: CBORDecoder.decode_array_float64_be,
84: CBORDecoder.decode_array_float16_le,
85: CBORDecoder.decode_array_float32_le,
86: CBORDecoder.decode_array_float64_le,
258: CBORDecoder.decode_set,
260: CBORDecoder.decode_ipaddress,
261: CBORDecoder.decode_ipnetwork,
Expand Down
219 changes: 202 additions & 17 deletions source/decoder.c
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,17 @@
#define be16toh(x) OSSwapBigToHostInt16(x)
#define be32toh(x) OSSwapBigToHostInt32(x)
#define be64toh(x) OSSwapBigToHostInt64(x)
#define le16toh(x) OSSwapLittleToHostInt16(x)
#define le32toh(x) OSSwapLittleToHostInt32(x)
#define le64toh(x) OSSwapLittleToHostInt64(x)
#elif _WIN32
// All windows platforms are (currently) little-endian so byteswap is required
#define be16toh(x) _byteswap_ushort(x)
#define be32toh(x) _byteswap_ulong(x)
#define be64toh(x) _byteswap_uint64(x)
#define le16toh(x) (x)
#define le32toh(x) (x)
#define le64toh(x) (x)
#endif

enum DecodeOption {
Expand All @@ -52,6 +58,12 @@ static PyObject * CBORDecoder_decode_bigfloat(CBORDecoderObject *);
static PyObject * CBORDecoder_decode_rational(CBORDecoderObject *);
static PyObject * CBORDecoder_decode_regexp(CBORDecoderObject *);
static PyObject * CBORDecoder_decode_uuid(CBORDecoderObject *);
static PyObject * CBORDecoder_decode_array_float16_be(CBORDecoderObject *);
static PyObject * CBORDecoder_decode_array_float32_be(CBORDecoderObject *);
static PyObject * CBORDecoder_decode_array_float64_be(CBORDecoderObject *);
static PyObject * CBORDecoder_decode_array_float16_le(CBORDecoderObject *);
static PyObject * CBORDecoder_decode_array_float32_le(CBORDecoderObject *);
static PyObject * CBORDecoder_decode_array_float64_le(CBORDecoderObject *);
static PyObject * CBORDecoder_decode_mime(CBORDecoderObject *);
static PyObject * CBORDecoder_decode_positive_bignum(CBORDecoderObject *);
static PyObject * CBORDecoder_decode_negative_bignum(CBORDecoderObject *);
Expand Down Expand Up @@ -891,23 +903,28 @@ decode_semantic(CBORDecoderObject *self, uint8_t subtype)

if (decode_length(self, subtype, &tagnum, NULL) == 0) {
switch (tagnum) {
case 0: ret = CBORDecoder_decode_datetime_string(self); break;
case 1: ret = CBORDecoder_decode_epoch_datetime(self); break;
case 2: ret = CBORDecoder_decode_positive_bignum(self); break;
case 3: ret = CBORDecoder_decode_negative_bignum(self); break;
case 4: ret = CBORDecoder_decode_fraction(self); break;
case 5: ret = CBORDecoder_decode_bigfloat(self); break;
case 28: ret = CBORDecoder_decode_shareable(self); break;
case 29: ret = CBORDecoder_decode_sharedref(self); break;
case 30: ret = CBORDecoder_decode_rational(self); break;
case 35: ret = CBORDecoder_decode_regexp(self); break;
case 36: ret = CBORDecoder_decode_mime(self); break;
case 37: ret = CBORDecoder_decode_uuid(self); break;
case 258: ret = CBORDecoder_decode_set(self); break;
case 260: ret = CBORDecoder_decode_ipaddress(self); break;
case 261: ret = CBORDecoder_decode_ipnetwork(self); break;
case 55799: ret = CBORDecoder_decode_self_describe_cbor(self);
break;
case 0: ret = CBORDecoder_decode_datetime_string(self); break;
case 1: ret = CBORDecoder_decode_epoch_datetime(self); break;
case 2: ret = CBORDecoder_decode_positive_bignum(self); break;
case 3: ret = CBORDecoder_decode_negative_bignum(self); break;
case 4: ret = CBORDecoder_decode_fraction(self); break;
case 5: ret = CBORDecoder_decode_bigfloat(self); break;
case 28: ret = CBORDecoder_decode_shareable(self); break;
case 29: ret = CBORDecoder_decode_sharedref(self); break;
case 30: ret = CBORDecoder_decode_rational(self); break;
case 35: ret = CBORDecoder_decode_regexp(self); break;
case 36: ret = CBORDecoder_decode_mime(self); break;
case 37: ret = CBORDecoder_decode_uuid(self); break;
case 80: ret = CBORDecoder_decode_array_float16_be(self); break;
case 81: ret = CBORDecoder_decode_array_float32_be(self); break;
case 82: ret = CBORDecoder_decode_array_float64_be(self); break;
case 84: ret = CBORDecoder_decode_array_float16_le(self); break;
case 85: ret = CBORDecoder_decode_array_float32_le(self); break;
case 86: ret = CBORDecoder_decode_array_float64_le(self); break;
case 258: ret = CBORDecoder_decode_set(self); break;
case 260: ret = CBORDecoder_decode_ipaddress(self); break;
case 261: ret = CBORDecoder_decode_ipnetwork(self); break;
case 55799: ret = CBORDecoder_decode_self_describe_cbor(self); break;

default:
tag = CBORTag_New(tagnum);
Expand Down Expand Up @@ -1325,6 +1342,162 @@ CBORDecoder_decode_uuid(CBORDecoderObject *self)
return ret;
}

/**
* Implementation of the decoder for all typed arrays.
*
* \param[in] self The decoder object
* \param[in] type_name The name of the type to use in error messages
* \param[in] element_size The size of the individual elements (e.g.: sizeof(uint64_t))
* \param[in] create_value Create a Python object from its byte representation
*/
static PyObject *
CBORDecoder_decode_array_typed_impl(CBORDecoderObject *self,
const char *type_name,
size_t element_size,
PyObject * (*create_value)(const char *byte_repr))
{
PyObject *bytes, *list, *ret = NULL;
Py_ssize_t bytes_size, element_count, element_idx;
char *bytes_direct;

bytes = decode(self, DECODE_UNSHARED);
if (bytes) {
if (PyBytes_CheckExact(bytes)) {
bytes_size = PyBytes_GET_SIZE(bytes);
if (bytes_size % element_size == 0) {
element_count = bytes_size / element_size;
list = PyList_New(element_count);
if (list) {
set_shareable(self, list);
bytes_direct = PyBytes_AS_STRING(bytes);

for (element_idx = 0; element_idx < element_count; bytes_direct += element_size, ++element_idx) {
PyList_SET_ITEM(list, element_idx, create_value(bytes_direct));
}

if (self->immutable) {
ret = PyList_AsTuple(list);
if (ret) {
Py_DECREF(list);
set_shareable(self, ret);
}
} else {
ret = list;
}
}
} else {
PyErr_Format(
_CBOR2_CBORDecodeValueError,
"invalid length for %s typed array %R -- must be multiple of %llu, but is %zd",
type_name,
element_size,
bytes_size);
}
} else {
PyErr_Format(
_CBOR2_CBORDecodeValueError,
"invalid %s typed array %R", type_name, bytes);
}
Py_DECREF(bytes);
}

return ret;
}


static PyObject *
create_float16_be_from_buffer(const char *src)
{
// unpack_float16 assume big-endian, so just use that
uint16_t tmp;
memcpy(&tmp, src, sizeof tmp);
return PyFloat_FromDouble(unpack_float16(tmp));
}


static PyObject *
create_float16_le_from_buffer(const char *src)
{
uint16_t tmp = (src[0] << 8) | ((uint8_t) src[1]);
return PyFloat_FromDouble(unpack_float16(tmp));
}


#define CBOR_DECODER_GEN_FLOAT_EXTRACT_FN(fn_name_, float_type_, irepr_type_, endian_flip_) \
static PyObject * \
fn_name_(const char *src) \
{ \
irepr_type_ i_repr; \
float_type_ value; \
\
memcpy(&i_repr, src, sizeof i_repr); \
i_repr = endian_flip_(i_repr); \
memcpy(&value, &i_repr, sizeof value); \
\
return PyFloat_FromDouble(value); \
}

CBOR_DECODER_GEN_FLOAT_EXTRACT_FN(create_float32_be_from_buffer, float, uint32_t, be32toh)
CBOR_DECODER_GEN_FLOAT_EXTRACT_FN(create_float64_be_from_buffer, double, uint64_t, be64toh)
CBOR_DECODER_GEN_FLOAT_EXTRACT_FN(create_float32_le_from_buffer, float, uint32_t, le32toh)
CBOR_DECODER_GEN_FLOAT_EXTRACT_FN(create_float64_le_from_buffer, double, uint64_t, le64toh)

#undef CBOR_DECODER_GEN_FLOAT_EXTRACT_FN


// CBORDecoder.decode_array_float16_be
static PyObject *
CBORDecoder_decode_array_float16_be(CBORDecoderObject *self)
{
// semantic type 80
return CBORDecoder_decode_array_typed_impl(self, "float16", 2, create_float16_be_from_buffer);
}


// CBORDecoder.decode_array_float32_be
static PyObject *
CBORDecoder_decode_array_float32_be(CBORDecoderObject *self)
{
// semantic type 81
return CBORDecoder_decode_array_typed_impl(self, "float32", 4, create_float32_be_from_buffer);
}


// CBORDecoder.decode_array_float64_be
static PyObject *
CBORDecoder_decode_array_float64_be(CBORDecoderObject *self)
{
// semantic type 82
return CBORDecoder_decode_array_typed_impl(self, "float64", 8, create_float64_be_from_buffer);
}


// CBORDecoder.decode_array_float16_le
static PyObject *
CBORDecoder_decode_array_float16_le(CBORDecoderObject *self)
{
// semantic type 84
return CBORDecoder_decode_array_typed_impl(self, "float16", 2, create_float16_le_from_buffer);
}


// CBORDecoder.decode_array_float32_le
static PyObject *
CBORDecoder_decode_array_float32_le(CBORDecoderObject *self)
{
// semantic type 85
return CBORDecoder_decode_array_typed_impl(self, "float32", 4, create_float32_le_from_buffer);
}


// CBORDecoder.decode_array_float64_le
static PyObject *
CBORDecoder_decode_array_float64_le(CBORDecoderObject *self)
{
// semantic type 86
return CBORDecoder_decode_array_typed_impl(self, "float64", 8, create_float64_le_from_buffer);
}


// CBORDecoder.decode_set(self)
static PyObject *
Expand Down Expand Up @@ -1738,6 +1911,18 @@ static PyMethodDef CBORDecoder_methods[] = {
"decode a shareable value from the input"},
{"decode_sharedref", (PyCFunction) CBORDecoder_decode_sharedref, METH_NOARGS,
"decode a shared reference from the input"},
{"decode_array_float16_be", (PyCFunction) CBORDecoder_decode_array_float16_be, METH_NOARGS,
"decode a typed array of big-endian half-precision floating-point values"},
{"decode_array_float32_be", (PyCFunction) CBORDecoder_decode_array_float32_be, METH_NOARGS,
"decode a typed array of big-endian single-precision floating-point values"},
{"decode_array_float64_be", (PyCFunction) CBORDecoder_decode_array_float64_be, METH_NOARGS,
"decode a typed array of big-endian double-precision floating-point values"},
{"decode_array_float16_le", (PyCFunction) CBORDecoder_decode_array_float16_le, METH_NOARGS,
"decode a typed array of little-endian half-precision floating-point values"},
{"decode_array_float32_le", (PyCFunction) CBORDecoder_decode_array_float32_le, METH_NOARGS,
"decode a typed array of little-endian single-precision floating-point values"},
{"decode_array_float64_le", (PyCFunction) CBORDecoder_decode_array_float64_le, METH_NOARGS,
"decode a typed array of little-endian double-precision floating-point values"},
{"decode_set", (PyCFunction) CBORDecoder_decode_set, METH_NOARGS,
"decode a set or frozenset from the input"},
{"decode_ipaddress", (PyCFunction) CBORDecoder_decode_ipaddress, METH_NOARGS,
Expand Down
23 changes: 23 additions & 0 deletions tests/test_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,3 +623,26 @@ def test_huge_truncated_bytes(impl):
def test_huge_truncated_string(impl):
with pytest.raises((impl.CBORDecodeEOF, MemoryError)):
impl.loads(unhexlify('7B37388519251ae9ca'))


@pytest.mark.parametrize('payload, expected', [
# 16-bit Big Endian
('d850483e00410048c0c500', [1.5, 2.5, 9.5, -5.0]),
# 16-bit Little Endian
('d85448003e0041c04800c5', [1.5, 2.5, 9.5, -5.0]),
# 32-bit Big Endian
('d851503fc000004020000041180000c0a00000', [1.5, 2.5, 9.5, -5.0]),
# 32-bit Little Endian
('d855500000c03f00002040000018410000a0c0', [1.5, 2.5, 9.5, -5.0]),
# 64-bit Big Endian
('d85258203ff800000000000040040000000000004023000000000000c014000000000000',
[1.5, 2.5, 9.5, -5.0]),
# 64-bit Little Endian
('d8565820000000000000f83f0000000000000440000000000000234000000000000014c0',
[1.5, 2.5, 9.5, -5.0]),
('d85640', []),
])
def test_typed_array_floats(impl, payload, expected):
src_bytes = unhexlify(payload)
decoded = impl.loads(src_bytes)
assert decoded == expected

0 comments on commit cad2aaf

Please sign in to comment.