Skip to content

Commit

Permalink
Add dictionary support for C data interface (#1407)
Browse files Browse the repository at this point in the history
* initial commit

* add integration tests for python

* address comments
  • Loading branch information
sunchao committed Mar 9, 2022
1 parent b4481b8 commit f19d1ed
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 64 deletions.
19 changes: 11 additions & 8 deletions arrow-pyarrow-integration-testing/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def assert_pyarrow_leak():
pa.field("c", pa.string()),
]
),
pa.dictionary(pa.int8(), pa.string()),
]

_unsupported_pyarrow_types = [
Expand Down Expand Up @@ -122,14 +123,6 @@ def test_type_roundtrip_raises(pyarrow_type):
with pytest.raises(pa.ArrowException):
rust.round_trip_type(pyarrow_type)


def test_dictionary_type_roundtrip():
# the dictionary type conversion is incomplete
pyarrow_type = pa.dictionary(pa.int32(), pa.string())
ty = rust.round_trip_type(pyarrow_type)
assert ty == pa.int32()


@pytest.mark.parametrize('pyarrow_type', _supported_pyarrow_types, ids=str)
def test_field_roundtrip(pyarrow_type):
pyarrow_field = pa.field("test", pyarrow_type, nullable=True)
Expand Down Expand Up @@ -263,3 +256,13 @@ def test_decimal_python():
assert a == b
del a
del b

def test_dictionary_python():
"""
Python -> Rust -> Python
"""
a = pa.array(["a", None, "b", None, "a"], type=pa.dictionary(pa.int8(), pa.string()))
b = rust.round_trip_array(a)
assert a == b
del a
del b
22 changes: 22 additions & 0 deletions arrow/src/array/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ impl TryFrom<ArrayData> for ffi::ArrowArray {

#[cfg(test)]
mod tests {
use crate::array::{DictionaryArray, Int32Array, StringArray};
use crate::error::Result;
use crate::{
array::{
Expand Down Expand Up @@ -127,4 +128,25 @@ mod tests {
let data = array.data();
test_round_trip(data)
}

#[test]
fn test_dictionary() -> Result<()> {
let values = StringArray::from(vec![Some("foo"), Some("bar"), None]);
let keys = Int32Array::from(vec![
Some(0),
Some(1),
None,
Some(1),
Some(1),
None,
Some(1),
Some(2),
Some(1),
None,
]);
let array = DictionaryArray::try_new(&keys, &values)?;

let data = array.data();
test_round_trip(data)
}
}
104 changes: 59 additions & 45 deletions arrow/src/datatypes/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ impl TryFrom<&FFI_ArrowSchema> for DataType {

/// See [CDataInterface docs](https://arrow.apache.org/docs/format/CDataInterface.html#data-type-description-format-strings)
fn try_from(c_schema: &FFI_ArrowSchema) -> Result<Self> {
let dtype = match c_schema.format() {
let mut dtype = match c_schema.format() {
"n" => DataType::Null,
"b" => DataType::Boolean,
"c" => DataType::Int8,
Expand Down Expand Up @@ -134,6 +134,12 @@ impl TryFrom<&FFI_ArrowSchema> for DataType {
}
}
};

if let Some(dict_schema) = c_schema.dictionary() {
let value_type = Self::try_from(dict_schema)?;
dtype = DataType::Dictionary(Box::new(dtype), Box::new(value_type));
}

Ok(dtype)
}
}
Expand Down Expand Up @@ -169,49 +175,7 @@ impl TryFrom<&DataType> for FFI_ArrowSchema {

/// See [CDataInterface docs](https://arrow.apache.org/docs/format/CDataInterface.html#data-type-description-format-strings)
fn try_from(dtype: &DataType) -> Result<Self> {
let format = match dtype {
DataType::Null => "n".to_string(),
DataType::Boolean => "b".to_string(),
DataType::Int8 => "c".to_string(),
DataType::UInt8 => "C".to_string(),
DataType::Int16 => "s".to_string(),
DataType::UInt16 => "S".to_string(),
DataType::Int32 => "i".to_string(),
DataType::UInt32 => "I".to_string(),
DataType::Int64 => "l".to_string(),
DataType::UInt64 => "L".to_string(),
DataType::Float16 => "e".to_string(),
DataType::Float32 => "f".to_string(),
DataType::Float64 => "g".to_string(),
DataType::Binary => "z".to_string(),
DataType::LargeBinary => "Z".to_string(),
DataType::Utf8 => "u".to_string(),
DataType::LargeUtf8 => "U".to_string(),
DataType::Decimal(precision, scale) => format!("d:{},{}", precision, scale),
DataType::Date32 => "tdD".to_string(),
DataType::Date64 => "tdm".to_string(),
DataType::Time32(TimeUnit::Second) => "tts".to_string(),
DataType::Time32(TimeUnit::Millisecond) => "ttm".to_string(),
DataType::Time64(TimeUnit::Microsecond) => "ttu".to_string(),
DataType::Time64(TimeUnit::Nanosecond) => "ttn".to_string(),
DataType::Timestamp(TimeUnit::Second, None) => "tss:".to_string(),
DataType::Timestamp(TimeUnit::Millisecond, None) => "tsm:".to_string(),
DataType::Timestamp(TimeUnit::Microsecond, None) => "tsu:".to_string(),
DataType::Timestamp(TimeUnit::Nanosecond, None) => "tsn:".to_string(),
DataType::Timestamp(TimeUnit::Second, Some(tz)) => format!("tss:{}", tz),
DataType::Timestamp(TimeUnit::Millisecond, Some(tz)) => format!("tsm:{}", tz),
DataType::Timestamp(TimeUnit::Microsecond, Some(tz)) => format!("tsu:{}", tz),
DataType::Timestamp(TimeUnit::Nanosecond, Some(tz)) => format!("tsn:{}", tz),
DataType::List(_) => "+l".to_string(),
DataType::LargeList(_) => "+L".to_string(),
DataType::Struct(_) => "+s".to_string(),
other => {
return Err(ArrowError::CDataInterface(format!(
"The datatype \"{:?}\" is still not supported in Rust implementation",
other
)))
}
};
let format = get_format_string(dtype)?;
// allocate and hold the children
let children = match dtype {
DataType::List(child) | DataType::LargeList(child) => {
Expand All @@ -223,7 +187,57 @@ impl TryFrom<&DataType> for FFI_ArrowSchema {
.collect::<Result<Vec<_>>>()?,
_ => vec![],
};
FFI_ArrowSchema::try_new(&format, children)
let dictionary = if let DataType::Dictionary(_, value_data_type) = dtype {
Some(Self::try_from(value_data_type.as_ref())?)
} else {
None
};
FFI_ArrowSchema::try_new(&format, children, dictionary)
}
}

fn get_format_string(dtype: &DataType) -> Result<String> {
match dtype {
DataType::Null => Ok("n".to_string()),
DataType::Boolean => Ok("b".to_string()),
DataType::Int8 => Ok("c".to_string()),
DataType::UInt8 => Ok("C".to_string()),
DataType::Int16 => Ok("s".to_string()),
DataType::UInt16 => Ok("S".to_string()),
DataType::Int32 => Ok("i".to_string()),
DataType::UInt32 => Ok("I".to_string()),
DataType::Int64 => Ok("l".to_string()),
DataType::UInt64 => Ok("L".to_string()),
DataType::Float16 => Ok("e".to_string()),
DataType::Float32 => Ok("f".to_string()),
DataType::Float64 => Ok("g".to_string()),
DataType::Binary => Ok("z".to_string()),
DataType::LargeBinary => Ok("Z".to_string()),
DataType::Utf8 => Ok("u".to_string()),
DataType::LargeUtf8 => Ok("U".to_string()),
DataType::Decimal(precision, scale) => Ok(format!("d:{},{}", precision, scale)),
DataType::Date32 => Ok("tdD".to_string()),
DataType::Date64 => Ok("tdm".to_string()),
DataType::Time32(TimeUnit::Second) => Ok("tts".to_string()),
DataType::Time32(TimeUnit::Millisecond) => Ok("ttm".to_string()),
DataType::Time64(TimeUnit::Microsecond) => Ok("ttu".to_string()),
DataType::Time64(TimeUnit::Nanosecond) => Ok("ttn".to_string()),
DataType::Timestamp(TimeUnit::Second, None) => Ok("tss:".to_string()),
DataType::Timestamp(TimeUnit::Millisecond, None) => Ok("tsm:".to_string()),
DataType::Timestamp(TimeUnit::Microsecond, None) => Ok("tsu:".to_string()),
DataType::Timestamp(TimeUnit::Nanosecond, None) => Ok("tsn:".to_string()),
DataType::Timestamp(TimeUnit::Second, Some(tz)) => Ok(format!("tss:{}", tz)),
DataType::Timestamp(TimeUnit::Millisecond, Some(tz)) => Ok(format!("tsm:{}", tz)),
DataType::Timestamp(TimeUnit::Microsecond, Some(tz)) => Ok(format!("tsu:{}", tz)),
DataType::Timestamp(TimeUnit::Nanosecond, Some(tz)) => Ok(format!("tsn:{}", tz)),
DataType::List(_) => Ok("+l".to_string()),
DataType::LargeList(_) => Ok("+L".to_string()),
DataType::Struct(_) => Ok("+s".to_string()),
DataType::Dictionary(key_data_type, _) => get_format_string(key_data_type),
other => Err(ArrowError::CDataInterface(format!(
"The datatype \"{:?}\" is still not supported in Rust implementation",
other
))),
}
}

Expand Down

0 comments on commit f19d1ed

Please sign in to comment.