Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/main/java/org/apache/sysds/hops/Hop.java
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
import org.apache.sysds.runtime.util.UtilFunctions;

public abstract class Hop implements ParseInfo {
private static final Log LOG = LogFactory.getLog(Hop.class.getName());
protected static final Log LOG = LogFactory.getLog(Hop.class.getName());

public static final long CPThreshold = 2000;

Expand Down
25 changes: 9 additions & 16 deletions src/main/java/org/apache/sysds/hops/UnaryOp.java
Original file line number Diff line number Diff line change
Expand Up @@ -130,35 +130,31 @@ public Lop constructLops()
//reuse existing lop
if( getLops() != null )
return getLops();
int k;
final int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
try {
Hop input = getInput().get(0);
Lop ret = null;
switch(_op){
case COMPRESS:
k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
ret = new Compression(input.constructLops(), getDataType(), getValueType(), optFindExecType(), 0);
break;
case DECOMPRESS:
k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
ret = new DeCompression(input.constructLops(), getDataType(), getValueType(), optFindExecType());
break;
case LOCAL:
ret = new Local(input.constructLops(), getDataType(), getValueType());
break;
default:
final boolean isScalarIn = getInput().get(0).getDataType() == DataType.SCALAR;
if(getDataType() == DataType.SCALAR // value type casts or matrix to scalar
|| (_op == OpOp1.CAST_AS_MATRIX && getInput().get(0).getDataType() == DataType.SCALAR) ||
(_op == OpOp1.CAST_AS_FRAME && getInput().get(0).getDataType() == DataType.SCALAR)) {
if(_op == OpOp1.IQM) { // special handling IQM
|| (_op == OpOp1.CAST_AS_MATRIX && isScalarIn) // cast matrix
|| (_op == OpOp1.CAST_AS_FRAME && isScalarIn)) { // cast frame
if(_op == OpOp1.IQM) // special handling IQM
ret = constructLopsIQM();
}
else if(_op == OpOp1.MEDIAN) {
else if(_op == OpOp1.MEDIAN)
ret = constructLopsMedian();
}
else { // general case SCALAR/CAST (always in CP)
ret = new UnaryCP(input.constructLops(), _op, getDataType(), getValueType());
}
else // general case SCALAR/CAST (always in CP) & always single threaded
ret = new UnaryCP(input.constructLops(), _op, getDataType(), getValueType(), k);
}
else { // general case MATRIX
ExecType et = optFindExecType();
Expand All @@ -168,13 +164,10 @@ else if(_op == OpOp1.MEDIAN) {
// TODO additional physical operation if offsets fit in memory
ret = constructLopsSparkCumulativeUnary();
}
else // default unary
{
else {// default unary
final boolean inplace = OptimizerUtils.ALLOW_UNARY_UPDATE_IN_PLACE &&
input.getParent().size() == 1 && (!(input instanceof DataOp) || !((DataOp) input).isRead());

k = isCumulativeUnaryOperation() || isExpensiveUnaryOperation() ?
OptimizerUtils.getConstrainedNumThreads(_maxNumThreads) : 1;
ret = new Unary(input.constructLops(), _op, getDataType(), getValueType(), et, k, inplace);
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/org/apache/sysds/lops/Lop.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@

public abstract class Lop
{
private static final Log LOG = LogFactory.getLog(Lop.class.getName());
protected static final Log LOG = LogFactory.getLog(Lop.class.getName());

public enum Type {
Data, DataGen, //CP/MR read/write/datagen
Expand Down
30 changes: 16 additions & 14 deletions src/main/java/org/apache/sysds/lops/Unary.java
Original file line number Diff line number Diff line change
Expand Up @@ -156,20 +156,22 @@ public String getInstructions(String input1, String output) {

// Unary operators with one input
StringBuilder sb = new StringBuilder();
sb.append( getExecType() );
sb.append( Lop.OPERAND_DELIMITOR );
sb.append( getOpcode() );
sb.append( OPERAND_DELIMITOR );
sb.append( getInputs().get(0).prepInputOperand(input1) );
sb.append( OPERAND_DELIMITOR );
sb.append( prepOutputOperand(output) );

//num threads for cumulative cp ops
if( (getExecType() == ExecType.CP || getExecType() == ExecType.FED) && isMultiThreadedOp(operation) ) {
sb.append( OPERAND_DELIMITOR );
sb.append( _numThreads );
sb.append( OPERAND_DELIMITOR );
sb.append( _inplace );
sb.append(getExecType());
sb.append(Lop.OPERAND_DELIMITOR);
sb.append(getOpcode());
sb.append(OPERAND_DELIMITOR);
sb.append(getInputs().get(0).prepInputOperand(input1));
sb.append(OPERAND_DELIMITOR);
sb.append(prepOutputOperand(output));

if(getExecType() == ExecType.CP || getExecType() == ExecType.FED) {
sb.append(OPERAND_DELIMITOR);
sb.append(_numThreads);
if(isMultiThreadedOp(operation)) {

sb.append(OPERAND_DELIMITOR);
sb.append(_inplace);
}
}

appendFedOut(sb);
Expand Down
25 changes: 18 additions & 7 deletions src/main/java/org/apache/sysds/lops/UnaryCP.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
import org.apache.sysds.common.Types.OpOp1;
import org.apache.sysds.common.Types.ValueType;

public class UnaryCP extends Lop
{
private OpOp1 operation;
public class UnaryCP extends Lop {
private final OpOp1 operation;
private final int _numThreads;

/**
* Constructor to perform a scalar operation
Expand All @@ -38,22 +38,32 @@ public class UnaryCP extends Lop
* @param dt data type of the output
* @param vt value type of the output
* @param et exec type
* @param k parallelization degree
*/
public UnaryCP(Lop input, OpOp1 op, DataType dt, ValueType vt, ExecType et) {
public UnaryCP(Lop input, OpOp1 op, DataType dt, ValueType vt, ExecType et, int k) {
super(Lop.Type.UnaryCP, dt, vt);
operation = op;
addInput(input);
input.addOutput(this);
lps.setProperties(inputs, et);
_numThreads = k;
}

public UnaryCP(Lop input, OpOp1 op, DataType dt, ValueType vt, ExecType et) {
this(input, op, dt, vt, et, 1);
}

public UnaryCP(Lop input, OpOp1 op, DataType dt, ValueType vt, int k) {
this(input, op, dt, vt, ExecType.CP, k);
}

public UnaryCP(Lop input, OpOp1 op, DataType dt, ValueType vt) {
this(input, op, dt, vt, ExecType.CP);
this(input, op, dt, vt, ExecType.CP, 1);
}

@Override
public String toString() {
return "Operation: " + operation;
return "Operation: " + getInstructions("", "");
}

private String getOpCode() {
Expand All @@ -65,6 +75,7 @@ public String getInstructions(String input, String output) {
return InstructionUtils.concatOperands(
getExecType().name(), getOpCode(),
getInputs().get(0).prepScalarInputOperand(getExecType()),
prepOutputOperand(output));
prepOutputOperand(output),
Integer.toString(_numThreads));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -933,7 +933,7 @@ public Set<String> getTmpParforFunctions() {
@Override
public String toString(){
StringBuilder sb = new StringBuilder();
sb.append(super.toString());
sb.append(this.getClass().getSimpleName().toString());
if(_prog != null)
sb.append("\nProgram: " + _prog.toString());
if(_variables != null)
Expand Down
Loading