Skip to content

Commit

Permalink
[SYSTEMDS-3695] Fix missing frame nary-append spark instruction
Browse files Browse the repository at this point in the history
Closes #2026.
  • Loading branch information
e-strauss authored and mboehm7 committed Jun 3, 2024
1 parent 589f574 commit 7d1f081
Show file tree
Hide file tree
Showing 9 changed files with 319 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,19 @@
package org.apache.sysds.runtime.instructions.spark;

import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.NotImplementedException;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.sysds.hops.BinaryOp;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.frame.data.FrameBlock;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.functionobjects.Plus;
import org.apache.sysds.runtime.instructions.InstructionUtils;
Expand All @@ -47,8 +54,16 @@
import org.apache.sysds.runtime.util.UtilFunctions;
import scala.Tuple2;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

import static org.apache.sysds.hops.BinaryOp.AppendMethod.MR_MAPPEND;
import static org.apache.sysds.hops.BinaryOp.AppendMethod.MR_RAPPEND;
import static org.apache.sysds.hops.OptimizerUtils.DEFAULT_FRAME_BLOCKSIZE;
import static org.apache.sysds.runtime.instructions.spark.FrameAppendMSPInstruction.appendFrameMSP;
import static org.apache.sysds.runtime.instructions.spark.FrameAppendRSPInstruction.appendFrameRSP;

