Skip to content

Commit

Permalink
[SYSTEMDS-3234] Multi-threaded covariance/central-moment operations
Browse files Browse the repository at this point in the history
Inspired by performance issues in SYSTEMDS-3233, this patch
introduces multi-threaded cov/cm operations which were still
single-threaded. These operations are mostly executed in parfor
contexts, but if large memory requirements force a lower degree of
parallelism in parfor, we should distributed the remaining
parallelism to intra-operation parallelism like many other ops.

Furthermore, this patch also cleans up the instruction construction
parsing, and core cov/cm operations in order to share a common code
path in LibMatrixAgg.

On the scenario of SYSTEMDS-3233 this patch improved end-to-end
performance from 261s to 144s eliminating cov/cm as top-2 heavy
hitters (now right indexing due to column indexing on sparse matrix).
On 100M row (800MB) input vectors and 100 operations, the total
runtime improved as follows (server with 32 vcores):
* 100x cov(100M, 100M): 105s -> 7.7s (13.6x)
* 100x cm(100M):        109s -> 9.1s (12x)
  • Loading branch information
mboehm7 committed Dec 18, 2021
1 parent 5cc5239 commit 2aad571
Show file tree
Hide file tree
Showing 14 changed files with 289 additions and 254 deletions.
24 changes: 13 additions & 11 deletions src/main/java/org/apache/sysds/hops/BinaryOp.java
Expand Up @@ -59,7 +59,7 @@
* Semantic: align indices (sort), then perform operation
*/

