Skip to content

Commit

Permalink
Adds support for decoding floating-point typed arrays from RFC8746
Browse files Browse the repository at this point in the history
This adds support for decoding arrays of floating point numbers of IEEE
754 formats binary16, binary32, and binary64 in both the big- and
little-endian form.
  • Loading branch information
tgockel committed Jun 4, 2021
1 parent 9f30439 commit 7361d6e
Show file tree
Hide file tree
Showing 5 changed files with 392 additions and 21 deletions.
59 changes: 59 additions & 0 deletions cbor2/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,59 @@ def decode_stringref_namespace(self):
self._stringref_namespace = old_namespace
return value

def _decode_typed_array_half_float_impl(self, name, tag, element_size, format):
"""Helper function for decoding typed arrays of half-precision floats"""
buf = self.decode()
if not isinstance(buf, bytes):
raise CBORDecodeValueError("invalid %s typed array %r" % (name, buf))
elif len(buf) % element_size != 0:
raise CBORDecodeValueError(
"invalid length for %s typed array -- must be multiple of %i, but is %i"
% (name, element_size, len(buf)))

out = struct.unpack(format % (len(buf) // element_size), buf)

if self._immutable:
return self.set_shareable(out)
else:
return self.set_shareable(list(out))

def _decode_typed_array_half_fload_func(*args):
return lambda self: self._decode_typed_array_half_float_impl(*args)

decode_array_float16_be = _decode_typed_array_half_fload_func('float16', 80, 2, '>%ie')
decode_array_float16_le = _decode_typed_array_half_fload_func('float16', 84, 2, '<%ie')

def _decode_typed_array_impl(self, name, tag, element_size, typecode, endianness):
"""Helper function for decoding typed arrays described by RFC 8746"""
import array
import sys

buf = self.decode()
if not isinstance(buf, bytes):
raise CBORDecodeValueError("invalid %s typed array %r" % (name, buf))
elif len(buf) % element_size != 0:
raise CBORDecodeValueError(
"invalid length for %s typed array -- must be multiple of %i, but is %i"
% (name, element_size, len(buf)))

out = array.array(typecode, buf)
if sys.byteorder != endianness:
out.byteswap()

if self._immutable:
return self.set_shareable(tuple(out))
else:
return self.set_shareable(out)

def _decode_typed_array_func(*args):
return lambda self: self._decode_typed_array_impl(*args)

decode_array_float32_be = _decode_typed_array_func('float32', 81, 4, 'f', 'big')
decode_array_float64_be = _decode_typed_array_func('float64', 82, 8, 'd', 'big')
decode_array_float32_le = _decode_typed_array_func('float32', 85, 4, 'f', 'little')
decode_array_float64_le = _decode_typed_array_func('float64', 86, 8, 'd', 'little')

def decode_set(self):
# Semantic tag 258
if self._immutable:
Expand Down Expand Up @@ -591,6 +644,12 @@ def decode_float64(self):
35: CBORDecoder.decode_regexp,
36: CBORDecoder.decode_mime,
37: CBORDecoder.decode_uuid,
80: CBORDecoder.decode_array_float16_be,
81: CBORDecoder.decode_array_float32_be,
82: CBORDecoder.decode_array_float64_be,
84: CBORDecoder.decode_array_float16_le,
85: CBORDecoder.decode_array_float32_le,
86: CBORDecoder.decode_array_float64_le,
256: CBORDecoder.decode_stringref_namespace,
258: CBORDecoder.decode_set,
260: CBORDecoder.decode_ipaddress,
Expand Down
283 changes: 264 additions & 19 deletions source/decoder.c
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ static PyObject * CBORDecoder_decode_bigfloat(CBORDecoderObject *);
static PyObject * CBORDecoder_decode_rational(CBORDecoderObject *);
static PyObject * CBORDecoder_decode_regexp(CBORDecoderObject *);
static PyObject * CBORDecoder_decode_uuid(CBORDecoderObject *);
static PyObject * CBORDecoder_decode_array_float16_be(CBORDecoderObject *);
static PyObject * CBORDecoder_decode_array_float32_be(CBORDecoderObject *);
static PyObject * CBORDecoder_decode_array_float64_be(CBORDecoderObject *);
static PyObject * CBORDecoder_decode_array_float16_le(CBORDecoderObject *);
static PyObject * CBORDecoder_decode_array_float32_le(CBORDecoderObject *);
static PyObject * CBORDecoder_decode_array_float64_le(CBORDecoderObject *);
static PyObject * CBORDecoder_decode_mime(CBORDecoderObject *);
static PyObject * CBORDecoder_decode_positive_bignum(CBORDecoderObject *);
static PyObject * CBORDecoder_decode_negative_bignum(CBORDecoderObject *);
Expand Down Expand Up @@ -954,25 +960,30 @@ decode_semantic(CBORDecoderObject *self, uint8_t subtype)

if (decode_length(self, subtype, &tagnum, NULL) == 0) {
switch (tagnum) {
case 0: ret = CBORDecoder_decode_datetime_string(self); break;
case 1: ret = CBORDecoder_decode_epoch_datetime(self); break;
case 2: ret = CBORDecoder_decode_positive_bignum(self); break;
case 3: ret = CBORDecoder_decode_negative_bignum(self); break;
case 4: ret = CBORDecoder_decode_fraction(self); break;
case 5: ret = CBORDecoder_decode_bigfloat(self); break;
case 25: ret = CBORDecoder_decode_stringref(self); break;
case 28: ret = CBORDecoder_decode_shareable(self); break;
case 29: ret = CBORDecoder_decode_sharedref(self); break;
case 30: ret = CBORDecoder_decode_rational(self); break;
case 35: ret = CBORDecoder_decode_regexp(self); break;
case 36: ret = CBORDecoder_decode_mime(self); break;
case 37: ret = CBORDecoder_decode_uuid(self); break;
case 256: ret = CBORDecoder_decode_stringref_ns(self); break;
case 258: ret = CBORDecoder_decode_set(self); break;
case 260: ret = CBORDecoder_decode_ipaddress(self); break;
case 261: ret = CBORDecoder_decode_ipnetwork(self); break;
case 55799: ret = CBORDecoder_decode_self_describe_cbor(self);
break;
case 0: ret = CBORDecoder_decode_datetime_string(self); break;
case 1: ret = CBORDecoder_decode_epoch_datetime(self); break;
case 2: ret = CBORDecoder_decode_positive_bignum(self); break;
case 3: ret = CBORDecoder_decode_negative_bignum(self); break;
case 4: ret = CBORDecoder_decode_fraction(self); break;
case 5: ret = CBORDecoder_decode_bigfloat(self); break;
case 25: ret = CBORDecoder_decode_stringref(self); break;
case 28: ret = CBORDecoder_decode_shareable(self); break;
case 29: ret = CBORDecoder_decode_sharedref(self); break;
case 30: ret = CBORDecoder_decode_rational(self); break;
case 35: ret = CBORDecoder_decode_regexp(self); break;
case 36: ret = CBORDecoder_decode_mime(self); break;
case 37: ret = CBORDecoder_decode_uuid(self); break;
case 80: ret = CBORDecoder_decode_array_float16_be(self); break;
case 81: ret = CBORDecoder_decode_array_float32_be(self); break;
case 82: ret = CBORDecoder_decode_array_float64_be(self); break;
case 84: ret = CBORDecoder_decode_array_float16_le(self); break;
case 85: ret = CBORDecoder_decode_array_float32_le(self); break;
case 86: ret = CBORDecoder_decode_array_float64_le(self); break;
case 256: ret = CBORDecoder_decode_stringref_ns(self); break;
case 258: ret = CBORDecoder_decode_set(self); break;
case 260: ret = CBORDecoder_decode_ipaddress(self); break;
case 261: ret = CBORDecoder_decode_ipnetwork(self); break;
case 55799: ret = CBORDecoder_decode_self_describe_cbor(self); break;

default:
tag = CBORTag_New(tagnum);
Expand Down Expand Up @@ -1444,6 +1455,228 @@ CBORDecoder_decode_stringref_ns(CBORDecoderObject *self)
self->stringref_namespace = old_namespace;
return ret;
}
/**
* Implementation of the decoder for typed arrays of half-precision floats.
*
* \param[in] self The decoder object
* \param[in] type_name The name of the type to use in error messages
* \param[in] element_size The size of the individual elements (e.g.: sizeof(uint64_t))
* \param[in] create_value Create a Python object from its byte representation
*/
static PyObject *
CBORDecoder_decode_array_typed_half_float_impl(CBORDecoderObject *self,
const char *type_name,
size_t element_size,
PyObject * (*create_value)(const char *byte_repr))
{
PyObject *bytes, *list, *ret = NULL;
Py_ssize_t bytes_size, element_count, element_idx;
char *bytes_direct;

bytes = decode(self, DECODE_UNSHARED);
if (bytes) {
if (PyBytes_CheckExact(bytes)) {
bytes_size = PyBytes_GET_SIZE(bytes);
if (bytes_size % element_size == 0) {
element_count = bytes_size / element_size;
list = PyList_New(element_count);
if (list) {
set_shareable(self, list);
bytes_direct = PyBytes_AS_STRING(bytes);

for (element_idx = 0; element_idx < element_count; bytes_direct += element_size, ++element_idx) {
PyList_SET_ITEM(list, element_idx, create_value(bytes_direct));
}

if (self->immutable) {
ret = PyList_AsTuple(list);
if (ret) {
Py_DECREF(list);
set_shareable(self, ret);
}
} else {
ret = list;
}
}
} else {
PyErr_Format(
_CBOR2_CBORDecodeValueError,
"invalid length for %s typed array %R -- must be multiple of %llu, but is %zd",
type_name,
element_size,
bytes_size);
}
} else {
PyErr_Format(
_CBOR2_CBORDecodeValueError,
"invalid %s typed array %R", type_name, bytes);
}
Py_DECREF(bytes);
}

return ret;
}


static PyObject *
create_float16_be_from_buffer(const char *src)
{
// unpack_float16 assume big-endian, so just use that
uint16_t tmp;
memcpy(&tmp, src, sizeof tmp);
return PyFloat_FromDouble(unpack_float16(tmp));
}


// CBORDecoder.decode_array_float16_be
static PyObject *
CBORDecoder_decode_array_float16_be(CBORDecoderObject *self)
{
// semantic type 80
return CBORDecoder_decode_array_typed_half_float_impl(self, "float16", 2, create_float16_be_from_buffer);
}


static PyObject *
create_float16_le_from_buffer(const char *src)
{
uint16_t tmp = (src[0] << 8) | ((uint8_t) src[1]);
return PyFloat_FromDouble(unpack_float16(tmp));
}


// CBORDecoder.decode_array_float16_le
static PyObject *
CBORDecoder_decode_array_float16_le(CBORDecoderObject *self)
{
// semantic type 84
return CBORDecoder_decode_array_typed_half_float_impl(self, "float16", 2, create_float16_le_from_buffer);
}


static bool
plaform_is_big_endian(void)
{
#if defined __BIG_ENDIAN__ \
|| defined __ARMEB__ \
|| defined __THUMBEB__ \
|| defined __AARCH64EB__ \
|| defined _MIBSEB \
|| defined __MIBSEB \
|| defined __MIBSEB__
// Big endian known at compile-time
return true;
#elif defined __LITTLE_ENDIAN__ \
|| defined __ARMEL__ \
|| defined __THUMBEL__ \
|| defined __AARCH64EL__ \
|| defined _MIPSEL \
|| defined __MIPSEL \
|| defined __MIPSEL__
// Little endian known at compile-time
return false;
#else
// Fall back to checking at run-time
char c;
size_t num = 1;
memcpy(&c, &num, 1);
return c == '\0';
#endif
}


static PyObject *
CBORDecoder_decode_array_typed_impl(CBORDecoderObject *self,
const char *type_name,
size_t element_size,
PyObject *array_typecode,
bool big_endian)
{
PyObject *bytes, *array, *byteswap_result, *ret = NULL;

if (!_CBOR2_array && _CBOR2_init_array() == -1)
return NULL;

bytes = decode(self, DECODE_UNSHARED);
if (bytes) {
if (PyBytes_CheckExact(bytes)) {
array = PyObject_CallFunctionObjArgs(_CBOR2_array, array_typecode, bytes, NULL);
if (array) {
set_shareable(self, array);

if (plaform_is_big_endian() == big_endian) {
// Byte order of source matches platform -- no need to modify array
Py_INCREF(array);
ret = array;
} else {
// Byte order is revered -- flip the elements
byteswap_result = PyObject_CallMethodObjArgs(array, _CBOR2_str_byteswap, NULL);
if (byteswap_result) {
Py_DECREF(byteswap_result);
Py_INCREF(array);
ret = array;
} else {
// byteswap failed -- error is set
array = NULL;
}
}

if (array && self->immutable) {
ret = PyObject_CallFunctionObjArgs((PyObject*) &PyTuple_Type, array, NULL);
}

Py_DECREF(array);
} else {
// array creation failed -- error is set
}
} else {
PyErr_Format(
_CBOR2_CBORDecodeValueError,
"invalid %s typed array %R", type_name, bytes);
}

Py_DECREF(bytes);
}

return ret;
}


// CBORDecoder.decode_array_float32_be
static PyObject *
CBORDecoder_decode_array_float32_be(CBORDecoderObject *self)
{
// semantic type 81
return CBORDecoder_decode_array_typed_impl(self, "float32", 4, _CBOR2_str_f, true);
}


// CBORDecoder.decode_array_float64_be
static PyObject *
CBORDecoder_decode_array_float64_be(CBORDecoderObject *self)
{
// semantic type 82
return CBORDecoder_decode_array_typed_impl(self, "float64", 8, _CBOR2_str_d, true);
}


// CBORDecoder.decode_array_float32_le
static PyObject *
CBORDecoder_decode_array_float32_le(CBORDecoderObject *self)
{
// semantic type 85
return CBORDecoder_decode_array_typed_impl(self, "float32", 4, _CBOR2_str_f, false);
}


// CBORDecoder.decode_array_float64_le
static PyObject *
CBORDecoder_decode_array_float64_le(CBORDecoderObject *self)
{
// semantic type 86
return CBORDecoder_decode_array_typed_impl(self, "float64", 8, _CBOR2_str_d, false);
}


// CBORDecoder.decode_set(self)
static PyObject *
Expand Down Expand Up @@ -1862,6 +2095,18 @@ static PyMethodDef CBORDecoder_methods[] = {
"decode a string reference from the input"},
{"decode_stringref_namespace", (PyCFunction) CBORDecoder_decode_stringref_ns, METH_NOARGS,
"decode a string reference namespace from the input"},
{"decode_array_float16_be", (PyCFunction) CBORDecoder_decode_array_float16_be, METH_NOARGS,
"decode a typed array of big-endian half-precision floating-point values"},
{"decode_array_float32_be", (PyCFunction) CBORDecoder_decode_array_float32_be, METH_NOARGS,
"decode a typed array of big-endian single-precision floating-point values"},
{"decode_array_float64_be", (PyCFunction) CBORDecoder_decode_array_float64_be, METH_NOARGS,
"decode a typed array of big-endian double-precision floating-point values"},
{"decode_array_float16_le", (PyCFunction) CBORDecoder_decode_array_float16_le, METH_NOARGS,
"decode a typed array of little-endian half-precision floating-point values"},
{"decode_array_float32_le", (PyCFunction) CBORDecoder_decode_array_float32_le, METH_NOARGS,
"decode a typed array of little-endian single-precision floating-point values"},
{"decode_array_float64_le", (PyCFunction) CBORDecoder_decode_array_float64_le, METH_NOARGS,
"decode a typed array of little-endian double-precision floating-point values"},
{"decode_set", (PyCFunction) CBORDecoder_decode_set, METH_NOARGS,
"decode a set or frozenset from the input"},
{"decode_ipaddress", (PyCFunction) CBORDecoder_decode_ipaddress, METH_NOARGS,
Expand Down
Loading

0 comments on commit 7361d6e

Please sign in to comment.