Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.flink.table.plan.nodes.dataset

import org.apache.calcite.plan._
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.JoinRelType
import org.apache.calcite.rel.metadata.RelMetadataQuery
import org.apache.calcite.rel.{BiRel, RelNode, RelWriter}
import org.apache.calcite.rex.RexNode
Expand Down Expand Up @@ -47,6 +48,7 @@ class DataSetSingleRowJoin(
rowRelDataType: RelDataType,
joinCondition: RexNode,
joinRowType: RelDataType,
joinType: JoinRelType,
ruleDescription: String)
extends BiRel(cluster, traitSet, leftNode, rightNode)
with DataSetRel {
Expand All @@ -63,6 +65,7 @@ class DataSetSingleRowJoin(
getRowType,
joinCondition,
joinRowType,
joinType,
ruleDescription)
}

Expand Down Expand Up @@ -97,7 +100,6 @@ class DataSetSingleRowJoin(
tableEnv.getConfig,
leftDataSet.getType,
rightDataSet.getType,
leftIsSingle,
joinCondition,
broadcastSetName)

Expand All @@ -118,14 +120,18 @@ class DataSetSingleRowJoin(
config: TableConfig,
inputType1: TypeInformation[Row],
inputType2: TypeInformation[Row],
firstIsSingle: Boolean,
joinCondition: RexNode,
broadcastInputSetName: String)
: FlatMapFunction[Row, Row] = {

val isOuterJoin = joinType match {
case JoinRelType.LEFT | JoinRelType.RIGHT => true
case _ => false
}

val codeGenerator = new CodeGenerator(
config,
false,
isOuterJoin,
inputType1,
Some(inputType2))

Expand All @@ -138,30 +144,57 @@ class DataSetSingleRowJoin(
val condition = codeGenerator.generateExpression(joinCondition)

val joinMethodBody =
s"""
|${condition.code}
|if (${condition.resultTerm}) {
| ${conversion.code}
| ${codeGenerator.collectorTerm}.collect(${conversion.resultTerm});
|}
|""".stripMargin
if (joinType == JoinRelType.INNER) {
s"""
|${condition.code}
|if (${condition.resultTerm}) {
| ${conversion.code}
| ${codeGenerator.collectorTerm}.collect(${conversion.resultTerm});
|}
|""".stripMargin
} else {
val singleNode =
if (!leftIsSingle) {
rightNode
}
else {
leftNode
}

val notSuitedToCondition = singleNode
.getRowType
.getFieldList
.map(field => getRowType.getFieldNames.indexOf(field.getName))
.map(i => s"${conversion.resultTerm}.setField($i,null);")

s"""
|${condition.code}
|${conversion.code}
|if(!${condition.resultTerm}){
|${notSuitedToCondition.mkString("\n")}
|}
|${codeGenerator.collectorTerm}.collect(${conversion.resultTerm});
|""".stripMargin
}

val genFunction = codeGenerator.generateFunction(
ruleDescription,
classOf[FlatJoinFunction[Row, Row, Row]],
joinMethodBody,
returnType)

if (firstIsSingle) {
new MapJoinRightRunner[Row, Row, Row](
if (!leftIsSingle) {
new MapJoinLeftRunner[Row, Row, Row](
genFunction.name,
genFunction.code,
isOuterJoin,
genFunction.returnType,
broadcastInputSetName)
} else {
new MapJoinLeftRunner[Row, Row, Row](
new MapJoinRightRunner[Row, Row, Row](
genFunction.name,
genFunction.code,
isOuterJoin,
genFunction.returnType,
broadcastInputSetName)
}
Expand All @@ -181,7 +214,7 @@ class DataSetSingleRowJoin(
}

private def joinTypeToString: String = {
"NestedLoopJoin"
"NestedLoop" + joinType.toString.toLowerCase.capitalize + "Join"
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ private class FlinkLogicalJoinConverter
val join: LogicalJoin = call.rel(0).asInstanceOf[LogicalJoin]
val joinInfo = join.analyzeCondition

hasEqualityPredicates(join, joinInfo) || isSingleRowInnerJoin(join)
(hasEqualityPredicates(join, joinInfo)
|| isSingleRowInnerJoin(join)
|| isOuterJoinWithSingleRowAtOuterSide(join, joinInfo))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need here the same condition as in DataSetSingleRowJoinRule.matches():

join.getJoinType match {
  case JoinRelType.INNER if isSingleRow(join.getRight) || isSingleRow(join.getLeft) => true
  case JoinRelType.LEFT if isSingleRow(join.getRight) => true
  case JoinRelType.RIGHT if isSingleRow(join.getLeft) => true
 case _ => false
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fhueske Thanks for your guidances!

}

override def convert(rel: RelNode): RelNode = {
Expand All @@ -101,6 +103,18 @@ private class FlinkLogicalJoinConverter
!joinInfo.pairs().isEmpty && (joinInfo.isEqui || join.getJoinType == JoinRelType.INNER)
}



private def isOuterJoinWithSingleRowAtOuterSide(
join: LogicalJoin,
joinInfo: JoinInfo): Boolean = {

val isLeflSingleOrEmpty = joinInfo.leftKeys.size() < 2
val isRightSingleOrEmpty = joinInfo.rightKeys.size() < 2
((join.getJoinType == JoinRelType.RIGHT && isLeflSingleOrEmpty)
|| (join.getJoinType == JoinRelType.LEFT && isRightSingleOrEmpty))
}

private def isSingleRowInnerJoin(join: LogicalJoin): Boolean = {
if (join.getJoinType == JoinRelType.INNER) {
isSingleRow(join.getRight) || isSingleRow(join.getLeft)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,11 @@ class DataSetSingleRowJoinRule
override def matches(call: RelOptRuleCall): Boolean = {
val join = call.rel(0).asInstanceOf[FlinkLogicalJoin]

if (isInnerJoin(join)) {
isSingleRow(join.getRight) || isSingleRow(join.getLeft)
} else {
false
join.getJoinType match {
case JoinRelType.INNER if isSingleRow(join.getLeft) || isSingleRow(join.getRight) => true
case JoinRelType.LEFT if isSingleRow(join.getRight) => true
case JoinRelType.RIGHT if isSingleRow(join.getLeft) => true
case _ => false
}
}

Expand Down Expand Up @@ -79,6 +80,7 @@ class DataSetSingleRowJoinRule
rel.getRowType,
join.getCondition,
join.getRowType,
join.getJoinType,
description)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,24 @@
package org.apache.flink.table.runtime

import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.types.Row
import org.apache.flink.util.Collector

class MapJoinLeftRunner[IN1, IN2, OUT](
name: String,
code: String,
outerJoin: Boolean,
returnType: TypeInformation[OUT],
broadcastSetName: String)
extends MapSideJoinRunner[IN1, IN2, IN2, IN1, OUT](name, code, returnType, broadcastSetName) {

override def flatMap(multiInput: IN1, out: Collector[OUT]): Unit = {
broadcastSet match {
case Some(singleInput) => function.join(multiInput, singleInput, out)
case None if outerJoin => function.
join(multiInput, null.asInstanceOf[IN2], out)
case None =>
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,24 @@
package org.apache.flink.table.runtime

import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.types.Row
import org.apache.flink.util.Collector

class MapJoinRightRunner[IN1, IN2, OUT](
name: String,
code: String,
outerJoin: Boolean,
returnType: TypeInformation[OUT],
broadcastSetName: String)
extends MapSideJoinRunner[IN1, IN2, IN1, IN2, OUT](name, code, returnType, broadcastSetName) {

override def flatMap(multiInput: IN2, out: Collector[OUT]): Unit = {
broadcastSet match {
case Some(singleInput) => function.join(singleInput, multiInput, out)
case None if outerJoin => function.
join(null.asInstanceOf[IN1], multiInput, out)
case None =>
}
}

}
Loading