Skip to content

Commit

Permalink
[SYSTEMDS-3500] Fix perftest regression via new contains-value function
Browse files Browse the repository at this point in the history
A while ago the MLLogreg script was extended with robustness checks for
NaN inputs. In the perftest MLogReg 1M_1K_dense (8GB), this led to a
performance regression of unnecessary with 20GB driver because
input and output (16GB) exceed the 70% memory budget. Given that
sum(isNaN(X)) is likely false, we now expose an already existing block
operations contains(X, pattern) that has only have the memory reqs.
We added the CP, SPARK, and FED instructions as well as related tests.
  • Loading branch information
mboehm7 committed Feb 23, 2023
1 parent 096ca06 commit 9446bf6
Show file tree
Hide file tree
Showing 17 changed files with 294 additions and 60 deletions.
6 changes: 3 additions & 3 deletions scripts/builtin/multiLogReg.dml
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@ m_multiLogReg = function(Matrix[Double] X, Matrix[Double] Y, Int icpt = 2,
D = ncol (X);

# Robustness for datasets with missing values (causing NaN gradients)
numNaNs = sum(isNaN(X))
if( numNaNs > 0 ) {
hasNaNs = contains(target=X, pattern=NaN);
if( hasNaNs > 0 ) {
if(verbose)
print("multiLogReg: matrix X contains "+numNaNs+" missing values, replacing with 0.")
print("multiLogReg: matrix X contains "+sum(isNaN(X))+" missing values, replacing with 0.")
X = replace(target=X, pattern=NaN, replacement=0);
}

Expand Down
6 changes: 3 additions & 3 deletions scripts/builtin/multiLogRegPredict.dml
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ m_multiLogRegPredict = function(Matrix[Double] X, Matrix[Double] B, Matrix[Doubl
stop("multiLogRegPredict: mismatching ncol(X) and nrow(B): "+ncol(X)+" "+nrow(B));

# Robustness for datasets with missing values (causing NaN probabilities)
numNaNs = sum(isNaN(X))
if( numNaNs > 0 ) {
print("multiLogRegPredict: matrix X contains "+numNaNs+" missing values, replacing with 0.")
hasNaNs = contains(target=X, pattern=NaN);
if( hasNaNs > 0 ) {
print("multiLogRegPredict: matrix X contains "+sum(isNaN(X))+" missing values, replacing with 0.")
X = replace(target=X, pattern=NaN, replacement=0);
}
accuracy = 0.0 # initialize variable
Expand Down
1 change: 1 addition & 0 deletions src/main/java/org/apache/sysds/common/Builtins.java
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ public enum Builtins {
// Parameterized functions with parameters
AUTODIFF("autoDiff", false, true),
CDF("cdf", false, true),
CONTAINS("contains", false, true),
COUNT_DISTINCT("countDistinct",false, true),
COUNT_DISTINCT_APPROX("countDistinctApprox", false, true),
COUNT_DISTINCT_APPROX_ROW("rowCountDistinctApprox", false, true),
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/org/apache/sysds/common/Types.java
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,7 @@ public static ReOrgOp valueOfByOpcode(String opcode) {
}

public enum ParamBuiltinOp {
AUTODIFF, INVALID, CDF, INVCDF, GROUPEDAGG, RMEMPTY, REPLACE, REXPAND,
AUTODIFF, CDF, CONTAINS, INVALID, INVCDF, GROUPEDAGG, RMEMPTY, REPLACE, REXPAND,
LOWER_TRI, UPPER_TRI,
TRANSFORMAPPLY, TRANSFORMDECODE, TRANSFORMCOLMAP, TRANSFORMMETA,
TOKENIZE, TOSTRING, LIST, PARAMSERV
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,8 @@ public Lop constructLops()
case REXPAND: {
constructLopsRExpand(inputlops, et);
break;
}
}
case CONTAINS:
case CDF:
case INVCDF:
case REPLACE:
Expand Down
21 changes: 5 additions & 16 deletions src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,6 @@ public String getInstructions(String output)

break;

case REPLACE: {
sb.append( "replace" );
sb.append( OPERAND_DELIMITOR );
sb.append(compileGenericParamMap(_inputParams));
break;
}

case LOWER_TRI: {
sb.append( "lowertri" );
sb.append( OPERAND_DELIMITOR );
Expand Down Expand Up @@ -174,11 +167,14 @@ public String getInstructions(String output)

break;

case CONTAINS:
case REPLACE:
case TOKENIZE:
case TRANSFORMAPPLY:
case TRANSFORMDECODE:
case TRANSFORMCOLMAP:
case TRANSFORMMETA:{
case TRANSFORMMETA:
case PARAMSERV: {
sb.append(_operation.name().toLowerCase()); //opcode
sb.append(OPERAND_DELIMITOR);
sb.append(compileGenericParamMap(_inputParams));
Expand All @@ -202,14 +198,7 @@ public String getInstructions(String output)
sb.append(compileGenericParamMap(_inputParams));
break;
}

case PARAMSERV: {
sb.append("paramserv");
sb.append(OPERAND_DELIMITOR);
sb.append(compileGenericParamMap(_inputParams));
break;
}


default:
throw new LopsException(this.printErrorLocation() + "In ParameterizedBuiltin Lop, Unknown operation: " + _operation);
}
Expand Down
1 change: 1 addition & 0 deletions src/main/java/org/apache/sysds/parser/DMLTranslator.java
Original file line number Diff line number Diff line change
Expand Up @@ -2007,6 +2007,7 @@ private Hop processParameterizedBuiltinFunctionExpression(ParameterizedBuiltinFu
target.getValueType(), source.getOpCode(), paramHops);
break;

case CONTAINS:
case GROUPEDAGG:
case RMEMPTY:
case REPLACE:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,10 @@ public void validateExpression(HashMap<String, DataIdentifier> ids, HashMap<Stri
validateReplace(output, conditional);
break;

case CONTAINS:
validateContains(output, conditional);
break;

case ORDER:
validateOrder(output, conditional);
break;
Expand Down Expand Up @@ -725,28 +729,24 @@ private void validateExtractTriangular(DataIdentifier output, Builtins op, bool
output.setDimensions(in.getDim1(), in.getDim2());
}

private void validateContains(DataIdentifier output, boolean conditional) {
//check existence and correctness of arguments
Expression target = getVarParam("target");
checkTargetParam(target, conditional);
checkScalarParam("contains", "pattern", conditional);

//set boolean scalar
output.setBooleanProperties();
}

private void validateReplace(DataIdentifier output, boolean conditional) {
//check existence and correctness of arguments
Expression target = getVarParam("target");
if( target.getOutput().getDataType() != DataType.FRAME ){
checkTargetParam(target, conditional);
}

Expression pattern = getVarParam("pattern");
if( pattern==null ) {
raiseValidateError("Named parameter 'pattern' missing. Please specify the replacement pattern.", conditional, LanguageErrorCodes.INVALID_PARAMETERS);
}
else if( pattern.getOutput().getDataType() != DataType.SCALAR ){
raiseValidateError("Replacement pattern 'pattern' is of type '"+pattern.getOutput().getDataType()+"'. Please, specify a scalar replacement pattern.", conditional, LanguageErrorCodes.INVALID_PARAMETERS);
}

Expression replacement = getVarParam("replacement");
if( replacement==null ) {
raiseValidateError("Named parameter 'replacement' missing. Please specify the replacement value.", conditional, LanguageErrorCodes.INVALID_PARAMETERS);
}
else if( replacement.getOutput().getDataType() != DataType.SCALAR ){
raiseValidateError("Replacement value 'replacement' is of type '"+replacement.getOutput().getDataType()+"'. Please, specify a scalar replacement value.", conditional, LanguageErrorCodes.INVALID_PARAMETERS);
}
checkScalarParam("replace", "pattern", conditional);
checkScalarParam("replace", "replacement", conditional);

// Output is a matrix with same dims as input
output.setDataType(target.getOutput().getDataType());
Expand All @@ -756,6 +756,19 @@ else if( replacement.getOutput().getDataType() != DataType.SCALAR ){
output.setValueType(ValueType.FP64);
output.setDimensions(target.getOutput().getDim1(), target.getOutput().getDim2());
}

private void checkScalarParam(String group, String param, boolean conditional) {
Expression eparam = getVarParam(param);
if( eparam==null ) {
raiseValidateError("Named parameter '"+param+"' missing. Please specify the "+group+" pattern.",
conditional, LanguageErrorCodes.INVALID_PARAMETERS);
}
else if( eparam.getOutput().getDataType() != DataType.SCALAR ){
raiseValidateError(group + " parameter '"+param+"' is of type '"
+ eparam.getOutput().getDataType()+"'. Please, specify a scalar "+param+".",
conditional, LanguageErrorCodes.INVALID_PARAMETERS);
}
}

private void validateOrder(DataIdentifier output, boolean conditional) {
//check existence and correctness of arguments
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,19 @@ else if( aop.aggOp.increOp.fn instanceof Mean ) {
throw new DMLRuntimeException(ex);
}
}


public static boolean aggBooleanScalar(Future<FederatedResponse>[] tmp) {
boolean ret = false;
try {
for( Future<FederatedResponse> fr : tmp )
ret |= ((ScalarObject)fr.get().getData()[0]).getBooleanValue();
}
catch (Exception e) {
throw new DMLRuntimeException(e);
}
return ret;
}

public static MatrixBlock aggMatrix(AggregateUnaryOperator aop, Future<FederatedResponse>[] ffr, FederationMap map) {
if (aop.isRowAggregate() && map.getType() == FType.ROW)
return bind(ffr, false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ public class CPInstructionParser extends InstructionParser {

// Parameterized Builtin Functions
String2CPInstructionType.put( "autoDiff" , CPType.ParameterizedBuiltin);
String2CPInstructionType.put( "contains", CPType.ParameterizedBuiltin);
String2CPInstructionType.put("paramserv", CPType.ParameterizedBuiltin);
String2CPInstructionType.put( "nvlist", CPType.ParameterizedBuiltin);
String2CPInstructionType.put( "cdf", CPType.ParameterizedBuiltin);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,8 @@ public class SPInstructionParser extends InstructionParser
String2SPInstructionType.put( "isinf", SPType.Unary);

// Parameterized Builtin Functions
String2SPInstructionType.put( "autoDiff" , SPType.ParameterizedBuiltin);
String2SPInstructionType.put( "autoDiff", SPType.ParameterizedBuiltin);
String2SPInstructionType.put( "contains", SPType.ParameterizedBuiltin);
String2SPInstructionType.put( "groupedagg", SPType.ParameterizedBuiltin);
String2SPInstructionType.put( "mapgroupedagg", SPType.ParameterizedBuiltin);
String2SPInstructionType.put( "rmempty", SPType.ParameterizedBuiltin);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,14 @@ else if(opcode.equalsIgnoreCase("groupedagg")) {
}
else if(opcode.equalsIgnoreCase("rmempty") || opcode.equalsIgnoreCase("replace") ||
opcode.equalsIgnoreCase("rexpand") || opcode.equalsIgnoreCase("lowertri") ||
opcode.equalsIgnoreCase("uppertri")) {
opcode.equalsIgnoreCase("uppertri") ) {
func = ParameterizedBuiltin.getParameterizedBuiltinFnObject(opcode);
return new ParameterizedBuiltinCPInstruction(new SimpleOperator(func), paramsMap, out, opcode, str);
}
else if(opcode.equals("transformapply") || opcode.equals("transformdecode") ||
opcode.equals("transformcolmap") || opcode.equals("transformmeta") || opcode.equals("tokenize") ||
opcode.equals("toString") || opcode.equals("nvlist") || opcode.equals("autoDiff")) {
else if(opcode.equals("transformapply") || opcode.equals("transformdecode")
|| opcode.equalsIgnoreCase("contains") || opcode.equals("transformcolmap")
|| opcode.equals("transformmeta") || opcode.equals("tokenize")
|| opcode.equals("toString") || opcode.equals("nvlist") || opcode.equals("autoDiff")) {
return new ParameterizedBuiltinCPInstruction(null, paramsMap, out, opcode, str);
}
else if("paramserv".equals(opcode)) {
Expand Down Expand Up @@ -235,6 +236,14 @@ else if(opcode.equalsIgnoreCase("rmempty")) {
ec.releaseMatrixInput(params.get("select"));
}
}
else if(opcode.equalsIgnoreCase("contains")) {
String varName = params.get("target");
MatrixBlock target = ec.getMatrixInput(varName);
double pattern = Double.parseDouble(params.get("pattern"));
boolean ret = target.containsValue(pattern);
ec.releaseMatrixInput(varName);
ec.setScalarOutput(output.getName(), new BooleanObject(ret));
}
else if(opcode.equalsIgnoreCase("replace")) {
if(ec.isFrameObject(params.get("target"))){
FrameBlock target = ec.getFrameInput(params.get("target"));
Expand All @@ -255,7 +264,6 @@ else if(opcode.equalsIgnoreCase("replace")) {
ec.setMatrixOutput(output.getName(), ret);
targetObj.release();
}

}
else if(opcode.equals("lowertri") || opcode.equals("uppertri")) {
MatrixBlock target = ec.getMatrixInput(params.get("target"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,11 @@
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.controlprogram.federated.MatrixLineagePair;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
import org.apache.sysds.runtime.functionobjects.ParameterizedBuiltin;
import org.apache.sysds.runtime.functionobjects.ValueFunction;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.BooleanObject;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.ParameterizedBuiltinCPInstruction;
Expand All @@ -85,8 +87,8 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio
protected final HashMap<String, String> params;

private static final String[] PARAM_BUILTINS = new String[]{
"replace", "rmempty", "lowertri", "uppertri", "transformdecode", "transformapply", "tokenize"};

"contains", "replace", "rmempty", "lowertri", "uppertri",
"transformdecode", "transformapply", "tokenize"};

protected ParameterizedBuiltinFEDInstruction(Operator op, HashMap<String, String> paramsMap, CPOperand out,
String opcode, String istr) {
Expand All @@ -110,7 +112,8 @@ public static ParameterizedBuiltinFEDInstruction parseInstruction(String str) {
ValueFunction func = ParameterizedBuiltin.getParameterizedBuiltinFnObject(opcode);
return new ParameterizedBuiltinFEDInstruction(new SimpleOperator(func), paramsMap, out, opcode, str);
}
else if(opcode.equals("transformapply") || opcode.equals("transformdecode") || opcode.equals("tokenize")) {
else if(opcode.equals("transformapply") || opcode.equals("transformdecode")
|| opcode.equals("tokenize") || opcode.equals("contains") ) {
return new ParameterizedBuiltinFEDInstruction(null, paramsMap, out, opcode, str);
}
else {
Expand Down Expand Up @@ -140,15 +143,17 @@ public static LinkedHashMap<String, String> constructParameterMap(String[] param
return paramMap;
}

public static ParameterizedBuiltinFEDInstruction parseInstruction(ParameterizedBuiltinCPInstruction inst,
ExecutionContext ec) {
public static ParameterizedBuiltinFEDInstruction parseInstruction(
ParameterizedBuiltinCPInstruction inst, ExecutionContext ec)
{
if(ArrayUtils.contains(PARAM_BUILTINS, inst.getOpcode()) && inst.getTarget(ec).isFederatedExcept(FType.BROADCAST))
return ParameterizedBuiltinFEDInstruction.parseInstruction(inst);
return null;
}

public static ParameterizedBuiltinFEDInstruction parseInstruction(ParameterizedBuiltinSPInstruction inst,
ExecutionContext ec) {
public static ParameterizedBuiltinFEDInstruction parseInstruction(
ParameterizedBuiltinSPInstruction inst, ExecutionContext ec)
{
if( inst.getOpcode().equalsIgnoreCase("replace") && inst.getTarget(ec).isFederatedExcept(FType.BROADCAST) )
return ParameterizedBuiltinFEDInstruction.parseInstruction(inst);
return null;
Expand All @@ -167,13 +172,21 @@ private static ParameterizedBuiltinFEDInstruction parseInstruction(Parameterized
@Override
public void processInstruction(ExecutionContext ec) {
String opcode = getOpcode();
if(opcode.equalsIgnoreCase("replace")) {
if(opcode.equalsIgnoreCase("contains")) {
FederationMap map = getTarget(ec).getFedMapping();
FederatedRequest fr1 = FederationUtils.callInstruction(instString,
output, new CPOperand[] {getTargetOperand()}, new long[] {map.getID()});
FederatedRequest fr2 = new FederatedRequest(RequestType.GET_VAR, fr1.getID());
Future<FederatedResponse>[] tmp = map.execute(getTID(), fr1, fr2);
boolean ret = FederationUtils.aggBooleanScalar(tmp);
ec.setVariable(output.getName(), new BooleanObject(ret));
}
else if(opcode.equalsIgnoreCase("replace")) {
// similar to unary federated instructions, get federated input
// execute instruction, and derive federated output matrix
CacheableData<?> mo = getTarget(ec);
FederatedRequest fr1 = FederationUtils.callInstruction(instString,
output,
new CPOperand[] {getTargetOperand()},
FederatedRequest fr1 = FederationUtils.callInstruction(
instString, output, new CPOperand[] {getTargetOperand()},
new long[] {mo.getFedMapping().getID()});
Future<FederatedResponse>[] ret = mo.getFedMapping().execute(getTID(), true, fr1);

Expand Down
Loading

0 comments on commit 9446bf6

Please sign in to comment.