Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions datafusion/physical-expr/src/expressions/case.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ enum EvalMethod {
///
/// CASE WHEN condition THEN column [ELSE NULL] END
InfallibleExprOrNull,
/// This is a specialization for a specific use case where we can take a fast path
/// if there is just one when/then pair and both the `then` and `else` expressions
/// are literal values
/// CASE WHEN condition THEN literal ELSE literal END
ScalarOrScalar,
}

/// The CASE expression is similar to a series of nested if/else and there are two forms that
Expand Down Expand Up @@ -140,6 +145,12 @@ impl CaseExpr {
&& else_expr.is_none()
{
EvalMethod::InfallibleExprOrNull
} else if when_then_expr.len() == 1
&& when_then_expr[0].1.as_any().is::<Literal>()
&& else_expr.is_some()
&& else_expr.as_ref().unwrap().as_any().is::<Literal>()
{
EvalMethod::ScalarOrScalar
} else {
EvalMethod::NoExpression
};
Expand Down Expand Up @@ -344,6 +355,38 @@ impl CaseExpr {
internal_err!("predicate did not evaluate to an array")
}
}

fn scalar_or_scalar(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
let return_type = self.data_type(&batch.schema())?;

// evaluate when expression
let when_value = self.when_then_expr[0].0.evaluate(batch)?;
let when_value = when_value.into_array(batch.num_rows())?;
let when_value = as_boolean_array(&when_value).map_err(|e| {
DataFusionError::Context(
"WHEN expression did not return a BooleanArray".to_string(),
Box::new(e),
)
})?;

// Treat 'NULL' as false value
let when_value = match when_value.null_count() {
0 => Cow::Borrowed(when_value),
_ => Cow::Owned(prep_null_mask_filter(when_value)),
};

// evaluate then_value
let then_value = self.when_then_expr[0].1.evaluate(batch)?;
let then_value = Scalar::new(then_value.into_array(1)?);

// keep `else_expr`'s data type and return type consistent
let e = self.else_expr.as_ref().unwrap();
let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())
.unwrap_or_else(|_| Arc::clone(e));
let else_ = Scalar::new(expr.evaluate(batch)?.into_array(1)?);

Ok(ColumnarValue::Array(zip(&when_value, &then_value, &else_)?))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the input is ColumnarValue::Scalar shouldn't the output also be a ColumnarValue::Scalar (rather than a ColumnarValue::Array?)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, the output will be an array containing values based on two scalar arguments.

SELECT CASE WHEN a > 2 THEN 'even' ELSE 'odd' END FROM foo
----
odd
even
even

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it would make sense (in a separate PR) to produce a dictionary array in this case since it will only even contain two distinct values? 🤔

}
}

impl PhysicalExpr for CaseExpr {
Expand Down Expand Up @@ -406,6 +449,7 @@ impl PhysicalExpr for CaseExpr {
// Specialization for CASE WHEN expr THEN column [ELSE NULL] END
self.case_column_or_null(batch)
}
EvalMethod::ScalarOrScalar => self.scalar_or_scalar(batch),
}
}

Expand Down
60 changes: 59 additions & 1 deletion datafusion/sqllogictest/test_files/case.slt
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

# create test data
statement ok
create table foo (a int, b int) as values (1, 2), (3, 4), (5, 6);
create table foo (a int, b int) as values (1, 2), (3, 4), (5, 6), (null, null), (6, null), (null, 7);

# CASE WHEN with condition
query T
Expand All @@ -26,6 +26,9 @@ SELECT CASE a WHEN 1 THEN 'one' WHEN 3 THEN 'three' ELSE '?' END FROM foo
one
three
?
?
?
?

# CASE WHEN with no condition
query I
Expand All @@ -34,6 +37,9 @@ SELECT CASE WHEN a > 2 THEN a ELSE b END FROM foo
2
3
5
NULL
6
7

# column or explicit null
query I
Expand All @@ -42,6 +48,9 @@ SELECT CASE WHEN a > 2 THEN b ELSE null END FROM foo
NULL
4
6
NULL
NULL
7

# column or implicit null
query I
Expand All @@ -50,3 +59,52 @@ SELECT CASE WHEN a > 2 THEN b END FROM foo
NULL
4
6
NULL
NULL
7

# scalar or scalar (string)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also add a test where both arguments are scalars (like CASE WHEN 1 > 2 THEN 'true' ELSE 'false') ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

query T
SELECT CASE WHEN a > 2 THEN 'even' ELSE 'odd' END FROM foo
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like NULL handling in this specialized implementation is not tested, we can add a (NULL, NULL) row into foo table

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I have added this.

----
odd
even
even
odd
even
odd

# scalar or scalar (int)
query I
SELECT CASE WHEN a > 2 THEN 1 ELSE 0 END FROM foo
----
0
1
1
0
1
0

# predicate binary expression with scalars (does not make much sense because the expression in
# this case is always false, so this expression could be rewritten as a literal 0 during planning
query I
SELECT CASE WHEN 1 > 2 THEN 1 ELSE 0 END FROM foo
----
0
0
0
0
0
0

# predicate using boolean literal (does not make much sense because the expression in
# this case is always false, so this expression could be rewritten as a literal 0 during planning
query I
SELECT CASE WHEN false THEN 1 ELSE 0 END FROM foo
----
0
0
0
0
0
0