Skip to content

Commit

Permalink
[FLINK-2576] Add outer join operator to Scala DataSet API
Browse files Browse the repository at this point in the history
  • Loading branch information
jkovacs committed Sep 16, 2015
1 parent 1ccca5b commit b66b1b0
Show file tree
Hide file tree
Showing 3 changed files with 347 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.apache.flink.api.java.Utils.CountHelper
import org.apache.flink.api.java.aggregation.Aggregations
import org.apache.flink.api.java.functions.{FirstReducer, KeySelector}
import org.apache.flink.api.java.io.{DiscardingOutputFormat, PrintingOutputFormat, TextOutputFormat}
import org.apache.flink.api.java.operators.JoinOperator.JoinType
import org.apache.flink.api.java.operators.Keys.ExpressionKeys
import org.apache.flink.api.java.operators._
import org.apache.flink.api.java.{DataSet => JavaDataSet, Utils}
Expand Down Expand Up @@ -840,11 +841,11 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) {

/**
* Creates a new DataSet by joining `this` DataSet with the `other` DataSet. To specify the join
* keys the `where` and `isEqualTo` methods must be used. For example:
* keys the `where` and `equalTo` methods must be used. For example:
* {{{
* val left: DataSet[(String, Int, Int)] = ...
* val right: DataSet[(Int, String, Int)] = ...
* val joined = left.join(right).where(0).isEqualTo(1)
* val joined = left.join(right).where(0).equalTo(1)
* }}}
*
* The default join result is a DataSet with 2-Tuples of the joined values. In the above example
Expand All @@ -854,7 +855,7 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) {
* {{{
* val left: DataSet[(String, Int, Int)] = ...
* val right: DataSet[(Int, String, Int)] = ...
* val joined = left.join(right).where(0).isEqualTo(1) { (l, r) =>
* val joined = left.join(right).where(0).equalTo(1) { (l, r) =>
* (l._1, r._2)
* }
* }}}
Expand All @@ -864,7 +865,7 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) {
* {{{
* val left: DataSet[(String, Int, Int)] = ...
* val right: DataSet[(Int, String, Int)] = ...
* val joined = left.join(right).where(0).isEqualTo(1) {
* val joined = left.join(right).where(0).equalTo(1) {
* (l, r, out: Collector[(String, Int)]) =>
* if (l._2 > 4) {
* out.collect((l._1, r._3))
Expand Down Expand Up @@ -899,6 +900,96 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) {
def joinWithHuge[O](other: DataSet[O]): UnfinishedJoinOperation[T, O] =
new UnfinishedJoinOperation(this, other, JoinHint.BROADCAST_HASH_FIRST)

/**
* Creates a new DataSet by performing a full outer join of `this` DataSet
* with the `other` DataSet, by combining two elements of two DataSets on
* key equality.
* Elements of both DataSets that do not have a matching element on the
* opposing side are joined with `null` and emitted to the resulting DataSet.
*
* To specify the join keys the `where` and `equalTo` methods must be used. For example:
* {{{
* val left: DataSet[(String, Int, Int)] = ...
* val right: DataSet[(Int, String, Int)] = ...
* val joined = left.fullOuterJoin(right).where(0).equalTo(1)
* }}}
*
* When using an outer join you are required to specify a join function. For example:
* {{{
* val joined = left.fullOuterJoin(right).where(0).equalTo(1) {
* (left, right) =>
* val a = if (left == null) null else left._1
* val b = if (right == null) null else right._3
* (a, b)
* }
* }}}
*/
def fullOuterJoin[O](other: DataSet[O]): UnfinishedJoinOperation[T, O] =
new UnfinishedJoinOperation(this, other, JoinHint.OPTIMIZER_CHOOSES, JoinType.FULL_OUTER)

/**
* Special [[fullOuterJoin]] operation for explicitly telling the system what join strategy to
* use. If null is given as the join strategy, then the optimizer will pick the strategy.
*/
def fullOuterJoin[O](other: DataSet[O], strategy: JoinHint): UnfinishedJoinOperation[T, O] =
new UnfinishedJoinOperation(this, other, strategy, JoinType.FULL_OUTER)

/**
* An outer join on the left side.
*
* Elements of the left side (i.e. `this`) that do not have a matching element on the other
* side are joined with `null` and emitted to the resulting DataSet.
*
* @param other The other DataSet with which this DataSet is joined.
* @return An UnfinishedJoinOperation to continue with the definition of the join transformation
* @see #fullOuterJoin
*/
def leftOuterJoin[O](other: DataSet[O]): UnfinishedJoinOperation[T, O] =
new UnfinishedJoinOperation(this, other, JoinHint.OPTIMIZER_CHOOSES, JoinType.LEFT_OUTER)

/**
* An outer join on the left side.
*
* Elements of the left side (i.e. `this`) that do not have a matching element on the other
* side are joined with `null` and emitted to the resulting DataSet.
*
* @param other The other DataSet with which this DataSet is joined.
* @param strategy The strategy that should be used execute the join. If { @code null} is given,
* then the optimizer will pick the join strategy.
* @return An UnfinishedJoinOperation to continue with the definition of the join transformation
* @see #fullOuterJoin
*/
def leftOuterJoin[O](other: DataSet[O], strategy: JoinHint): UnfinishedJoinOperation[T, O] =
new UnfinishedJoinOperation(this, other, strategy, JoinType.LEFT_OUTER)

/**
* An outer join on the right side.
*
* Elements of the right side (i.e. `other`) that do not have a matching element on `this`
* side are joined with `null` and emitted to the resulting DataSet.
*
* @param other The other DataSet with which this DataSet is joined.
* @return An UnfinishedJoinOperation to continue with the definition of the join transformation
* @see #fullOuterJoin
*/
def rightOuterJoin[O](other: DataSet[O]): UnfinishedJoinOperation[T, O] =
new UnfinishedJoinOperation(this, other, JoinHint.OPTIMIZER_CHOOSES, JoinType.RIGHT_OUTER)

/**
* An outer join on the right side.
*
* Elements of the right side (i.e. `other`) that do not have a matching element on `this`
* side are joined with `null` and emitted to the resulting DataSet.
*
* @param other The other DataSet with which this DataSet is joined.
* @param strategy The strategy that should be used execute the join. If { @code null} is given,
* then the optimizer will pick the join strategy.
* @return An UnfinishedJoinOperation to continue with the definition of the join transformation
* @see #fullOuterJoin
*/
def rightOuterJoin[O](other: DataSet[O], strategy: JoinHint): UnfinishedJoinOperation[T, O] =
new UnfinishedJoinOperation(this, other, strategy, JoinType.RIGHT_OUTER)

// --------------------------------------------------------------------------------------------
// Co-Group
// --------------------------------------------------------------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@
*/
package org.apache.flink.api.scala

import org.apache.flink.api.common.ExecutionConfig
import org.apache.flink.api.common.operators.Operator
import org.apache.flink.api.common.operators.base.AbstractJoinOperatorBase
import org.apache.flink.api.common.{InvalidProgramException, ExecutionConfig}
import org.apache.flink.api.common.functions.{FlatJoinFunction, JoinFunction, Partitioner, RichFlatJoinFunction}
import org.apache.flink.api.common.operators.base.AbstractJoinOperatorBase.JoinHint
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.common.typeutils.TypeSerializer
import org.apache.flink.api.java.operators.JoinOperator.DefaultJoin.WrappingFlatJoinFunction
import org.apache.flink.api.java.operators.JoinOperator.EquiJoin
import org.apache.flink.api.java.operators.JoinOperator.{JoinType, EquiJoin}
import org.apache.flink.api.java.operators._
import org.apache.flink.api.scala.typeutils.{CaseClassSerializer, CaseClassTypeInfo}
import org.apache.flink.util.Collector
Expand Down Expand Up @@ -87,7 +89,8 @@ class JoinDataSet[L, R](
joiner,
implicitly[TypeInformation[O]],
defaultJoin.getJoinHint,
getCallLocationName())
getCallLocationName(),
defaultJoin.getJoinType)

