Skip to content

Commit

Permalink
[FLINK-28919][table] Add built-in generate_series function.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuyongvs committed Sep 5, 2022
1 parent df0bc11 commit a28f2f9
Show file tree
Hide file tree
Showing 5 changed files with 291 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,30 @@ ANY, and(logical(LogicalTypeRoot.BOOLEAN), LITERAL)
.internal()
.build();

public static final BuiltInFunctionDefinition GENERATE_SERIES =
BuiltInFunctionDefinition.newBuilder()
.name("GENERATE_SERIES")
.kind(TABLE)
.inputTypeStrategy(
or(
sequence(
new String[] {"start", "stop"},
new ArgumentTypeStrategy[] {
logical(LogicalTypeFamily.NUMERIC),
logical(LogicalTypeFamily.NUMERIC)
}),
sequence(
new String[] {"start", "stop", "step"},
new ArgumentTypeStrategy[] {
logical(LogicalTypeFamily.NUMERIC),
logical(LogicalTypeFamily.NUMERIC),
logical(LogicalTypeFamily.NUMERIC)
})))
.outputTypeStrategy(COMMON)
.runtimeClass(
"org.apache.flink.table.runtime.functions.table.GenerateSeriesFunction")
.build();

// --------------------------------------------------------------------------------------------
// Logic functions
// --------------------------------------------------------------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
*/
package org.apache.flink.table.planner.runtime.batch.sql

import org.apache.flink.api.common.typeinfo.BasicTypeInfo.{INT_TYPE_INFO, STRING_TYPE_INFO}
import org.apache.flink.api.java.typeutils.RowTypeInfo
import org.apache.flink.api.scala._
import org.apache.flink.api.scala.typeutils.Types
import org.apache.flink.table.planner.runtime.utils.BatchTestBase
Expand Down Expand Up @@ -275,4 +277,87 @@ class CorrelateITCase2 extends BatchTestBase {
)
)
}

@Test
def testGenerateSeries(): Unit = {
registerCollection("t1", nullData4, type4, "a,b,c")

checkResult("SELECT a, b, c, v FROM t1, LATERAL TABLE(GENERATE_SERIES(1, 0)) AS T(v)", Seq())

checkResult(
"SELECT a, b, c, v FROM t1, LATERAL TABLE(GENERATE_SERIES(0, 0)) AS T(v)",
Seq(
row("book", 1, 12, 0),
row("book", 2, null, 0),
row("book", 4, 11, 0),
row("fruit", 3, 44, 0),
row("fruit", 4, null, 0),
row("fruit", 5, null, 0))
)

checkResult(
"SELECT a, b, c, v FROM t1, LATERAL TABLE(GENERATE_SERIES(0, 1)) AS T(v)",
Seq(
row("book", 1, 12, 0),
row("book", 1, 12, 1),
row("book", 2, null, 0),
row("book", 2, null, 1),
row("book", 4, 11, 0),
row("book", 4, 11, 1),
row("fruit", 3, 44, 0),
row("fruit", 3, 44, 1),
row("fruit", 4, null, 0),
row("fruit", 4, null, 1),
row("fruit", 5, null, 0),
row("fruit", 5, null, 1)
)
)

checkResult(
"SELECT a, b, c, v FROM t1 t1 join " +
"LATERAL TABLE(GENERATE_SERIES(1614325532, 1614325539)) AS T(v) ON TRUE " +
"where c is not null and substring(cast(v as varchar), 10, 1) = cast(b as varchar)",
Seq(row("book", 4, 11, 1614325534L), row("fruit", 3, 44, 1614325533L))
)
}

