Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions src/main/java/org/apache/sysds/hops/BinaryOp.java
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,10 @@ else if(isMatrixScalar || isMatrixMatrix) {
public boolean isMultiThreadedOpType() {
return !getDataType().isScalar()
|| getOp() == OpOp2.COV
|| getOp() == OpOp2.MOMENT;
|| getOp() == OpOp2.MOMENT
|| getOp() == OpOp2.IQM
|| getOp() == OpOp2.MEDIAN
|| getOp() == OpOp2.QUANTILE;
}

@Override
Expand Down Expand Up @@ -233,11 +236,12 @@ public Lop constructLops()
}

private void constructLopsIQM(ExecType et) {
int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
SortKeys sort = SortKeys.constructSortByValueLop(
getInput().get(0).constructLops(),
getInput().get(1).constructLops(),
SortKeys.OperationTypes.WithWeights,
getInput().get(0).getDataType(), getInput().get(0).getValueType(), et);
getInput().get(0).getDataType(), getInput().get(0).getValueType(), et, k);
sort.getOutputParameters().setDimensions(
getInput().get(0).getDim1(),
getInput().get(0).getDim2(),
Expand All @@ -256,11 +260,12 @@ private void constructLopsIQM(ExecType et) {
}

private void constructLopsMedian(ExecType et) {
int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
SortKeys sort = SortKeys.constructSortByValueLop(
getInput().get(0).constructLops(),
getInput().get(1).constructLops(),
SortKeys.OperationTypes.WithWeights,
getInput().get(0).getDataType(), getInput().get(0).getValueType(), et);
getInput().get(0).getDataType(), getInput().get(0).getValueType(), et, k);
sort.getOutputParameters().setDimensions(
getInput().get(0).getDim1(),
getInput().get(0).getDim2(),
Expand Down Expand Up @@ -317,10 +322,11 @@ private void constructLopsQuantile(ExecType et) {
else
pick_op = PickByCount.OperationTypes.RANGEPICK;

int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
SortKeys sort = SortKeys.constructSortByValueLop(
getInput().get(0).constructLops(),
SortKeys.OperationTypes.WithoutWeights,
DataType.MATRIX, ValueType.FP64, et );
DataType.MATRIX, ValueType.FP64, et, k );
sort.getOutputParameters().setDimensions(
getInput().get(0).getDim1(),
getInput().get(0).getDim2(),
Expand Down
7 changes: 5 additions & 2 deletions src/main/java/org/apache/sysds/hops/TernaryOp.java
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,9 @@ public boolean isGPUEnabled() {
public boolean isMultiThreadedOpType() {
return _op == OpOp3.IFELSE
|| _op == OpOp3.MINUS_MULT
|| _op == OpOp3.PLUS_MULT;
|| _op == OpOp3.PLUS_MULT
|| _op == OpOp3.QUANTILE
|| _op == OpOp3.INTERQUANTILE;
}

@Override
Expand Down Expand Up @@ -247,9 +249,10 @@ private void constructLopsQuantile() {
throw new HopsException("Unexpected operation: " + _op + ", expecting " + OpOp3.QUANTILE + " or " + OpOp3.INTERQUANTILE );

ExecType et = optFindExecType();
int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
SortKeys sort = SortKeys.constructSortByValueLop(getInput().get(0).constructLops(),
getInput().get(1).constructLops(), SortKeys.OperationTypes.WithWeights,
getInput().get(0).getDataType(), getInput().get(0).getValueType(), et);
getInput().get(0).getDataType(), getInput().get(0).getValueType(), et, k);
PickByCount pick = new PickByCount(sort, getInput().get(2).constructLops(),
getDataType(), getValueType(), (_op == OpOp3.QUANTILE) ?
PickByCount.OperationTypes.VALUEPICK : PickByCount.OperationTypes.RANGEPICK, et, true);
Expand Down
20 changes: 10 additions & 10 deletions src/main/java/org/apache/sysds/hops/UnaryOp.java
Original file line number Diff line number Diff line change
Expand Up @@ -197,12 +197,11 @@ else if(_op == OpOp1.MEDIAN) {
private Lop constructLopsMedian()
{
ExecType et = optFindExecType();


int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
SortKeys sort = SortKeys.constructSortByValueLop(
getInput().get(0).constructLops(),
SortKeys.OperationTypes.WithoutWeights,
DataType.MATRIX, ValueType.FP64, et );
DataType.MATRIX, ValueType.FP64, et, k );
sort.getOutputParameters().setDimensions(
getInput().get(0).getDim1(),
getInput().get(0).getDim2(),
Expand All @@ -225,14 +224,13 @@ private Lop constructLopsMedian()

private Lop constructLopsIQM()
{

ExecType et = optFindExecType();

int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
Hop input = getInput().get(0);
SortKeys sort = SortKeys.constructSortByValueLop(
input.constructLops(),
SortKeys.OperationTypes.WithoutWeights,
DataType.MATRIX, ValueType.FP64, et );
SortKeys sort = SortKeys.constructSortByValueLop(
input.constructLops(),
SortKeys.OperationTypes.WithoutWeights,
DataType.MATRIX, ValueType.FP64, et, k );
sort.getOutputParameters().setDimensions(
input.getDim1(),
input.getDim2(),
Expand Down Expand Up @@ -456,7 +454,9 @@ public boolean isExpensiveUnaryOperation() {
|| _op == OpOp1.LOG
|| _op == OpOp1.SIGMOID
|| _op == OpOp1.COMPRESS
|| _op == OpOp1.DECOMPRESS);
|| _op == OpOp1.DECOMPRESS
|| _op == OpOp1.MEDIAN
|| _op == OpOp1.IQM);
}

public boolean isMetadataOperation() {
Expand Down
57 changes: 37 additions & 20 deletions src/main/java/org/apache/sysds/lops/SortKeys.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,31 +39,34 @@ public enum OperationTypes {
}

private OperationTypes operation;


private int _numThreads;

public OperationTypes getOpType() {
return operation;
}

public SortKeys(Lop input, OperationTypes op, DataType dt, ValueType vt, ExecType et) {
public SortKeys(Lop input, OperationTypes op, DataType dt, ValueType vt, ExecType et, int numThreads) {
super(Lop.Type.SortKeys, dt, vt);
init(input, null, op, et);
init(input, null, op, et, numThreads);
}

public SortKeys(Lop input, boolean desc, OperationTypes op, DataType dt, ValueType vt, ExecType et) {
super(Lop.Type.SortKeys, dt, vt);
init(input, null, op, et);
init(input, null, op, et, 1);
}

public SortKeys(Lop input1, Lop input2, OperationTypes op, DataType dt, ValueType vt, ExecType et) {
super(Lop.Type.SortKeys, dt, vt);
init(input1, input2, op, et);
public SortKeys(Lop input1, Lop input2, OperationTypes op, DataType dt, ValueType vt, ExecType et, int numThreads) {
super(Lop.Type.SortKeys, dt, vt);
init(input1, input2, op, et, numThreads);
}

private void init(Lop input1, Lop input2, OperationTypes op, ExecType et) {
private void init(Lop input1, Lop input2, OperationTypes op, ExecType et, int numThreads) {
addInput(input1);
input1.addOutput(this);

operation = op;
_numThreads = numThreads;

// SortKeys can accept a optional second input only when executing in CP
// Example: sorting with weights inside CP
Expand All @@ -82,43 +85,57 @@ public String toString() {

@Override
public String getInstructions(String input, String output) {
return InstructionUtils.concatOperands(
StringBuilder sb = new StringBuilder();
sb.append(InstructionUtils.concatOperands(
getExecType().name(),
OPCODE,
getInputs().get(0).prepInputOperand(input),
prepOutputOperand(output));
prepOutputOperand(output)));

if( getExecType() == ExecType.CP ) {
sb.append( OPERAND_DELIMITOR );
sb.append(_numThreads);
}
return sb.toString();
}

@Override
public String getInstructions(String input1, String input2, String output) {
return InstructionUtils.concatOperands(
StringBuilder sb = new StringBuilder();
sb.append(InstructionUtils.concatOperands(
getExecType().name(),
OPCODE,
getInputs().get(0).prepInputOperand(input1),
getInputs().get(1).prepInputOperand(input2),
prepOutputOperand(output));
prepOutputOperand(output)));

if( getExecType() == ExecType.CP ) {
sb.append( OPERAND_DELIMITOR );
sb.append(_numThreads);
}
return sb.toString();
}

// This method is invoked in two cases:
// 1) SortKeys (both weighted and unweighted) executes in MR
// 2) Unweighted SortKeys executes in CP
public static SortKeys constructSortByValueLop(Lop input1, OperationTypes op,
DataType dt, ValueType vt, ExecType et) {
public static SortKeys constructSortByValueLop(Lop input1, OperationTypes op,
DataType dt, ValueType vt, ExecType et, int numThreads) {

for (Lop lop : input1.getOutputs()) {
if ( lop.type == Lop.Type.SortKeys ) {
return (SortKeys)lop;
}
}
SortKeys retVal = new SortKeys(input1, op, dt, vt, et);

SortKeys retVal = new SortKeys(input1, op, dt, vt, et, numThreads);
retVal.setAllPositions(input1.getFilename(), input1.getBeginLine(), input1.getBeginColumn(), input1.getEndLine(), input1.getEndColumn());
return retVal;
}

// This method is invoked ONLY for the case of Weighted SortKeys executing in CP
public static SortKeys constructSortByValueLop(Lop input1, Lop input2, OperationTypes op,
DataType dt, ValueType vt, ExecType et) {
DataType dt, ValueType vt, ExecType et, int numThreads) {

HashSet<Lop> set1 = new HashSet<>();
set1.addAll(input1.getOutputs());
Expand All @@ -131,7 +148,7 @@ public static SortKeys constructSortByValueLop(Lop input1, Lop input2, Operation
}
}

SortKeys retVal = new SortKeys(input1, input2, op, dt, vt, et);
SortKeys retVal = new SortKeys(input1, input2, op, dt, vt, et, numThreads);
retVal.setAllPositions(input1.getFilename(), input1.getBeginLine(), input1.getBeginColumn(), input1.getEndLine(), input1.getEndColumn());
return retVal;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,15 @@ public static String[] getInstructionPartsWithValueType( String str ) {

return ret;
}


public static String stripThreadCount(String str) {
String[] parts = str.split(Instruction.OPERAND_DELIM, -1);
String[] ret = new String[parts.length-1];
for (int i=0; i<parts.length-1; i++) //strip-off the thread count
ret[i] = parts[i];
return concatOperands(ret);
}

public static ExecType getExecType( String str ) {
try{
int ix = str.indexOf(Instruction.OPERAND_DELIM);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,35 @@
*
*/
public class QuantileSortCPInstruction extends UnaryCPInstruction {
int _numThreads;

private QuantileSortCPInstruction(CPOperand in, CPOperand out, String opcode, String istr) {
this(in, null, out, opcode, istr);
private QuantileSortCPInstruction(CPOperand in, CPOperand out, String opcode, String istr, int k) {
this(in, null, out, opcode, istr, k);
}

private QuantileSortCPInstruction(CPOperand in1, CPOperand in2, CPOperand out, String opcode,
String istr) {
String istr, int k) {
super(CPType.QSort, null, in1, in2, out, opcode, istr);
_numThreads = k;
}

private static void parseInstruction(String instr, CPOperand in1, CPOperand in2, CPOperand out) {
String[] parts = InstructionUtils.getInstructionPartsWithValueType(instr);

out.split(parts[parts.length-2]);

switch(parts.length) {
case 4:
in1.split(parts[1]);
in2 = null;
break;
case 5:
in1.split(parts[1]);
in2.split(parts[2]);
break;
default:
throw new DMLRuntimeException("Unexpected number of operands in the instruction: " + instr);
}
}

public static QuantileSortCPInstruction parseInstruction ( String str ) {
Expand All @@ -55,16 +76,19 @@ public static QuantileSortCPInstruction parseInstruction ( String str ) {
String opcode = parts[0];

if ( opcode.equalsIgnoreCase(SortKeys.OPCODE) ) {
if ( parts.length == 3 ) {
int k = Integer.parseInt(parts[parts.length-1]); //#threads
if ( parts.length == 4 ) {
// Example: sort:mVar1:mVar2 (input=mVar1, output=mVar2)
parseUnaryInstruction(str, in1, out);
return new QuantileSortCPInstruction(in1, out, opcode, str);
InstructionUtils.checkNumFields(str, 3);
parseInstruction(str, in1, null, out);
return new QuantileSortCPInstruction(in1, out, opcode, str, k);
}
else if ( parts.length == 4 ) {
else if ( parts.length == 5 ) {
// Example: sort:mVar1:mVar2:mVar3 (input=mVar1, weights=mVar2, output=mVar3)
InstructionUtils.checkNumFields(str, 4);
in2 = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN);
parseUnaryInstruction(str, in1, in2, out);
return new QuantileSortCPInstruction(in1, in2, out, opcode, str);
parseInstruction(str, in1, in2, out);
return new QuantileSortCPInstruction(in1, in2, out, opcode, str, k);
}
else {
throw new DMLRuntimeException("Invalid number of operands in instruction: " + str);
Expand All @@ -85,7 +109,7 @@ public void processInstruction(ExecutionContext ec) {
}

//process core instruction
MatrixBlock resultBlock = matBlock.sortOperations(wtBlock, new MatrixBlock());
MatrixBlock resultBlock = matBlock.sortOperations(wtBlock, new MatrixBlock(), _numThreads);

//release inputs
ec.releaseMatrixInput(input1.getName());
Expand Down
Loading