Skip to content

Commit

Permalink
[SYSTEMDS-3525] Binary Inplace Operations
Browse files Browse the repository at this point in the history
This commit initialize the inplace logic for Binary operations.
Initially this is only used in a very specific case of division by a vector
that does not contain NaN or zero and the input is not used by any other
operator.

Additionally this commit adds a parameterized test that verify equivalent
behavior of the inplace operations and the normal operations.

Closes #1808
  • Loading branch information
Baunsgaard committed Apr 24, 2023
1 parent 107dae9 commit 1329f3d
Show file tree
Hide file tree
Showing 11 changed files with 921 additions and 89 deletions.
45 changes: 31 additions & 14 deletions src/main/java/org/apache/sysds/hops/AggUnaryOp.java
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,28 @@ else if( _direction == Direction.Row && dc.rowsKnown() )
}


private boolean inputAlreadySpark(){
return (!(getInput(0) instanceof DataOp) //input is not checkpoint
&& getInput(0).optFindExecType() == ExecType.SPARK);
}

private boolean inputOnlyRDD(){
return (getInput(0) instanceof DataOp && ((DataOp)getInput(0)).hasOnlyRDD());
}

private boolean onlyOneParent(){
return getInput(0).getParent().size()==1;
}

private boolean allParentsSpark(){
return getInput(0).getParent().stream().filter(h -> h != this)
.allMatch(h -> h.optFindExecType(false) == ExecType.SPARK);
}

private boolean inputDoesNotRequireAggregation(){
return !requiresAggregation(getInput(0), _direction);
}

