Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYSTEMDS-3695] Fix frame builtin nary append for Spark backend #2026

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,19 @@
package org.apache.sysds.runtime.instructions.spark;

import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.NotImplementedException;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.sysds.hops.BinaryOp;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.frame.data.FrameBlock;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.functionobjects.Plus;
import org.apache.sysds.runtime.instructions.InstructionUtils;
Expand All @@ -34,6 +41,7 @@
import org.apache.sysds.runtime.instructions.spark.AppendGSPInstruction.ShiftMatrix;
import org.apache.sysds.runtime.instructions.spark.functions.MapInputSignature;
import org.apache.sysds.runtime.instructions.spark.functions.MapJoinSignature;
import org.apache.sysds.runtime.instructions.spark.utils.FrameRDDAggregateUtils;
import org.apache.sysds.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysds.runtime.instructions.spark.utils.SparkUtils;
import org.apache.sysds.runtime.lineage.LineageItem;
Expand All @@ -47,8 +55,16 @@
import org.apache.sysds.runtime.util.UtilFunctions;
import scala.Tuple2;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

import static org.apache.sysds.hops.BinaryOp.AppendMethod.MR_MAPPEND;
import static org.apache.sysds.hops.BinaryOp.AppendMethod.MR_RAPPEND;
import static org.apache.sysds.hops.OptimizerUtils.DEFAULT_FRAME_BLOCKSIZE;
import static org.apache.sysds.runtime.instructions.spark.FrameAppendMSPInstruction.appendFrameMSP;
import static org.apache.sysds.runtime.instructions.spark.FrameAppendRSPInstruction.appendFrameRSP;

