Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.codegen.cplan.CNode;
import org.apache.sysds.hops.codegen.cplan.CNodeBinary;
import org.apache.sysds.hops.codegen.cplan.CNodeCell;
import org.apache.sysds.hops.codegen.cplan.CNodeData;
import org.apache.sysds.hops.codegen.cplan.CNodeMultiAgg;
Expand Down Expand Up @@ -945,7 +946,8 @@ else if( OptimizerUtils.isSparkExecutionMode() ) {
&& TemplateUtils.hasSingleOperation(tpl) )
|| (tpl instanceof CNodeRow && (((CNodeRow)tpl).getRowType()==RowType.NO_AGG
|| ((CNodeRow)tpl).getRowType()==RowType.NO_AGG_B1
|| ((CNodeRow)tpl).getRowType()==RowType.ROW_AGG )
|| (((CNodeRow)tpl).getRowType()==RowType.ROW_AGG && !TemplateUtils.isBinary(tpl.getOutput(),
CNodeBinary.BinType.AGGMAX_ROWMAXS_VECTMULT)))
&& TemplateUtils.hasSingleOperation(tpl))
|| TemplateUtils.hasNoOperation(tpl) )
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
public class CNodeBinary extends CNode {

public enum BinType {
AGGMAX_ROWMAXS_VECTMULT,
//matrix multiplication operations
DOT_PRODUCT, VECT_MATRIXMULT, VECT_OUTERMULT_ADD,
//vector-scalar-add operations
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ public String getTemplate(BinType type, boolean sparseLhs, boolean sparseRhs,
boolean scalarVector, boolean scalarInput, boolean vectorVector)
{
switch (type) {
case AGGMAX_ROWMAXS_VECTMULT:
return sparseLhs ? "\tdouble %TMP% = LibSpoofPrimitives.aggMaxRowMaxsVectMult(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen);\n" :
"\tdouble %TMP% = LibSpoofPrimitives.aggMaxRowMaxsVectMult(%IN1%, %IN2%, %POS1%, %POS2%, %LEN%);\n";
case DOT_PRODUCT:
return sparseLhs ? " double %TMP% = LibSpoofPrimitives.dotProduct(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen);\n" :
" double %TMP% = LibSpoofPrimitives.dotProduct(%IN1%, %IN2%, %POS1%, %POS2%, %LEN%);\n";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@

import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.codegen.cplan.CNode;
import org.apache.sysds.hops.codegen.cplan.CNodeBinary;
import org.apache.sysds.hops.codegen.cplan.CNodeData;
import org.apache.sysds.hops.codegen.cplan.CNodeMultiAgg;
import org.apache.sysds.hops.codegen.cplan.CNodeOuterProduct;
import org.apache.sysds.hops.codegen.cplan.CNodeRow;
import org.apache.sysds.hops.codegen.cplan.CNodeTpl;
import org.apache.sysds.hops.codegen.cplan.CNodeUnary;
import org.apache.sysds.hops.codegen.cplan.CNodeBinary.BinType;
Expand Down Expand Up @@ -56,6 +58,9 @@ public CNodeTpl simplifyCPlan(CNodeTpl tpl) {
}
else {
tpl.setOutput(rSimplifyCNode(tpl.getOutput()));
if(tpl instanceof CNodeRow && TemplateUtils.isBinary(tpl.getOutput(), BinType.AGGMAX_ROWMAXS_VECTMULT)) {
((CNodeRow) tpl).setNumVectorIntermediates(((CNodeRow) tpl).getNumVectorIntermediates()-2);
}
}

return tpl;
Expand All @@ -73,10 +78,24 @@ private static CNode rSimplifyCNode(CNode node) {
node = rewriteBinaryPow2Vect(node); //X^2 -> X*X
node = rewriteBinaryMult2(node); //x*2 -> x+x;
node = rewriteBinaryMult2Vect(node); //X*2 -> X+X;

node = rewriteMaxRowMaxsVectMult(node); // max(rowMaxs(G * t(c)), c); see components.dml

return node;
}


private static CNode rewriteMaxRowMaxsVectMult(CNode node) {
if(TemplateUtils.isBinary(node, BinType.MAX)) {
CNode left = node.getInput().get(0);
CNode right = node.getInput().get(1);
return (TemplateUtils.isUnary(left, UnaryType.ROW_MAXS) &&
TemplateUtils.isBinary(left.getInput().get(0), BinType.VECT_MULT) &&
TemplateUtils.isUnary(right, UnaryType.LOOKUP_R) ? new CNodeBinary(left.getInput().get(0).getInput().get(0),
right.getInput().get(0), BinType.AGGMAX_ROWMAXS_VECTMULT) : node);
}
else
return(node);
}

private static CNode rewriteRowCountNnz(CNode node) {
return (TemplateUtils.isUnary(node, UnaryType.ROW_SUMS)
&& TemplateUtils.isBinary(node.getInput().get(0), BinType.VECT_NOTEQUAL_SCALAR)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,8 @@ public static boolean hasSingleOperation(CNodeTpl tpl) {
&& !TemplateUtils.isUnary(output,
UnaryType.EXP, UnaryType.LOG, UnaryType.ROW_COUNTNNZS))
|| (output instanceof CNodeBinary
&& !TemplateUtils.isBinary(output, BinType.VECT_OUTERMULT_ADD))
&& (!TemplateUtils.isBinary(output, BinType.VECT_OUTERMULT_ADD) ||
!TemplateUtils.isBinary(output, BinType.AGGMAX_ROWMAXS_VECTMULT)))
|| output instanceof CNodeTernary
&& ((CNodeTernary)output).getType() == TernaryType.IFELSE)
&& hasOnlyDataNodeOrLookupInputs(output);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,23 @@ public class LibSpoofPrimitives
private static ThreadLocal<VectorBuffer> memPool = new ThreadLocal<VectorBuffer>() {
@Override protected VectorBuffer initialValue() { return new VectorBuffer(0,0,0); }
};


public static double aggMaxRowMaxsVectMult(double[] a, double[] b, int ai, int bi, int len) {
double val = Double.NEGATIVE_INFINITY;
int j=0;
for( int i = ai; i < ai+len; i++ )
val = Math.max(a[i]*b[j++], val);
return Math.max(val, b[bi]);
}

public static double aggMaxRowMaxsVectMult(double[] a, double[] b, int[] aix, int ai, int bi, int len) {
double val = Double.NEGATIVE_INFINITY;
for( int i = ai; i < ai+len; i++ )
val = Math.max(a[i]*b[aix[i]], val);
return Math.max(val, b[bi]);
}

// forwarded calls to LibMatrixMult

public static double dotProduct(double[] a, double[] b, int ai, int bi, int len) {
if( a == null || b == null ) return 0;
return LibMatrixMult.dotProduct(a, b, ai, bi, len);
Expand Down