Skip to content

Commit

Permalink
[SYSTEMDS-2996] countDistinctApprox Builtin function
Browse files Browse the repository at this point in the history
This commit adds countDistinctApprox instruction to allow
for a faster approximate counting of distinct elements in a matrix.
Also added is support for spark with this new instruction.

Closes #1531
Closes #1554

(Just to make sure github see that you are the author)
Co-authored-by: Badrul Chowdhury <badrul_chowdhury@apple.com>
  • Loading branch information
BACtaki authored and Baunsgaard committed Mar 2, 2022
1 parent 463bc97 commit 5590135
Show file tree
Hide file tree
Showing 34 changed files with 1,932 additions and 373 deletions.
2 changes: 1 addition & 1 deletion src/main/java/org/apache/sysds/common/Builtins.java
Expand Up @@ -93,7 +93,6 @@ public enum Builtins {
COS("cos", false),
COSH("cosh", false),
COUNT_DISTINCT("countDistinct",false),
COUNT_DISTINCT_APPROX("countDistinctApprox",false),
COV("cov", false),
COX("cox", true),
CSPLINE("cspline", true),
Expand Down Expand Up @@ -306,6 +305,7 @@ public enum Builtins {
//parameterized builtin functions
AUTODIFF("autoDiff", false, true),
CDF("cdf", false, true),
COUNT_DISTINCT_APPROX("countDistinctApprox", false, true),
CVLM("cvlm", true, false),
GROUPEDAGG("aggregate", "groupedAggregate", false, true),
INVCDF("icdf", false, true),
Expand Down
3 changes: 3 additions & 0 deletions src/main/java/org/apache/sysds/common/Types.java
Expand Up @@ -153,6 +153,9 @@ public boolean isRow() {
public boolean isCol() {
return this == Col;
}
public boolean isRowCol() {
return this == RowCol;
}
@Override
public String toString() {
switch(this) {
Expand Down
2 changes: 2 additions & 0 deletions src/main/java/org/apache/sysds/conf/DMLConfig.java
Expand Up @@ -25,6 +25,7 @@
import java.io.StringWriter;
import java.util.HashMap;

import javax.xml.XMLConstants;
import javax.xml.parsers.DocumentBuilder;
import javax.xml.parsers.DocumentBuilderFactory;
import javax.xml.parsers.ParserConfigurationException;
Expand Down Expand Up @@ -245,6 +246,7 @@ private void parseConfig () throws ParserConfigurationException, SAXException, I
private DocumentBuilder getDocumentBuilder() throws ParserConfigurationException {
if (_documentBuilder == null) {
DocumentBuilderFactory factory = DocumentBuilderFactory.newInstance();
factory.setFeature(XMLConstants.FEATURE_SECURE_PROCESSING, true); // Prevent XML Injection
factory.setIgnoringComments(true); //ignore XML comments
_documentBuilder = factory.newDocumentBuilder();
}
Expand Down
9 changes: 6 additions & 3 deletions src/main/java/org/apache/sysds/lops/PartialAggregate.java
Expand Up @@ -217,7 +217,7 @@ private String getOpcode() {
}

/**
* Instruction generation for for CP and Spark
* Instruction generation for CP and Spark
*/
@Override
public String getInstructions(String input1, String output)
Expand Down Expand Up @@ -348,8 +348,11 @@ else if( dir == Direction.Col )
}

case COUNT_DISTINCT_APPROX: {
if(dir == Direction.RowCol )
return "uacdap";
switch (dir) {
case RowCol: return "uacdap";
case Row: return "uacdapr";
case Col: return "uacdapc";
}
break;
}
}
Expand Down
Expand Up @@ -623,10 +623,10 @@ public void validateExpression(HashMap<String, DataIdentifier> ids, HashMap<Stri
case MEAN:
//checkNumParameters(2, false); // mean(Y) or mean(Y,W)
if (getSecondExpr() != null) {
checkNumParameters (2);
checkNumParameters(2);
}
else {
checkNumParameters (1);
checkNumParameters(1);
}

checkMatrixParam(getFirstExpr());
Expand Down Expand Up @@ -933,15 +933,13 @@ else if( getOpCode() == Builtins.RBIND ) {
output.setValueType(ValueType.INT64);
break;
case COUNT_DISTINCT:
case COUNT_DISTINCT_APPROX:
checkNumParameters(1);
checkDataTypeParam(getFirstExpr(), DataType.MATRIX);
output.setDataType(DataType.SCALAR);
output.setDimensions(0, 0);
output.setBlocksize(0);
output.setValueType(ValueType.INT64);
break;

case LINEAGE:
checkNumParameters(1);
checkDataTypeParam(getFirstExpr(),
Expand All @@ -951,14 +949,12 @@ else if( getOpCode() == Builtins.RBIND ) {
output.setBlocksize(0);
output.setValueType(ValueType.STRING);
break;

case LIST:
output.setDataType(DataType.LIST);
output.setValueType(ValueType.UNKNOWN);
output.setDimensions(getAllExpr().length, 1);
output.setBlocksize(-1);
break;

case EXISTS:
checkNumParameters(1);
checkStringOrDataIdentifier(getFirstExpr());
Expand Down Expand Up @@ -1825,9 +1821,9 @@ public VariableSet variablesUpdated() {
protected void checkNumParameters(int count) { //always unconditional
if (getFirstExpr() == null && _args.length > 0) {
raiseValidateError("Missing argument for function " + this.getOpCode(), false,
LanguageErrorCodes.INVALID_PARAMETERS);
LanguageErrorCodes.INVALID_PARAMETERS);
}

// Not sure the rationale for the first two if loops, but will keep them for backward compatibility
if (((count == 1) && (getSecondExpr() != null || getThirdExpr() != null))
|| ((count == 2) && (getThirdExpr() != null))) {
Expand All @@ -1843,7 +1839,7 @@ protected void checkNumParameters(int count) { //always unconditional
} else if (count == 0 && (_args.length > 0
|| getSecondExpr() != null || getThirdExpr() != null)) {
raiseValidateError("Missing argument for function " + this.getOpCode()
+ "(). This function doesn't take any arguments.", false);
+ "(). This function doesn't take any arguments.", false);
}
}

Expand All @@ -1870,7 +1866,7 @@ protected void checkDataTypeParam(Expression e, DataType... dt) { //always uncon
if( !ArrayUtils.contains(dt, e.getOutput().getDataType()) )
raiseValidateError("Non-matching expected data type for function "+ getOpCode(), false, LanguageErrorCodes.UNSUPPORTED_PARAMETERS);
}

protected void checkMatrixFrameParam(Expression e) { //always unconditional
if (e.getOutput().getDataType() != DataType.MATRIX && e.getOutput().getDataType() != DataType.FRAME) {
raiseValidateError("Expecting matrix or frame parameter for function "+ getOpCode(), false, LanguageErrorCodes.UNSUPPORTED_PARAMETERS);
Expand Down
60 changes: 40 additions & 20 deletions src/main/java/org/apache/sysds/parser/DMLTranslator.java
Expand Up @@ -30,6 +30,22 @@

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Builtins;
import org.apache.sysds.common.Types.AggOp;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.Direction;
import org.apache.sysds.common.Types.FileFormat;
import org.apache.sysds.common.Types.OpOp1;
import org.apache.sysds.common.Types.OpOp2;
import org.apache.sysds.common.Types.OpOp3;
import org.apache.sysds.common.Types.OpOpDG;
import org.apache.sysds.common.Types.OpOpData;
import org.apache.sysds.common.Types.OpOpDnn;
import org.apache.sysds.common.Types.OpOpN;
import org.apache.sysds.common.Types.ParamBuiltinOp;
import org.apache.sysds.common.Types.ReOrgOp;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.conf.DMLConfig;
import org.apache.sysds.hops.AggBinaryOp;
Expand Down Expand Up @@ -62,22 +78,6 @@
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.LopsException;
import org.apache.sysds.lops.compile.Dag;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Builtins;
import org.apache.sysds.common.Types.AggOp;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.Direction;
import org.apache.sysds.common.Types.FileFormat;
import org.apache.sysds.common.Types.OpOp1;
import org.apache.sysds.common.Types.OpOp2;
import org.apache.sysds.common.Types.OpOp3;
import org.apache.sysds.common.Types.OpOpDG;
import org.apache.sysds.common.Types.OpOpData;
import org.apache.sysds.common.Types.OpOpDnn;
import org.apache.sysds.common.Types.OpOpN;
import org.apache.sysds.common.Types.ParamBuiltinOp;
import org.apache.sysds.common.Types.ReOrgOp;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.parser.PrintStatement.PRINTTYPE;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.BasicProgramBlock;
Expand All @@ -91,7 +91,6 @@
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;


public class DMLTranslator
{
private static final Log LOG = LogFactory.getLog(DMLTranslator.class.getName());
Expand Down Expand Up @@ -2035,6 +2034,29 @@ private Hop processParameterizedBuiltinFunctionExpression(ParameterizedBuiltinFu
target.getValueType(), ParamBuiltinOp.LIST, paramHops);
break;

case COUNT_DISTINCT_APPROX:
// Default direction and data type
Direction dir = Direction.RowCol;
DataType dataType = DataType.SCALAR;

LiteralOp dirOp = (LiteralOp) paramHops.get("dir");
if (dirOp != null) {
String dirString = dirOp.getStringValue().toUpperCase();
if (dirString.equals(Direction.RowCol.toString())) {
dir = Direction.RowCol;
dataType = DataType.SCALAR;
} else if (dirString.equals(Direction.Row.toString())) {
dir = Direction.Row;
dataType = DataType.MATRIX;
} else if (dirString.equals(Direction.Col.toString())) {
dir = Direction.Col;
dataType = DataType.MATRIX;
}
}

currBuiltinOp = new AggUnaryOp(target.getName(), dataType, target.getValueType(),
AggOp.valueOf(source.getOpCode().name()), dir, paramHops.get("data"));
break;
default:
throw new ParseException(source.printErrorLocation() +
"processParameterizedBuiltinFunctionExpression() -- Unknown operation: " + source.getOpCode());
Expand Down Expand Up @@ -2335,11 +2357,9 @@ private Hop processBuiltinFunctionExpression(BuiltinFunctionExpression source, D
case PROD:
case VAR:
case COUNT_DISTINCT:
case COUNT_DISTINCT_APPROX:
currBuiltinOp = new AggUnaryOp(target.getName(), DataType.SCALAR, target.getValueType(),
AggOp.valueOf(source.getOpCode().name()), Direction.RowCol, expr);
AggOp.valueOf(source.getOpCode().name()), Direction.RowCol, expr);
break;

case MEAN:
if ( expr2 == null ) {
// example: x = mean(Y);
Expand Down

0 comments on commit 5590135

Please sign in to comment.