From 4ae63f907282d977dad0ec0262a55d6508b9f12b Mon Sep 17 00:00:00 2001 From: Mark Dokter Date: Fri, 18 Mar 2022 01:11:19 +0100 Subject: [PATCH] [SYSTEMDS-3334] Code-gen rewrite AGGMAX_ROWMAXS_VECTMULT This patch adds a rewrite to avoid vector intermediates in the generated row template of connected components by doing the elementwise multiplication, row_maxs and max in one pass. --- .../sysds/hops/codegen/SpoofCompiler.java | 4 +++- .../sysds/hops/codegen/cplan/CNodeBinary.java | 1 + .../sysds/hops/codegen/cplan/java/Binary.java | 3 +++ .../codegen/template/CPlanOpRewriter.java | 23 +++++++++++++++++-- .../hops/codegen/template/TemplateUtils.java | 3 ++- .../runtime/codegen/LibSpoofPrimitives.java | 18 +++++++++++++-- 6 files changed, 46 insertions(+), 6 deletions(-) diff --git a/src/main/java/org/apache/sysds/hops/codegen/SpoofCompiler.java b/src/main/java/org/apache/sysds/hops/codegen/SpoofCompiler.java index ade88775e14..2872510f3d4 100644 --- a/src/main/java/org/apache/sysds/hops/codegen/SpoofCompiler.java +++ b/src/main/java/org/apache/sysds/hops/codegen/SpoofCompiler.java @@ -38,6 +38,7 @@ import org.apache.sysds.hops.LiteralOp; import org.apache.sysds.hops.OptimizerUtils; import org.apache.sysds.hops.codegen.cplan.CNode; +import org.apache.sysds.hops.codegen.cplan.CNodeBinary; import org.apache.sysds.hops.codegen.cplan.CNodeCell; import org.apache.sysds.hops.codegen.cplan.CNodeData; import org.apache.sysds.hops.codegen.cplan.CNodeMultiAgg; @@ -945,7 +946,8 @@ else if( OptimizerUtils.isSparkExecutionMode() ) { && TemplateUtils.hasSingleOperation(tpl) ) || (tpl instanceof CNodeRow && (((CNodeRow)tpl).getRowType()==RowType.NO_AGG || ((CNodeRow)tpl).getRowType()==RowType.NO_AGG_B1 - || ((CNodeRow)tpl).getRowType()==RowType.ROW_AGG ) + || (((CNodeRow)tpl).getRowType()==RowType.ROW_AGG && !TemplateUtils.isBinary(tpl.getOutput(), + CNodeBinary.BinType.AGGMAX_ROWMAXS_VECTMULT))) && TemplateUtils.hasSingleOperation(tpl)) || TemplateUtils.hasNoOperation(tpl) ) { diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java index 2e6bcd5d487..659bfa056d6 100644 --- a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java +++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java @@ -30,6 +30,7 @@ public class CNodeBinary extends CNode { public enum BinType { + AGGMAX_ROWMAXS_VECTMULT, //matrix multiplication operations DOT_PRODUCT, VECT_MATRIXMULT, VECT_OUTERMULT_ADD, //vector-scalar-add operations diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Binary.java b/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Binary.java index ecb7878f669..9cf66400ba7 100644 --- a/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Binary.java +++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Binary.java @@ -28,6 +28,9 @@ public String getTemplate(BinType type, boolean sparseLhs, boolean sparseRhs, boolean scalarVector, boolean scalarInput, boolean vectorVector) { switch (type) { + case AGGMAX_ROWMAXS_VECTMULT: + return sparseLhs ? "\tdouble %TMP% = LibSpoofPrimitives.aggMaxRowMaxsVectMult(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen);\n" : + "\tdouble %TMP% = LibSpoofPrimitives.aggMaxRowMaxsVectMult(%IN1%, %IN2%, %POS1%, %POS2%, %LEN%);\n"; case DOT_PRODUCT: return sparseLhs ? " double %TMP% = LibSpoofPrimitives.dotProduct(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen);\n" : " double %TMP% = LibSpoofPrimitives.dotProduct(%IN1%, %IN2%, %POS1%, %POS2%, %LEN%);\n"; diff --git a/src/main/java/org/apache/sysds/hops/codegen/template/CPlanOpRewriter.java b/src/main/java/org/apache/sysds/hops/codegen/template/CPlanOpRewriter.java index 2b981ee8938..b463e8ac807 100644 --- a/src/main/java/org/apache/sysds/hops/codegen/template/CPlanOpRewriter.java +++ b/src/main/java/org/apache/sysds/hops/codegen/template/CPlanOpRewriter.java @@ -23,9 +23,11 @@ import org.apache.sysds.hops.LiteralOp; import org.apache.sysds.hops.codegen.cplan.CNode; +import org.apache.sysds.hops.codegen.cplan.CNodeBinary; import org.apache.sysds.hops.codegen.cplan.CNodeData; import org.apache.sysds.hops.codegen.cplan.CNodeMultiAgg; import org.apache.sysds.hops.codegen.cplan.CNodeOuterProduct; +import org.apache.sysds.hops.codegen.cplan.CNodeRow; import org.apache.sysds.hops.codegen.cplan.CNodeTpl; import org.apache.sysds.hops.codegen.cplan.CNodeUnary; import org.apache.sysds.hops.codegen.cplan.CNodeBinary.BinType; @@ -56,6 +58,9 @@ public CNodeTpl simplifyCPlan(CNodeTpl tpl) { } else { tpl.setOutput(rSimplifyCNode(tpl.getOutput())); + if(tpl instanceof CNodeRow && TemplateUtils.isBinary(tpl.getOutput(), BinType.AGGMAX_ROWMAXS_VECTMULT)) { + ((CNodeRow) tpl).setNumVectorIntermediates(((CNodeRow) tpl).getNumVectorIntermediates()-2); + } } return tpl; @@ -73,10 +78,24 @@ private static CNode rSimplifyCNode(CNode node) { node = rewriteBinaryPow2Vect(node); //X^2 -> X*X node = rewriteBinaryMult2(node); //x*2 -> x+x; node = rewriteBinaryMult2Vect(node); //X*2 -> X+X; - + node = rewriteMaxRowMaxsVectMult(node); // max(rowMaxs(G * t(c)), c); see components.dml + return node; } - + + private static CNode rewriteMaxRowMaxsVectMult(CNode node) { + if(TemplateUtils.isBinary(node, BinType.MAX)) { + CNode left = node.getInput().get(0); + CNode right = node.getInput().get(1); + return (TemplateUtils.isUnary(left, UnaryType.ROW_MAXS) && + TemplateUtils.isBinary(left.getInput().get(0), BinType.VECT_MULT) && + TemplateUtils.isUnary(right, UnaryType.LOOKUP_R) ? new CNodeBinary(left.getInput().get(0).getInput().get(0), + right.getInput().get(0), BinType.AGGMAX_ROWMAXS_VECTMULT) : node); + } + else + return(node); + } + private static CNode rewriteRowCountNnz(CNode node) { return (TemplateUtils.isUnary(node, UnaryType.ROW_SUMS) && TemplateUtils.isBinary(node.getInput().get(0), BinType.VECT_NOTEQUAL_SCALAR) diff --git a/src/main/java/org/apache/sysds/hops/codegen/template/TemplateUtils.java b/src/main/java/org/apache/sysds/hops/codegen/template/TemplateUtils.java index f61305fa13c..a94993f2547 100644 --- a/src/main/java/org/apache/sysds/hops/codegen/template/TemplateUtils.java +++ b/src/main/java/org/apache/sysds/hops/codegen/template/TemplateUtils.java @@ -391,7 +391,8 @@ public static boolean hasSingleOperation(CNodeTpl tpl) { && !TemplateUtils.isUnary(output, UnaryType.EXP, UnaryType.LOG, UnaryType.ROW_COUNTNNZS)) || (output instanceof CNodeBinary - && !TemplateUtils.isBinary(output, BinType.VECT_OUTERMULT_ADD)) + && (!TemplateUtils.isBinary(output, BinType.VECT_OUTERMULT_ADD) || + !TemplateUtils.isBinary(output, BinType.AGGMAX_ROWMAXS_VECTMULT))) || output instanceof CNodeTernary && ((CNodeTernary)output).getType() == TernaryType.IFELSE) && hasOnlyDataNodeOrLookupInputs(output); diff --git a/src/main/java/org/apache/sysds/runtime/codegen/LibSpoofPrimitives.java b/src/main/java/org/apache/sysds/runtime/codegen/LibSpoofPrimitives.java index 905b39226dc..c94cf114fb3 100644 --- a/src/main/java/org/apache/sysds/runtime/codegen/LibSpoofPrimitives.java +++ b/src/main/java/org/apache/sysds/runtime/codegen/LibSpoofPrimitives.java @@ -50,9 +50,23 @@ public class LibSpoofPrimitives private static ThreadLocal memPool = new ThreadLocal() { @Override protected VectorBuffer initialValue() { return new VectorBuffer(0,0,0); } }; - + + public static double aggMaxRowMaxsVectMult(double[] a, double[] b, int ai, int bi, int len) { + double val = Double.NEGATIVE_INFINITY; + int j=0; + for( int i = ai; i < ai+len; i++ ) + val = Math.max(a[i]*b[j++], val); + return Math.max(val, b[bi]); + } + + public static double aggMaxRowMaxsVectMult(double[] a, double[] b, int[] aix, int ai, int bi, int len) { + double val = Double.NEGATIVE_INFINITY; + for( int i = ai; i < ai+len; i++ ) + val = Math.max(a[i]*b[aix[i]], val); + return Math.max(val, b[bi]); + } + // forwarded calls to LibMatrixMult - public static double dotProduct(double[] a, double[] b, int ai, int bi, int len) { if( a == null || b == null ) return 0; return LibMatrixMult.dotProduct(a, b, ai, bi, len);