diff --git a/src/main/java/org/apache/sysds/hops/BinaryOp.java b/src/main/java/org/apache/sysds/hops/BinaryOp.java index 73deda4f95f..1151135ae0e 100644 --- a/src/main/java/org/apache/sysds/hops/BinaryOp.java +++ b/src/main/java/org/apache/sysds/hops/BinaryOp.java @@ -181,7 +181,10 @@ else if(isMatrixScalar || isMatrixMatrix) { public boolean isMultiThreadedOpType() { return !getDataType().isScalar() || getOp() == OpOp2.COV - || getOp() == OpOp2.MOMENT; + || getOp() == OpOp2.MOMENT + || getOp() == OpOp2.IQM + || getOp() == OpOp2.MEDIAN + || getOp() == OpOp2.QUANTILE; } @Override @@ -233,11 +236,12 @@ public Lop constructLops() } private void constructLopsIQM(ExecType et) { + int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads); SortKeys sort = SortKeys.constructSortByValueLop( getInput().get(0).constructLops(), getInput().get(1).constructLops(), SortKeys.OperationTypes.WithWeights, - getInput().get(0).getDataType(), getInput().get(0).getValueType(), et); + getInput().get(0).getDataType(), getInput().get(0).getValueType(), et, k); sort.getOutputParameters().setDimensions( getInput().get(0).getDim1(), getInput().get(0).getDim2(), @@ -256,11 +260,12 @@ private void constructLopsIQM(ExecType et) { } private void constructLopsMedian(ExecType et) { + int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads); SortKeys sort = SortKeys.constructSortByValueLop( getInput().get(0).constructLops(), getInput().get(1).constructLops(), SortKeys.OperationTypes.WithWeights, - getInput().get(0).getDataType(), getInput().get(0).getValueType(), et); + getInput().get(0).getDataType(), getInput().get(0).getValueType(), et, k); sort.getOutputParameters().setDimensions( getInput().get(0).getDim1(), getInput().get(0).getDim2(), @@ -317,10 +322,11 @@ private void constructLopsQuantile(ExecType et) { else pick_op = PickByCount.OperationTypes.RANGEPICK; + int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads); SortKeys sort = SortKeys.constructSortByValueLop( getInput().get(0).constructLops(), SortKeys.OperationTypes.WithoutWeights, - DataType.MATRIX, ValueType.FP64, et ); + DataType.MATRIX, ValueType.FP64, et, k ); sort.getOutputParameters().setDimensions( getInput().get(0).getDim1(), getInput().get(0).getDim2(), diff --git a/src/main/java/org/apache/sysds/hops/TernaryOp.java b/src/main/java/org/apache/sysds/hops/TernaryOp.java index b7ad4fd3334..a754f1be23c 100644 --- a/src/main/java/org/apache/sysds/hops/TernaryOp.java +++ b/src/main/java/org/apache/sysds/hops/TernaryOp.java @@ -149,7 +149,9 @@ public boolean isGPUEnabled() { public boolean isMultiThreadedOpType() { return _op == OpOp3.IFELSE || _op == OpOp3.MINUS_MULT - || _op == OpOp3.PLUS_MULT; + || _op == OpOp3.PLUS_MULT + || _op == OpOp3.QUANTILE + || _op == OpOp3.INTERQUANTILE; } @Override @@ -247,9 +249,10 @@ private void constructLopsQuantile() { throw new HopsException("Unexpected operation: " + _op + ", expecting " + OpOp3.QUANTILE + " or " + OpOp3.INTERQUANTILE ); ExecType et = optFindExecType(); + int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads); SortKeys sort = SortKeys.constructSortByValueLop(getInput().get(0).constructLops(), getInput().get(1).constructLops(), SortKeys.OperationTypes.WithWeights, - getInput().get(0).getDataType(), getInput().get(0).getValueType(), et); + getInput().get(0).getDataType(), getInput().get(0).getValueType(), et, k); PickByCount pick = new PickByCount(sort, getInput().get(2).constructLops(), getDataType(), getValueType(), (_op == OpOp3.QUANTILE) ? PickByCount.OperationTypes.VALUEPICK : PickByCount.OperationTypes.RANGEPICK, et, true); diff --git a/src/main/java/org/apache/sysds/hops/UnaryOp.java b/src/main/java/org/apache/sysds/hops/UnaryOp.java index 009d6fff197..25a1202c5fa 100644 --- a/src/main/java/org/apache/sysds/hops/UnaryOp.java +++ b/src/main/java/org/apache/sysds/hops/UnaryOp.java @@ -197,12 +197,11 @@ else if(_op == OpOp1.MEDIAN) { private Lop constructLopsMedian() { ExecType et = optFindExecType(); - - + int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads); SortKeys sort = SortKeys.constructSortByValueLop( getInput().get(0).constructLops(), SortKeys.OperationTypes.WithoutWeights, - DataType.MATRIX, ValueType.FP64, et ); + DataType.MATRIX, ValueType.FP64, et, k ); sort.getOutputParameters().setDimensions( getInput().get(0).getDim1(), getInput().get(0).getDim2(), @@ -225,14 +224,13 @@ private Lop constructLopsMedian() private Lop constructLopsIQM() { - ExecType et = optFindExecType(); - + int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads); Hop input = getInput().get(0); - SortKeys sort = SortKeys.constructSortByValueLop( - input.constructLops(), - SortKeys.OperationTypes.WithoutWeights, - DataType.MATRIX, ValueType.FP64, et ); + SortKeys sort = SortKeys.constructSortByValueLop( + input.constructLops(), + SortKeys.OperationTypes.WithoutWeights, + DataType.MATRIX, ValueType.FP64, et, k ); sort.getOutputParameters().setDimensions( input.getDim1(), input.getDim2(), @@ -456,7 +454,9 @@ public boolean isExpensiveUnaryOperation() { || _op == OpOp1.LOG || _op == OpOp1.SIGMOID || _op == OpOp1.COMPRESS - || _op == OpOp1.DECOMPRESS); + || _op == OpOp1.DECOMPRESS + || _op == OpOp1.MEDIAN + || _op == OpOp1.IQM); } public boolean isMetadataOperation() { diff --git a/src/main/java/org/apache/sysds/lops/SortKeys.java b/src/main/java/org/apache/sysds/lops/SortKeys.java index 01c1594195d..f7a0f823884 100644 --- a/src/main/java/org/apache/sysds/lops/SortKeys.java +++ b/src/main/java/org/apache/sysds/lops/SortKeys.java @@ -39,31 +39,34 @@ public enum OperationTypes { } private OperationTypes operation; - + + private int _numThreads; + public OperationTypes getOpType() { return operation; } - public SortKeys(Lop input, OperationTypes op, DataType dt, ValueType vt, ExecType et) { + public SortKeys(Lop input, OperationTypes op, DataType dt, ValueType vt, ExecType et, int numThreads) { super(Lop.Type.SortKeys, dt, vt); - init(input, null, op, et); + init(input, null, op, et, numThreads); } public SortKeys(Lop input, boolean desc, OperationTypes op, DataType dt, ValueType vt, ExecType et) { super(Lop.Type.SortKeys, dt, vt); - init(input, null, op, et); + init(input, null, op, et, 1); } - public SortKeys(Lop input1, Lop input2, OperationTypes op, DataType dt, ValueType vt, ExecType et) { - super(Lop.Type.SortKeys, dt, vt); - init(input1, input2, op, et); + public SortKeys(Lop input1, Lop input2, OperationTypes op, DataType dt, ValueType vt, ExecType et, int numThreads) { + super(Lop.Type.SortKeys, dt, vt); + init(input1, input2, op, et, numThreads); } - private void init(Lop input1, Lop input2, OperationTypes op, ExecType et) { + private void init(Lop input1, Lop input2, OperationTypes op, ExecType et, int numThreads) { addInput(input1); input1.addOutput(this); operation = op; + _numThreads = numThreads; // SortKeys can accept a optional second input only when executing in CP // Example: sorting with weights inside CP @@ -82,43 +85,57 @@ public String toString() { @Override public String getInstructions(String input, String output) { - return InstructionUtils.concatOperands( + StringBuilder sb = new StringBuilder(); + sb.append(InstructionUtils.concatOperands( getExecType().name(), OPCODE, getInputs().get(0).prepInputOperand(input), - prepOutputOperand(output)); + prepOutputOperand(output))); + + if( getExecType() == ExecType.CP ) { + sb.append( OPERAND_DELIMITOR ); + sb.append(_numThreads); + } + return sb.toString(); } @Override public String getInstructions(String input1, String input2, String output) { - return InstructionUtils.concatOperands( + StringBuilder sb = new StringBuilder(); + sb.append(InstructionUtils.concatOperands( getExecType().name(), OPCODE, getInputs().get(0).prepInputOperand(input1), getInputs().get(1).prepInputOperand(input2), - prepOutputOperand(output)); + prepOutputOperand(output))); + + if( getExecType() == ExecType.CP ) { + sb.append( OPERAND_DELIMITOR ); + sb.append(_numThreads); + } + return sb.toString(); } - + // This method is invoked in two cases: // 1) SortKeys (both weighted and unweighted) executes in MR // 2) Unweighted SortKeys executes in CP - public static SortKeys constructSortByValueLop(Lop input1, OperationTypes op, - DataType dt, ValueType vt, ExecType et) { - + public static SortKeys constructSortByValueLop(Lop input1, OperationTypes op, + DataType dt, ValueType vt, ExecType et, int numThreads) { + for (Lop lop : input1.getOutputs()) { if ( lop.type == Lop.Type.SortKeys ) { return (SortKeys)lop; } } - - SortKeys retVal = new SortKeys(input1, op, dt, vt, et); + + SortKeys retVal = new SortKeys(input1, op, dt, vt, et, numThreads); retVal.setAllPositions(input1.getFilename(), input1.getBeginLine(), input1.getBeginColumn(), input1.getEndLine(), input1.getEndColumn()); return retVal; } // This method is invoked ONLY for the case of Weighted SortKeys executing in CP public static SortKeys constructSortByValueLop(Lop input1, Lop input2, OperationTypes op, - DataType dt, ValueType vt, ExecType et) { + DataType dt, ValueType vt, ExecType et, int numThreads) { HashSet set1 = new HashSet<>(); set1.addAll(input1.getOutputs()); @@ -131,7 +148,7 @@ public static SortKeys constructSortByValueLop(Lop input1, Lop input2, Operation } } - SortKeys retVal = new SortKeys(input1, input2, op, dt, vt, et); + SortKeys retVal = new SortKeys(input1, input2, op, dt, vt, et, numThreads); retVal.setAllPositions(input1.getFilename(), input1.getBeginLine(), input1.getBeginColumn(), input1.getEndLine(), input1.getEndColumn()); return retVal; } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java index f22fdfe5506..39fbef23077 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java @@ -225,7 +225,15 @@ public static String[] getInstructionPartsWithValueType( String str ) { return ret; } - + + public static String stripThreadCount(String str) { + String[] parts = str.split(Instruction.OPERAND_DELIM, -1); + String[] ret = new String[parts.length-1]; + for (int i=0; i 1 ? InstructionUtils.stripThreadCount(instString) : instString; + newInst = InstructionUtils.replaceOperand(newInst, 1, "append"); newInst = InstructionUtils.concatOperands(newInst, "true"); FederatedRequest[] fr1 = in.getFedMapping().broadcastSliced(weights, false); FederatedRequest fr2 = FederationUtils.callInstruction(newInst, output, @@ -123,7 +150,7 @@ public void processColumnQSort(ExecutionContext ec) { FederatedResponse response = data .executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1, - new GetSorted(data.getVarID(), varID, wtBlock))).get(); + new GetSorted(data.getVarID(), varID, wtBlock, _numThreads))).get(); if(!response.isSuccessful()) response.throwExceptionFromResponse(); } @@ -145,17 +172,19 @@ private static class GetSorted extends FederatedUDF { private static final long serialVersionUID = -1969015577260167645L; private final long _outputID; private final MatrixBlock _weights; + private final int _numThreads; - protected GetSorted(long input, long outputID, MatrixBlock weights) { + protected GetSorted(long input, long outputID, MatrixBlock weights, int k) { super(new long[] {input}); _outputID = outputID; _weights = weights; + _numThreads = k; } @Override public FederatedResponse execute(ExecutionContext ec, Data... data) { MatrixBlock mb = ((MatrixObject) data[0]).acquireReadAndRelease(); - MatrixBlock res = mb.sortOperations(_weights, new MatrixBlock()); + MatrixBlock res = mb.sortOperations(_weights, new MatrixBlock(), _numThreads); MatrixObject mout = ExecutionContext.createMatrixObject(res); diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java index eb152ab5834..5df5d62e724 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java @@ -4820,6 +4820,10 @@ public final MatrixBlock sortOperations(MatrixValue weights){ } public MatrixBlock sortOperations(MatrixValue weights, MatrixBlock result) { + return sortOperations(weights, result, 1); + } + + public MatrixBlock sortOperations(MatrixValue weights, MatrixBlock result, int k) { boolean wtflag = (weights!=null); MatrixBlock wts= (weights == null ? null : checkType(weights)); @@ -4877,7 +4881,7 @@ public MatrixBlock sortOperations(MatrixValue weights, MatrixBlock result) { // Sort td and tw based on values inside td (ascending sort), incl copy into result SortIndex sfn = new SortIndex(1, false, false); - ReorgOperator rop = new ReorgOperator(sfn); + ReorgOperator rop = new ReorgOperator(sfn, k); LibMatrixReorg.reorg(tdw, result, rop); return result;