Skip to content

Commit

Permalink
Improve GetIndexedFieldExpr adding utf8 key based access for struct v… (
Browse files Browse the repository at this point in the history
#1204)

* Improve GetIndexedFieldExpr adding utf8 key based access for struct values

* fix clippies
  • Loading branch information
Igosuki committed Nov 2, 2021
1 parent a4f4de8 commit 6a7dbbb
Show file tree
Hide file tree
Showing 4 changed files with 220 additions and 14 deletions.
21 changes: 20 additions & 1 deletion datafusion/src/field_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use arrow::datatypes::{DataType, Field};
use crate::error::{DataFusionError, Result};
use crate::scalar::ScalarValue;

/// Returns the field access indexed by `key` from a [`DataType::List`]
/// Returns the field access indexed by `key` from a [`DataType::List`] or [`DataType::Struct`]
/// # Error
/// Errors if
/// * the `data_type` is not a Struct or,
Expand All @@ -39,6 +39,25 @@ pub fn get_indexed_field(data_type: &DataType, key: &ScalarValue) -> Result<Fiel
Ok(Field::new(&i.to_string(), lt.data_type().clone(), false))
}
}
(DataType::Struct(fields), ScalarValue::Utf8(Some(s))) => {
if s.is_empty() {
Err(DataFusionError::Plan(
"Struct based indexed access requires a non empty string".to_string(),
))
} else {
let field = fields.iter().find(|f| f.name() == s);
match field {
None => Err(DataFusionError::Plan(format!(
"Field {} not found in struct",
s
))),
Some(f) => Ok(f.clone()),
}
}
}
(DataType::Struct(_), _) => Err(DataFusionError::Plan(
"Only utf8 strings are valid as an indexed field in a struct".to_string(),
)),
(DataType::List(_), _) => Err(DataFusionError::Plan(
"Only ints are valid as an indexed field in a list".to_string(),
)),
Expand Down
2 changes: 1 addition & 1 deletion datafusion/src/logical_plan/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ pub enum Expr {
IsNull(Box<Expr>),
/// arithmetic negation of an expression, the operand must be of a signed numeric data type
Negative(Box<Expr>),
/// Returns the field of a [`ListArray`] by key
/// Returns the field of a [`ListArray`] or [`StructArray`] by key
GetIndexedField {
/// the expression to take the field from
expr: Box<Expr>,
Expand Down
167 changes: 155 additions & 12 deletions datafusion/src/physical_plan/expressions/get_indexed_field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ use crate::{
field_util::get_indexed_field as get_data_type_field,
physical_plan::{ColumnarValue, PhysicalExpr},
};
use arrow::array::ListArray;
use arrow::array::{ListArray, StructArray};
use std::fmt::Debug;

/// expression to get a field of a struct array.
Expand Down Expand Up @@ -81,7 +81,7 @@ impl PhysicalExpr for GetIndexedFieldExpr {
let arg = self.arg.evaluate(batch)?;
match arg {
ColumnarValue::Array(array) => match (array.data_type(), &self.key) {
(DataType::List(_), _) if self.key.is_null() => {
(DataType::List(_) | DataType::Struct(_), _) if self.key.is_null() => {
let scalar_null: ScalarValue = array.data_type().try_into()?;
Ok(ColumnarValue::Scalar(scalar_null))
}
Expand All @@ -100,6 +100,13 @@ impl PhysicalExpr for GetIndexedFieldExpr {
let iter = concat(vec.as_slice()).unwrap();
Ok(ColumnarValue::Array(iter))
}
(DataType::Struct(_), ScalarValue::Utf8(Some(k))) => {
let as_struct_array = array.as_any().downcast_ref::<StructArray>().unwrap();
match as_struct_array.column_by_name(k) {
None => Err(DataFusionError::Execution(format!("get indexed field {} not found in struct", k))),
Some(col) => Ok(ColumnarValue::Array(col.clone()))
}
}
(dt, key) => Err(DataFusionError::NotImplemented(format!("get indexed field is only possible on lists with int64 indexes. Tried {} with {} index", dt, key))),
},
ColumnarValue::Scalar(_) => Err(DataFusionError::NotImplemented(
Expand All @@ -112,18 +119,16 @@ impl PhysicalExpr for GetIndexedFieldExpr {
#[cfg(test)]
mod tests {
use super::*;
use crate::arrow::array::GenericListArray;
use crate::error::Result;
use crate::physical_plan::expressions::{col, lit};
use arrow::array::{ListBuilder, StringBuilder};
use arrow::array::{
Int64Array, Int64Builder, ListBuilder, StringBuilder, StructArray, StructBuilder,
};
use arrow::{array::StringArray, datatypes::Field};

fn get_indexed_field_test(
list_of_lists: Vec<Vec<Option<&str>>>,
index: i64,
expected: Vec<Option<&str>>,
) -> Result<()> {
let schema = list_schema("l");
let builder = StringBuilder::new(3);
fn build_utf8_lists(list_of_lists: Vec<Vec<Option<&str>>>) -> GenericListArray<i32> {
let builder = StringBuilder::new(list_of_lists.len());
let mut lb = ListBuilder::new(builder);
for values in list_of_lists {
let builder = lb.values();
Expand All @@ -137,9 +142,18 @@ mod tests {
lb.append(true).unwrap();
}

let expr = col("l", &schema).unwrap();
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(lb.finish())])?;
lb.finish()
}

fn get_indexed_field_test(
list_of_lists: Vec<Vec<Option<&str>>>,
index: i64,
expected: Vec<Option<&str>>,
) -> Result<()> {
let schema = list_schema("l");
let list_col = build_utf8_lists(list_of_lists);
let expr = col("l", &schema).unwrap();
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(list_col)])?;
let key = ScalarValue::Int64(Some(index));
let expr = Arc::new(GetIndexedFieldExpr::new(expr, key));
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
Expand Down Expand Up @@ -222,4 +236,133 @@ mod tests {
let expr = col("l", &schema).unwrap();
get_indexed_field_test_failure(schema, expr, ScalarValue::Int8(Some(0)), "This feature is not implemented: get indexed field is only possible on lists with int64 indexes. Tried List(Field { name: \"item\", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: None }) with 0 index")
}

fn build_struct(
fields: Vec<Field>,
list_of_tuples: Vec<(Option<i64>, Vec<Option<&str>>)>,
) -> StructArray {
let foo_builder = Int64Array::builder(list_of_tuples.len());
let str_builder = StringBuilder::new(list_of_tuples.len());
let bar_builder = ListBuilder::new(str_builder);
let mut builder = StructBuilder::new(
fields,
vec![Box::new(foo_builder), Box::new(bar_builder)],
);
for (int_value, list_value) in list_of_tuples {
let fb = builder.field_builder::<Int64Builder>(0).unwrap();
match int_value {
None => fb.append_null(),
Some(v) => fb.append_value(v),
}
.unwrap();
builder.append(true).unwrap();
let lb = builder
.field_builder::<ListBuilder<StringBuilder>>(1)
.unwrap();
for str_value in list_value {
match str_value {
None => lb.values().append_null(),
Some(v) => lb.values().append_value(v),
}
.unwrap();
}
lb.append(true).unwrap();
}
builder.finish()
}

fn get_indexed_field_mixed_test(
list_of_tuples: Vec<(Option<i64>, Vec<Option<&str>>)>,
expected_strings: Vec<Vec<Option<&str>>>,
expected_ints: Vec<Option<i64>>,
) -> Result<()> {
let struct_col = "s";
let fields = vec![
Field::new("foo", DataType::Int64, true),
Field::new(
"bar",
DataType::List(Box::new(Field::new("item", DataType::Utf8, true))),
true,
),
];
let schema = Schema::new(vec![Field::new(
struct_col,
DataType::Struct(fields.clone()),
true,
)]);
let struct_col = build_struct(fields, list_of_tuples.clone());

let struct_col_expr = col("s", &schema).unwrap();
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(struct_col)])?;

let int_field_key = ScalarValue::Utf8(Some("foo".to_string()));
let get_field_expr = Arc::new(GetIndexedFieldExpr::new(
struct_col_expr.clone(),
int_field_key,
));
let result = get_field_expr
.evaluate(&batch)?
.into_array(batch.num_rows());
let result = result
.as_any()
.downcast_ref::<Int64Array>()
.expect("failed to downcast to Int64Array");
let expected = &Int64Array::from(expected_ints);
assert_eq!(expected, result);

let list_field_key = ScalarValue::Utf8(Some("bar".to_string()));
let get_list_expr =
Arc::new(GetIndexedFieldExpr::new(struct_col_expr, list_field_key));
let result = get_list_expr.evaluate(&batch)?.into_array(batch.num_rows());
let result = result
.as_any()
.downcast_ref::<ListArray>()
.unwrap_or_else(|| panic!("failed to downcast to ListArray : {:?}", result));
let expected =
&build_utf8_lists(list_of_tuples.into_iter().map(|t| t.1).collect());
assert_eq!(expected, result);

for (i, expected) in expected_strings.into_iter().enumerate() {
let get_nested_str_expr = Arc::new(GetIndexedFieldExpr::new(
get_list_expr.clone(),
ScalarValue::Int64(Some(i as i64)),
));
let result = get_nested_str_expr
.evaluate(&batch)?
.into_array(batch.num_rows());
let result = result
.as_any()
.downcast_ref::<StringArray>()
.unwrap_or_else(|| {
panic!("failed to downcast to StringArray : {:?}", result)
});
let expected = &StringArray::from(expected);
assert_eq!(expected, result);
}
Ok(())
}

#[test]
fn get_indexed_field_struct() -> Result<()> {
let list_of_structs = vec![
(Some(10), vec![Some("a"), Some("b"), None]),
(Some(15), vec![None, Some("c"), Some("d")]),
(None, vec![Some("e"), None, Some("f")]),
];

let expected_list = vec![
vec![Some("a"), None, Some("e")],
vec![Some("b"), Some("c"), None],
vec![None, Some("d"), Some("f")],
];

let expected_ints = vec![Some(10), Some(15), None];

get_indexed_field_mixed_test(
list_of_structs.clone(),
expected_list,
expected_ints,
)?;
Ok(())
}
}
44 changes: 44 additions & 0 deletions datafusion/tests/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5476,3 +5476,47 @@ async fn query_nested_get_indexed_field() -> Result<()> {
assert_eq!(expected, actual);
Ok(())
}

#[tokio::test]
async fn query_nested_get_indexed_field_on_struct() -> Result<()> {
let mut ctx = ExecutionContext::new();
let nested_dt = DataType::List(Box::new(Field::new("item", DataType::Int64, true)));
// Nested schema of { "some_struct": { "bar": [i64] } }
let struct_fields = vec![Field::new("bar", nested_dt.clone(), true)];
let schema = Arc::new(Schema::new(vec![Field::new(
"some_struct",
DataType::Struct(struct_fields.clone()),
false,
)]));

let builder = PrimitiveBuilder::<Int64Type>::new(3);
let nested_lb = ListBuilder::new(builder);
let mut sb = StructBuilder::new(struct_fields, vec![Box::new(nested_lb)]);
for int_vec in vec![vec![0, 1, 2, 3], vec![4, 5, 6, 7], vec![8, 9, 10, 11]] {
let lb = sb.field_builder::<ListBuilder<Int64Builder>>(0).unwrap();
for int in int_vec {
lb.values().append_value(int).unwrap();
}
lb.append(true).unwrap();
}
let data = RecordBatch::try_new(schema.clone(), vec![Arc::new(sb.finish())])?;
let table = MemTable::try_new(schema, vec![vec![data]])?;
let table_a = Arc::new(table);

ctx.register_table("structs", table_a)?;

// Original column is micros, convert to millis and check timestamp
let sql = "SELECT some_struct[\"bar\"] as l0 FROM structs LIMIT 3";
let actual = execute(&mut ctx, sql).await;
let expected = vec![
vec!["[0, 1, 2, 3]"],
vec!["[4, 5, 6, 7]"],
vec!["[8, 9, 10, 11]"],
];
assert_eq!(expected, actual);
let sql = "SELECT some_struct[\"bar\"][0] as i0 FROM structs LIMIT 3";
let actual = execute(&mut ctx, sql).await;
let expected = vec![vec!["0"], vec!["4"], vec!["8"]];
assert_eq!(expected, actual);
Ok(())
}

0 comments on commit 6a7dbbb

Please sign in to comment.