diff --git a/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js
index 0b450dc76bc38..3c8ddddf07b1e 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js
+++ b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js
@@ -19,6 +19,9 @@
* to be registered after the page loads. */
$(function() {
$("span.expand-additional-metrics").click(function(){
+ var status = window.localStorage.getItem("expand-additional-metrics") == "true";
+ status = !status;
+
// Expand the list of additional metrics.
var additionalMetricsDiv = $(this).parent().find('.additional-metrics');
$(additionalMetricsDiv).toggleClass('collapsed');
@@ -26,17 +29,31 @@ $(function() {
// Switch the class of the arrow from open to closed.
$(this).find('.expand-additional-metrics-arrow').toggleClass('arrow-open');
$(this).find('.expand-additional-metrics-arrow').toggleClass('arrow-closed');
+
+ window.localStorage.setItem("expand-additional-metrics", "" + status);
});
+ if (window.localStorage.getItem("expand-additional-metrics") == "true") {
+ // Set it to false so that the click function can revert it
+ window.localStorage.setItem("expand-additional-metrics", "false");
+ $("span.expand-additional-metrics").trigger("click");
+ }
+
stripeSummaryTable();
$('input[type="checkbox"]').click(function() {
- var column = "table ." + $(this).attr("name");
+ var name = $(this).attr("name")
+ var column = "table ." + name;
+ var status = window.localStorage.getItem(name) == "true";
+ status = !status;
$(column).toggle();
stripeSummaryTable();
+ window.localStorage.setItem(name, "" + status);
});
$("#select-all-metrics").click(function() {
+ var status = window.localStorage.getItem("select-all-metrics") == "true";
+ status = !status;
if (this.checked) {
// Toggle all un-checked options.
$('input[type="checkbox"]:not(:checked)').trigger('click');
@@ -44,6 +61,21 @@ $(function() {
// Toggle all checked options.
$('input[type="checkbox"]:checked').trigger('click');
}
+ window.localStorage.setItem("select-all-metrics", "" + status);
+ });
+
+ if (window.localStorage.getItem("select-all-metrics") == "true") {
+ $("#select-all-metrics").attr('checked', status);
+ }
+
+ $("span.additional-metric-title").parent().find('input[type="checkbox"]').each(function() {
+ var name = $(this).attr("name")
+ // If name is undefined, then skip it because it's the "select-all-metrics" checkbox
+ if (name && window.localStorage.getItem(name) == "true") {
+ // Set it to false so that the click function can revert it
+ window.localStorage.setItem(name, "false");
+ $(this).trigger("click")
+ }
});
// Trigger a click on the checkbox if a user clicks the label next to it.
diff --git a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js
index 9fa53baaf4212..4a893bc0189aa 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js
+++ b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js
@@ -72,6 +72,14 @@ var StagePageVizConstants = {
rankSep: 40
};
+/*
+ * Return "expand-dag-viz-arrow-job" if forJob is true.
+ * Otherwise, return "expand-dag-viz-arrow-stage".
+ */
+function expandDagVizArrowKey(forJob) {
+ return forJob ? "expand-dag-viz-arrow-job" : "expand-dag-viz-arrow-stage";
+}
+
/*
* Show or hide the RDD DAG visualization.
*
@@ -79,6 +87,9 @@ var StagePageVizConstants = {
* This is the narrow interface called from the Scala UI code.
*/
function toggleDagViz(forJob) {
+ var status = window.localStorage.getItem(expandDagVizArrowKey(forJob)) == "true";
+ status = !status;
+
var arrowSelector = ".expand-dag-viz-arrow";
$(arrowSelector).toggleClass('arrow-closed');
$(arrowSelector).toggleClass('arrow-open');
@@ -93,8 +104,24 @@ function toggleDagViz(forJob) {
// Save the graph for later so we don't have to render it again
graphContainer().style("display", "none");
}
+
+ window.localStorage.setItem(expandDagVizArrowKey(forJob), "" + status);
}
+$(function (){
+ if (window.localStorage.getItem(expandDagVizArrowKey(false)) == "true") {
+ // Set it to false so that the click function can revert it
+ window.localStorage.setItem(expandDagVizArrowKey(false), "false");
+ toggleDagViz(false);
+ }
+
+ if (window.localStorage.getItem(expandDagVizArrowKey(true)) == "true") {
+ // Set it to false so that the click function can revert it
+ window.localStorage.setItem(expandDagVizArrowKey(true), "false");
+ toggleDagViz(true);
+ }
+});
+
/*
* Render the RDD DAG visualization.
*
diff --git a/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js b/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js
index ca74ef9d7e94e..f4453c71df1ea 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js
+++ b/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js
@@ -66,14 +66,27 @@ function drawApplicationTimeline(groupArray, eventObjArray, startTime) {
setupJobEventAction();
$("span.expand-application-timeline").click(function() {
+ var status = window.localStorage.getItem("expand-application-timeline") == "true";
+ status = !status;
+
$("#application-timeline").toggleClass('collapsed');
// Switch the class of the arrow from open to closed.
$(this).find('.expand-application-timeline-arrow').toggleClass('arrow-open');
$(this).find('.expand-application-timeline-arrow').toggleClass('arrow-closed');
+
+ window.localStorage.setItem("expand-application-timeline", "" + status);
});
}
+$(function (){
+ if (window.localStorage.getItem("expand-application-timeline") == "true") {
+ // Set it to false so that the click function can revert it
+ window.localStorage.setItem("expand-application-timeline", "false");
+ $("span.expand-application-timeline").trigger('click');
+ }
+});
+
function drawJobTimeline(groupArray, eventObjArray, startTime) {
var groups = new vis.DataSet(groupArray);
var items = new vis.DataSet(eventObjArray);
@@ -125,14 +138,27 @@ function drawJobTimeline(groupArray, eventObjArray, startTime) {
setupStageEventAction();
$("span.expand-job-timeline").click(function() {
+ var status = window.localStorage.getItem("expand-job-timeline") == "true";
+ status = !status;
+
$("#job-timeline").toggleClass('collapsed');
// Switch the class of the arrow from open to closed.
$(this).find('.expand-job-timeline-arrow').toggleClass('arrow-open');
$(this).find('.expand-job-timeline-arrow').toggleClass('arrow-closed');
+
+ window.localStorage.setItem("expand-job-timeline", "" + status);
});
}
+$(function (){
+ if (window.localStorage.getItem("expand-job-timeline") == "true") {
+ // Set it to false so that the click function can revert it
+ window.localStorage.setItem("expand-job-timeline", "false");
+ $("span.expand-job-timeline").trigger('click');
+ }
+});
+
function drawTaskAssignmentTimeline(groupArray, eventObjArray, minLaunchTime, maxFinishTime) {
var groups = new vis.DataSet(groupArray);
var items = new vis.DataSet(eventObjArray);
@@ -176,14 +202,27 @@ function drawTaskAssignmentTimeline(groupArray, eventObjArray, minLaunchTime, ma
setupZoomable("#task-assignment-timeline-zoom-lock", taskTimeline);
$("span.expand-task-assignment-timeline").click(function() {
+ var status = window.localStorage.getItem("expand-task-assignment-timeline") == "true";
+ status = !status;
+
$("#task-assignment-timeline").toggleClass("collapsed");
// Switch the class of the arrow from open to closed.
$(this).find(".expand-task-assignment-timeline-arrow").toggleClass("arrow-open");
$(this).find(".expand-task-assignment-timeline-arrow").toggleClass("arrow-closed");
+
+ window.localStorage.setItem("expand-task-assignment-timeline", "" + status);
});
}
+$(function (){
+ if (window.localStorage.getItem("expand-task-assignment-timeline") == "true") {
+ // Set it to false so that the click function can revert it
+ window.localStorage.setItem("expand-task-assignment-timeline", "false");
+ $("span.expand-task-assignment-timeline").trigger('click');
+ }
+});
+
function setupExecutorEventAction() {
$(".item.box.executor").each(function () {
$(this).hover(
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala
index e93eb93124e51..b48836d5c8897 100644
--- a/core/src/main/scala/org/apache/spark/TaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContext.scala
@@ -21,6 +21,7 @@ import java.io.Serializable
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.metrics.source.Source
import org.apache.spark.unsafe.memory.TaskMemoryManager
import org.apache.spark.util.TaskCompletionListener
@@ -148,6 +149,14 @@ abstract class TaskContext extends Serializable {
@DeveloperApi
def taskMetrics(): TaskMetrics
+ /**
+ * ::DeveloperApi::
+ * Returns all metrics sources with the given name which are associated with the instance
+ * which runs the task. For more information see [[org.apache.spark.metrics.MetricsSystem!]].
+ */
+ @DeveloperApi
+ def getMetricsSources(sourceName: String): Seq[Source]
+
/**
* Returns the manager for this task's managed memory.
*/
diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
index 6e394f1b12445..9ee168ae016f8 100644
--- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
@@ -20,6 +20,8 @@ package org.apache.spark
import scala.collection.mutable.{ArrayBuffer, HashMap}
import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.metrics.MetricsSystem
+import org.apache.spark.metrics.source.Source
import org.apache.spark.unsafe.memory.TaskMemoryManager
import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException}
@@ -29,6 +31,7 @@ private[spark] class TaskContextImpl(
override val taskAttemptId: Long,
override val attemptNumber: Int,
override val taskMemoryManager: TaskMemoryManager,
+ @transient private val metricsSystem: MetricsSystem,
val runningLocally: Boolean = false,
val taskMetrics: TaskMetrics = TaskMetrics.empty)
extends TaskContext
@@ -95,6 +98,9 @@ private[spark] class TaskContextImpl(
override def isInterrupted(): Boolean = interrupted
+ override def getMetricsSources(sourceName: String): Seq[Source] =
+ metricsSystem.getSourcesByName(sourceName)
+
@transient private val accumulators = new HashMap[Long, Accumulable[_, _]]
private[spark] override def registerAccumulator(a: Accumulable[_, _]): Unit = synchronized {
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 9087debde8c41..66624ffbe4790 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -210,7 +210,10 @@ private[spark] class Executor(
// Run the actual task and measure its runtime.
taskStart = System.currentTimeMillis()
val (value, accumUpdates) = try {
- task.run(taskAttemptId = taskId, attemptNumber = attemptNumber)
+ task.run(
+ taskAttemptId = taskId,
+ attemptNumber = attemptNumber,
+ metricsSystem = env.metricsSystem)
} finally {
// Note: this memory freeing logic is duplicated in DAGScheduler.runLocallyWithinThread;
// when changing this, make sure to update both copies.
diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
index 67f64d5e278de..4517f465ebd3b 100644
--- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
@@ -142,6 +142,9 @@ private[spark] class MetricsSystem private (
} else { defaultName }
}
+ def getSourcesByName(sourceName: String): Seq[Source] =
+ sources.filter(_.sourceName == sourceName)
+
def registerSource(source: Source) {
sources += source
try {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 71a219a4f3414..b829d06923404 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -682,6 +682,7 @@ class DAGScheduler(
taskAttemptId = 0,
attemptNumber = 0,
taskMemoryManager = taskMemoryManager,
+ metricsSystem = env.metricsSystem,
runningLocally = true)
TaskContext.setTaskContext(taskContext)
try {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index 76a19aeac4679..d11a00956a9a9 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -22,6 +22,7 @@ import java.nio.ByteBuffer
import scala.collection.mutable.HashMap
+import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.{TaskContextImpl, TaskContext}
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.serializer.SerializerInstance
@@ -61,13 +62,18 @@ private[spark] abstract class Task[T](
* @param attemptNumber how many times this task has been attempted (0 for the first attempt)
* @return the result of the task along with updates of Accumulators.
*/
- final def run(taskAttemptId: Long, attemptNumber: Int): (T, AccumulatorUpdates) = {
+ final def run(
+ taskAttemptId: Long,
+ attemptNumber: Int,
+ metricsSystem: MetricsSystem)
+ : (T, AccumulatorUpdates) = {
context = new TaskContextImpl(
stageId = stageId,
partitionId = partitionId,
taskAttemptId = taskAttemptId,
attemptNumber = attemptNumber,
taskMemoryManager = taskMemoryManager,
+ metricsSystem = metricsSystem,
runningLocally = false)
TaskContext.setTaskContext(context)
context.taskMetrics.setHostname(Utils.localHostName())
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index f14c603ac6891..c65b3e517773e 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -169,9 +169,12 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
// Make fake resource offers on all executors
private def makeOffers() {
- launchTasks(scheduler.resourceOffers(executorDataMap.map { case (id, executorData) =>
+ // Filter out executors under killing
+ val activeExecutors = executorDataMap.filterKeys(!executorsPendingToRemove.contains(_))
+ val workOffers = activeExecutors.map { case (id, executorData) =>
new WorkerOffer(id, executorData.executorHost, executorData.freeCores)
- }.toSeq))
+ }.toSeq
+ launchTasks(scheduler.resourceOffers(workOffers))
}
override def onDisconnected(remoteAddress: RpcAddress): Unit = {
@@ -181,9 +184,13 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
// Make fake resource offers on just one executor
private def makeOffers(executorId: String) {
- val executorData = executorDataMap(executorId)
- launchTasks(scheduler.resourceOffers(
- Seq(new WorkerOffer(executorId, executorData.executorHost, executorData.freeCores))))
+ // Filter out executors under killing
+ if (!executorsPendingToRemove.contains(executorId)) {
+ val executorData = executorDataMap(executorId)
+ val workOffers = Seq(
+ new WorkerOffer(executorId, executorData.executorHost, executorData.freeCores))
+ launchTasks(scheduler.resourceOffers(workOffers))
+ }
}
// Launch tasks returned by a set of resource offers
diff --git a/core/src/main/scala/org/apache/spark/ui/PagedTable.scala b/core/src/main/scala/org/apache/spark/ui/PagedTable.scala
new file mode 100644
index 0000000000000..17d7b39c2d951
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ui/PagedTable.scala
@@ -0,0 +1,246 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui
+
+import scala.xml.{Node, Unparsed}
+
+/**
+ * A data source that provides data for a page.
+ *
+ * @param pageSize the number of rows in a page
+ */
+private[ui] abstract class PagedDataSource[T](val pageSize: Int) {
+
+ if (pageSize <= 0) {
+ throw new IllegalArgumentException("Page size must be positive")
+ }
+
+ /**
+ * Return the size of all data.
+ */
+ protected def dataSize: Int
+
+ /**
+ * Slice a range of data.
+ */
+ protected def sliceData(from: Int, to: Int): Seq[T]
+
+ /**
+ * Slice the data for this page
+ */
+ def pageData(page: Int): PageData[T] = {
+ val totalPages = (dataSize + pageSize - 1) / pageSize
+ if (page <= 0 || page > totalPages) {
+ throw new IndexOutOfBoundsException(
+ s"Page $page is out of range. Please select a page number between 1 and $totalPages.")
+ }
+ val from = (page - 1) * pageSize
+ val to = dataSize.min(page * pageSize)
+ PageData(totalPages, sliceData(from, to))
+ }
+
+}
+
+/**
+ * The data returned by `PagedDataSource.pageData`, including the page number, the number of total
+ * pages and the data in this page.
+ */
+private[ui] case class PageData[T](totalPage: Int, data: Seq[T])
+
+/**
+ * A paged table that will generate a HTML table for a specified page and also the page navigation.
+ */
+private[ui] trait PagedTable[T] {
+
+ def tableId: String
+
+ def tableCssClass: String
+
+ def dataSource: PagedDataSource[T]
+
+ def headers: Seq[Node]
+
+ def row(t: T): Seq[Node]
+
+ def table(page: Int): Seq[Node] = {
+ val _dataSource = dataSource
+ try {
+ val PageData(totalPages, data) = _dataSource.pageData(page)
+
+ {pageNavigation(page, _dataSource.pageSize, totalPages)}
+
+ {headers}
+
+ {data.map(row)}
+
+
+
+ } catch {
+ case e: IndexOutOfBoundsException =>
+ val PageData(totalPages, _) = _dataSource.pageData(1)
+
+ {pageNavigation(1, _dataSource.pageSize, totalPages)}
+
{e.getMessage}
+
+ }
+ }
+
+ /**
+ * Return a page navigation.
+ *
+ * - If the totalPages is 1, the page navigation will be empty
+ * -
+ * If the totalPages is more than 1, it will create a page navigation including a group of
+ * page numbers and a form to submit the page number.
+ *
+ *
+ *
+ * Here are some examples of the page navigation:
+ * {{{
+ * << < 11 12 13* 14 15 16 17 18 19 20 > >>
+ *
+ * This is the first group, so "<<" is hidden.
+ * < 1 2* 3 4 5 6 7 8 9 10 > >>
+ *
+ * This is the first group and the first page, so "<<" and "<" are hidden.
+ * 1* 2 3 4 5 6 7 8 9 10 > >>
+ *
+ * Assume totalPages is 19. This is the last group, so ">>" is hidden.
+ * << < 11 12 13* 14 15 16 17 18 19 >
+ *
+ * Assume totalPages is 19. This is the last group and the last page, so ">>" and ">" are hidden.
+ * << < 11 12 13 14 15 16 17 18 19*
+ *
+ * * means the current page number
+ * << means jumping to the first page of the previous group.
+ * < means jumping to the previous page.
+ * >> means jumping to the first page of the next group.
+ * > means jumping to the next page.
+ * }}}
+ */
+ private[ui] def pageNavigation(page: Int, pageSize: Int, totalPages: Int): Seq[Node] = {
+ if (totalPages == 1) {
+ Nil
+ } else {
+ // A group includes all page numbers will be shown in the page navigation.
+ // The size of group is 10 means there are 10 page numbers will be shown.
+ // The first group is 1 to 10, the second is 2 to 20, and so on
+ val groupSize = 10
+ val firstGroup = 0
+ val lastGroup = (totalPages - 1) / groupSize
+ val currentGroup = (page - 1) / groupSize
+ val startPage = currentGroup * groupSize + 1
+ val endPage = totalPages.min(startPage + groupSize - 1)
+ val pageTags = (startPage to endPage).map { p =>
+ if (p == page) {
+ // The current page should be disabled so that it cannot be clicked.
+ {p}
+ } else {
+ {p}
+ }
+ }
+ val (goButtonJsFuncName, goButtonJsFunc) = goButtonJavascriptFunction
+ // When clicking the "Go" button, it will call this javascript method and then call
+ // "goButtonJsFuncName"
+ val formJs =
+ s"""$$(function(){
+ | $$( "#form-task-page" ).submit(function(event) {
+ | var page = $$("#form-task-page-no").val()
+ | var pageSize = $$("#form-task-page-size").val()
+ | pageSize = pageSize ? pageSize: 100;
+ | if (page != "") {
+ | ${goButtonJsFuncName}(page, pageSize);
+ | }
+ | event.preventDefault();
+ | });
+ |});
+ """.stripMargin
+
+
+ }
+ }
+
+ /**
+ * Return a link to jump to a page.
+ */
+ def pageLink(page: Int): String
+
+ /**
+ * Only the implementation knows how to create the url with a page number and the page size, so we
+ * leave this one to the implementation. The implementation should create a JavaScript method that
+ * accepts a page number along with the page size and jumps to the page. The return value is this
+ * method name and its JavaScript codes.
+ */
+ def goButtonJavascriptFunction: (String, String)
+}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
index 6e077bf3e70d5..cf04b5e59239b 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
@@ -17,6 +17,7 @@
package org.apache.spark.ui.jobs
+import java.net.URLEncoder
import java.util.Date
import javax.servlet.http.HttpServletRequest
@@ -27,13 +28,14 @@ import org.apache.commons.lang3.StringEscapeUtils
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo}
-import org.apache.spark.ui.{ToolTips, WebUIPage, UIUtils}
+import org.apache.spark.ui._
import org.apache.spark.ui.jobs.UIData._
-import org.apache.spark.ui.scope.RDDOperationGraph
import org.apache.spark.util.{Utils, Distribution}
/** Page showing statistics and task list for a given stage */
private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
+ import StagePage._
+
private val progressListener = parent.progressListener
private val operationGraphListener = parent.operationGraphListener
@@ -74,6 +76,16 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
val parameterAttempt = request.getParameter("attempt")
require(parameterAttempt != null && parameterAttempt.nonEmpty, "Missing attempt parameter")
+ val parameterTaskPage = request.getParameter("task.page")
+ val parameterTaskSortColumn = request.getParameter("task.sort")
+ val parameterTaskSortDesc = request.getParameter("task.desc")
+ val parameterTaskPageSize = request.getParameter("task.pageSize")
+
+ val taskPage = Option(parameterTaskPage).map(_.toInt).getOrElse(1)
+ val taskSortColumn = Option(parameterTaskSortColumn).getOrElse("Index")
+ val taskSortDesc = Option(parameterTaskSortDesc).map(_.toBoolean).getOrElse(false)
+ val taskPageSize = Option(parameterTaskPageSize).map(_.toInt).getOrElse(100)
+
// If this is set, expand the dag visualization by default
val expandDagVizParam = request.getParameter("expandDagViz")
val expandDagViz = expandDagVizParam != null && expandDagVizParam.toBoolean
@@ -231,52 +243,47 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
accumulableRow,
accumulables.values.toSeq)
- val taskHeadersAndCssClasses: Seq[(String, String)] =
- Seq(
- ("Index", ""), ("ID", ""), ("Attempt", ""), ("Status", ""), ("Locality Level", ""),
- ("Executor ID / Host", ""), ("Launch Time", ""), ("Duration", ""),
- ("Scheduler Delay", TaskDetailsClassNames.SCHEDULER_DELAY),
- ("Task Deserialization Time", TaskDetailsClassNames.TASK_DESERIALIZATION_TIME),
- ("GC Time", ""),
- ("Result Serialization Time", TaskDetailsClassNames.RESULT_SERIALIZATION_TIME),
- ("Getting Result Time", TaskDetailsClassNames.GETTING_RESULT_TIME)) ++
- {if (hasAccumulators) Seq(("Accumulators", "")) else Nil} ++
- {if (stageData.hasInput) Seq(("Input Size / Records", "")) else Nil} ++
- {if (stageData.hasOutput) Seq(("Output Size / Records", "")) else Nil} ++
- {if (stageData.hasShuffleRead) {
- Seq(("Shuffle Read Blocked Time", TaskDetailsClassNames.SHUFFLE_READ_BLOCKED_TIME),
- ("Shuffle Read Size / Records", ""),
- ("Shuffle Remote Reads", TaskDetailsClassNames.SHUFFLE_READ_REMOTE_SIZE))
- } else {
- Nil
- }} ++
- {if (stageData.hasShuffleWrite) {
- Seq(("Write Time", ""), ("Shuffle Write Size / Records", ""))
- } else {
- Nil
- }} ++
- {if (stageData.hasBytesSpilled) {
- Seq(("Shuffle Spill (Memory)", ""), ("Shuffle Spill (Disk)", ""))
- } else {
- Nil
- }} ++
- Seq(("Errors", ""))
-
- val unzipped = taskHeadersAndCssClasses.unzip
-
val currentTime = System.currentTimeMillis()
- val taskTable = UIUtils.listingTable(
- unzipped._1,
- taskRow(
+ val (taskTable, taskTableHTML) = try {
+ val _taskTable = new TaskPagedTable(
+ UIUtils.prependBaseUri(parent.basePath) +
+ s"/stages/stage?id=${stageId}&attempt=${stageAttemptId}",
+ tasks,
hasAccumulators,
stageData.hasInput,
stageData.hasOutput,
stageData.hasShuffleRead,
stageData.hasShuffleWrite,
stageData.hasBytesSpilled,
- currentTime),
- tasks,
- headerClasses = unzipped._2)
+ currentTime,
+ pageSize = taskPageSize,
+ sortColumn = taskSortColumn,
+ desc = taskSortDesc
+ )
+ (_taskTable, _taskTable.table(taskPage))
+ } catch {
+ case e @ (_ : IllegalArgumentException | _ : IndexOutOfBoundsException) =>
+ (null, {e.getMessage}
)
+ }
+
+ val jsForScrollingDownToTaskTable =
+
+
+ val taskIdsInPage = if (taskTable == null) Set.empty[Long]
+ else taskTable.dataSource.slicedTaskIds
+
// Excludes tasks which failed and have incomplete metrics
val validTasks = tasks.filter(t => t.taskInfo.status == "SUCCESS" && t.taskMetrics.isDefined)
@@ -499,12 +506,15 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
dagViz ++
maybeExpandDagViz ++
showAdditionalMetrics ++
- makeTimeline(stageData.taskData.values.toSeq, currentTime) ++
+ makeTimeline(
+ // Only show the tasks in the table
+ stageData.taskData.values.toSeq.filter(t => taskIdsInPage.contains(t.taskInfo.taskId)),
+ currentTime) ++
Summary Metrics for {numCompleted} Completed Tasks
++
{summaryTable.getOrElse("No tasks have reported metrics yet.")}
++
Aggregated Metrics by Executor
++ executorTable.toNodeSeq ++
maybeAccumulableTable ++
- Tasks
++ taskTable
+ Tasks
++ taskTableHTML ++ jsForScrollingDownToTaskTable
UIUtils.headerSparkPage(stageHeader, content, parent, showVisualization = true)
}
}
@@ -679,164 +689,619 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
}
- def taskRow(
- hasAccumulators: Boolean,
- hasInput: Boolean,
- hasOutput: Boolean,
- hasShuffleRead: Boolean,
- hasShuffleWrite: Boolean,
- hasBytesSpilled: Boolean,
- currentTime: Long)(taskData: TaskUIData): Seq[Node] = {
- taskData match { case TaskUIData(info, metrics, errorMessage) =>
- val duration = if (info.status == "RUNNING") info.timeRunning(currentTime)
- else metrics.map(_.executorRunTime).getOrElse(1L)
- val formatDuration = if (info.status == "RUNNING") UIUtils.formatDuration(duration)
- else metrics.map(m => UIUtils.formatDuration(m.executorRunTime)).getOrElse("")
- val schedulerDelay = metrics.map(getSchedulerDelay(info, _, currentTime)).getOrElse(0L)
- val gcTime = metrics.map(_.jvmGCTime).getOrElse(0L)
- val taskDeserializationTime = metrics.map(_.executorDeserializeTime).getOrElse(0L)
- val serializationTime = metrics.map(_.resultSerializationTime).getOrElse(0L)
- val gettingResultTime = getGettingResultTime(info, currentTime)
-
- val maybeAccumulators = info.accumulables
- val accumulatorsReadable = maybeAccumulators.map { acc =>
- StringEscapeUtils.escapeHtml4(s"${acc.name}: ${acc.update.get}")
+}
+
+private[ui] object StagePage {
+ private[ui] def getGettingResultTime(info: TaskInfo, currentTime: Long): Long = {
+ if (info.gettingResult) {
+ if (info.finished) {
+ info.finishTime - info.gettingResultTime
+ } else {
+ // The task is still fetching the result.
+ currentTime - info.gettingResultTime
}
+ } else {
+ 0L
+ }
+ }
- val maybeInput = metrics.flatMap(_.inputMetrics)
- val inputSortable = maybeInput.map(_.bytesRead.toString).getOrElse("")
- val inputReadable = maybeInput
- .map(m => s"${Utils.bytesToString(m.bytesRead)} (${m.readMethod.toString.toLowerCase()})")
- .getOrElse("")
- val inputRecords = maybeInput.map(_.recordsRead.toString).getOrElse("")
-
- val maybeOutput = metrics.flatMap(_.outputMetrics)
- val outputSortable = maybeOutput.map(_.bytesWritten.toString).getOrElse("")
- val outputReadable = maybeOutput
- .map(m => s"${Utils.bytesToString(m.bytesWritten)}")
- .getOrElse("")
- val outputRecords = maybeOutput.map(_.recordsWritten.toString).getOrElse("")
-
- val maybeShuffleRead = metrics.flatMap(_.shuffleReadMetrics)
- val shuffleReadBlockedTimeSortable = maybeShuffleRead
- .map(_.fetchWaitTime.toString).getOrElse("")
- val shuffleReadBlockedTimeReadable =
- maybeShuffleRead.map(ms => UIUtils.formatDuration(ms.fetchWaitTime)).getOrElse("")
-
- val totalShuffleBytes = maybeShuffleRead.map(_.totalBytesRead)
- val shuffleReadSortable = totalShuffleBytes.map(_.toString).getOrElse("")
- val shuffleReadReadable = totalShuffleBytes.map(Utils.bytesToString).getOrElse("")
- val shuffleReadRecords = maybeShuffleRead.map(_.recordsRead.toString).getOrElse("")
-
- val remoteShuffleBytes = maybeShuffleRead.map(_.remoteBytesRead)
- val shuffleReadRemoteSortable = remoteShuffleBytes.map(_.toString).getOrElse("")
- val shuffleReadRemoteReadable = remoteShuffleBytes.map(Utils.bytesToString).getOrElse("")
-
- val maybeShuffleWrite = metrics.flatMap(_.shuffleWriteMetrics)
- val shuffleWriteSortable = maybeShuffleWrite.map(_.shuffleBytesWritten.toString).getOrElse("")
- val shuffleWriteReadable = maybeShuffleWrite
- .map(m => s"${Utils.bytesToString(m.shuffleBytesWritten)}").getOrElse("")
- val shuffleWriteRecords = maybeShuffleWrite
- .map(_.shuffleRecordsWritten.toString).getOrElse("")
-
- val maybeWriteTime = metrics.flatMap(_.shuffleWriteMetrics).map(_.shuffleWriteTime)
- val writeTimeSortable = maybeWriteTime.map(_.toString).getOrElse("")
- val writeTimeReadable = maybeWriteTime.map(t => t / (1000 * 1000)).map { ms =>
- if (ms == 0) "" else UIUtils.formatDuration(ms)
- }.getOrElse("")
-
- val maybeMemoryBytesSpilled = metrics.map(_.memoryBytesSpilled)
- val memoryBytesSpilledSortable = maybeMemoryBytesSpilled.map(_.toString).getOrElse("")
- val memoryBytesSpilledReadable =
- maybeMemoryBytesSpilled.map(Utils.bytesToString).getOrElse("")
-
- val maybeDiskBytesSpilled = metrics.map(_.diskBytesSpilled)
- val diskBytesSpilledSortable = maybeDiskBytesSpilled.map(_.toString).getOrElse("")
- val diskBytesSpilledReadable = maybeDiskBytesSpilled.map(Utils.bytesToString).getOrElse("")
-
-
- {info.index} |
- {info.taskId} |
- {
- if (info.speculative) s"${info.attempt} (speculative)" else info.attempt.toString
- } |
- {info.status} |
- {info.taskLocality} |
- {info.executorId} / {info.host} |
- {UIUtils.formatDate(new Date(info.launchTime))} |
-
- {formatDuration}
- |
-
- {UIUtils.formatDuration(schedulerDelay.toLong)}
- |
-
- {UIUtils.formatDuration(taskDeserializationTime.toLong)}
- |
-
- {if (gcTime > 0) UIUtils.formatDuration(gcTime) else ""}
- |
-
- {UIUtils.formatDuration(serializationTime)}
- |
-
- {UIUtils.formatDuration(gettingResultTime)}
- |
- {if (hasAccumulators) {
-
- {Unparsed(accumulatorsReadable.mkString(" "))}
- |
- }}
- {if (hasInput) {
-
- {s"$inputReadable / $inputRecords"}
- |
- }}
- {if (hasOutput) {
-
- {s"$outputReadable / $outputRecords"}
- |
- }}
+ private[ui] def getSchedulerDelay(
+ info: TaskInfo, metrics: TaskMetrics, currentTime: Long): Long = {
+ if (info.finished) {
+ val totalExecutionTime = info.finishTime - info.launchTime
+ val executorOverhead = (metrics.executorDeserializeTime +
+ metrics.resultSerializationTime)
+ math.max(
+ 0,
+ totalExecutionTime - metrics.executorRunTime - executorOverhead -
+ getGettingResultTime(info, currentTime))
+ } else {
+ // The task is still running and the metrics like executorRunTime are not available.
+ 0L
+ }
+ }
+}
+
+private[ui] case class TaskTableRowInputData(inputSortable: Long, inputReadable: String)
+
+private[ui] case class TaskTableRowOutputData(outputSortable: Long, outputReadable: String)
+
+private[ui] case class TaskTableRowShuffleReadData(
+ shuffleReadBlockedTimeSortable: Long,
+ shuffleReadBlockedTimeReadable: String,
+ shuffleReadSortable: Long,
+ shuffleReadReadable: String,
+ shuffleReadRemoteSortable: Long,
+ shuffleReadRemoteReadable: String)
+
+private[ui] case class TaskTableRowShuffleWriteData(
+ writeTimeSortable: Long,
+ writeTimeReadable: String,
+ shuffleWriteSortable: Long,
+ shuffleWriteReadable: String)
+
+private[ui] case class TaskTableRowBytesSpilledData(
+ memoryBytesSpilledSortable: Long,
+ memoryBytesSpilledReadable: String,
+ diskBytesSpilledSortable: Long,
+ diskBytesSpilledReadable: String)
+
+/**
+ * Contains all data that needs for sorting and generating HTML. Using this one rather than
+ * TaskUIData to avoid creating duplicate contents during sorting the data.
+ */
+private[ui] case class TaskTableRowData(
+ index: Int,
+ taskId: Long,
+ attempt: Int,
+ speculative: Boolean,
+ status: String,
+ taskLocality: String,
+ executorIdAndHost: String,
+ launchTime: Long,
+ duration: Long,
+ formatDuration: String,
+ schedulerDelay: Long,
+ taskDeserializationTime: Long,
+ gcTime: Long,
+ serializationTime: Long,
+ gettingResultTime: Long,
+ accumulators: Option[String], // HTML
+ input: Option[TaskTableRowInputData],
+ output: Option[TaskTableRowOutputData],
+ shuffleRead: Option[TaskTableRowShuffleReadData],
+ shuffleWrite: Option[TaskTableRowShuffleWriteData],
+ bytesSpilled: Option[TaskTableRowBytesSpilledData],
+ error: String)
+
+private[ui] class TaskDataSource(
+ tasks: Seq[TaskUIData],
+ hasAccumulators: Boolean,
+ hasInput: Boolean,
+ hasOutput: Boolean,
+ hasShuffleRead: Boolean,
+ hasShuffleWrite: Boolean,
+ hasBytesSpilled: Boolean,
+ currentTime: Long,
+ pageSize: Int,
+ sortColumn: String,
+ desc: Boolean) extends PagedDataSource[TaskTableRowData](pageSize) {
+ import StagePage._
+
+ // Convert TaskUIData to TaskTableRowData which contains the final contents to show in the table
+ // so that we can avoid creating duplicate contents during sorting the data
+ private val data = tasks.map(taskRow).sorted(ordering(sortColumn, desc))
+
+ private var _slicedTaskIds: Set[Long] = null
+
+ override def dataSize: Int = data.size
+
+ override def sliceData(from: Int, to: Int): Seq[TaskTableRowData] = {
+ val r = data.slice(from, to)
+ _slicedTaskIds = r.map(_.taskId).toSet
+ r
+ }
+
+ def slicedTaskIds: Set[Long] = _slicedTaskIds
+
+ private def taskRow(taskData: TaskUIData): TaskTableRowData = {
+ val TaskUIData(info, metrics, errorMessage) = taskData
+ val duration = if (info.status == "RUNNING") info.timeRunning(currentTime)
+ else metrics.map(_.executorRunTime).getOrElse(1L)
+ val formatDuration = if (info.status == "RUNNING") UIUtils.formatDuration(duration)
+ else metrics.map(m => UIUtils.formatDuration(m.executorRunTime)).getOrElse("")
+ val schedulerDelay = metrics.map(getSchedulerDelay(info, _, currentTime)).getOrElse(0L)
+ val gcTime = metrics.map(_.jvmGCTime).getOrElse(0L)
+ val taskDeserializationTime = metrics.map(_.executorDeserializeTime).getOrElse(0L)
+ val serializationTime = metrics.map(_.resultSerializationTime).getOrElse(0L)
+ val gettingResultTime = getGettingResultTime(info, currentTime)
+
+ val maybeAccumulators = info.accumulables
+ val accumulatorsReadable = maybeAccumulators.map { acc =>
+ StringEscapeUtils.escapeHtml4(s"${acc.name}: ${acc.update.get}")
+ }
+
+ val maybeInput = metrics.flatMap(_.inputMetrics)
+ val inputSortable = maybeInput.map(_.bytesRead).getOrElse(0L)
+ val inputReadable = maybeInput
+ .map(m => s"${Utils.bytesToString(m.bytesRead)} (${m.readMethod.toString.toLowerCase()})")
+ .getOrElse("")
+ val inputRecords = maybeInput.map(_.recordsRead.toString).getOrElse("")
+
+ val maybeOutput = metrics.flatMap(_.outputMetrics)
+ val outputSortable = maybeOutput.map(_.bytesWritten).getOrElse(0L)
+ val outputReadable = maybeOutput
+ .map(m => s"${Utils.bytesToString(m.bytesWritten)}")
+ .getOrElse("")
+ val outputRecords = maybeOutput.map(_.recordsWritten.toString).getOrElse("")
+
+ val maybeShuffleRead = metrics.flatMap(_.shuffleReadMetrics)
+ val shuffleReadBlockedTimeSortable = maybeShuffleRead.map(_.fetchWaitTime).getOrElse(0L)
+ val shuffleReadBlockedTimeReadable =
+ maybeShuffleRead.map(ms => UIUtils.formatDuration(ms.fetchWaitTime)).getOrElse("")
+
+ val totalShuffleBytes = maybeShuffleRead.map(_.totalBytesRead)
+ val shuffleReadSortable = totalShuffleBytes.getOrElse(0L)
+ val shuffleReadReadable = totalShuffleBytes.map(Utils.bytesToString).getOrElse("")
+ val shuffleReadRecords = maybeShuffleRead.map(_.recordsRead.toString).getOrElse("")
+
+ val remoteShuffleBytes = maybeShuffleRead.map(_.remoteBytesRead)
+ val shuffleReadRemoteSortable = remoteShuffleBytes.getOrElse(0L)
+ val shuffleReadRemoteReadable = remoteShuffleBytes.map(Utils.bytesToString).getOrElse("")
+
+ val maybeShuffleWrite = metrics.flatMap(_.shuffleWriteMetrics)
+ val shuffleWriteSortable = maybeShuffleWrite.map(_.shuffleBytesWritten).getOrElse(0L)
+ val shuffleWriteReadable = maybeShuffleWrite
+ .map(m => s"${Utils.bytesToString(m.shuffleBytesWritten)}").getOrElse("")
+ val shuffleWriteRecords = maybeShuffleWrite
+ .map(_.shuffleRecordsWritten.toString).getOrElse("")
+
+ val maybeWriteTime = metrics.flatMap(_.shuffleWriteMetrics).map(_.shuffleWriteTime)
+ val writeTimeSortable = maybeWriteTime.getOrElse(0L)
+ val writeTimeReadable = maybeWriteTime.map(t => t / (1000 * 1000)).map { ms =>
+ if (ms == 0) "" else UIUtils.formatDuration(ms)
+ }.getOrElse("")
+
+ val maybeMemoryBytesSpilled = metrics.map(_.memoryBytesSpilled)
+ val memoryBytesSpilledSortable = maybeMemoryBytesSpilled.getOrElse(0L)
+ val memoryBytesSpilledReadable =
+ maybeMemoryBytesSpilled.map(Utils.bytesToString).getOrElse("")
+
+ val maybeDiskBytesSpilled = metrics.map(_.diskBytesSpilled)
+ val diskBytesSpilledSortable = maybeDiskBytesSpilled.getOrElse(0L)
+ val diskBytesSpilledReadable = maybeDiskBytesSpilled.map(Utils.bytesToString).getOrElse("")
+
+ val input =
+ if (hasInput) {
+ Some(TaskTableRowInputData(inputSortable, s"$inputReadable / $inputRecords"))
+ } else {
+ None
+ }
+
+ val output =
+ if (hasOutput) {
+ Some(TaskTableRowOutputData(outputSortable, s"$outputReadable / $outputRecords"))
+ } else {
+ None
+ }
+
+ val shuffleRead =
+ if (hasShuffleRead) {
+ Some(TaskTableRowShuffleReadData(
+ shuffleReadBlockedTimeSortable,
+ shuffleReadBlockedTimeReadable,
+ shuffleReadSortable,
+ s"$shuffleReadReadable / $shuffleReadRecords",
+ shuffleReadRemoteSortable,
+ shuffleReadRemoteReadable
+ ))
+ } else {
+ None
+ }
+
+ val shuffleWrite =
+ if (hasShuffleWrite) {
+ Some(TaskTableRowShuffleWriteData(
+ writeTimeSortable,
+ writeTimeReadable,
+ shuffleWriteSortable,
+ s"$shuffleWriteReadable / $shuffleWriteRecords"
+ ))
+ } else {
+ None
+ }
+
+ val bytesSpilled =
+ if (hasBytesSpilled) {
+ Some(TaskTableRowBytesSpilledData(
+ memoryBytesSpilledSortable,
+ memoryBytesSpilledReadable,
+ diskBytesSpilledSortable,
+ diskBytesSpilledReadable
+ ))
+ } else {
+ None
+ }
+
+ TaskTableRowData(
+ info.index,
+ info.taskId,
+ info.attempt,
+ info.speculative,
+ info.status,
+ info.taskLocality.toString,
+ s"${info.executorId} / ${info.host}",
+ info.launchTime,
+ duration,
+ formatDuration,
+ schedulerDelay,
+ taskDeserializationTime,
+ gcTime,
+ serializationTime,
+ gettingResultTime,
+ if (hasAccumulators) Some(accumulatorsReadable.mkString("
")) else None,
+ input,
+ output,
+ shuffleRead,
+ shuffleWrite,
+ bytesSpilled,
+ errorMessage.getOrElse("")
+ )
+ }
+
+ /**
+ * Return Ordering according to sortColumn and desc
+ */
+ private def ordering(sortColumn: String, desc: Boolean): Ordering[TaskTableRowData] = {
+ val ordering = sortColumn match {
+ case "Index" => new Ordering[TaskTableRowData] {
+ override def compare(x: TaskTableRowData, y: TaskTableRowData): Int =
+ Ordering.Int.compare(x.index, y.index)
+ }
+ case "ID" => new Ordering[TaskTableRowData] {
+ override def compare(x: TaskTableRowData, y: TaskTableRowData): Int =
+ Ordering.Long.compare(x.taskId, y.taskId)
+ }
+ case "Attempt" => new Ordering[TaskTableRowData] {
+ override def compare(x: TaskTableRowData, y: TaskTableRowData): Int =
+ Ordering.Int.compare(x.attempt, y.attempt)
+ }
+ case "Status" => new Ordering[TaskTableRowData] {
+ override def compare(x: TaskTableRowData, y: TaskTableRowData): Int =
+ Ordering.String.compare(x.status, y.status)
+ }
+ case "Locality Level" => new Ordering[TaskTableRowData] {
+ override def compare(x: TaskTableRowData, y: TaskTableRowData): Int =
+ Ordering.String.compare(x.taskLocality, y.taskLocality)
+ }
+ case "Executor ID / Host" => new Ordering[TaskTableRowData] {
+ override def compare(x: TaskTableRowData, y: TaskTableRowData): Int =
+ Ordering.String.compare(x.executorIdAndHost, y.executorIdAndHost)
+ }
+ case "Launch Time" => new Ordering[TaskTableRowData] {
+ override def compare(x: TaskTableRowData, y: TaskTableRowData): Int =
+ Ordering.Long.compare(x.launchTime, y.launchTime)
+ }
+ case "Duration" => new Ordering[TaskTableRowData] {
+ override def compare(x: TaskTableRowData, y: TaskTableRowData): Int =
+ Ordering.Long.compare(x.duration, y.duration)
+ }
+ case "Scheduler Delay" => new Ordering[TaskTableRowData] {
+ override def compare(x: TaskTableRowData, y: TaskTableRowData): Int =
+ Ordering.Long.compare(x.schedulerDelay, y.schedulerDelay)
+ }
+ case "Task Deserialization Time" => new Ordering[TaskTableRowData] {
+ override def compare(x: TaskTableRowData, y: TaskTableRowData): Int =
+ Ordering.Long.compare(x.taskDeserializationTime, y.taskDeserializationTime)
+ }
+ case "GC Time" => new Ordering[TaskTableRowData] {
+ override def compare(x: TaskTableRowData, y: TaskTableRowData): Int =
+ Ordering.Long.compare(x.gcTime, y.gcTime)
+ }
+ case "Result Serialization Time" => new Ordering[TaskTableRowData] {
+ override def compare(x: TaskTableRowData, y: TaskTableRowData): Int =
+ Ordering.Long.compare(x.serializationTime, y.serializationTime)
+ }
+ case "Getting Result Time" => new Ordering[TaskTableRowData] {
+ override def compare(x: TaskTableRowData, y: TaskTableRowData): Int =
+ Ordering.Long.compare(x.gettingResultTime, y.gettingResultTime)
+ }
+ case "Accumulators" =>
+ if (hasAccumulators) {
+ new Ordering[TaskTableRowData] {
+ override def compare(x: TaskTableRowData, y: TaskTableRowData): Int =
+ Ordering.String.compare(x.accumulators.get, y.accumulators.get)
+ }
+ } else {
+ throw new IllegalArgumentException(
+ "Cannot sort by Accumulators because of no accumulators")
+ }
+ case "Input Size / Records" =>
+ if (hasInput) {
+ new Ordering[TaskTableRowData] {
+ override def compare(x: TaskTableRowData, y: TaskTableRowData): Int =
+ Ordering.Long.compare(x.input.get.inputSortable, y.input.get.inputSortable)
+ }
+ } else {
+ throw new IllegalArgumentException(
+ "Cannot sort by Input Size / Records because of no inputs")
+ }
+ case "Output Size / Records" =>
+ if (hasOutput) {
+ new Ordering[TaskTableRowData] {
+ override def compare(x: TaskTableRowData, y: TaskTableRowData): Int =
+ Ordering.Long.compare(x.output.get.outputSortable, y.output.get.outputSortable)
+ }
+ } else {
+ throw new IllegalArgumentException(
+ "Cannot sort by Output Size / Records because of no outputs")
+ }
+ // ShuffleRead
+ case "Shuffle Read Blocked Time" =>
+ if (hasShuffleRead) {
+ new Ordering[TaskTableRowData] {
+ override def compare(x: TaskTableRowData, y: TaskTableRowData): Int =
+ Ordering.Long.compare(x.shuffleRead.get.shuffleReadBlockedTimeSortable,
+ y.shuffleRead.get.shuffleReadBlockedTimeSortable)
+ }
+ } else {
+ throw new IllegalArgumentException(
+ "Cannot sort by Shuffle Read Blocked Time because of no shuffle reads")
+ }
+ case "Shuffle Read Size / Records" =>
+ if (hasShuffleRead) {
+ new Ordering[TaskTableRowData] {
+ override def compare(x: TaskTableRowData, y: TaskTableRowData): Int =
+ Ordering.Long.compare(x.shuffleRead.get.shuffleReadSortable,
+ y.shuffleRead.get.shuffleReadSortable)
+ }
+ } else {
+ throw new IllegalArgumentException(
+ "Cannot sort by Shuffle Read Size / Records because of no shuffle reads")
+ }
+ case "Shuffle Remote Reads" =>
+ if (hasShuffleRead) {
+ new Ordering[TaskTableRowData] {
+ override def compare(x: TaskTableRowData, y: TaskTableRowData): Int =
+ Ordering.Long.compare(x.shuffleRead.get.shuffleReadRemoteSortable,
+ y.shuffleRead.get.shuffleReadRemoteSortable)
+ }
+ } else {
+ throw new IllegalArgumentException(
+ "Cannot sort by Shuffle Remote Reads because of no shuffle reads")
+ }
+ // ShuffleWrite
+ case "Write Time" =>
+ if (hasShuffleWrite) {
+ new Ordering[TaskTableRowData] {
+ override def compare(x: TaskTableRowData, y: TaskTableRowData): Int =
+ Ordering.Long.compare(x.shuffleWrite.get.writeTimeSortable,
+ y.shuffleWrite.get.writeTimeSortable)
+ }
+ } else {
+ throw new IllegalArgumentException(
+ "Cannot sort by Write Time because of no shuffle writes")
+ }
+ case "Shuffle Write Size / Records" =>
+ if (hasShuffleWrite) {
+ new Ordering[TaskTableRowData] {
+ override def compare(x: TaskTableRowData, y: TaskTableRowData): Int =
+ Ordering.Long.compare(x.shuffleWrite.get.shuffleWriteSortable,
+ y.shuffleWrite.get.shuffleWriteSortable)
+ }
+ } else {
+ throw new IllegalArgumentException(
+ "Cannot sort by Shuffle Write Size / Records because of no shuffle writes")
+ }
+ // BytesSpilled
+ case "Shuffle Spill (Memory)" =>
+ if (hasBytesSpilled) {
+ new Ordering[TaskTableRowData] {
+ override def compare(x: TaskTableRowData, y: TaskTableRowData): Int =
+ Ordering.Long.compare(x.bytesSpilled.get.memoryBytesSpilledSortable,
+ y.bytesSpilled.get.memoryBytesSpilledSortable)
+ }
+ } else {
+ throw new IllegalArgumentException(
+ "Cannot sort by Shuffle Spill (Memory) because of no spills")
+ }
+ case "Shuffle Spill (Disk)" =>
+ if (hasBytesSpilled) {
+ new Ordering[TaskTableRowData] {
+ override def compare(x: TaskTableRowData, y: TaskTableRowData): Int =
+ Ordering.Long.compare(x.bytesSpilled.get.diskBytesSpilledSortable,
+ y.bytesSpilled.get.diskBytesSpilledSortable)
+ }
+ } else {
+ throw new IllegalArgumentException(
+ "Cannot sort by Shuffle Spill (Disk) because of no spills")
+ }
+ case "Errors" => new Ordering[TaskTableRowData] {
+ override def compare(x: TaskTableRowData, y: TaskTableRowData): Int =
+ Ordering.String.compare(x.error, y.error)
+ }
+ case unknownColumn => throw new IllegalArgumentException(s"Unknown column: $unknownColumn")
+ }
+ if (desc) {
+ ordering.reverse
+ } else {
+ ordering
+ }
+ }
+
+}
+
+private[ui] class TaskPagedTable(
+ basePath: String,
+ data: Seq[TaskUIData],
+ hasAccumulators: Boolean,
+ hasInput: Boolean,
+ hasOutput: Boolean,
+ hasShuffleRead: Boolean,
+ hasShuffleWrite: Boolean,
+ hasBytesSpilled: Boolean,
+ currentTime: Long,
+ pageSize: Int,
+ sortColumn: String,
+ desc: Boolean) extends PagedTable[TaskTableRowData]{
+
+ override def tableId: String = ""
+
+ override def tableCssClass: String = "table table-bordered table-condensed table-striped"
+
+ override val dataSource: TaskDataSource = new TaskDataSource(
+ data,
+ hasAccumulators,
+ hasInput,
+ hasOutput,
+ hasShuffleRead,
+ hasShuffleWrite,
+ hasBytesSpilled,
+ currentTime,
+ pageSize,
+ sortColumn,
+ desc
+ )
+
+ override def pageLink(page: Int): String = {
+ val encodedSortColumn = URLEncoder.encode(sortColumn, "UTF-8")
+ s"${basePath}&task.page=$page&task.sort=${encodedSortColumn}&task.desc=${desc}" +
+ s"&task.pageSize=${pageSize}"
+ }
+
+ override def goButtonJavascriptFunction: (String, String) = {
+ val jsFuncName = "goToTaskPage"
+ val encodedSortColumn = URLEncoder.encode(sortColumn, "UTF-8")
+ val jsFunc = s"""
+ |currentTaskPageSize = ${pageSize}
+ |function goToTaskPage(page, pageSize) {
+ | // Set page to 1 if the page size changes
+ | page = pageSize == currentTaskPageSize ? page : 1;
+ | var url = "${basePath}&task.sort=${encodedSortColumn}&task.desc=${desc}" +
+ | "&task.page=" + page + "&task.pageSize=" + pageSize;
+ | window.location.href = url;
+ |}
+ """.stripMargin
+ (jsFuncName, jsFunc)
+ }
+
+ def headers: Seq[Node] = {
+ val taskHeadersAndCssClasses: Seq[(String, String)] =
+ Seq(
+ ("Index", ""), ("ID", ""), ("Attempt", ""), ("Status", ""), ("Locality Level", ""),
+ ("Executor ID / Host", ""), ("Launch Time", ""), ("Duration", ""),
+ ("Scheduler Delay", TaskDetailsClassNames.SCHEDULER_DELAY),
+ ("Task Deserialization Time", TaskDetailsClassNames.TASK_DESERIALIZATION_TIME),
+ ("GC Time", ""),
+ ("Result Serialization Time", TaskDetailsClassNames.RESULT_SERIALIZATION_TIME),
+ ("Getting Result Time", TaskDetailsClassNames.GETTING_RESULT_TIME)) ++
+ {if (hasAccumulators) Seq(("Accumulators", "")) else Nil} ++
+ {if (hasInput) Seq(("Input Size / Records", "")) else Nil} ++
+ {if (hasOutput) Seq(("Output Size / Records", "")) else Nil} ++
{if (hasShuffleRead) {
-
- {shuffleReadBlockedTimeReadable}
- |
-
- {s"$shuffleReadReadable / $shuffleReadRecords"}
- |
-
- {shuffleReadRemoteReadable}
- |
- }}
+ Seq(("Shuffle Read Blocked Time", TaskDetailsClassNames.SHUFFLE_READ_BLOCKED_TIME),
+ ("Shuffle Read Size / Records", ""),
+ ("Shuffle Remote Reads", TaskDetailsClassNames.SHUFFLE_READ_REMOTE_SIZE))
+ } else {
+ Nil
+ }} ++
{if (hasShuffleWrite) {
-
- {writeTimeReadable}
- |
-
- {s"$shuffleWriteReadable / $shuffleWriteRecords"}
- |
- }}
+ Seq(("Write Time", ""), ("Shuffle Write Size / Records", ""))
+ } else {
+ Nil
+ }} ++
{if (hasBytesSpilled) {
-
- {memoryBytesSpilledReadable}
- |
-
- {diskBytesSpilledReadable}
- |
- }}
- {errorMessageCell(errorMessage)}
-
+ Seq(("Shuffle Spill (Memory)", ""), ("Shuffle Spill (Disk)", ""))
+ } else {
+ Nil
+ }} ++
+ Seq(("Errors", ""))
+
+ if (!taskHeadersAndCssClasses.map(_._1).contains(sortColumn)) {
+ new IllegalArgumentException(s"Unknown column: $sortColumn")
}
+
+ val headerRow: Seq[Node] = {
+ taskHeadersAndCssClasses.map { case (header, cssClass) =>
+ if (header == sortColumn) {
+ val headerLink =
+ s"$basePath&task.sort=${URLEncoder.encode(header, "UTF-8")}&task.desc=${!desc}" +
+ s"&task.pageSize=${pageSize}"
+ val js = Unparsed(s"window.location.href='${headerLink}'")
+ val arrow = if (desc) "▾" else "▴" // UP or DOWN
+
+ {header}
+ {Unparsed(arrow)}
+ |
+ } else {
+ val headerLink =
+ s"$basePath&task.sort=${URLEncoder.encode(header, "UTF-8")}&task.pageSize=${pageSize}"
+ val js = Unparsed(s"window.location.href='${headerLink}'")
+
+ {header}
+ |
+ }
+ }
+ }
+ {headerRow}
+ }
+
+ def row(task: TaskTableRowData): Seq[Node] = {
+
+ {task.index} |
+ {task.taskId} |
+ {if (task.speculative) s"${task.attempt} (speculative)" else task.attempt.toString} |
+ {task.status} |
+ {task.taskLocality} |
+ {task.executorIdAndHost} |
+ {UIUtils.formatDate(new Date(task.launchTime))} |
+ {task.formatDuration} |
+
+ {UIUtils.formatDuration(task.schedulerDelay)}
+ |
+
+ {UIUtils.formatDuration(task.taskDeserializationTime)}
+ |
+
+ {if (task.gcTime > 0) UIUtils.formatDuration(task.gcTime) else ""}
+ |
+
+ {UIUtils.formatDuration(task.serializationTime)}
+ |
+
+ {UIUtils.formatDuration(task.gettingResultTime)}
+ |
+ {if (task.accumulators.nonEmpty) {
+ {Unparsed(task.accumulators.get)} |
+ }}
+ {if (task.input.nonEmpty) {
+ {task.input.get.inputReadable} |
+ }}
+ {if (task.output.nonEmpty) {
+ {task.output.get.outputReadable} |
+ }}
+ {if (task.shuffleRead.nonEmpty) {
+
+ {task.shuffleRead.get.shuffleReadBlockedTimeReadable}
+ |
+ {task.shuffleRead.get.shuffleReadReadable} |
+
+ {task.shuffleRead.get.shuffleReadRemoteReadable}
+ |
+ }}
+ {if (task.shuffleWrite.nonEmpty) {
+ {task.shuffleWrite.get.writeTimeReadable} |
+ {task.shuffleWrite.get.shuffleWriteReadable} |
+ }}
+ {if (task.bytesSpilled.nonEmpty) {
+ {task.bytesSpilled.get.memoryBytesSpilledReadable} |
+ {task.bytesSpilled.get.diskBytesSpilledReadable} |
+ }}
+ {errorMessageCell(task.error)}
+
}
- private def errorMessageCell(errorMessage: Option[String]): Seq[Node] = {
- val error = errorMessage.getOrElse("")
+ private def errorMessageCell(error: String): Seq[Node] = {
val isMultiline = error.indexOf('\n') >= 0
// Display the first line by default
val errorSummary = StringEscapeUtils.escapeHtml4(
@@ -860,32 +1325,4 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
}
{errorSummary}{details} |
}
-
- private def getGettingResultTime(info: TaskInfo, currentTime: Long): Long = {
- if (info.gettingResult) {
- if (info.finished) {
- info.finishTime - info.gettingResultTime
- } else {
- // The task is still fetching the result.
- currentTime - info.gettingResultTime
- }
- } else {
- 0L
- }
- }
-
- private def getSchedulerDelay(info: TaskInfo, metrics: TaskMetrics, currentTime: Long): Long = {
- if (info.finished) {
- val totalExecutionTime = info.finishTime - info.launchTime
- val executorOverhead = (metrics.executorDeserializeTime +
- metrics.resultSerializationTime)
- math.max(
- 0,
- totalExecutionTime - metrics.executorRunTime - executorOverhead -
- getGettingResultTime(info, currentTime))
- } else {
- // The task is still running and the metrics like executorRunTime are not available.
- 0L
- }
- }
}
diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
index 43626b4ef4880..ebead830c6466 100644
--- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
@@ -49,45 +49,28 @@ private[spark] object ClosureCleaner extends Logging {
cls.getName.contains("$anonfun$")
}
- // Get a list of the classes of the outer objects of a given closure object, obj;
+ // Get a list of the outer objects and their classes of a given closure object, obj;
// the outer objects are defined as any closures that obj is nested within, plus
// possibly the class that the outermost closure is in, if any. We stop searching
// for outer objects beyond that because cloning the user's object is probably
// not a good idea (whereas we can clone closure objects just fine since we
// understand how all their fields are used).
- private def getOuterClasses(obj: AnyRef): List[Class[_]] = {
+ private def getOuterClassesAndObjects(obj: AnyRef): (List[Class[_]], List[AnyRef]) = {
for (f <- obj.getClass.getDeclaredFields if f.getName == "$outer") {
f.setAccessible(true)
val outer = f.get(obj)
// The outer pointer may be null if we have cleaned this closure before
if (outer != null) {
if (isClosure(f.getType)) {
- return f.getType :: getOuterClasses(outer)
+ val recurRet = getOuterClassesAndObjects(outer)
+ return (f.getType :: recurRet._1, outer :: recurRet._2)
} else {
- return f.getType :: Nil // Stop at the first $outer that is not a closure
+ return (f.getType :: Nil, outer :: Nil) // Stop at the first $outer that is not a closure
}
}
}
- Nil
+ (Nil, Nil)
}
-
- // Get a list of the outer objects for a given closure object.
- private def getOuterObjects(obj: AnyRef): List[AnyRef] = {
- for (f <- obj.getClass.getDeclaredFields if f.getName == "$outer") {
- f.setAccessible(true)
- val outer = f.get(obj)
- // The outer pointer may be null if we have cleaned this closure before
- if (outer != null) {
- if (isClosure(f.getType)) {
- return outer :: getOuterObjects(outer)
- } else {
- return outer :: Nil // Stop at the first $outer that is not a closure
- }
- }
- }
- Nil
- }
-
/**
* Return a list of classes that represent closures enclosed in the given closure object.
*/
@@ -205,8 +188,7 @@ private[spark] object ClosureCleaner extends Logging {
// A list of enclosing objects and their respective classes, from innermost to outermost
// An outer object at a given index is of type outer class at the same index
- val outerClasses = getOuterClasses(func)
- val outerObjects = getOuterObjects(func)
+ val (outerClasses, outerObjects) = getOuterClassesAndObjects(func)
// For logging purposes only
val declaredFields = func.getClass.getDeclaredFields
diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
index adf69a4e78e71..a078f14af52a1 100644
--- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
@@ -92,8 +92,8 @@ private[spark] object JsonProtocol {
executorRemovedToJson(executorRemoved)
case logStart: SparkListenerLogStart =>
logStartToJson(logStart)
- // These aren't used, but keeps compiler happy
- case SparkListenerExecutorMetricsUpdate(_, _) => JNothing
+ case metricsUpdate: SparkListenerExecutorMetricsUpdate =>
+ executorMetricsUpdateToJson(metricsUpdate)
}
}
@@ -224,6 +224,19 @@ private[spark] object JsonProtocol {
("Spark Version" -> SPARK_VERSION)
}
+ def executorMetricsUpdateToJson(metricsUpdate: SparkListenerExecutorMetricsUpdate): JValue = {
+ val execId = metricsUpdate.execId
+ val taskMetrics = metricsUpdate.taskMetrics
+ ("Event" -> Utils.getFormattedClassName(metricsUpdate)) ~
+ ("Executor ID" -> execId) ~
+ ("Metrics Updated" -> taskMetrics.map { case (taskId, stageId, stageAttemptId, metrics) =>
+ ("Task ID" -> taskId) ~
+ ("Stage ID" -> stageId) ~
+ ("Stage Attempt ID" -> stageAttemptId) ~
+ ("Task Metrics" -> taskMetricsToJson(metrics))
+ })
+ }
+
/** ------------------------------------------------------------------- *
* JSON serialization methods for classes SparkListenerEvents depend on |
* -------------------------------------------------------------------- */
@@ -463,6 +476,7 @@ private[spark] object JsonProtocol {
val executorAdded = Utils.getFormattedClassName(SparkListenerExecutorAdded)
val executorRemoved = Utils.getFormattedClassName(SparkListenerExecutorRemoved)
val logStart = Utils.getFormattedClassName(SparkListenerLogStart)
+ val metricsUpdate = Utils.getFormattedClassName(SparkListenerExecutorMetricsUpdate)
(json \ "Event").extract[String] match {
case `stageSubmitted` => stageSubmittedFromJson(json)
@@ -481,6 +495,7 @@ private[spark] object JsonProtocol {
case `executorAdded` => executorAddedFromJson(json)
case `executorRemoved` => executorRemovedFromJson(json)
case `logStart` => logStartFromJson(json)
+ case `metricsUpdate` => executorMetricsUpdateFromJson(json)
}
}
@@ -598,6 +613,18 @@ private[spark] object JsonProtocol {
SparkListenerLogStart(sparkVersion)
}
+ def executorMetricsUpdateFromJson(json: JValue): SparkListenerExecutorMetricsUpdate = {
+ val execInfo = (json \ "Executor ID").extract[String]
+ val taskMetrics = (json \ "Metrics Updated").extract[List[JValue]].map { json =>
+ val taskId = (json \ "Task ID").extract[Long]
+ val stageId = (json \ "Stage ID").extract[Int]
+ val stageAttemptId = (json \ "Stage Attempt ID").extract[Int]
+ val metrics = taskMetricsFromJson(json \ "Task Metrics")
+ (taskId, stageId, stageAttemptId, metrics)
+ }
+ SparkListenerExecutorMetricsUpdate(execInfo, taskMetrics)
+ }
+
/** --------------------------------------------------------------------- *
* JSON deserialization methods for classes SparkListenerEvents depend on |
* ---------------------------------------------------------------------- */
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
index 1e4531ef395ae..d166037351c31 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
@@ -26,7 +26,7 @@ import scala.collection.mutable.ArrayBuffer
import com.google.common.io.ByteStreams
-import org.apache.spark.{Logging, SparkEnv}
+import org.apache.spark.{Logging, SparkEnv, TaskContext}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.serializer.{DeserializationStream, Serializer}
import org.apache.spark.storage.{BlockId, BlockManager}
@@ -470,14 +470,27 @@ class ExternalAppendOnlyMap[K, V, C](
item
}
- // TODO: Ensure this gets called even if the iterator isn't drained.
private def cleanup() {
batchIndex = batchOffsets.length // Prevent reading any other batch
val ds = deserializeStream
- deserializeStream = null
- fileStream = null
- ds.close()
- file.delete()
+ if (ds != null) {
+ ds.close()
+ deserializeStream = null
+ }
+ if (fileStream != null) {
+ fileStream.close()
+ fileStream = null
+ }
+ if (file.exists()) {
+ file.delete()
+ }
+ }
+
+ val context = TaskContext.get()
+ // context is null in some tests of ExternalAppendOnlyMapSuite because these tests don't run in
+ // a TaskContext.
+ if (context != null) {
+ context.addTaskCompletionListener(context => cleanup())
}
}
diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java
index dfd86d3e51e7d..1b04a3b1cff0e 100644
--- a/core/src/test/java/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java
@@ -1011,7 +1011,7 @@ public void persist() {
@Test
public void iterator() {
JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2);
- TaskContext context = new TaskContextImpl(0, 0, 0L, 0, null, false, new TaskMetrics());
+ TaskContext context = new TaskContextImpl(0, 0, 0L, 0, null, null, false, new TaskMetrics());
Assert.assertEquals(1, rdd.iterator(rdd.partitions().get(0), context).next().intValue());
}
diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
index af81e46a657d3..618a5fb24710f 100644
--- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
@@ -65,7 +65,7 @@ class CacheManagerSuite extends SparkFunSuite with LocalSparkContext with Before
// in blockManager.put is a losing battle. You have been warned.
blockManager = sc.env.blockManager
cacheManager = sc.env.cacheManager
- val context = new TaskContextImpl(0, 0, 0, 0, null)
+ val context = new TaskContextImpl(0, 0, 0, 0, null, null)
val computeValue = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
val getValue = blockManager.get(RDDBlockId(rdd.id, split.index))
assert(computeValue.toList === List(1, 2, 3, 4))
@@ -77,7 +77,7 @@ class CacheManagerSuite extends SparkFunSuite with LocalSparkContext with Before
val result = new BlockResult(Array(5, 6, 7).iterator, DataReadMethod.Memory, 12)
when(blockManager.get(RDDBlockId(0, 0))).thenReturn(Some(result))
- val context = new TaskContextImpl(0, 0, 0, 0, null)
+ val context = new TaskContextImpl(0, 0, 0, 0, null, null)
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
assert(value.toList === List(5, 6, 7))
}
@@ -86,14 +86,14 @@ class CacheManagerSuite extends SparkFunSuite with LocalSparkContext with Before
// Local computation should not persist the resulting value, so don't expect a put().
when(blockManager.get(RDDBlockId(0, 0))).thenReturn(None)
- val context = new TaskContextImpl(0, 0, 0, 0, null, true)
+ val context = new TaskContextImpl(0, 0, 0, 0, null, null, true)
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
assert(value.toList === List(1, 2, 3, 4))
}
test("verify task metrics updated correctly") {
cacheManager = sc.env.cacheManager
- val context = new TaskContextImpl(0, 0, 0, 0, null)
+ val context = new TaskContextImpl(0, 0, 0, 0, null, null)
cacheManager.getOrCompute(rdd3, split, context, StorageLevel.MEMORY_ONLY)
assert(context.taskMetrics.updatedBlocks.getOrElse(Seq()).size === 2)
}
diff --git a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala
index 32f04d54eff94..3e8816a4c65be 100644
--- a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala
@@ -175,7 +175,7 @@ class PipedRDDSuite extends SparkFunSuite with SharedSparkContext {
}
val hadoopPart1 = generateFakeHadoopPartition()
val pipedRdd = new PipedRDD(nums, "printenv " + varName)
- val tContext = new TaskContextImpl(0, 0, 0, 0, null)
+ val tContext = new TaskContextImpl(0, 0, 0, 0, null, null)
val rddIter = pipedRdd.compute(hadoopPart1, tContext)
val arr = rddIter.toArray
assert(arr(0) == "/some/path")
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
index b9b0eccb0d834..9201d1e1f328b 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
@@ -24,11 +24,27 @@ import org.scalatest.BeforeAndAfter
import org.apache.spark._
import org.apache.spark.rdd.RDD
-import org.apache.spark.util.{TaskCompletionListenerException, TaskCompletionListener}
+import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException}
+import org.apache.spark.metrics.source.JvmSource
class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkContext {
+ test("provide metrics sources") {
+ val filePath = getClass.getClassLoader.getResource("test_metrics_config.properties").getFile
+ val conf = new SparkConf(loadDefaults = false)
+ .set("spark.metrics.conf", filePath)
+ sc = new SparkContext("local", "test", conf)
+ val rdd = sc.makeRDD(1 to 1)
+ val result = sc.runJob(rdd, (tc: TaskContext, it: Iterator[Int]) => {
+ tc.getMetricsSources("jvm").count {
+ case source: JvmSource => true
+ case _ => false
+ }
+ }).sum
+ assert(result > 0)
+ }
+
test("calls TaskCompletionListener after failure") {
TaskContextSuite.completed = false
sc = new SparkContext("local", "test")
@@ -44,13 +60,13 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
val task = new ResultTask[String, String](0, 0,
sc.broadcast(closureSerializer.serialize((rdd, func)).array), rdd.partitions(0), Seq(), 0)
intercept[RuntimeException] {
- task.run(0, 0)
+ task.run(0, 0, null)
}
assert(TaskContextSuite.completed === true)
}
test("all TaskCompletionListeners should be called even if some fail") {
- val context = new TaskContextImpl(0, 0, 0, 0, null)
+ val context = new TaskContextImpl(0, 0, 0, 0, null, null)
val listener = mock(classOf[TaskCompletionListener])
context.addTaskCompletionListener(_ => throw new Exception("blah"))
context.addTaskCompletionListener(listener)
diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala
index 6c9cb448e7833..db718ecabbdb9 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala
@@ -138,7 +138,7 @@ class HashShuffleReaderSuite extends SparkFunSuite with LocalSparkContext {
shuffleHandle,
reduceId,
reduceId + 1,
- new TaskContextImpl(0, 0, 0, 0, null),
+ new TaskContextImpl(0, 0, 0, 0, null, null),
blockManager,
mapOutputTracker)
diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
index 64f3fbdcebed9..cf8bd8ae69625 100644
--- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
@@ -95,7 +95,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
)
val iterator = new ShuffleBlockFetcherIterator(
- new TaskContextImpl(0, 0, 0, 0, null),
+ new TaskContextImpl(0, 0, 0, 0, null, null),
transfer,
blockManager,
blocksByAddress,
@@ -165,7 +165,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
(remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq))
- val taskContext = new TaskContextImpl(0, 0, 0, 0, null)
+ val taskContext = new TaskContextImpl(0, 0, 0, 0, null, null)
val iterator = new ShuffleBlockFetcherIterator(
taskContext,
transfer,
@@ -227,7 +227,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
(remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq))
- val taskContext = new TaskContextImpl(0, 0, 0, 0, null)
+ val taskContext = new TaskContextImpl(0, 0, 0, 0, null, null)
val iterator = new ShuffleBlockFetcherIterator(
taskContext,
transfer,
diff --git a/core/src/test/scala/org/apache/spark/ui/PagedTableSuite.scala b/core/src/test/scala/org/apache/spark/ui/PagedTableSuite.scala
new file mode 100644
index 0000000000000..cc76c141c53cc
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/ui/PagedTableSuite.scala
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui
+
+import scala.xml.Node
+
+import org.apache.spark.SparkFunSuite
+
+class PagedDataSourceSuite extends SparkFunSuite {
+
+ test("basic") {
+ val dataSource1 = new SeqPagedDataSource[Int](1 to 5, pageSize = 2)
+ assert(dataSource1.pageData(1) === PageData(3, (1 to 2)))
+
+ val dataSource2 = new SeqPagedDataSource[Int](1 to 5, pageSize = 2)
+ assert(dataSource2.pageData(2) === PageData(3, (3 to 4)))
+
+ val dataSource3 = new SeqPagedDataSource[Int](1 to 5, pageSize = 2)
+ assert(dataSource3.pageData(3) === PageData(3, Seq(5)))
+
+ val dataSource4 = new SeqPagedDataSource[Int](1 to 5, pageSize = 2)
+ val e1 = intercept[IndexOutOfBoundsException] {
+ dataSource4.pageData(4)
+ }
+ assert(e1.getMessage === "Page 4 is out of range. Please select a page number between 1 and 3.")
+
+ val dataSource5 = new SeqPagedDataSource[Int](1 to 5, pageSize = 2)
+ val e2 = intercept[IndexOutOfBoundsException] {
+ dataSource5.pageData(0)
+ }
+ assert(e2.getMessage === "Page 0 is out of range. Please select a page number between 1 and 3.")
+
+ }
+}
+
+class PagedTableSuite extends SparkFunSuite {
+ test("pageNavigation") {
+ // Create a fake PagedTable to test pageNavigation
+ val pagedTable = new PagedTable[Int] {
+ override def tableId: String = ""
+
+ override def tableCssClass: String = ""
+
+ override def dataSource: PagedDataSource[Int] = null
+
+ override def pageLink(page: Int): String = page.toString
+
+ override def headers: Seq[Node] = Nil
+
+ override def row(t: Int): Seq[Node] = Nil
+
+ override def goButtonJavascriptFunction: (String, String) = ("", "")
+ }
+
+ assert(pagedTable.pageNavigation(1, 10, 1) === Nil)
+ assert(
+ (pagedTable.pageNavigation(1, 10, 2).head \\ "li").map(_.text.trim) === Seq("1", "2", ">"))
+ assert(
+ (pagedTable.pageNavigation(2, 10, 2).head \\ "li").map(_.text.trim) === Seq("<", "1", "2"))
+
+ assert((pagedTable.pageNavigation(1, 10, 100).head \\ "li").map(_.text.trim) ===
+ (1 to 10).map(_.toString) ++ Seq(">", ">>"))
+ assert((pagedTable.pageNavigation(2, 10, 100).head \\ "li").map(_.text.trim) ===
+ Seq("<") ++ (1 to 10).map(_.toString) ++ Seq(">", ">>"))
+
+ assert((pagedTable.pageNavigation(100, 10, 100).head \\ "li").map(_.text.trim) ===
+ Seq("<<", "<") ++ (91 to 100).map(_.toString))
+ assert((pagedTable.pageNavigation(99, 10, 100).head \\ "li").map(_.text.trim) ===
+ Seq("<<", "<") ++ (91 to 100).map(_.toString) ++ Seq(">"))
+
+ assert((pagedTable.pageNavigation(11, 10, 100).head \\ "li").map(_.text.trim) ===
+ Seq("<<", "<") ++ (11 to 20).map(_.toString) ++ Seq(">", ">>"))
+ assert((pagedTable.pageNavigation(93, 10, 97).head \\ "li").map(_.text.trim) ===
+ Seq("<<", "<") ++ (91 to 97).map(_.toString) ++ Seq(">"))
+ }
+}
+
+private[spark] class SeqPagedDataSource[T](seq: Seq[T], pageSize: Int)
+ extends PagedDataSource[T](pageSize) {
+
+ override protected def dataSize: Int = seq.size
+
+ override protected def sliceData(from: Int, to: Int): Seq[T] = seq.slice(from, to)
+}
diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala
index 3147c937769d2..a829b099025e9 100644
--- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala
+++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala
@@ -120,8 +120,8 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri
// Accessors for private methods
private val _isClosure = PrivateMethod[Boolean]('isClosure)
private val _getInnerClosureClasses = PrivateMethod[List[Class[_]]]('getInnerClosureClasses)
- private val _getOuterClasses = PrivateMethod[List[Class[_]]]('getOuterClasses)
- private val _getOuterObjects = PrivateMethod[List[AnyRef]]('getOuterObjects)
+ private val _getOuterClassesAndObjects =
+ PrivateMethod[(List[Class[_]], List[AnyRef])]('getOuterClassesAndObjects)
private def isClosure(obj: AnyRef): Boolean = {
ClosureCleaner invokePrivate _isClosure(obj)
@@ -131,12 +131,8 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri
ClosureCleaner invokePrivate _getInnerClosureClasses(closure)
}
- private def getOuterClasses(closure: AnyRef): List[Class[_]] = {
- ClosureCleaner invokePrivate _getOuterClasses(closure)
- }
-
- private def getOuterObjects(closure: AnyRef): List[AnyRef] = {
- ClosureCleaner invokePrivate _getOuterObjects(closure)
+ private def getOuterClassesAndObjects(closure: AnyRef): (List[Class[_]], List[AnyRef]) = {
+ ClosureCleaner invokePrivate _getOuterClassesAndObjects(closure)
}
test("get inner closure classes") {
@@ -171,14 +167,11 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri
val closure2 = () => localValue
val closure3 = () => someSerializableValue
val closure4 = () => someSerializableMethod()
- val outerClasses1 = getOuterClasses(closure1)
- val outerClasses2 = getOuterClasses(closure2)
- val outerClasses3 = getOuterClasses(closure3)
- val outerClasses4 = getOuterClasses(closure4)
- val outerObjects1 = getOuterObjects(closure1)
- val outerObjects2 = getOuterObjects(closure2)
- val outerObjects3 = getOuterObjects(closure3)
- val outerObjects4 = getOuterObjects(closure4)
+
+ val (outerClasses1, outerObjects1) = getOuterClassesAndObjects(closure1)
+ val (outerClasses2, outerObjects2) = getOuterClassesAndObjects(closure2)
+ val (outerClasses3, outerObjects3) = getOuterClassesAndObjects(closure3)
+ val (outerClasses4, outerObjects4) = getOuterClassesAndObjects(closure4)
// The classes and objects should have the same size
assert(outerClasses1.size === outerObjects1.size)
@@ -211,10 +204,8 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri
val x = 1
val closure1 = () => 1
val closure2 = () => x
- val outerClasses1 = getOuterClasses(closure1)
- val outerClasses2 = getOuterClasses(closure2)
- val outerObjects1 = getOuterObjects(closure1)
- val outerObjects2 = getOuterObjects(closure2)
+ val (outerClasses1, outerObjects1) = getOuterClassesAndObjects(closure1)
+ val (outerClasses2, outerObjects2) = getOuterClassesAndObjects(closure2)
assert(outerClasses1.size === outerObjects1.size)
assert(outerClasses2.size === outerObjects2.size)
// These inner closures only reference local variables, and so do not have $outer pointers
@@ -227,12 +218,9 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri
val closure1 = () => 1
val closure2 = () => y
val closure3 = () => localValue
- val outerClasses1 = getOuterClasses(closure1)
- val outerClasses2 = getOuterClasses(closure2)
- val outerClasses3 = getOuterClasses(closure3)
- val outerObjects1 = getOuterObjects(closure1)
- val outerObjects2 = getOuterObjects(closure2)
- val outerObjects3 = getOuterObjects(closure3)
+ val (outerClasses1, outerObjects1) = getOuterClassesAndObjects(closure1)
+ val (outerClasses2, outerObjects2) = getOuterClassesAndObjects(closure2)
+ val (outerClasses3, outerObjects3) = getOuterClassesAndObjects(closure3)
assert(outerClasses1.size === outerObjects1.size)
assert(outerClasses2.size === outerObjects2.size)
assert(outerClasses3.size === outerObjects3.size)
@@ -265,9 +253,9 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri
val closure1 = () => 1
val closure2 = () => localValue
val closure3 = () => someSerializableValue
- val outerClasses1 = getOuterClasses(closure1)
- val outerClasses2 = getOuterClasses(closure2)
- val outerClasses3 = getOuterClasses(closure3)
+ val (outerClasses1, _) = getOuterClassesAndObjects(closure1)
+ val (outerClasses2, _) = getOuterClassesAndObjects(closure2)
+ val (outerClasses3, _) = getOuterClassesAndObjects(closure3)
val fields1 = findAccessedFields(closure1, outerClasses1, findTransitively = false)
val fields2 = findAccessedFields(closure2, outerClasses2, findTransitively = false)
@@ -307,10 +295,10 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri
val closure2 = () => a
val closure3 = () => localValue
val closure4 = () => someSerializableValue
- val outerClasses1 = getOuterClasses(closure1)
- val outerClasses2 = getOuterClasses(closure2)
- val outerClasses3 = getOuterClasses(closure3)
- val outerClasses4 = getOuterClasses(closure4)
+ val (outerClasses1, _) = getOuterClassesAndObjects(closure1)
+ val (outerClasses2, _) = getOuterClassesAndObjects(closure2)
+ val (outerClasses3, _) = getOuterClassesAndObjects(closure3)
+ val (outerClasses4, _) = getOuterClassesAndObjects(closure4)
// First, find only fields accessed directly, not transitively, by these closures
val fields1 = findAccessedFields(closure1, outerClasses1, findTransitively = false)
diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
index e0ef9c70a5fc3..dde95f3778434 100644
--- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
@@ -83,6 +83,9 @@ class JsonProtocolSuite extends SparkFunSuite {
val executorAdded = SparkListenerExecutorAdded(executorAddedTime, "exec1",
new ExecutorInfo("Hostee.awesome.com", 11, logUrlMap))
val executorRemoved = SparkListenerExecutorRemoved(executorRemovedTime, "exec2", "test reason")
+ val executorMetricsUpdate = SparkListenerExecutorMetricsUpdate("exec3", Seq(
+ (1L, 2, 3, makeTaskMetrics(300L, 400L, 500L, 600L, 700, 800,
+ hasHadoopInput = true, hasOutput = true))))
testEvent(stageSubmitted, stageSubmittedJsonString)
testEvent(stageCompleted, stageCompletedJsonString)
@@ -102,6 +105,7 @@ class JsonProtocolSuite extends SparkFunSuite {
testEvent(applicationEnd, applicationEndJsonString)
testEvent(executorAdded, executorAddedJsonString)
testEvent(executorRemoved, executorRemovedJsonString)
+ testEvent(executorMetricsUpdate, executorMetricsUpdateJsonString)
}
test("Dependent Classes") {
@@ -440,10 +444,20 @@ class JsonProtocolSuite extends SparkFunSuite {
case (e1: SparkListenerEnvironmentUpdate, e2: SparkListenerEnvironmentUpdate) =>
assertEquals(e1.environmentDetails, e2.environmentDetails)
case (e1: SparkListenerExecutorAdded, e2: SparkListenerExecutorAdded) =>
- assert(e1.executorId == e1.executorId)
+ assert(e1.executorId === e1.executorId)
assertEquals(e1.executorInfo, e2.executorInfo)
case (e1: SparkListenerExecutorRemoved, e2: SparkListenerExecutorRemoved) =>
- assert(e1.executorId == e1.executorId)
+ assert(e1.executorId === e1.executorId)
+ case (e1: SparkListenerExecutorMetricsUpdate, e2: SparkListenerExecutorMetricsUpdate) =>
+ assert(e1.execId === e2.execId)
+ assertSeqEquals[(Long, Int, Int, TaskMetrics)](e1.taskMetrics, e2.taskMetrics, (a, b) => {
+ val (taskId1, stageId1, stageAttemptId1, metrics1) = a
+ val (taskId2, stageId2, stageAttemptId2, metrics2) = b
+ assert(taskId1 === taskId2)
+ assert(stageId1 === stageId2)
+ assert(stageAttemptId1 === stageAttemptId2)
+ assertEquals(metrics1, metrics2)
+ })
case (e1, e2) =>
assert(e1 === e2)
case _ => fail("Events don't match in types!")
@@ -1598,4 +1612,55 @@ class JsonProtocolSuite extends SparkFunSuite {
| "Removed Reason": "test reason"
|}
"""
+
+ private val executorMetricsUpdateJsonString =
+ s"""
+ |{
+ | "Event": "SparkListenerExecutorMetricsUpdate",
+ | "Executor ID": "exec3",
+ | "Metrics Updated": [
+ | {
+ | "Task ID": 1,
+ | "Stage ID": 2,
+ | "Stage Attempt ID": 3,
+ | "Task Metrics": {
+ | "Host Name": "localhost",
+ | "Executor Deserialize Time": 300,
+ | "Executor Run Time": 400,
+ | "Result Size": 500,
+ | "JVM GC Time": 600,
+ | "Result Serialization Time": 700,
+ | "Memory Bytes Spilled": 800,
+ | "Disk Bytes Spilled": 0,
+ | "Input Metrics": {
+ | "Data Read Method": "Hadoop",
+ | "Bytes Read": 2100,
+ | "Records Read": 21
+ | },
+ | "Output Metrics": {
+ | "Data Write Method": "Hadoop",
+ | "Bytes Written": 1200,
+ | "Records Written": 12
+ | },
+ | "Updated Blocks": [
+ | {
+ | "Block ID": "rdd_0_0",
+ | "Status": {
+ | "Storage Level": {
+ | "Use Disk": true,
+ | "Use Memory": true,
+ | "Use ExternalBlockStore": false,
+ | "Deserialized": false,
+ | "Replication": 2
+ | },
+ | "Memory Size": 0,
+ | "ExternalBlockStore Size": 0,
+ | "Disk Size": 0
+ | }
+ | }
+ | ]
+ | }
+ | }]
+ |}
+ """.stripMargin
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index fe57d17f1ec14..280ae0e546358 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -526,7 +526,7 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression)
/**
* Returns the input formatted according do printf-style format strings
*/
-case class StringFormat(children: Expression*) extends Expression with CodegenFallback {
+case class StringFormat(children: Expression*) extends Expression with ImplicitCastInputTypes {
require(children.nonEmpty, "printf() should take at least 1 argument")
@@ -536,6 +536,10 @@ case class StringFormat(children: Expression*) extends Expression with CodegenFa
private def format: Expression = children(0)
private def args: Seq[Expression] = children.tail
+ override def inputTypes: Seq[AbstractDataType] =
+ children.zipWithIndex.map(x => if (x._2 == 0) StringType else AnyDataType)
+
+
override def eval(input: InternalRow): Any = {
val pattern = format.eval(input)
if (pattern == null) {
@@ -551,6 +555,42 @@ case class StringFormat(children: Expression*) extends Expression with CodegenFa
}
}
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ val pattern = children.head.gen(ctx)
+
+ val argListGen = children.tail.map(x => (x.dataType, x.gen(ctx)))
+ val argListCode = argListGen.map(_._2.code + "\n")
+
+ val argListString = argListGen.foldLeft("")((s, v) => {
+ val nullSafeString =
+ if (ctx.boxedType(v._1) != ctx.javaType(v._1)) {
+ // Java primitives get boxed in order to allow null values.
+ s"(${v._2.isNull}) ? (${ctx.boxedType(v._1)}) null : " +
+ s"new ${ctx.boxedType(v._1)}(${v._2.primitive})"
+ } else {
+ s"(${v._2.isNull}) ? null : ${v._2.primitive}"
+ }
+ s + "," + nullSafeString
+ })
+
+ val form = ctx.freshName("formatter")
+ val formatter = classOf[java.util.Formatter].getName
+ val sb = ctx.freshName("sb")
+ val stringBuffer = classOf[StringBuffer].getName
+ s"""
+ ${pattern.code}
+ boolean ${ev.isNull} = ${pattern.isNull};
+ ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
+ if (!${ev.isNull}) {
+ ${argListCode.mkString}
+ $stringBuffer $sb = new $stringBuffer();
+ $formatter $form = new $formatter($sb, ${classOf[Locale].getName}.US);
+ $form.format(${pattern.primitive}.toString() $argListString);
+ ${ev.primitive} = UTF8String.fromString($sb.toString());
+ }
+ """
+ }
+
override def prettyName: String = "printf"
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
index 506a447492dfc..3c2d88731beb4 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
@@ -351,18 +351,16 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}
test("FORMAT") {
- val f = 'f.string.at(0)
- val d1 = 'd.int.at(1)
- val s1 = 's.string.at(2)
-
- val row1 = create_row("aa%d%s", 12, "cc")
- val row2 = create_row(null, 12, "cc")
- checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a", row1)
+ checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a")
checkEvaluation(StringFormat(Literal("aa")), "aa", create_row(null))
- checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a", row1)
+ checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a")
+ checkEvaluation(StringFormat(Literal("aa%d%s"), 12, "cc"), "aa12cc")
- checkEvaluation(StringFormat(f, d1, s1), "aa12cc", row1)
- checkEvaluation(StringFormat(f, d1, s1), null, row2)
+ checkEvaluation(StringFormat(Literal.create(null, StringType), 12, "cc"), null)
+ checkEvaluation(
+ StringFormat(Literal("aa%d%s"), Literal.create(null, IntegerType), "cc"), "aanullcc")
+ checkEvaluation(
+ StringFormat(Literal("aa%d%s"), 12, Literal.create(null, StringType)), "aa12null")
}
test("INSTR") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
index d1f855903ca4b..3702e73b4e74f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
@@ -132,6 +132,16 @@ class StringFunctionsSuite extends QueryTest {
checkAnswer(
df.selectExpr("printf(a, b, c)"),
Row("aa123cc"))
+
+ val df2 = Seq(("aa%d%s".getBytes, 123, "cc")).toDF("a", "b", "c")
+
+ checkAnswer(
+ df2.select(formatString($"a", $"b", $"c"), formatString("aa%d%s", "b", "c")),
+ Row("aa123cc", "aa123cc"))
+
+ checkAnswer(
+ df2.selectExpr("printf(a, b, c)"),
+ Row("aa123cc"))
}
test("string instr function") {