Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions datafusion/core/tests/sql/projection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

use datafusion::logical_plan::{provider_as_source, LogicalPlanBuilder, UNNAMED_TABLE};
use datafusion::test_util::scan_empty;
use datafusion_expr::when;
use tempfile::TempDir;

use super::*;
Expand Down Expand Up @@ -220,6 +221,48 @@ async fn preserve_nullability_on_projection() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn project_cast_dictionary() {
let ctx = SessionContext::new();

let host: DictionaryArray<Int32Type> = vec![Some("host1"), None, Some("host2")]
.into_iter()
.collect();

let batch = RecordBatch::try_from_iter(vec![("host", Arc::new(host) as _)]).unwrap();

let t = MemTable::try_new(batch.schema(), vec![vec![batch]]).unwrap();

// Note that `host` is a dictionary array but `lit("")` is a DataType::Utf8 that needs to be cast
let expr = when(col("host").is_null(), lit(""))
.otherwise(col("host"))
.unwrap();

let projection = None;
let builder = LogicalPlanBuilder::scan(
"cpu_load_short",
provider_as_source(Arc::new(t)),
projection,
)
.unwrap();

let logical_plan = builder.project(vec![expr]).unwrap().build().unwrap();

let physical_plan = ctx.create_physical_plan(&logical_plan).await.unwrap();
let actual = collect(physical_plan, ctx.task_ctx()).await.unwrap();

let expected = vec![
"+------------------------------------------------------------------------------------+",
"| CASE WHEN #cpu_load_short.host IS NULL THEN Utf8(\"\") ELSE #cpu_load_short.host END |",
"+------------------------------------------------------------------------------------+",
"| host1 |",
"| |",
"| host2 |",
"+------------------------------------------------------------------------------------+",
];
assert_batches_eq!(expected, &actual);
}

#[tokio::test]
async fn projection_on_memory_scan() -> Result<()> {
let schema = Schema::new(vec![
Expand Down
108 changes: 91 additions & 17 deletions datafusion/physical-expr/src/expressions/try_cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,19 @@ impl PhysicalExpr for TryCastExpr {
&array,
&self.cast_type,
)?)),
ColumnarValue::Scalar(scalar)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the actual code change -- the rest of the PR is tests

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't this prevent this from using the scalar comparison kernels, effectively reversing #2808

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤔 that is a good point -- perhaps I'll have a go at a proper fix for #2874

if matches!(self.cast_type, DataType::Dictionary(_, _)) =>
{
// ScalarValues do not preserve dictionary encoding
// (so they don't survive the round trip),
// https://github.com/apache/arrow-datafusion/issues/2874
// Until that is fixed, "unpack" the ColumnarValue here
let array = scalar.to_array_of_size(batch.num_rows());
Ok(ColumnarValue::Array(kernels::cast::cast(
&array,
&self.cast_type,
)?))
}
ColumnarValue::Scalar(scalar) => {
let scalar_array = scalar.to_array();
let cast_array = kernels::cast::cast(&scalar_array, &self.cast_type)?;
Expand Down Expand Up @@ -119,8 +132,8 @@ mod tests {
use super::*;
use crate::expressions::col;
use arrow::array::{
BasicDecimalArray, DecimalArray, DecimalBuilder, StringArray,
Time64NanosecondArray,
as_string_array, BasicDecimalArray, DecimalArray, DecimalBuilder,
DictionaryArray, StringArray, Time64NanosecondArray,
};
use arrow::util::decimal::{BasicDecimal, Decimal128};
use arrow::{
Expand Down Expand Up @@ -185,10 +198,13 @@ mod tests {
// 3. evaluate the expression
// 4. verify that the resulting expression is of type B
// 5. verify that the resulting values are downcastable and correct
//
// $VALUE_FN is an expression (like `result.value`) that extracts the value at index `i`
macro_rules! generic_test_cast {
($A_ARRAY:ident, $A_TYPE:expr, $A_VEC:expr, $TYPEARRAY:ident, $TYPE:expr, $VEC:expr) => {{
($A_ARRAY:ident, $A_TYPE:expr, $A_VEC:expr, $TYPEARRAY:ident, $TYPE:expr, $VEC:expr, $VALUE_FN:expr) => {{
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

needed to add a way to extract the value of the result at element i because it is different for a DictionaryArray

let schema = Schema::new(vec![Field::new("a", $A_TYPE, true)]);
let a = $A_ARRAY::from($A_VEC);
let a = $A_ARRAY::from_iter($A_VEC);

let batch =
RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;

Expand Down Expand Up @@ -222,7 +238,7 @@ mod tests {
// verify that the result itself is correct
for (i, x) in $VEC.iter().enumerate() {
match x {
Some(x) => assert_eq!(result.value(i), *x),
Some(x) => assert_eq!($VALUE_FN(result, i), *x),
None => assert!(!result.is_valid(i)),
}
}
Expand Down Expand Up @@ -396,7 +412,8 @@ mod tests {
Some(convert(3)),
Some(convert(4)),
Some(convert(5)),
]
],
|result: &DecimalArray, i| result.value(i)
);

// int16
Expand All @@ -413,7 +430,8 @@ mod tests {
Some(convert(3)),
Some(convert(4)),
Some(convert(5)),
]
],
|result: &DecimalArray, i| result.value(i)
);

// int32
Expand All @@ -430,7 +448,8 @@ mod tests {
Some(convert(3)),
Some(convert(4)),
Some(convert(5)),
]
],
|result: &DecimalArray, i| result.value(i)
);

// int64
Expand All @@ -447,7 +466,8 @@ mod tests {
Some(convert(3)),
Some(convert(4)),
Some(convert(5)),
]
],
|result: &DecimalArray, i| result.value(i)
);

// int64 to different scale
Expand All @@ -464,7 +484,8 @@ mod tests {
Some(convert(300)),
Some(convert(400)),
Some(convert(500)),
]
],
|result: &DecimalArray, i| result.value(i)
);

// float32
Expand All @@ -481,7 +502,8 @@ mod tests {
Some(convert(300)),
Some(convert(112)),
Some(convert(550)),
]
],
|result: &DecimalArray, i| result.value(i)
);

