diff --git a/docs/spark-mlcontext-programming-guide.md b/docs/spark-mlcontext-programming-guide.md index 71db1a49e53..d0e1f0fc473 100644 --- a/docs/spark-mlcontext-programming-guide.md +++ b/docs/spark-mlcontext-programming-guide.md @@ -1108,6 +1108,7 @@ minOut = min(Xin) maxOut = max(Xin) meanOut = mean(Xin) """ +val mm = new MatrixMetadata(numRows, numCols) val minMaxMeanScript = dml(minMaxMean).in("Xin", df, mm).out("minOut", "maxOut", "meanOut") val (min, max, mean) = ml.execute(minMaxMeanScript).getTuple[Double, Double, Double]("minOut", "maxOut", "meanOut") @@ -1131,6 +1132,9 @@ maxOut = max(Xin) meanOut = mean(Xin) " +scala> val mm = new MatrixMetadata(numRows, numCols) +mm: org.apache.sysml.api.mlcontext.MatrixMetadata = rows: 10000, columns: 1000, non-zeros: None, rows per block: None, columns per block: None + scala> val minMaxMeanScript = dml(minMaxMean).in("Xin", df, mm).out("minOut", "maxOut", "meanOut") minMaxMeanScript: org.apache.sysml.api.mlcontext.Script = Inputs: @@ -1147,17 +1151,57 @@ scala> val (min, max, mean) = ml.execute(minMaxMeanScript).getTuple[Double, Doub PROGRAM --MAIN PROGRAM ----GENERIC (lines 1-8) [recompile=false] -------(4959) PRead Xin [10000,1000,1000,1000,10000000] [0,0,76 -> 76MB] [chkpt] -------(4960) ua(minRC) (4959) [0,0,-1,-1,-1] [76,0,0 -> 76MB] -------(4968) PWrite minOut (4960) [0,0,-1,-1,-1] [0,0,0 -> 0MB] -------(4961) ua(maxRC) (4959) [0,0,-1,-1,-1] [76,0,0 -> 76MB] -------(4974) PWrite maxOut (4961) [0,0,-1,-1,-1] [0,0,0 -> 0MB] -------(4962) ua(meanRC) (4959) [0,0,-1,-1,-1] [76,0,0 -> 76MB] -------(4980) PWrite meanOut (4962) [0,0,-1,-1,-1] [0,0,0 -> 0MB] - -min: Double = 3.682402316407263E-8 -max: Double = 0.999999984664141 -mean: Double = 0.49997351913605814 +------(12) TRead Xin [10000,1000,1000,1000,10000000] [0,0,76 -> 76MB] [chkpt], CP +------(13) ua(minRC) (12) [0,0,-1,-1,-1] [76,0,0 -> 76MB], CP +------(21) TWrite minOut (13) [0,0,-1,-1,-1] [0,0,0 -> 0MB], CP +------(14) ua(maxRC) (12) [0,0,-1,-1,-1] [76,0,0 -> 76MB], CP +------(27) TWrite maxOut (14) [0,0,-1,-1,-1] [0,0,0 -> 0MB], CP +------(15) ua(meanRC) (12) [0,0,-1,-1,-1] [76,0,0 -> 76MB], CP +------(33) TWrite meanOut (15) [0,0,-1,-1,-1] [0,0,0 -> 0MB], CP + +min: Double = 5.16651366133658E-9 +max: Double = 0.9999999368927975 +mean: Double = 0.5001096515241128 + +{% endhighlight %} + + + + + +Different explain levels can be set. The explain levels are NONE, HOPS, RUNTIME, RECOMPILE_HOPS, and RECOMPILE_RUNTIME. + +
+ +
+{% highlight scala %} +ml.setExplainLevel(MLContext.ExplainLevel.RUNTIME) +val (min, max, mean) = ml.execute(minMaxMeanScript).getTuple[Double, Double, Double]("minOut", "maxOut", "meanOut") +{% endhighlight %} +
+ +
+{% highlight scala %} +scala> ml.setExplainLevel(MLContext.ExplainLevel.RUNTIME) + +scala> val (min, max, mean) = ml.execute(minMaxMeanScript).getTuple[Double, Double, Double]("minOut", "maxOut", "meanOut") + +PROGRAM ( size CP/SP = 9/0 ) +--MAIN PROGRAM +----GENERIC (lines 1-8) [recompile=false] +------CP uamin Xin.MATRIX.DOUBLE _Var8.SCALAR.DOUBLE 8 +------CP uamax Xin.MATRIX.DOUBLE _Var9.SCALAR.DOUBLE 8 +------CP uamean Xin.MATRIX.DOUBLE _Var10.SCALAR.DOUBLE 8 +------CP assignvar _Var8.SCALAR.DOUBLE.false minOut.SCALAR.DOUBLE +------CP assignvar _Var9.SCALAR.DOUBLE.false maxOut.SCALAR.DOUBLE +------CP assignvar _Var10.SCALAR.DOUBLE.false meanOut.SCALAR.DOUBLE +------CP rmvar _Var8 +------CP rmvar _Var9 +------CP rmvar _Var10 + +min: Double = 5.16651366133658E-9 +max: Double = 0.9999999368927975 +mean: Double = 0.5001096515241128 {% endhighlight %}
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java b/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java index 85e2143b932..8f809f8be32 100644 --- a/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java +++ b/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java @@ -50,6 +50,7 @@ import org.apache.sysml.runtime.instructions.spark.functions.SparkListener; import org.apache.sysml.runtime.matrix.MatrixFormatMetaData; import org.apache.sysml.runtime.matrix.data.OutputInfo; +import org.apache.sysml.utils.Explain.ExplainType; /** * The MLContext API offers programmatic access to SystemML on Spark from @@ -99,9 +100,49 @@ public class MLContext { */ private boolean statistics = false; + /** + * The level and type of program explanation that should be displayed if + * explain is set to true. + */ + private ExplainLevel explainLevel = null; + private List scriptHistoryStrings = new ArrayList(); private Map scripts = new LinkedHashMap(); + /** + * The different explain levels supported by SystemML. + * + */ + public enum ExplainLevel { + /** Explain disabled */ + NONE, + /** Explain program and HOPs */ + HOPS, + /** Explain runtime program */ + RUNTIME, + /** Explain HOPs, including recompile */ + RECOMPILE_HOPS, + /** Explain runtime program, including recompile */ + RECOMPILE_RUNTIME; + + public ExplainType getExplainType() { + switch (this) { + case NONE: + return ExplainType.NONE; + case HOPS: + return ExplainType.HOPS; + case RUNTIME: + return ExplainType.RUNTIME; + case RECOMPILE_HOPS: + return ExplainType.RECOMPILE_HOPS; + case RECOMPILE_RUNTIME: + return ExplainType.RECOMPILE_RUNTIME; + default: + return ExplainType.HOPS; + } + } + }; + /** * Retrieve the currently active MLContext. This is used internally by * SystemML via MLContextProxy. @@ -225,6 +266,7 @@ public void setConfigProperty(String propertyName, String propertyValue) { public MLResults execute(Script script) { ScriptExecutor scriptExecutor = new ScriptExecutor(sparkMonitoringUtil); scriptExecutor.setExplain(explain); + scriptExecutor.setExplainLevel(explainLevel); scriptExecutor.setStatistics(statistics); return execute(script, scriptExecutor); } @@ -311,6 +353,17 @@ public void setExplain(boolean explain) { this.explain = explain; } + /** + * Set the level of program explanation that should be displayed if explain + * is set to true. + * + * @param explainLevel + * the level of program explanation + */ + public void setExplainLevel(ExplainLevel explainLevel) { + this.explainLevel = explainLevel; + } + /** * Used internally by MLContextProxy. * @@ -503,13 +556,13 @@ public void clear() { } public void close() { - //reset static status (refs to sc / mlcontext) + // reset static status (refs to sc / mlcontext) SparkExecutionContext.resetSparkContextStatic(); MLContextProxy.setActive(false); activeMLContext = null; - - //clear local status, but do not stop sc as it - //may be used or stopped externally + + // clear local status, but do not stop sc as it + // may be used or stopped externally clear(); resetConfig(); sc = null; diff --git a/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java b/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java index 4702af27a8f..cd4797cc963 100644 --- a/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java +++ b/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java @@ -26,6 +26,7 @@ import org.apache.commons.lang3.StringUtils; import org.apache.sysml.api.DMLScript; import org.apache.sysml.api.jmlc.JMLCUtils; +import org.apache.sysml.api.mlcontext.MLContext.ExplainLevel; import org.apache.sysml.api.monitoring.SparkMonitoringUtil; import org.apache.sysml.conf.ConfigurationManager; import org.apache.sysml.conf.DMLConfig; @@ -48,6 +49,7 @@ import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory; import org.apache.sysml.utils.Explain; import org.apache.sysml.utils.Explain.ExplainCounts; +import org.apache.sysml.utils.Explain.ExplainType; import org.apache.sysml.utils.Statistics; /** @@ -114,6 +116,7 @@ public class ScriptExecutor { protected Script script; protected boolean explain = false; protected boolean statistics = false; + protected ExplainLevel explainLevel; /** * ScriptExecutor constructor. @@ -197,7 +200,12 @@ protected void rewriteHops() { protected void showExplanation() { if (explain) { try { - System.out.println(Explain.explain(dmlProgram)); + if (explainLevel == null) { + System.out.println(Explain.explain(dmlProgram)); + } else { + ExplainType explainType = explainLevel.getExplainType(); + System.out.println(Explain.explain(dmlProgram, runtimeProgram, explainType)); + } } catch (HopsException e) { throw new MLContextException("Exception occurred while explaining dml program", e); } catch (DMLRuntimeException e) { @@ -276,10 +284,10 @@ protected void createAndInitializeExecutionContext() { *
  • {@link #validateScript()}
  • *
  • {@link #constructHops()}
  • *
  • {@link #rewriteHops()}
  • - *
  • {@link #showExplanation()}
  • *
  • {@link #rewritePersistentReadsAndWrites()}
  • *
  • {@link #constructLops()}
  • *
  • {@link #generateRuntimeProgram()}
  • + *
  • {@link #showExplanation()}
  • *
  • {@link #globalDataFlowOptimization()}
  • *
  • {@link #countCompiledMRJobsAndSparkInstructions()}
  • *
  • {@link #initializeCachingAndScratchSpace()}
  • @@ -304,10 +312,10 @@ public MLResults execute(Script script) { validateScript(); constructHops(); rewriteHops(); - showExplanation(); rewritePersistentReadsAndWrites(); constructLops(); generateRuntimeProgram(); + showExplanation(); globalDataFlowOptimization(); countCompiledMRJobsAndSparkInstructions(); initializeCachingAndScratchSpace(); @@ -621,4 +629,21 @@ public void setStatistics(boolean statistics) { this.statistics = statistics; } + /** + * Set the level of program explanation that should be displayed if explain + * is set to true. + * + * @param explainLevel + * the level of program explanation + */ + public void setExplainLevel(ExplainLevel explainLevel) { + this.explainLevel = explainLevel; + if (explainLevel == null) { + DMLScript.EXPLAIN = ExplainType.NONE; + } else { + ExplainType explainType = explainLevel.getExplainType(); + DMLScript.EXPLAIN = explainType; + } + } + }