Skip to content

Commit

Permalink
[SYSTEMDS-3413] Row/Col aggregation for countDistinct
Browse files Browse the repository at this point in the history
This patch converts countDistinct() from  a non-parameterized builtin to
a parameterized builtin function to allow for 1 new parameter: dir for
direction. The value of dir can be r and c, denoting row-wise and
column-wise aggregation respectively. This patch only implements CP and
the SP case will throw a NotImplementedException()- the latter case will
be addressed in a subsequent patch.

Closes #1677
  • Loading branch information
BACtaki authored and Baunsgaard committed Oct 28, 2022
1 parent 11d0773 commit 6c9f0ff
Show file tree
Hide file tree
Showing 41 changed files with 1,496 additions and 171 deletions.
10 changes: 7 additions & 3 deletions src/main/java/org/apache/sysds/common/Builtins.java
Expand Up @@ -37,7 +37,7 @@
* building SystemDS, these scripts are packaged into the jar as well.
*/
public enum Builtins {
//builtin functions
// Builtin functions without parameters
ABSTAIN("abstain", true),
ABS("abs", false),
ACOS("acos", false),
Expand Down Expand Up @@ -93,7 +93,6 @@ public enum Builtins {
CORRECTTYPOSAPPLY("correctTyposApply", true),
COS("cos", false),
COSH("cosh", false),
COUNT_DISTINCT("countDistinct",false),
COV("cov", false),
COX("cox", true),
CSPLINE("cspline", true),
Expand Down Expand Up @@ -305,10 +304,15 @@ public enum Builtins {
XGBOOSTPREDICT_CLASS("xgboostPredictClassification", true),
XOR("xor", false),

//parameterized builtin functions
// Parameterized functions with parameters
AUTODIFF("autoDiff", false, true),
CDF("cdf", false, true),
COUNT_DISTINCT("countDistinct",false, true),
COUNT_DISTINCT_ROW("countDistinctRow",false, true),
COUNT_DISTINCT_COL("countDistinctCol",false, true),
COUNT_DISTINCT_APPROX("countDistinctApprox", false, true),
COUNT_DISTINCT_APPROX_ROW("countDistinctApproxRow", false, true),
COUNT_DISTINCT_APPROX_COL("countDistinctApproxCol", false, true),
CVLM("cvlm", true, false),
GROUPEDAGG("aggregate", "groupedAggregate", false, true),
INVCDF("icdf", false, true),
Expand Down
6 changes: 3 additions & 3 deletions src/main/java/org/apache/sysds/common/Types.java
Expand Up @@ -197,9 +197,9 @@ public enum AggOp {
PROD(4), SUM_PROD(5),
TRACE(6), MEAN(7), VAR(8),
MAXINDEX(9), MININDEX(10),
COUNT_DISTINCT(11),
COUNT_DISTINCT_APPROX(12);
COUNT_DISTINCT(11), COUNT_DISTINCT_ROW(12), COUNT_DISTINCT_COL(13),
COUNT_DISTINCT_APPROX(14), COUNT_DISTINCT_APPROX_ROW(15), COUNT_DISTINCT_APPROX_COL(16);

@Override
public String toString() {
switch(this) {
Expand Down
29 changes: 24 additions & 5 deletions src/main/java/org/apache/sysds/lops/PartialAggregate.java
Expand Up @@ -342,19 +342,38 @@ else if( dir == Direction.Col )
}

case COUNT_DISTINCT: {
if(dir == Direction.RowCol )
return "uacd";
break;
switch (dir) {
case RowCol: return "uacd";
case Row: return "uacdr";
case Col: return "uacdc";
default:
throw new LopsException("PartialAggregate.getOpcode() - "
+ "Unknown aggregate direction: " + dir);
}
}


case COUNT_DISTINCT_ROW:
return "uacdr";

case COUNT_DISTINCT_COL:
return "uacdc";

case COUNT_DISTINCT_APPROX: {
switch (dir) {
case RowCol: return "uacdap";
case Row: return "uacdapr";
case Col: return "uacdapc";
default:
throw new LopsException("PartialAggregate.getOpcode() - "
+ "Unknown aggregate direction: " + dir);
}
break;
}

case COUNT_DISTINCT_APPROX_ROW:
return "uacdapr";

case COUNT_DISTINCT_APPROX_COL:
return "uacdapc";
}

//should never come here for normal compilation
Expand Down
Expand Up @@ -943,14 +943,6 @@ else if( getOpCode() == Builtins.RBIND ) {
output.setBlocksize(0);
output.setValueType(ValueType.INT64);
break;
case COUNT_DISTINCT:
checkNumParameters(1);
checkDataTypeParam(getFirstExpr(), DataType.MATRIX);
output.setDataType(DataType.SCALAR);
output.setDimensions(0, 0);
output.setBlocksize(0);
output.setValueType(ValueType.INT64);
break;
case LINEAGE:
checkNumParameters(1);
checkDataTypeParam(getFirstExpr(),
Expand Down
22 changes: 18 additions & 4 deletions src/main/java/org/apache/sysds/parser/DMLTranslator.java
Expand Up @@ -2034,15 +2034,16 @@ private Hop processParameterizedBuiltinFunctionExpression(ParameterizedBuiltinFu
target.getValueType(), ParamBuiltinOp.TOSTRING, paramHops) :
HopRewriteUtils.createBinary(paramHops.get("target"), new LiteralOp(""), OpOp2.PLUS);
break;

case LISTNV:
currBuiltinOp = new ParameterizedBuiltinOp(target.getName(), target.getDataType(),
target.getValueType(), ParamBuiltinOp.LIST, paramHops);
break;

case COUNT_DISTINCT:
case COUNT_DISTINCT_APPROX:
// Default direction and data type
Direction dir = Direction.RowCol;
DataType dataType = DataType.SCALAR;
Direction dir = Direction.RowCol; // Default direction
DataType dataType = DataType.SCALAR; // Default output data type

LiteralOp dirOp = (LiteralOp) paramHops.get("dir");
if (dirOp != null) {
Expand All @@ -2062,6 +2063,19 @@ private Hop processParameterizedBuiltinFunctionExpression(ParameterizedBuiltinFu
currBuiltinOp = new AggUnaryOp(target.getName(), dataType, target.getValueType(),
AggOp.valueOf(source.getOpCode().name()), dir, paramHops.get("data"));
break;

case COUNT_DISTINCT_ROW:
case COUNT_DISTINCT_APPROX_ROW:
currBuiltinOp = new AggUnaryOp(target.getName(), DataType.MATRIX, target.getValueType(),
AggOp.valueOf(source.getOpCode().name()), Direction.Row, paramHops.get("data"));
break;

case COUNT_DISTINCT_COL:
case COUNT_DISTINCT_APPROX_COL:
currBuiltinOp = new AggUnaryOp(target.getName(), DataType.MATRIX, target.getValueType(),
AggOp.valueOf(source.getOpCode().name()), Direction.Col, paramHops.get("data"));
break;

default:
throw new ParseException(source.printErrorLocation() +
"processParameterizedBuiltinFunctionExpression() -- Unknown operation: " + source.getOpCode());
Expand Down Expand Up @@ -2361,10 +2375,10 @@ private Hop processBuiltinFunctionExpression(BuiltinFunctionExpression source, D
case SUM:
case PROD:
case VAR:
case COUNT_DISTINCT:
currBuiltinOp = new AggUnaryOp(target.getName(), DataType.SCALAR, target.getValueType(),
AggOp.valueOf(source.getOpCode().name()), Direction.RowCol, expr);
break;

case MEAN:
if ( expr2 == null ) {
// example: x = mean(Y);
Expand Down
Expand Up @@ -246,7 +246,15 @@ public void validateExpression(HashMap<String, DataIdentifier> ids, HashMap<Stri
validateParamserv(output, conditional);
break;

case COUNT_DISTINCT:
case COUNT_DISTINCT_ROW:
case COUNT_DISTINCT_COL:
validateCountDistinct(output, conditional);
break;

case COUNT_DISTINCT_APPROX:
case COUNT_DISTINCT_APPROX_ROW:
case COUNT_DISTINCT_APPROX_COL:
validateCountDistinctApprox(output, conditional);
break;

Expand Down Expand Up @@ -353,6 +361,45 @@ private void validateParamserv(DataIdentifier output, boolean conditional) {
output.setBlocksize(-1);
}

private void validateCountDistinct(DataIdentifier output, boolean conditional) {
HashMap<String, Expression> varParams = getVarParams();

// "data" is the only parameter that is allowed to be unnamed
if (varParams.containsKey(null)) {
varParams.put("data", varParams.remove(null));
}

// Validate the number of parameters
String fname = getOpCode().getName();
String usageMessage = "function " + fname + " takes at least 1 and at most 2 parameters";
if (varParams.size() < 1) {
raiseValidateError("Too few parameters: " + usageMessage, conditional);
}

if (varParams.size() > 2) {
raiseValidateError("Too many parameters: " + usageMessage, conditional);
}

// Check parameter names are valid
Set<String> validParameterNames = CollectionUtils.asSet("data", "dir");
checkInvalidParameters(getOpCode(), varParams, validParameterNames);

// Check parameter expression data types match expected
checkDataType(false, fname, "data", DataType.MATRIX, conditional);
checkDataValueType(false, fname, "data", DataType.MATRIX, ValueType.FP64, conditional);

// We need the dimensions of the input matrix to determine the output matrix characteristics
// Validate data parameter, lookup previously defined var or resolve expression
Identifier dataId = varParams.get("data").getOutput();
if (dataId == null) {
raiseValidateError("Cannot parse input parameter \"data\" to function " + fname, conditional);
}

checkStringParam(true, fname, "dir", conditional);
// Check data value of "dir" parameter
validateAggregationDirection(dataId, output);
}

private void validateCountDistinctApprox(DataIdentifier output, boolean conditional) {
Set<String> validTypeNames = CollectionUtils.asSet("KMV");
HashMap<String, Expression> varParams = getVarParams();
Expand Down Expand Up @@ -390,7 +437,7 @@ private void validateCountDistinctApprox(DataIdentifier output, boolean conditio

checkStringParam(true, fname, "type", conditional);
// Check data value of "type" parameter
if (varParams.keySet().contains("type")) {
if (varParams.containsKey("type")) {
String typeString = varParams.get("type").toString().toUpperCase();
if (!validTypeNames.contains(typeString)) {
raiseValidateError("Unrecognized type for optional parameter " + typeString, conditional);
Expand All @@ -402,7 +449,12 @@ private void validateCountDistinctApprox(DataIdentifier output, boolean conditio

checkStringParam(true, fname, "dir", conditional);
// Check data value of "dir" parameter
if (varParams.keySet().contains("dir")) {
validateAggregationDirection(dataId, output);
}

private void validateAggregationDirection(Identifier dataId, DataIdentifier output) {
HashMap<String, Expression> varParams = getVarParams();
if (varParams.containsKey("dir")) {
String directionString = varParams.get("dir").toString().toUpperCase();

// Set output type and dimensions based on direction
Expand Down Expand Up @@ -435,9 +487,7 @@ private void validateCountDistinctApprox(DataIdentifier output, boolean conditio
} else {
raiseValidateError("Invalid argument: " + directionString + " is not recognized");
}

// default to dir="rc"
} else {
} else { // default to dir="rc"
output.setDataType(DataType.SCALAR);
output.setDimensions(0, 0);
output.setBlocksize(0);
Expand Down
Expand Up @@ -117,6 +117,8 @@ public class CPInstructionParser extends InstructionParser
String2CPInstructionType.put( "exists" , CPType.AggregateUnary);
String2CPInstructionType.put( "lineage" , CPType.AggregateUnary);
String2CPInstructionType.put( "uacd" , CPType.AggregateUnary);
String2CPInstructionType.put( "uacdr" , CPType.AggregateUnary);
String2CPInstructionType.put( "uacdc" , CPType.AggregateUnary);
String2CPInstructionType.put( "uacdap" , CPType.AggregateUnary);
String2CPInstructionType.put( "uacdapr" , CPType.AggregateUnary);
String2CPInstructionType.put( "uacdapc" , CPType.AggregateUnary);
Expand Down
Expand Up @@ -82,26 +82,16 @@
import org.apache.sysds.runtime.functionobjects.ReduceDiag;
import org.apache.sysds.runtime.functionobjects.ReduceRow;
import org.apache.sysds.runtime.functionobjects.Xor;
import org.apache.sysds.runtime.instructions.cp.CPInstruction.CPType;
import org.apache.sysds.runtime.instructions.cp.AggregateUnaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.CPInstruction.CPType;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FEDType;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
import org.apache.sysds.runtime.instructions.gpu.GPUInstruction.GPUINSTRUCTION_TYPE;
import org.apache.sysds.runtime.instructions.spark.SPInstruction.SPType;
import org.apache.sysds.runtime.matrix.data.LibCommonsMath;
import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator;
import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
import org.apache.sysds.runtime.matrix.operators.AggregateTernaryOperator;
import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.matrix.operators.CMOperator;
import org.apache.sysds.runtime.matrix.operators.*;
import org.apache.sysds.runtime.matrix.operators.CMOperator.AggregateOperationTypes;
import org.apache.sysds.runtime.matrix.operators.LeftScalarOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.matrix.operators.RightScalarOperator;
import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
import org.apache.sysds.runtime.matrix.operators.TernaryOperator;
import org.apache.sysds.runtime.matrix.operators.UnaryOperator;


public class InstructionUtils
Expand Down Expand Up @@ -287,7 +277,14 @@ public static boolean isUnaryMetadata(String opcode) {
public static AggregateUnaryOperator parseBasicAggregateUnaryOperator(String opcode) {
return parseBasicAggregateUnaryOperator(opcode, 1);
}


/**
* Parse the given opcode into an aggregate unary operator.
*
* @param opcode opcode
* @param numThreads number of threads
* @return Parsed aggregate unary operator object. Caller must handle possible null return value.
*/
public static AggregateUnaryOperator parseBasicAggregateUnaryOperator(String opcode, int numThreads)
{
AggregateUnaryOperator aggun = null;
Expand Down Expand Up @@ -420,7 +417,31 @@ else if ( opcode.equalsIgnoreCase("uacmin") ) {
AggregateOperator agg = new AggregateOperator(Double.POSITIVE_INFINITY, Builtin.getBuiltinFnObject("min"));
aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject(), numThreads);
}

else if ( opcode.equalsIgnoreCase("uacd") ) {
aggun = new CountDistinctOperator(AggregateUnaryCPInstruction.AUType.COUNT_DISTINCT,
Direction.RowCol, ReduceAll.getReduceAllFnObject());
}
else if ( opcode.equalsIgnoreCase("uacdr") ) {
aggun = new CountDistinctOperator(AggregateUnaryCPInstruction.AUType.COUNT_DISTINCT,
Direction.Row, ReduceCol.getReduceColFnObject());
}
else if ( opcode.equalsIgnoreCase("uacdc") ) {
aggun = new CountDistinctOperator(AggregateUnaryCPInstruction.AUType.COUNT_DISTINCT,
Direction.Col, ReduceRow.getReduceRowFnObject());
}
else if ( opcode.equalsIgnoreCase("uacdap") ) {
aggun = new CountDistinctOperator(AggregateUnaryCPInstruction.AUType.COUNT_DISTINCT_APPROX,
Direction.RowCol, ReduceAll.getReduceAllFnObject());
}
else if ( opcode.equalsIgnoreCase("uacdapr") ) {
aggun = new CountDistinctOperator(AggregateUnaryCPInstruction.AUType.COUNT_DISTINCT_APPROX,
Direction.Row, ReduceCol.getReduceColFnObject());
}
else if ( opcode.equalsIgnoreCase("uacdapc") ) {
aggun = new CountDistinctOperator(AggregateUnaryCPInstruction.AUType.COUNT_DISTINCT_APPROX,
Direction.Col, ReduceRow.getReduceRowFnObject());
}

return aggun;
}

Expand Down
Expand Up @@ -126,7 +126,9 @@ public class SPInstructionParser extends InstructionParser
String2SPInstructionType.put( "uac*" , SPType.AggregateUnary);
String2SPInstructionType.put( "uatrace" , SPType.AggregateUnary);
String2SPInstructionType.put( "uaktrace", SPType.AggregateUnary);
String2SPInstructionType.put( "uacdap" , SPType.AggregateUnary);
String2SPInstructionType.put( "uacd" , SPType.AggregateUnary);
String2SPInstructionType.put( "uacdr" , SPType.AggregateUnary);
String2SPInstructionType.put( "uacdc" , SPType.AggregateUnary);

// Aggregate unary sketch operators
String2SPInstructionType.put( "uacdap" , SPType.AggregateUnarySketch);
Expand Down

0 comments on commit 6c9f0ff

Please sign in to comment.