Skip to content

Commit

Permalink
[FLINK-2105] Implement Sort-Merge Outer Join algorithm
Browse files Browse the repository at this point in the history
This closes #907
  • Loading branch information
r-pogalz authored and fhueske committed Aug 4, 2015
1 parent df9f481 commit 941ac6d
Show file tree
Hide file tree
Showing 11 changed files with 974 additions and 36 deletions.
Expand Up @@ -115,20 +115,20 @@ public void abort() {
}

/**
* Calls the <code>JoinFunction#match()</code> method for all two key-value pairs that share the same key and come
* from different inputs. The output of the <code>match()</code> method is forwarded.
* Calls the <code>JoinFunction#join()</code> method for all two key-value pairs that share the same key and come
* from different inputs. The output of the <code>join()</code> method is forwarded.
* <p>
* This method first zig-zags between the two sorted inputs in order to find a common
* key, and then calls the match stub with the cross product of the values.
* key, and then calls the join stub with the cross product of the values.
*
* @throws Exception Forwards all exceptions from the user code and the I/O system.
* @see org.apache.flink.runtime.operators.util.JoinTaskIterator#callWithNextKey(org.apache.flink.api.common.functions.FlatJoinFunction, org.apache.flink.util.Collector)
*/
@Override
public abstract boolean callWithNextKey(final FlatJoinFunction<T1, T2, O> matchFunction, final Collector<O> collector)
public abstract boolean callWithNextKey(final FlatJoinFunction<T1, T2, O> joinFunction, final Collector<O> collector)
throws Exception;

