From 3559b3bc0491608d8c5ca9d55b57736df29ae369 Mon Sep 17 00:00:00 2001 From: rubenada Date: Thu, 28 Nov 2019 12:12:36 +0100 Subject: [PATCH] [CALCITE-3542] Implement RepeatUnion All=false --- .../enumerable/EnumerableRepeatUnion.java | 22 ++-- .../apache/calcite/util/BuiltInMethod.java | 4 +- .../EnumerableRepeatUnionHierarchyTest.java | 114 +++++++++++----- .../enumerable/EnumerableRepeatUnionTest.java | 66 ++++++++++ .../calcite/linq4j/EnumerableDefaults.java | 122 +++++++++++------- site/_docs/algebra.md | 4 +- 6 files changed, 240 insertions(+), 92 deletions(-) diff --git a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableRepeatUnion.java b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableRepeatUnion.java index 33b234affff..d86967e05ac 100644 --- a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableRepeatUnion.java +++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableRepeatUnion.java @@ -25,6 +25,7 @@ import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.RepeatUnion; import org.apache.calcite.util.BuiltInMethod; +import org.apache.calcite.util.Util; import java.util.List; @@ -53,12 +54,8 @@ public class EnumerableRepeatUnion extends RepeatUnion implements EnumerableRel } @Override public Result implement(EnumerableRelImplementor implementor, Prefer pref) { - if (!all) { - throw new UnsupportedOperationException( - "Only EnumerableRepeatUnion ALL is supported"); - } - // return repeatUnionAll(, , iterationLimit); + // return repeatUnion(, , iterationLimit, all, ); BlockBuilder builder = new BlockBuilder(); RelNode seed = getSeedRel(); @@ -70,17 +67,20 @@ public class EnumerableRepeatUnion extends RepeatUnion implements EnumerableRel Expression seedExp = builder.append("seed", seedResult.block); Expression iterativeExp = builder.append("iteration", iterationResult.block); + PhysType physType = PhysTypeImpl.of( + implementor.getTypeFactory(), + getRowType(), + pref.prefer(seedResult.format)); + Expression unionExp = Expressions.call( - BuiltInMethod.REPEAT_UNION_ALL.method, + BuiltInMethod.REPEAT_UNION.method, seedExp, iterativeExp, - Expressions.constant(iterationLimit, int.class)); + Expressions.constant(iterationLimit, int.class), + Expressions.constant(all, boolean.class), + Util.first(physType.comparer(), Expressions.call(BuiltInMethod.IDENTITY_COMPARER.method))); builder.add(unionExp); - PhysType physType = PhysTypeImpl.of( - implementor.getTypeFactory(), - getRowType(), - pref.prefer(seedResult.format)); return implementor.result(physType, builder.toBlock()); } diff --git a/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java b/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java index 3adecbbb2a5..1e8f22c646a 100644 --- a/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java +++ b/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java @@ -221,8 +221,8 @@ public enum BuiltInMethod { Comparator.class), UNION(ExtendedEnumerable.class, "union", Enumerable.class), CONCAT(ExtendedEnumerable.class, "concat", Enumerable.class), - REPEAT_UNION_ALL(EnumerableDefaults.class, "repeatUnionAll", Enumerable.class, - Enumerable.class, int.class), + REPEAT_UNION(EnumerableDefaults.class, "repeatUnion", Enumerable.class, + Enumerable.class, int.class, boolean.class, EqualityComparer.class), LAZY_COLLECTION_SPOOL(EnumerableDefaults.class, "lazyCollectionSpool", Collection.class, Enumerable.class), INTERSECT(ExtendedEnumerable.class, "intersect", Enumerable.class, boolean.class), diff --git a/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableRepeatUnionHierarchyTest.java b/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableRepeatUnionHierarchyTest.java index 85bd7061248..5107d46d1a3 100644 --- a/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableRepeatUnionHierarchyTest.java +++ b/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableRepeatUnionHierarchyTest.java @@ -20,6 +20,7 @@ import org.apache.calcite.adapter.java.ReflectiveSchema; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rex.RexNode; import org.apache.calcite.schema.Schema; import org.apache.calcite.test.CalciteAssert; import org.apache.calcite.test.HierarchySchema; @@ -30,7 +31,9 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; +import java.util.ArrayList; import java.util.Arrays; +import java.util.List; import java.util.function.Function; /** @@ -55,33 +58,65 @@ public class EnumerableRepeatUnionHierarchyTest { private static final String EMP4 = "empid=4; name=Emp4"; private static final String EMP5 = "empid=5; name=Emp5"; + private static final int[] ID1 = new int[]{1}; + private static final String ID1_STR = Arrays.toString(ID1); + private static final int[] ID2 = new int[]{2}; + private static final String ID2_STR = Arrays.toString(ID2); + private static final int[] ID3 = new int[]{3}; + private static final String ID3_STR = Arrays.toString(ID3); + private static final int[] ID4 = new int[]{4}; + private static final String ID4_STR = Arrays.toString(ID4); + private static final int[] ID5 = new int[]{5}; + private static final String ID5_STR = Arrays.toString(ID5); + private static final int[] ID3_5 = new int[]{3, 5}; + private static final String ID3_5_STR = Arrays.toString(ID3_5); + private static final int[] ID1_3 = new int[]{1, 3}; + private static final String ID1_3_STR = Arrays.toString(ID1_3); + public static Iterable data() { + return Arrays.asList(new Object[][] { - { 1, true, -1, new String[]{EMP1} }, - { 2, true, -2, new String[]{EMP2, EMP1} }, - { 3, true, -1, new String[]{EMP3, EMP2, EMP1} }, - { 4, true, -5, new String[]{EMP4, EMP1} }, - { 5, true, -1, new String[]{EMP5, EMP2, EMP1} }, - { 3, true, 0, new String[]{EMP3} }, - { 3, true, 1, new String[]{EMP3, EMP2} }, - { 3, true, 2, new String[]{EMP3, EMP2, EMP1} }, - { 3, true, 10, new String[]{EMP3, EMP2, EMP1} }, - - { 1, false, -1, new String[]{EMP1, EMP2, EMP4, EMP3, EMP5} }, - { 2, false, -10, new String[]{EMP2, EMP3, EMP5} }, - { 3, false, -100, new String[]{EMP3} }, - { 4, false, -1, new String[]{EMP4} }, - { 1, false, 0, new String[]{EMP1} }, - { 1, false, 1, new String[]{EMP1, EMP2, EMP4} }, - { 1, false, 2, new String[]{EMP1, EMP2, EMP4, EMP3, EMP5} }, - { 1, false, 20, new String[]{EMP1, EMP2, EMP4, EMP3, EMP5} }, + { true, ID1, ID1_STR, true, -1, new String[]{EMP1} }, + { true, ID2, ID2_STR, true, -2, new String[]{EMP2, EMP1} }, + { true, ID3, ID3_STR, true, -1, new String[]{EMP3, EMP2, EMP1} }, + { true, ID4, ID4_STR, true, -5, new String[]{EMP4, EMP1} }, + { true, ID5, ID5_STR, true, -1, new String[]{EMP5, EMP2, EMP1} }, + { true, ID3, ID3_STR, true, 0, new String[]{EMP3} }, + { true, ID3, ID3_STR, true, 1, new String[]{EMP3, EMP2} }, + { true, ID3, ID3_STR, true, 2, new String[]{EMP3, EMP2, EMP1} }, + { true, ID3, ID3_STR, true, 10, new String[]{EMP3, EMP2, EMP1} }, + + { true, ID1, ID1_STR, false, -1, new String[]{EMP1, EMP2, EMP4, EMP3, EMP5} }, + { true, ID2, ID2_STR, false, -10, new String[]{EMP2, EMP3, EMP5} }, + { true, ID3, ID3_STR, false, -100, new String[]{EMP3} }, + { true, ID4, ID4_STR, false, -1, new String[]{EMP4} }, + { true, ID1, ID1_STR, false, 0, new String[]{EMP1} }, + { true, ID1, ID1_STR, false, 1, new String[]{EMP1, EMP2, EMP4} }, + { true, ID1, ID1_STR, false, 2, new String[]{EMP1, EMP2, EMP4, EMP3, EMP5} }, + { true, ID1, ID1_STR, false, 20, new String[]{EMP1, EMP2, EMP4, EMP3, EMP5} }, + + // tests to verify all=true vs all=false + { true, ID3_5, ID3_5_STR, true, -1, new String[]{EMP3, EMP5, EMP2, EMP2, EMP1, EMP1} }, + { false, ID3_5, ID3_5_STR, true, -1, new String[]{EMP3, EMP5, EMP2, EMP1} }, + { true, ID3_5, ID3_5_STR, true, 0, new String[]{EMP3, EMP5} }, + { false, ID3_5, ID3_5_STR, true, 0, new String[]{EMP3, EMP5} }, + { true, ID3_5, ID3_5_STR, true, 1, new String[]{EMP3, EMP5, EMP2, EMP2} }, + { false, ID3_5, ID3_5_STR, true, 1, new String[]{EMP3, EMP5, EMP2} }, + { true, ID1_3, ID1_3_STR, false, -1, new String[]{EMP1, EMP3, EMP2, EMP4, EMP3, EMP5} }, + { false, ID1_3, ID1_3_STR, false, -1, new String[]{EMP1, EMP3, EMP2, EMP4, EMP5} }, }); } - @ParameterizedTest + @ParameterizedTest(name = "{index} : hierarchy(startIds:{2}, ascendant:{3}, " + + "maxDepth:{4}, all:{0})") @MethodSource("data") - public void testHierarchy(int startId, boolean ascendant, - int maxDepth, String[] expected) { + public void testHierarchy( + boolean all, + int[] startIds, + String startIdsStr, + boolean ascendant, + int maxDepth, + String[] expected) { final String fromField; final String toField; if (ascendant) { @@ -96,27 +131,40 @@ public void testHierarchy(int startId, boolean ascendant, CalciteAssert.that() .withSchema("s", schema) .query("?") - .withRel(hierarchy(startId, fromField, toField, maxDepth)) + .withRel(buildHierarchy(all, startIds, fromField, toField, maxDepth)) .returnsOrdered(expected); } - private Function hierarchy(int startId, String fromField, - String toField, int maxDepth) { + private Function buildHierarchy( + boolean all, + int[] startIds, + String fromField, + String toField, + int maxDepth) { // WITH RECURSIVE delta(empid, name) as ( - // SELECT empid, name FROM emps WHERE empid = - // UNION ALL + // SELECT empid, name FROM emps WHERE empid IN () + // UNION [ALL] // SELECT e.empid, e.name FROM delta d // JOIN hierarchies h ON d.empid = h. // JOIN emps e ON h. = e.empid // ) // SELECT empid, name FROM delta - return builder -> builder - .scan("s", "emps") - .filter( + return builder -> { + builder + .scan("s", "emps"); + + final List filters = new ArrayList<>(); + for (int startId : startIds) { + filters.add( builder.equals( builder.field("empid"), - builder.literal(startId))) + builder.literal(startId))); + } + + builder + .filter( + builder.or(filters)) .project( builder.field("emps", "empid"), builder.field("emps", "name")) @@ -137,8 +185,10 @@ private Function hierarchy(int startId, String fromField, .project( builder.field("emps", "empid"), builder.field("emps", "name")) - .repeatUnion("#DELTA#", true, maxDepth) - .build(); + .repeatUnion("#DELTA#", all, maxDepth); + + return builder.build(); + }; } } diff --git a/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableRepeatUnionTest.java b/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableRepeatUnionTest.java index 1040f5bbdb9..866c9171a3b 100644 --- a/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableRepeatUnionTest.java +++ b/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableRepeatUnionTest.java @@ -59,6 +59,72 @@ public class EnumerableRepeatUnionTest { .returnsOrdered("i=1", "i=2", "i=3", "i=4", "i=5", "i=6", "i=7", "i=8", "i=9", "i=10"); } + @Test public void testGenerateNumbers2() { + CalciteAssert.that() + .query("?") + .withRel( + // WITH RECURSIVE aux(i) AS ( + // VALUES (0) + // UNION -- (ALL would generate an infinite loop!) + // SELECT (i+1)%10 FROM aux WHERE i < 10 + // ) + // SELECT * FROM aux + builder -> builder + .values(new String[] { "i" }, 0) + .transientScan("AUX") + .filter( + builder.call(SqlStdOperatorTable.LESS_THAN, + builder.field(0), + builder.literal(10))) + .project( + builder.call(SqlStdOperatorTable.MOD, + builder.call(SqlStdOperatorTable.PLUS, + builder.field(0), + builder.literal(1)), + builder.literal(10))) + .repeatUnion("AUX", false) + .build()) + .returnsOrdered("i=0", "i=1", "i=2", "i=3", "i=4", "i=5", "i=6", "i=7", "i=8", "i=9"); + } + + @Test public void testGenerateNumbers3() { + CalciteAssert.that() + .query("?") + .withRel( + // WITH RECURSIVE aux(i, j) AS ( + // VALUES (0, 0) + // UNION -- (ALL would generate an infinite loop!) + // SELECT (i+1)%10, j FROM aux WHERE i < 10 + // ) + // SELECT * FROM aux + builder -> builder + .values(new String[] { "i", "j" }, 0, 0) + .transientScan("AUX") + .filter( + builder.call(SqlStdOperatorTable.LESS_THAN, + builder.field(0), + builder.literal(10))) + .project( + builder.call(SqlStdOperatorTable.MOD, + builder.call(SqlStdOperatorTable.PLUS, + builder.field(0), + builder.literal(1)), + builder.literal(10)), + builder.field(1)) + .repeatUnion("AUX", false) + .build()) + .returnsOrdered("i=0; j=0", + "i=1; j=0", + "i=2; j=0", + "i=3; j=0", + "i=4; j=0", + "i=5; j=0", + "i=6; j=0", + "i=7; j=0", + "i=8; j=0", + "i=9; j=0"); + } + @Test public void testFactorial() { CalciteAssert.that() .query("?") diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java b/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java index 4e55feb373d..af1c0352d32 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java @@ -3937,20 +3937,25 @@ public void close() { private static final Object DUMMY = new Object(); /** - * Repeat Union All enumerable: it will evaluate the seed enumerable once, and then + * Repeat Union enumerable: it will evaluate the seed enumerable once, and then * it will start to evaluate the iteration enumerable over and over until either it returns * no results, or an optional maximum numbers of iterations is reached * @param seed seed enumerable * @param iteration iteration enumerable * @param iterationLimit maximum numbers of repetitions for the iteration enumerable * (negative value means no limit) + * @param all whether duplicates will be considered or not + * @param comparer {@link EqualityComparer} to control duplicates, + * only used if {@code all} is {@code false} * @param record type */ @SuppressWarnings("unchecked") - public static Enumerable repeatUnionAll( - Enumerable seed, - Enumerable iteration, - int iterationLimit) { + public static Enumerable repeatUnion( + Enumerable seed, + Enumerable iteration, + int iterationLimit, + boolean all, + EqualityComparer comparer) { return new AbstractEnumerable() { @Override public Enumerator enumerator() { return new Enumerator() { @@ -3960,67 +3965,94 @@ public static Enumerable repeatUnionAll( private final Enumerator seedEnumerator = seed.enumerator(); private Enumerator iterativeEnumerator = null; + // Set to control duplicates, only used if "all" is false + private final Set> processed = new HashSet<>(); + private final Function1> wrapper = wrapperFor(comparer); + @Override public TSource current() { - if (this.current == DUMMY) { + if (current == DUMMY) { throw new NoSuchElementException(); } - return this.current; + return current; + } + + private boolean checkValue(TSource value) { + if (all) { + return true; // no need to check duplicates + } + + // check duplicates + final Wrapped wrapped = wrapper.apply(value); + if (!processed.contains(wrapped)) { + processed.add(wrapped); + return true; + } + + return false; } @Override public boolean moveNext() { // if we are not done with the seed moveNext on it - if (!this.seedProcessed) { - if (this.seedEnumerator.moveNext()) { - this.current = this.seedEnumerator.current(); - return true; + while (!seedProcessed) { + if (seedEnumerator.moveNext()) { + TSource value = seedEnumerator.current(); + if (checkValue(value)) { + current = value; + return true; + } } else { - this.seedProcessed = true; + seedProcessed = true; } } // if we are done with the seed, moveNext on the iterative part while (true) { - if (iterationLimit >= 0 && this.currentIteration == iterationLimit) { + if (iterationLimit >= 0 && currentIteration == iterationLimit) { // max number of iterations reached, we are done - this.current = (TSource) DUMMY; + current = (TSource) DUMMY; return false; } - if (this.iterativeEnumerator == null) { - this.iterativeEnumerator = iteration.enumerator(); + if (iterativeEnumerator == null) { + iterativeEnumerator = iteration.enumerator(); } - if (this.iterativeEnumerator.moveNext()) { - this.current = this.iterativeEnumerator.current(); - return true; + while (iterativeEnumerator.moveNext()) { + TSource value = iterativeEnumerator.current(); + if (checkValue(value)) { + current = value; + return true; + } } - if (this.current == DUMMY) { + if (current == DUMMY) { // current iteration did not return any value, we are done return false; } // current iteration level (which returned some values) is finished, go to next one - this.current = (TSource) DUMMY; - this.iterativeEnumerator.close(); - this.iterativeEnumerator = null; - this.currentIteration++; + current = (TSource) DUMMY; + iterativeEnumerator.close(); + iterativeEnumerator = null; + currentIteration++; } } @Override public void reset() { - this.seedEnumerator.reset(); - if (this.iterativeEnumerator != null) { - this.iterativeEnumerator.close(); - this.iterativeEnumerator = null; + seedEnumerator.reset(); + seedProcessed = false; + processed.clear(); + if (iterativeEnumerator != null) { + iterativeEnumerator.close(); + iterativeEnumerator = null; } - this.currentIteration = 0; + currentIteration = 0; } @Override public void close() { - this.seedEnumerator.close(); - if (this.iterativeEnumerator != null) { - this.iterativeEnumerator.close(); + seedEnumerator.close(); + if (iterativeEnumerator != null) { + iterativeEnumerator.close(); } } }; @@ -4045,36 +4077,36 @@ public static Enumerable lazyCollectionSpool( private final Collection tempCollection = new ArrayList<>(); @Override public TSource current() { - if (this.current == DUMMY) { + if (current == DUMMY) { throw new NoSuchElementException(); } - return this.current; + return current; } @Override public boolean moveNext() { - if (this.inputEnumerator.moveNext()) { - this.current = this.inputEnumerator.current(); - this.tempCollection.add(this.current); + if (inputEnumerator.moveNext()) { + current = inputEnumerator.current(); + tempCollection.add(current); return true; } - this.flush(); + flush(); return false; } private void flush() { - this.collection.clear(); - this.collection.addAll(this.tempCollection); - this.tempCollection.clear(); + collection.clear(); + collection.addAll(tempCollection); + tempCollection.clear(); } @Override public void reset() { - this.inputEnumerator.reset(); - this.collection.clear(); - this.tempCollection.clear(); + inputEnumerator.reset(); + collection.clear(); + tempCollection.clear(); } @Override public void close() { - this.inputEnumerator.close(); + inputEnumerator.close(); } }; } diff --git a/site/_docs/algebra.md b/site/_docs/algebra.md index d589fb0b5d5..13ec1eac369 100644 --- a/site/_docs/algebra.md +++ b/site/_docs/algebra.md @@ -307,7 +307,7 @@ return the `RelBuilder`. |:------------------- |:----------- | `scan(tableName)` | Creates a [TableScan]({{ site.apiRoot }}/org/apache/calcite/rel/core/TableScan.html). | `functionScan(operator, n, expr...)`
`functionScan(operator, n, exprList)` | Creates a [TableFunctionScan]({{ site.apiRoot }}/org/apache/calcite/rel/core/TableFunctionScan.html) of the `n` most recent relational expressions. -| `transientScan(tableName [, rowType])` | Creates a [TableScan]({{ site.apiRoot }}/org/apache/calcite/rel/core/TableScan.html) on a [TransientTable]]({{ site.apiRoot }}/org/apache/calcite/schema/TransientTable.html) with the given type (if not specified, the most recent relational expression's type will be used). +| `transientScan(tableName [, rowType])` | Creates a [TableScan]({{ site.apiRoot }}/org/apache/calcite/rel/core/TableScan.html) on a [TransientTable]({{ site.apiRoot }}/org/apache/calcite/schema/TransientTable.html) with the given type (if not specified, the most recent relational expression's type will be used). | `values(fieldNames, value...)`
`values(rowType, tupleList)` | Creates a [Values]({{ site.apiRoot }}/org/apache/calcite/rel/core/Values.html). | `filter([variablesSet, ] exprList)`
`filter([variablesSet, ] expr...)` | Creates a [Filter]({{ site.apiRoot }}/org/apache/calcite/rel/core/Filter.html) over the AND of the given predicates; if `variablesSet` is specified, the predicates may reference those variables. | `project(expr...)`
`project(exprList [, fieldNames])` | Creates a [Project]({{ site.apiRoot }}/org/apache/calcite/rel/core/Project.html). To override the default name, wrap expressions using `alias`, or specify the `fieldNames` argument. @@ -328,7 +328,7 @@ return the `RelBuilder`. | `union(all [, n])` | Creates a [Union]({{ site.apiRoot }}/org/apache/calcite/rel/core/Union.html) of the `n` (default two) most recent relational expressions. | `intersect(all [, n])` | Creates an [Intersect]({{ site.apiRoot }}/org/apache/calcite/rel/core/Intersect.html) of the `n` (default two) most recent relational expressions. | `minus(all)` | Creates a [Minus]({{ site.apiRoot }}/org/apache/calcite/rel/core/Minus.html) of the two most recent relational expressions. -| `repeatUnion(tableName, all [, n])` | Creates a [RepeatUnion]({{ site.apiRoot }}/org/apache/calcite/rel/core/RepeatUnion.html) associated to a [TransientTable]]({{ site.apiRoot }}/org/apache/calcite/schema/TransientTable.html) of the two most recent relational expressions, with `n` maximum number of iterations (default -1, i.e. no limit). +| `repeatUnion(tableName, all [, n])` | Creates a [RepeatUnion]({{ site.apiRoot }}/org/apache/calcite/rel/core/RepeatUnion.html) associated to a [TransientTable]({{ site.apiRoot }}/org/apache/calcite/schema/TransientTable.html) of the two most recent relational expressions, with `n` maximum number of iterations (default -1, i.e. no limit). | `snapshot(period)` | Creates a [Snapshot]({{ site.apiRoot }}/org/apache/calcite/rel/core/Snapshot.html) of the given snapshot period. | `match(pattern, strictStart,` `strictEnd, patterns, measures,` `after, subsets, allRows,` `partitionKeys, orderKeys,` `interval)` | Creates a [Match]({{ site.apiRoot }}/org/apache/calcite/rel/core/Match.html).