Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support large-utf8 in groupby #35

Merged
merged 2 commits into from Apr 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
51 changes: 51 additions & 0 deletions datafusion/src/execution/context.rs
Expand Up @@ -1646,6 +1646,57 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn group_by_largeutf8() {
{
let mut ctx = ExecutionContext::new();

// input data looks like:
// A, 1
// B, 2
// A, 2
// A, 4
// C, 1
// A, 1

let str_array: LargeStringArray = vec!["A", "B", "A", "A", "C", "A"]
.into_iter()
.map(Some)
.collect();
let str_array = Arc::new(str_array);

let val_array: Int64Array = vec![1, 2, 2, 4, 1, 1].into();
let val_array = Arc::new(val_array);

let schema = Arc::new(Schema::new(vec![
Field::new("str", str_array.data_type().clone(), false),
Field::new("val", val_array.data_type().clone(), false),
]));

let batch =
RecordBatch::try_new(schema.clone(), vec![str_array, val_array]).unwrap();

let provider = MemTable::try_new(schema.clone(), vec![vec![batch]]).unwrap();
ctx.register_table("t", Arc::new(provider)).unwrap();

let results =
plan_and_collect(&mut ctx, "SELECT str, count(val) FROM t GROUP BY str")
.await
.expect("ran plan correctly");

let expected = vec![
"+-----+------------+",
"| str | COUNT(val) |",
"+-----+------------+",
"| A | 4 |",
"| B | 1 |",
"| C | 1 |",
"+-----+------------+",
];
assert_batches_sorted_eq!(expected, &results);
}
}

#[tokio::test]
async fn group_by_dictionary() {
async fn run_test_case<K: ArrowDictionaryKeyType>() {
Expand Down
9 changes: 7 additions & 2 deletions datafusion/src/physical_plan/group_scalar.rs
Expand Up @@ -37,6 +37,7 @@ pub(crate) enum GroupByScalar {
Int32(i32),
Int64(i64),
Utf8(Box<String>),
LargeUtf8(Box<String>),
Boolean(bool),
TimeMillisecond(i64),
TimeMicrosecond(i64),
Expand Down Expand Up @@ -74,6 +75,9 @@ impl TryFrom<&ScalarValue> for GroupByScalar {
GroupByScalar::TimeNanosecond(*v)
}
ScalarValue::Utf8(Some(v)) => GroupByScalar::Utf8(Box::new(v.clone())),
ScalarValue::LargeUtf8(Some(v)) => {
GroupByScalar::LargeUtf8(Box::new(v.clone()))
}
ScalarValue::Float32(None)
| ScalarValue::Float64(None)
| ScalarValue::Boolean(None)
Expand Down Expand Up @@ -116,6 +120,7 @@ impl From<&GroupByScalar> for ScalarValue {
GroupByScalar::UInt32(v) => ScalarValue::UInt32(Some(*v)),
GroupByScalar::UInt64(v) => ScalarValue::UInt64(Some(*v)),
GroupByScalar::Utf8(v) => ScalarValue::Utf8(Some(v.to_string())),
GroupByScalar::LargeUtf8(v) => ScalarValue::LargeUtf8(Some(v.to_string())),
GroupByScalar::TimeMillisecond(v) => {
ScalarValue::TimestampMillisecond(Some(*v))
}
Expand Down Expand Up @@ -191,14 +196,14 @@ mod tests {
#[test]
fn from_scalar_unsupported() {
// Use any ScalarValue type not supported by GroupByScalar.
let scalar_value = ScalarValue::LargeUtf8(Some("1.1".to_string()));
let scalar_value = ScalarValue::Binary(Some(vec![1, 2]));
let result = GroupByScalar::try_from(&scalar_value);

match result {
Err(DataFusionError::Internal(error_message)) => assert_eq!(
error_message,
String::from(
"Cannot convert a ScalarValue with associated DataType LargeUtf8"
"Cannot convert a ScalarValue with associated DataType Binary"
)
),
_ => panic!("Unexpected result"),
Expand Down
18 changes: 17 additions & 1 deletion datafusion/src/physical_plan/hash_aggregate.rs
Expand Up @@ -59,7 +59,8 @@ use ordered_float::OrderedFloat;
use pin_project_lite::pin_project;

use arrow::array::{
TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
LargeStringArray, TimestampMicrosecondArray, TimestampMillisecondArray,
TimestampNanosecondArray,
};
use async_trait::async_trait;

Expand Down Expand Up @@ -540,6 +541,14 @@ fn create_key_for_col(col: &ArrayRef, row: usize, vec: &mut Vec<u8>) -> Result<(
// store the string value
vec.extend_from_slice(value.as_bytes());
}
DataType::LargeUtf8 => {
let array = col.as_any().downcast_ref::<LargeStringArray>().unwrap();
let value = array.value(row);
// store the size
vec.extend_from_slice(&value.len().to_le_bytes());
// store the string value
vec.extend_from_slice(value.as_bytes());
}
DataType::Date32 => {
let array = col.as_any().downcast_ref::<Date32Array>().unwrap();
vec.extend_from_slice(&array.value(row).to_le_bytes());
Expand Down Expand Up @@ -953,6 +962,9 @@ fn create_batch_from_map(
GroupByScalar::Utf8(str) => {
Arc::new(StringArray::from(vec![&***str]))
}
GroupByScalar::LargeUtf8(str) => {
Arc::new(LargeStringArray::from(vec![&***str]))
}
GroupByScalar::Boolean(b) => Arc::new(BooleanArray::from(vec![*b])),
GroupByScalar::TimeMillisecond(n) => {
Arc::new(TimestampMillisecondArray::from(vec![*n]))
Expand Down Expand Up @@ -1103,6 +1115,10 @@ fn create_group_by_value(col: &ArrayRef, row: usize) -> Result<GroupByScalar> {
let array = col.as_any().downcast_ref::<StringArray>().unwrap();
Ok(GroupByScalar::Utf8(Box::new(array.value(row).into())))
}
DataType::LargeUtf8 => {
let array = col.as_any().downcast_ref::<LargeStringArray>().unwrap();
Ok(GroupByScalar::Utf8(Box::new(array.value(row).into())))
}
DataType::Boolean => {
let array = col.as_any().downcast_ref::<BooleanArray>().unwrap();
Ok(GroupByScalar::Boolean(array.value(row)))
Expand Down
3 changes: 3 additions & 0 deletions datafusion/src/physical_plan/hash_join.rs
Expand Up @@ -831,6 +831,9 @@ pub fn create_hashes<'a>(
DataType::Utf8 => {
hash_array!(StringArray, col, str, hashes_buffer, random_state);
}
DataType::LargeUtf8 => {
hash_array!(LargeStringArray, col, str, hashes_buffer, random_state);
}
_ => {
// This is internal because we should have caught this before.
return Err(DataFusionError::Internal(
Expand Down
2 changes: 1 addition & 1 deletion datafusion/src/physical_plan/type_coercion.rs
Expand Up @@ -196,7 +196,7 @@ pub fn can_coerce_from(type_into: &DataType, type_from: &DataType) -> bool {
| Float64
),
Timestamp(TimeUnit::Nanosecond, None) => matches!(type_from, Timestamp(_, None)),
Utf8 => true,
Utf8 | LargeUtf8 => true,
_ => false,
}
}
Expand Down