protected void crossMatchingGroup(Iterator<T1> values1, Iterator<T2> values2, FlatJoinFunction<T1, T2, O> matchFunction, Collector<O> collector) throws Exception {
protected void crossMatchingGroup(Iterator<T1> values1, Iterator<T2> values2, FlatJoinFunction<T1, T2, O> joinFunction, Collector<O> collector) throws Exception {
final T1 firstV1 = values1.next();
final T2 firstV2 = values2.next();

Expand All @@ -143,45 +143,45 @@ protected void crossMatchingGroup(Iterator<T1> values1, Iterator<T2> values2, Fl
if (v2HasNext) {
// both sides contain more than one value
// TODO: Decide which side to spill and which to block!
crossMwithNValues(firstV1, values1, firstV2, values2, matchFunction, collector);
crossMwithNValues(firstV1, values1, firstV2, values2, joinFunction, collector);
} else {
crossSecond1withNValues(firstV2, firstV1, values1, matchFunction, collector);
crossSecond1withNValues(firstV2, firstV1, values1, joinFunction, collector);
}
} else {
if (v2HasNext) {
crossFirst1withNValues(firstV1, firstV2, values2, matchFunction, collector);
crossFirst1withNValues(firstV1, firstV2, values2, joinFunction, collector);
} else {
// both sides contain only one value
matchFunction.join(firstV1, firstV2, collector);
joinFunction.join(firstV1, firstV2, collector);
}
}
}

/**
* Crosses a single value from the first input with N values, all sharing a common key.
* Effectively realizes a <i>1:N</i> match (join).
* Effectively realizes a <i>1:N</i> join.
*
* @param val1 The value form the <i>1</i> side.
* @param firstValN The first of the values from the <i>N</i> side.
* @param valsN Iterator over remaining <i>N</i> side values.
* @throws Exception Forwards all exceptions thrown by the stub.
*/
private void crossFirst1withNValues(final T1 val1, final T2 firstValN,
final Iterator<T2> valsN, final FlatJoinFunction<T1, T2, O> matchFunction, final Collector<O> collector)
final Iterator<T2> valsN, final FlatJoinFunction<T1, T2, O> joinFunction, final Collector<O> collector)
throws Exception {
T1 copy1 = createCopy(serializer1, val1, this.copy1);
matchFunction.join(copy1, firstValN, collector);
joinFunction.join(copy1, firstValN, collector);

// set copy and match first element
// set copy and join first element
boolean more = true;
do {
final T2 nRec = valsN.next();

if (valsN.hasNext()) {
copy1 = createCopy(serializer1, val1, this.copy1);
matchFunction.join(copy1, nRec, collector);
joinFunction.join(copy1, nRec, collector);
} else {
matchFunction.join(val1, nRec, collector);
joinFunction.join(val1, nRec, collector);
more = false;
}
}
Expand All @@ -190,28 +190,28 @@ private void crossFirst1withNValues(final T1 val1, final T2 firstValN,

/**
* Crosses a single value from the second side with N values, all sharing a common key.
* Effectively realizes a <i>N:1</i> match (join).
* Effectively realizes a <i>N:1</i> join.
*
* @param val1 The value form the <i>1</i> side.
* @param firstValN The first of the values from the <i>N</i> side.
* @param valsN Iterator over remaining <i>N</i> side values.
* @throws Exception Forwards all exceptions thrown by the stub.
*/
private void crossSecond1withNValues(T2 val1, T1 firstValN,
Iterator<T1> valsN, FlatJoinFunction<T1, T2, O> matchFunction, Collector<O> collector) throws Exception {
Iterator<T1> valsN, FlatJoinFunction<T1, T2, O> joinFunction, Collector<O> collector) throws Exception {
T2 copy2 = createCopy(serializer2, val1, this.copy2);
matchFunction.join(firstValN, copy2, collector);
joinFunction.join(firstValN, copy2, collector);

// set copy and match first element
// set copy and join first element
boolean more = true;
do {
final T1 nRec = valsN.next();

if (valsN.hasNext()) {
copy2 = createCopy(serializer2, val1, this.copy2);
matchFunction.join(nRec, copy2, collector);
joinFunction.join(nRec, copy2, collector);
} else {
matchFunction.join(nRec, val1, collector);
joinFunction.join(nRec, val1, collector);
more = false;
}
}
Expand All @@ -220,7 +220,7 @@ private void crossSecond1withNValues(T2 val1, T1 firstValN,

private void crossMwithNValues(final T1 firstV1, Iterator<T1> spillVals,
final T2 firstV2, final Iterator<T2> blockVals,
final FlatJoinFunction<T1, T2, O> matchFunction, final Collector<O> collector) throws Exception {
final FlatJoinFunction<T1, T2, O> joinFunction, final Collector<O> collector) throws Exception {
// ==================================================
// We have one first (head) element from both inputs (firstV1 and firstV2)
// We have an iterator for both inputs.
Expand All @@ -237,13 +237,13 @@ private void crossMwithNValues(final T1 firstV1, Iterator<T1> spillVals,
// 5) cross the head of the spilling side with the next block
// 6) cross the spilling iterator with the next block.

// match the first values first
// join the first values first
T1 copy1 = this.createCopy(serializer1, firstV1, this.copy1);
T2 blockHeadCopy = this.createCopy(serializer2, firstV2, this.blockHeadCopy);
T1 spillHeadCopy = null;

// --------------- 1) Cross the heads -------------------
matchFunction.join(copy1, firstV2, collector);
joinFunction.join(copy1, firstV2, collector);

// for the remaining values, we do a block-nested-loops join
SpillingResettableIterator<T1> spillIt = null;
Expand All @@ -256,7 +256,7 @@ private void crossMwithNValues(final T1 firstV1, Iterator<T1> spillVals,
while (this.blockIt.hasNext()) {
final T2 nextBlockRec = this.blockIt.next();
copy1 = this.createCopy(serializer1, firstV1, this.copy1);
matchFunction.join(copy1, nextBlockRec, collector);
joinFunction.join(copy1, nextBlockRec, collector);
}
this.blockIt.reset();

Expand Down Expand Up @@ -286,15 +286,15 @@ private void crossMwithNValues(final T1 firstV1, Iterator<T1> spillVals,

// -------- 3) cross the iterator of the spilling side with the head of the block side --------
T2 copy2 = this.createCopy(serializer2, blockHeadCopy, this.copy2);
matchFunction.join(copy1, copy2, collector);
joinFunction.join(copy1, copy2, collector);

// -------- 4) cross the iterator of the spilling side with the first block --------
while (this.blockIt.hasNext()) {
T2 nextBlockRec = this.blockIt.next();

// get instances of key and block value
copy1 = this.createCopy(serializer1, nextSpillVal, this.copy1);
matchFunction.join(copy1, nextBlockRec, collector);
joinFunction.join(copy1, nextBlockRec, collector);
}
// reset block iterator
this.blockIt.reset();
Expand All @@ -316,7 +316,7 @@ private void crossMwithNValues(final T1 firstV1, Iterator<T1> spillVals,
while (this.blockIt.hasNext()) {
copy1 = this.createCopy(serializer1, spillHeadCopy, this.copy1);
final T2 nextBlockVal = blockIt.next();
matchFunction.join(copy1, nextBlockVal, collector);
joinFunction.join(copy1, nextBlockVal, collector);
}
this.blockIt.reset();

Expand All @@ -329,7 +329,7 @@ private void crossMwithNValues(final T1 firstV1, Iterator<T1> spillVals,
// get instances of key and block value
final T2 nextBlockVal = this.blockIt.next();
copy1 = this.createCopy(serializer1, nextSpillVal, this.copy1);
matchFunction.join(copy1, nextBlockVal, collector);
joinFunction.join(copy1, nextBlockVal, collector);
}

// reset block iterator
Expand Down
@@ -0,0 +1,189 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.runtime.operators.sort;

import org.apache.flink.api.common.functions.FlatJoinFunction;
import org.apache.flink.api.common.typeutils.TypeComparator;
import org.apache.flink.api.common.typeutils.TypePairComparator;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.runtime.io.disk.iomanager.IOManager;
import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
import org.apache.flink.runtime.memorymanager.MemoryAllocationException;
import org.apache.flink.runtime.memorymanager.MemoryManager;
import org.apache.flink.util.Collector;
import org.apache.flink.util.MutableObjectIterator;

import java.util.Iterator;

/**
* An implementation of the {@link org.apache.flink.runtime.operators.util.JoinTaskIterator} that realizes the
* outer join through a sort-merge join strategy.
*/
public abstract class AbstractMergeOuterJoinIterator<T1, T2, O> extends AbstractMergeIterator<T1, T2, O> {

public static enum OuterJoinType {LEFT, RIGHT, FULL}

private final OuterJoinType outerJoinType;

private boolean initialized = false;
private boolean it1Empty = false;
private boolean it2Empty = false;


public AbstractMergeOuterJoinIterator(
OuterJoinType outerJoinType,
MutableObjectIterator<T1> input1,
MutableObjectIterator<T2> input2,
TypeSerializer<T1> serializer1, TypeComparator<T1> comparator1,
TypeSerializer<T2> serializer2, TypeComparator<T2> comparator2,
TypePairComparator<T1, T2> pairComparator,
MemoryManager memoryManager,
IOManager ioManager,
int numMemoryPages,
AbstractInvokable parentTask)
throws MemoryAllocationException {
super(input1, input2, serializer1, comparator1, serializer2, comparator2, pairComparator, memoryManager, ioManager, numMemoryPages, parentTask);

this.outerJoinType = outerJoinType;
}

/**
* Calls the <code>JoinFunction#join()</code> method for all two key-value pairs that share the same key and come
* from different inputs. Furthermore, depending on the outer join type (LEFT, RIGHT, FULL), all key-value pairs where no
* matching partner from the other input exists are joined with null.
* The output of the <code>join()</code> method is forwarded.
*
* @throws Exception Forwards all exceptions from the user code and the I/O system.
* @see org.apache.flink.runtime.operators.util.JoinTaskIterator#callWithNextKey(org.apache.flink.api.common.functions.FlatJoinFunction, org.apache.flink.util.Collector)
*/
@Override
public boolean callWithNextKey(final FlatJoinFunction<T1, T2, O> joinFunction, final Collector<O> collector) throws Exception {
if (!initialized) {
//first run, set iterators to first elements
it1Empty = !this.iterator1.nextKey();
it2Empty = !this.iterator2.nextKey();
initialized = true;
}

if (it1Empty && it2Empty) {
return false;
} else if (it2Empty) {
if (outerJoinType == OuterJoinType.LEFT || outerJoinType == OuterJoinType.FULL) {
joinLeftKeyValuesWithNull(iterator1.getValues(), joinFunction, collector);
it1Empty = !iterator1.nextKey();
return true;
} else {
//consume rest of left side
while (iterator1.nextKey()) ;
it1Empty = true;
return false;
}
} else if (it1Empty) {
if (outerJoinType == OuterJoinType.RIGHT || outerJoinType == OuterJoinType.FULL) {
joinRightKeyValuesWithNull(iterator2.getValues(), joinFunction, collector);
it2Empty = !iterator2.nextKey();
return true;
} else {
//consume rest of right side
while (iterator2.nextKey()) ;
it2Empty = true;
return false;
}
} else {
final TypePairComparator<T1, T2> comparator = super.pairComparator;
comparator.setReference(this.iterator1.getCurrent());
T2 current2 = this.iterator2.getCurrent();

// zig zag
while (true) {
// determine the relation between the (possibly composite) keys
final int comp = comparator.compareToReference(current2);

if (comp == 0) {
break;
}

if (comp < 0) {
//right key < left key
if (outerJoinType == OuterJoinType.RIGHT || outerJoinType == OuterJoinType.FULL) {
//join right key values with null in case of right or full outer join
joinRightKeyValuesWithNull(iterator2.getValues(), joinFunction, collector);
it2Empty = !iterator2.nextKey();
return true;
} else {
//skip this right key if it is a left outer join
if (!this.iterator2.nextKey()) {
//if right side is empty, join current left key values with null
joinLeftKeyValuesWithNull(iterator1.getValues(), joinFunction, collector);
it1Empty = !iterator1.nextKey();
it2Empty = true;
return true;
}
current2 = this.iterator2.getCurrent();
}
} else {
//right key > left key
if (outerJoinType == OuterJoinType.LEFT || outerJoinType == OuterJoinType.FULL) {
//join left key values with null in case of left or full outer join
joinLeftKeyValuesWithNull(iterator1.getValues(), joinFunction, collector);
it1Empty = !iterator1.nextKey();
return true;
} else {
//skip this left key if it is a right outer join
if (!this.iterator1.nextKey()) {
//if right side is empty, join current right key values with null
joinRightKeyValuesWithNull(iterator2.getValues(), joinFunction, collector);
it1Empty = true;
it2Empty = !iterator2.nextKey();
return true;
}
comparator.setReference(this.iterator1.getCurrent());
}
}
}

// here, we have a common key! call the join function with the cross product of the
// values
final Iterator<T1> values1 = this.iterator1.getValues();
final Iterator<T2> values2 = this.iterator2.getValues();

crossMatchingGroup(values1, values2, joinFunction, collector);
it1Empty = !iterator1.nextKey();
it2Empty = !iterator2.nextKey();
return true;
}
}

private void joinLeftKeyValuesWithNull(Iterator<T1> values, FlatJoinFunction<T1, T2, O> joinFunction, Collector<O> collector) throws Exception {
while (values.hasNext()) {
T1 next = values.next();
this.copy1 = createCopy(serializer1, next, copy1);
joinFunction.join(copy1, null, collector);
}
}

private void joinRightKeyValuesWithNull(Iterator<T2> values, FlatJoinFunction<T1, T2, O> joinFunction, Collector<O> collector) throws Exception {
while (values.hasNext()) {
T2 next = values.next();
this.copy2 = createCopy(serializer2, next, copy2);
joinFunction.join(null, copy2, collector);
}
}

}

0 comments on commit 941ac6d

Please sign in to comment.