if (customPartitioner != null) {
wrap(joinOperator.withPartitioner(customPartitioner))
Expand Down Expand Up @@ -117,7 +120,8 @@ class JoinDataSet[L, R](
joiner,
implicitly[TypeInformation[O]],
defaultJoin.getJoinHint,
getCallLocationName())
getCallLocationName(),
defaultJoin.getJoinType)

if (customPartitioner != null) {
wrap(joinOperator.withPartitioner(customPartitioner))
Expand Down Expand Up @@ -145,7 +149,8 @@ class JoinDataSet[L, R](
joiner,
implicitly[TypeInformation[O]],
defaultJoin.getJoinHint,
getCallLocationName())
getCallLocationName(),
defaultJoin.getJoinType)

if (customPartitioner != null) {
wrap(joinOperator.withPartitioner(customPartitioner))
Expand Down Expand Up @@ -174,7 +179,8 @@ class JoinDataSet[L, R](
generatedFunction, fun,
implicitly[TypeInformation[O]],
defaultJoin.getJoinHint,
getCallLocationName())
getCallLocationName(),
defaultJoin.getJoinType)

if (customPartitioner != null) {
wrap(joinOperator.withPartitioner(customPartitioner))
Expand Down Expand Up @@ -223,9 +229,13 @@ class JoinDataSet[L, R](
class UnfinishedJoinOperation[L, R](
leftSet: DataSet[L],
rightSet: DataSet[R],
val joinHint: JoinHint)
val joinHint: JoinHint,
val joinType: JoinType)
extends UnfinishedKeyPairOperation[L, R, JoinDataSet[L, R]](leftSet, rightSet) {

def this(leftSet: DataSet[L], rightSet: DataSet[R], joinHint: JoinHint) =
this(leftSet, rightSet, joinHint, JoinType.INNER)

private[flink] def finish(leftKey: Keys[L], rightKey: Keys[R]) = {
val joiner = new FlatJoinFunction[L, R, (L, R)] {
def join(left: L, right: R, out: Collector[(L, R)]) = {
Expand Down Expand Up @@ -253,7 +263,16 @@ class UnfinishedJoinOperation[L, R](
}
val joinOperator = new EquiJoin[L, R, (L, R)](
leftSet.javaSet, rightSet.javaSet, leftKey, rightKey, joiner, returnType, joinHint,
getCallLocationName())
getCallLocationName(), joinType) {

override protected def translateToDataFlow(input1: Operator[L], input2: Operator[R]):
AbstractJoinOperatorBase[_, _, (L, R), _] = {
if (joinType.isOuter) {
throw new InvalidProgramException("Must specify a custom join function for outer join")
}
super.translateToDataFlow(input1, input2)
}
}

new JoinDataSet(joinOperator, leftSet, rightSet, leftKey, rightKey)
}
Expand Down

0 comments on commit b66b1b0

Please sign in to comment.