// float64
Expand All @@ -498,7 +520,8 @@ mod tests {
Some(convert(30000)),
Some(convert(11234)),
Some(convert(55000)),
]
],
|result: &DecimalArray, i| result.value(i)
);
Ok(())
}
Expand All @@ -517,7 +540,8 @@ mod tests {
Some(3_u32),
Some(4_u32),
Some(5_u32)
]
],
|result: &UInt32Array, i| result.value(i)
);
Ok(())
}
Expand All @@ -530,7 +554,8 @@ mod tests {
vec![1, 2, 3, 4, 5],
StringArray,
DataType::Utf8,
vec![Some("1"), Some("2"), Some("3"), Some("4"), Some("5")]
vec![Some("1"), Some("2"), Some("3"), Some("4"), Some("5")],
|result: &StringArray, i| result.value(i).to_string()
);
Ok(())
}
Expand All @@ -540,10 +565,58 @@ mod tests {
generic_test_cast!(
StringArray,
DataType::Utf8,
vec!["a", "2", "3", "b", "5"],
vec![Some("a"), Some("2"), Some("3"), Some("b"), Some("5")],
Int32Array,
DataType::Int32,
vec![None, Some(2), Some(3), None, Some(5)]
vec![None, Some(2), Some(3), None, Some(5)],
|result: &Int32Array, i| result.value(i)
);
Ok(())
}

#[test]
fn test_try_cast_string_dict_to_utf8() -> Result<()> {
let dict_type = DataType::Dictionary(
Box::new(DataType::Int32), // key_type
Box::new(DataType::Utf8), // value_type
);

// define a type alias so we can use the macro
type DictArrayType = DictionaryArray<Int32Type>;

generic_test_cast!(
DictArrayType,
dict_type,
vec![Some("a"), Some("b")],
StringArray,
DataType::Utf8,
vec![Some("a"), Some("b")],
|result: &StringArray, i| result.value(i).to_string()
);
Ok(())
}

#[allow(clippy::redundant_clone)]
#[test]
fn test_try_cast_utf8_to_string_dict() -> Result<()> {
let dict_type = DataType::Dictionary(
Box::new(DataType::Int32), // key_type
Box::new(DataType::Utf8), // value_type
);

// define a type alias so we can use the macro
type DictArrayType = DictionaryArray<Int32Type>;

generic_test_cast!(
StringArray,
DataType::Utf8,
vec![Some("a"), Some("b")],
DictArrayType,
dict_type.clone(),
vec![Some("a"), Some("b")],
|result: &DictArrayType, i| {
as_string_array(result.values()).value(i).to_string()
}
);
Ok(())
}
Expand All @@ -562,7 +635,8 @@ mod tests {
original.clone(),
TimestampNanosecondArray,
DataType::Timestamp(TimeUnit::Nanosecond, None),
expected
expected,
|result: &TimestampNanosecondArray, i| result.value(i)
);
Ok(())
}
Expand Down