@Test
def testGenerateSeriesWithDifferentArgsType(): Unit = {
registerCollection("t1", Seq(row("book")), new RowTypeInfo(STRING_TYPE_INFO), "a")

checkResult(
"SELECT a, v FROM t1, LATERAL TABLE(" +
"GENERATE_SERIES(cast(2 AS SMALLINT), cast(4 AS BIGINT), cast(0.5 AS FLOAT))) AS T(v)",
Seq(
row("book", 2.0),
row("book", 2.5),
row("book", 3.0),
row("book", 3.5),
row("book", 4.0)
)
)

checkResult(
"SELECT a, v FROM t1, LATERAL TABLE(" +
"GENERATE_SERIES(cast(4 AS SMALLINT), cast(2 AS BIGINT), cast(-0.5 AS FLOAT))) AS T(v)",
Seq(
row("book", 4.0),
row("book", 3.5),
row("book", 2.5),
row("book", 3.0),
row("book", 2.0)
)
)
}

// The orginal exception is wrapped.
@Test(expected = classOf[RuntimeException])
def testTableGenerateFunctionLeftJoin(): Unit = {
registerCollection("t1", Seq(row("book")), new RowTypeInfo(STRING_TYPE_INFO), "a")

checkResult(
"SELECT a, v FROM t1, LATERAL TABLE(GENERATE_SERIES(2, 4, 0)) AS T(v)",
Seq()
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -543,18 +543,6 @@ class MiscITCase extends BatchTestBase {
Seq(row("abcd", "f%g", "abcd"), row("e fg", null, "e"), row("e fg", null, "fg"))
)

checkResult(
"SELECT f, g, v FROM testTable," +
"LATERAL TABLE(GENERATE_SERIES(0, CAST(b AS INTEGER))) AS T(v)",
Seq(
row("abcd", "f%g", 0),
row(null, "hij_k", 0),
row(null, "hij_k", 1),
row("e fg", null, 0),
row("e fg", null, 1),
row("e fg", null, 2))
)

checkResult(
"SELECT f, g, v FROM testTable," +
"LATERAL TABLE(JSON_TUPLE('{\"a1\": \"b1\", \"a2\": \"b2\", \"e fg\": \"b3\"}'," +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ import scala.collection.mutable

class CorrelateITCase extends StreamingTestBase {

val nullableData = List(
("book", 1, 12),
("book", 2, null),
("book", 4, 11),
("fruit", 4, null),
("fruit", 3, 44),
("fruit", 5, null))

@Before
override def before(): Unit = {
super.before()
Expand Down Expand Up @@ -389,6 +397,61 @@ class CorrelateITCase extends StreamingTestBase {
assertEquals(expected.sorted, sink.getAppendResults.sorted)
}

@Test
def testGenerateSeries(): Unit = {
val t1 = env.fromCollection(nullableData).toTable(tEnv, 'a, 'b, 'c)
tEnv.registerTable("t1", t1)

val sql = "SELECT a, b, c, v FROM t1, LATERAL TABLE(GENERATE_SERIES(0, 0)) AS T(v)"

val result = tEnv.sqlQuery(sql)
val sink = TestSinkUtil.configureSink(result, new TestingAppendTableSink)
tEnv.asInstanceOf[TableEnvironmentInternal].registerTableSinkInternal("MySink", sink)
result.executeInsert("MySink").await()

val expected = List(
"book,1,12,0",
"book,2,null,0",
"book,4,11,0",
"fruit,3,44,0",
"fruit,4,null,0",
"fruit,5,null,0")
assertEquals(expected.sorted, sink.getAppendResults.sorted)
}

@Test
def testGenerateSeriesWithFilter(): Unit = {
val t1 = env.fromCollection(nullableData).toTable(tEnv, 'a, 'b, 'c)
tEnv.registerTable("t1", t1)

val sql = "SELECT a, b, c, v FROM t1 t1 join " +
"LATERAL TABLE(GENERATE_SERIES(1614325532, 1614325539)) AS T(v) ON TRUE where c is not null" +
" and substring(cast(v as varchar), 10, 1) = cast(b as varchar)"

val result = tEnv.sqlQuery(sql)
val sink = TestSinkUtil.configureSink(result, new TestingAppendTableSink)
tEnv.asInstanceOf[TableEnvironmentInternal].registerTableSinkInternal("MySink", sink)
result.executeInsert("MySink").await()

val expected = List("book,4,11,1614325534", "fruit,3,44,1614325533")
assertEquals(expected.sorted, sink.getAppendResults.sorted)
}

@Test
def testGenerateSeriesWithEmptyOutput(): Unit = {
val t1 = env.fromCollection(nullableData).toTable(tEnv, 'a, 'b, 'c)
tEnv.registerTable("t1", t1)

val sql = "SELECT a, b, c, v FROM t1, LATERAL TABLE(GENERATE_SERIES(1, 0)) AS T(v)"

val result = tEnv.sqlQuery(sql)
val sink = TestSinkUtil.configureSink(result, new TestingAppendTableSink)
tEnv.asInstanceOf[TableEnvironmentInternal].registerTableSinkInternal("MySink", sink)
result.executeInsert("MySink").await()

assertEquals(List.empty, sink.getAppendResults.sorted)
}

// TODO support agg
// @Test
// def testCountStarOnCorrelate(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.runtime.functions.table;

import org.apache.flink.annotation.Internal;
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.functions.BuiltInFunctionDefinitions;
import org.apache.flink.table.functions.SpecializedFunction.SpecializedContext;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.utils.LogicalTypeMerging;

import java.util.List;
import java.util.stream.Collectors;

/**
* GenerateSeries implements the table function `generate_series(start, stop)`
* `generate_series(start, * stop, step)` which generate a series of values, from start to stop with
* a step size.
*/
@Internal
public class GenerateSeriesFunction extends BuiltInTableFunction<Object> {

private static final long serialVersionUID = 1L;

private final transient DataType outputDataType;
private final LogicalType outputLogicalType;

public GenerateSeriesFunction(SpecializedContext specializedContext) {
super(BuiltInFunctionDefinitions.GENERATE_SERIES, specializedContext);

// The output type in the context is already wrapped, however, the result of the
// function is not. Therefore, we need a custom output type.
final List<LogicalType> actualTypes =
specializedContext.getCallContext().getArgumentDataTypes().stream()
.map(DataType::getLogicalType)
.collect(Collectors.toList());
this.outputLogicalType = LogicalTypeMerging.findCommonType(actualTypes).get();
this.outputDataType = DataTypes.of(outputLogicalType).toInternal();
}

@Override
public DataType getOutputDataType() {
return outputDataType;
}

public void eval(Number start, Number stop) {
eval(start, stop, 1);
}

public void eval(Number start, Number stop, Number step) {
if (isZero(step)) {
throw new IllegalArgumentException("step size cannot equal zero");
}
double s = start.doubleValue();
if (step.doubleValue() > 0) {
while (s <= stop.doubleValue()) {
collect(converter(s));
s += step.doubleValue();
}
} else {
while (s >= stop.doubleValue()) {
collect(converter(s));
s += step.doubleValue();
}
}
}

private Object converter(double s) {
switch (outputLogicalType.getTypeRoot()) {
case TINYINT:
return Byte.valueOf((byte) s);
case SMALLINT:
return Short.valueOf((short) s);
case INTEGER:
return Integer.valueOf((int) s);
case BIGINT:
return Long.valueOf((long) s);
case FLOAT:
return Float.valueOf((float) s);
case DOUBLE:
return Double.valueOf(s);
default:
throw new UnsupportedOperationException(
"Unsupported type: " + outputLogicalType.getTypeRoot());
}
}

private boolean isZero(Object number) {
if (number instanceof Byte
|| number instanceof Short
|| number instanceof Integer
|| number instanceof Long) {
return ((Long) number).compareTo(0L) == 0;
} else if (number instanceof Float) {
return ((Float) number).compareTo(0.0f) == 0;
} else if (number instanceof Double) {
return ((Double) number).compareTo(0.0d) == 0;
} else {
throw new UnsupportedOperationException("Unsupported step type: " + number.getClass());
}
}
}

0 comments on commit a28f2f9

Please sign in to comment.