Skip to content

Commit

Permalink
[FLINK-9714][table] Support versioned joins with processing time
Browse files Browse the repository at this point in the history
  • Loading branch information
pnowojski committed Sep 21, 2018
1 parent dde089b commit 00add9c
Show file tree
Hide file tree
Showing 10 changed files with 1,068 additions and 15 deletions.
Expand Up @@ -23,7 +23,7 @@ import java.util.Collections
import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.core._
import org.apache.calcite.rex.{RexBuilder, RexNode}
import org.apache.calcite.rex.{RexBuilder, RexCall, RexNode}
import org.apache.calcite.sql.`type`.{OperandTypes, ReturnTypes}
import org.apache.calcite.sql.{SqlFunction, SqlFunctionCategory, SqlKind}
import org.apache.flink.util.Preconditions.checkArgument
Expand Down Expand Up @@ -101,6 +101,38 @@ object LogicalTemporalTableJoin {
OperandTypes.ANY)),
SqlFunctionCategory.SYSTEM)

def isRowtimeCall(call: RexCall): Boolean = {
checkArgument(call.getOperator == TEMPORAL_JOIN_CONDITION)
call.getOperands.size() == 3
}

def isProctimeCall(call: RexCall): Boolean = {
checkArgument(call.getOperator == TEMPORAL_JOIN_CONDITION)
call.getOperands.size() == 2
}

def makeRowTimeTemporalJoinConditionCall(
rexBuilder: RexBuilder,
leftTimeAttribute: RexNode,
rightTimeAttribute: RexNode,
rightPrimaryKeyExpression: RexNode): RexNode = {
rexBuilder.makeCall(
TEMPORAL_JOIN_CONDITION,
leftTimeAttribute,
rightTimeAttribute,
rightPrimaryKeyExpression)
}

def makeProcTimeTemporalJoinConditionCall(
rexBuilder: RexBuilder,
leftTimeAttribute: RexNode,
rightPrimaryKeyExpression: RexNode): RexNode = {
rexBuilder.makeCall(
TEMPORAL_JOIN_CONDITION,
leftTimeAttribute,
rightPrimaryKeyExpression)
}

/**
* See [[LogicalTemporalTableJoin]]
*/
Expand All @@ -119,8 +151,8 @@ object LogicalTemporalTableJoin {
traitSet,
left,
right,
rexBuilder.makeCall(
TEMPORAL_JOIN_CONDITION,
makeRowTimeTemporalJoinConditionCall(
rexBuilder,
leftTimeAttribute,
rightTimeAttribute,
rightPrimaryKeyExpression))
Expand Down Expand Up @@ -148,8 +180,8 @@ object LogicalTemporalTableJoin {
traitSet,
left,
right,
rexBuilder.makeCall(
TEMPORAL_JOIN_CONDITION,
makeProcTimeTemporalJoinConditionCall(
rexBuilder,
leftTimeAttribute,
rightPrimaryKeyExpression))
}
Expand Down
Expand Up @@ -48,6 +48,15 @@ trait CommonJoin {
}
}

private[flink] def temporalJoinToString(
inputType: RelDataType,
joinCondition: RexNode,
joinType: JoinRelType,
expression: (RexNode, List[String], Option[List[RexNode]]) => String): String = {

"Temporal" + joinToString(inputType, joinCondition, joinType, expression)
}

