diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ConcatOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ConcatOperatorConversion.java index 8ba2ca6dcb82..39428dc20d7a 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ConcatOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ConcatOperatorConversion.java @@ -28,7 +28,7 @@ public class ConcatOperatorConversion extends DirectOperatorConversion { - private static final SqlFunction SQL_FUNCTION = OperatorConversions + public static final SqlFunction SQL_FUNCTION = OperatorConversions .operatorBuilder("CONCAT") .operandTypeChecker(OperandTypes.SAME_VARIADIC) .returnTypeCascadeNullable(SqlTypeName.VARCHAR) diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TextcatOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TextcatOperatorConversion.java index 40a13479b8c1..acdbc1029caa 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TextcatOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/TextcatOperatorConversion.java @@ -28,7 +28,7 @@ public class TextcatOperatorConversion extends DirectOperatorConversion { - private static final SqlFunction SQL_FUNCTION = OperatorConversions + public static final SqlFunction SQL_FUNCTION = OperatorConversions .operatorBuilder("textcat") .operandTypes(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER) .requiredOperandCount(2) diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/CalciteRulesManager.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/CalciteRulesManager.java index 06c73fa9d3bb..78d8ca7d0982 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/CalciteRulesManager.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/CalciteRulesManager.java @@ -31,6 +31,7 @@ import org.apache.calcite.plan.hep.HepProgram; import org.apache.calcite.plan.hep.HepProgramBuilder; import org.apache.calcite.plan.volcano.AbstractConverter; +import org.apache.calcite.plan.volcano.VolcanoPlanner; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.RelFactories; import org.apache.calcite.rel.metadata.DefaultRelMetadataProvider; @@ -55,13 +56,16 @@ import org.apache.druid.sql.calcite.rule.DruidTableScanRule; import org.apache.druid.sql.calcite.rule.ExtensionCalciteRuleProvider; import org.apache.druid.sql.calcite.rule.FilterDecomposeCoalesceRule; +import org.apache.druid.sql.calcite.rule.FilterDecomposeConcatRule; import org.apache.druid.sql.calcite.rule.FilterJoinExcludePushToChildRule; +import org.apache.druid.sql.calcite.rule.FlattenConcatRule; import org.apache.druid.sql.calcite.rule.ProjectAggregatePruneUnusedCallRule; import org.apache.druid.sql.calcite.rule.SortCollapseRule; import org.apache.druid.sql.calcite.rule.logical.DruidLogicalRules; import org.apache.druid.sql.calcite.run.EngineFeature; import java.util.ArrayList; +import java.util.Collection; import java.util.List; import java.util.Set; @@ -88,7 +92,7 @@ public class CalciteRulesManager * 3) {@link CoreRules#JOIN_COMMUTE}, {@link JoinPushThroughJoinRule#RIGHT}, {@link JoinPushThroughJoinRule#LEFT}, * and {@link CoreRules#FILTER_INTO_JOIN}, which are part of {@link #FANCY_JOIN_RULES}. * 4) {@link CoreRules#PROJECT_FILTER_TRANSPOSE} because PartialDruidQuery would like to have the Project on top of the Filter - - * this rule could create a lot of non-usefull plans. + * this rule could create a lot of non-useful plans. */ private static final List BASE_RULES = ImmutableList.of( @@ -228,50 +232,87 @@ public CalciteRulesManager(final Set extensionCalc public List programs(final PlannerContext plannerContext) { final boolean isDebug = plannerContext.queryContext().isDebug(); - - // Program that pre-processes the tree before letting the full-on VolcanoPlanner loose. - final List prePrograms = new ArrayList<>(); - prePrograms.add(new LoggingProgram("Start", isDebug)); - prePrograms.add(Programs.subQuery(DefaultRelMetadataProvider.INSTANCE)); - prePrograms.add(new LoggingProgram("Finished subquery program", isDebug)); - prePrograms.add(DecorrelateAndTrimFieldsProgram.INSTANCE); - prePrograms.add(new LoggingProgram("Finished decorrelate and trim fields program", isDebug)); - prePrograms.add(buildCoalesceProgram()); - prePrograms.add(new LoggingProgram("Finished coalesce program", isDebug)); - prePrograms.add(buildReductionProgram(plannerContext)); - prePrograms.add(new LoggingProgram("Finished expression reduction program", isDebug)); - - final Program preProgram = Programs.sequence(prePrograms.toArray(new Program[0])); + final Program druidPreProgram = buildPreProgram(plannerContext, true); + final Program bindablePreProgram = buildPreProgram(plannerContext, false); return ImmutableList.of( Programs.sequence( - preProgram, + druidPreProgram, Programs.ofRules(druidConventionRuleSet(plannerContext)), new LoggingProgram("After Druid volcano planner program", isDebug) ), Programs.sequence( - preProgram, + bindablePreProgram, Programs.ofRules(bindableConventionRuleSet(plannerContext)), new LoggingProgram("After bindable volcano planner program", isDebug) ), Programs.sequence( - preProgram, + druidPreProgram, Programs.ofRules(logicalConventionRuleSet(plannerContext)), new LoggingProgram("After logical volcano planner program", isDebug) ) ); } - private Program buildReductionProgram(final PlannerContext plannerContext) + /** + * Build the program that runs prior to the cost-based {@link VolcanoPlanner}. + * + * @param plannerContext planner context + * @param isDruid whether this is a Druid program + */ + private Program buildPreProgram(final PlannerContext plannerContext, final boolean isDruid) + { + final boolean isDebug = plannerContext.queryContext().isDebug(); + + // Program that pre-processes the tree before letting the full-on VolcanoPlanner loose. + final List prePrograms = new ArrayList<>(); + prePrograms.add(new LoggingProgram("Start", isDebug)); + prePrograms.add(Programs.subQuery(DefaultRelMetadataProvider.INSTANCE)); + prePrograms.add(new LoggingProgram("Finished subquery program", isDebug)); + prePrograms.add(DecorrelateAndTrimFieldsProgram.INSTANCE); + prePrograms.add(new LoggingProgram("Finished decorrelate and trim fields program", isDebug)); + prePrograms.add(buildReductionProgram(plannerContext, isDruid)); + prePrograms.add(new LoggingProgram("Finished expression reduction program", isDebug)); + + return Programs.sequence(prePrograms.toArray(new Program[0])); + } + + /** + * Builds an expression reduction program using {@link #REDUCTION_RULES} (built-in to Calcite) plus some + * Druid-specific rules. + */ + private Program buildReductionProgram(final PlannerContext plannerContext, final boolean isDruid) { - List hepRules = new ArrayList(REDUCTION_RULES); + final List hepRules = new ArrayList<>(); + + if (isDruid) { + // Must run before REDUCTION_RULES, since otherwise ReduceExpressionsRule#pushPredicateIntoCase may + // make it impossible to convert to COALESCE. + hepRules.add(new CaseToCoalesceRule()); + hepRules.add(new CoalesceLookupRule()); + + // Flatten calls to CONCAT, which happen easily with the || operator since it only accepts two arguments. + hepRules.add(new FlattenConcatRule()); + + // Decompose filters on COALESCE to promote more usage of indexes. + hepRules.add(new FilterDecomposeCoalesceRule()); + } + + // Calcite's builtin reduction rules. + hepRules.addAll(REDUCTION_RULES); + + if (isDruid) { + // Decompose filters on CONCAT to promote more usage of indexes. Runs after REDUCTION_RULES because + // this rule benefits from reduction of effectively-literal calls to actual literals. + hepRules.add(new FilterDecomposeConcatRule()); + } + // Apply CoreRules#FILTER_INTO_JOIN early to avoid exploring less optimal plans. - if (plannerContext.getJoinAlgorithm().requiresSubquery()) { + if (isDruid && plannerContext.getJoinAlgorithm().requiresSubquery()) { hepRules.add(CoreRules.FILTER_INTO_JOIN); } - return buildHepProgram( - hepRules - ); + + return buildHepProgram(hepRules); } private static class LoggingProgram implements Program @@ -372,7 +413,13 @@ public List baseRuleSet(final PlannerContext plannerContext) return rules.build(); } - private static Program buildHepProgram(final Iterable rules) + /** + * Build a {@link HepProgram} to apply rules mechanically as part of {@link #buildPreProgram}. Rules are applied + * one-by-one. + * + * @param rules rules to apply + */ + private static Program buildHepProgram(final Collection rules) { final HepProgramBuilder builder = HepProgram.builder(); builder.addMatchLimit(CalciteRulesManager.HEP_DEFAULT_MATCH_LIMIT); @@ -382,20 +429,6 @@ private static Program buildHepProgram(final Iterable rule return Programs.of(builder.build(), true, DefaultRelMetadataProvider.INSTANCE); } - /** - * Program that performs various manipulations related to COALESCE. - */ - private static Program buildCoalesceProgram() - { - return buildHepProgram( - ImmutableList.of( - new CaseToCoalesceRule(), - new CoalesceLookupRule(), - new FilterDecomposeCoalesceRule() - ) - ); - } - /** * Based on Calcite's Programs.DecorrelateProgram and Programs.TrimFieldsProgram, which are private and only * accessible through Programs.standard (which we don't want, since it also adds Enumerable rules). diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rule/FilterDecomposeConcatRule.java b/sql/src/main/java/org/apache/druid/sql/calcite/rule/FilterDecomposeConcatRule.java new file mode 100644 index 000000000000..1a28392a5b4e --- /dev/null +++ b/sql/src/main/java/org/apache/druid/sql/calcite/rule/FilterDecomposeConcatRule.java @@ -0,0 +1,296 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.sql.calcite.rule; + +import com.google.common.collect.HashMultiset; +import com.google.common.collect.Iterables; +import com.google.common.collect.Multiset; +import org.apache.calcite.plan.RelOptRule; +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.rel.core.Filter; +import org.apache.calcite.rel.rules.SubstitutionRule; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexLiteral; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexShuttle; +import org.apache.calcite.rex.RexUtil; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.type.SqlTypeFamily; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.druid.common.config.NullHandling; + +import javax.annotation.Nullable; +import java.util.ArrayList; +import java.util.List; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * Transform calls like [CONCAT(x, '-', y) = 'a-b'] => [x = 'a' AND y = 'b']. + */ +public class FilterDecomposeConcatRule extends RelOptRule implements SubstitutionRule +{ + public FilterDecomposeConcatRule() + { + super(operand(Filter.class, any())); + } + + @Override + public void onMatch(RelOptRuleCall call) + { + final Filter oldFilter = call.rel(0); + final DecomposeConcatShuttle shuttle = new DecomposeConcatShuttle( + oldFilter.getCluster().getRexBuilder()); + final RexNode newCondition = oldFilter.getCondition().accept(shuttle); + + //noinspection ObjectEquality + if (newCondition != oldFilter.getCondition()) { + call.transformTo( + call.builder() + .push(oldFilter.getInput()) + .filter(newCondition).build() + ); + + call.getPlanner().prune(oldFilter); + } + } + + /** + * Shuttle that decomposes predicates on top of CONCAT calls. + */ + static class DecomposeConcatShuttle extends RexShuttle + { + private final RexBuilder rexBuilder; + + DecomposeConcatShuttle(final RexBuilder rexBuilder) + { + this.rexBuilder = rexBuilder; + } + + @Override + public RexNode visitCall(final RexCall call) + { + final RexNode newCall; + final boolean negate; + + if (call.isA(SqlKind.EQUALS) || call.isA(SqlKind.NOT_EQUALS)) { + // Convert: [CONCAT(x, '-', y) = 'a-b'] => [x = 'a' AND y = 'b'] + // Convert: [CONCAT(x, '-', y) <> 'a-b'] => [NOT (x = 'a' AND y = 'b')] + negate = call.isA(SqlKind.NOT_EQUALS); + final RexNode lhs = call.getOperands().get(0); + final RexNode rhs = call.getOperands().get(1); + + if (FlattenConcatRule.isNonTrivialStringConcat(lhs) && RexUtil.isLiteral(rhs, true)) { + newCall = tryDecomposeConcatEquals((RexCall) lhs, rhs, rexBuilder); + } else if (FlattenConcatRule.isNonTrivialStringConcat(rhs) && RexUtil.isLiteral(lhs, true)) { + newCall = tryDecomposeConcatEquals((RexCall) rhs, lhs, rexBuilder); + } else { + newCall = null; + } + } else if ((call.isA(SqlKind.IS_NULL) || call.isA(SqlKind.IS_NOT_NULL)) + && FlattenConcatRule.isNonTrivialStringConcat(Iterables.getOnlyElement(call.getOperands()))) { + negate = call.isA(SqlKind.IS_NOT_NULL); + final RexCall concatCall = (RexCall) Iterables.getOnlyElement(call.getOperands()); + if (NullHandling.sqlCompatible()) { + // Convert: [CONCAT(x, '-', y) IS NULL] => [x IS NULL OR y IS NULL] + newCall = RexUtil.composeDisjunction( + rexBuilder, + Iterables.transform( + concatCall.getOperands(), + operand -> rexBuilder.makeCall(SqlStdOperatorTable.IS_NULL, operand) + ) + ); + } else { + // Treat [CONCAT(x, '-', y) IS NULL] as [CONCAT(x, '-', y) = ''] + newCall = tryDecomposeConcatEquals(concatCall, rexBuilder.makeLiteral(""), rexBuilder); + } + } else { + negate = false; + newCall = null; + } + + if (newCall != null) { + // Found a CONCAT comparison to decompose. + return negate ? rexBuilder.makeCall(SqlStdOperatorTable.NOT, newCall) : newCall; + } else { + // Didn't find anything interesting. Visit children of original call. + return super.visitCall(call); + } + } + } + + /** + * Convert [CONCAT(x, '-', y) = 'a-b'] => [x = 'a' AND y = 'b']. + * + * @param concatCall the call to concat, i.e. CONCAT(x, '-', y) + * @param matchRexNode the literal being matched, i.e. 'a-b' + * @param rexBuilder rex builder + */ + @Nullable + private static RexNode tryDecomposeConcatEquals( + final RexCall concatCall, + final RexNode matchRexNode, + final RexBuilder rexBuilder + ) + { + final String matchValue = getAsString(matchRexNode); + if (matchValue == null) { + return null; + } + + // We can decompose if all nonliterals are separated by literals, and if each literal appears in the matchValue + // string exactly the number of times that it appears in the call to CONCAT. (In this case, the concatenation can + // be unambiguously reversed.) + final StringBuilder regexBuilder = new StringBuilder(); + final List nonLiterals = new ArrayList<>(); + final Multiset literalCounter = HashMultiset.create(); + boolean expectLiteral = false; // If true, next operand must be a literal. + for (int i = 0; i < concatCall.getOperands().size(); i++) { + final RexNode operand = concatCall.getOperands().get(i); + if (RexUtil.isLiteral(operand, true)) { + final String operandValue = getAsString(operand); + if (operandValue == null || operandValue.isEmpty()) { + return null; + } + + regexBuilder.append(Pattern.quote(operandValue)); + literalCounter.add(operandValue); + expectLiteral = false; + } else { + if (expectLiteral) { + return null; + } + + nonLiterals.add(operand); + regexBuilder.append("(.*)"); + expectLiteral = true; + } + } + + // Verify, using literalCounter, that each literal appears in the matchValue the correct number of times. + for (Multiset.Entry entry : literalCounter.entrySet()) { + final int occurrences = countOccurrences(matchValue, entry.getElement()); + if (occurrences > entry.getCount()) { + // If occurrences > entry.getCount(), the match is ambiguous; consider concat(x, 'x', y) = '2x3x4' + return null; + } else if (occurrences < entry.getCount()) { + return impossibleMatch(nonLiterals, rexBuilder); + } + } + + // Apply the regex to the matchValue to get the expected value of each non-literal. + final Pattern regex = Pattern.compile(regexBuilder.toString(), Pattern.DOTALL); + final Matcher matcher = regex.matcher(matchValue); + if (matcher.matches()) { + final List conditions = new ArrayList<>(nonLiterals.size()); + for (int i = 0; i < nonLiterals.size(); i++) { + final RexNode operand = nonLiterals.get(i); + conditions.add( + rexBuilder.makeCall( + SqlStdOperatorTable.EQUALS, + operand, + rexBuilder.makeLiteral(matcher.group(i + 1)) + ) + ); + } + + return RexUtil.composeConjunction(rexBuilder, conditions); + } else { + return impossibleMatch(nonLiterals, rexBuilder); + } + } + + /** + * Generate an expression for the case where matching is impossible. + * + * This expression might be FALSE and might be UNKNOWN depending on whether any of the inputs are null. Use the + * construct "x IS NULL AND UNKNOWN" for each arg x to CONCAT, which is FALSE if x is not null and UNKNOWN is x + * is null. Then OR them all together, so the entire expression is FALSE if all args are not null, and UNKNOWN if any arg is null. + * + * @param nonLiterals non-literal arguments to CONCAT + */ + private static RexNode impossibleMatch(final List nonLiterals, final RexBuilder rexBuilder) + { + if (NullHandling.sqlCompatible()) { + // This expression might be FALSE and might be UNKNOWN depending on whether any of the inputs are null. Use the + // construct "x IS NULL AND UNKNOWN" for each arg x to CONCAT, which is FALSE if x is not null and UNKNOWN if + // x is null. Then OR them all together, so the entire expression is FALSE if all args are not null, and + // UNKNOWN if any arg is null. + final RexLiteral unknown = + rexBuilder.makeNullLiteral(rexBuilder.getTypeFactory().createSqlType(SqlTypeName.BOOLEAN)); + return RexUtil.composeDisjunction( + rexBuilder, + Iterables.transform( + nonLiterals, + operand -> rexBuilder.makeCall( + SqlStdOperatorTable.AND, + rexBuilder.makeCall(SqlStdOperatorTable.IS_NULL, operand), + unknown + ) + ) + ); + } else { + return rexBuilder.makeLiteral(false); + } + } + + /** + * Given a literal (which may be wrapped in a cast), remove the cast call (if any) and read it as a string. + * Returns null if the rex can't be read as a string. + */ + @Nullable + private static String getAsString(final RexNode rexNode) + { + if (!SqlTypeFamily.STRING.contains(rexNode.getType())) { + // We don't expect this to happen, since this method is used when reading from RexNodes that are expected + // to be strings. But if it does (CONCAT operator that accepts non-strings?), return null so we skip the + // optimization. + return null; + } + + // Get matchValue from the matchLiteral (remove cast call if any, then read as string). + final RexNode matchLiteral = RexUtil.removeCast(rexNode); + if (SqlTypeFamily.STRING.contains(matchLiteral.getType())) { + return RexLiteral.stringValue(matchLiteral); + } else if (SqlTypeFamily.NUMERIC.contains(matchLiteral.getType())) { + return String.valueOf(RexLiteral.value(matchLiteral)); + } else { + return null; + } + } + + /** + * Count the number of occurrences of substring in string. Considers overlapping occurrences as multiple occurrences; + * for example the string "--" is counted as appearing twice in "---". + */ + private static int countOccurrences(final String string, final String substring) + { + int count = 0; + int i = -1; + + while ((i = string.indexOf(substring, i + 1)) >= 0) { + count++; + } + + return count; + } +} diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rule/FlattenConcatRule.java b/sql/src/main/java/org/apache/druid/sql/calcite/rule/FlattenConcatRule.java new file mode 100644 index 000000000000..589c590b1076 --- /dev/null +++ b/sql/src/main/java/org/apache/druid/sql/calcite/rule/FlattenConcatRule.java @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.sql.calcite.rule; + +import org.apache.calcite.plan.RelOptRule; +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.rules.SubstitutionRule; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexShuttle; +import org.apache.calcite.rex.RexUtil; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.type.SqlTypeFamily; +import org.apache.druid.common.config.NullHandling; +import org.apache.druid.math.expr.Function; +import org.apache.druid.sql.calcite.expression.builtin.ConcatOperatorConversion; +import org.apache.druid.sql.calcite.expression.builtin.TextcatOperatorConversion; + +import java.util.ArrayList; +import java.util.List; + +/** + * Flattens calls to CONCAT. Useful because otherwise [a || b || c] would get planned as [CONCAT(CONCAT(a, b), c)]. + */ +public class FlattenConcatRule extends RelOptRule implements SubstitutionRule +{ + public FlattenConcatRule() + { + super(operand(RelNode.class, any())); + } + + @Override + public void onMatch(RelOptRuleCall call) + { + final RelNode oldNode = call.rel(0); + final FlattenConcatShuttle shuttle = new FlattenConcatShuttle(oldNode.getCluster().getRexBuilder()); + final RelNode newNode = oldNode.accept(shuttle); + + //noinspection ObjectEquality + if (newNode != oldNode) { + call.transformTo(newNode); + call.getPlanner().prune(oldNode); + } + } + + private static class FlattenConcatShuttle extends RexShuttle + { + private final RexBuilder rexBuilder; + + public FlattenConcatShuttle(RexBuilder rexBuilder) + { + this.rexBuilder = rexBuilder; + } + + @Override + public RexNode visitCall(RexCall call) + { + if (isNonTrivialStringConcat(call)) { + final List newOperands = new ArrayList<>(); + for (final RexNode operand : call.getOperands()) { + if (isNonTrivialStringConcat(operand)) { + // Recursively flatten. We only flatten non-trivial CONCAT calls, because trivial ones (which do not + // reference any inputs) are reduced to constants by ReduceExpressionsRule. + final RexNode visitedOperand = visitCall((RexCall) operand); + + if (isStringConcat(visitedOperand)) { + newOperands.addAll(((RexCall) visitedOperand).getOperands()); + } else { + newOperands.add(visitedOperand); + } + } else if (RexUtil.isNullLiteral(operand, true) && NullHandling.sqlCompatible()) { + return rexBuilder.makeNullLiteral(call.getType()); + } else { + newOperands.add(operand); + } + } + + if (!newOperands.equals(call.getOperands())) { + return rexBuilder.makeCall(ConcatOperatorConversion.SQL_FUNCTION, newOperands); + } else { + return call; + } + } else { + return super.visitCall(call); + } + } + } + + /** + * Whether a rex is a string concatenation operator. All of these end up being converted to + * {@link Function.ConcatFunc}. + */ + static boolean isStringConcat(final RexNode rexNode) + { + if (SqlTypeFamily.STRING.contains(rexNode.getType()) && rexNode instanceof RexCall) { + final SqlOperator operator = ((RexCall) rexNode).getOperator(); + return ConcatOperatorConversion.SQL_FUNCTION.equals(operator) + || TextcatOperatorConversion.SQL_FUNCTION.equals(operator) + || SqlStdOperatorTable.CONCAT.equals(operator); + } else { + return false; + } + } + + /** + * Whether a rex is a string concatenation involving at least one an input field. + */ + static boolean isNonTrivialStringConcat(final RexNode rexNode) + { + return isStringConcat(rexNode) && !RelOptUtil.InputFinder.bits(rexNode).isEmpty(); + } +} diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteMultiValueStringQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteMultiValueStringQueryTest.java index 7226811dba04..97f41c425e43 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteMultiValueStringQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteMultiValueStringQueryTest.java @@ -161,7 +161,7 @@ public void testMultiValueStringWorksLikeStringGroupByWithFilter() new DefaultDimensionSpec("v0", "_d0", ColumnType.STRING) ) ) - .setDimFilter(equality("v0", "bfoo", ColumnType.STRING)) + .setDimFilter(equality("dim3", "b", ColumnType.STRING)) .setAggregatorSpecs(aggregators(new LongSumAggregatorFactory("a0", "cnt"))) .setLimitSpec(new DefaultLimitSpec( ImmutableList.of(new OrderByColumnSpec( @@ -248,7 +248,7 @@ public void testMultiValueStringWorksLikeStringScanWithFilter() .dataSource(CalciteTests.DATASOURCE3) .eternityInterval() .virtualColumns(expressionVirtualColumn("v0", "concat(\"dim3\",'foo')", ColumnType.STRING)) - .filters(equality("v0", "bfoo", ColumnType.STRING)) + .filters(equality("dim3", "b", ColumnType.STRING)) .columns(ImmutableList.of("v0")) .context(QUERY_CONTEXT_DEFAULT) .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java index 2f795f54ed74..232228f5654e 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java @@ -11886,16 +11886,21 @@ public void testConcat() new Object[]{"abc-abc_abc"} ) ); + } + @Test + public void testConcat2() + { + // Tests flattening CONCAT, and tests reduction of concat('x', 'y') => 'xy' testQuery( - "SELECT CONCAt(dim1, CONCAt(dim2,'x'), m2, 9999, dim1) as dimX FROM foo", + "SELECT CONCAt(dim1, CONCAt(dim2,concat('x', 'y')), m2, 9999, dim1) as dimX FROM foo", ImmutableList.of( newScanQueryBuilder() .dataSource(CalciteTests.DATASOURCE1) .intervals(querySegmentSpec(Filtration.eternity())) .virtualColumns(expressionVirtualColumn( "v0", - "concat(\"dim1\",concat(\"dim2\",'x'),\"m2\",9999,\"dim1\")", + "concat(\"dim1\",\"dim2\",'xy',\"m2\",9999,\"dim1\")", ColumnType.STRING )) .columns("v0") @@ -11904,12 +11909,12 @@ public void testConcat() .build() ), ImmutableList.of( - new Object[]{"ax1.09999"}, - new Object[]{NullHandling.sqlCompatible() ? null : "10.1x2.0999910.1"}, // dim2 is null - new Object[]{"2x3.099992"}, - new Object[]{"1ax4.099991"}, - new Object[]{"defabcx5.09999def"}, - new Object[]{NullHandling.sqlCompatible() ? null : "abcx6.09999abc"} // dim2 is null + new Object[]{"axy1.09999"}, + new Object[]{NullHandling.sqlCompatible() ? null : "10.1xy2.0999910.1"}, // dim2 is null + new Object[]{"2xy3.099992"}, + new Object[]{"1axy4.099991"}, + new Object[]{"defabcxy5.09999def"}, + new Object[]{NullHandling.sqlCompatible() ? null : "abcxy6.09999abc"} // dim2 is null ) ); } @@ -11942,10 +11947,14 @@ public void testConcatGroup() new Object[]{"def-def_def"} ) ); + } - final List secondResults; + @Test + public void testConcatGroup2() + { + final List results; if (useDefault) { - secondResults = ImmutableList.of( + results = ImmutableList.of( new Object[]{"10.1x2.0999910.1"}, new Object[]{"1ax4.099991"}, new Object[]{"2x3.099992"}, @@ -11954,7 +11963,7 @@ public void testConcatGroup() new Object[]{"defabcx5.09999def"} ); } else { - secondResults = ImmutableList.of( + results = ImmutableList.of( new Object[]{null}, new Object[]{"1ax4.099991"}, new Object[]{"2x3.099992"}, @@ -11962,6 +11971,7 @@ public void testConcatGroup() new Object[]{"defabcx5.09999def"} ); } + testQuery( "SELECT CONCAT(dim1, CONCAT(dim2,'x'), m2, 9999, dim1) as dimX FROM foo GROUP BY 1", ImmutableList.of( @@ -11970,7 +11980,7 @@ public void testConcatGroup() .setInterval(querySegmentSpec(Filtration.eternity())) .setVirtualColumns(expressionVirtualColumn( "v0", - "concat(\"dim1\",concat(\"dim2\",'x'),\"m2\",9999,\"dim1\")", + "concat(\"dim1\",\"dim2\",'x',\"m2\",9999,\"dim1\")", ColumnType.STRING )) .setDimensions(dimensions(new DefaultDimensionSpec("v0", "d0"))) @@ -11979,7 +11989,172 @@ public void testConcatGroup() .build() ), - secondResults + results + ); + } + + @Test + public void testConcatDecomposeAlwaysFalseOrUnknown() + { + testQuery( + "SELECT CONCAT(dim1, 'x', dim2) as dimX\n" + + "FROM foo\n" + + "WHERE CONCAT(dim1, 'x', dim2) IN ('1a', '3x4')", + ImmutableList.of( + newScanQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .virtualColumns(expressionVirtualColumn("v0", "concat(\"dim1\",'x',\"dim2\")", ColumnType.STRING)) + .filters(and( + equality("dim1", "3", ColumnType.STRING), + equality("dim2", "4", ColumnType.STRING) + )) + .columns("v0") + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .context(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of() + ); + } + + @Test + public void testConcatDecomposeAlwaysFalseOrUnknownNegated() + { + testQuery( + "SELECT CONCAT(dim1, 'x', dim2) as dimX\n" + + "FROM foo\n" + + "WHERE CONCAT(dim1, 'x', dim2) NOT IN ('1a', '3x4', '4x5')\n", + ImmutableList.of( + newScanQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .virtualColumns(expressionVirtualColumn( + "v0", + "concat(\"dim1\",'x',\"dim2\")", + ColumnType.STRING + )) + .filters( + NullHandling.sqlCompatible() + ? and( + or( + not(equality("dim1", "3", ColumnType.STRING)), + not(equality("dim2", "4", ColumnType.STRING)) + ), + or( + not(equality("dim1", "4", ColumnType.STRING)), + not(equality("dim2", "5", ColumnType.STRING)) + ), + notNull("dim1"), + notNull("dim2") + ) + : and( + or( + not(equality("dim1", "3", ColumnType.STRING)), + not(equality("dim2", "4", ColumnType.STRING)) + ), + or( + not(equality("dim1", "4", ColumnType.STRING)), + not(equality("dim2", "5", ColumnType.STRING)) + ) + ) + ) + .columns("v0") + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .context(QUERY_CONTEXT_DEFAULT) + .build() + ), + NullHandling.sqlCompatible() + ? ImmutableList.of( + new Object[]{"xa"}, + new Object[]{"2x"}, + new Object[]{"1xa"}, + new Object[]{"defxabc"} + ) + : ImmutableList.of( + new Object[]{"xa"}, + new Object[]{"10.1x"}, + new Object[]{"2x"}, + new Object[]{"1xa"}, + new Object[]{"defxabc"}, + new Object[]{"abcx"} + ) + ); + } + + @Test + public void testConcatDecomposeIsNull() + { + testQuery( + "SELECT dim1, dim2, CONCAT(dim1, 'x', dim2) as dimX\n" + + "FROM foo\n" + + "WHERE CONCAT(dim1, 'x', dim2) IS NULL", + ImmutableList.of( + NullHandling.sqlCompatible() + ? newScanQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .virtualColumns(expressionVirtualColumn( + "v0", + "concat(\"dim1\",'x',\"dim2\")", + ColumnType.STRING + )) + .filters(or(isNull("dim1"), isNull("dim2"))) + .columns("dim1", "dim2", "v0") + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .context(QUERY_CONTEXT_DEFAULT) + .build() + : Druids.newScanQueryBuilder() + .dataSource( + InlineDataSource.fromIterable( + ImmutableList.of(), + RowSignature.builder() + .add("dim1", ColumnType.STRING) + .add("dim2", ColumnType.STRING) + .add("dimX", ColumnType.STRING) + .build() + ) + ) + .intervals(querySegmentSpec(Filtration.eternity())) + .columns("dim1", "dim2", "dimX") + .resultFormat(ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .legacy(false) + .build() + + ), + NullHandling.sqlCompatible() + ? ImmutableList.of( + new Object[]{"10.1", null, null}, + new Object[]{"abc", null, null} + ) + : ImmutableList.of() + ); + } + + @Test + public void testConcatDoubleBarsDecompose() + { + testQuery( + "SELECT dim1 || LOWER('x') || dim2 || 'z' as dimX\n" + + "FROM foo\n" + + "WHERE dim1 || LOWER('x') || dim2 || 'z' IN ('1xaz', '3x4z')", + ImmutableList.of( + newScanQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .virtualColumns(expressionVirtualColumn("v0", "concat(\"dim1\",'x',\"dim2\",'z')", ColumnType.STRING)) + .filters(or( + and(equality("dim1", "1", ColumnType.STRING), equality("dim2", "a", ColumnType.STRING)), + and(equality("dim1", "3", ColumnType.STRING), equality("dim2", "4", ColumnType.STRING)) + )) + .columns("v0") + .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) + .context(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{"1xaz"} + ) ); } @@ -13816,8 +13991,10 @@ public void testStringAggExpression() cannotVectorize(); skipVectorize(); testQuery( - // TODO(gianm): '||' used to be CONCAT('|', '|'), but for some reason this is no longer being reduced - "SELECT STRING_AGG(DISTINCT CONCAT(dim1, dim2), ','), STRING_AGG(DISTINCT CONCAT(dim1, dim2), '||') FROM foo", + "SELECT\n" + + " STRING_AGG(DISTINCT CONCAT(dim1, dim2), ','),\n" + + " STRING_AGG(DISTINCT CONCAT(dim1, dim2), CONCAT('|', '|'))\n" + + "FROM foo", ImmutableList.of( Druids.newTimeseriesQueryBuilder() .dataSource(CalciteTests.DATASOURCE1) diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/rule/FilterDecomposeConcatRuleTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/rule/FilterDecomposeConcatRuleTest.java new file mode 100644 index 000000000000..601d02c1ca8d --- /dev/null +++ b/sql/src/test/java/org/apache/druid/sql/calcite/rule/FilterDecomposeConcatRuleTest.java @@ -0,0 +1,369 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.sql.calcite.rule; + +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexLiteral; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexShuttle; +import org.apache.calcite.rex.RexUtil; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.druid.common.config.NullHandling; +import org.apache.druid.sql.calcite.expression.builtin.ConcatOperatorConversion; +import org.apache.druid.sql.calcite.planner.DruidTypeSystem; +import org.apache.druid.testing.InitializedNullHandlingTest; +import org.junit.Assert; +import org.junit.Test; + +import java.math.BigDecimal; +import java.util.Arrays; + +public class FilterDecomposeConcatRuleTest extends InitializedNullHandlingTest +{ + private final RelDataTypeFactory typeFactory = DruidTypeSystem.TYPE_FACTORY; + private final RexBuilder rexBuilder = new RexBuilder(typeFactory); + private final RexShuttle shuttle = new FilterDecomposeConcatRule.DecomposeConcatShuttle(rexBuilder); + + @Test + public void test_notConcat() + { + final RexNode call = + equals( + rexBuilder.makeCall(SqlStdOperatorTable.LOWER, inputRef(0)), + literal("2") + ); + + Assert.assertEquals(call, shuttle.apply(call)); + } + + @Test + public void test_oneInput() + { + final RexNode concatCall = + concat(literal("it's "), inputRef(0)); + + Assert.assertEquals( + and(equals(inputRef(0), literal("2"))), + shuttle.apply(equals(concatCall, literal("it's 2"))) + ); + } + + @Test + public void test_oneInput_lhsLiteral() + { + final RexNode concatCall = + concat(literal("it's "), inputRef(0)); + + Assert.assertEquals( + and(equals(inputRef(0), literal("2"))), + shuttle.apply(equals(literal("it's 2"), concatCall)) + ); + } + + @Test + public void test_oneInput_noLiteral() + { + final RexNode concatCall = concat(inputRef(0)); + + Assert.assertEquals( + and(equals(inputRef(0), literal("it's 2"))), + shuttle.apply(equals(literal("it's 2"), concatCall)) + ); + } + + @Test + public void test_twoInputs() + { + final RexNode concatCall = + concat(inputRef(0), literal("x"), inputRef(1)); + + Assert.assertEquals( + and(equals(inputRef(0), literal("2")), equals(inputRef(1), literal("3"))), + shuttle.apply(equals(concatCall, literal("2x3"))) + ); + } + + @Test + public void test_twoInputs_castNumberInputRef() + { + // CAST(x AS VARCHAR) when x is BIGINT + final RexNode numericInputRef = rexBuilder.makeCast( + typeFactory.createTypeWithNullability(typeFactory.createSqlType(SqlTypeName.VARCHAR), true), + rexBuilder.makeInputRef( + typeFactory.createTypeWithNullability(typeFactory.createSqlType(SqlTypeName.BIGINT), true), + 0 + ) + ); + + final RexNode concatCall = + concat(numericInputRef, literal("x"), inputRef(1)); + + Assert.assertEquals( + and( + equals( + numericInputRef, + literal("2") + ), + equals( + inputRef(1), + literal("3") + ) + ), + shuttle.apply(equals(concatCall, literal("2x3"))) + ); + } + + @Test + public void test_twoInputs_notEquals() + { + final RexNode call = + notEquals( + concat(inputRef(0), literal("x"), inputRef(1)), + literal("2x3") + ); + + Assert.assertEquals( + rexBuilder.makeCall( + SqlStdOperatorTable.NOT, + and(equals(inputRef(0), literal("2")), equals(inputRef(1), literal("3"))) + ), + shuttle.apply(call) + ); + } + + @Test + public void test_twoInputs_castNumberLiteral() + { + final RexNode three = rexBuilder.makeCast( + typeFactory.createSqlType(SqlTypeName.VARCHAR), + rexBuilder.makeExactLiteral(BigDecimal.valueOf(3L)) + ); + + final RexNode concatCall = + concat(inputRef(0), three, inputRef(1), literal("4")); + + Assert.assertEquals( + and(equals(inputRef(0), literal("x")), equals(inputRef(1), literal("y"))), + shuttle.apply(equals(concatCall, literal("x3y4"))) + ); + } + + @Test + public void test_twoInputs_noLiteral() + { + final RexNode call = equals(concat(inputRef(0), inputRef(1)), literal("2x3")); + Assert.assertEquals(call, shuttle.apply(call)); + } + + @Test + public void test_twoInputs_isNull() + { + final RexNode call = + isNull(concat(inputRef(0), literal("x"), inputRef(1))); + + Assert.assertEquals( + NullHandling.sqlCompatible() + ? or(isNull(inputRef(0)), isNull(inputRef(1))) + : rexBuilder.makeLiteral(false), + shuttle.apply(call) + ); + } + + @Test + public void test_twoInputs_isNotNull() + { + final RexNode call = + notNull(concat(inputRef(0), literal("x"), inputRef(1))); + + Assert.assertEquals( + rexBuilder.makeCall( + SqlStdOperatorTable.NOT, + NullHandling.sqlCompatible() + ? or(isNull(inputRef(0)), isNull(inputRef(1))) + : rexBuilder.makeLiteral(false) + ), + shuttle.apply(call) + ); + } + + @Test + public void test_twoInputs_tooManyXes() + { + final RexNode call = + equals( + concat(inputRef(0), literal("x"), inputRef(1)), + literal("2xx3") // ambiguous match + ); + + Assert.assertEquals(call, shuttle.apply(call)); + } + + @Test + public void test_twoInputs_notEnoughXes() + { + final RexNode call = + equals( + concat(inputRef(0), literal("x"), inputRef(1)), + literal("2z3") // doesn't match concat pattern + ); + + final RexLiteral unknown = rexBuilder.makeNullLiteral(typeFactory.createSqlType(SqlTypeName.BOOLEAN)); + Assert.assertEquals( + NullHandling.sqlCompatible() + ? or( + and(isNull(inputRef(0)), unknown), + and(isNull(inputRef(1)), unknown) + ) + : rexBuilder.makeLiteral(false), + shuttle.apply(call) + ); + } + + @Test + public void test_twoInputs_delimitersWrongOrder() + { + final RexNode call = + equals( + concat(literal("z"), inputRef(0), literal("x"), inputRef(1)), + literal("x2z3") // doesn't match concat pattern + ); + + final RexLiteral unknown = rexBuilder.makeNullLiteral(typeFactory.createSqlType(SqlTypeName.BOOLEAN)); + Assert.assertEquals( + NullHandling.sqlCompatible() + ? or( + and(isNull(inputRef(0)), unknown), + and(isNull(inputRef(1)), unknown) + ) + : rexBuilder.makeLiteral(false), + shuttle.apply(call) + ); + } + + @Test + public void test_twoInputs_emptyDelimiter() + { + final RexNode call = + equals( + concat(inputRef(0), literal(""), inputRef(1)), + literal("23") // must be recognized as ambiguous + ); + + Assert.assertEquals(call, shuttle.apply(call)); + } + + @Test + public void test_twoInputs_ambiguousOverlappingDeliminters() + { + final RexNode call = + equals( + concat(inputRef(0), literal("--"), inputRef(1)), + literal("2---3") // must be recognized as ambiguous + ); + + Assert.assertEquals(call, shuttle.apply(call)); + } + + @Test + public void test_twoInputs_impossibleOverlappingDelimiters() + { + final RexNode call = + equals( + concat(inputRef(0), literal("--"), inputRef(1), literal("--")), + literal("2---3") // must be recognized as impossible + ); + + final RexLiteral unknown = rexBuilder.makeNullLiteral(typeFactory.createSqlType(SqlTypeName.BOOLEAN)); + Assert.assertEquals( + NullHandling.sqlCompatible() + ? or( + and(isNull(inputRef(0)), unknown), + and(isNull(inputRef(1)), unknown) + ) + : rexBuilder.makeLiteral(false), + shuttle.apply(call) + ); + } + + @Test + public void test_twoInputs_backToBackLiterals() + { + final RexNode concatCall = + concat(inputRef(0), literal("x"), literal("y"), inputRef(1)); + + Assert.assertEquals( + and(equals(inputRef(0), literal("2")), equals(inputRef(1), literal("3"))), + shuttle.apply(equals(concatCall, literal("2xy3"))) + ); + } + + private RexNode concat(RexNode... args) + { + return rexBuilder.makeCall(ConcatOperatorConversion.SQL_FUNCTION, args); + } + + private RexNode inputRef(int i) + { + return rexBuilder.makeInputRef( + typeFactory.createTypeWithNullability( + typeFactory.createSqlType(SqlTypeName.VARCHAR), + true + ), + i + ); + } + + private RexNode or(RexNode... args) + { + return RexUtil.composeDisjunction(rexBuilder, Arrays.asList(args)); + } + + private RexNode and(RexNode... args) + { + return RexUtil.composeConjunction(rexBuilder, Arrays.asList(args)); + } + + private RexNode equals(RexNode arg, RexNode value) + { + return rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, arg, value); + } + + private RexNode notEquals(RexNode arg, RexNode value) + { + return rexBuilder.makeCall(SqlStdOperatorTable.NOT_EQUALS, arg, value); + } + + private RexNode isNull(RexNode arg) + { + return rexBuilder.makeCall(SqlStdOperatorTable.IS_NULL, arg); + } + + private RexNode notNull(RexNode arg) + { + return rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL, arg); + } + + private RexNode literal(String s) + { + return rexBuilder.makeLiteral(s); + } +}