diff --git a/docs/spark-mlcontext-programming-guide.md b/docs/spark-mlcontext-programming-guide.md
index 71db1a49e53..d0e1f0fc473 100644
--- a/docs/spark-mlcontext-programming-guide.md
+++ b/docs/spark-mlcontext-programming-guide.md
@@ -1108,6 +1108,7 @@ minOut = min(Xin)
maxOut = max(Xin)
meanOut = mean(Xin)
"""
+val mm = new MatrixMetadata(numRows, numCols)
val minMaxMeanScript = dml(minMaxMean).in("Xin", df, mm).out("minOut", "maxOut", "meanOut")
val (min, max, mean) = ml.execute(minMaxMeanScript).getTuple[Double, Double, Double]("minOut", "maxOut", "meanOut")
@@ -1131,6 +1132,9 @@ maxOut = max(Xin)
meanOut = mean(Xin)
"
+scala> val mm = new MatrixMetadata(numRows, numCols)
+mm: org.apache.sysml.api.mlcontext.MatrixMetadata = rows: 10000, columns: 1000, non-zeros: None, rows per block: None, columns per block: None
+
scala> val minMaxMeanScript = dml(minMaxMean).in("Xin", df, mm).out("minOut", "maxOut", "meanOut")
minMaxMeanScript: org.apache.sysml.api.mlcontext.Script =
Inputs:
@@ -1147,17 +1151,57 @@ scala> val (min, max, mean) = ml.execute(minMaxMeanScript).getTuple[Double, Doub
PROGRAM
--MAIN PROGRAM
----GENERIC (lines 1-8) [recompile=false]
-------(4959) PRead Xin [10000,1000,1000,1000,10000000] [0,0,76 -> 76MB] [chkpt]
-------(4960) ua(minRC) (4959) [0,0,-1,-1,-1] [76,0,0 -> 76MB]
-------(4968) PWrite minOut (4960) [0,0,-1,-1,-1] [0,0,0 -> 0MB]
-------(4961) ua(maxRC) (4959) [0,0,-1,-1,-1] [76,0,0 -> 76MB]
-------(4974) PWrite maxOut (4961) [0,0,-1,-1,-1] [0,0,0 -> 0MB]
-------(4962) ua(meanRC) (4959) [0,0,-1,-1,-1] [76,0,0 -> 76MB]
-------(4980) PWrite meanOut (4962) [0,0,-1,-1,-1] [0,0,0 -> 0MB]
-
-min: Double = 3.682402316407263E-8
-max: Double = 0.999999984664141
-mean: Double = 0.49997351913605814
+------(12) TRead Xin [10000,1000,1000,1000,10000000] [0,0,76 -> 76MB] [chkpt], CP
+------(13) ua(minRC) (12) [0,0,-1,-1,-1] [76,0,0 -> 76MB], CP
+------(21) TWrite minOut (13) [0,0,-1,-1,-1] [0,0,0 -> 0MB], CP
+------(14) ua(maxRC) (12) [0,0,-1,-1,-1] [76,0,0 -> 76MB], CP
+------(27) TWrite maxOut (14) [0,0,-1,-1,-1] [0,0,0 -> 0MB], CP
+------(15) ua(meanRC) (12) [0,0,-1,-1,-1] [76,0,0 -> 76MB], CP
+------(33) TWrite meanOut (15) [0,0,-1,-1,-1] [0,0,0 -> 0MB], CP
+
+min: Double = 5.16651366133658E-9
+max: Double = 0.9999999368927975
+mean: Double = 0.5001096515241128
+
+{% endhighlight %}
+
+
+
+
+
+Different explain levels can be set. The explain levels are NONE, HOPS, RUNTIME, RECOMPILE_HOPS, and RECOMPILE_RUNTIME.
+
+
+
+
+{% highlight scala %}
+ml.setExplainLevel(MLContext.ExplainLevel.RUNTIME)
+val (min, max, mean) = ml.execute(minMaxMeanScript).getTuple[Double, Double, Double]("minOut", "maxOut", "meanOut")
+{% endhighlight %}
+
+
+
+{% highlight scala %}
+scala> ml.setExplainLevel(MLContext.ExplainLevel.RUNTIME)
+
+scala> val (min, max, mean) = ml.execute(minMaxMeanScript).getTuple[Double, Double, Double]("minOut", "maxOut", "meanOut")
+
+PROGRAM ( size CP/SP = 9/0 )
+--MAIN PROGRAM
+----GENERIC (lines 1-8) [recompile=false]
+------CP uamin Xin.MATRIX.DOUBLE _Var8.SCALAR.DOUBLE 8
+------CP uamax Xin.MATRIX.DOUBLE _Var9.SCALAR.DOUBLE 8
+------CP uamean Xin.MATRIX.DOUBLE _Var10.SCALAR.DOUBLE 8
+------CP assignvar _Var8.SCALAR.DOUBLE.false minOut.SCALAR.DOUBLE
+------CP assignvar _Var9.SCALAR.DOUBLE.false maxOut.SCALAR.DOUBLE
+------CP assignvar _Var10.SCALAR.DOUBLE.false meanOut.SCALAR.DOUBLE
+------CP rmvar _Var8
+------CP rmvar _Var9
+------CP rmvar _Var10
+
+min: Double = 5.16651366133658E-9
+max: Double = 0.9999999368927975
+mean: Double = 0.5001096515241128
{% endhighlight %}
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java b/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java
index 85e2143b932..8f809f8be32 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java
@@ -50,6 +50,7 @@
import org.apache.sysml.runtime.instructions.spark.functions.SparkListener;
import org.apache.sysml.runtime.matrix.MatrixFormatMetaData;
import org.apache.sysml.runtime.matrix.data.OutputInfo;
+import org.apache.sysml.utils.Explain.ExplainType;
/**
* The MLContext API offers programmatic access to SystemML on Spark from
@@ -99,9 +100,49 @@ public class MLContext {
*/
private boolean statistics = false;
+ /**
+ * The level and type of program explanation that should be displayed if
+ * explain is set to true.
+ */
+ private ExplainLevel explainLevel = null;
+
private List
scriptHistoryStrings = new ArrayList();
private Map scripts = new LinkedHashMap();
+ /**
+ * The different explain levels supported by SystemML.
+ *
+ */
+ public enum ExplainLevel {
+ /** Explain disabled */
+ NONE,
+ /** Explain program and HOPs */
+ HOPS,
+ /** Explain runtime program */
+ RUNTIME,
+ /** Explain HOPs, including recompile */
+ RECOMPILE_HOPS,
+ /** Explain runtime program, including recompile */
+ RECOMPILE_RUNTIME;
+
+ public ExplainType getExplainType() {
+ switch (this) {
+ case NONE:
+ return ExplainType.NONE;
+ case HOPS:
+ return ExplainType.HOPS;
+ case RUNTIME:
+ return ExplainType.RUNTIME;
+ case RECOMPILE_HOPS:
+ return ExplainType.RECOMPILE_HOPS;
+ case RECOMPILE_RUNTIME:
+ return ExplainType.RECOMPILE_RUNTIME;
+ default:
+ return ExplainType.HOPS;
+ }
+ }
+ };
+
/**
* Retrieve the currently active MLContext. This is used internally by
* SystemML via MLContextProxy.
@@ -225,6 +266,7 @@ public void setConfigProperty(String propertyName, String propertyValue) {
public MLResults execute(Script script) {
ScriptExecutor scriptExecutor = new ScriptExecutor(sparkMonitoringUtil);
scriptExecutor.setExplain(explain);
+ scriptExecutor.setExplainLevel(explainLevel);
scriptExecutor.setStatistics(statistics);
return execute(script, scriptExecutor);
}
@@ -311,6 +353,17 @@ public void setExplain(boolean explain) {
this.explain = explain;
}
+ /**
+ * Set the level of program explanation that should be displayed if explain
+ * is set to true.
+ *
+ * @param explainLevel
+ * the level of program explanation
+ */
+ public void setExplainLevel(ExplainLevel explainLevel) {
+ this.explainLevel = explainLevel;
+ }
+
/**
* Used internally by MLContextProxy.
*
@@ -503,13 +556,13 @@ public void clear() {
}
public void close() {
- //reset static status (refs to sc / mlcontext)
+ // reset static status (refs to sc / mlcontext)
SparkExecutionContext.resetSparkContextStatic();
MLContextProxy.setActive(false);
activeMLContext = null;
-
- //clear local status, but do not stop sc as it
- //may be used or stopped externally
+
+ // clear local status, but do not stop sc as it
+ // may be used or stopped externally
clear();
resetConfig();
sc = null;
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java b/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java
index 4702af27a8f..cd4797cc963 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java
@@ -26,6 +26,7 @@
import org.apache.commons.lang3.StringUtils;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.api.jmlc.JMLCUtils;
+import org.apache.sysml.api.mlcontext.MLContext.ExplainLevel;
import org.apache.sysml.api.monitoring.SparkMonitoringUtil;
import org.apache.sysml.conf.ConfigurationManager;
import org.apache.sysml.conf.DMLConfig;
@@ -48,6 +49,7 @@
import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
import org.apache.sysml.utils.Explain;
import org.apache.sysml.utils.Explain.ExplainCounts;
+import org.apache.sysml.utils.Explain.ExplainType;
import org.apache.sysml.utils.Statistics;
/**
@@ -114,6 +116,7 @@ public class ScriptExecutor {
protected Script script;
protected boolean explain = false;
protected boolean statistics = false;
+ protected ExplainLevel explainLevel;
/**
* ScriptExecutor constructor.
@@ -197,7 +200,12 @@ protected void rewriteHops() {
protected void showExplanation() {
if (explain) {
try {
- System.out.println(Explain.explain(dmlProgram));
+ if (explainLevel == null) {
+ System.out.println(Explain.explain(dmlProgram));
+ } else {
+ ExplainType explainType = explainLevel.getExplainType();
+ System.out.println(Explain.explain(dmlProgram, runtimeProgram, explainType));
+ }
} catch (HopsException e) {
throw new MLContextException("Exception occurred while explaining dml program", e);
} catch (DMLRuntimeException e) {
@@ -276,10 +284,10 @@ protected void createAndInitializeExecutionContext() {
* {@link #validateScript()}
* {@link #constructHops()}
* {@link #rewriteHops()}
- * {@link #showExplanation()}
* {@link #rewritePersistentReadsAndWrites()}
* {@link #constructLops()}
* {@link #generateRuntimeProgram()}
+ * {@link #showExplanation()}
* {@link #globalDataFlowOptimization()}
* {@link #countCompiledMRJobsAndSparkInstructions()}
* {@link #initializeCachingAndScratchSpace()}
@@ -304,10 +312,10 @@ public MLResults execute(Script script) {
validateScript();
constructHops();
rewriteHops();
- showExplanation();
rewritePersistentReadsAndWrites();
constructLops();
generateRuntimeProgram();
+ showExplanation();
globalDataFlowOptimization();
countCompiledMRJobsAndSparkInstructions();
initializeCachingAndScratchSpace();
@@ -621,4 +629,21 @@ public void setStatistics(boolean statistics) {
this.statistics = statistics;
}
+ /**
+ * Set the level of program explanation that should be displayed if explain
+ * is set to true.
+ *
+ * @param explainLevel
+ * the level of program explanation
+ */
+ public void setExplainLevel(ExplainLevel explainLevel) {
+ this.explainLevel = explainLevel;
+ if (explainLevel == null) {
+ DMLScript.EXPLAIN = ExplainType.NONE;
+ } else {
+ ExplainType explainType = explainLevel.getExplainType();
+ DMLScript.EXPLAIN = explainType;
+ }
+ }
+
}