From cb47472a8d7458ab33fc441938e753ec025d1ca1 Mon Sep 17 00:00:00 2001 From: arnabp Date: Tue, 22 Mar 2022 22:07:48 +0100 Subject: [PATCH 1/2] [SYSTEMDS-3338] Multi-threaded local Qsort instruction This patch updates the QuantileSort instruction to use a multithreaded sort. This change improves quantile by 2.5x for 100M rows. --- .../java/org/apache/sysds/hops/BinaryOp.java | 14 +++-- .../java/org/apache/sysds/hops/TernaryOp.java | 7 ++- .../java/org/apache/sysds/hops/UnaryOp.java | 20 +++---- .../java/org/apache/sysds/lops/SortKeys.java | 57 ++++++++++++------- .../cp/QuantileSortCPInstruction.java | 39 ++++++++++--- .../runtime/matrix/data/MatrixBlock.java | 6 +- 6 files changed, 99 insertions(+), 44 deletions(-) diff --git a/src/main/java/org/apache/sysds/hops/BinaryOp.java b/src/main/java/org/apache/sysds/hops/BinaryOp.java index 73deda4f95f..1151135ae0e 100644 --- a/src/main/java/org/apache/sysds/hops/BinaryOp.java +++ b/src/main/java/org/apache/sysds/hops/BinaryOp.java @@ -181,7 +181,10 @@ else if(isMatrixScalar || isMatrixMatrix) { public boolean isMultiThreadedOpType() { return !getDataType().isScalar() || getOp() == OpOp2.COV - || getOp() == OpOp2.MOMENT; + || getOp() == OpOp2.MOMENT + || getOp() == OpOp2.IQM + || getOp() == OpOp2.MEDIAN + || getOp() == OpOp2.QUANTILE; } @Override @@ -233,11 +236,12 @@ public Lop constructLops() } private void constructLopsIQM(ExecType et) { + int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads); SortKeys sort = SortKeys.constructSortByValueLop( getInput().get(0).constructLops(), getInput().get(1).constructLops(), SortKeys.OperationTypes.WithWeights, - getInput().get(0).getDataType(), getInput().get(0).getValueType(), et); + getInput().get(0).getDataType(), getInput().get(0).getValueType(), et, k); sort.getOutputParameters().setDimensions( getInput().get(0).getDim1(), getInput().get(0).getDim2(), @@ -256,11 +260,12 @@ private void constructLopsIQM(ExecType et) { } private void constructLopsMedian(ExecType et) { + int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads); SortKeys sort = SortKeys.constructSortByValueLop( getInput().get(0).constructLops(), getInput().get(1).constructLops(), SortKeys.OperationTypes.WithWeights, - getInput().get(0).getDataType(), getInput().get(0).getValueType(), et); + getInput().get(0).getDataType(), getInput().get(0).getValueType(), et, k); sort.getOutputParameters().setDimensions( getInput().get(0).getDim1(), getInput().get(0).getDim2(), @@ -317,10 +322,11 @@ private void constructLopsQuantile(ExecType et) { else pick_op = PickByCount.OperationTypes.RANGEPICK; + int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads); SortKeys sort = SortKeys.constructSortByValueLop( getInput().get(0).constructLops(), SortKeys.OperationTypes.WithoutWeights, - DataType.MATRIX, ValueType.FP64, et ); + DataType.MATRIX, ValueType.FP64, et, k ); sort.getOutputParameters().setDimensions( getInput().get(0).getDim1(), getInput().get(0).getDim2(), diff --git a/src/main/java/org/apache/sysds/hops/TernaryOp.java b/src/main/java/org/apache/sysds/hops/TernaryOp.java index b7ad4fd3334..a754f1be23c 100644 --- a/src/main/java/org/apache/sysds/hops/TernaryOp.java +++ b/src/main/java/org/apache/sysds/hops/TernaryOp.java @@ -149,7 +149,9 @@ public boolean isGPUEnabled() { public boolean isMultiThreadedOpType() { return _op == OpOp3.IFELSE || _op == OpOp3.MINUS_MULT - || _op == OpOp3.PLUS_MULT; + || _op == OpOp3.PLUS_MULT + || _op == OpOp3.QUANTILE + || _op == OpOp3.INTERQUANTILE; } @Override @@ -247,9 +249,10 @@ private void constructLopsQuantile() { throw new HopsException("Unexpected operation: " + _op + ", expecting " + OpOp3.QUANTILE + " or " + OpOp3.INTERQUANTILE ); ExecType et = optFindExecType(); + int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads); SortKeys sort = SortKeys.constructSortByValueLop(getInput().get(0).constructLops(), getInput().get(1).constructLops(), SortKeys.OperationTypes.WithWeights, - getInput().get(0).getDataType(), getInput().get(0).getValueType(), et); + getInput().get(0).getDataType(), getInput().get(0).getValueType(), et, k); PickByCount pick = new PickByCount(sort, getInput().get(2).constructLops(), getDataType(), getValueType(), (_op == OpOp3.QUANTILE) ? PickByCount.OperationTypes.VALUEPICK : PickByCount.OperationTypes.RANGEPICK, et, true); diff --git a/src/main/java/org/apache/sysds/hops/UnaryOp.java b/src/main/java/org/apache/sysds/hops/UnaryOp.java index 009d6fff197..25a1202c5fa 100644 --- a/src/main/java/org/apache/sysds/hops/UnaryOp.java +++ b/src/main/java/org/apache/sysds/hops/UnaryOp.java @@ -197,12 +197,11 @@ else if(_op == OpOp1.MEDIAN) { private Lop constructLopsMedian() { ExecType et = optFindExecType(); - - + int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads); SortKeys sort = SortKeys.constructSortByValueLop( getInput().get(0).constructLops(), SortKeys.OperationTypes.WithoutWeights, - DataType.MATRIX, ValueType.FP64, et ); + DataType.MATRIX, ValueType.FP64, et, k ); sort.getOutputParameters().setDimensions( getInput().get(0).getDim1(), getInput().get(0).getDim2(), @@ -225,14 +224,13 @@ private Lop constructLopsMedian() private Lop constructLopsIQM() { - ExecType et = optFindExecType(); - + int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads); Hop input = getInput().get(0); - SortKeys sort = SortKeys.constructSortByValueLop( - input.constructLops(), - SortKeys.OperationTypes.WithoutWeights, - DataType.MATRIX, ValueType.FP64, et ); + SortKeys sort = SortKeys.constructSortByValueLop( + input.constructLops(), + SortKeys.OperationTypes.WithoutWeights, + DataType.MATRIX, ValueType.FP64, et, k ); sort.getOutputParameters().setDimensions( input.getDim1(), input.getDim2(), @@ -456,7 +454,9 @@ public boolean isExpensiveUnaryOperation() { || _op == OpOp1.LOG || _op == OpOp1.SIGMOID || _op == OpOp1.COMPRESS - || _op == OpOp1.DECOMPRESS); + || _op == OpOp1.DECOMPRESS + || _op == OpOp1.MEDIAN + || _op == OpOp1.IQM); } public boolean isMetadataOperation() { diff --git a/src/main/java/org/apache/sysds/lops/SortKeys.java b/src/main/java/org/apache/sysds/lops/SortKeys.java index 01c1594195d..f7a0f823884 100644 --- a/src/main/java/org/apache/sysds/lops/SortKeys.java +++ b/src/main/java/org/apache/sysds/lops/SortKeys.java @@ -39,31 +39,34 @@ public enum OperationTypes { } private OperationTypes operation; - + + private int _numThreads; + public OperationTypes getOpType() { return operation; } - public SortKeys(Lop input, OperationTypes op, DataType dt, ValueType vt, ExecType et) { + public SortKeys(Lop input, OperationTypes op, DataType dt, ValueType vt, ExecType et, int numThreads) { super(Lop.Type.SortKeys, dt, vt); - init(input, null, op, et); + init(input, null, op, et, numThreads); } public SortKeys(Lop input, boolean desc, OperationTypes op, DataType dt, ValueType vt, ExecType et) { super(Lop.Type.SortKeys, dt, vt); - init(input, null, op, et); + init(input, null, op, et, 1); } - public SortKeys(Lop input1, Lop input2, OperationTypes op, DataType dt, ValueType vt, ExecType et) { - super(Lop.Type.SortKeys, dt, vt); - init(input1, input2, op, et); + public SortKeys(Lop input1, Lop input2, OperationTypes op, DataType dt, ValueType vt, ExecType et, int numThreads) { + super(Lop.Type.SortKeys, dt, vt); + init(input1, input2, op, et, numThreads); } - private void init(Lop input1, Lop input2, OperationTypes op, ExecType et) { + private void init(Lop input1, Lop input2, OperationTypes op, ExecType et, int numThreads) { addInput(input1); input1.addOutput(this); operation = op; + _numThreads = numThreads; // SortKeys can accept a optional second input only when executing in CP // Example: sorting with weights inside CP @@ -82,43 +85,57 @@ public String toString() { @Override public String getInstructions(String input, String output) { - return InstructionUtils.concatOperands( + StringBuilder sb = new StringBuilder(); + sb.append(InstructionUtils.concatOperands( getExecType().name(), OPCODE, getInputs().get(0).prepInputOperand(input), - prepOutputOperand(output)); + prepOutputOperand(output))); + + if( getExecType() == ExecType.CP ) { + sb.append( OPERAND_DELIMITOR ); + sb.append(_numThreads); + } + return sb.toString(); } @Override public String getInstructions(String input1, String input2, String output) { - return InstructionUtils.concatOperands( + StringBuilder sb = new StringBuilder(); + sb.append(InstructionUtils.concatOperands( getExecType().name(), OPCODE, getInputs().get(0).prepInputOperand(input1), getInputs().get(1).prepInputOperand(input2), - prepOutputOperand(output)); + prepOutputOperand(output))); + + if( getExecType() == ExecType.CP ) { + sb.append( OPERAND_DELIMITOR ); + sb.append(_numThreads); + } + return sb.toString(); } - + // This method is invoked in two cases: // 1) SortKeys (both weighted and unweighted) executes in MR // 2) Unweighted SortKeys executes in CP - public static SortKeys constructSortByValueLop(Lop input1, OperationTypes op, - DataType dt, ValueType vt, ExecType et) { - + public static SortKeys constructSortByValueLop(Lop input1, OperationTypes op, + DataType dt, ValueType vt, ExecType et, int numThreads) { + for (Lop lop : input1.getOutputs()) { if ( lop.type == Lop.Type.SortKeys ) { return (SortKeys)lop; } } - - SortKeys retVal = new SortKeys(input1, op, dt, vt, et); + + SortKeys retVal = new SortKeys(input1, op, dt, vt, et, numThreads); retVal.setAllPositions(input1.getFilename(), input1.getBeginLine(), input1.getBeginColumn(), input1.getEndLine(), input1.getEndColumn()); return retVal; } // This method is invoked ONLY for the case of Weighted SortKeys executing in CP public static SortKeys constructSortByValueLop(Lop input1, Lop input2, OperationTypes op, - DataType dt, ValueType vt, ExecType et) { + DataType dt, ValueType vt, ExecType et, int numThreads) { HashSet set1 = new HashSet<>(); set1.addAll(input1.getOutputs()); @@ -131,7 +148,7 @@ public static SortKeys constructSortByValueLop(Lop input1, Lop input2, Operation } } - SortKeys retVal = new SortKeys(input1, input2, op, dt, vt, et); + SortKeys retVal = new SortKeys(input1, input2, op, dt, vt, et, numThreads); retVal.setAllPositions(input1.getFilename(), input1.getBeginLine(), input1.getBeginColumn(), input1.getEndLine(), input1.getEndColumn()); return retVal; } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/QuantileSortCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/QuantileSortCPInstruction.java index 3e953d22cd0..78f81a8b1dd 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/QuantileSortCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/QuantileSortCPInstruction.java @@ -36,9 +36,11 @@ * */ public class QuantileSortCPInstruction extends UnaryCPInstruction { + int _numThreads; - private QuantileSortCPInstruction(CPOperand in, CPOperand out, String opcode, String istr) { + private QuantileSortCPInstruction(CPOperand in, CPOperand out, String opcode, String istr, int k) { this(in, null, out, opcode, istr); + _numThreads = k; } private QuantileSortCPInstruction(CPOperand in1, CPOperand in2, CPOperand out, String opcode, @@ -46,6 +48,26 @@ private QuantileSortCPInstruction(CPOperand in1, CPOperand in2, CPOperand out, S super(CPType.QSort, null, in1, in2, out, opcode, istr); } + private static void parseInstruction(String instr, CPOperand in1, CPOperand in2, CPOperand out) { + String[] parts = InstructionUtils.getInstructionPartsWithValueType(instr); + + String opcode = parts[0]; + out.split(parts[parts.length-2]); + + switch(parts.length) { + case 4: + in1.split(parts[1]); + in2 = null; + break; + case 5: + in1.split(parts[1]); + in2.split(parts[2]); + break; + default: + throw new DMLRuntimeException("Unexpected number of operands in the instruction: " + instr); + } + } + public static QuantileSortCPInstruction parseInstruction ( String str ) { CPOperand in1 = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN); CPOperand in2 = null; @@ -55,15 +77,18 @@ public static QuantileSortCPInstruction parseInstruction ( String str ) { String opcode = parts[0]; if ( opcode.equalsIgnoreCase(SortKeys.OPCODE) ) { - if ( parts.length == 3 ) { + int k = Integer.parseInt(parts[parts.length-1]); //#threads + if ( parts.length == 4 ) { // Example: sort:mVar1:mVar2 (input=mVar1, output=mVar2) - parseUnaryInstruction(str, in1, out); - return new QuantileSortCPInstruction(in1, out, opcode, str); + InstructionUtils.checkNumFields(str, 3); + parseInstruction(str, in1, null, out); + return new QuantileSortCPInstruction(in1, out, opcode, str, k); } - else if ( parts.length == 4 ) { + else if ( parts.length == 5 ) { // Example: sort:mVar1:mVar2:mVar3 (input=mVar1, weights=mVar2, output=mVar3) + InstructionUtils.checkNumFields(str, 4); in2 = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN); - parseUnaryInstruction(str, in1, in2, out); + parseInstruction(str, in1, in2, out); return new QuantileSortCPInstruction(in1, in2, out, opcode, str); } else { @@ -85,7 +110,7 @@ public void processInstruction(ExecutionContext ec) { } //process core instruction - MatrixBlock resultBlock = matBlock.sortOperations(wtBlock, new MatrixBlock()); + MatrixBlock resultBlock = matBlock.sortOperations(wtBlock, new MatrixBlock(), _numThreads); //release inputs ec.releaseMatrixInput(input1.getName()); diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java index eb152ab5834..5df5d62e724 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java @@ -4820,6 +4820,10 @@ public final MatrixBlock sortOperations(MatrixValue weights){ } public MatrixBlock sortOperations(MatrixValue weights, MatrixBlock result) { + return sortOperations(weights, result, 1); + } + + public MatrixBlock sortOperations(MatrixValue weights, MatrixBlock result, int k) { boolean wtflag = (weights!=null); MatrixBlock wts= (weights == null ? null : checkType(weights)); @@ -4877,7 +4881,7 @@ public MatrixBlock sortOperations(MatrixValue weights, MatrixBlock result) { // Sort td and tw based on values inside td (ascending sort), incl copy into result SortIndex sfn = new SortIndex(1, false, false); - ReorgOperator rop = new ReorgOperator(sfn); + ReorgOperator rop = new ReorgOperator(sfn, k); LibMatrixReorg.reorg(tdw, result, rop); return result; From 5c880a18e5b3f3af37caff0aeb2e6e0d1e941c2b Mon Sep 17 00:00:00 2001 From: arnabp Date: Wed, 23 Mar 2022 14:35:50 +0100 Subject: [PATCH 2/2] [SYSTEMDS-3338] Multi-threaded local Qsort instruction Fix instruction string parsing for federated local qsort. --- .../instructions/InstructionUtils.java | 10 +++- .../cp/QuantileSortCPInstruction.java | 9 ++- .../fed/QuantileSortFEDInstruction.java | 55 ++++++++++++++----- 3 files changed, 55 insertions(+), 19 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java index f22fdfe5506..39fbef23077 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java @@ -225,7 +225,15 @@ public static String[] getInstructionPartsWithValueType( String str ) { return ret; } - + + public static String stripThreadCount(String str) { + String[] parts = str.split(Instruction.OPERAND_DELIM, -1); + String[] ret = new String[parts.length-1]; + for (int i=0; i 1 ? InstructionUtils.stripThreadCount(instString) : instString; + newInst = InstructionUtils.replaceOperand(newInst, 1, "append"); newInst = InstructionUtils.concatOperands(newInst, "true"); FederatedRequest[] fr1 = in.getFedMapping().broadcastSliced(weights, false); FederatedRequest fr2 = FederationUtils.callInstruction(newInst, output, @@ -123,7 +150,7 @@ public void processColumnQSort(ExecutionContext ec) { FederatedResponse response = data .executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1, - new GetSorted(data.getVarID(), varID, wtBlock))).get(); + new GetSorted(data.getVarID(), varID, wtBlock, _numThreads))).get(); if(!response.isSuccessful()) response.throwExceptionFromResponse(); } @@ -145,17 +172,19 @@ private static class GetSorted extends FederatedUDF { private static final long serialVersionUID = -1969015577260167645L; private final long _outputID; private final MatrixBlock _weights; + private final int _numThreads; - protected GetSorted(long input, long outputID, MatrixBlock weights) { + protected GetSorted(long input, long outputID, MatrixBlock weights, int k) { super(new long[] {input}); _outputID = outputID; _weights = weights; + _numThreads = k; } @Override public FederatedResponse execute(ExecutionContext ec, Data... data) { MatrixBlock mb = ((MatrixObject) data[0]).acquireReadAndRelease(); - MatrixBlock res = mb.sortOperations(_weights, new MatrixBlock()); + MatrixBlock res = mb.sortOperations(_weights, new MatrixBlock(), _numThreads); MatrixObject mout = ExecutionContext.createMatrixObject(res);