Skip to content

Commit

Permalink
[SYSTEMDS-3119] Multi-return eval function calls (evalList)
Browse files Browse the repository at this point in the history
This patch extends the existing eval() second-order function calls
(which return a single matrix) by evalList that bundles multiple
returns into a named list. This approach allows reusing all existing
primitives as they are, yet support better state management in data
cleaning pipelines. In detail, we provide a new language-level
builtin function evalList, but both eval and evalList are parsed
to eval operations, simply with different output type, and at
runtime, we handle the functions accordingly.

Additional changes that showed up during the tests include:
* New rewrite for list indexes (avoid unnecessary instructions)
* Extended rewrite for DAG splits after data-dependent operators
  (include persistent writes into consideration to avoid Spark ops)
* Cleanup right indexing lop construction (old MR code)
* Fix for invalid dimensions checks for list indexing
* Fix selected tests for reduced # expected spark jobs
  • Loading branch information
mboehm7 committed Dec 29, 2021
1 parent a96a76d commit 9da2fb8
Show file tree
Hide file tree
Showing 16 changed files with 219 additions and 109 deletions.
1 change: 1 addition & 0 deletions src/main/java/org/apache/sysds/common/Builtins.java
Expand Up @@ -123,6 +123,7 @@ public enum Builtins {
EXECUTE_PIPELINE("executePipeline", true),
EXP("exp", false),
EVAL("eval", false),
EVALLIST("evalList", false),
FIX_INVALID_LENGTHS("fixInvalidLengths", true),
FF_TRAIN("ffTrain", true),
FF_PREDICT("ffPredict", true),
Expand Down
15 changes: 5 additions & 10 deletions src/main/java/org/apache/sysds/hops/IndexingOp.java
Expand Up @@ -26,7 +26,6 @@
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.hops.AggBinaryOp.SparkAggType;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.lops.Data;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.common.Types.ExecType;
import org.apache.sysds.lops.RightIndex;
Expand Down Expand Up @@ -140,10 +139,8 @@ public Lop constructLops()
SparkAggType aggtype = (method==IndexingMethod.MR_VRIX || isBlockAligned()) ?
SparkAggType.NONE : SparkAggType.MULTI_BLOCK;

Lop dummy = Data.createLiteralLop(ValueType.INT64, Integer.toString(-1));
RightIndex reindex = new RightIndex(
input.constructLops(), getInput().get(1).constructLops(), getInput().get(2).constructLops(),
getInput().get(3).constructLops(), getInput().get(4).constructLops(), dummy, dummy,
RightIndex reindex = new RightIndex(input.constructLops(), getInput(1).constructLops(),
getInput(2).constructLops(), getInput(3).constructLops(), getInput(4).constructLops(),
getDataType(), getValueType(), aggtype, et);

setOutputDimensions(reindex);
Expand All @@ -152,11 +149,9 @@ public Lop constructLops()
}
else //CP or GPU
{
Lop dummy = Data.createLiteralLop(ValueType.INT64, Integer.toString(-1));
RightIndex reindex = new RightIndex(
input.constructLops(), getInput().get(1).constructLops(), getInput().get(2).constructLops(),
getInput().get(3).constructLops(), getInput().get(4).constructLops(), dummy, dummy,
getDataType(), getValueType(), et);
RightIndex reindex = new RightIndex(input.constructLops(), getInput(1).constructLops(),
getInput(2).constructLops(), getInput(3).constructLops(), getInput(4).constructLops(),
getDataType(), getValueType(), et);

setOutputDimensions(reindex);
setLineNumbers(reindex);
Expand Down
Expand Up @@ -150,6 +150,11 @@ private static boolean checkAndReplaceEvalFunctionCall(DMLProgram prog, Statemen
+ "applicable for replacement, but list inputs not yet supported.");
continue;
}
if( eval.getDataType().isList() ) {
LOG.warn("IPA: eval("+fnamespace+"::"+fname+") "
+ "applicable for replacement, but list output not yet supported.");
continue;
}
if( fstmt.getOutputParams().size() != 1 || !fstmt.getOutputParams().get(0).getDataType().isMatrix() ) {
LOG.warn("IPA: eval("+fnamespace+"::"+fname+") "
+ "applicable for replacement, but function output is not a matrix.");
Expand Down
Expand Up @@ -165,6 +165,7 @@ private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst)
hi = fuseBinarySubDAGToUnaryOperation(hop, hi, i); //e.g., X*(1-X)-> sprop(X) || 1/(1+exp(-X)) -> sigmoid(X) || X*(X>0) -> selp(X)
hi = simplifyTraceMatrixMult(hop, hi, i); //e.g., trace(X%*%Y)->sum(X*t(Y));
hi = simplifySlicedMatrixMult(hop, hi, i); //e.g., (X%*%Y)[1,1] -> X[1,] %*% Y[,1];
hi = simplifyListIndexing(hi); //e.g., L[i:i, 1:ncol(L)] -> L[i:i, 1:1]
hi = simplifyConstantSort(hop, hi, i); //e.g., order(matrix())->matrix/seq;
hi = simplifyOrderedSort(hop, hi, i); //e.g., order(matrix())->seq;
hi = fuseOrderOperationChain(hi); //e.g., order(order(X,2),1) -> order(X,(12))
Expand Down Expand Up @@ -1390,12 +1391,23 @@ private static Hop simplifySlicedMatrixMult(Hop parent, Hop hi, int pos)
mm.refreshSizeInformation();

hi = mm;
LOG.debug("Applied simplifySlicedMatrixMult");

LOG.debug("Applied simplifySlicedMatrixMult");
}