private[flink] def joinToString(
inputType: RelDataType,
joinCondition: RexNode,
Expand Down
@@ -0,0 +1,237 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.plan.nodes.datastream

import org.apache.calcite.rel.core.{JoinInfo, JoinRelType}
import org.apache.calcite.rex._
import org.apache.flink.api.common.functions.FlatJoinFunction
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.streaming.api.functions.co.CoProcessFunction
import org.apache.flink.table.api.{StreamQueryConfig, TableConfig, TableException, ValidationException}
import org.apache.flink.table.calcite.FlinkTypeFactory._
import org.apache.flink.table.codegen.GeneratedFunction
import org.apache.flink.table.plan.logical.rel.LogicalTemporalTableJoin
import org.apache.flink.table.plan.logical.rel.LogicalTemporalTableJoin._
import org.apache.flink.table.plan.schema.RowSchema
import org.apache.flink.table.plan.util.RexDefaultVisitor
import org.apache.flink.table.runtime.join.TemporalJoin
import org.apache.flink.table.runtime.types.CRow
import org.apache.flink.types.Row
import org.apache.flink.util.Preconditions.checkState

class DataStreamTemporalJoinToCoProcessTranslator private (
textualRepresentation: String,
config: TableConfig,
returnType: TypeInformation[Row],
leftSchema: RowSchema,
rightSchema: RowSchema,
joinInfo: JoinInfo,
rexBuilder: RexBuilder,
leftTimeAttribute: RexNode,
rightTimeAttribute: Option[RexNode],
rightPrimaryKeyExpression: RexNode,
remainingNonEquiJoinPredicates: RexNode)
extends DataStreamJoinToCoProcessTranslator(
config,
returnType,
leftSchema,
rightSchema,
joinInfo,
rexBuilder) {

override val nonEquiJoinPredicates: Option[RexNode] = Some(remainingNonEquiJoinPredicates)

override protected def createCoProcessFunction(
joinType: JoinRelType,
queryConfig: StreamQueryConfig,
joinFunction: GeneratedFunction[FlatJoinFunction[Row, Row, Row], Row])
: CoProcessFunction[CRow, CRow, CRow] = {

if (rightTimeAttribute.isDefined) {
throw new ValidationException(
s"Currently only proctime temporal joins are supported in [$textualRepresentation]")
}

joinType match {
case JoinRelType.INNER =>
new TemporalJoin(
leftSchema.typeInfo,
rightSchema.typeInfo,
joinFunction.name,
joinFunction.code,
queryConfig)
case _ =>
throw new ValidationException(
s"Only ${JoinRelType.INNER} temporal join is supported in [$textualRepresentation]")
}
}
}

object DataStreamTemporalJoinToCoProcessTranslator {
def create(
textualRepresentation: String,
config: TableConfig,
returnType: TypeInformation[Row],
leftSchema: RowSchema,
rightSchema: RowSchema,
joinInfo: JoinInfo,
rexBuilder: RexBuilder): DataStreamTemporalJoinToCoProcessTranslator = {

checkState(
!joinInfo.isEqui,
"Missing %s in join condition",
TEMPORAL_JOIN_CONDITION)

val nonEquiJoinRex: RexNode = joinInfo.getRemaining(rexBuilder)
val temporalJoinConditionExtractor = new TemporalJoinConditionExtractor(
textualRepresentation,
leftSchema.typeInfo.getTotalFields,
joinInfo,
rexBuilder)

val remainingNonEquiJoinPredicates = temporalJoinConditionExtractor.apply(nonEquiJoinRex)

checkState(
temporalJoinConditionExtractor.leftTimeAttribute.isDefined &&
temporalJoinConditionExtractor.rightPrimaryKeyExpression.isDefined,
"Missing %s in join condition",
TEMPORAL_JOIN_CONDITION)

new DataStreamTemporalJoinToCoProcessTranslator(
textualRepresentation,
config,
returnType,
leftSchema,
rightSchema,
joinInfo,
rexBuilder,
temporalJoinConditionExtractor.leftTimeAttribute.get,
temporalJoinConditionExtractor.rightTimeAttribute,
temporalJoinConditionExtractor.rightPrimaryKeyExpression.get,
remainingNonEquiJoinPredicates)
}

private class TemporalJoinConditionExtractor(
textualRepresentation: String,
rightKeysStartingOffset: Int,
joinInfo: JoinInfo,
rexBuilder: RexBuilder)

extends RexShuttle {

var leftTimeAttribute: Option[RexNode] = None

var rightTimeAttribute: Option[RexNode] = None

var rightPrimaryKeyExpression: Option[RexNode] = None

override def visitCall(call: RexCall): RexNode = {
if (call.getOperator != TEMPORAL_JOIN_CONDITION) {
return super.visitCall(call)
}

checkState(
leftTimeAttribute.isEmpty
&& rightPrimaryKeyExpression.isEmpty
&& rightTimeAttribute.isEmpty,
"Multiple %s functions in [%s]",
TEMPORAL_JOIN_CONDITION,
textualRepresentation)

if (LogicalTemporalTableJoin.isRowtimeCall(call)) {
leftTimeAttribute = Some(call.getOperands.get(0))
rightTimeAttribute = Some(call.getOperands.get(1))

rightPrimaryKeyExpression = Some(validateRightPrimaryKey(call.getOperands.get(2)))

if (!isRowtimeIndicatorType(rightTimeAttribute.get.getType)) {
throw new ValidationException(
s"Non rowtime timeAttribute [${rightTimeAttribute.get.getType}] " +
s"used to create TemporalTableFunction")
}
if (!isRowtimeIndicatorType(leftTimeAttribute.get.getType)) {
throw new ValidationException(
s"Non rowtime timeAttribute [${leftTimeAttribute.get.getType}] " +
s"passed as the argument to TemporalTableFunction")
}

throw new TableException("Event time temporal joins are not yet supported.")
}
else if (LogicalTemporalTableJoin.isProctimeCall(call)) {
leftTimeAttribute = Some(call.getOperands.get(0))
rightPrimaryKeyExpression = Some(validateRightPrimaryKey(call.getOperands.get(1)))

if (!isProctimeIndicatorType(leftTimeAttribute.get.getType)) {
throw new ValidationException(
s"Non processing timeAttribute [${leftTimeAttribute.get.getType}] " +
s"passed as the argument to TemporalTableFunction")
}
}
else {
throw new IllegalStateException(
s"Unsupported invocation $call in [$textualRepresentation]")
}
rexBuilder.makeLiteral(true)
}

private def validateRightPrimaryKey(rightPrimaryKey: RexNode): RexNode = {
if (joinInfo.rightKeys.size() != 1) {
throw new ValidationException(
s"Only single column join key is supported. " +
s"Found ${joinInfo.rightKeys} in [$textualRepresentation]")
}
val rightKey = joinInfo.rightKeys.get(0) + rightKeysStartingOffset

val primaryKeyVisitor = new PrimaryKeyVisitor(textualRepresentation)
rightPrimaryKey.accept(primaryKeyVisitor)

primaryKeyVisitor.inputReference match {
case None =>
throw new IllegalStateException(
s"Failed to find primary key reference in [$textualRepresentation]")
case Some(primaryKeyInputReference) if primaryKeyInputReference != rightKey =>
throw new ValidationException(
s"Join key [$rightKey] must be the same as " +
s"temporal table's primary key [$primaryKeyInputReference] " +
s"in [$textualRepresentation]")
case _ =>
rightPrimaryKey
}
}
}

/**
* Extracts input references from primary key expression.
*/
private class PrimaryKeyVisitor(textualRepresentation: String)
extends RexDefaultVisitor[RexNode] {

var inputReference: Option[Int] = None

override def visitInputRef(inputRef: RexInputRef): RexNode = {
inputReference = Some(inputRef.getIndex)
inputRef
}

override def visitNode(rexNode: RexNode): RexNode = {
throw new ValidationException(
s"Unsupported right primary key expression [$rexNode] in [$textualRepresentation]")
}
}
}
Expand Up @@ -22,12 +22,20 @@ import org.apache.calcite.plan._
import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.core.{JoinInfo, JoinRelType}
import org.apache.calcite.rex.RexNode
import org.apache.flink.api.common.functions.FlatJoinFunction
import org.apache.flink.streaming.api.datastream.DataStream
import org.apache.flink.table.api.{StreamQueryConfig, StreamTableEnvironment}
import org.apache.flink.table.api.{StreamQueryConfig, StreamTableEnvironment, TableException}
import org.apache.flink.table.codegen.FunctionCodeGenerator
import org.apache.flink.table.plan.schema.RowSchema
import org.apache.flink.table.runtime.types.CRow
import org.apache.flink.table.runtime.CRowKeySelector
import org.apache.flink.table.runtime.join._
import org.apache.flink.table.runtime.types.{CRow, CRowTypeInfo}
import org.apache.flink.types.Row
import org.apache.flink.util.Preconditions.checkState

import scala.collection.JavaConversions._
import scala.collection.mutable.ArrayBuffer

/**
* RelNode for a stream join with [[org.apache.flink.table.functions.TemporalTableFunction]].
*/
Expand Down Expand Up @@ -74,9 +82,14 @@ class DataStreamTemporalTableJoin(
ruleDescription)
}

override def translateToPlan(
tableEnv: StreamTableEnvironment,
queryConfig: StreamQueryConfig): DataStream[CRow] = {
throw new NotImplementedError()
}
}
override protected def createTranslator(
tableEnv: StreamTableEnvironment): DataStreamJoinToCoProcessTranslator = {
DataStreamTemporalJoinToCoProcessTranslator.create(
this.toString,
tableEnv.getConfig,
schema.typeInfo,
leftSchema,
rightSchema,
joinInfo,
cluster.getRexBuilder)
}}

0 comments on commit 00add9c

Please sign in to comment.