public class BuiltinNarySPInstruction extends SPInstruction implements LineageTraceable
{
private CPOperand[] inputs;
Expand All @@ -75,32 +91,78 @@ public static BuiltinNarySPInstruction parseInstruction ( String str ) {
public void processInstruction(ExecutionContext ec) {
SparkExecutionContext sec = (SparkExecutionContext)ec;
JavaPairRDD<MatrixIndexes,MatrixBlock> out = null;
DataCharacteristics mcOut = null;
DataCharacteristics dcout = null;
boolean inputIsMatrix = inputs[0].isMatrix();


if( getOpcode().equals("cbind") || getOpcode().equals("rbind") ) {
//compute output characteristics
boolean cbind = getOpcode().equals("cbind");
mcOut = computeAppendOutputDataCharacteristics(sec, inputs, cbind);

//get consolidated input via union over shifted and padded inputs
DataCharacteristics off = new MatrixCharacteristics(0, 0, mcOut.getBlocksize(), 0);
for( CPOperand input : inputs ) {
DataCharacteristics mcIn = sec.getDataCharacteristics(input.getName());
JavaPairRDD<MatrixIndexes,MatrixBlock> in = sec
.getBinaryMatrixBlockRDDHandleForVariable( input.getName() )
.flatMapToPair(new ShiftMatrix(off, mcIn, cbind))
.mapToPair(new PadBlocksFunction(mcOut)); //just padding
out = (out != null) ? out.union(in) : in;
updateAppendDataCharacteristics(mcIn, off, cbind);
dcout = computeAppendOutputDataCharacteristics(sec, inputs, cbind);
if(inputIsMatrix){
//get consolidated input via union over shifted and padded inputs
DataCharacteristics off = new MatrixCharacteristics(0, 0, dcout.getBlocksize(), 0);
for( CPOperand input : inputs ) {
DataCharacteristics mcIn = sec.getDataCharacteristics(input.getName());
JavaPairRDD<MatrixIndexes, MatrixBlock> in = sec
.getBinaryMatrixBlockRDDHandleForVariable(input.getName())
.flatMapToPair(new ShiftMatrix(off, mcIn, cbind))
.mapToPair(new PadBlocksFunction(dcout)); //just padding
out = (out != null) ? out.union(in) : in;
updateAppendDataCharacteristics(mcIn, off, cbind);
}
//aggregate partially overlapping blocks w/ single shuffle
int numPartOut = SparkUtils.getNumPreferredPartitions(dcout);
out = RDDAggregateUtils.mergeByKey(out, numPartOut, false);
}
//FRAME
else {
JavaPairRDD<Long,FrameBlock> outFrame = sec.getFrameBinaryBlockRDDHandleForVariable( inputs[0].getName() );;
dcout = new MatrixCharacteristics(sec.getDataCharacteristics(inputs[0].getName()));
FrameObject fo = new FrameObject(sec.getFrameObject(inputs[0].getName()));
boolean[] broadcasted = new boolean[inputs.length];
broadcasted[0] = false;

for(int i = 1; i < inputs.length; i++){
DataCharacteristics dcIn = sec.getDataCharacteristics(inputs[i].getName());
final int blk_size = dcout.getBlocksize() <= 0 ? DEFAULT_FRAME_BLOCKSIZE : dcout.getBlocksize();

broadcasted[i] = BinaryOp.FORCED_APPEND_METHOD == MR_MAPPEND ||
BinaryOp.FORCED_APPEND_METHOD == null && cbind && dcIn.getCols() <= blk_size &&
OptimizerUtils.checkSparkBroadcastMemoryBudget(dcout.getCols(), dcIn.getCols(), blk_size, dcIn.getNonZeros());

//easy case: broadcast & map
if(broadcasted[i]){
outFrame = appendFrameMSP(outFrame, sec.getBroadcastForFrameVariable(inputs[i].getName()));
}
//general case for frames:
else{
if(BinaryOp.FORCED_APPEND_METHOD != null && BinaryOp.FORCED_APPEND_METHOD != MR_RAPPEND)
throw new DMLRuntimeException("Forced append type ["+BinaryOp.FORCED_APPEND_METHOD+"] is not supported for frames");

JavaPairRDD<Long,FrameBlock> in2 = sec.getFrameBinaryBlockRDDHandleForVariable( inputs[i].getName() );
outFrame = appendFrameRSP(outFrame, in2, dcout.getRows(), cbind);
}
updateAppendDataCharacteristics(dcIn, dcout, cbind);
if(cbind)
fo.setSchema(fo.mergeSchemas(sec.getFrameObject(inputs[i].getName())));
}

//set output RDD and add lineage
sec.getDataCharacteristics(output.getName()).set(dcout);
sec.setRDDHandleForVariable(output.getName(), outFrame);
sec.getFrameObject(output.getName()).setSchema(fo.getSchema());
for( int i = 0; i < inputs.length; i++)
if(broadcasted[i])
sec.addLineageBroadcast(output.getName(), inputs[i].getName());
else
sec.addLineageRDD(output.getName(), inputs[i].getName());
return;
}

//aggregate partially overlapping blocks w/ single shuffle
int numPartOut = SparkUtils.getNumPreferredPartitions(mcOut);
out = RDDAggregateUtils.mergeByKey(out, numPartOut, false);
}
else if( ArrayUtils.contains(new String[]{"nmin","nmax","n+"}, getOpcode()) ) {
//compute output characteristics
mcOut = computeMinMaxOutputDataCharacteristics(sec, inputs);
dcout = computeMinMaxOutputDataCharacteristics(sec, inputs);

//get scalars and consolidated input via join
List<ScalarObject> scalars = sec.getScalarInputs(inputs);
Expand All @@ -118,13 +180,41 @@ else if( ArrayUtils.contains(new String[]{"nmin","nmax","n+"}, getOpcode()) ) {
}

//set output RDD and add lineage
sec.getDataCharacteristics(output.getName()).set(mcOut);
sec.getDataCharacteristics(output.getName()).set(dcout);
sec.setRDDHandleForVariable(output.getName(), out);
for( CPOperand input : inputs )
if( !input.isScalar() )
sec.addLineageRDD(output.getName(), input.getName());
}


private static class AlignBlkTask implements PairFlatMapFunction<Tuple2<Long, FrameBlock>, Long, FrameBlock> {
long max_rows;

public AlignBlkTask(long rows) {
max_rows = rows;
}

@Override
public Iterator<Tuple2<Long, FrameBlock>> call(Tuple2<Long, FrameBlock> longFrameBlockTuple2) throws Exception {
Long index = longFrameBlockTuple2._1;
FrameBlock fb = longFrameBlockTuple2._2;
ArrayList<Tuple2<Long, FrameBlock>> list = new ArrayList<Tuple2<Long, FrameBlock>>();
//single output block
if(max_rows <= DEFAULT_FRAME_BLOCKSIZE){
FrameBlock fbout = new FrameBlock(fb.getSchema());
fbout.ensureAllocatedColumns((int) max_rows);
fbout = fbout.leftIndexingOperations(fb,index.intValue() - 1, index.intValue() + fb.getNumRows() - 2,0, fb.getNumColumns()-1, null );
list.add(new Tuple2<>(1L, fbout));
} else {
throw new NotImplementedException("Other Alignment strategies need to be implemented");
//long aligned_index = (index / DEFAULT_FRAME_BLOCKSIZE)*OptimizerUtils.DEFAULT_FRAME_BLOCKSIZE+1;
//list.add(new Tuple2<>(index / DEFAULT_FRAME_BLOCKSIZE + 1, fb));
}

return list.iterator();
}
}

private static DataCharacteristics computeAppendOutputDataCharacteristics(SparkExecutionContext sec, CPOperand[] inputs, boolean cbind) {
DataCharacteristics mcIn1 = sec.getDataCharacteristics(inputs[0].getName());
DataCharacteristics mcOut = new MatrixCharacteristics(0, 0, mcIn1.getBlocksize(), 0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,7 @@ public void processInstruction(ExecutionContext ec) {
//execute map-append operations (partitioning preserving if keys for blocks not changing)
JavaPairRDD<Long,FrameBlock> out = null;
if( preservesPartitioning(_cbind) ) {
out = in1.mapPartitionsToPair(
new MapSideAppendPartitionFunction(in2), true);
out = appendFrameMSP(in1, in2);
}
else
throw new DMLRuntimeException("Append type rbind not supported for frame mappend, instead use rappend");
Expand All @@ -74,13 +73,20 @@ public void processInstruction(ExecutionContext ec) {
sec.getFrameObject(output.getName()).setSchema(sec.getFrameObject(input1.getName()).getSchema());
}

public static JavaPairRDD<Long, FrameBlock> appendFrameMSP(JavaPairRDD<Long, FrameBlock> in1, PartitionedBroadcast<FrameBlock> in2) {
JavaPairRDD<Long, FrameBlock> out;
out = in1.mapPartitionsToPair(
new MapSideAppendPartitionFunction(in2), true);
return out;
}

private static boolean preservesPartitioning( boolean cbind ) {
//Partitions for input1 will be preserved in case of cbind,
// where as in case of rbind partitions will not be preserved.
return cbind;
}

private static class MapSideAppendPartitionFunction implements PairFlatMapFunction<Iterator<Tuple2<Long,FrameBlock>>, Long, FrameBlock>
private static class MapSideAppendPartitionFunction implements PairFlatMapFunction<Iterator<Tuple2<Long,FrameBlock>>, Long, FrameBlock>
{
private static final long serialVersionUID = -3997051891171313830L;

Expand Down Expand Up @@ -118,8 +124,17 @@ protected Tuple2<Long, FrameBlock> computeNext(Tuple2<Long, FrameBlock> arg)

int rowix = (ix.intValue()-1)/OptimizerUtils.DEFAULT_FRAME_BLOCKSIZE+1;
int colix = 1;


FrameBlock in2 = _pm.getBlock(rowix, colix);

//if misalignment -> slice out fb from RHS
if(in1.getNumRows() != in2.getNumRows()){
int start = ix.intValue() - 1 - (rowix-1)*OptimizerUtils.DEFAULT_FRAME_BLOCKSIZE;
int end = start + in1.getNumRows() - 1;
in2 = in2.slice(start, end);
}

FrameBlock out = in1.append(in2, true); //cbind
return new Tuple2<>(ix, out);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,9 @@ public void processInstruction(ExecutionContext ec) {
JavaPairRDD<Long,FrameBlock> in2 = sec.getFrameBinaryBlockRDDHandleForVariable( input2.getName() );
JavaPairRDD<Long,FrameBlock> out = null;
long leftRows = sec.getDataCharacteristics(input1.getName()).getRows();

if(_cbind) {
JavaPairRDD<Long,FrameBlock> in1Aligned = in1.mapToPair(new ReduceSideAppendAlignFunction(leftRows));
in1Aligned = FrameRDDAggregateUtils.mergeByKey(in1Aligned);
JavaPairRDD<Long,FrameBlock> in2Aligned = in2.mapToPair(new ReduceSideAppendAlignFunction(leftRows));
in2Aligned = FrameRDDAggregateUtils.mergeByKey(in2Aligned);

out = in1Aligned.join(in2Aligned).mapValues(new ReduceSideColumnsFunction(_cbind));
} else { //rbind
JavaPairRDD<Long,FrameBlock> right = in2.mapToPair( new ReduceSideAppendRowsFunction(leftRows));
out = in1.union(right);
}


out = appendFrameRSP(in1, in2, leftRows, _cbind);

//put output RDD handle into symbol table
updateBinaryAppendOutputDataCharacteristics(sec, _cbind);
sec.setRDDHandleForVariable(output.getName(), out);
Expand All @@ -73,6 +63,19 @@ public void processInstruction(ExecutionContext ec) {
sec.getFrameObject(output.getName()).setSchema(sec.getFrameObject(input1.getName()).getSchema());
}

public static JavaPairRDD<Long, FrameBlock> appendFrameRSP(JavaPairRDD<Long, FrameBlock> in1, JavaPairRDD<Long, FrameBlock> in2, long leftRows, boolean cbind) {
if(cbind) {
JavaPairRDD<Long,FrameBlock> in1Aligned = in1.mapToPair(new ReduceSideAppendAlignFunction(leftRows));
in1Aligned = FrameRDDAggregateUtils.mergeByKey(in1Aligned);
JavaPairRDD<Long,FrameBlock> in2Aligned = in2.mapToPair(new ReduceSideAppendAlignFunction(leftRows));
in2Aligned = FrameRDDAggregateUtils.mergeByKey(in2Aligned);
return in1Aligned.join(in2Aligned).mapValues(new ReduceSideColumnsFunction(cbind));
} else { //rbind
JavaPairRDD<Long,FrameBlock> right = in2.mapToPair( new ReduceSideAppendRowsFunction(leftRows));
return in1.union(right);
}
}

private static class ReduceSideColumnsFunction implements Function<Tuple2<FrameBlock, FrameBlock>, FrameBlock>
{
private static final long serialVersionUID = -97824903649667646L;
Expand Down Expand Up @@ -109,7 +112,7 @@ public Tuple2<Long,FrameBlock> call(Tuple2<Long, FrameBlock> arg0)
}
}

private static class ReduceSideAppendAlignFunction implements PairFunction<Tuple2<Long, FrameBlock>, Long, FrameBlock>
private static class ReduceSideAppendAlignFunction implements PairFunction<Tuple2<Long, FrameBlock>, Long, FrameBlock>
{
private static final long serialVersionUID = 5850400295183766409L;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ public void testCP() {
runWerTest(ExecType.CP);
}

// @Test
// public void testSpark() {
// runWerTest(ExecType.SPARK);
// }
@Test
public void testSpark() {
runWerTest(ExecType.SPARK);
}

private void runWerTest(ExecType instType) {
ExecMode platformOld = setExecMode(instType);
Expand Down
Loading
Loading