return hi;
}

private static Hop simplifyListIndexing(Hop hi) {
//e.g., L[i:i, 1:ncol(L)] -> L[i:i, 1:1]
if( hi instanceof IndexingOp && hi.getDataType().isList()
&& !(hi.getInput(4) instanceof LiteralOp) )
{
HopRewriteUtils.replaceChildReference(hi, hi.getInput(4), new LiteralOp(1));
LOG.debug("Applied simplifyListIndexing (line "+hi.getBeginLine()+").");
}
return hi;
}

private static Hop simplifyConstantSort(Hop parent, Hop hi, int pos)
{
Expand Down
Expand Up @@ -230,7 +230,7 @@ private void rCollectDataDependentOperators( Hop hop, ArrayList<Hop> cand )
return;

//prevent unnecessary dag split (dims known or no consumer operations)
boolean noSplitRequired = (HopRewriteUtils.hasOnlyWriteParents(hop, true, true)
boolean noSplitRequired = (HopRewriteUtils.hasOnlyWriteParents(hop, true, false)
|| hop.dimsKnown() || DMLScript.getGlobalExecMode() == ExecMode.SINGLE_NODE);
boolean investigateChilds = true;

Expand Down
58 changes: 14 additions & 44 deletions src/main/java/org/apache/sysds/lops/RightIndex.java
Expand Up @@ -36,46 +36,42 @@ public class RightIndex extends Lop
//optional attribute for spark exec type
private SparkAggType _aggtype = SparkAggType.MULTI_BLOCK;

public RightIndex(Lop input, Lop rowL, Lop rowU, Lop colL, Lop colU, Lop rowDim, Lop colDim,
DataType dt, ValueType vt, ExecType et, boolean forleft)
public RightIndex(Lop input, Lop rowL, Lop rowU, Lop colL, Lop colU,
DataType dt, ValueType vt, ExecType et, boolean forleft)
{
super(Lop.Type.RightIndex, dt, vt);
init(input, rowL, rowU, colL, colU, rowDim, colDim, dt, vt, et, forleft);
init(input, rowL, rowU, colL, colU, dt, vt, et, forleft);
}

public RightIndex(Lop input, Lop rowL, Lop rowU, Lop colL, Lop colU, Lop rowDim, Lop colDim,
DataType dt, ValueType vt, ExecType et)
public RightIndex(Lop input, Lop rowL, Lop rowU, Lop colL, Lop colU,
DataType dt, ValueType vt, ExecType et)
{
super(Lop.Type.RightIndex, dt, vt);
init(input, rowL, rowU, colL, colU, rowDim, colDim, dt, vt, et, false);
init(input, rowL, rowU, colL, colU, dt, vt, et, false);
}

public RightIndex(Lop input, Lop rowL, Lop rowU, Lop colL, Lop colU, Lop rowDim, Lop colDim,
DataType dt, ValueType vt, SparkAggType aggtype, ExecType et)
public RightIndex(Lop input, Lop rowL, Lop rowU, Lop colL, Lop colU,
DataType dt, ValueType vt, SparkAggType aggtype, ExecType et)
{
super(Lop.Type.RightIndex, dt, vt);
_aggtype = aggtype;
init(input, rowL, rowU, colL, colU, rowDim, colDim, dt, vt, et, false);
init(input, rowL, rowU, colL, colU, dt, vt, et, false);
}

private void init(Lop inputMatrix, Lop rowL, Lop rowU, Lop colL, Lop colU, Lop leftMatrixRowDim,
Lop leftMatrixColDim, DataType dt, ValueType vt, ExecType et, boolean forleft)
{
private void init(Lop inputMatrix, Lop rowL, Lop rowU, Lop colL, Lop colU,
DataType dt, ValueType vt, ExecType et, boolean forleft)
{
addInput(inputMatrix);
addInput(rowL);
addInput(rowU);
addInput(colL);
addInput(colU);
addInput(leftMatrixRowDim);
addInput(leftMatrixColDim);

inputMatrix.addOutput(this);
rowL.addOutput(this);
rowU.addOutput(this);
colL.addOutput(this);
colU.addOutput(this);
leftMatrixRowDim.addOutput(this);
leftMatrixColDim.addOutput(this);
lps.setProperties(inputs, et);
forLeftIndexing=forleft;
}
Expand All @@ -93,7 +89,7 @@ public SparkAggType getAggType() {
}

@Override
public String getInstructions(String input, String rowl, String rowu, String coll, String colu, String leftRowDim, String leftColDim, String output) {
public String getInstructions(String input, String rowl, String rowu, String coll, String colu, String output) {
StringBuilder sb = new StringBuilder();
sb.append( getExecType() );
sb.append( OPERAND_DELIMITOR );
Expand Down Expand Up @@ -124,40 +120,14 @@ public String getInstructions(String input, String rowl, String rowu, String col
//in case of spark, we also compile the optional aggregate flag into the instruction.
if( getExecType() == ExecType.SPARK ) {
sb.append( OPERAND_DELIMITOR );
sb.append( _aggtype );
sb.append( _aggtype );
}

return sb.toString();
}

@Override
public String getInstructions(int input_index1, int input_index2, int input_index3, int input_index4, int input_index5, int input_index6, int input_index7, int output_index) {
/*
* Example: B = A[row_l:row_u, col_l:col_u]
* A - input matrix (input_index1)
* row_l - lower bound in row dimension
* row_u - upper bound in row dimension
* col_l - lower bound in column dimension
* col_u - upper bound in column dimension
*
* Since row_l,row_u,col_l,col_u are scalars, values for input_index(2,3,4,5,6,7)
* will be equal to -1. They should be ignored and the scalar value labels must
* be derived from input lops.
*/
String rowl = getInputs().get(1).prepScalarLabel();
String rowu = getInputs().get(2).prepScalarLabel();
String coll = getInputs().get(3).prepScalarLabel();
String colu = getInputs().get(4).prepScalarLabel();

String left_nrow = getInputs().get(5).prepScalarLabel();
String left_ncol = getInputs().get(6).prepScalarLabel();

return getInstructions(Integer.toString(input_index1), rowl, rowu, coll, colu, left_nrow, left_ncol, Integer.toString(output_index));
}

@Override
public String toString() {
return getOpcode();
}

}
Expand Up @@ -550,11 +550,13 @@ public void validateExpression(HashMap<String, DataIdentifier> ids, HashMap<Stri

switch (getOpCode()) {
case EVAL:
case EVALLIST:
if (_args.length == 0)
raiseValidateError("Function eval should provide at least one argument, i.e., the function name.", false);
checkValueTypeParam(_args[0], ValueType.STRING);
output.setDataType(DataType.MATRIX);
output.setValueType(ValueType.FP64);
boolean listReturn = (getOpCode()==Builtins.EVALLIST);
output.setDataType(listReturn ? DataType.LIST : DataType.MATRIX);
output.setValueType(listReturn ? ValueType.UNKNOWN : ValueType.FP64);
output.setDimensions(-1, -1);
output.setBlocksize(ConfigurationManager.getBlocksize());
break;
Expand Down
1 change: 1 addition & 0 deletions src/main/java/org/apache/sysds/parser/DMLTranslator.java
Expand Up @@ -2244,6 +2244,7 @@ private Hop processBuiltinFunctionExpression(BuiltinFunctionExpression source, D
switch (source.getOpCode()) {

case EVAL:
case EVALLIST:
currBuiltinOp = new NaryOp(target.getName(), target.getDataType(), target.getValueType(),
OpOpN.EVAL, processAllExpressions(source.getAllExpr(), hops));
break;
Expand Down
24 changes: 13 additions & 11 deletions src/main/java/org/apache/sysds/parser/StatementBlock.java
Expand Up @@ -994,19 +994,21 @@ else if (!(target instanceof IndexedIdentifier)){

// validate that size of LHS index ranges is being assigned:
// (a) a matrix value of same size as LHS
// (b) singleton value (semantics: initialize enitre submatrix with this value)
// (b) singleton value (semantics: initialize entire submatrix with this value)
IndexPair targetSize = ((IndexedIdentifier)target).calculateIndexedDimensions(ids.getVariables(), currConstVars, conditional);

if (targetSize._row >= 1 && source.getOutput().getDim1() > 1 && targetSize._row != source.getOutput().getDim1()){
target.raiseValidateError("Dimension mismatch. Indexed expression " + target.toString() + " can only be assigned matrix with dimensions "
+ targetSize._row + " rows and " + targetSize._col + " cols. Attempted to assign matrix with dimensions "
+ source.getOutput().getDim1() + " rows and " + source.getOutput().getDim2() + " cols ", conditional);
}

if (targetSize._col >= 1 && source.getOutput().getDim2() > 1 && targetSize._col != source.getOutput().getDim2()){
target.raiseValidateError("Dimension mismatch. Indexed expression " + target.toString() + " can only be assigned matrix with dimensions "
+ targetSize._row + " rows and " + targetSize._col + " cols. Attempted to assign matrix with dimensions "
+ source.getOutput().getDim1() + " rows and " + source.getOutput().getDim2() + " cols ", conditional);
if( target.getDataType().isMatrixOrFrame() ) {
if (targetSize._row >= 1 && source.getOutput().getDim1() > 1 && targetSize._row != source.getOutput().getDim1()){
target.raiseValidateError("Dimension mismatch. Indexed expression " + target.toString() + " can only be assigned matrix with dimensions "
+ targetSize._row + " rows and " + targetSize._col + " cols. Attempted to assign matrix with dimensions "
+ source.getOutput().getDim1() + " rows and " + source.getOutput().getDim2() + " cols ", conditional);
}

if (targetSize._col >= 1 && source.getOutput().getDim2() > 1 && targetSize._col != source.getOutput().getDim2()){
target.raiseValidateError("Dimension mismatch. Indexed expression " + target.toString() + " can only be assigned matrix with dimensions "
+ targetSize._row + " rows and " + targetSize._col + " cols. Attempted to assign matrix with dimensions "
+ source.getOutput().getDim1() + " rows and " + source.getOutput().getDim2() + " cols ", conditional);
}
}
((IndexedIdentifier)target).setDimensions(targetSize._row, targetSize._col);
}
Expand Down
Expand Up @@ -72,6 +72,10 @@ public List<String> getInputParamNames() {
return _inputParams.stream().map(d -> d.getName()).collect(Collectors.toList());
}

public List<String> getOutputParamNames() {
return _outputParams.stream().map(d -> d.getName()).collect(Collectors.toList());
}

public ArrayList<DataIdentifier> getInputParams(){
return _inputParams;
}
Expand Down
Expand Up @@ -67,7 +67,14 @@ public EvalNaryCPInstruction(Operator op, String opcode, String istr, CPOperand

@Override
public void processInstruction(ExecutionContext ec) {
//1. get the namespace and func
// There are two main types of eval function calls, which share most of the
// code for lazy function loading and execution:
// a) a single-return eval fcall returns a matrix which is bound to the output
// (if the function returns multiple objects, the first one is used as output)
// b) a multi-return eval fcall gets all returns of the function call and
// creates a named list used the names of the function signature

//1. get the namespace and function names
String funcName = ec.getScalarInput(inputs[0]).getStringValue();
String nsName = null; //default namespace
if( funcName.contains(Program.KEY_DELIM) ) {
Expand All @@ -76,14 +83,13 @@ public void processInstruction(ExecutionContext ec) {
nsName = parts[0];
}

// bound the inputs to avoiding being deleted after the function call
// bind the inputs to avoiding being deleted after the function call
CPOperand[] boundInputs = Arrays.copyOfRange(inputs, 1, inputs.length);
List<String> boundOutputNames = new ArrayList<>();
boundOutputNames.add(output.getName());


//2. copy the created output matrix
MatrixObject outputMO = new MatrixObject(ec.getMatrixObject(output.getName()));

MatrixObject outputMO = !output.isMatrix() ? null :
new MatrixObject(ec.getMatrixObject(output.getName()));

//3. lazy loading of dml-bodied builtin functions (incl. rename
// of function name to dml-bodied builtin scheme (data-type-specific)
DataType dt1 = boundInputs[0].getDataType().isList() ?
Expand Down Expand Up @@ -138,34 +144,51 @@ public void processInstruction(ExecutionContext ec) {
boundInputs2[i] = new CPOperand(varName, in);
}
boundInputs = boundInputs2;
lineageInputs = DMLScript.LINEAGE
? lo.getLineageItems().toArray(new LineageItem[lo.getLength()]) : null;
lineageInputs = !DMLScript.LINEAGE ? null :
lo.getLineageItems().toArray(new LineageItem[lo.getLength()]);
}

// bind the outputs
List<String> boundOutputNames = new ArrayList<>();
if( output.getDataType().isMatrix() )
boundOutputNames.add(output.getName());
else //list
boundOutputNames.addAll(fpb.getOutputParamNames());

//5. call the function (to unoptimized function)
FunctionCallCPInstruction fcpi = new FunctionCallCPInstruction(nsName, funcName,
false, boundInputs, lineageInputs, fpb.getInputParamNames(), boundOutputNames, "eval func");
fcpi.processInstruction(ec);

//6. convert the result to matrix
Data newOutput = ec.getVariable(output);
if (!(newOutput instanceof MatrixObject)) {
MatrixBlock mb = null;
if (newOutput instanceof ScalarObject) {
//convert scalar to matrix
mb = new MatrixBlock(((ScalarObject) newOutput).getDoubleValue());
} else if (newOutput instanceof FrameObject) {
//convert frame to matrix
mb = DataConverter.convertToMatrixBlock(((FrameObject) newOutput).acquireRead());
ec.cleanupCacheableData((FrameObject) newOutput);
//6a. convert the result to matrix
if( output.getDataType().isMatrix() ) {
Data newOutput = ec.getVariable(output);
if (!(newOutput instanceof MatrixObject)) {
MatrixBlock mb = null;
if (newOutput instanceof ScalarObject) {
//convert scalar to matrix
mb = new MatrixBlock(((ScalarObject) newOutput).getDoubleValue());
} else if (newOutput instanceof FrameObject) {
//convert frame to matrix
mb = DataConverter.convertToMatrixBlock(((FrameObject) newOutput).acquireRead());
ec.cleanupCacheableData((FrameObject) newOutput);
}
else {
throw new DMLRuntimeException("Invalid eval return type: "+newOutput.getDataType().name()
+ " (valid: matrix/frame/scalar; where frames or scalars are converted to output matrices)");
}
outputMO.acquireModify(mb);
outputMO.release();
ec.setVariable(output.getName(), outputMO);
}
else {
throw new DMLRuntimeException("Invalid eval return type: "+newOutput.getDataType().name()
+ " (valid: matrix/frame/scalar; where frames or scalars are converted to output matrices)");
}
outputMO.acquireModify(mb);
outputMO.release();
ec.setVariable(output.getName(), outputMO);
}
//6a. wrap outputs in named list (evalList)
else {
Data[] ldata = boundOutputNames.stream()
.map(n -> ec.getVariable(n)).toArray(Data[]::new);
String[] lnames = boundOutputNames.toArray(new String[0]);
ListObject listOutput = new ListObject(ldata, lnames);
ec.setVariable(output.getName(), listOutput);
}

//7. cleanup of variable expanded from list
Expand Down
Expand Up @@ -736,7 +736,7 @@ private void processMoveInstruction(ExecutionContext ec) {

if ( srcData == null ) {
throw new DMLRuntimeException("Unexpected error: could not find a data object "
+ "for variable name:" + getInput1().getName() + ", while processing instruction ");
+ "for variable name: " + getInput1().getName() + ", while processing instruction ");
}

// remove existing variable bound to target name and
Expand Down

0 comments on commit 9da2fb8

Please sign in to comment.