Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 51 additions & 30 deletions core/src/main/java/org/apache/calcite/rex/RexSimplify.java
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,6 @@ private RexNode simplifyCase(RexCall call) {
}
return last;
}
trueFalse:
if (call.getType().getSqlTypeName() == SqlTypeName.BOOLEAN) {
// Optimize CASE where every branch returns constant true or constant
// false.
Expand Down Expand Up @@ -448,45 +447,67 @@ private RexNode simplifyCase(RexCall call) {
return disjunction;
}
}
// 2) Another simplification
// CASE
// WHEN p1 THEN TRUE
// WHEN p2 THEN FALSE
// WHEN p3 THEN TRUE
// ELSE FALSE
// END
// if p1...pn cannot be nullable
for (Ord<Pair<RexNode, RexNode>> pair : Ord.zip(pairs)) {
if (pair.e.getKey().getType().isNullable()) {
break trueFalse;
}
if (!pair.e.getValue().isAlwaysTrue()
&& !pair.e.getValue().isAlwaysFalse()
&& (!unknownAsFalse || !RexUtil.isNull(pair.e.getValue()))) {
break trueFalse;

if (pairs.size() < 20) {
// 2) Simplify generic boolean CASE into AND/OR:
// CASE
// WHEN p1 THEN e1
// WHEN p2 THEN e2
// ELSE <else>
// END
// (p1 and e1) or (!p1 and p2 and e2) or (!p1 and !p2 and <else>)
// p1..pn not nullable
// if e1..en, <else> are TRUE/FALSE they may
// simplify the final epxression
// p1..pn being TRUE/FALSE were already eliminated by now by casePairs

// We start the NOT predicates with a TRUE
// It will be eliminated by simplification
// It has the role of an empty initial expression
boolean successAnyExpression = true;

RexNode notPredicates = rexBuilder.makeLiteral(true);
final List<RexNode> terms = new ArrayList<>();
for (Ord<Pair<RexNode, RexNode>> pair : Ord.zip(pairs)) {
if (pair.e.getKey().getType().isNullable()) {
successAnyExpression = false;
break;
}
final RexNode toAdd = RexUtil.composeConjunction(rexBuilder,
ImmutableList.<RexNode>of(
notPredicates,
pair.e.getKey(),
pair.e.getValue()),
false);
terms.add(toAdd);
// If the value is TRUE we can skip the key from the NOT list:
// p1 or (!p1 and p2) => p1 or p2
if (!pair.e.getValue().isAlwaysTrue()) {
notPredicates = RexUtil.composeConjunction(rexBuilder,
ImmutableList.<RexNode>of(
notPredicates,
RexUtil.not(pair.e.getKey())),
false);
notPredicates = simplify(notPredicates);
}
}
}
final List<RexNode> terms = new ArrayList<>();
final List<RexNode> notTerms = new ArrayList<>();
for (Ord<Pair<RexNode, RexNode>> pair : Ord.zip(pairs)) {
if (pair.e.getValue().isAlwaysTrue()) {
terms.add(RexUtil.andNot(rexBuilder, pair.e.getKey(), notTerms));
} else {
notTerms.add(pair.e.getKey());

if (successAnyExpression) {
final RexNode disjunction = simplify(RexUtil.composeDisjunction(rexBuilder, terms));
if (!call.getType().equals(disjunction.getType())) {
return rexBuilder.makeCast(call.getType(), disjunction);
}
return disjunction;
}
}
final RexNode disjunction = RexUtil.composeDisjunction(rexBuilder, terms);
if (!call.getType().equals(disjunction.getType())) {
return rexBuilder.makeCast(call.getType(), disjunction);
}
return disjunction;
}
if (newOperands.equals(operands)) {
return call;
}
return call.clone(call.getType(), newOperands);
}


/** Given "CASE WHEN p1 THEN v1 ... ELSE e END"
* returns [(p1, v1), ..., (true, e)]. */
private static List<Pair<RexNode, RexNode>> casePairs(RexBuilder rexBuilder,
Expand Down
152 changes: 149 additions & 3 deletions core/src/test/java/org/apache/calcite/test/RexProgramTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -1100,12 +1100,12 @@ private void checkExponentialCnf(int n) {

// case: remove false branches
checkSimplify(case_(eq(bRef, cRef), dRef, falseLiteral, aRef, eRef),
"CASE(=(?0.b, ?0.c), ?0.d, ?0.e)");
"OR(AND(=(?0.b, ?0.c), ?0.d), AND(<>(?0.b, ?0.c), ?0.e))");

// case: true branches become the last branch
checkSimplify(
case_(eq(bRef, cRef), dRef, trueLiteral, aRef, eq(cRef, dRef), eRef, cRef),
"CASE(=(?0.b, ?0.c), ?0.d, ?0.a)");
"OR(AND(=(?0.b, ?0.c), ?0.d), AND(<>(?0.b, ?0.c), ?0.a))");

// case: singleton
checkSimplify(case_(trueLiteral, aRef, eq(cRef, dRef), eRef, cRef), "?0.a");
Expand All @@ -1117,7 +1117,7 @@ private void checkExponentialCnf(int n) {
// case: trailing false and null, no simplification
checkSimplify2(
case_(aRef, trueLiteral, bRef, trueLiteral, cRef, falseLiteral, unknownLiteral),
"CASE(?0.a, true, ?0.b, true, ?0.c, false, null)",
"OR(?0.a, ?0.b, AND(null, NOT(?0.c)))",
"CAST(OR(?0.a, ?0.b)):BOOLEAN");

// case: form an AND of branches that return true
Expand Down Expand Up @@ -1318,6 +1318,152 @@ private void checkExponentialCnf(int n) {
assertThat(result, is(caseNode));
}

@Test public void testSimplifyCaseAsAndOr() throws Exception {
final RelDataType intType = typeFactory.createSqlType(SqlTypeName.INTEGER);
final RelDataType strType = typeFactory.createSqlType(SqlTypeName.VARCHAR, 255);
final RelDataType rowType = typeFactory.builder()
.add("i", intType)
.add("s", strType)
.build();

final RexDynamicParam range = rexBuilder.makeDynamicParam(rowType, 0);
final RexNode intRef = rexBuilder.makeFieldAccess(range, 0);
final RexNode strRef = rexBuilder.makeFieldAccess(range, 1);
final RexNode val3 = rexBuilder.makeLiteral(3, intType, false);
final RexNode val5 = rexBuilder.makeLiteral(5, intType, false);
final RexNode val7 = rexBuilder.makeLiteral(7, intType, false);
final RexNode val10 = rexBuilder.makeLiteral(10, intType, false);
final RexNode val11 = rexBuilder.makeLiteral(11, intType, false);
final RexNode val13 = rexBuilder.makeLiteral(13, intType, false);
final RexNode val25 = rexBuilder.makeLiteral(25, intType, false);
final RexNode val50 = rexBuilder.makeLiteral(50, intType, false);
final RexNode val100 = rexBuilder.makeLiteral(100, intType, false);
final RexNode val200 = rexBuilder.makeLiteral(200, intType, false);
final RexNode valFoo = rexBuilder.makeLiteral("foo");
final RexNode valBar = rexBuilder.makeLiteral("bar");
final RexNode valFar = rexBuilder.makeLiteral("far");
final RexNode valCar = rexBuilder.makeLiteral("car");

// CASE
// WHEN i = 3 THEN TRUE
// WHEN i = 5 THEN TRUE
// ELSE FALSE
// END
checkSimplifyFilter(
case_(
eq(intRef, val3), trueLiteral,
eq(intRef, val5), trueLiteral,
falseLiteral),
"OR(=(?0.i, 3), =(?0.i, 5))");

// CASE
// WHEN $1 = 3 THEN TRUE
// WHEN $1 = 5 THEN TRUE
// ELSE TRUE
// END
checkSimplifyFilter(
case_(
eq(intRef, val3), trueLiteral,
eq(intRef, val5), trueLiteral,
trueLiteral),
"true");

// CASE
// WHEN $1 = 3 THEN $2 = foo
// WHEN $1 = 5 THEN $2= bar
// ELSE TRUE
// END
checkSimplifyFilter(
case_(
eq(intRef, val3), eq(strRef, valFoo),
eq(intRef, val5), eq(strRef, valBar),
trueLiteral),
"OR("
+ "AND(=(?0.i, 3), =(?0.s, 'foo')), "
+ "AND(<>(?0.i, 3), =(?0.i, 5), =(?0.s, 'bar')), "
+ "AND(<>(?0.i, 3), <>(?0.i, 5)))");

// CASE
// WHEN $1 = 3 THEN $2 = foo
// WHEN $1 = 5 THEN $2= bar
// ELSE FALSE
// END
checkSimplifyFilter(
case_(
eq(intRef, val3), eq(strRef, valFoo),
eq(intRef, val5), eq(strRef, valBar),
falseLiteral),
"OR("
+ "AND(=(?0.i, 3), =(?0.s, 'foo')), "
+ "AND(<>(?0.i, 3), =(?0.i, 5), =(?0.s, 'bar')))");

// CASE
// WHEN $1 = 3 THEN $2 = foo
// WHEN $1 > 10 THEN
// CASE
// WHEN $i = 11 THEN $2 = bar
// WHEN $i = 13 THEN $2 = far
// ELSE ...
// WHEN $1 < 3 THEN $2 = car
// ELSE ...
// END
checkSimplifyFilter(
case_(
eq(intRef, val3), eq(strRef, valFoo),
gt(intRef, val10), case_(
eq(intRef, val11), eq(strRef, valBar),
eq(intRef, val13), eq(strRef, valFar),
falseLiteral),
lt(intRef, val3), eq(strRef, valCar),
trueLiteral),
"OR("
+ "AND(=(?0.i, 3), =(?0.s, 'foo')), "
+ "AND(<>(?0.i, 3), >(?0.i, 10), OR("
+ "AND(=(?0.i, 11), =(?0.s, 'bar')), "
+ "AND(<>(?0.i, 11), =(?0.i, 13), =(?0.s, 'far')))), "
+ "AND(<>(?0.i, 3), <(?0.i, 3), =(?0.s, 'car')), "
+ "AND(<>(?0.i, 3), <=(?0.i, 10), >=(?0.i, 3)))");

// CASE
// WHEN $1 = 3 THEN $2 = foo
// WHEN CASE
// WHEN $i >100 THEN $1 < 200
// WHEN $i < 50 THEN $1 > 25
// ELSE ...
// END THEN $2=bar
// WHEN $1 < 3 THEN $2 = car
// ELSE ...
// END
checkSimplifyFilter(
case_(
eq(intRef, val3), eq(strRef, valFoo),
case_(
gt(intRef, val100), lt(intRef, val200),
lt(intRef, val50), gt(intRef, val25),
trueLiteral), eq(strRef, valCar),
lt(intRef, val3), eq(strRef, valCar),
falseLiteral),
"OR("
+ "AND(=(?0.i, 3), =(?0.s, 'foo')),"
+ " AND(<>(?0.i, 3), OR("
+ "AND(>(?0.i, 100), <(?0.i, 200)),"
+ " AND(<=(?0.i, 100), <(?0.i, 50), >(?0.i, 25)),"
+ " AND(<=(?0.i, 100), >=(?0.i, 50))"
+ "), =(?0.s, 'car')),"
+ " AND(<>(?0.i, 3), OR("
+ "<=(?0.i, 100), "
+ ">=(?0.i, 200)"
+ "), OR("
+ ">(?0.i, 100), "
+ ">=(?0.i, 50), "
+ "<=(?0.i, 25)"
+ "), OR("
+ ">(?0.i, 100), "
+ "<(?0.i, 50)), "
+ "<(?0.i, 3), "
+ "=(?0.s, 'car')))");
}

@Test public void testSimplifyAnd() {
RelDataType booleanNotNullableType =
typeFactory.createTypeWithNullability(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ where case when (sal = 1000) then
<Resource name="planAfter">
<![CDATA[
LogicalProject(SAL=[$5])
LogicalFilter(condition=[CASE(=($5, 1000), =($5, 1000), =($5, 2000))])
LogicalFilter(condition=[OR(=($5, 1000), AND(<>($5, 1000), =($5, 2000)))])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
]]>
</Resource>
Expand Down Expand Up @@ -5631,7 +5631,7 @@ LogicalAggregate(group=[{}], EXPR$0=[COUNT()])
<![CDATA[
LogicalAggregate(group=[{}], EXPR$0=[COUNT()])
LogicalProject($f0=[1])
LogicalFilter(condition=[=($7, 10)])
LogicalFilter(condition=[AND(<>($7, 20), =($7, 10))])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
]]>
</Resource>
Expand All @@ -5656,7 +5656,7 @@ LogicalAggregate(group=[{}], EXPR$0=[COUNT()])
<![CDATA[
LogicalAggregate(group=[{}], EXPR$0=[COUNT()])
LogicalProject($f0=[1])
LogicalFilter(condition=[=($7, 10)])
LogicalFilter(condition=[OR(AND(<>($7, 20), =($7, 10)), AND(<>($7, 20), null))])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
]]>
</Resource>
Expand All @@ -5683,7 +5683,7 @@ LogicalAggregate(group=[{}], EXPR$0=[COUNT()])
<![CDATA[
LogicalAggregate(group=[{}], EXPR$0=[COUNT()])
LogicalProject($f0=[1])
LogicalFilter(condition=[OR(=($7, 30), =($7, 10))])
LogicalFilter(condition=[OR(=($7, 30), AND(<>($7, 20), =($7, 10)))])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
]]>
</Resource>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1921,7 +1921,7 @@ from dept]]>
</Resource>
<Resource name="plan">
<![CDATA[
LogicalProject(NAME=[$1], EXPR$1=[CASE(=($2, 0), false, IS NOT NULL($6), true, <($3, $2), null, false)])
LogicalProject(NAME=[$1], EXPR$1=[OR(AND(<>($2, 0), IS NOT NULL($6)), AND(<>($2, 0), <($3, $2), null))])
LogicalJoin(condition=[=($4, $5)], joinType=[left])
LogicalProject($f0=[$0], $f1=[$1], $f2=[$2], $f3=[$3], $f4=[$0])
LogicalJoin(condition=[true], joinType=[inner])
Expand Down