public class BuiltinNarySPInstruction extends SPInstruction implements LineageTraceable
{
private CPOperand[] inputs;
Expand All @@ -75,32 +90,82 @@ public static BuiltinNarySPInstruction parseInstruction ( String str ) {
public void processInstruction(ExecutionContext ec) {
SparkExecutionContext sec = (SparkExecutionContext)ec;
JavaPairRDD<MatrixIndexes,MatrixBlock> out = null;
DataCharacteristics mcOut = null;
DataCharacteristics dcout = null;
boolean inputIsMatrix = inputs[0].isMatrix();


if( getOpcode().equals("cbind") || getOpcode().equals("rbind") ) {
//compute output characteristics
boolean cbind = getOpcode().equals("cbind");
mcOut = computeAppendOutputDataCharacteristics(sec, inputs, cbind);

//get consolidated input via union over shifted and padded inputs
DataCharacteristics off = new MatrixCharacteristics(0, 0, mcOut.getBlocksize(), 0);
for( CPOperand input : inputs ) {
DataCharacteristics mcIn = sec.getDataCharacteristics(input.getName());
JavaPairRDD<MatrixIndexes,MatrixBlock> in = sec
.getBinaryMatrixBlockRDDHandleForVariable( input.getName() )
.flatMapToPair(new ShiftMatrix(off, mcIn, cbind))
.mapToPair(new PadBlocksFunction(mcOut)); //just padding
out = (out != null) ? out.union(in) : in;
updateAppendDataCharacteristics(mcIn, off, cbind);
dcout = computeAppendOutputDataCharacteristics(sec, inputs, cbind);
if(inputIsMatrix){
//get consolidated input via union over shifted and padded inputs
DataCharacteristics off = new MatrixCharacteristics(0, 0, dcout.getBlocksize(), 0);
for( CPOperand input : inputs ) {
DataCharacteristics mcIn = sec.getDataCharacteristics(input.getName());
JavaPairRDD<MatrixIndexes, MatrixBlock> in = sec
.getBinaryMatrixBlockRDDHandleForVariable(input.getName())
.flatMapToPair(new ShiftMatrix(off, mcIn, cbind))
.mapToPair(new PadBlocksFunction(dcout)); //just padding
out = (out != null) ? out.union(in) : in;
updateAppendDataCharacteristics(mcIn, off, cbind);
}
//aggregate partially overlapping blocks w/ single shuffle
int numPartOut = SparkUtils.getNumPreferredPartitions(dcout);
out = RDDAggregateUtils.mergeByKey(out, numPartOut, false);
}
//FRAME
else {
JavaPairRDD<Long,FrameBlock> outFrame =
sec.getFrameBinaryBlockRDDHandleForVariable( inputs[0].getName() );
dcout = new MatrixCharacteristics(sec.getDataCharacteristics(inputs[0].getName()));
FrameObject fo = new FrameObject(sec.getFrameObject(inputs[0].getName()));
boolean[] broadcasted = new boolean[inputs.length];
broadcasted[0] = false;

for(int i = 1; i < inputs.length; i++){
DataCharacteristics dcIn = sec.getDataCharacteristics(inputs[i].getName());
final int blk_size = dcout.getBlocksize() <= 0 ? DEFAULT_FRAME_BLOCKSIZE : dcout.getBlocksize();

broadcasted[i] = BinaryOp.FORCED_APPEND_METHOD == MR_MAPPEND
|| BinaryOp.FORCED_APPEND_METHOD == null && cbind && dcIn.getCols() <= blk_size
&& OptimizerUtils.checkSparkBroadcastMemoryBudget(
dcout.getCols(), dcIn.getCols(), blk_size, dcIn.getNonZeros());

//easy case: broadcast & map
if(broadcasted[i]){
outFrame = appendFrameMSP(outFrame, sec.getBroadcastForFrameVariable(inputs[i].getName()));
}
//general case for frames:
else{
if(BinaryOp.FORCED_APPEND_METHOD != null && BinaryOp.FORCED_APPEND_METHOD != MR_RAPPEND)
throw new DMLRuntimeException("Forced append type ["
+BinaryOp.FORCED_APPEND_METHOD+"] is not supported for frames");

JavaPairRDD<Long,FrameBlock> in2 =
sec.getFrameBinaryBlockRDDHandleForVariable(inputs[i].getName() );
outFrame = appendFrameRSP(outFrame, in2, dcout.getRows(), cbind);
}
updateAppendDataCharacteristics(dcIn, dcout, cbind);
if(cbind)
fo.setSchema(fo.mergeSchemas(sec.getFrameObject(inputs[i].getName())));
}

//set output RDD and add lineage
sec.getDataCharacteristics(output.getName()).set(dcout);
sec.setRDDHandleForVariable(output.getName(), outFrame);
sec.getFrameObject(output.getName()).setSchema(fo.getSchema());
for( int i = 0; i < inputs.length; i++)
if(broadcasted[i])
sec.addLineageBroadcast(output.getName(), inputs[i].getName());
else
sec.addLineageRDD(output.getName(), inputs[i].getName());
return;
}

//aggregate partially overlapping blocks w/ single shuffle
int numPartOut = SparkUtils.getNumPreferredPartitions(mcOut);
out = RDDAggregateUtils.mergeByKey(out, numPartOut, false);
}
else if( ArrayUtils.contains(new String[]{"nmin","nmax","n+"}, getOpcode()) ) {
//compute output characteristics
mcOut = computeMinMaxOutputDataCharacteristics(sec, inputs);
dcout = computeMinMaxOutputDataCharacteristics(sec, inputs);

//get scalars and consolidated input via join
List<ScalarObject> scalars = sec.getScalarInputs(inputs);
Expand All @@ -118,13 +183,43 @@ else if( ArrayUtils.contains(new String[]{"nmin","nmax","n+"}, getOpcode()) ) {
}

//set output RDD and add lineage
sec.getDataCharacteristics(output.getName()).set(mcOut);
sec.getDataCharacteristics(output.getName()).set(dcout);
sec.setRDDHandleForVariable(output.getName(), out);
for( CPOperand input : inputs )
if( !input.isScalar() )
sec.addLineageRDD(output.getName(), input.getName());
}


@SuppressWarnings("unused")
private static class AlignBlkTask implements PairFlatMapFunction<Tuple2<Long, FrameBlock>, Long, FrameBlock> {
private static final long serialVersionUID = 1333460067852261573L;
long max_rows;

public AlignBlkTask(long rows) {
max_rows = rows;
}

@Override
public Iterator<Tuple2<Long, FrameBlock>> call(Tuple2<Long, FrameBlock> longFrameBlockTuple2) throws Exception {
Long index = longFrameBlockTuple2._1;
FrameBlock fb = longFrameBlockTuple2._2;
ArrayList<Tuple2<Long, FrameBlock>> list = new ArrayList<Tuple2<Long, FrameBlock>>();
//single output block
if(max_rows <= DEFAULT_FRAME_BLOCKSIZE){
FrameBlock fbout = new FrameBlock(fb.getSchema());
fbout.ensureAllocatedColumns((int) max_rows);
fbout = fbout.leftIndexingOperations(fb,index.intValue() - 1, index.intValue() + fb.getNumRows() - 2,0, fb.getNumColumns()-1, null );
list.add(new Tuple2<>(1L, fbout));
} else {
throw new NotImplementedException("Other Alignment strategies need to be implemented");
//long aligned_index = (index / DEFAULT_FRAME_BLOCKSIZE)*OptimizerUtils.DEFAULT_FRAME_BLOCKSIZE+1;
//list.add(new Tuple2<>(index / DEFAULT_FRAME_BLOCKSIZE + 1, fb));
}

return list.iterator();
}
}

private static DataCharacteristics computeAppendOutputDataCharacteristics(SparkExecutionContext sec, CPOperand[] inputs, boolean cbind) {
DataCharacteristics mcIn1 = sec.getDataCharacteristics(inputs[0].getName());
DataCharacteristics mcOut = new MatrixCharacteristics(0, 0, mcIn1.getBlocksize(), 0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,7 @@ public void processInstruction(ExecutionContext ec) {
//execute map-append operations (partitioning preserving if keys for blocks not changing)
JavaPairRDD<Long,FrameBlock> out = null;
if( preservesPartitioning(_cbind) ) {
out = in1.mapPartitionsToPair(
new MapSideAppendPartitionFunction(in2), true);
out = appendFrameMSP(in1, in2);
}
else
throw new DMLRuntimeException("Append type rbind not supported for frame mappend, instead use rappend");
Expand All @@ -74,13 +73,20 @@ public void processInstruction(ExecutionContext ec) {
sec.getFrameObject(output.getName()).setSchema(sec.getFrameObject(input1.getName()).getSchema());
}

public static JavaPairRDD<Long, FrameBlock> appendFrameMSP(JavaPairRDD<Long, FrameBlock> in1, PartitionedBroadcast<FrameBlock> in2) {
JavaPairRDD<Long, FrameBlock> out;
out = in1.mapPartitionsToPair(
new MapSideAppendPartitionFunction(in2), true);
return out;
}

private static boolean preservesPartitioning( boolean cbind ) {
//Partitions for input1 will be preserved in case of cbind,
// where as in case of rbind partitions will not be preserved.
return cbind;
}

private static class MapSideAppendPartitionFunction implements PairFlatMapFunction<Iterator<Tuple2<Long,FrameBlock>>, Long, FrameBlock>
private static class MapSideAppendPartitionFunction implements PairFlatMapFunction<Iterator<Tuple2<Long,FrameBlock>>, Long, FrameBlock>
{
private static final long serialVersionUID = -3997051891171313830L;

Expand Down Expand Up @@ -118,8 +124,17 @@ protected Tuple2<Long, FrameBlock> computeNext(Tuple2<Long, FrameBlock> arg)

int rowix = (ix.intValue()-1)/OptimizerUtils.DEFAULT_FRAME_BLOCKSIZE+1;
int colix = 1;


FrameBlock in2 = _pm.getBlock(rowix, colix);

//if misalignment -> slice out fb from RHS
if(in1.getNumRows() != in2.getNumRows()){
int start = ix.intValue() - 1 - (rowix-1)*OptimizerUtils.DEFAULT_FRAME_BLOCKSIZE;
int end = start + in1.getNumRows() - 1;
in2 = in2.slice(start, end);
}

FrameBlock out = in1.append(in2, true); //cbind
return new Tuple2<>(ix, out);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,9 @@ public void processInstruction(ExecutionContext ec) {
JavaPairRDD<Long,FrameBlock> in2 = sec.getFrameBinaryBlockRDDHandleForVariable( input2.getName() );
JavaPairRDD<Long,FrameBlock> out = null;
long leftRows = sec.getDataCharacteristics(input1.getName()).getRows();

if(_cbind) {
JavaPairRDD<Long,FrameBlock> in1Aligned = in1.mapToPair(new ReduceSideAppendAlignFunction(leftRows));
in1Aligned = FrameRDDAggregateUtils.mergeByKey(in1Aligned);
JavaPairRDD<Long,FrameBlock> in2Aligned = in2.mapToPair(new ReduceSideAppendAlignFunction(leftRows));
in2Aligned = FrameRDDAggregateUtils.mergeByKey(in2Aligned);

out = in1Aligned.join(in2Aligned).mapValues(new ReduceSideColumnsFunction(_cbind));
} else { //rbind
JavaPairRDD<Long,FrameBlock> right = in2.mapToPair( new ReduceSideAppendRowsFunction(leftRows));
out = in1.union(right);
}


out = appendFrameRSP(in1, in2, leftRows, _cbind);

//put output RDD handle into symbol table
updateBinaryAppendOutputDataCharacteristics(sec, _cbind);
sec.setRDDHandleForVariable(output.getName(), out);
Expand All @@ -73,6 +63,19 @@ public void processInstruction(ExecutionContext ec) {
sec.getFrameObject(output.getName()).setSchema(sec.getFrameObject(input1.getName()).getSchema());
}

public static JavaPairRDD<Long, FrameBlock> appendFrameRSP(JavaPairRDD<Long, FrameBlock> in1, JavaPairRDD<Long, FrameBlock> in2, long leftRows, boolean cbind) {
if(cbind) {
JavaPairRDD<Long,FrameBlock> in1Aligned = in1.mapToPair(new ReduceSideAppendAlignFunction(leftRows));
in1Aligned = FrameRDDAggregateUtils.mergeByKey(in1Aligned);
JavaPairRDD<Long,FrameBlock> in2Aligned = in2.mapToPair(new ReduceSideAppendAlignFunction(leftRows));
in2Aligned = FrameRDDAggregateUtils.mergeByKey(in2Aligned);
return in1Aligned.join(in2Aligned).mapValues(new ReduceSideColumnsFunction(cbind));
} else { //rbind
JavaPairRDD<Long,FrameBlock> right = in2.mapToPair( new ReduceSideAppendRowsFunction(leftRows));
return in1.union(right);
}
}

private static class ReduceSideColumnsFunction implements Function<Tuple2<FrameBlock, FrameBlock>, FrameBlock>
{
private static final long serialVersionUID = -97824903649667646L;
Expand Down Expand Up @@ -109,7 +112,7 @@ public Tuple2<Long,FrameBlock> call(Tuple2<Long, FrameBlock> arg0)
}
}

private static class ReduceSideAppendAlignFunction implements PairFunction<Tuple2<Long, FrameBlock>, Long, FrameBlock>
private static class ReduceSideAppendAlignFunction implements PairFunction<Tuple2<Long, FrameBlock>, Long, FrameBlock>
{
private static final long serialVersionUID = 5850400295183766409L;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ public void testCP() {
runWerTest(ExecType.CP);
}

// @Test
// public void testSpark() {
// runWerTest(ExecType.SPARK);
// }
@Test
public void testSpark() {
runWerTest(ExecType.SPARK);
}

private void runWerTest(ExecType instType) {
ExecMode platformOld = setExecMode(instType);
Expand Down
Loading

0 comments on commit 7d1f081

Please sign in to comment.