Skip to content

Commit

Permalink
[FLINK-12248][table-planner-blink] Support e2e over window in blink b…
Browse files Browse the repository at this point in the history
…atch (#8206)
  • Loading branch information
JingsongLi authored and KurtYoung committed Apr 19, 2019
1 parent e34e6e6 commit 8e1643b
Show file tree
Hide file tree
Showing 21 changed files with 3,181 additions and 86 deletions.
Expand Up @@ -63,12 +63,12 @@ public static Expression call(FunctionDefinition functionDefinition, List<Expres
return new CallExpression(functionDefinition, args);
}

public static Expression and(Expression... args) {
return new CallExpression(AND, Arrays.asList(args));
public static Expression and(Expression arg1, Expression arg2) {
return new CallExpression(AND, Arrays.asList(arg1, arg2));
}

public static Expression or(Expression... args) {
return new CallExpression(OR, Arrays.asList(args));
public static Expression or(Expression arg1, Expression arg2) {
return new CallExpression(OR, Arrays.asList(arg1, arg2));
}

public static Expression not(Expression arg) {
Expand Down
Expand Up @@ -67,17 +67,33 @@ public RexNodeConverter(RelBuilder relBuilder) {

@Override
public RexNode visitCall(CallExpression call) {
List<RexNode> child = call.getChildren().stream()
.map(expression -> expression.accept(RexNodeConverter.this))
.collect(Collectors.toList());
switch (call.getFunctionDefinition().getType()) {
case SCALAR_FUNCTION:
return visitScalarFunc(call.getFunctionDefinition(), child);
return visitScalarFunc(call);
default: throw new UnsupportedOperationException();
}
}

private RexNode visitScalarFunc(FunctionDefinition def, List<RexNode> child) {
private List<RexNode> convertCallChildren(CallExpression call) {
return call.getChildren().stream()
.map(expression -> expression.accept(RexNodeConverter.this))
.collect(Collectors.toList());
}

private RexNode visitScalarFunc(CallExpression call) {
FunctionDefinition def = call.getFunctionDefinition();

if (call.getFunctionDefinition().equals(BuiltInFunctionDefinitions.CAST)) {
RexNode child = call.getChildren().get(0).accept(this);
TypeLiteralExpression type = (TypeLiteralExpression) call.getChildren().get(1);
return relBuilder.getRexBuilder().makeAbstractCast(
typeFactory.createTypeFromInternalType(
createInternalTypeFromTypeInfo(type.getType()),
child.getType().isNullable()),
child);
}

List<RexNode> child = convertCallChildren(call);
if (BuiltInFunctionDefinitions.IF.equals(def)) {
return relBuilder.call(FlinkSqlOperatorTable.CASE, child);
} else if (BuiltInFunctionDefinitions.IS_NULL.equals(def)) {
Expand Down Expand Up @@ -117,6 +133,10 @@ && isTemporal(toInternalType(child.get(1).getType()))) {
return relBuilder.call(FlinkSqlOperatorTable.LESS_THAN, child);
} else if (BuiltInFunctionDefinitions.GREATER_THAN.equals(def)) {
return relBuilder.call(FlinkSqlOperatorTable.GREATER_THAN, child);
} else if (BuiltInFunctionDefinitions.AND.equals(def)) {
return relBuilder.call(FlinkSqlOperatorTable.AND, child);
} else if (BuiltInFunctionDefinitions.NOT.equals(def)) {
return relBuilder.call(FlinkSqlOperatorTable.NOT, child);
} else {
throw new UnsupportedOperationException(def.getName());
}
Expand Down
Expand Up @@ -70,10 +70,7 @@ public Expression[] accumulateExpressions() {
// sequence = if (lastValues equalTo orderKeys) sequence else sequence + 1
accExpressions[0] = ifThenElse(orderKeyEqualsExpression(), sequence, plus(sequence, literal(1L)));
Expression[] operands = operands();
for (int i = 0; i < operands.length; ++i) {
// lastValue_i = orderKey[i]
accExpressions[i + 1] = operands[i];
}
System.arraycopy(operands, 0, accExpressions, 1, operands.length);
return accExpressions;
}

Expand Down
Expand Up @@ -82,10 +82,7 @@ public Expression[] accumulateExpressions() {
accExpressions[1] = ifThenElse(and(orderKeyEqualsExpression(), not(equalTo(sequence, literal(0L)))),
sequence, currNumber);
Expression[] operands = operands();
for (int i = 0; i < operands.length; ++i) {
// lastValue_i = orderKey[i]
accExpressions[i + 2] = operands[i];
}
System.arraycopy(operands, 0, accExpressions, 2, operands.length);
return accExpressions;
}

Expand Down
Expand Up @@ -22,6 +22,7 @@
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.table.api.TableException;
import org.apache.flink.table.expressions.Expression;
import org.apache.flink.table.expressions.ExpressionBuilder;
import org.apache.flink.table.expressions.UnresolvedReferenceExpression;
import org.apache.flink.table.type.DecimalType;
import org.apache.flink.table.type.InternalType;
Expand All @@ -30,8 +31,9 @@
import java.sql.Date;
import java.sql.Time;
import java.sql.Timestamp;
import java.util.Arrays;
import java.util.Optional;

import static org.apache.flink.table.expressions.ExpressionBuilder.and;
import static org.apache.flink.table.expressions.ExpressionBuilder.equalTo;
import static org.apache.flink.table.expressions.ExpressionBuilder.ifThenElse;
import static org.apache.flink.table.expressions.ExpressionBuilder.isNull;
Expand Down Expand Up @@ -92,11 +94,8 @@ protected Expression orderKeyEqualsExpression() {
ifThenElse(isNull(operand(i)), literal(true), literal(false)),
equalTo(lasValue, operand(i)));
}
if (orderKeyEquals.length == 0) {
return literal(true);
} else {
return and(orderKeyEquals);
}
Optional<Expression> ret = Arrays.stream(orderKeyEquals).reduce(ExpressionBuilder::and);
return ret.orElseGet(() -> literal(true));
}

protected Expression generateInitLiteral(InternalType orderType) {
Expand All @@ -116,6 +115,8 @@ protected Expression generateInitLiteral(InternalType orderType) {
return literal(0.0d);
} else if (orderType instanceof DecimalType) {
return literal(java.math.BigDecimal.ZERO);
} else if (orderType.equals(InternalTypes.STRING)) {
return literal("");
} else if (orderType.equals(InternalTypes.DATE)) {
return literal(new Date(0));
} else if (orderType.equals(InternalTypes.TIME)) {
Expand Down
Expand Up @@ -121,7 +121,7 @@ class ImperativeAggCodeGen(
// do not set dataview into the acc in createAccumulator
val accField = if (isAccTypeInternal) {
// do not need convert to internal type
s"$functionTerm.createAccumulator()"
s"($accTypeInternalTerm) $functionTerm.createAccumulator()"
} else {
genToInternal(ctx, externalAccType, s"$functionTerm.createAccumulator()")
}
Expand Down Expand Up @@ -156,7 +156,7 @@ class ImperativeAggCodeGen(

override def resetAccumulator(generator: ExprCodeGenerator): String = {
if (isAccTypeInternal) {
s"$accInternalTerm = $functionTerm.createAccumulator();"
s"$accInternalTerm = ($accTypeInternalTerm) $functionTerm.createAccumulator();"
} else {
s"""
|$accExternalTerm = $functionTerm.createAccumulator();
Expand Down
@@ -0,0 +1,84 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.codegen.over

import org.apache.flink.table.`type`.{InternalType, RowType}
import org.apache.flink.table.api.TableConfig
import org.apache.flink.table.codegen.CodeGenUtils.{BASE_ROW, newName}
import org.apache.flink.table.codegen.Indenter.toISC
import org.apache.flink.table.codegen.{CodeGenUtils, CodeGeneratorContext, GenerateUtils}
import org.apache.flink.table.generated.{GeneratedRecordComparator, RecordComparator}

/**
* RANGE allow the compound ORDER BY and the random type when the bound is current row.
*/
class MultiFieldRangeBoundComparatorCodeGenerator(
conf: TableConfig,
inType: RowType,
keys: Array[Int],
keyTypes: Array[InternalType],
keyOrders: Array[Boolean],
nullsIsLasts: Array[Boolean],
isLowerBound: Boolean = true) {

def generateBoundComparator(name: String): GeneratedRecordComparator = {
val className = newName(name)
val input = CodeGenUtils.DEFAULT_INPUT1_TERM
val current = CodeGenUtils.DEFAULT_INPUT2_TERM

// In order to avoid the loss of precision in long cast to int.
def generateReturnCode(comp: String): String = {
if (isLowerBound) s"return $comp >= 0 ? 1 : -1;" else s"return $comp > 0 ? 1 : -1;"
}

val ctx = CodeGeneratorContext(conf)

val compareCode = GenerateUtils.generateRowCompare(
ctx, keys, keyTypes, keyOrders, nullsIsLasts, input, current)

val code =
j"""
public class $className implements ${classOf[RecordComparator].getCanonicalName} {

private final Object[] references;
${ctx.reuseMemberCode()}

public $className(Object[] references) {
this.references = references;
${ctx.reuseInitCode()}
${ctx.reuseOpenCode()}
}

@Override
public int compare($BASE_ROW $input, $BASE_ROW $current) {
int ret = compareInternal($input, $current);
${generateReturnCode("ret")}
}

private int compareInternal($BASE_ROW $input, $BASE_ROW $current) {
$compareCode
return 0;
}

}
""".stripMargin
new GeneratedRecordComparator(className, code, ctx.references.toArray)
}
}

0 comments on commit 8e1643b

Please sign in to comment.