public class BinaryOp extends MultiThreadedHop{
public class BinaryOp extends MultiThreadedHop {
// private static final Log LOG = LogFactory.getLog(BinaryOp.class.getName());

//we use the full remote memory budget (but reduced by sort buffer),
Expand Down Expand Up @@ -179,7 +179,9 @@ else if(isMatrixScalar || isMatrixMatrix) {

@Override
public boolean isMultiThreadedOpType() {
return !getDataType().isScalar();
return !getDataType().isScalar()
|| getOp() == OpOp2.COV
|| getOp() == OpOp2.MOMENT;
}

@Override
Expand Down Expand Up @@ -279,26 +281,26 @@ private void constructLopsMedian(ExecType et) {
setLops(pick);
}

private void constructLopsCentralMoment(ExecType et)
{
private void constructLopsCentralMoment(ExecType et) {
// The output data type is a SCALAR if central moment
// gets computed in CP/SPARK, and it will be MATRIX otherwise.
DataType dt = DataType.SCALAR;
int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
CentralMoment cm = new CentralMoment(
getInput().get(0).constructLops(),
getInput().get(1).constructLops(),
dt, getValueType(), et);

getInput().get(0).constructLops(),
getInput().get(1).constructLops(),
dt, getValueType(), k, et);
setLineNumbers(cm);
cm.getOutputParameters().setDimensions(0, 0, 0, -1);
setLops(cm);
}

private void constructLopsCovariance(ExecType et) {
int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
CoVariance cov = new CoVariance(
getInput().get(0).constructLops(),
getInput().get(1).constructLops(),
getDataType(), getValueType(), et);
getInput().get(0).constructLops(),
getInput().get(1).constructLops(),
getDataType(), getValueType(), k, et);
cov.getOutputParameters().setDimensions(0, 0, 0, -1);
setLineNumbers(cov);
setLops(cov);
Expand Down
24 changes: 11 additions & 13 deletions src/main/java/org/apache/sysds/hops/TernaryOp.java
Expand Up @@ -203,18 +203,17 @@ public Lop constructLops()
/**
* Method to construct LOPs when op = CENTRAILMOMENT.
*/
private void constructLopsCentralMoment()
{
private void constructLopsCentralMoment() {
if ( _op != OpOp3.MOMENT )
throw new HopsException("Unexpected operation: " + _op + ", expecting " + OpOp3.MOMENT );

ExecType et = optFindExecType();

int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
CentralMoment cm = new CentralMoment(
getInput().get(0).constructLops(),
getInput().get(1).constructLops(),
getInput().get(2).constructLops(),
getDataType(), getValueType(), et);
getInput().get(0).constructLops(),
getInput().get(1).constructLops(),
getInput().get(2).constructLops(),
getDataType(), getValueType(), k, et);
cm.getOutputParameters().setDimensions(0, 0, 0, -1);
setLineNumbers(cm);
setLops(cm);
Expand All @@ -228,13 +227,12 @@ private void constructLopsCovariance() {
throw new HopsException("Unexpected operation: " + _op + ", expecting " + OpOp3.COV );

ExecType et = optFindExecType();


int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
CoVariance cov = new CoVariance(
getInput().get(0).constructLops(),
getInput().get(1).constructLops(),
getInput().get(2).constructLops(),
getDataType(), getValueType(), et);
getInput().get(0).constructLops(),
getInput().get(1).constructLops(),
getInput().get(2).constructLops(),
getDataType(), getValueType(), k, et);
cov.getOutputParameters().setDimensions(0, 0, 0, -1);
setLineNumbers(cov);
setLops(cov);
Expand Down
37 changes: 23 additions & 14 deletions src/main/java/org/apache/sysds/lops/CentralMoment.java
Expand Up @@ -30,6 +30,18 @@
*/
public class CentralMoment extends Lop
{
private final int _numThreads;

public CentralMoment(Lop input1, Lop input2, DataType dt, ValueType vt, int numThreads, ExecType et) {
this(input1, input2, null, dt, vt, numThreads, et);
}

public CentralMoment(Lop input1, Lop input2, Lop input3, DataType dt, ValueType vt, int numThreads, ExecType et) {
super(Lop.Type.CentralMoment, dt, vt);
init(input1, input2, input3, et);
_numThreads = numThreads;
}

/**
* Constructor to perform central moment.
* input1 <- data (weighted or unweighted)
Expand All @@ -54,15 +66,6 @@ private void init(Lop input1, Lop input2, Lop input3, ExecType et) {
lps.setProperties(inputs, et);
}

public CentralMoment(Lop input1, Lop input2, DataType dt, ValueType vt, ExecType et) {
this(input1, input2, null, dt, vt, et);
}

public CentralMoment(Lop input1, Lop input2, Lop input3, DataType dt, ValueType vt, ExecType et) {
super(Lop.Type.CentralMoment, dt, vt);
init(input1, input2, input3, et);
}

@Override
public String toString() {
return "Operation = CentralMoment";
Expand All @@ -77,21 +80,27 @@ public String toString() {
*/
@Override
public String getInstructions(String input1, String input2, String input3, String output) {
StringBuilder sb = new StringBuilder();
if( input3 == null ) {
return InstructionUtils.concatOperands(
sb.append(InstructionUtils.concatOperands(
getExecType().toString(), "cm",
getInputs().get(0).prepInputOperand(input1),
getInputs().get((input3!=null)?2:1).prepScalarInputOperand(getExecType()),
prepOutputOperand(output));
prepOutputOperand(output)));
}
else {
return InstructionUtils.concatOperands(
sb.append(InstructionUtils.concatOperands(
getExecType().toString(), "cm",
getInputs().get(0).prepInputOperand(input1),
getInputs().get(1).prepInputOperand(input2),
getInputs().get((input3!=null)?2:1).prepScalarInputOperand(getExecType()),
prepOutputOperand(output));
prepOutputOperand(output)));
}
if( getExecType() == ExecType.CP ) {
sb.append(OPERAND_DELIMITOR);
sb.append(String.valueOf(_numThreads));
}
return sb.toString();
}

/**
Expand All @@ -104,4 +113,4 @@ public String getInstructions(String input1, String input2, String input3, Strin
public String getInstructions(String input1, String input2, String output) {
return getInstructions(input1, input2, null, output);
}
}
}
43 changes: 14 additions & 29 deletions src/main/java/org/apache/sysds/lops/CoVariance.java
Expand Up @@ -29,44 +29,31 @@
*/
public class CoVariance extends Lop
{

public CoVariance(Lop input1, DataType dt, ValueType vt, ExecType et) {
super(Lop.Type.CoVariance, dt, vt);
init(input1, null, null, et);
}
private final int _numThreads;

public CoVariance(Lop input1, Lop input2, DataType dt, ValueType vt, ExecType et) {
this(input1, input2, null, dt, vt, et);
public CoVariance(Lop input1, Lop input2, DataType dt, ValueType vt, int numThreads, ExecType et) {
this(input1, input2, null, dt, vt, numThreads, et);
}

public CoVariance(Lop input1, Lop input2, Lop input3, DataType dt, ValueType vt, ExecType et) {
public CoVariance(Lop input1, Lop input2, Lop input3, DataType dt, ValueType vt, int numThreads, ExecType et) {
super(Lop.Type.CoVariance, dt, vt);
init(input1, input2, input3, et);
_numThreads = numThreads;
}

private void init(Lop input1, Lop input2, Lop input3, ExecType et) {
/*
* When et = MR: covariance lop will have a single input lop, which
* denote the combined input data -- output of combinebinary, if unweighed;
* and output combineteriaty (if weighted).
*
* When et = CP: covariance lop must have at least two input lops, which
* denote the two input columns on which covariance is computed. It also
* takes an optional third arguments, when weighted covariance is computed.
*/
if ( input2 == null )
throw new LopsException(this.printErrorLocation() + "Invalid inputs to covariance lop.");

addInput(input1);
input1.addOutput(this);

if ( input2 == null ) {
throw new LopsException(this.printErrorLocation() + "Invalid inputs to covariance lop.");
}
addInput(input2);
input2.addOutput(this);

if ( input3 != null ) {
addInput(input3);
input3.addOutput(this);
}

lps.setProperties(inputs, et);
}

Expand Down Expand Up @@ -102,19 +89,17 @@ public String getInstructions(String input1, String input2, String input3, Strin

sb.append( getInputs().get(0).prepInputOperand(input1));
sb.append( OPERAND_DELIMITOR );

if( input2 != null ) {
sb.append( getInputs().get(1).prepInputOperand(input2));
sb.append( OPERAND_DELIMITOR );
}

sb.append( getInputs().get(1).prepInputOperand(input2));
sb.append( OPERAND_DELIMITOR );
if( input3 != null ) {
sb.append( getInputs().get(2).prepInputOperand(input3));
sb.append( OPERAND_DELIMITOR );
}

sb.append( prepOutputOperand(output));
sb.append( OPERAND_DELIMITOR );
sb.append(_numThreads);

return sb.toString();
}
}
}
Expand Up @@ -74,6 +74,10 @@ public static CM getCMFnObject( AggregateOperationTypes type ) {
//execution due to state in cm object (buff2, buff3)
return new CM( type );
}

public static CM getCMFnObject(CM fn) {
return getCMFnObject(fn._type);
}

public AggregateOperationTypes getAggOpType() {
return _type;
Expand Down
Expand Up @@ -59,6 +59,7 @@ private COV() {
* @param w2 ?
* @return result
*/
@Override
public Data execute(Data in1, double u, double v, double w2)
{
CM_COV_Object cov1=(CM_COV_Object) in1;
Expand Down
Expand Up @@ -81,6 +81,10 @@ public Data execute(Data in1, double in2, double in3) {
throw new DMLRuntimeException("execute(): should not be invoked from base class.");
}

public Data execute(Data in1, double in2, double in3, double in4) {
throw new DMLRuntimeException("execute(): should not be invoked from base class.");
}

public Data execute(Data in1, Data in2) {
throw new DMLRuntimeException("execute(): should not be invoked from base class.");
}
Expand Down
Expand Up @@ -1063,8 +1063,7 @@ public static String replaceOperandName(String instStr) {
* @return the instruction string with the given inputs concatenated
*/
public static String concatOperands(String... inputs) {
concatBaseOperandsWithDelim(Lop.OPERAND_DELIMITOR, inputs);
return _strBuilders.get().toString();
return concatBaseOperandsWithDelim(Lop.OPERAND_DELIMITOR, inputs);
}

/**
Expand All @@ -1073,18 +1072,18 @@ public static String concatOperands(String... inputs) {
* @return concatenated input parts
*/
public static String concatOperandParts(String... inputs) {
concatBaseOperandsWithDelim(Instruction.VALUETYPE_PREFIX, inputs);
return _strBuilders.get().toString();
return concatBaseOperandsWithDelim(Instruction.VALUETYPE_PREFIX, inputs);
}

private static void concatBaseOperandsWithDelim(String delim, String... inputs){
private static String concatBaseOperandsWithDelim(String delim, String... inputs){
StringBuilder sb = _strBuilders.get();
sb.setLength(0); //reuse allocated space
for( int i=0; i<inputs.length-1; i++ ) {
sb.append(inputs[i]);
sb.append(delim);
}
sb.append(inputs[inputs.length-1]);
return sb.toString();
}

public static String concatStrings(String... inputs) {
Expand Down
Expand Up @@ -19,8 +19,6 @@

package org.apache.sysds.runtime.instructions.cp;

import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.functionobjects.CM;
Expand All @@ -37,52 +35,37 @@ private CentralMomentCPInstruction(CMOperator cm, CPOperand in1, CPOperand in2,
}

public static CentralMomentCPInstruction parseInstruction(String str) {
CPOperand in1 = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN);
CPOperand in2 = null;
CPOperand in3 = null;
CPOperand out = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN);

String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
String opcode = parts[0];

//check supported opcode
if( !opcode.equalsIgnoreCase("cm") ) {
throw new DMLRuntimeException("Unsupported opcode "+opcode);
}

if ( parts.length == 4 ) {
// Example: CP.cm.mVar0.Var1.mVar2; (without weights)
in2 = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN);
parseUnaryInstruction(str, in1, in2, out);
}
else if ( parts.length == 5) {
// CP.cm.mVar0.mVar1.Var2.mVar3; (with weights)
in2 = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN);
in3 = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN);
parseUnaryInstruction(str, in1, in2, in3, out);
}


InstructionUtils.checkNumFields(str, 4, 5); //w/o opcode
CPOperand in1 = new CPOperand(parts[1]); //data
CPOperand in2 = new CPOperand(parts[2]); //scalar
CPOperand in3 = (parts.length==5) ? null : new CPOperand(parts[3]); //weights
CPOperand out = new CPOperand(parts[parts.length-2]);
int numThreads = Integer.parseInt(parts[parts.length-1]);

/*
* Exact order of the central moment MAY NOT be known at compilation time.
* We first try to parse the second argument as an integer, and if we fail,
* we simply pass -1 so that getCMAggOpType() picks up AggregateOperationTypes.INVALID.
* It must be updated at run time in processInstruction() method.
*/

int cmOrder;
try {
if ( in3 == null ) {
cmOrder = Integer.parseInt(in2.getName());
}
else {
cmOrder = Integer.parseInt(in3.getName());
}
} catch(NumberFormatException e) {
cmOrder = Integer.parseInt((in3==null) ? in2.getName() : in3.getName());
}
catch(NumberFormatException e) {
cmOrder = -1; // unknown at compilation time
}

AggregateOperationTypes opType = CMOperator.getCMAggOpType(cmOrder);
CMOperator cm = new CMOperator(CM.getCMFnObject(opType), opType);
CMOperator cm = new CMOperator(CM.getCMFnObject(opType), opType, numThreads);
return new CentralMomentCPInstruction(cm, in1, in2, in3, out, opcode, str);
}

Expand Down

0 comments on commit 2aad571

Please sign in to comment.