From 6c9f0ff1304125111ee20d3a3309f45f65bc6661 Mon Sep 17 00:00:00 2001 From: Badrul Chowdhury Date: Fri, 28 Oct 2022 14:38:21 +0200 Subject: [PATCH] [SYSTEMDS-3413] Row/Col aggregation for countDistinct This patch converts countDistinct() from a non-parameterized builtin to a parameterized builtin function to allow for 1 new parameter: dir for direction. The value of dir can be r and c, denoting row-wise and column-wise aggregation respectively. This patch only implements CP and the SP case will throw a NotImplementedException()- the latter case will be addressed in a subsequent patch. Closes #1677 --- .../org/apache/sysds/common/Builtins.java | 10 +- .../java/org/apache/sysds/common/Types.java | 6 +- .../apache/sysds/lops/PartialAggregate.java | 29 +++- .../parser/BuiltinFunctionExpression.java | 8 - .../apache/sysds/parser/DMLTranslator.java | 22 ++- ...arameterizedBuiltinFunctionExpression.java | 60 +++++++- .../instructions/CPInstructionParser.java | 2 + .../instructions/InstructionUtils.java | 51 +++++-- .../instructions/SPInstructionParser.java | 4 +- .../cp/AggregateUnaryCPInstruction.java | 70 +++------ .../AggregateUnarySketchSPInstruction.java | 34 ++--- .../matrix/data/LibMatrixCountDistinct.java | 4 +- .../operators/CountDistinctOperator.java | 45 ++---- .../component/matrix/CountDistinctTest.java | 5 +- .../countDistinct/CountDistinctApproxCol.java | 2 +- .../countDistinct/CountDistinctApproxRow.java | 2 +- .../countDistinct/CountDistinctCol.java | 103 +++++++++++++ .../countDistinct/CountDistinctColAlias.java | 103 +++++++++++++ .../countDistinct/CountDistinctRow.java | 103 +++++++++++++ .../countDistinct/CountDistinctRowAlias.java | 103 +++++++++++++ .../countDistinct/CountDistinctRowCol.java | 2 +- .../CountDistinctRowColParameterized.java | 55 +++++++ .../CountDistinctRowOrColBase.java | 32 ++-- .../CountDistinctApproxCol.java | 118 +++++++++++++++ .../CountDistinctApproxColAlias.java | 118 +++++++++++++++ .../CountDistinctApproxRow.java | 118 +++++++++++++++ .../CountDistinctApproxRowAlias.java | 118 +++++++++++++++ .../CountDistinctApproxRowCol.java | 5 +- ...ountDistinctApproxRowColParameterized.java | 141 ++++++++++++++++++ ...countDistinct.dml => countDistinctCol.dml} | 2 +- .../countDistinct/countDistinctColAlias.dml | 24 +++ .../countDistinct/countDistinctRow.dml | 24 +++ .../countDistinct/countDistinctRowAlias.dml | 24 +++ .../countDistinct/countDistinctRowCol.dml | 24 +++ .../countDistinctRowColParameterized.dml | 24 +++ .../countDistinctApproxCol.dml | 0 .../countDistinctApproxColAlias.dml | 24 +++ .../countDistinctApproxRow.dml | 0 .../countDistinctApproxRowAlias.dml | 24 +++ .../countDistinctApproxRowCol.dml | 24 +++ ...ountDistinctApproxRowColParameterized.dml} | 0 41 files changed, 1496 insertions(+), 171 deletions(-) create mode 100644 src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctCol.java create mode 100644 src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctColAlias.java create mode 100644 src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRow.java create mode 100644 src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowAlias.java create mode 100644 src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowColParameterized.java create mode 100644 src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxCol.java create mode 100644 src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxColAlias.java create mode 100644 src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxRow.java create mode 100644 src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxRowAlias.java rename src/test/java/org/apache/sysds/test/functions/{countDistinct => countDistinctApprox}/CountDistinctApproxRowCol.java (96%) create mode 100644 src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxRowColParameterized.java rename src/test/scripts/functions/countDistinct/{countDistinct.dml => countDistinctCol.dml} (96%) create mode 100644 src/test/scripts/functions/countDistinct/countDistinctColAlias.dml create mode 100644 src/test/scripts/functions/countDistinct/countDistinctRow.dml create mode 100644 src/test/scripts/functions/countDistinct/countDistinctRowAlias.dml create mode 100644 src/test/scripts/functions/countDistinct/countDistinctRowCol.dml create mode 100644 src/test/scripts/functions/countDistinct/countDistinctRowColParameterized.dml rename src/test/scripts/functions/{countDistinct => countDistinctApprox}/countDistinctApproxCol.dml (100%) create mode 100644 src/test/scripts/functions/countDistinctApprox/countDistinctApproxColAlias.dml rename src/test/scripts/functions/{countDistinct => countDistinctApprox}/countDistinctApproxRow.dml (100%) create mode 100644 src/test/scripts/functions/countDistinctApprox/countDistinctApproxRowAlias.dml create mode 100644 src/test/scripts/functions/countDistinctApprox/countDistinctApproxRowCol.dml rename src/test/scripts/functions/{countDistinct/countDistinctApproxRowCol.dml => countDistinctApprox/countDistinctApproxRowColParameterized.dml} (100%) diff --git a/src/main/java/org/apache/sysds/common/Builtins.java b/src/main/java/org/apache/sysds/common/Builtins.java index 5e9509696b3..262212570ec 100644 --- a/src/main/java/org/apache/sysds/common/Builtins.java +++ b/src/main/java/org/apache/sysds/common/Builtins.java @@ -37,7 +37,7 @@ * building SystemDS, these scripts are packaged into the jar as well. */ public enum Builtins { - //builtin functions + // Builtin functions without parameters ABSTAIN("abstain", true), ABS("abs", false), ACOS("acos", false), @@ -93,7 +93,6 @@ public enum Builtins { CORRECTTYPOSAPPLY("correctTyposApply", true), COS("cos", false), COSH("cosh", false), - COUNT_DISTINCT("countDistinct",false), COV("cov", false), COX("cox", true), CSPLINE("cspline", true), @@ -305,10 +304,15 @@ public enum Builtins { XGBOOSTPREDICT_CLASS("xgboostPredictClassification", true), XOR("xor", false), - //parameterized builtin functions + // Parameterized functions with parameters AUTODIFF("autoDiff", false, true), CDF("cdf", false, true), + COUNT_DISTINCT("countDistinct",false, true), + COUNT_DISTINCT_ROW("countDistinctRow",false, true), + COUNT_DISTINCT_COL("countDistinctCol",false, true), COUNT_DISTINCT_APPROX("countDistinctApprox", false, true), + COUNT_DISTINCT_APPROX_ROW("countDistinctApproxRow", false, true), + COUNT_DISTINCT_APPROX_COL("countDistinctApproxCol", false, true), CVLM("cvlm", true, false), GROUPEDAGG("aggregate", "groupedAggregate", false, true), INVCDF("icdf", false, true), diff --git a/src/main/java/org/apache/sysds/common/Types.java b/src/main/java/org/apache/sysds/common/Types.java index c013b7890bb..4f613a40d77 100644 --- a/src/main/java/org/apache/sysds/common/Types.java +++ b/src/main/java/org/apache/sysds/common/Types.java @@ -197,9 +197,9 @@ public enum AggOp { PROD(4), SUM_PROD(5), TRACE(6), MEAN(7), VAR(8), MAXINDEX(9), MININDEX(10), - COUNT_DISTINCT(11), - COUNT_DISTINCT_APPROX(12); - + COUNT_DISTINCT(11), COUNT_DISTINCT_ROW(12), COUNT_DISTINCT_COL(13), + COUNT_DISTINCT_APPROX(14), COUNT_DISTINCT_APPROX_ROW(15), COUNT_DISTINCT_APPROX_COL(16); + @Override public String toString() { switch(this) { diff --git a/src/main/java/org/apache/sysds/lops/PartialAggregate.java b/src/main/java/org/apache/sysds/lops/PartialAggregate.java index 050d87a3fda..0481c7373ac 100644 --- a/src/main/java/org/apache/sysds/lops/PartialAggregate.java +++ b/src/main/java/org/apache/sysds/lops/PartialAggregate.java @@ -342,19 +342,38 @@ else if( dir == Direction.Col ) } case COUNT_DISTINCT: { - if(dir == Direction.RowCol ) - return "uacd"; - break; + switch (dir) { + case RowCol: return "uacd"; + case Row: return "uacdr"; + case Col: return "uacdc"; + default: + throw new LopsException("PartialAggregate.getOpcode() - " + + "Unknown aggregate direction: " + dir); + } } - + + case COUNT_DISTINCT_ROW: + return "uacdr"; + + case COUNT_DISTINCT_COL: + return "uacdc"; + case COUNT_DISTINCT_APPROX: { switch (dir) { case RowCol: return "uacdap"; case Row: return "uacdapr"; case Col: return "uacdapc"; + default: + throw new LopsException("PartialAggregate.getOpcode() - " + + "Unknown aggregate direction: " + dir); } - break; } + + case COUNT_DISTINCT_APPROX_ROW: + return "uacdapr"; + + case COUNT_DISTINCT_APPROX_COL: + return "uacdapc"; } //should never come here for normal compilation diff --git a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java index c5db658dfa3..c3aca47d381 100644 --- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java @@ -943,14 +943,6 @@ else if( getOpCode() == Builtins.RBIND ) { output.setBlocksize(0); output.setValueType(ValueType.INT64); break; - case COUNT_DISTINCT: - checkNumParameters(1); - checkDataTypeParam(getFirstExpr(), DataType.MATRIX); - output.setDataType(DataType.SCALAR); - output.setDimensions(0, 0); - output.setBlocksize(0); - output.setValueType(ValueType.INT64); - break; case LINEAGE: checkNumParameters(1); checkDataTypeParam(getFirstExpr(), diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java b/src/main/java/org/apache/sysds/parser/DMLTranslator.java index 315f54ff721..553bf56fc52 100644 --- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java +++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java @@ -2034,15 +2034,16 @@ private Hop processParameterizedBuiltinFunctionExpression(ParameterizedBuiltinFu target.getValueType(), ParamBuiltinOp.TOSTRING, paramHops) : HopRewriteUtils.createBinary(paramHops.get("target"), new LiteralOp(""), OpOp2.PLUS); break; + case LISTNV: currBuiltinOp = new ParameterizedBuiltinOp(target.getName(), target.getDataType(), target.getValueType(), ParamBuiltinOp.LIST, paramHops); break; + case COUNT_DISTINCT: case COUNT_DISTINCT_APPROX: - // Default direction and data type - Direction dir = Direction.RowCol; - DataType dataType = DataType.SCALAR; + Direction dir = Direction.RowCol; // Default direction + DataType dataType = DataType.SCALAR; // Default output data type LiteralOp dirOp = (LiteralOp) paramHops.get("dir"); if (dirOp != null) { @@ -2062,6 +2063,19 @@ private Hop processParameterizedBuiltinFunctionExpression(ParameterizedBuiltinFu currBuiltinOp = new AggUnaryOp(target.getName(), dataType, target.getValueType(), AggOp.valueOf(source.getOpCode().name()), dir, paramHops.get("data")); break; + + case COUNT_DISTINCT_ROW: + case COUNT_DISTINCT_APPROX_ROW: + currBuiltinOp = new AggUnaryOp(target.getName(), DataType.MATRIX, target.getValueType(), + AggOp.valueOf(source.getOpCode().name()), Direction.Row, paramHops.get("data")); + break; + + case COUNT_DISTINCT_COL: + case COUNT_DISTINCT_APPROX_COL: + currBuiltinOp = new AggUnaryOp(target.getName(), DataType.MATRIX, target.getValueType(), + AggOp.valueOf(source.getOpCode().name()), Direction.Col, paramHops.get("data")); + break; + default: throw new ParseException(source.printErrorLocation() + "processParameterizedBuiltinFunctionExpression() -- Unknown operation: " + source.getOpCode()); @@ -2361,10 +2375,10 @@ private Hop processBuiltinFunctionExpression(BuiltinFunctionExpression source, D case SUM: case PROD: case VAR: - case COUNT_DISTINCT: currBuiltinOp = new AggUnaryOp(target.getName(), DataType.SCALAR, target.getValueType(), AggOp.valueOf(source.getOpCode().name()), Direction.RowCol, expr); break; + case MEAN: if ( expr2 == null ) { // example: x = mean(Y); diff --git a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java index c83b6f39112..bdfd38c5a46 100644 --- a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java @@ -246,7 +246,15 @@ public void validateExpression(HashMap ids, HashMap varParams = getVarParams(); + + // "data" is the only parameter that is allowed to be unnamed + if (varParams.containsKey(null)) { + varParams.put("data", varParams.remove(null)); + } + + // Validate the number of parameters + String fname = getOpCode().getName(); + String usageMessage = "function " + fname + " takes at least 1 and at most 2 parameters"; + if (varParams.size() < 1) { + raiseValidateError("Too few parameters: " + usageMessage, conditional); + } + + if (varParams.size() > 2) { + raiseValidateError("Too many parameters: " + usageMessage, conditional); + } + + // Check parameter names are valid + Set validParameterNames = CollectionUtils.asSet("data", "dir"); + checkInvalidParameters(getOpCode(), varParams, validParameterNames); + + // Check parameter expression data types match expected + checkDataType(false, fname, "data", DataType.MATRIX, conditional); + checkDataValueType(false, fname, "data", DataType.MATRIX, ValueType.FP64, conditional); + + // We need the dimensions of the input matrix to determine the output matrix characteristics + // Validate data parameter, lookup previously defined var or resolve expression + Identifier dataId = varParams.get("data").getOutput(); + if (dataId == null) { + raiseValidateError("Cannot parse input parameter \"data\" to function " + fname, conditional); + } + + checkStringParam(true, fname, "dir", conditional); + // Check data value of "dir" parameter + validateAggregationDirection(dataId, output); + } + private void validateCountDistinctApprox(DataIdentifier output, boolean conditional) { Set validTypeNames = CollectionUtils.asSet("KMV"); HashMap varParams = getVarParams(); @@ -390,7 +437,7 @@ private void validateCountDistinctApprox(DataIdentifier output, boolean conditio checkStringParam(true, fname, "type", conditional); // Check data value of "type" parameter - if (varParams.keySet().contains("type")) { + if (varParams.containsKey("type")) { String typeString = varParams.get("type").toString().toUpperCase(); if (!validTypeNames.contains(typeString)) { raiseValidateError("Unrecognized type for optional parameter " + typeString, conditional); @@ -402,7 +449,12 @@ private void validateCountDistinctApprox(DataIdentifier output, boolean conditio checkStringParam(true, fname, "dir", conditional); // Check data value of "dir" parameter - if (varParams.keySet().contains("dir")) { + validateAggregationDirection(dataId, output); + } + + private void validateAggregationDirection(Identifier dataId, DataIdentifier output) { + HashMap varParams = getVarParams(); + if (varParams.containsKey("dir")) { String directionString = varParams.get("dir").toString().toUpperCase(); // Set output type and dimensions based on direction @@ -435,9 +487,7 @@ private void validateCountDistinctApprox(DataIdentifier output, boolean conditio } else { raiseValidateError("Invalid argument: " + directionString + " is not recognized"); } - - // default to dir="rc" - } else { + } else { // default to dir="rc" output.setDataType(DataType.SCALAR); output.setDimensions(0, 0); output.setBlocksize(0); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java index f2d3080ddcb..b83a78674b7 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java @@ -117,6 +117,8 @@ public class CPInstructionParser extends InstructionParser String2CPInstructionType.put( "exists" , CPType.AggregateUnary); String2CPInstructionType.put( "lineage" , CPType.AggregateUnary); String2CPInstructionType.put( "uacd" , CPType.AggregateUnary); + String2CPInstructionType.put( "uacdr" , CPType.AggregateUnary); + String2CPInstructionType.put( "uacdc" , CPType.AggregateUnary); String2CPInstructionType.put( "uacdap" , CPType.AggregateUnary); String2CPInstructionType.put( "uacdapr" , CPType.AggregateUnary); String2CPInstructionType.put( "uacdapc" , CPType.AggregateUnary); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java index 2c8468955fc..d87e772709e 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java @@ -82,26 +82,16 @@ import org.apache.sysds.runtime.functionobjects.ReduceDiag; import org.apache.sysds.runtime.functionobjects.ReduceRow; import org.apache.sysds.runtime.functionobjects.Xor; -import org.apache.sysds.runtime.instructions.cp.CPInstruction.CPType; +import org.apache.sysds.runtime.instructions.cp.AggregateUnaryCPInstruction; import org.apache.sysds.runtime.instructions.cp.CPOperand; +import org.apache.sysds.runtime.instructions.cp.CPInstruction.CPType; import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FEDType; import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput; import org.apache.sysds.runtime.instructions.gpu.GPUInstruction.GPUINSTRUCTION_TYPE; import org.apache.sysds.runtime.instructions.spark.SPInstruction.SPType; import org.apache.sysds.runtime.matrix.data.LibCommonsMath; -import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator; -import org.apache.sysds.runtime.matrix.operators.AggregateOperator; -import org.apache.sysds.runtime.matrix.operators.AggregateTernaryOperator; -import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator; -import org.apache.sysds.runtime.matrix.operators.BinaryOperator; -import org.apache.sysds.runtime.matrix.operators.CMOperator; +import org.apache.sysds.runtime.matrix.operators.*; import org.apache.sysds.runtime.matrix.operators.CMOperator.AggregateOperationTypes; -import org.apache.sysds.runtime.matrix.operators.LeftScalarOperator; -import org.apache.sysds.runtime.matrix.operators.Operator; -import org.apache.sysds.runtime.matrix.operators.RightScalarOperator; -import org.apache.sysds.runtime.matrix.operators.ScalarOperator; -import org.apache.sysds.runtime.matrix.operators.TernaryOperator; -import org.apache.sysds.runtime.matrix.operators.UnaryOperator; public class InstructionUtils @@ -287,7 +277,14 @@ public static boolean isUnaryMetadata(String opcode) { public static AggregateUnaryOperator parseBasicAggregateUnaryOperator(String opcode) { return parseBasicAggregateUnaryOperator(opcode, 1); } - + + /** + * Parse the given opcode into an aggregate unary operator. + * + * @param opcode opcode + * @param numThreads number of threads + * @return Parsed aggregate unary operator object. Caller must handle possible null return value. + */ public static AggregateUnaryOperator parseBasicAggregateUnaryOperator(String opcode, int numThreads) { AggregateUnaryOperator aggun = null; @@ -420,7 +417,31 @@ else if ( opcode.equalsIgnoreCase("uacmin") ) { AggregateOperator agg = new AggregateOperator(Double.POSITIVE_INFINITY, Builtin.getBuiltinFnObject("min")); aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject(), numThreads); } - + else if ( opcode.equalsIgnoreCase("uacd") ) { + aggun = new CountDistinctOperator(AggregateUnaryCPInstruction.AUType.COUNT_DISTINCT, + Direction.RowCol, ReduceAll.getReduceAllFnObject()); + } + else if ( opcode.equalsIgnoreCase("uacdr") ) { + aggun = new CountDistinctOperator(AggregateUnaryCPInstruction.AUType.COUNT_DISTINCT, + Direction.Row, ReduceCol.getReduceColFnObject()); + } + else if ( opcode.equalsIgnoreCase("uacdc") ) { + aggun = new CountDistinctOperator(AggregateUnaryCPInstruction.AUType.COUNT_DISTINCT, + Direction.Col, ReduceRow.getReduceRowFnObject()); + } + else if ( opcode.equalsIgnoreCase("uacdap") ) { + aggun = new CountDistinctOperator(AggregateUnaryCPInstruction.AUType.COUNT_DISTINCT_APPROX, + Direction.RowCol, ReduceAll.getReduceAllFnObject()); + } + else if ( opcode.equalsIgnoreCase("uacdapr") ) { + aggun = new CountDistinctOperator(AggregateUnaryCPInstruction.AUType.COUNT_DISTINCT_APPROX, + Direction.Row, ReduceCol.getReduceColFnObject()); + } + else if ( opcode.equalsIgnoreCase("uacdapc") ) { + aggun = new CountDistinctOperator(AggregateUnaryCPInstruction.AUType.COUNT_DISTINCT_APPROX, + Direction.Col, ReduceRow.getReduceRowFnObject()); + } + return aggun; } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java index 73329b954d2..9496cf465ec 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java @@ -126,7 +126,9 @@ public class SPInstructionParser extends InstructionParser String2SPInstructionType.put( "uac*" , SPType.AggregateUnary); String2SPInstructionType.put( "uatrace" , SPType.AggregateUnary); String2SPInstructionType.put( "uaktrace", SPType.AggregateUnary); - String2SPInstructionType.put( "uacdap" , SPType.AggregateUnary); + String2SPInstructionType.put( "uacd" , SPType.AggregateUnary); + String2SPInstructionType.put( "uacdr" , SPType.AggregateUnary); + String2SPInstructionType.put( "uacdc" , SPType.AggregateUnary); // Aggregate unary sketch operators String2SPInstructionType.put( "uacdap" , SPType.AggregateUnarySketch); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateUnaryCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateUnaryCPInstruction.java index ddf00ada2b7..6fc01075205 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateUnaryCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateUnaryCPInstruction.java @@ -33,11 +33,13 @@ import org.apache.sysds.runtime.functionobjects.ReduceCol; import org.apache.sysds.runtime.functionobjects.ReduceRow; import org.apache.sysds.runtime.instructions.InstructionUtils; +import org.apache.sysds.runtime.instructions.spark.data.CorrMatrixBlock; import org.apache.sysds.runtime.lineage.LineageDedupUtils; import org.apache.sysds.runtime.lineage.LineageItem; import org.apache.sysds.runtime.matrix.data.LibMatrixCountDistinct; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.data.MatrixIndexes; +import org.apache.sysds.runtime.matrix.data.sketch.countdistinctapprox.SmallestPriorityQueue; import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator; import org.apache.sysds.runtime.matrix.operators.CountDistinctOperator; import org.apache.sysds.runtime.matrix.operators.Operator; @@ -45,6 +47,9 @@ import org.apache.sysds.runtime.meta.DataCharacteristics; import org.apache.sysds.utils.Explain; +import java.util.HashSet; +import java.util.Set; + public class AggregateUnaryCPInstruction extends UnaryCPInstruction { // private static final Log LOG = LogFactory.getLog(AggregateUnaryCPInstruction.class.getName()); @@ -81,36 +86,19 @@ public static AggregateUnaryCPInstruction parseInstruction(String str) { return new AggregateUnaryCPInstruction(new SimpleOperator(Builtin.getBuiltinFnObject(opcode)), in1, out, AUType.valueOf(opcode.toUpperCase()), opcode, str); } - else if(opcode.equalsIgnoreCase("uacd")){ - CountDistinctOperator op = new CountDistinctOperator(AUType.COUNT_DISTINCT) - .setDirection(Types.Direction.RowCol) - .setIndexFunction(ReduceAll.getReduceAllFnObject()); - - return new AggregateUnaryCPInstruction(op, in1, out, AUType.COUNT_DISTINCT, - opcode, str); + else if(opcode.equalsIgnoreCase("uacd") + || opcode.equalsIgnoreCase("uacdr") + || opcode.equalsIgnoreCase("uacdc")){ + AggregateUnaryOperator aggun = InstructionUtils.parseBasicAggregateUnaryOperator(opcode, + Integer.parseInt(parts[3])); + return new AggregateUnaryCPInstruction(aggun, in1, out, AUType.COUNT_DISTINCT, opcode, str); } - else if(opcode.equalsIgnoreCase("uacdap")){ - CountDistinctOperator op = new CountDistinctOperator(AUType.COUNT_DISTINCT_APPROX) - .setDirection(Types.Direction.RowCol) - .setIndexFunction(ReduceAll.getReduceAllFnObject()); - - return new AggregateUnaryCPInstruction(op, in1, out, AUType.COUNT_DISTINCT_APPROX, - opcode, str); - } - else if(opcode.equalsIgnoreCase("uacdapr")){ - CountDistinctOperator op = new CountDistinctOperator(AUType.COUNT_DISTINCT_APPROX) - .setDirection(Types.Direction.Row) - .setIndexFunction(ReduceCol.getReduceColFnObject()); - - return new AggregateUnaryCPInstruction(op, in1, out, AUType.COUNT_DISTINCT_APPROX, - opcode, str); - } - else if(opcode.equalsIgnoreCase("uacdapc")){ - CountDistinctOperator op = new CountDistinctOperator(AUType.COUNT_DISTINCT_APPROX) - .setDirection(Types.Direction.Col) - .setIndexFunction(ReduceRow.getReduceRowFnObject()); - - return new AggregateUnaryCPInstruction(op, in1, out, AUType.COUNT_DISTINCT_APPROX, + else if(opcode.equalsIgnoreCase("uacdap") + || opcode.equalsIgnoreCase("uacdapr") + || opcode.equalsIgnoreCase("uacdapc")){ + AggregateUnaryOperator aggun = InstructionUtils.parseBasicAggregateUnaryOperator(opcode, + Integer.parseInt(parts[3])); + return new AggregateUnaryCPInstruction(aggun, in1, out, AUType.COUNT_DISTINCT_APPROX, opcode, str); } else if(opcode.equalsIgnoreCase("uarimax") || opcode.equalsIgnoreCase("uarimin")){ @@ -199,34 +187,18 @@ else if( input1.getDataType().isMatrix() || input1.getDataType().isFrame() ) { ec.setScalarOutput(output_name, new StringObject(out)); break; } - case COUNT_DISTINCT: { - if( !ec.getVariables().keySet().contains(input1.getName()) ) - throw new DMLRuntimeException("Variable '" + input1.getName() + "' does not exist."); - MatrixBlock input = ec.getMatrixInput(input1.getName()); - - // Operator type: test and cast - if (!(_optr instanceof CountDistinctOperator)) { - throw new DMLRuntimeException("Operator should be instance of " + CountDistinctOperator.class.getSimpleName()); - } - CountDistinctOperator op = (CountDistinctOperator) (_optr); - - //TODO add support for row or col count distinct. - int res = (int) LibMatrixCountDistinct.estimateDistinctValues(input, op).getValue(0, 0); - ec.releaseMatrixInput(input1.getName()); - ec.setScalarOutput(output_name, new IntObject(res)); - break; - } + case COUNT_DISTINCT: case COUNT_DISTINCT_APPROX: { if(!ec.getVariables().keySet().contains(input1.getName())) { throw new DMLRuntimeException("Variable '" + input1.getName() + "' does not exist."); } - MatrixBlock input = ec.getMatrixInput(input1.getName()); + + // Operator type: test and cast if (!(_optr instanceof CountDistinctOperator)) { throw new DMLRuntimeException("Operator should be instance of " + CountDistinctOperator.class.getSimpleName()); } - - CountDistinctOperator op = (CountDistinctOperator) _optr; // It is safe to cast at this point + CountDistinctOperator op = (CountDistinctOperator) _optr; if (op.getDirection().isRowCol()) { long res = (long) LibMatrixCountDistinct.estimateDistinctValues(input, op).getValue(0, 0); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySketchSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySketchSPInstruction.java index 71bc75fd457..703828e3a13 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySketchSPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySketchSPInstruction.java @@ -53,20 +53,6 @@ public class AggregateUnarySketchSPInstruction extends UnarySPInstruction { protected AggregateUnarySketchSPInstruction(Operator op, CPOperand in, CPOperand out, AggBinaryOp.SparkAggType aggtype, String opcode, String instr) { super(SPType.AggregateUnarySketch, op, in, out, opcode, instr); this.op = (CountDistinctOperator) super.getOperator(); - - if (opcode.equals("uacdap")) { - this.op.setDirection(Types.Direction.RowCol) - .setIndexFunction(ReduceAll.getReduceAllFnObject()); - } else if (opcode.equals("uacdapr")) { - this.op.setDirection(Types.Direction.Row) - .setIndexFunction(ReduceCol.getReduceColFnObject()); - } else if (opcode.equals("uacdapc")) { - this.op.setDirection(Types.Direction.Col) - .setIndexFunction(ReduceRow.getReduceRowFnObject()); - } else { - throw new DMLException("Unrecognized opcode " + opcode); - } - this.aggtype = aggtype; } @@ -79,7 +65,19 @@ public static AggregateUnarySketchSPInstruction parseInstruction(String str) { CPOperand out = new CPOperand(parts[2]); AggBinaryOp.SparkAggType aggtype = AggBinaryOp.SparkAggType.valueOf(parts[3]); - CountDistinctOperator cdop = new CountDistinctOperator(CountDistinctOperatorTypes.KMV, Hash.HashType.LinearHash); + CountDistinctOperator cdop = null; + if (opcode.equals("uacdap")) { + cdop = new CountDistinctOperator(CountDistinctOperatorTypes.KMV, Types.Direction.RowCol, + ReduceAll.getReduceAllFnObject(), Hash.HashType.LinearHash); + } else if (opcode.equals("uacdapr")) { + cdop = new CountDistinctOperator(CountDistinctOperatorTypes.KMV, Types.Direction.Row, + ReduceCol.getReduceColFnObject(), Hash.HashType.LinearHash); + } else if (opcode.equals("uacdapc")) { + cdop = new CountDistinctOperator(CountDistinctOperatorTypes.KMV, Types.Direction.Col, + ReduceRow.getReduceRowFnObject(), Hash.HashType.LinearHash); + } else { + throw new DMLException("Unrecognized opcode: " + opcode); + } return new AggregateUnarySketchSPInstruction(cdop, in1, out, aggtype, opcode, str); } @@ -147,7 +145,7 @@ private void processMatrixSketch(ExecutionContext ec) { out3 = out2.mapValues(new CalculateAggregateSketchFunction(this.op)); - updateUnaryAggOutputDataCharacteristics(sec, this.op.getIndexFunction()); + updateUnaryAggOutputDataCharacteristics(sec, this.op.indexFn); // put output RDD handle into symbol table sec.setRDDHandleForVariable(output.getName(), out3); @@ -173,7 +171,7 @@ public CorrMatrixBlock call(Tuple2 arg0) throws Exce MatrixBlock blkIn = arg0._2(); MatrixIndexes ixOut = new MatrixIndexes(); - this.op.getIndexFunction().execute(ixIn, ixOut); + this.op.indexFn.execute(ixIn, ixOut); return LibMatrixCountDistinct.createSketch(blkIn, this.op); } @@ -222,7 +220,7 @@ public Tuple2 call(Tuple2(idxOut, blkOut); } diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java index 814b7737f96..ee97c8d3405 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java @@ -52,7 +52,6 @@ public interface LibMatrixCountDistinct { * Public method to count the number of distinct values inside a matrix. Depending on which CountDistinctOperator * selected it either gets the absolute number or a estimated value. * - * TODO: Support counting num distinct in rows, or columns axis. * TODO: If the MatrixBlock type is CompressedMatrix, simply read the values from the ColGroups. * * @param in the input matrix to count number distinct values in @@ -252,7 +251,7 @@ else if(blkIn instanceof CompressedMatrixBlock) { } } } else { // Col aggregation - blkOut = new MatrixBlock(1, blkIn.getNumColumns(), false, blkIn.getNumRows()); + blkOut = new MatrixBlock(1, blkIn.getNumColumns(), false, blkIn.getNumColumns()); blkOut.allocateBlock(); // All dense and sparse formats (COO, CSR, MCSR) are row-major formats, so there is no obvious way to iterate @@ -300,7 +299,6 @@ else if(blkIn instanceof CompressedMatrixBlock) { if (csrBlock.isEmpty(rix)) { continue; } - distinct.clear(); int rpos = csrBlock.pos(rix); int clen = csrBlock.size(rix); int[] cixs = csrBlock.indexes(); diff --git a/src/main/java/org/apache/sysds/runtime/matrix/operators/CountDistinctOperator.java b/src/main/java/org/apache/sysds/runtime/matrix/operators/CountDistinctOperator.java index 1c430c91340..c33accf9430 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/operators/CountDistinctOperator.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/operators/CountDistinctOperator.java @@ -22,19 +22,20 @@ import org.apache.sysds.common.Types; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.functionobjects.IndexFunction; +import org.apache.sysds.runtime.functionobjects.Plus; import org.apache.sysds.runtime.instructions.cp.AggregateUnaryCPInstruction.AUType; import org.apache.sysds.utils.Hash.HashType; -public class CountDistinctOperator extends Operator { +public class CountDistinctOperator extends AggregateUnaryOperator { private static final long serialVersionUID = 7615123453265129670L; private final CountDistinctOperatorTypes operatorType; + private final Types.Direction direction; private final HashType hashType; - private Types.Direction direction; - private IndexFunction indexFunction; - public CountDistinctOperator(AUType opType) { - super(true); + public CountDistinctOperator(AUType opType, Types.Direction direction, IndexFunction indexFunction) { + super(new AggregateOperator(0, Plus.getPlusFnObject()), indexFunction, 1); + switch(opType) { case COUNT_DISTINCT: this.operatorType = CountDistinctOperatorTypes.COUNT; @@ -46,25 +47,15 @@ public CountDistinctOperator(AUType opType) { throw new DMLRuntimeException(opType + " not supported for CountDistinct Operator"); } this.hashType = HashType.LinearHash; + this.direction = direction; } - public CountDistinctOperator(CountDistinctOperatorTypes operatorType) { - super(true); - this.operatorType = operatorType; - this.hashType = HashType.StandardJava; - } - - public CountDistinctOperator(CountDistinctOperatorTypes operatorType, HashType hashType) { - super(true); - this.operatorType = operatorType; - this.hashType = hashType; - } + public CountDistinctOperator(CountDistinctOperatorTypes operatorType, Types.Direction direction, + IndexFunction indexFunction, HashType hashType) { + super(new AggregateOperator(0, Plus.getPlusFnObject()), indexFunction, 1); - public CountDistinctOperator(CountDistinctOperatorTypes operatorType, IndexFunction indexFunction, - HashType hashType) { - super(true); this.operatorType = operatorType; - this.indexFunction = indexFunction; + this.direction = direction; this.hashType = hashType; } @@ -76,21 +67,7 @@ public HashType getHashType() { return hashType; } - public IndexFunction getIndexFunction() { - return indexFunction; - } - - public CountDistinctOperator setIndexFunction(IndexFunction indexFunction) { - this.indexFunction = indexFunction; - return this; - } - public Types.Direction getDirection() { return direction; } - - public CountDistinctOperator setDirection(Types.Direction direction) { - this.direction = direction; - return this; - } } diff --git a/src/test/java/org/apache/sysds/test/component/matrix/CountDistinctTest.java b/src/test/java/org/apache/sysds/test/component/matrix/CountDistinctTest.java index 5de18c4b3e7..4b4909e27a8 100644 --- a/src/test/java/org/apache/sysds/test/component/matrix/CountDistinctTest.java +++ b/src/test/java/org/apache/sysds/test/component/matrix/CountDistinctTest.java @@ -29,6 +29,7 @@ import org.apache.commons.lang.NotImplementedException; import org.apache.sysds.api.DMLException; import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.functionobjects.ReduceAll; import org.apache.sysds.runtime.matrix.data.LibMatrixCountDistinct; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.operators.CountDistinctOperator; @@ -138,7 +139,9 @@ else if(et != CountDistinctOperatorTypes.COUNT) { @Test public void testEstimation() { try { - CountDistinctOperator op = new CountDistinctOperator(et, ht).setDirection(Types.Direction.RowCol); + CountDistinctOperator op = new CountDistinctOperator(et, Types.Direction.RowCol, + ReduceAll.getReduceAllFnObject(), ht); + if(expectedException != null) { assertThrows(expectedException.getClass(), () -> { LibMatrixCountDistinct.estimateDistinctValues(in, op); diff --git a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxCol.java b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxCol.java index 5a7eccc447f..69f5fa1ef1d 100644 --- a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxCol.java +++ b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxCol.java @@ -26,7 +26,7 @@ public class CountDistinctApproxCol extends CountDistinctRowOrColBase { private final static String TEST_NAME = "countDistinctApproxCol"; - private final static String TEST_DIR = "functions/countDistinct/"; + private final static String TEST_DIR = "functions/countDistinctApprox/"; private final static String TEST_CLASS_DIR = TEST_DIR + CountDistinctApproxCol.class.getSimpleName() + "/"; @Override diff --git a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRow.java b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRow.java index c9aa75e3755..07f3fcac380 100644 --- a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRow.java +++ b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRow.java @@ -26,7 +26,7 @@ public class CountDistinctApproxRow extends CountDistinctRowOrColBase { private final static String TEST_NAME = "countDistinctApproxRow"; - private final static String TEST_DIR = "functions/countDistinct/"; + private final static String TEST_DIR = "functions/countDistinctApprox/"; private final static String TEST_CLASS_DIR = TEST_DIR + CountDistinctApproxRow.class.getSimpleName() + "/"; @Override diff --git a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctCol.java b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctCol.java new file mode 100644 index 00000000000..fb26da4da29 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctCol.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.countDistinct; + +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.data.SparseBlock; +import org.junit.Test; + +public class CountDistinctCol extends CountDistinctRowOrColBase { + + private final static String TEST_NAME = "countDistinctCol"; + private final static String TEST_DIR = "functions/countDistinct/"; + private final static String TEST_CLASS_DIR = TEST_DIR + CountDistinctCol.class.getSimpleName() + "/"; + + @Override + protected String getTestClassDir() { + return TEST_CLASS_DIR; + } + + @Override + protected String getTestName() { + return TEST_NAME; + } + + @Override + protected String getTestDir() { + return TEST_DIR; + } + + @Override + protected Types.Direction getDirection() { + return Types.Direction.Col; + } + + @Override + public void setUp() { + super.addTestConfiguration(); + super.setRunSparkTests(false); + } + + @Test + public void testCPSparseLargeDefaultMCSR() { + Types.ExecType ex = Types.ExecType.CP; + + int actualDistinctCount = 10; + int rows = 1000, cols = 10000; + double sparsity = 0.1; + double tolerance = actualDistinctCount * this.percentTolerance; + + countDistinctMatrixTest(getDirection(), actualDistinctCount, cols, rows, sparsity, ex, tolerance); + } + + @Test + public void testCPSparseLargeCSR() { + int actualDistinctCount = 10; + int rows = 1000, cols = 10000; + double sparsity = 0.1; + double tolerance = actualDistinctCount * this.percentTolerance; + + super.testCPSparseLarge(SparseBlock.Type.CSR, Types.Direction.Row, rows, cols, actualDistinctCount, sparsity, + tolerance); + } + + @Test + public void testCPSparseLargeCOO() { + int actualDistinctCount = 10; + int rows = 1000, cols = 10000; + double sparsity = 0.1; + double tolerance = actualDistinctCount * this.percentTolerance; + + super.testCPSparseLarge(SparseBlock.Type.COO, Types.Direction.Row, rows, cols, actualDistinctCount, sparsity, + tolerance); + } + + @Test + public void testCPDenseLarge() { + Types.ExecType ex = Types.ExecType.CP; + + int actualDistinctCount = 100; + int rows = 1000, cols = 10000; + double sparsity = 0.9; + double tolerance = actualDistinctCount * this.percentTolerance; + + countDistinctMatrixTest(getDirection(), actualDistinctCount, cols, rows, sparsity, ex, tolerance); + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctColAlias.java b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctColAlias.java new file mode 100644 index 00000000000..08620d13d1a --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctColAlias.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.countDistinct; + +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.data.SparseBlock; +import org.junit.Test; + +public class CountDistinctColAlias extends CountDistinctRowOrColBase { + + private final static String TEST_NAME = "countDistinctColAlias"; + private final static String TEST_DIR = "functions/countDistinct/"; + private final static String TEST_CLASS_DIR = TEST_DIR + CountDistinctColAlias.class.getSimpleName() + "/"; + + @Override + protected String getTestClassDir() { + return TEST_CLASS_DIR; + } + + @Override + protected String getTestName() { + return TEST_NAME; + } + + @Override + protected String getTestDir() { + return TEST_DIR; + } + + @Override + protected Types.Direction getDirection() { + return Types.Direction.Col; + } + + @Override + public void setUp() { + super.addTestConfiguration(); + super.setRunSparkTests(false); + } + + @Test + public void testCPSparseLargeDefaultMCSR() { + Types.ExecType ex = Types.ExecType.CP; + + int actualDistinctCount = 10; + int rows = 1000, cols = 10000; + double sparsity = 0.1; + double tolerance = actualDistinctCount * this.percentTolerance; + + countDistinctMatrixTest(getDirection(), actualDistinctCount, cols, rows, sparsity, ex, tolerance); + } + + @Test + public void testCPSparseLargeCSR() { + int actualDistinctCount = 10; + int rows = 1000, cols = 10000; + double sparsity = 0.1; + double tolerance = actualDistinctCount * this.percentTolerance; + + super.testCPSparseLarge(SparseBlock.Type.CSR, Types.Direction.Row, rows, cols, actualDistinctCount, sparsity, + tolerance); + } + + @Test + public void testCPSparseLargeCOO() { + int actualDistinctCount = 10; + int rows = 1000, cols = 10000; + double sparsity = 0.1; + double tolerance = actualDistinctCount * this.percentTolerance; + + super.testCPSparseLarge(SparseBlock.Type.COO, Types.Direction.Row, rows, cols, actualDistinctCount, sparsity, + tolerance); + } + + @Test + public void testCPDenseLarge() { + Types.ExecType ex = Types.ExecType.CP; + + int actualDistinctCount = 100; + int rows = 1000, cols = 10000; + double sparsity = 0.9; + double tolerance = actualDistinctCount * this.percentTolerance; + + countDistinctMatrixTest(getDirection(), actualDistinctCount, cols, rows, sparsity, ex, tolerance); + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRow.java b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRow.java new file mode 100644 index 00000000000..568c7516d0f --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRow.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.countDistinct; + +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.data.SparseBlock; +import org.junit.Test; + +public class CountDistinctRow extends CountDistinctRowOrColBase { + + private final static String TEST_NAME = "countDistinctRow"; + private final static String TEST_DIR = "functions/countDistinct/"; + private final static String TEST_CLASS_DIR = TEST_DIR + CountDistinctRow.class.getSimpleName() + "/"; + + @Override + protected String getTestClassDir() { + return TEST_CLASS_DIR; + } + + @Override + protected String getTestName() { + return TEST_NAME; + } + + @Override + protected String getTestDir() { + return TEST_DIR; + } + + @Override + protected Types.Direction getDirection() { + return Types.Direction.Row; + } + + @Override + public void setUp() { + super.addTestConfiguration(); + super.setRunSparkTests(false); + } + + @Test + public void testCPSparseLargeDefaultMCSR() { + Types.ExecType ex = Types.ExecType.CP; + + int actualDistinctCount = 10; + int rows = 10000, cols = 1000; + double sparsity = 0.1; + double tolerance = actualDistinctCount * this.percentTolerance; + + countDistinctMatrixTest(getDirection(), actualDistinctCount, cols, rows, sparsity, ex, tolerance); + } + + @Test + public void testCPSparseLargeCSR() { + int actualDistinctCount = 10; + int rows = 10000, cols = 1000; + double sparsity = 0.1; + double tolerance = actualDistinctCount * this.percentTolerance; + + super.testCPSparseLarge(SparseBlock.Type.CSR, Types.Direction.Row, rows, cols, actualDistinctCount, sparsity, + tolerance); + } + + @Test + public void testCPSparseLargeCOO() { + int actualDistinctCount = 10; + int rows = 10000, cols = 1000; + double sparsity = 0.1; + double tolerance = actualDistinctCount * this.percentTolerance; + + super.testCPSparseLarge(SparseBlock.Type.COO, Types.Direction.Row, rows, cols, actualDistinctCount, sparsity, + tolerance); + } + + @Test + public void testCPDenseLarge() { + Types.ExecType ex = Types.ExecType.CP; + + int actualDistinctCount = 100; + int rows = 10000, cols = 1000; + double sparsity = 0.9; + double tolerance = actualDistinctCount * this.percentTolerance; + + countDistinctMatrixTest(getDirection(), actualDistinctCount, cols, rows, sparsity, ex, tolerance); + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowAlias.java b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowAlias.java new file mode 100644 index 00000000000..c9e24cd38d7 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowAlias.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.countDistinct; + +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.data.SparseBlock; +import org.junit.Test; + +public class CountDistinctRowAlias extends CountDistinctRowOrColBase { + + private final static String TEST_NAME = "countDistinctRowAlias"; + private final static String TEST_DIR = "functions/countDistinct/"; + private final static String TEST_CLASS_DIR = TEST_DIR + CountDistinctRowAlias.class.getSimpleName() + "/"; + + @Override + protected String getTestClassDir() { + return TEST_CLASS_DIR; + } + + @Override + protected String getTestName() { + return TEST_NAME; + } + + @Override + protected String getTestDir() { + return TEST_DIR; + } + + @Override + protected Types.Direction getDirection() { + return Types.Direction.Row; + } + + @Override + public void setUp() { + super.addTestConfiguration(); + super.setRunSparkTests(false); + } + + @Test + public void testCPSparseLargeDefaultMCSR() { + Types.ExecType ex = Types.ExecType.CP; + + int actualDistinctCount = 10; + int rows = 10000, cols = 1000; + double sparsity = 0.1; + double tolerance = actualDistinctCount * this.percentTolerance; + + countDistinctMatrixTest(getDirection(), actualDistinctCount, cols, rows, sparsity, ex, tolerance); + } + + @Test + public void testCPSparseLargeCSR() { + int actualDistinctCount = 10; + int rows = 10000, cols = 1000; + double sparsity = 0.1; + double tolerance = actualDistinctCount * this.percentTolerance; + + super.testCPSparseLarge(SparseBlock.Type.CSR, Types.Direction.Row, rows, cols, actualDistinctCount, sparsity, + tolerance); + } + + @Test + public void testCPSparseLargeCOO() { + int actualDistinctCount = 10; + int rows = 10000, cols = 1000; + double sparsity = 0.1; + double tolerance = actualDistinctCount * this.percentTolerance; + + super.testCPSparseLarge(SparseBlock.Type.COO, Types.Direction.Row, rows, cols, actualDistinctCount, sparsity, + tolerance); + } + + @Test + public void testCPDenseLarge() { + Types.ExecType ex = Types.ExecType.CP; + + int actualDistinctCount = 100; + int rows = 10000, cols = 1000; + double sparsity = 0.9; + double tolerance = actualDistinctCount * this.percentTolerance; + + countDistinctMatrixTest(getDirection(), actualDistinctCount, cols, rows, sparsity, ex, tolerance); + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowCol.java b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowCol.java index 3de4a61bcdd..8f7d9acb8f6 100644 --- a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowCol.java +++ b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowCol.java @@ -24,7 +24,7 @@ public class CountDistinctRowCol extends CountDistinctRowColBase { - public String TEST_NAME = "countDistinct"; + public String TEST_NAME = "countDistinctRowCol"; public String TEST_DIR = "functions/countDistinct/"; public String TEST_CLASS_DIR = TEST_DIR + CountDistinctRowCol.class.getSimpleName() + "/"; diff --git a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowColParameterized.java b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowColParameterized.java new file mode 100644 index 00000000000..02048595e61 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowColParameterized.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.countDistinct; + +import org.apache.sysds.common.Types.ExecType; +import org.junit.Test; + +public class CountDistinctRowColParameterized extends CountDistinctRowColBase { + + public String TEST_NAME = "countDistinctRowColParameterized"; + public String TEST_DIR = "functions/countDistinct/"; + public String TEST_CLASS_DIR = TEST_DIR + CountDistinctRowColParameterized.class.getSimpleName() + "/"; + + protected String getTestClassDir() { + return TEST_CLASS_DIR; + } + + protected String getTestName() { + return TEST_NAME; + } + + protected String getTestDir() { + return TEST_DIR; + } + + @Override + public void setUp() { + super.addTestConfiguration(); + super.percentTolerance = 0.0; + } + + @Test + public void testSimple1by1() { + // test simple 1 by 1. + ExecType ex = ExecType.CP; + countDistinctScalarTest(1, 1, 1, 1.0, ex, 0.00001); + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowOrColBase.java b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowOrColBase.java index a880c0d0dd5..0d517776a23 100644 --- a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowOrColBase.java +++ b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowOrColBase.java @@ -30,6 +30,8 @@ import org.apache.sysds.test.TestUtils; import org.junit.Test; +import static org.junit.Assume.assumeTrue; + public abstract class CountDistinctRowOrColBase extends CountDistinctBase { @Override @@ -43,6 +45,8 @@ public abstract class CountDistinctRowOrColBase extends CountDistinctBase { protected abstract Types.Direction getDirection(); + private boolean runSparkTests = true; + protected void addTestConfiguration() { TestUtils.clearAssertionInformation(); addTestConfiguration(getTestName(), new TestConfiguration(getTestClassDir(), getTestName(), new String[] {"A"})); @@ -50,19 +54,8 @@ protected void addTestConfiguration() { this.percentTolerance = 0.2; } - /** - * This is a contrived example where size of row/col > 1024, which forces the calculation of a sketch. - */ - @Test - public void testCPDenseXLarge() { - Types.ExecType ex = Types.ExecType.CP; - - int actualDistinctCount = 10000; - int rows = 10000, cols = 10000; - double sparsity = 0.9; - double tolerance = actualDistinctCount * this.percentTolerance; - - countDistinctMatrixTest(getDirection(), actualDistinctCount, cols, rows, sparsity, ex, tolerance); + public void setRunSparkTests(boolean runSparkTests) { + this.runSparkTests = runSparkTests; } @Test @@ -91,6 +84,8 @@ public void testCPDenseSmall() { @Test public void testSparkSparseLargeMultiBlockAggregation() { + assumeTrue(runSparkTests); + Types.ExecType execType = Types.ExecType.SPARK; int actualDistinctCount = 10; @@ -103,6 +98,8 @@ public void testSparkSparseLargeMultiBlockAggregation() { @Test public void testSparkDenseLargeMultiBlockAggregation() { + assumeTrue(runSparkTests); + Types.ExecType execType = Types.ExecType.SPARK; int actualDistinctCount = 10; @@ -115,6 +112,8 @@ public void testSparkDenseLargeMultiBlockAggregation() { @Test public void testSparkSparseLargeNoneAggregation() { + assumeTrue(runSparkTests); + Types.ExecType execType = Types.ExecType.SPARK; int actualDistinctCount = 10; @@ -127,6 +126,8 @@ public void testSparkSparseLargeNoneAggregation() { @Test public void testSparkDenseLargeNoneAggregation() { + assumeTrue(runSparkTests); + Types.ExecType execType = Types.ExecType.SPARK; int actualDistinctCount = 10; @@ -145,9 +146,8 @@ protected void testCPSparseLarge(SparseBlock.Type sparseBlockType, Types.Directi } blkIn = new MatrixBlock(blkIn, sparseBlockType, true); - CountDistinctOperator op = new CountDistinctOperator(AggregateUnaryCPInstruction.AUType.COUNT_DISTINCT_APPROX) - .setDirection(direction) - .setIndexFunction(ReduceCol.getReduceColFnObject()); + CountDistinctOperator op = new CountDistinctOperator(AggregateUnaryCPInstruction.AUType.COUNT_DISTINCT_APPROX, + direction, ReduceCol.getReduceColFnObject()); MatrixBlock blkOut = LibMatrixCountDistinct.estimateDistinctValues(blkIn, op); double[][] expectedMatrix = getExpectedMatrixRowOrCol(direction, cols, rows, actualDistinctCount); diff --git a/src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxCol.java b/src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxCol.java new file mode 100644 index 00000000000..6752bc29bbb --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxCol.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.countDistinctApprox; + +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.data.SparseBlock; +import org.apache.sysds.test.functions.countDistinct.CountDistinctRowOrColBase; +import org.junit.Test; + +public class CountDistinctApproxCol extends CountDistinctRowOrColBase { + + private final static String TEST_NAME = "countDistinctApproxCol"; + private final static String TEST_DIR = "functions/countDistinctApprox/"; + private final static String TEST_CLASS_DIR = TEST_DIR + CountDistinctApproxCol.class.getSimpleName() + "/"; + + @Override + protected String getTestClassDir() { + return TEST_CLASS_DIR; + } + + @Override + protected String getTestName() { + return TEST_NAME; + } + + @Override + protected String getTestDir() { + return TEST_DIR; + } + + @Override + protected Types.Direction getDirection() { + return Types.Direction.Col; + } + + @Override + public void setUp() { + super.addTestConfiguration(); + } + + @Test + public void testCPSparseLargeDefaultMCSR() { + Types.ExecType ex = Types.ExecType.CP; + + int actualDistinctCount = 10; + int rows = 1000, cols = 10000; + double sparsity = 0.1; + double tolerance = actualDistinctCount * this.percentTolerance; + + countDistinctMatrixTest(getDirection(), actualDistinctCount, cols, rows, sparsity, ex, tolerance); + } + + @Test + public void testCPSparseLargeCSR() { + int actualDistinctCount = 10; + int rows = 1000, cols = 10000; + double sparsity = 0.1; + double tolerance = actualDistinctCount * this.percentTolerance; + + super.testCPSparseLarge(SparseBlock.Type.CSR, Types.Direction.Col, rows, cols, actualDistinctCount, sparsity, + tolerance); + } + + @Test + public void testCPSparseLargeCOO() { + int actualDistinctCount = 10; + int rows = 1000, cols = 10000; + double sparsity = 0.1; + double tolerance = actualDistinctCount * this.percentTolerance; + + super.testCPSparseLarge(SparseBlock.Type.COO, Types.Direction.Col, rows, cols, actualDistinctCount, sparsity, + tolerance); + } + + @Test + public void testCPDenseLarge() { + Types.ExecType ex = Types.ExecType.CP; + + int actualDistinctCount = 100; + int rows = 1000, cols = 10000; + double sparsity = 0.9; + double tolerance = actualDistinctCount * this.percentTolerance; + + countDistinctMatrixTest(getDirection(), actualDistinctCount, cols, rows, sparsity, ex, tolerance); + } + + /** + * This is a contrived example where size of row/col > 1024, which forces the calculation of a sketch in CP exec mode. + */ + @Test + public void testCPDenseXLarge() { + Types.ExecType ex = Types.ExecType.CP; + + int actualDistinctCount = 10000; + int rows = 10000, cols = 10000; + double sparsity = 0.9; + double tolerance = actualDistinctCount * this.percentTolerance; + + countDistinctMatrixTest(getDirection(), actualDistinctCount, cols, rows, sparsity, ex, tolerance); + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxColAlias.java b/src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxColAlias.java new file mode 100644 index 00000000000..e87813f464c --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxColAlias.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.countDistinctApprox; + +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.data.SparseBlock; +import org.apache.sysds.test.functions.countDistinct.CountDistinctRowOrColBase; +import org.junit.Test; + +public class CountDistinctApproxColAlias extends CountDistinctRowOrColBase { + + private final static String TEST_NAME = "countDistinctApproxColAlias"; + private final static String TEST_DIR = "functions/countDistinctApprox/"; + private final static String TEST_CLASS_DIR = TEST_DIR + CountDistinctApproxColAlias.class.getSimpleName() + "/"; + + @Override + protected String getTestClassDir() { + return TEST_CLASS_DIR; + } + + @Override + protected String getTestName() { + return TEST_NAME; + } + + @Override + protected String getTestDir() { + return TEST_DIR; + } + + @Override + protected Types.Direction getDirection() { + return Types.Direction.Col; + } + + @Override + public void setUp() { + super.addTestConfiguration(); + } + + @Test + public void testCPSparseLargeDefaultMCSR() { + Types.ExecType ex = Types.ExecType.CP; + + int actualDistinctCount = 10; + int rows = 1000, cols = 10000; + double sparsity = 0.1; + double tolerance = actualDistinctCount * this.percentTolerance; + + countDistinctMatrixTest(getDirection(), actualDistinctCount, cols, rows, sparsity, ex, tolerance); + } + + @Test + public void testCPSparseLargeCSR() { + int actualDistinctCount = 10; + int rows = 1000, cols = 10000; + double sparsity = 0.1; + double tolerance = actualDistinctCount * this.percentTolerance; + + super.testCPSparseLarge(SparseBlock.Type.CSR, Types.Direction.Col, rows, cols, actualDistinctCount, sparsity, + tolerance); + } + + @Test + public void testCPSparseLargeCOO() { + int actualDistinctCount = 10; + int rows = 1000, cols = 10000; + double sparsity = 0.1; + double tolerance = actualDistinctCount * this.percentTolerance; + + super.testCPSparseLarge(SparseBlock.Type.COO, Types.Direction.Col, rows, cols, actualDistinctCount, sparsity, + tolerance); + } + + @Test + public void testCPDenseLarge() { + Types.ExecType ex = Types.ExecType.CP; + + int actualDistinctCount = 100; + int rows = 1000, cols = 10000; + double sparsity = 0.9; + double tolerance = actualDistinctCount * this.percentTolerance; + + countDistinctMatrixTest(getDirection(), actualDistinctCount, cols, rows, sparsity, ex, tolerance); + } + + /** + * This is a contrived example where size of row/col > 1024, which forces the calculation of a sketch in CP exec mode. + */ + @Test + public void testCPDenseXLarge() { + Types.ExecType ex = Types.ExecType.CP; + + int actualDistinctCount = 10000; + int rows = 10000, cols = 10000; + double sparsity = 0.9; + double tolerance = actualDistinctCount * this.percentTolerance; + + countDistinctMatrixTest(getDirection(), actualDistinctCount, cols, rows, sparsity, ex, tolerance); + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxRow.java b/src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxRow.java new file mode 100644 index 00000000000..6e4678f5a8d --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxRow.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.countDistinctApprox; + +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.data.SparseBlock; +import org.apache.sysds.test.functions.countDistinct.CountDistinctRowOrColBase; +import org.junit.Test; + +public class CountDistinctApproxRow extends CountDistinctRowOrColBase { + + private final static String TEST_NAME = "countDistinctApproxRow"; + private final static String TEST_DIR = "functions/countDistinctApprox/"; + private final static String TEST_CLASS_DIR = TEST_DIR + CountDistinctApproxRow.class.getSimpleName() + "/"; + + @Override + protected String getTestClassDir() { + return TEST_CLASS_DIR; + } + + @Override + protected String getTestName() { + return TEST_NAME; + } + + @Override + protected String getTestDir() { + return TEST_DIR; + } + + @Override + protected Types.Direction getDirection() { + return Types.Direction.Row; + } + + @Override + public void setUp() { + super.addTestConfiguration(); + } + + @Test + public void testCPSparseLargeDefaultMCSR() { + Types.ExecType ex = Types.ExecType.CP; + + int actualDistinctCount = 10; + int rows = 10000, cols = 1000; + double sparsity = 0.1; + double tolerance = actualDistinctCount * this.percentTolerance; + + countDistinctMatrixTest(getDirection(), actualDistinctCount, cols, rows, sparsity, ex, tolerance); + } + + @Test + public void testCPSparseLargeCSR() { + int actualDistinctCount = 10; + int rows = 10000, cols = 1000; + double sparsity = 0.1; + double tolerance = actualDistinctCount * this.percentTolerance; + + super.testCPSparseLarge(SparseBlock.Type.CSR, Types.Direction.Row, rows, cols, actualDistinctCount, sparsity, + tolerance); + } + + @Test + public void testCPSparseLargeCOO() { + int actualDistinctCount = 10; + int rows = 10000, cols = 1000; + double sparsity = 0.1; + double tolerance = actualDistinctCount * this.percentTolerance; + + super.testCPSparseLarge(SparseBlock.Type.COO, Types.Direction.Row, rows, cols, actualDistinctCount, sparsity, + tolerance); + } + + @Test + public void testCPDenseLarge() { + Types.ExecType ex = Types.ExecType.CP; + + int actualDistinctCount = 100; + int rows = 10000, cols = 1000; + double sparsity = 0.9; + double tolerance = actualDistinctCount * this.percentTolerance; + + countDistinctMatrixTest(getDirection(), actualDistinctCount, cols, rows, sparsity, ex, tolerance); + } + + /** + * This is a contrived example where size of row/col > 1024, which forces the calculation of a sketch in CP exec mode. + */ + @Test + public void testCPDenseXLarge() { + Types.ExecType ex = Types.ExecType.CP; + + int actualDistinctCount = 10000; + int rows = 10000, cols = 10000; + double sparsity = 0.9; + double tolerance = actualDistinctCount * this.percentTolerance; + + countDistinctMatrixTest(getDirection(), actualDistinctCount, cols, rows, sparsity, ex, tolerance); + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxRowAlias.java b/src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxRowAlias.java new file mode 100644 index 00000000000..99b6d60e044 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxRowAlias.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.countDistinctApprox; + +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.data.SparseBlock; +import org.apache.sysds.test.functions.countDistinct.CountDistinctRowOrColBase; +import org.junit.Test; + +public class CountDistinctApproxRowAlias extends CountDistinctRowOrColBase { + + private final static String TEST_NAME = "countDistinctApproxRowAlias"; + private final static String TEST_DIR = "functions/countDistinctApprox/"; + private final static String TEST_CLASS_DIR = TEST_DIR + CountDistinctApproxRowAlias.class.getSimpleName() + "/"; + + @Override + protected String getTestClassDir() { + return TEST_CLASS_DIR; + } + + @Override + protected String getTestName() { + return TEST_NAME; + } + + @Override + protected String getTestDir() { + return TEST_DIR; + } + + @Override + protected Types.Direction getDirection() { + return Types.Direction.Row; + } + + @Override + public void setUp() { + super.addTestConfiguration(); + } + + @Test + public void testCPSparseLargeDefaultMCSR() { + Types.ExecType ex = Types.ExecType.CP; + + int actualDistinctCount = 10; + int rows = 10000, cols = 1000; + double sparsity = 0.1; + double tolerance = actualDistinctCount * this.percentTolerance; + + countDistinctMatrixTest(getDirection(), actualDistinctCount, cols, rows, sparsity, ex, tolerance); + } + + @Test + public void testCPSparseLargeCSR() { + int actualDistinctCount = 10; + int rows = 10000, cols = 1000; + double sparsity = 0.1; + double tolerance = actualDistinctCount * this.percentTolerance; + + super.testCPSparseLarge(SparseBlock.Type.CSR, Types.Direction.Row, rows, cols, actualDistinctCount, sparsity, + tolerance); + } + + @Test + public void testCPSparseLargeCOO() { + int actualDistinctCount = 10; + int rows = 10000, cols = 1000; + double sparsity = 0.1; + double tolerance = actualDistinctCount * this.percentTolerance; + + super.testCPSparseLarge(SparseBlock.Type.COO, Types.Direction.Row, rows, cols, actualDistinctCount, sparsity, + tolerance); + } + + @Test + public void testCPDenseLarge() { + Types.ExecType ex = Types.ExecType.CP; + + int actualDistinctCount = 100; + int rows = 10000, cols = 1000; + double sparsity = 0.9; + double tolerance = actualDistinctCount * this.percentTolerance; + + countDistinctMatrixTest(getDirection(), actualDistinctCount, cols, rows, sparsity, ex, tolerance); + } + + /** + * This is a contrived example where size of row/col > 1024, which forces the calculation of a sketch in CP exec mode. + */ + @Test + public void testCPDenseXLarge() { + Types.ExecType ex = Types.ExecType.CP; + + int actualDistinctCount = 10000; + int rows = 10000, cols = 10000; + double sparsity = 0.9; + double tolerance = actualDistinctCount * this.percentTolerance; + + countDistinctMatrixTest(getDirection(), actualDistinctCount, cols, rows, sparsity, ex, tolerance); + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRowCol.java b/src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxRowCol.java similarity index 96% rename from src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRowCol.java rename to src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxRowCol.java index e59b0028874..4c0f27bd5bb 100644 --- a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRowCol.java +++ b/src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxRowCol.java @@ -17,15 +17,16 @@ * under the License. */ -package org.apache.sysds.test.functions.countDistinct; +package org.apache.sysds.test.functions.countDistinctApprox; import org.apache.sysds.common.Types.ExecType; +import org.apache.sysds.test.functions.countDistinct.CountDistinctRowColBase; import org.junit.Test; public class CountDistinctApproxRowCol extends CountDistinctRowColBase { private final static String TEST_NAME = "countDistinctApproxRowCol"; - private final static String TEST_DIR = "functions/countDistinct/"; + private final static String TEST_DIR = "functions/countDistinctApprox/"; private final static String TEST_CLASS_DIR = TEST_DIR + CountDistinctApproxRowCol.class.getSimpleName() + "/"; @Override diff --git a/src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxRowColParameterized.java b/src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxRowColParameterized.java new file mode 100644 index 00000000000..df532a95d80 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxRowColParameterized.java @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.countDistinctApprox; + +import org.apache.sysds.common.Types.ExecType; +import org.apache.sysds.test.functions.countDistinct.CountDistinctRowColBase; +import org.junit.Test; + +public class CountDistinctApproxRowColParameterized extends CountDistinctRowColBase { + + private final static String TEST_NAME = "countDistinctApproxRowColParameterized"; + private final static String TEST_DIR = "functions/countDistinctApprox/"; + private final static String TEST_CLASS_DIR = TEST_DIR + CountDistinctApproxRowColParameterized.class.getSimpleName() + "/"; + + @Override + public void setUp() { + super.addTestConfiguration(); + super.percentTolerance = 0.2; + } + + @Test + public void testCPSparseLarge() { + ExecType ex = ExecType.CP; + double tolerance = 9000 * percentTolerance; + countDistinctScalarTest(9000, 10000, 5000, 0.1, ex, tolerance); + } + + @Test + public void testSparkSparseLarge() { + ExecType ex = ExecType.SPARK; + double tolerance = 9000 * percentTolerance; + countDistinctScalarTest(9000, 10000, 5000, 0.1, ex, tolerance); + } + + @Test + public void testCPSparseSmall() { + ExecType ex = ExecType.CP; + double tolerance = 9000 * percentTolerance; + countDistinctScalarTest(9000, 999, 999, 0.1, ex, tolerance); + } + + @Test + public void testSparkSparseSmall() { + ExecType ex = ExecType.SPARK; + double tolerance = 9000 * percentTolerance; + countDistinctScalarTest(9000, 999, 999, 0.1, ex, tolerance); + } + + @Test + public void testCPDenseXSmall() { + ExecType ex = ExecType.CP; + double tolerance = 5 * percentTolerance; + countDistinctScalarTest(5, 5, 10, 1.0, ex, tolerance); + } + + @Test + public void testSparkDenseXSmall() { + ExecType ex = ExecType.SPARK; + double tolerance = 5 * percentTolerance; + countDistinctScalarTest(5, 10, 5, 1.0, ex, tolerance); + } + + @Test + public void testCPEmpty() { + ExecType ex = ExecType.CP; + countDistinctScalarTest(1, 0, 0, 0.1, ex, 0); + } + + @Test + public void testSparkEmpty() { + ExecType ex = ExecType.SPARK; + countDistinctScalarTest(1, 0, 0, 0.1, ex, 0); + } + + @Test + public void testCPSingleValue() { + ExecType ex = ExecType.CP; + countDistinctScalarTest(1, 1, 1, 1.0, ex, 0); + } + + @Test + public void testSparkSingleValue() { + ExecType ex = ExecType.SPARK; + countDistinctScalarTest(1, 1, 1, 1.0, ex, 0); + } + + // Corresponding execType=SPARK tests for CP tests in base class + // + @Test + public void testSparkDense1Unique() { + ExecType ex = ExecType.SPARK; + double tolerance = 0.00001; + countDistinctScalarTest(1, 100, 1000, 1.0, ex, tolerance); + } + + @Test + public void testSparkDense2Unique() { + ExecType ex = ExecType.SPARK; + double tolerance = 0.00001; + countDistinctScalarTest(2, 100, 1000, 1.0, ex, tolerance); + } + + @Test + public void testSparkDense120Unique() { + ExecType ex = ExecType.SPARK; + double tolerance = 0.00001 + 120 * percentTolerance; + countDistinctScalarTest(120, 100, 1000, 1.0, ex, tolerance); + } + + @Override + protected String getTestClassDir() { + return TEST_CLASS_DIR; + } + + @Override + protected String getTestName() { + return TEST_NAME; + } + + @Override + protected String getTestDir() { + return TEST_DIR; + } +} diff --git a/src/test/scripts/functions/countDistinct/countDistinct.dml b/src/test/scripts/functions/countDistinct/countDistinctCol.dml similarity index 96% rename from src/test/scripts/functions/countDistinct/countDistinct.dml rename to src/test/scripts/functions/countDistinct/countDistinctCol.dml index 3b21bc89f1b..3f2918ee1ea 100644 --- a/src/test/scripts/functions/countDistinct/countDistinct.dml +++ b/src/test/scripts/functions/countDistinct/countDistinctCol.dml @@ -20,5 +20,5 @@ #------------------------------------------------------------- input = round(rand(rows = $2, cols = $3, min = 0, max = $1 -1, sparsity= $4, seed = 7)) -res = countDistinct(input) +res = countDistinct(input, dir="c") write(res, $5, format="text") diff --git a/src/test/scripts/functions/countDistinct/countDistinctColAlias.dml b/src/test/scripts/functions/countDistinct/countDistinctColAlias.dml new file mode 100644 index 00000000000..3eeb8ed54ab --- /dev/null +++ b/src/test/scripts/functions/countDistinct/countDistinctColAlias.dml @@ -0,0 +1,24 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +input = round(rand(rows = $2, cols = $3, min = 0, max = $1 -1, sparsity= $4, seed = 7)) +res = countDistinctCol(input, dir="c") +write(res, $5, format="text") diff --git a/src/test/scripts/functions/countDistinct/countDistinctRow.dml b/src/test/scripts/functions/countDistinct/countDistinctRow.dml new file mode 100644 index 00000000000..f8665f6fc75 --- /dev/null +++ b/src/test/scripts/functions/countDistinct/countDistinctRow.dml @@ -0,0 +1,24 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +input = round(rand(rows = $2, cols = $3, min = 0, max = $1 -1, sparsity= $4, seed = 7)) +res = countDistinct(input, dir="r") +write(res, $5, format="text") diff --git a/src/test/scripts/functions/countDistinct/countDistinctRowAlias.dml b/src/test/scripts/functions/countDistinct/countDistinctRowAlias.dml new file mode 100644 index 00000000000..62d7196ce17 --- /dev/null +++ b/src/test/scripts/functions/countDistinct/countDistinctRowAlias.dml @@ -0,0 +1,24 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +input = round(rand(rows = $2, cols = $3, min = 0, max = $1 -1, sparsity= $4, seed = 7)) +res = countDistinctRow(input, dir="r") +write(res, $5, format="text") diff --git a/src/test/scripts/functions/countDistinct/countDistinctRowCol.dml b/src/test/scripts/functions/countDistinct/countDistinctRowCol.dml new file mode 100644 index 00000000000..7ac9dd53fc9 --- /dev/null +++ b/src/test/scripts/functions/countDistinct/countDistinctRowCol.dml @@ -0,0 +1,24 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +input = round(rand(rows = $2, cols = $3, min = 0, max = $1 -1, sparsity= $4, seed = 7)) +res = countDistinct(input) # default is dir="rc" +write(res, $5, format="text") diff --git a/src/test/scripts/functions/countDistinct/countDistinctRowColParameterized.dml b/src/test/scripts/functions/countDistinct/countDistinctRowColParameterized.dml new file mode 100644 index 00000000000..9bd22867d0a --- /dev/null +++ b/src/test/scripts/functions/countDistinct/countDistinctRowColParameterized.dml @@ -0,0 +1,24 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +input = round(rand(rows = $2, cols = $3, min = 0, max = $1 -1, sparsity= $4, seed = 7)) +res = countDistinct(input, dir="rc") +write(res, $5, format="text") diff --git a/src/test/scripts/functions/countDistinct/countDistinctApproxCol.dml b/src/test/scripts/functions/countDistinctApprox/countDistinctApproxCol.dml similarity index 100% rename from src/test/scripts/functions/countDistinct/countDistinctApproxCol.dml rename to src/test/scripts/functions/countDistinctApprox/countDistinctApproxCol.dml diff --git a/src/test/scripts/functions/countDistinctApprox/countDistinctApproxColAlias.dml b/src/test/scripts/functions/countDistinctApprox/countDistinctApproxColAlias.dml new file mode 100644 index 00000000000..83a9f5070c0 --- /dev/null +++ b/src/test/scripts/functions/countDistinctApprox/countDistinctApproxColAlias.dml @@ -0,0 +1,24 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +input = round(rand(rows = $2, cols = $3, min = 0, max = $1 -1, sparsity= $4, seed = 7)) +res = countDistinctApproxCol(input, dir="c", type="KMV") +write(res, $5, format="text") diff --git a/src/test/scripts/functions/countDistinct/countDistinctApproxRow.dml b/src/test/scripts/functions/countDistinctApprox/countDistinctApproxRow.dml similarity index 100% rename from src/test/scripts/functions/countDistinct/countDistinctApproxRow.dml rename to src/test/scripts/functions/countDistinctApprox/countDistinctApproxRow.dml diff --git a/src/test/scripts/functions/countDistinctApprox/countDistinctApproxRowAlias.dml b/src/test/scripts/functions/countDistinctApprox/countDistinctApproxRowAlias.dml new file mode 100644 index 00000000000..f4be4801568 --- /dev/null +++ b/src/test/scripts/functions/countDistinctApprox/countDistinctApproxRowAlias.dml @@ -0,0 +1,24 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +input = round(rand(rows = $2, cols = $3, min = 0, max = $1 -1, sparsity= $4, seed = 7)) +res = countDistinctApproxRow(input, dir="r", type="KMV") +write(res, $5, format="text") diff --git a/src/test/scripts/functions/countDistinctApprox/countDistinctApproxRowCol.dml b/src/test/scripts/functions/countDistinctApprox/countDistinctApproxRowCol.dml new file mode 100644 index 00000000000..21245ecfbb6 --- /dev/null +++ b/src/test/scripts/functions/countDistinctApprox/countDistinctApproxRowCol.dml @@ -0,0 +1,24 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +input = round(rand(rows = $2, cols = $3, min = 0, max = $1 -1, sparsity= $4, seed = 7)) +res = countDistinctApprox(input, type="KMV") +write(res, $5, format="text") diff --git a/src/test/scripts/functions/countDistinct/countDistinctApproxRowCol.dml b/src/test/scripts/functions/countDistinctApprox/countDistinctApproxRowColParameterized.dml similarity index 100% rename from src/test/scripts/functions/countDistinct/countDistinctApproxRowCol.dml rename to src/test/scripts/functions/countDistinctApprox/countDistinctApproxRowColParameterized.dml