@Override
protected ExecType optFindExecType(boolean transitive) {

Expand All @@ -351,17 +373,14 @@ protected ExecType optFindExecType(boolean transitive) {
}
else
{
if ( OptimizerUtils.isMemoryBasedOptLevel() )
{
if ( OptimizerUtils.isMemoryBasedOptLevel()) {
_etype = findExecTypeByMemEstimate();
}
// Choose CP, if the input dimensions are below threshold or if the input is a vector
else if ( getInput().get(0).areDimsBelowThreshold() || getInput().get(0).isVector() )
{
else if(getInput().get(0).areDimsBelowThreshold() || getInput().get(0).isVector()) {
_etype = ExecType.CP;
}
else
{
else {
_etype = REMOTE;
}

Expand All @@ -372,14 +391,12 @@ else if ( getInput().get(0).areDimsBelowThreshold() || getInput().get(0).isVecto
//spark-specific decision refinement (execute unary aggregate w/ spark input and
//single parent also in spark because it's likely cheap and reduces data transfer)
//we also allow multiple parents, if all other parents are already in Spark mode
if( transitive && _etype == ExecType.CP && _etypeForced != ExecType.CP
&& ((!(getInput(0) instanceof DataOp) //input is not checkpoint
&& getInput(0).optFindExecType() == ExecType.SPARK)
|| (getInput(0) instanceof DataOp && ((DataOp)getInput(0)).hasOnlyRDD()))
&& (getInput(0).getParent().size()==1 //uagg is only parent, or
|| getInput(0).getParent().stream().filter(h -> h != this)
.allMatch(h -> h.optFindExecType(false) == ExecType.SPARK)
|| !requiresAggregation(getInput(0), _direction)) ) //w/o agg

boolean shouldEvaluateIfSpark = transitive && _etype == ExecType.CP && _etypeForced != ExecType.CP;

if( shouldEvaluateIfSpark
&& (inputAlreadySpark() || inputOnlyRDD())
&& (onlyOneParent() || allParentsSpark() || inputDoesNotRequireAggregation() ))
{
//pull unary aggregate into spark
_etype = ExecType.SPARK;
Expand Down
93 changes: 76 additions & 17 deletions src/main/java/org/apache/sysds/hops/BinaryOp.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

package org.apache.sysds.hops;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.AggOp;
import org.apache.sysds.common.Types.DataType;
Expand All @@ -27,6 +29,7 @@
import org.apache.sysds.common.Types.OpOp1;
import org.apache.sysds.common.Types.OpOp2;
import org.apache.sysds.common.Types.OpOpDnn;
import org.apache.sysds.common.Types.ParamBuiltinOp;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
Expand All @@ -52,21 +55,21 @@
import org.apache.sysds.runtime.meta.MatrixCharacteristics;


/* Binary (cell operations): aij + bij
/** Binary (cell operations): aij + bij
* Properties:
* Symbol: *, -, +, ...
* 2 Operands
* Semantic: align indices (sort), then perform operation
*/

public class BinaryOp extends MultiThreadedHop {
// private static final Log LOG = LogFactory.getLog(BinaryOp.class.getName());
protected static final Log LOG = LogFactory.getLog(BinaryOp.class.getName());

//we use the full remote memory budget (but reduced by sort buffer),
public static final double APPEND_MEM_MULTIPLIER = 1.0;

private OpOp2 op;
private boolean outer = false;
private boolean inplace = false;

public static AppendMethod FORCED_APPEND_METHOD = null;
public static MMBinaryMethod FORCED_BINARY_METHOD = null;
Expand Down Expand Up @@ -126,6 +129,10 @@ public void setOuterVectorOperation(boolean flag) {
public boolean isOuter(){
return outer;
}

public boolean isInplace(){
return inplace;
}

@Override
public boolean isGPUEnabled() {
Expand Down Expand Up @@ -435,7 +442,7 @@ ot, getDataType(), getValueType(), et,
else { //general case
tmp = new Binary(getInput(0).constructLops(), getInput(1).constructLops(),
op, getDataType(), getValueType(), et,
OptimizerUtils.getConstrainedNumThreads(_maxNumThreads));
OptimizerUtils.getConstrainedNumThreads(_maxNumThreads), inplace);
}

setOutputDimensions(tmp);
Expand Down Expand Up @@ -477,7 +484,7 @@ && getInput().get(0).dimsKnown() && getInput().get(1).dimsKnown()) {
else
binary = new Binary(getInput(0).constructLops(), getInput(1).constructLops(),
op, getDataType(), getValueType(), et,
OptimizerUtils.getConstrainedNumThreads(_maxNumThreads));
OptimizerUtils.getConstrainedNumThreads(_maxNumThreads), inplace);

setOutputDimensions(binary);
setLineNumbers(binary);
Expand Down Expand Up @@ -700,6 +707,44 @@ public boolean allowsAllExecTypes()
return true;
}

private static boolean isReplace(Hop h) {
return h instanceof ParameterizedBuiltinOp && //
((ParameterizedBuiltinOp) h).getOp() == ParamBuiltinOp.REPLACE;
}

private static boolean isReplaceWithPattern(ParameterizedBuiltinOp h, double pattern, double replace) {
Hop pat = h.getParameterHop("pattern");
Hop rep = h.getParameterHop("replacement");
if(pat instanceof LiteralOp && rep instanceof LiteralOp) {
double patOb = ((LiteralOp) pat).getDoubleValue();
double repOb = ((LiteralOp) rep).getDoubleValue();
return ((Double.isNaN(pattern) && Double.isNaN(patOb)) // is both NaN
|| Double.compare(pattern, patOb) == 0) // Is equivalent pattern
&& Double.compare(replace, repOb) == 0; // is equivalent replace.
}
return false;
}

private static boolean doesNotContainNanAndInf(Hop p1) {
if(isReplace(p1)) {
Hop p2 = p1.getInput().get(0);
if(isReplace(p2)) {
ParameterizedBuiltinOp pp1 = (ParameterizedBuiltinOp) p1;
ParameterizedBuiltinOp pp2 = (ParameterizedBuiltinOp) p2;
return (isReplaceWithPattern(pp1, Double.NaN, 1) && isReplaceWithPattern(pp2, 0, 1)) ||
(isReplaceWithPattern(pp2, Double.NaN, 1) && isReplaceWithPattern(pp1, 0, 1));
}
}
return false;
}

private boolean memOfInputIsLessThanBudget() {
final double in1Memory = getInput().get(0).getMemEstimate();
final double in2Memory = getInput().get(1).getMemEstimate();
final double budget = OptimizerUtils.getLocalMemBudget();
return in1Memory + in2Memory < budget;
}

@Override
protected ExecType optFindExecType(boolean transitive) {

Expand Down Expand Up @@ -755,20 +800,34 @@ else if ( dt1 == DataType.SCALAR && dt2 == DataType.MATRIX ) {
checkAndSetInvalidCPDimsAndSize();
}

//spark-specific decision refinement (execute unary scalar w/ spark input and
//single parent also in spark because it's likely cheap and reduces intermediates)
if( transitive && _etype == ExecType.CP && _etypeForced != ExecType.CP && _etypeForced != ExecType.FED
&& getDataType().isMatrix() && (dt1.isScalar() || dt2.isScalar())
&& supportsMatrixScalarOperations() //scalar operations
&& !(getInput().get(dt1.isScalar()?1:0) instanceof DataOp) //input is not checkpoint
&& getInput().get(dt1.isScalar()?1:0).getParent().size()==1 //unary scalar is only parent
&& !HopRewriteUtils.isSingleBlock(getInput().get(dt1.isScalar()?1:0)) //single block triggered exec
&& getInput().get(dt1.isScalar()?1:0).optFindExecType() == ExecType.SPARK )
{
//pull unary scalar operation into spark
//spark-specific decision refinement (execute unary scalar w/ spark input and
// single parent also in spark because it's likely cheap and reduces intermediates)
if(transitive && _etype == ExecType.CP && _etypeForced != ExecType.CP && _etypeForced != ExecType.FED &&
getDataType().isMatrix() // output should be a matrix
&& (dt1.isScalar() || dt2.isScalar()) // one side should be scalar
&& supportsMatrixScalarOperations() // scalar operations
&& !(getInput().get(dt1.isScalar() ? 1 : 0) instanceof DataOp) // input is not checkpoint
&& getInput().get(dt1.isScalar() ? 1 : 0).getParent().size() == 1 // unary scalar is only parent
&& !HopRewriteUtils.isSingleBlock(getInput().get(dt1.isScalar() ? 1 : 0)) // single block triggered exec
&& getInput().get(dt1.isScalar() ? 1 : 0).optFindExecType() == ExecType.SPARK) {
// pull unary scalar operation into spark
_etype = ExecType.SPARK;
}


if( transitive && _etypeForced != ExecType.SPARK && _etypeForced != ExecType.FED && //
getDataType().isMatrix() // Output is a matrix
&& op == OpOp2.DIV // Operation is division
&& dt1.isMatrix() // Left hand side is a Matrix
// right hand side is a scalar or a vector.
&& (dt2.isScalar() || (dt2.isMatrix() & getInput().get(1).isVector())) //
&& memOfInputIsLessThanBudget() //
&& getInput().get(0).getExecType() != ExecType.SPARK // Is not already a spark operation
&& doesNotContainNanAndInf(getInput().get(1)) // Guaranteed not to densify the operation
) {
inplace = true;
_etype = ExecType.CP;
}

//ensure cp exec type for single-node operations
if ( op == OpOp2.SOLVE ) {
if (isGPUEnabled())
Expand Down
12 changes: 12 additions & 0 deletions src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
Original file line number Diff line number Diff line change
Expand Up @@ -734,6 +734,18 @@ && getTargetHop().areDimsBelowThreshold() ) {
_etype = ExecType.CP;
}

// If previous instructions were in spark force aggregating
// parameterized operations to be executed in spark
if(transitive && _etype == ExecType.CP && _etypeForced != ExecType.CP) {
switch(_op) {
case CONTAINS:
if(getTargetHop().optFindExecType() == ExecType.SPARK)
_etype = ExecType.SPARK;
default:
// Do not change execution type.
}
}

//mark for recompile (forever)
setRequiresRecompileIfNecessary();

Expand Down
14 changes: 11 additions & 3 deletions src/main/java/org/apache/sysds/lops/Binary.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@


/**
* Lop to perform binary operation. Both inputs must be matrices or vectors.
* Example - A = B + C, where B and C are matrices or vectors.
* Lop to perform binary operation. Both inputs must be matrices, vectors or scalars.
* Example - A = B + C.
*/

public class Binary extends Lop
{
private OpOp2 operation;
private final int _numThreads;
private final boolean inplace;

/**
* Constructor to perform a binary operation.
Expand All @@ -55,9 +55,14 @@ public Binary(Lop input1, Lop input2, OpOp2 op, DataType dt, ValueType vt, ExecT
}

public Binary(Lop input1, Lop input2, OpOp2 op, DataType dt, ValueType vt, ExecType et, int k) {
this(input1, input2, op, dt, vt, et, k, false);
}

public Binary(Lop input1, Lop input2, OpOp2 op, DataType dt, ValueType vt, ExecType et, int k, boolean inplace) {
super(Lop.Type.Binary, dt, vt);
init(input1, input2, op, dt, vt, et);
_numThreads = k;
this.inplace = inplace;
}

private void init(Lop input1, Lop input2, OpOp2 op, DataType dt, ValueType vt, ExecType et) {
Expand Down Expand Up @@ -107,6 +112,9 @@ public String getInstructions(String input1, String input2, String output) {
else if( getExecType() == ExecType.FED )
ret = InstructionUtils.concatOperands(ret, String.valueOf(_numThreads), _fedOutput.name());

if (getExecType() == ExecType.CP && inplace)
ret = InstructionUtils.concatOperands(ret, "InPlace");

return ret;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public enum IType {
FEDERATED
}

private static final Log LOG = LogFactory.getLog(Instruction.class.getName());
protected static final Log LOG = LogFactory.getLog(Instruction.class.getName());
protected final Operator _optr;

protected Instruction(Operator _optr){
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ else if (in1.getDataType() == DataType.FRAME && in2.getDataType() == DataType.MA

private static String[] parseBinaryInstruction(String instr, CPOperand in1, CPOperand in2, CPOperand out) {
String[] parts = InstructionUtils.getInstructionPartsWithValueType(instr);
InstructionUtils.checkNumFields ( parts, 3, 4, 5 );
InstructionUtils.checkNumFields ( parts, 3, 4, 5, 6 );
in1.split(parts[1]);
in2.split(parts[2]);
out.split(parts[3]);
Expand Down
Loading

0 comments on commit 1329f3d

Please sign in to comment.