diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
index 716d604ca31b4..066512d159d00 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
@@ -214,7 +214,6 @@ private[ui] class AllJobsPage(parent: JobsTab, store: AppStatusStore) extends We
val someJobHasJobGroup = jobs.exists(_.jobGroup.isDefined)
val jobIdTitle = if (someJobHasJobGroup) "Job Id (Job Group)" else "Job Id"
val jobPage = Option(request.getParameter(jobTag + ".page")).map(_.toInt).getOrElse(1)
- val currentTime = System.currentTimeMillis()
try {
new JobPagedTable(
@@ -226,7 +225,6 @@ private[ui] class AllJobsPage(parent: JobsTab, store: AppStatusStore) extends We
UIUtils.prependBaseUri(request, parent.basePath),
"jobs", // subPath
killEnabled,
- currentTime,
jobIdTitle
).table(jobPage)
} catch {
@@ -399,7 +397,6 @@ private[ui] class JobDataSource(
store: AppStatusStore,
jobs: Seq[v1.JobData],
basePath: String,
- currentTime: Long,
pageSize: Int,
sortColumn: String,
desc: Boolean) extends PagedDataSource[JobTableRowData](pageSize) {
@@ -410,15 +407,9 @@ private[ui] class JobDataSource(
// so that we can avoid creating duplicate contents during sorting the data
private val data = jobs.map(jobRow).sorted(ordering(sortColumn, desc))
- private var _slicedJobIds: Set[Int] = null
-
override def dataSize: Int = data.size
- override def sliceData(from: Int, to: Int): Seq[JobTableRowData] = {
- val r = data.slice(from, to)
- _slicedJobIds = r.map(_.jobData.jobId).toSet
- r
- }
+ override def sliceData(from: Int, to: Int): Seq[JobTableRowData] = data.slice(from, to)
private def jobRow(jobData: v1.JobData): JobTableRowData = {
val duration: Option[Long] = JobDataUtil.getDuration(jobData)
@@ -479,17 +470,17 @@ private[ui] class JobPagedTable(
basePath: String,
subPath: String,
killEnabled: Boolean,
- currentTime: Long,
jobIdTitle: String
) extends PagedTable[JobTableRowData] {
+
private val (sortColumn, desc, pageSize) = getTableParameters(request, jobTag, jobIdTitle)
private val parameterPath = basePath + s"/$subPath/?" + getParameterOtherTable(request, jobTag)
+ private val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name())
override def tableId: String = jobTag + "-table"
override def tableCssClass: String =
- "table table-bordered table-sm table-striped " +
- "table-head-clickable table-cell-width-limited"
+ "table table-bordered table-sm table-striped table-head-clickable table-cell-width-limited"
override def pageSizeFormField: String = jobTag + ".pageSize"
@@ -499,13 +490,11 @@ private[ui] class JobPagedTable(
store,
data,
basePath,
- currentTime,
pageSize,
sortColumn,
desc)
override def pageLink(page: Int): String = {
- val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name())
parameterPath +
s"&$pageNumberFormField=$page" +
s"&$jobTag.sort=$encodedSortColumn" +
@@ -514,10 +503,8 @@ private[ui] class JobPagedTable(
s"#$tableHeaderId"
}
- override def goButtonFormPath: String = {
- val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name())
+ override def goButtonFormPath: String =
s"$parameterPath&$jobTag.sort=$encodedSortColumn&$jobTag.desc=$desc#$tableHeaderId"
- }
override def headers: Seq[Node] = {
// Information for each header: title, sortable, tooltip
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
index 1b072274541c8..47ba951953cec 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
@@ -212,7 +212,6 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We
stageData,
UIUtils.prependBaseUri(request, parent.basePath) +
s"/stages/stage/?id=${stageId}&attempt=${stageAttemptId}",
- currentTime,
pageSize = taskPageSize,
sortColumn = taskSortColumn,
desc = taskSortDesc,
@@ -452,7 +451,6 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We
private[ui] class TaskDataSource(
stage: StageData,
- currentTime: Long,
pageSize: Int,
sortColumn: String,
desc: Boolean,
@@ -474,8 +472,6 @@ private[ui] class TaskDataSource(
_tasksToShow
}
- def tasks: Seq[TaskData] = _tasksToShow
-
def executorLogs(id: String): Map[String, String] = {
executorIdToLogs.getOrElseUpdate(id,
store.asOption(store.executorSummary(id)).map(_.executorLogs).getOrElse(Map.empty))
@@ -486,7 +482,6 @@ private[ui] class TaskDataSource(
private[ui] class TaskPagedTable(
stage: StageData,
basePath: String,
- currentTime: Long,
pageSize: Int,
sortColumn: String,
desc: Boolean,
@@ -494,6 +489,8 @@ private[ui] class TaskPagedTable(
import ApiHelper._
+ private val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name())
+
override def tableId: String = "task-table"
override def tableCssClass: String =
@@ -505,14 +502,12 @@ private[ui] class TaskPagedTable(
override val dataSource: TaskDataSource = new TaskDataSource(
stage,
- currentTime,
pageSize,
sortColumn,
desc,
store)
override def pageLink(page: Int): String = {
- val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name())
basePath +
s"&$pageNumberFormField=$page" +
s"&task.sort=$encodedSortColumn" +
@@ -520,10 +515,7 @@ private[ui] class TaskPagedTable(
s"&$pageSizeFormField=$pageSize"
}
- override def goButtonFormPath: String = {
- val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name())
- s"$basePath&task.sort=$encodedSortColumn&task.desc=$desc"
- }
+ override def goButtonFormPath: String = s"$basePath&task.sort=$encodedSortColumn&task.desc=$desc"
def headers: Seq[Node] = {
import ApiHelper._
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
index f9e84c2b2f4ec..9e6eb418fe134 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
@@ -116,8 +116,7 @@ private[ui] class StagePagedTable(
override def tableId: String = stageTag + "-table"
override def tableCssClass: String =
- "table table-bordered table-sm table-striped " +
- "table-head-clickable table-cell-width-limited"
+ "table table-bordered table-sm table-striped table-head-clickable table-cell-width-limited"
override def pageSizeFormField: String = stageTag + ".pageSize"
@@ -125,7 +124,9 @@ private[ui] class StagePagedTable(
private val (sortColumn, desc, pageSize) = getTableParameters(request, stageTag, "Stage Id")
- val parameterPath = UIUtils.prependBaseUri(request, basePath) + s"/$subPath/?" +
+ private val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name())
+
+ private val parameterPath = UIUtils.prependBaseUri(request, basePath) + s"/$subPath/?" +
getParameterOtherTable(request, stageTag)
override val dataSource = new StageDataSource(
@@ -138,7 +139,6 @@ private[ui] class StagePagedTable(
)
override def pageLink(page: Int): String = {
- val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name())
parameterPath +
s"&$pageNumberFormField=$page" +
s"&$stageTag.sort=$encodedSortColumn" +
@@ -147,10 +147,8 @@ private[ui] class StagePagedTable(
s"#$tableHeaderId"
}
- override def goButtonFormPath: String = {
- val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name())
+ override def goButtonFormPath: String =
s"$parameterPath&$stageTag.sort=$encodedSortColumn&$stageTag.desc=$desc#$tableHeaderId"
- }
override def headers: Seq[Node] = {
// stageHeadersAndCssClasses has three parts: header title, sortable and tooltip information.
@@ -311,15 +309,9 @@ private[ui] class StageDataSource(
// table so that we can avoid creating duplicate contents during sorting the data
private val data = stages.map(stageRow).sorted(ordering(sortColumn, desc))
- private var _slicedStageIds: Set[Int] = _
-
override def dataSize: Int = data.size
- override def sliceData(from: Int, to: Int): Seq[StageTableRowData] = {
- val r = data.slice(from, to)
- _slicedStageIds = r.map(_.stageId).toSet
- r
- }
+ override def sliceData(from: Int, to: Int): Seq[StageTableRowData] = data.slice(from, to)
private def stageRow(stageData: v1.StageData): StageTableRowData = {
val formattedSubmissionTime = stageData.submissionTime match {
@@ -350,7 +342,6 @@ private[ui] class StageDataSource(
val shuffleWrite = stageData.shuffleWriteBytes
val shuffleWriteWithUnit = if (shuffleWrite > 0) Utils.bytesToString(shuffleWrite) else ""
-
new StageTableRowData(
stageData,
Some(stageData),
diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
index 26bbff5e54d25..844d9b7cf2c27 100644
--- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
@@ -487,6 +487,7 @@ private[spark] object JsonProtocol {
("Callsite" -> rddInfo.callSite) ~
("Parent IDs" -> parentIds) ~
("Storage Level" -> storageLevel) ~
+ ("Barrier" -> rddInfo.isBarrier) ~
("Number of Partitions" -> rddInfo.numPartitions) ~
("Number of Cached Partitions" -> rddInfo.numCachedPartitions) ~
("Memory Size" -> rddInfo.memSize) ~
diff --git a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala
index 6191e41b4118f..54899bfcf34fa 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala
@@ -25,6 +25,7 @@ import org.scalatest.concurrent.Eventually
import org.scalatest.time.SpanSugar._
import org.apache.spark._
+import org.apache.spark.internal.config
import org.apache.spark.internal.config.Tests.TEST_NO_STAGE_RETRY
class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext with Eventually {
@@ -37,10 +38,10 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext with
.setAppName("test-cluster")
.set(TEST_NO_STAGE_RETRY, true)
sc = new SparkContext(conf)
+ TestUtils.waitUntilExecutorsUp(sc, numWorker, 60000)
}
- // TODO (SPARK-31730): re-enable it
- ignore("global sync by barrier() call") {
+ test("global sync by barrier() call") {
initLocalClusterSparkContext()
val rdd = sc.makeRDD(1 to 10, 4)
val rdd2 = rdd.barrier().mapPartitions { it =>
@@ -57,10 +58,7 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext with
}
test("share messages with allGather() call") {
- val conf = new SparkConf()
- .setMaster("local-cluster[4, 1, 1024]")
- .setAppName("test-cluster")
- sc = new SparkContext(conf)
+ initLocalClusterSparkContext()
val rdd = sc.makeRDD(1 to 10, 4)
val rdd2 = rdd.barrier().mapPartitions { it =>
val context = BarrierTaskContext.get()
@@ -78,10 +76,7 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext with
}
test("throw exception if we attempt to synchronize with different blocking calls") {
- val conf = new SparkConf()
- .setMaster("local-cluster[4, 1, 1024]")
- .setAppName("test-cluster")
- sc = new SparkContext(conf)
+ initLocalClusterSparkContext()
val rdd = sc.makeRDD(1 to 10, 4)
val rdd2 = rdd.barrier().mapPartitions { it =>
val context = BarrierTaskContext.get()
@@ -100,10 +95,7 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext with
}
test("successively sync with allGather and barrier") {
- val conf = new SparkConf()
- .setMaster("local-cluster[4, 1, 1024]")
- .setAppName("test-cluster")
- sc = new SparkContext(conf)
+ initLocalClusterSparkContext()
val rdd = sc.makeRDD(1 to 10, 4)
val rdd2 = rdd.barrier().mapPartitions { it =>
val context = BarrierTaskContext.get()
@@ -129,8 +121,7 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext with
assert(times2.max - times2.min <= 1000)
}
- // TODO (SPARK-31730): re-enable it
- ignore("support multiple barrier() call within a single task") {
+ test("support multiple barrier() call within a single task") {
initLocalClusterSparkContext()
val rdd = sc.makeRDD(1 to 10, 4)
val rdd2 = rdd.barrier().mapPartitions { it =>
@@ -285,6 +276,9 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext with
test("SPARK-31485: barrier stage should fail if only partial tasks are launched") {
initLocalClusterSparkContext(2)
+ // It's required to reset the delay timer when a task is scheduled, otherwise all the tasks
+ // could get scheduled at ANY level.
+ sc.conf.set(config.LEGACY_LOCALITY_WAIT_RESET, true)
val rdd0 = sc.parallelize(Seq(0, 1, 2, 3), 2)
val dep = new OneToOneDependency[Int](rdd0)
// set up a barrier stage with 2 tasks and both tasks prefer executor 0 (only 1 core) for
diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala
index 61ea21fa86c5a..7c23e4449f461 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala
@@ -36,6 +36,7 @@ import org.apache.spark.deploy.history.{EventLogFileReader, SingleEventLogFileWr
import org.apache.spark.deploy.history.EventLogTestHelper._
import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics}
import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config.{EVENT_LOG_DIR, EVENT_LOG_ENABLED}
import org.apache.spark.io._
import org.apache.spark.metrics.{ExecutorMetricType, MetricsSystem}
import org.apache.spark.resource.ResourceProfile
@@ -100,6 +101,49 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit
testStageExecutorMetricsEventLogging()
}
+ test("SPARK-31764: isBarrier should be logged in event log") {
+ val conf = new SparkConf()
+ conf.set(EVENT_LOG_ENABLED, true)
+ conf.set(EVENT_LOG_DIR, testDirPath.toString)
+ val sc = new SparkContext("local", "test-SPARK-31764", conf)
+ val appId = sc.applicationId
+
+ sc.parallelize(1 to 10)
+ .barrier()
+ .mapPartitions(_.map(elem => (elem, elem)))
+ .filter(elem => elem._1 % 2 == 0)
+ .reduceByKey(_ + _)
+ .collect
+ sc.stop()
+
+ val eventLogStream = EventLogFileReader.openEventLog(new Path(testDirPath, appId), fileSystem)
+ val events = readLines(eventLogStream).map(line => JsonProtocol.sparkEventFromJson(parse(line)))
+ val jobStartEvents = events
+ .filter(event => event.isInstanceOf[SparkListenerJobStart])
+ .map(_.asInstanceOf[SparkListenerJobStart])
+
+ assert(jobStartEvents.size === 1)
+ val stageInfos = jobStartEvents.head.stageInfos
+ assert(stageInfos.size === 2)
+
+ val stage0 = stageInfos(0)
+ val rddInfosInStage0 = stage0.rddInfos
+ assert(rddInfosInStage0.size === 3)
+ val sortedRddInfosInStage0 = rddInfosInStage0.sortBy(_.scope.get.name)
+ assert(sortedRddInfosInStage0(0).scope.get.name === "filter")
+ assert(sortedRddInfosInStage0(0).isBarrier === true)
+ assert(sortedRddInfosInStage0(1).scope.get.name === "mapPartitions")
+ assert(sortedRddInfosInStage0(1).isBarrier === true)
+ assert(sortedRddInfosInStage0(2).scope.get.name === "parallelize")
+ assert(sortedRddInfosInStage0(2).isBarrier === false)
+
+ val stage1 = stageInfos(1)
+ val rddInfosInStage1 = stage1.rddInfos
+ assert(rddInfosInStage1.size === 1)
+ assert(rddInfosInStage1(0).scope.get.name === "reduceByKey")
+ assert(rddInfosInStage1(0).isBarrier === false) // reduceByKey
+ }
+
/* ----------------- *
* Actual test logic *
* ----------------- */
diff --git a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala
index 5d34a56473375..3d52199b01327 100644
--- a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala
@@ -98,7 +98,6 @@ class StagePageSuite extends SparkFunSuite with LocalSparkContext {
val taskTable = new TaskPagedTable(
stageData,
basePath = "/a/b/c",
- currentTime = 0,
pageSize = 10,
sortColumn = "Index",
desc = false,
diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
index bc7f8b5d719db..248142a5ad633 100644
--- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
@@ -1100,6 +1100,7 @@ private[spark] object JsonProtocolSuite extends Assertions {
| "Deserialized": true,
| "Replication": 1
| },
+ | "Barrier" : false,
| "Number of Partitions": 201,
| "Number of Cached Partitions": 301,
| "Memory Size": 401,
@@ -1623,6 +1624,7 @@ private[spark] object JsonProtocolSuite extends Assertions {
| "Deserialized": true,
| "Replication": 1
| },
+ | "Barrier" : false,
| "Number of Partitions": 200,
| "Number of Cached Partitions": 300,
| "Memory Size": 400,
@@ -1668,6 +1670,7 @@ private[spark] object JsonProtocolSuite extends Assertions {
| "Deserialized": true,
| "Replication": 1
| },
+ | "Barrier" : false,
| "Number of Partitions": 400,
| "Number of Cached Partitions": 600,
| "Memory Size": 800,
@@ -1684,6 +1687,7 @@ private[spark] object JsonProtocolSuite extends Assertions {
| "Deserialized": true,
| "Replication": 1
| },
+ | "Barrier" : false,
| "Number of Partitions": 401,
| "Number of Cached Partitions": 601,
| "Memory Size": 801,
@@ -1729,6 +1733,7 @@ private[spark] object JsonProtocolSuite extends Assertions {
| "Deserialized": true,
| "Replication": 1
| },
+ | "Barrier" : false,
| "Number of Partitions": 600,
| "Number of Cached Partitions": 900,
| "Memory Size": 1200,
@@ -1745,6 +1750,7 @@ private[spark] object JsonProtocolSuite extends Assertions {
| "Deserialized": true,
| "Replication": 1
| },
+ | "Barrier" : false,
| "Number of Partitions": 601,
| "Number of Cached Partitions": 901,
| "Memory Size": 1201,
@@ -1761,6 +1767,7 @@ private[spark] object JsonProtocolSuite extends Assertions {
| "Deserialized": true,
| "Replication": 1
| },
+ | "Barrier" : false,
| "Number of Partitions": 602,
| "Number of Cached Partitions": 902,
| "Memory Size": 1202,
@@ -1806,6 +1813,7 @@ private[spark] object JsonProtocolSuite extends Assertions {
| "Deserialized": true,
| "Replication": 1
| },
+ | "Barrier" : false,
| "Number of Partitions": 800,
| "Number of Cached Partitions": 1200,
| "Memory Size": 1600,
@@ -1822,6 +1830,7 @@ private[spark] object JsonProtocolSuite extends Assertions {
| "Deserialized": true,
| "Replication": 1
| },
+ | "Barrier" : false,
| "Number of Partitions": 801,
| "Number of Cached Partitions": 1201,
| "Memory Size": 1601,
@@ -1838,6 +1847,7 @@ private[spark] object JsonProtocolSuite extends Assertions {
| "Deserialized": true,
| "Replication": 1
| },
+ | "Barrier" : false,
| "Number of Partitions": 802,
| "Number of Cached Partitions": 1202,
| "Memory Size": 1602,
@@ -1854,6 +1864,7 @@ private[spark] object JsonProtocolSuite extends Assertions {
| "Deserialized": true,
| "Replication": 1
| },
+ | "Barrier" : false,
| "Number of Partitions": 803,
| "Number of Cached Partitions": 1203,
| "Memory Size": 1603,
diff --git a/dev/deps/spark-deps-hadoop-3.2-hive-2.3 b/dev/deps/spark-deps-hadoop-3.2-hive-2.3
index 3c3ce2dcdd6d4..b5a10b5dba378 100644
--- a/dev/deps/spark-deps-hadoop-3.2-hive-2.3
+++ b/dev/deps/spark-deps-hadoop-3.2-hive-2.3
@@ -183,7 +183,6 @@ metrics-jmx/4.1.1//metrics-jmx-4.1.1.jar
metrics-json/4.1.1//metrics-json-4.1.1.jar
metrics-jvm/4.1.1//metrics-jvm-4.1.1.jar
minlog/1.3.0//minlog-1.3.0.jar
-mssql-jdbc/6.2.1.jre7//mssql-jdbc-6.2.1.jre7.jar
netty-all/4.1.47.Final//netty-all-4.1.47.Final.jar
nimbus-jose-jwt/4.41.1//nimbus-jose-jwt-4.41.1.jar
objenesis/2.5.1//objenesis-2.5.1.jar
diff --git a/dev/run-tests-jenkins.py b/dev/run-tests-jenkins.py
index 72e32d4e16e14..13be9592d771f 100755
--- a/dev/run-tests-jenkins.py
+++ b/dev/run-tests-jenkins.py
@@ -198,7 +198,7 @@ def main():
# format: http://linux.die.net/man/1/timeout
# must be less than the timeout configured on Jenkins. Usually Jenkins's timeout is higher
# then this. Please consult with the build manager or a committer when it should be increased.
- tests_timeout = "400m"
+ tests_timeout = "500m"
# Array to capture all test names to run on the pull request. These tests are represented
# by their file equivalents in the dev/tests/ directory.
diff --git a/docs/img/webui-structured-streaming-detail.png b/docs/img/webui-structured-streaming-detail.png
new file mode 100644
index 0000000000000..f4850523c5c2f
Binary files /dev/null and b/docs/img/webui-structured-streaming-detail.png differ
diff --git a/docs/sql-ref-datetime-pattern.md b/docs/sql-ref-datetime-pattern.md
index 4275f03335b33..48e85b450e6b2 100644
--- a/docs/sql-ref-datetime-pattern.md
+++ b/docs/sql-ref-datetime-pattern.md
@@ -76,7 +76,7 @@ The count of pattern letters determines the format.
- Year: The count of letters determines the minimum field width below which padding is used. If the count of letters is two, then a reduced two digit form is used. For printing, this outputs the rightmost two digits. For parsing, this will parse using the base value of 2000, resulting in a year within the range 2000 to 2099 inclusive. If the count of letters is less than four (but not two), then the sign is only output for negative years. Otherwise, the sign is output if the pad width is exceeded when 'G' is not present.
-- Month: If the number of pattern letters is 3 or more, the month is interpreted as text; otherwise, it is interpreted as a number. The text form is depend on letters - 'M' denotes the 'standard' form, and 'L' is for 'stand-alone' form. The difference between the 'standard' and 'stand-alone' forms is trickier to describe as there is no difference in English. However, in other languages there is a difference in the word used when the text is used alone, as opposed to in a complete date. For example, the word used for a month when used alone in a date picker is different to the word used for month in association with a day and year in a date. In Russian, 'Июль' is the stand-alone form of July, and 'Июля' is the standard form. Here are examples for all supported pattern letters (more than 4 letters is invalid):
+- Month: It follows the rule of Number/Text. The text form is depend on letters - 'M' denotes the 'standard' form, and 'L' is for 'stand-alone' form. These two forms are different only in some certain languages. For example, in Russian, 'Июль' is the stand-alone form of July, and 'Июля' is the standard form. Here are examples for all supported pattern letters:
- `'M'` or `'L'`: Month number in a year starting from 1. There is no difference between 'M' and 'L'. Month from 1 to 9 are printed without padding.
```sql
spark-sql> select date_format(date '1970-01-01', "M");
@@ -107,8 +107,8 @@ The count of pattern letters determines the format.
```
- `'MMMM'`: full textual month representation in the standard form. It is used for parsing/formatting months as a part of dates/timestamps.
```sql
- spark-sql> select date_format(date '1970-01-01', "MMMM yyyy");
- January 1970
+ spark-sql> select date_format(date '1970-01-01', "d MMMM");
+ 1 January
spark-sql> select to_csv(named_struct('date', date '1970-01-01'), map('dateFormat', 'd MMMM', 'locale', 'RU'));
1 января
```
diff --git a/docs/web-ui.md b/docs/web-ui.md
index 3c35dbeec86a2..e2e612cef3e54 100644
--- a/docs/web-ui.md
+++ b/docs/web-ui.md
@@ -407,6 +407,34 @@ Here is the list of SQL metrics:
+## Structured Streaming Tab
+When running Structured Streaming jobs in micro-batch mode, a Structured Streaming tab will be
+available on the Web UI. The overview page displays some brief statistics for running and completed
+queries. Also, you can check the latest exception of a failed query. For detailed statistics, please
+click a "run id" in the tables.
+
+
+
+
+
+The statistics page displays some useful metrics for insight into the status of your streaming
+queries. Currently, it contains the following metrics.
+
+* **Input Rate.** The aggregate (across all sources) rate of data arriving.
+* **Process Rate.** The aggregate (across all sources) rate at which Spark is processing data.
+* **Input Rows.** The aggregate (across all sources) number of records processed in a trigger.
+* **Batch Duration.** The process duration of each batch.
+* **Operation Duration.** The amount of time taken to perform various operations in milliseconds.
+The tracked operations are listed as follows.
+ * addBatch: Adds result data of the current batch to the sink.
+ * getBatch: Gets a new batch of data to process.
+ * latestOffset: Gets the latest offsets for sources.
+ * queryPlanning: Generates the execution plan.
+ * walCommit: Writes the offsets to the metadata log.
+
+As an early-release version, the statistics page is still under development and will be improved in
+future releases.
+
## Streaming Tab
The web UI includes a Streaming tab if the application uses Spark streaming. This tab displays
scheduling delay and processing time for each micro-batch in the data stream, which can be useful
diff --git a/external/avro/src/test/resources/before_1582_date_v2_4.avro b/external/avro/src/test/resources/before_1582_date_v2_4.avro
deleted file mode 100644
index 96aa7cbf176a5..0000000000000
Binary files a/external/avro/src/test/resources/before_1582_date_v2_4.avro and /dev/null differ
diff --git a/external/avro/src/test/resources/before_1582_date_v2_4_5.avro b/external/avro/src/test/resources/before_1582_date_v2_4_5.avro
new file mode 100644
index 0000000000000..5c15601f7ee4b
Binary files /dev/null and b/external/avro/src/test/resources/before_1582_date_v2_4_5.avro differ
diff --git a/external/avro/src/test/resources/before_1582_date_v2_4_6.avro b/external/avro/src/test/resources/before_1582_date_v2_4_6.avro
new file mode 100644
index 0000000000000..212ea1d5efa5c
Binary files /dev/null and b/external/avro/src/test/resources/before_1582_date_v2_4_6.avro differ
diff --git a/external/avro/src/test/resources/before_1582_timestamp_micros_v2_4_5.avro b/external/avro/src/test/resources/before_1582_timestamp_micros_v2_4_5.avro
new file mode 100644
index 0000000000000..c3445e3999bc1
Binary files /dev/null and b/external/avro/src/test/resources/before_1582_timestamp_micros_v2_4_5.avro differ
diff --git a/external/avro/src/test/resources/before_1582_timestamp_micros_v2_4_6.avro b/external/avro/src/test/resources/before_1582_timestamp_micros_v2_4_6.avro
new file mode 100644
index 0000000000000..96008d2378b1f
Binary files /dev/null and b/external/avro/src/test/resources/before_1582_timestamp_micros_v2_4_6.avro differ
diff --git a/external/avro/src/test/resources/before_1582_ts_millis_v2_4.avro b/external/avro/src/test/resources/before_1582_timestamp_millis_v2_4_5.avro
similarity index 52%
rename from external/avro/src/test/resources/before_1582_ts_millis_v2_4.avro
rename to external/avro/src/test/resources/before_1582_timestamp_millis_v2_4_5.avro
index dbaec814eb954..be12a0782073c 100644
Binary files a/external/avro/src/test/resources/before_1582_ts_millis_v2_4.avro and b/external/avro/src/test/resources/before_1582_timestamp_millis_v2_4_5.avro differ
diff --git a/external/avro/src/test/resources/before_1582_timestamp_millis_v2_4_6.avro b/external/avro/src/test/resources/before_1582_timestamp_millis_v2_4_6.avro
new file mode 100644
index 0000000000000..262f5dd6e77a4
Binary files /dev/null and b/external/avro/src/test/resources/before_1582_timestamp_millis_v2_4_6.avro differ
diff --git a/external/avro/src/test/resources/before_1582_ts_micros_v2_4.avro b/external/avro/src/test/resources/before_1582_ts_micros_v2_4.avro
deleted file mode 100644
index efe5e71a58813..0000000000000
Binary files a/external/avro/src/test/resources/before_1582_ts_micros_v2_4.avro and /dev/null differ
diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
index a5c1fb15add5c..e2ae489446d85 100644
--- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
+++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.avro
import java.io._
import java.net.URL
-import java.nio.file.{Files, Paths}
+import java.nio.file.{Files, Paths, StandardCopyOption}
import java.sql.{Date, Timestamp}
import java.util.{Locale, UUID}
@@ -38,7 +38,7 @@ import org.apache.spark.sql._
import org.apache.spark.sql.TestingUDT.IntervalData
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.logical.Filter
-import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{withDefaultTimeZone, UTC}
+import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{withDefaultTimeZone, LA, UTC}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.datasources.{DataSource, FilePartition}
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
@@ -1529,23 +1529,82 @@ abstract class AvroSuite extends QueryTest with SharedSparkSession {
}
}
+ // It generates input files for the test below:
+ // "SPARK-31183: compatibility with Spark 2.4 in reading dates/timestamps"
+ ignore("SPARK-31855: generate test files for checking compatibility with Spark 2.4") {
+ val resourceDir = "external/avro/src/test/resources"
+ val version = "2_4_6"
+ def save(
+ in: Seq[String],
+ t: String,
+ dstFile: String,
+ options: Map[String, String] = Map.empty): Unit = {
+ withTempDir { dir =>
+ in.toDF("dt")
+ .select($"dt".cast(t))
+ .repartition(1)
+ .write
+ .mode("overwrite")
+ .options(options)
+ .format("avro")
+ .save(dir.getCanonicalPath)
+ Files.copy(
+ dir.listFiles().filter(_.getName.endsWith(".avro")).head.toPath,
+ Paths.get(resourceDir, dstFile),
+ StandardCopyOption.REPLACE_EXISTING)
+ }
+ }
+ withDefaultTimeZone(LA) {
+ withSQLConf(
+ SQLConf.SESSION_LOCAL_TIMEZONE.key -> LA.getId) {
+ save(
+ Seq("1001-01-01"),
+ "date",
+ s"before_1582_date_v$version.avro")
+ save(
+ Seq("1001-01-01 01:02:03.123"),
+ "timestamp",
+ s"before_1582_timestamp_millis_v$version.avro",
+ // scalastyle:off line.size.limit
+ Map("avroSchema" ->
+ s"""
+ | {
+ | "namespace": "logical",
+ | "type": "record",
+ | "name": "test",
+ | "fields": [
+ | {"name": "dt", "type": ["null", {"type": "long","logicalType": "timestamp-millis"}], "default": null}
+ | ]
+ | }
+ |""".stripMargin))
+ // scalastyle:on line.size.limit
+ save(
+ Seq("1001-01-01 01:02:03.123456"),
+ "timestamp",
+ s"before_1582_timestamp_micros_v$version.avro")
+ }
+ }
+ }
+
test("SPARK-31183: compatibility with Spark 2.4 in reading dates/timestamps") {
// test reading the existing 2.4 files and new 3.0 files (with rebase on/off) together.
- def checkReadMixedFiles(fileName: String, dt: String, dataStr: String): Unit = {
+ def checkReadMixedFiles(
+ fileName: String,
+ dt: String,
+ dataStr: String,
+ checkDefaultLegacyRead: String => Unit): Unit = {
withTempPaths(2) { paths =>
paths.foreach(_.delete())
val path2_4 = getResourceAvroFilePath(fileName)
val path3_0 = paths(0).getCanonicalPath
val path3_0_rebase = paths(1).getCanonicalPath
if (dt == "date") {
- val df = Seq(dataStr).toDF("str").select($"str".cast("date").as("date"))
+ val df = Seq(dataStr).toDF("str").select($"str".cast("date").as("dt"))
// By default we should fail to write ancient datetime values.
- var e = intercept[SparkException](df.write.format("avro").save(path3_0))
+ val e = intercept[SparkException](df.write.format("avro").save(path3_0))
assert(e.getCause.getCause.getCause.isInstanceOf[SparkUpgradeException])
- // By default we should fail to read ancient datetime values.
- e = intercept[SparkException](spark.read.format("avro").load(path2_4).collect())
- assert(e.getCause.isInstanceOf[SparkUpgradeException])
+ checkDefaultLegacyRead(path2_4)
withSQLConf(SQLConf.LEGACY_AVRO_REBASE_MODE_IN_WRITE.key -> CORRECTED.toString) {
df.write.format("avro").mode("overwrite").save(path3_0)
@@ -1562,25 +1621,23 @@ abstract class AvroSuite extends QueryTest with SharedSparkSession {
1.to(3).map(_ => Row(java.sql.Date.valueOf(dataStr))))
}
} else {
- val df = Seq(dataStr).toDF("str").select($"str".cast("timestamp").as("ts"))
+ val df = Seq(dataStr).toDF("str").select($"str".cast("timestamp").as("dt"))
val avroSchema =
s"""
|{
| "type" : "record",
| "name" : "test_schema",
| "fields" : [
- | {"name": "ts", "type": {"type": "long", "logicalType": "$dt"}}
+ | {"name": "dt", "type": {"type": "long", "logicalType": "$dt"}}
| ]
|}""".stripMargin
// By default we should fail to write ancient datetime values.
- var e = intercept[SparkException] {
+ val e = intercept[SparkException] {
df.write.format("avro").option("avroSchema", avroSchema).save(path3_0)
}
assert(e.getCause.getCause.getCause.isInstanceOf[SparkUpgradeException])
- // By default we should fail to read ancient datetime values.
- e = intercept[SparkException](spark.read.format("avro").load(path2_4).collect())
- assert(e.getCause.isInstanceOf[SparkUpgradeException])
+ checkDefaultLegacyRead(path2_4)
withSQLConf(SQLConf.LEGACY_AVRO_REBASE_MODE_IN_WRITE.key -> CORRECTED.toString) {
df.write.format("avro").option("avroSchema", avroSchema).mode("overwrite").save(path3_0)
@@ -1600,11 +1657,33 @@ abstract class AvroSuite extends QueryTest with SharedSparkSession {
}
}
- checkReadMixedFiles("before_1582_date_v2_4.avro", "date", "1001-01-01")
- checkReadMixedFiles(
- "before_1582_ts_micros_v2_4.avro", "timestamp-micros", "1001-01-01 01:02:03.123456")
- checkReadMixedFiles(
- "before_1582_ts_millis_v2_4.avro", "timestamp-millis", "1001-01-01 01:02:03.124")
+ def failInRead(path: String): Unit = {
+ val e = intercept[SparkException](spark.read.format("avro").load(path).collect())
+ assert(e.getCause.isInstanceOf[SparkUpgradeException])
+ }
+ def successInRead(path: String): Unit = spark.read.format("avro").load(path).collect()
+ Seq(
+ // By default we should fail to read ancient datetime values when parquet files don't
+ // contain Spark version.
+ "2_4_5" -> failInRead _,
+ "2_4_6" -> successInRead _
+ ).foreach { case (version, checkDefaultRead) =>
+ checkReadMixedFiles(
+ s"before_1582_date_v$version.avro",
+ "date",
+ "1001-01-01",
+ checkDefaultRead)
+ checkReadMixedFiles(
+ s"before_1582_timestamp_micros_v$version.avro",
+ "timestamp-micros",
+ "1001-01-01 01:02:03.123456",
+ checkDefaultRead)
+ checkReadMixedFiles(
+ s"before_1582_timestamp_millis_v$version.avro",
+ "timestamp-millis",
+ "1001-01-01 01:02:03.123",
+ checkDefaultRead)
+ }
}
test("SPARK-31183: rebasing microseconds timestamps in write") {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
index febeba7e13fcb..e0b128e369816 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
@@ -19,6 +19,7 @@ package org.apache.spark.ml
import org.apache.spark.annotation.Since
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
+import org.apache.spark.ml.functions.checkNonNegativeWeight
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
@@ -71,7 +72,7 @@ private[ml] trait PredictorParams extends Params
val w = this match {
case p: HasWeightCol =>
if (isDefined(p.weightCol) && $(p.weightCol).nonEmpty) {
- col($(p.weightCol)).cast(DoubleType)
+ checkNonNegativeWeight((col($(p.weightCol)).cast(DoubleType)))
} else {
lit(1.0)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
index 5459a0fab9135..e65295dbdaf55 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
@@ -22,6 +22,7 @@ import org.json4s.DefaultFormats
import org.apache.spark.annotation.Since
import org.apache.spark.ml.PredictorParams
+import org.apache.spark.ml.functions.checkNonNegativeWeight
import org.apache.spark.ml.linalg._
import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.HasWeightCol
@@ -179,7 +180,7 @@ class NaiveBayes @Since("1.5.0") (
}
val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) {
- col($(weightCol)).cast(DoubleType)
+ checkNonNegativeWeight(col($(weightCol)).cast(DoubleType))
} else {
lit(1.0)
}
@@ -259,7 +260,7 @@ class NaiveBayes @Since("1.5.0") (
import spark.implicits._
val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) {
- col($(weightCol)).cast(DoubleType)
+ checkNonNegativeWeight(col($(weightCol)).cast(DoubleType))
} else {
lit(1.0)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
index 6c7112b80569f..b09f11dcfe156 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
@@ -21,6 +21,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.annotation.Since
import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.functions.checkNonNegativeWeight
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
@@ -280,7 +281,7 @@ class BisectingKMeans @Since("2.0.0") (
val handlePersistence = dataset.storageLevel == StorageLevel.NONE
val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) {
- col($(weightCol)).cast(DoubleType)
+ checkNonNegativeWeight(col($(weightCol)).cast(DoubleType))
} else {
lit(1.0)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
index 6d4137b638dcc..18fd220b4ca9c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
@@ -22,6 +22,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.annotation.Since
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.functions.checkNonNegativeWeight
import org.apache.spark.ml.impl.Utils.{unpackUpperTriangular, EPSILON}
import org.apache.spark.ml.linalg._
import org.apache.spark.ml.param._
@@ -417,7 +418,7 @@ class GaussianMixture @Since("2.0.0") (
instr.logNumFeatures(numFeatures)
val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) {
- col($(weightCol)).cast(DoubleType)
+ checkNonNegativeWeight(col($(weightCol)).cast(DoubleType))
} else {
lit(1.0)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index a42c920e24987..806015b633c23 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -23,6 +23,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.annotation.Since
import org.apache.spark.ml.{Estimator, Model, PipelineStage}
+import org.apache.spark.ml.functions.checkNonNegativeWeight
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
@@ -336,7 +337,7 @@ class KMeans @Since("1.5.0") (
val handlePersistence = dataset.storageLevel == StorageLevel.NONE
val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) {
- col($(weightCol)).cast(DoubleType)
+ checkNonNegativeWeight(col($(weightCol)).cast(DoubleType))
} else {
lit(1.0)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
index fac4d92b1810c..52be22f714981 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
@@ -18,6 +18,7 @@
package org.apache.spark.ml.evaluation
import org.apache.spark.annotation.Since
+import org.apache.spark.ml.functions.checkNonNegativeWeight
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
@@ -131,7 +132,7 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va
col($(rawPredictionCol)),
col($(labelCol)).cast(DoubleType),
if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0)
- else col($(weightCol)).cast(DoubleType)).rdd.map {
+ else checkNonNegativeWeight(col($(weightCol)).cast(DoubleType))).rdd.map {
case Row(rawPrediction: Vector, label: Double, weight: Double) =>
(rawPrediction(1), label, weight)
case Row(rawPrediction: Double, label: Double, weight: Double) =>
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala
index 19790fd270619..fa2c25a5912a7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala
@@ -18,6 +18,7 @@
package org.apache.spark.ml.evaluation
import org.apache.spark.annotation.Since
+import org.apache.spark.ml.functions.checkNonNegativeWeight
import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasPredictionCol, HasWeightCol}
import org.apache.spark.ml.util._
@@ -139,7 +140,7 @@ class ClusteringEvaluator @Since("2.3.0") (@Since("2.3.0") override val uid: Str
} else {
dataset.select(col($(predictionCol)),
vectorCol.as($(featuresCol), dataset.schema($(featuresCol)).metadata),
- col(weightColName).cast(DoubleType))
+ checkNonNegativeWeight(col(weightColName).cast(DoubleType)))
}
val metrics = new ClusteringMetrics(df)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringMetrics.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringMetrics.scala
index 8bf4ee1ecadfb..a785d063f1476 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringMetrics.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringMetrics.scala
@@ -300,7 +300,6 @@ private[evaluation] object SquaredEuclideanSilhouette extends Silhouette {
(featureSum: DenseVector, squaredNormSum: Double, weightSum: Double),
(features, squaredNorm, weight)
) =>
- require(weight >= 0.0, s"illegal weight value: $weight. weight must be >= 0.0.")
BLAS.axpy(weight, features, featureSum)
(featureSum, squaredNormSum + squaredNorm * weight, weightSum + weight)
},
@@ -503,7 +502,6 @@ private[evaluation] object CosineSilhouette extends Silhouette {
seqOp = {
case ((normalizedFeaturesSum: DenseVector, weightSum: Double),
(normalizedFeatures, weight)) =>
- require(weight >= 0.0, s"illegal weight value: $weight. weight must be >= 0.0.")
BLAS.axpy(weight, normalizedFeatures, normalizedFeaturesSum)
(normalizedFeaturesSum, weightSum + weight)
},
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
index ad1b70915e157..3d77792c4fc88 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
@@ -18,6 +18,7 @@
package org.apache.spark.ml.evaluation
import org.apache.spark.annotation.Since
+import org.apache.spark.ml.functions.checkNonNegativeWeight
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
@@ -186,7 +187,7 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid
SchemaUtils.checkNumericType(schema, $(labelCol))
val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) {
- col($(weightCol)).cast(DoubleType)
+ checkNonNegativeWeight(col($(weightCol)).cast(DoubleType))
} else {
lit(1.0)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
index aca017762deca..f0b7c345c3285 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
@@ -18,6 +18,7 @@
package org.apache.spark.ml.evaluation
import org.apache.spark.annotation.Since
+import org.apache.spark.ml.functions.checkNonNegativeWeight
import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol, HasWeightCol}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
@@ -122,7 +123,8 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui
val predictionAndLabelsWithWeights = dataset
.select(col($(predictionCol)).cast(DoubleType), col($(labelCol)).cast(DoubleType),
- if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)))
+ if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0)
+ else checkNonNegativeWeight(col($(weightCol)).cast(DoubleType)))
.rdd
.map { case Row(prediction: Double, label: Double, weight: Double) =>
(prediction, label, weight) }
diff --git a/mllib/src/main/scala/org/apache/spark/ml/functions.scala b/mllib/src/main/scala/org/apache/spark/ml/functions.scala
index 0f03231079866..a0b6d11a46be9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/functions.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/functions.scala
@@ -71,4 +71,10 @@ object functions {
)
}
}
+
+ private[ml] def checkNonNegativeWeight = udf {
+ value: Double =>
+ require(value >= 0, s"illegal weight value: $value. weight must be >= 0.0.")
+ value
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index fa41a98749f32..0ee895a95a288 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -29,6 +29,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.ml.PredictorParams
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.feature.{Instance, OffsetInstance}
+import org.apache.spark.ml.functions.checkNonNegativeWeight
import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors}
import org.apache.spark.ml.optim._
import org.apache.spark.ml.param._
@@ -399,7 +400,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
"GeneralizedLinearRegression was given data with 0 features, and with Param fitIntercept " +
"set to false. To fit a model with 0 features, fitIntercept must be set to true." )
- val w = if (!hasWeightCol) lit(1.0) else col($(weightCol))
+ val w = if (!hasWeightCol) lit(1.0) else checkNonNegativeWeight(col($(weightCol)))
val offset = if (!hasOffsetCol) lit(0.0) else col($(offsetCol)).cast(DoubleType)
val model = if (familyAndLink.family == Gaussian && familyAndLink.link == Identity) {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
index fe4de57de60f2..ec2640e9ef225 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
@@ -22,6 +22,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.annotation.Since
import org.apache.spark.internal.Logging
import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.functions.checkNonNegativeWeight
import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
@@ -87,11 +88,11 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures
} else {
col($(featuresCol))
}
- val w = if (hasWeightCol) col($(weightCol)).cast(DoubleType) else lit(1.0)
+ val w =
+ if (hasWeightCol) checkNonNegativeWeight(col($(weightCol)).cast(DoubleType)) else lit(1.0)
dataset.select(col($(labelCol)).cast(DoubleType), f, w).rdd.map {
- case Row(label: Double, feature: Double, weight: Double) =>
- (label, feature, weight)
+ case Row(label: Double, feature: Double, weight: Double) => (label, feature, weight)
}
}
diff --git a/pom.xml b/pom.xml
index fd4cebcd37319..deaf87f15539c 100644
--- a/pom.xml
+++ b/pom.xml
@@ -1357,6 +1357,10 @@
com.zaxxer
HikariCP-java7
+
+ com.microsoft.sqlserver
+ mssql-jdbc
+
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index b80149afa2af4..4f29f2f0be1e8 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -25,7 +25,6 @@
from tempfile import NamedTemporaryFile
from py4j.protocol import Py4JError
-from py4j.java_gateway import is_instance_of
from pyspark import accumulators
from pyspark.accumulators import Accumulator
@@ -865,17 +864,10 @@ def union(self, rdds):
first_jrdd_deserializer = rdds[0]._jrdd_deserializer
if any(x._jrdd_deserializer != first_jrdd_deserializer for x in rdds):
rdds = [x._reserialize() for x in rdds]
- gw = SparkContext._gateway
cls = SparkContext._jvm.org.apache.spark.api.java.JavaRDD
- is_jrdd = is_instance_of(gw, rdds[0]._jrdd, cls)
- jrdds = gw.new_array(cls, len(rdds))
+ jrdds = SparkContext._gateway.new_array(cls, len(rdds))
for i in range(0, len(rdds)):
- if is_jrdd:
- jrdds[i] = rdds[i]._jrdd
- else:
- # zip could return JavaPairRDD hence we ensure `_jrdd`
- # to be `JavaRDD` by wrapping it in a `map`
- jrdds[i] = rdds[i].map(lambda x: x)._jrdd
+ jrdds[i] = rdds[i]._jrdd
return RDD(self._jsc.union(jrdds), self, rdds[0]._jrdd_deserializer)
def broadcast(self, value):
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index a3ce87096e790..65b902cf3c4d5 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -2219,6 +2219,20 @@ def semanticHash(self):
"""
return self._jdf.semanticHash()
+ @since(3.1)
+ def inputFiles(self):
+ """
+ Returns a best-effort snapshot of the files that compose this :class:`DataFrame`.
+ This method simply asks each constituent BaseRelation for its respective files and
+ takes the union of all results. Depending on the source relations, this may not find
+ all input files. Duplicates are removed.
+
+ >>> df = spark.read.load("examples/src/main/resources/people.json", format="json")
+ >>> len(df.inputFiles())
+ 1
+ """
+ return list(self._jdf.inputFiles())
+
where = copy_func(
filter,
sinceversion=1.3,
diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py
index 4dd15d14b9c53..ff0b10a9306cf 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -154,6 +154,9 @@ def create_array(s, t):
# Ensure timestamp series are in expected form for Spark internal representation
if t is not None and pa.types.is_timestamp(t):
s = _check_series_convert_timestamps_internal(s, self._timezone)
+ elif type(s.dtype) == pd.CategoricalDtype:
+ # Note: This can be removed once minimum pyarrow version is >= 0.16.1
+ s = s.astype(s.dtypes.categories.dtype)
try:
array = pa.Array.from_pandas(s, mask=mask, type=t, safe=self._safecheck)
except pa.ArrowException as e:
diff --git a/python/pyspark/sql/pandas/types.py b/python/pyspark/sql/pandas/types.py
index d1edf3f9c47c1..4b70c8a2e95e1 100644
--- a/python/pyspark/sql/pandas/types.py
+++ b/python/pyspark/sql/pandas/types.py
@@ -114,6 +114,8 @@ def from_arrow_type(at):
return StructType(
[StructField(field.name, from_arrow_type(field.type), nullable=field.nullable)
for field in at])
+ elif types.is_dictionary(at):
+ spark_type = from_arrow_type(at.value_type)
else:
raise TypeError("Unsupported type in conversion from Arrow: " + str(at))
return spark_type
diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py
index 004c79f290213..c59765dd79eb9 100644
--- a/python/pyspark/sql/tests/test_arrow.py
+++ b/python/pyspark/sql/tests/test_arrow.py
@@ -415,6 +415,33 @@ def run_test(num_records, num_parts, max_records, use_delay=False):
for case in cases:
run_test(*case)
+ def test_createDateFrame_with_category_type(self):
+ pdf = pd.DataFrame({"A": [u"a", u"b", u"c", u"a"]})
+ pdf["B"] = pdf["A"].astype('category')
+ category_first_element = dict(enumerate(pdf['B'].cat.categories))[0]
+
+ with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": True}):
+ arrow_df = self.spark.createDataFrame(pdf)
+ arrow_type = arrow_df.dtypes[1][1]
+ result_arrow = arrow_df.toPandas()
+ arrow_first_category_element = result_arrow["B"][0]
+
+ with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}):
+ df = self.spark.createDataFrame(pdf)
+ spark_type = df.dtypes[1][1]
+ result_spark = df.toPandas()
+ spark_first_category_element = result_spark["B"][0]
+
+ assert_frame_equal(result_spark, result_arrow)
+
+ # ensure original category elements are string
+ self.assertIsInstance(category_first_element, str)
+ # spark data frame and arrow execution mode enabled data frame type must match pandas
+ self.assertEqual(spark_type, 'string')
+ self.assertEqual(arrow_type, 'string')
+ self.assertIsInstance(arrow_first_category_element, str)
+ self.assertIsInstance(spark_first_category_element, str)
+
@unittest.skipIf(
not have_pandas or not have_pyarrow,
diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py
index 9861178158f85..062e61663a332 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -17,6 +17,8 @@
import os
import pydoc
+import shutil
+import tempfile
import time
import unittest
@@ -820,6 +822,22 @@ def test_same_semantics_error(self):
with self.assertRaisesRegexp(ValueError, "should be of DataFrame.*int"):
self.spark.range(10).sameSemantics(1)
+ def test_input_files(self):
+ tpath = tempfile.mkdtemp()
+ shutil.rmtree(tpath)
+ try:
+ self.spark.range(1, 100, 1, 10).write.parquet(tpath)
+ # read parquet file and get the input files list
+ input_files_list = self.spark.read.parquet(tpath).inputFiles()
+
+ # input files list should contain 10 entries
+ self.assertEquals(len(input_files_list), 10)
+ # all file paths in list must contain tpath
+ for file_path in input_files_list:
+ self.assertTrue(tpath in file_path)
+ finally:
+ shutil.rmtree(tpath)
+
class QueryExecutionListenerTests(unittest.TestCase, SQLTestUtils):
# These tests are separate because it uses 'spark.sql.queryExecutionListeners' which is
diff --git a/python/pyspark/sql/tests/test_pandas_udf_scalar.py b/python/pyspark/sql/tests/test_pandas_udf_scalar.py
index 7260e80e2cfca..2d38efd39f902 100644
--- a/python/pyspark/sql/tests/test_pandas_udf_scalar.py
+++ b/python/pyspark/sql/tests/test_pandas_udf_scalar.py
@@ -897,6 +897,24 @@ def test_timestamp_dst(self):
result = df.withColumn('time', foo_udf(df.time))
self.assertEquals(df.collect(), result.collect())
+ def test_udf_category_type(self):
+
+ @pandas_udf('string')
+ def to_category_func(x):
+ return x.astype('category')
+
+ pdf = pd.DataFrame({"A": [u"a", u"b", u"c", u"a"]})
+ df = self.spark.createDataFrame(pdf)
+ df = df.withColumn("B", to_category_func(df['A']))
+ result_spark = df.toPandas()
+
+ spark_type = df.dtypes[1][1]
+ # spark data frame and arrow execution mode enabled data frame type must match pandas
+ self.assertEqual(spark_type, 'string')
+
+ # Check result of column 'B' must be equal to column 'A' in type and values
+ pd.testing.assert_series_equal(result_spark["A"], result_spark["B"], check_names=False)
+
@unittest.skipIf(sys.version_info[:2] < (3, 5), "Type hints are supported from Python 3.5.")
def test_type_annotation(self):
# Regression test to check if type hints can be used. See SPARK-23569.
diff --git a/python/pyspark/tests/test_rdd.py b/python/pyspark/tests/test_rdd.py
index 04dfe68e57a3a..62ad4221d7078 100644
--- a/python/pyspark/tests/test_rdd.py
+++ b/python/pyspark/tests/test_rdd.py
@@ -168,15 +168,6 @@ def test_zip_chaining(self):
set([(x, (x, x)) for x in 'abc'])
)
- def test_union_pair_rdd(self):
- # Regression test for SPARK-31788
- rdd = self.sc.parallelize([1, 2])
- pair_rdd = rdd.zip(rdd)
- self.assertEqual(
- self.sc.union([pair_rdd, pair_rdd]).collect(),
- [((1, 1), (2, 2)), ((1, 1), (2, 2))]
- )
-
def test_deleting_input_files(self):
# Regression test for SPARK-1025
tempFile = tempfile.NamedTemporaryFile(delete=False)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index d4799cace4531..e2559d4c07297 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -339,7 +339,7 @@ object FunctionRegistry {
expression[GetJsonObject]("get_json_object"),
expression[InitCap]("initcap"),
expression[StringInstr]("instr"),
- expression[Lower]("lcase"),
+ expression[Lower]("lcase", true),
expression[Length]("length"),
expression[Levenshtein]("levenshtein"),
expression[Like]("like"),
@@ -350,7 +350,7 @@ object FunctionRegistry {
expression[StringTrimLeft]("ltrim"),
expression[JsonTuple]("json_tuple"),
expression[ParseUrl]("parse_url"),
- expression[StringLocate]("position"),
+ expression[StringLocate]("position", true),
expression[FormatString]("printf", true),
expression[RegExpExtract]("regexp_extract"),
expression[RegExpReplace]("regexp_replace"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index fe3fea5e35b1b..26f5bee72092c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst
import java.sql.{Date, Timestamp}
-import java.time.LocalDate
+import java.time.{Instant, LocalDate}
import scala.language.implicitConversions
@@ -152,6 +152,7 @@ package object dsl {
implicit def bigDecimalToLiteral(d: java.math.BigDecimal): Literal = Literal(d)
implicit def decimalToLiteral(d: Decimal): Literal = Literal(d)
implicit def timestampToLiteral(t: Timestamp): Literal = Literal(t)
+ implicit def instantToLiteral(i: Instant): Literal = Literal(i)
implicit def binaryToLiteral(a: Array[Byte]): Literal = Literal(a)
implicit def symbolToUnresolvedAttribute(s: Symbol): analysis.UnresolvedAttribute =
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala
index 82d689477080d..f7fe467cea830 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala
@@ -144,7 +144,7 @@ object TimeWindow {
case class PreciseTimestampConversion(
child: Expression,
fromType: DataType,
- toType: DataType) extends UnaryExpression with ExpectsInputTypes {
+ toType: DataType) extends UnaryExpression with ExpectsInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] = Seq(fromType)
override def dataType: DataType = toType
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala
index 7b819db32e425..342b14eaa3390 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala
@@ -127,7 +127,8 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme
> SELECT _FUNC_ 0;
-1
""")
-case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+case class BitwiseNot(child: Expression)
+ extends UnaryExpression with ExpectsInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] = Seq(IntegralType)
@@ -164,7 +165,8 @@ case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInp
0
""",
since = "3.0.0")
-case class BitwiseCount(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+case class BitwiseCount(child: Expression)
+ extends UnaryExpression with ExpectsInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(IntegralType, BooleanType))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 4fd68dcfe5156..b32e9ee05f1ef 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -141,7 +141,7 @@ object Size {
""",
group = "map_funcs")
case class MapKeys(child: Expression)
- extends UnaryExpression with ExpectsInputTypes {
+ extends UnaryExpression with ExpectsInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] = Seq(MapType)
@@ -332,7 +332,7 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI
""",
group = "map_funcs")
case class MapValues(child: Expression)
- extends UnaryExpression with ExpectsInputTypes {
+ extends UnaryExpression with ExpectsInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] = Seq(MapType)
@@ -361,7 +361,8 @@ case class MapValues(child: Expression)
""",
group = "map_funcs",
since = "3.0.0")
-case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+case class MapEntries(child: Expression)
+ extends UnaryExpression with ExpectsInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] = Seq(MapType)
@@ -649,7 +650,7 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres
""",
group = "map_funcs",
since = "2.4.0")
-case class MapFromEntries(child: Expression) extends UnaryExpression {
+case class MapFromEntries(child: Expression) extends UnaryExpression with NullIntolerant {
@transient
private lazy val dataTypeDetails: Option[(MapType, Boolean, Boolean)] = child.dataType match {
@@ -873,7 +874,7 @@ object ArraySortLike {
group = "array_funcs")
// scalastyle:on line.size.limit
case class SortArray(base: Expression, ascendingOrder: Expression)
- extends BinaryExpression with ArraySortLike {
+ extends BinaryExpression with ArraySortLike with NullIntolerant {
def this(e: Expression) = this(e, Literal(true))
@@ -1017,7 +1018,8 @@ case class Shuffle(child: Expression, randomSeed: Option[Long] = None)
Reverse logic for arrays is available since 2.4.0.
"""
)
-case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class Reverse(child: Expression)
+ extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
// Input types are utilized by type coercion in ImplicitTypeCasts.
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, ArrayType))
@@ -1086,7 +1088,7 @@ case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastI
""",
group = "array_funcs")
case class ArrayContains(left: Expression, right: Expression)
- extends BinaryExpression with ImplicitCastInputTypes {
+ extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def dataType: DataType = BooleanType
@@ -1185,7 +1187,7 @@ case class ArrayContains(left: Expression, right: Expression)
since = "2.4.0")
// scalastyle:off line.size.limit
case class ArraysOverlap(left: Expression, right: Expression)
- extends BinaryArrayExpressionWithImplicitCast {
+ extends BinaryArrayExpressionWithImplicitCast with NullIntolerant {
override def checkInputDataTypes(): TypeCheckResult = super.checkInputDataTypes() match {
case TypeCheckResult.TypeCheckSuccess =>
@@ -1410,7 +1412,7 @@ case class ArraysOverlap(left: Expression, right: Expression)
since = "2.4.0")
// scalastyle:on line.size.limit
case class Slice(x: Expression, start: Expression, length: Expression)
- extends TernaryExpression with ImplicitCastInputTypes {
+ extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def dataType: DataType = x.dataType
@@ -1688,7 +1690,8 @@ case class ArrayJoin(
""",
group = "array_funcs",
since = "2.4.0")
-case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class ArrayMin(child: Expression)
+ extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def nullable: Boolean = true
@@ -1755,7 +1758,8 @@ case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCast
""",
group = "array_funcs",
since = "2.4.0")
-case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class ArrayMax(child: Expression)
+ extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def nullable: Boolean = true
@@ -1831,7 +1835,7 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast
group = "array_funcs",
since = "2.4.0")
case class ArrayPosition(left: Expression, right: Expression)
- extends BinaryExpression with ImplicitCastInputTypes {
+ extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
@transient private lazy val ordering: Ordering[Any] =
TypeUtils.getInterpretedOrdering(right.dataType)
@@ -1909,7 +1913,7 @@ case class ArrayPosition(left: Expression, right: Expression)
""",
since = "2.4.0")
case class ElementAt(left: Expression, right: Expression)
- extends GetMapValueUtil with GetArrayItemUtil {
+ extends GetMapValueUtil with GetArrayItemUtil with NullIntolerant {
@transient private lazy val mapKeyType = left.dataType.asInstanceOf[MapType].keyType
@@ -2245,7 +2249,7 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio
""",
group = "array_funcs",
since = "2.4.0")
-case class Flatten(child: Expression) extends UnaryExpression {
+case class Flatten(child: Expression) extends UnaryExpression with NullIntolerant {
private def childDataType: ArrayType = child.dataType.asInstanceOf[ArrayType]
@@ -2884,7 +2888,7 @@ case class ArrayRepeat(left: Expression, right: Expression)
group = "array_funcs",
since = "2.4.0")
case class ArrayRemove(left: Expression, right: Expression)
- extends BinaryExpression with ImplicitCastInputTypes {
+ extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def dataType: DataType = left.dataType
@@ -3081,7 +3085,7 @@ trait ArraySetLike {
group = "array_funcs",
since = "2.4.0")
case class ArrayDistinct(child: Expression)
- extends UnaryExpression with ArraySetLike with ExpectsInputTypes {
+ extends UnaryExpression with ArraySetLike with ExpectsInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType)
@@ -3219,7 +3223,8 @@ case class ArrayDistinct(child: Expression)
/**
* Will become common base class for [[ArrayUnion]], [[ArrayIntersect]], and [[ArrayExcept]].
*/
-trait ArrayBinaryLike extends BinaryArrayExpressionWithImplicitCast with ArraySetLike {
+trait ArrayBinaryLike
+ extends BinaryArrayExpressionWithImplicitCast with ArraySetLike with NullIntolerant {
override protected def dt: DataType = dataType
override protected def et: DataType = elementType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index 858c91a4d8e86..1b4a705e804f1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
-import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
+import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.{FUNC_ALIAS, FunctionBuilder}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util._
@@ -255,7 +255,7 @@ object CreateMap {
{1.0:"2",3.0:"4"}
""", since = "2.4.0")
case class MapFromArrays(left: Expression, right: Expression)
- extends BinaryExpression with ExpectsInputTypes {
+ extends BinaryExpression with ExpectsInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, ArrayType)
@@ -311,7 +311,12 @@ case object NamePlaceholder extends LeafExpression with Unevaluable {
/**
* Returns a Row containing the evaluation of all children expressions.
*/
-object CreateStruct extends FunctionBuilder {
+object CreateStruct {
+ /**
+ * Returns a named struct with generated names or using the names when available.
+ * It should not be used for `struct` expressions or functions explicitly called
+ * by users.
+ */
def apply(children: Seq[Expression]): CreateNamedStruct = {
CreateNamedStruct(children.zipWithIndex.flatMap {
case (e: NamedExpression, _) if e.resolved => Seq(Literal(e.name), e)
@@ -320,12 +325,23 @@ object CreateStruct extends FunctionBuilder {
})
}
+ /**
+ * Returns a named struct with a pretty SQL name. It will show the pretty SQL string
+ * in its output column name as if `struct(...)` was called. Should be
+ * used for `struct` expressions or functions explicitly called by users.
+ */
+ def create(children: Seq[Expression]): CreateNamedStruct = {
+ val expr = CreateStruct(children)
+ expr.setTagValue(FUNC_ALIAS, "struct")
+ expr
+ }
+
/**
* Entry to use in the function registry.
*/
val registryEntry: (String, (ExpressionInfo, FunctionBuilder)) = {
val info: ExpressionInfo = new ExpressionInfo(
- "org.apache.spark.sql.catalyst.expressions.NamedStruct",
+ classOf[CreateNamedStruct].getCanonicalName,
null,
"struct",
"_FUNC_(col1, col2, col3, ...) - Creates a struct with the given field values.",
@@ -335,7 +351,7 @@ object CreateStruct extends FunctionBuilder {
"",
"",
"")
- ("struct", (info, this))
+ ("struct", (info, this.create))
}
}
@@ -433,7 +449,15 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression {
""".stripMargin, isNull = FalseLiteral)
}
- override def prettyName: String = "named_struct"
+ // There is an alias set at `CreateStruct.create`. If there is an alias,
+ // this is the struct function explicitly called by a user and we should
+ // respect it in the SQL string as `struct(...)`.
+ override def prettyName: String = getTagValue(FUNC_ALIAS).getOrElse("named_struct")
+
+ override def sql: String = getTagValue(FUNC_ALIAS).map { alias =>
+ val childrenSQL = children.indices.filter(_ % 2 == 1).map(children(_).sql).mkString(", ")
+ s"$alias($childrenSQL)"
+ }.getOrElse(super.sql)
}
/**
@@ -452,7 +476,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression {
since = "2.0.1")
// scalastyle:on line.size.limit
case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: Expression)
- extends TernaryExpression with ExpectsInputTypes {
+ extends TernaryExpression with ExpectsInputTypes with NullIntolerant {
def this(child: Expression, pairDelim: Expression) = {
this(child, pairDelim, Literal(":"))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala
index 5140db90c5954..f9ccf3c8c811f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala
@@ -211,7 +211,8 @@ case class StructsToCsv(
options: Map[String, String],
child: Expression,
timeZoneId: Option[String] = None)
- extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes {
+ extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes
+ with NullIntolerant {
override def nullable: Boolean = true
def this(options: Map[String, String], child: Expression) = this(options, child, None)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
index 7dc008a2e5df8..4f3db1b8a57ce 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
@@ -198,7 +198,7 @@ case class CurrentBatchTimestamp(
group = "datetime_funcs",
since = "1.5.0")
case class DateAdd(startDate: Expression, days: Expression)
- extends BinaryExpression with ExpectsInputTypes {
+ extends BinaryExpression with ExpectsInputTypes with NullIntolerant {
override def left: Expression = startDate
override def right: Expression = days
@@ -234,7 +234,7 @@ case class DateAdd(startDate: Expression, days: Expression)
group = "datetime_funcs",
since = "1.5.0")
case class DateSub(startDate: Expression, days: Expression)
- extends BinaryExpression with ExpectsInputTypes {
+ extends BinaryExpression with ExpectsInputTypes with NullIntolerant {
override def left: Expression = startDate
override def right: Expression = days
@@ -266,7 +266,8 @@ case class DateSub(startDate: Expression, days: Expression)
group = "datetime_funcs",
since = "1.5.0")
case class Hour(child: Expression, timeZoneId: Option[String] = None)
- extends UnaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes {
+ extends UnaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes
+ with NullIntolerant {
def this(child: Expression) = this(child, None)
@@ -298,7 +299,8 @@ case class Hour(child: Expression, timeZoneId: Option[String] = None)
group = "datetime_funcs",
since = "1.5.0")
case class Minute(child: Expression, timeZoneId: Option[String] = None)
- extends UnaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes {
+ extends UnaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes
+ with NullIntolerant {
def this(child: Expression) = this(child, None)
@@ -330,7 +332,8 @@ case class Minute(child: Expression, timeZoneId: Option[String] = None)
group = "datetime_funcs",
since = "1.5.0")
case class Second(child: Expression, timeZoneId: Option[String] = None)
- extends UnaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes {
+ extends UnaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes
+ with NullIntolerant {
def this(child: Expression) = this(child, None)
@@ -353,7 +356,8 @@ case class Second(child: Expression, timeZoneId: Option[String] = None)
}
case class SecondWithFraction(child: Expression, timeZoneId: Option[String] = None)
- extends UnaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes {
+ extends UnaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes
+ with NullIntolerant {
def this(child: Expression) = this(child, None)
@@ -385,7 +389,8 @@ case class SecondWithFraction(child: Expression, timeZoneId: Option[String] = No
""",
group = "datetime_funcs",
since = "1.5.0")
-case class DayOfYear(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class DayOfYear(child: Expression)
+ extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] = Seq(DateType)
@@ -402,7 +407,7 @@ case class DayOfYear(child: Expression) extends UnaryExpression with ImplicitCas
}
abstract class NumberToTimestampBase extends UnaryExpression
- with ExpectsInputTypes {
+ with ExpectsInputTypes with NullIntolerant {
protected def upScaleFactor: Long
@@ -487,7 +492,8 @@ case class MicrosToTimestamp(child: Expression)
""",
group = "datetime_funcs",
since = "1.5.0")
-case class Year(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class Year(child: Expression)
+ extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] = Seq(DateType)
@@ -503,7 +509,8 @@ case class Year(child: Expression) extends UnaryExpression with ImplicitCastInpu
}
}
-case class YearOfWeek(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class YearOfWeek(child: Expression)
+ extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] = Seq(DateType)
@@ -528,7 +535,8 @@ case class YearOfWeek(child: Expression) extends UnaryExpression with ImplicitCa
""",
group = "datetime_funcs",
since = "1.5.0")
-case class Quarter(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class Quarter(child: Expression)
+ extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] = Seq(DateType)
@@ -553,7 +561,8 @@ case class Quarter(child: Expression) extends UnaryExpression with ImplicitCastI
""",
group = "datetime_funcs",
since = "1.5.0")
-case class Month(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class Month(child: Expression)
+ extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] = Seq(DateType)
@@ -577,7 +586,8 @@ case class Month(child: Expression) extends UnaryExpression with ImplicitCastInp
30
""",
since = "1.5.0")
-case class DayOfMonth(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class DayOfMonth(child: Expression)
+ extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] = Seq(DateType)
@@ -647,7 +657,7 @@ case class WeekDay(child: Expression) extends DayWeek {
}
}
-abstract class DayWeek extends UnaryExpression with ImplicitCastInputTypes {
+abstract class DayWeek extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] = Seq(DateType)
@@ -665,7 +675,8 @@ abstract class DayWeek extends UnaryExpression with ImplicitCastInputTypes {
group = "datetime_funcs",
since = "1.5.0")
// scalastyle:on line.size.limit
-case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class WeekOfYear(child: Expression)
+ extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] = Seq(DateType)
@@ -704,7 +715,8 @@ case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCa
since = "1.5.0")
// scalastyle:on line.size.limit
case class DateFormatClass(left: Expression, right: Expression, timeZoneId: Option[String] = None)
- extends BinaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes {
+ extends BinaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes
+ with NullIntolerant {
def this(left: Expression, right: Expression) = this(left, right, None)
@@ -1154,7 +1166,8 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[
""",
group = "datetime_funcs",
since = "1.5.0")
-case class LastDay(startDate: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class LastDay(startDate: Expression)
+ extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def child: Expression = startDate
override def inputTypes: Seq[AbstractDataType] = Seq(DateType)
@@ -1192,7 +1205,7 @@ case class LastDay(startDate: Expression) extends UnaryExpression with ImplicitC
since = "1.5.0")
// scalastyle:on line.size.limit
case class NextDay(startDate: Expression, dayOfWeek: Expression)
- extends BinaryExpression with ImplicitCastInputTypes {
+ extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def left: Expression = startDate
override def right: Expression = dayOfWeek
@@ -1248,7 +1261,7 @@ case class NextDay(startDate: Expression, dayOfWeek: Expression)
* Adds an interval to timestamp.
*/
case class TimeAdd(start: Expression, interval: Expression, timeZoneId: Option[String] = None)
- extends BinaryExpression with TimeZoneAwareExpression with ExpectsInputTypes {
+ extends BinaryExpression with TimeZoneAwareExpression with ExpectsInputTypes with NullIntolerant {
def this(start: Expression, interval: Expression) = this(start, interval, None)
@@ -1306,7 +1319,7 @@ case class DateAddInterval(
interval: Expression,
timeZoneId: Option[String] = None,
ansiEnabled: Boolean = SQLConf.get.ansiEnabled)
- extends BinaryExpression with ExpectsInputTypes with TimeZoneAwareExpression {
+ extends BinaryExpression with ExpectsInputTypes with TimeZoneAwareExpression with NullIntolerant {
override def left: Expression = start
override def right: Expression = interval
@@ -1380,7 +1393,7 @@ case class DateAddInterval(
since = "1.5.0")
// scalastyle:on line.size.limit
case class FromUTCTimestamp(left: Expression, right: Expression)
- extends BinaryExpression with ImplicitCastInputTypes {
+ extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, StringType)
override def dataType: DataType = TimestampType
@@ -1440,7 +1453,7 @@ case class FromUTCTimestamp(left: Expression, right: Expression)
since = "1.5.0")
// scalastyle:on line.size.limit
case class AddMonths(startDate: Expression, numMonths: Expression)
- extends BinaryExpression with ImplicitCastInputTypes {
+ extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def left: Expression = startDate
override def right: Expression = numMonths
@@ -1494,7 +1507,8 @@ case class MonthsBetween(
date2: Expression,
roundOff: Expression,
timeZoneId: Option[String] = None)
- extends TernaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes {
+ extends TernaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes
+ with NullIntolerant {
def this(date1: Expression, date2: Expression) = this(date1, date2, Literal.TrueLiteral, None)
@@ -1552,7 +1566,7 @@ case class MonthsBetween(
since = "1.5.0")
// scalastyle:on line.size.limit
case class ToUTCTimestamp(left: Expression, right: Expression)
- extends BinaryExpression with ImplicitCastInputTypes {
+ extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, StringType)
override def dataType: DataType = TimestampType
@@ -1906,7 +1920,7 @@ case class TruncTimestamp(
group = "datetime_funcs",
since = "1.5.0")
case class DateDiff(endDate: Expression, startDate: Expression)
- extends BinaryExpression with ImplicitCastInputTypes {
+ extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def left: Expression = endDate
override def right: Expression = startDate
@@ -1960,7 +1974,7 @@ private case class GetTimestamp(
group = "datetime_funcs",
since = "3.0.0")
case class MakeDate(year: Expression, month: Expression, day: Expression)
- extends TernaryExpression with ImplicitCastInputTypes {
+ extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def children: Seq[Expression] = Seq(year, month, day)
override def inputTypes: Seq[AbstractDataType] = Seq(IntegerType, IntegerType, IntegerType)
@@ -2031,7 +2045,8 @@ case class MakeTimestamp(
sec: Expression,
timezone: Option[Expression] = None,
timeZoneId: Option[String] = None)
- extends SeptenaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes {
+ extends SeptenaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes
+ with NullIntolerant {
def this(
year: Expression,
@@ -2307,7 +2322,7 @@ case class Extract(field: Expression, source: Expression, child: Expression)
* between the given timestamps.
*/
case class SubtractTimestamps(endTimestamp: Expression, startTimestamp: Expression)
- extends BinaryExpression with ExpectsInputTypes {
+ extends BinaryExpression with ExpectsInputTypes with NullIntolerant {
override def left: Expression = endTimestamp
override def right: Expression = startTimestamp
@@ -2328,7 +2343,7 @@ case class SubtractTimestamps(endTimestamp: Expression, startTimestamp: Expressi
* Returns the interval from the `left` date (inclusive) to the `right` date (exclusive).
*/
case class SubtractDates(left: Expression, right: Expression)
- extends BinaryExpression with ImplicitCastInputTypes {
+ extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] = Seq(DateType, DateType)
override def dataType: DataType = CalendarIntervalType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
index 9014ebfe2f96a..c2c70b2ab08e1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
@@ -27,7 +27,7 @@ import org.apache.spark.sql.types._
* Note: this expression is internal and created only by the optimizer,
* we don't need to do type check for it.
*/
-case class UnscaledValue(child: Expression) extends UnaryExpression {
+case class UnscaledValue(child: Expression) extends UnaryExpression with NullIntolerant {
override def dataType: DataType = LongType
override def toString: String = s"UnscaledValue($child)"
@@ -49,7 +49,7 @@ case class MakeDecimal(
child: Expression,
precision: Int,
scale: Int,
- nullOnOverflow: Boolean) extends UnaryExpression {
+ nullOnOverflow: Boolean) extends UnaryExpression with NullIntolerant {
def this(child: Expression, precision: Int, scale: Int) = {
this(child, precision, scale, !SQLConf.get.ansiEnabled)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
index 4c8c58ae232f4..5e21b58f070ba 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
@@ -53,7 +53,8 @@ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
> SELECT _FUNC_('Spark');
8cde774d6f7333752ed72cacddb05126
""")
-case class Md5(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class Md5(child: Expression)
+ extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def dataType: DataType = StringType
@@ -89,7 +90,7 @@ case class Md5(child: Expression) extends UnaryExpression with ImplicitCastInput
""")
// scalastyle:on line.size.limit
case class Sha2(left: Expression, right: Expression)
- extends BinaryExpression with Serializable with ImplicitCastInputTypes {
+ extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant with Serializable {
override def dataType: DataType = StringType
override def nullable: Boolean = true
@@ -160,7 +161,8 @@ case class Sha2(left: Expression, right: Expression)
> SELECT _FUNC_('Spark');
85f5955f4b27a9a4c2aab6ffe5d7189fc298b92c
""")
-case class Sha1(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class Sha1(child: Expression)
+ extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def dataType: DataType = StringType
@@ -187,7 +189,8 @@ case class Sha1(child: Expression) extends UnaryExpression with ImplicitCastInpu
> SELECT _FUNC_('Spark');
1557323817
""")
-case class Crc32(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class Crc32(child: Expression)
+ extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def dataType: DataType = LongType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala
index 1a569a7b89fe1..baab224691bc1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala
@@ -31,7 +31,7 @@ abstract class ExtractIntervalPart(
val dataType: DataType,
func: CalendarInterval => Any,
funcName: String)
- extends UnaryExpression with ExpectsInputTypes with Serializable {
+ extends UnaryExpression with ExpectsInputTypes with NullIntolerant with Serializable {
override def inputTypes: Seq[AbstractDataType] = Seq(CalendarIntervalType)
@@ -82,7 +82,7 @@ object ExtractIntervalPart {
abstract class IntervalNumOperation(
interval: Expression,
num: Expression)
- extends BinaryExpression with ImplicitCastInputTypes with Serializable {
+ extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant with Serializable {
override def left: Expression = interval
override def right: Expression = num
@@ -160,7 +160,7 @@ case class MakeInterval(
hours: Expression,
mins: Expression,
secs: Expression)
- extends SeptenaryExpression with ImplicitCastInputTypes {
+ extends SeptenaryExpression with ImplicitCastInputTypes with NullIntolerant {
def this(
years: Expression,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
index 205e5271517c3..f4568f860ac0e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
@@ -519,7 +519,8 @@ case class JsonToStructs(
options: Map[String, String],
child: Expression,
timeZoneId: Option[String] = None)
- extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes {
+ extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes
+ with NullIntolerant {
// The JSON input data might be missing certain fields. We force the nullability
// of the user-provided schema to avoid data corruptions. In particular, the parquet-mr encoder
@@ -638,7 +639,8 @@ case class StructsToJson(
options: Map[String, String],
child: Expression,
timeZoneId: Option[String] = None)
- extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes {
+ extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback
+ with ExpectsInputTypes with NullIntolerant {
override def nullable: Boolean = true
def this(options: Map[String, String], child: Expression) = this(options, child, None)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
index 66e6334e3a450..fe8ea2a3c6733 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
@@ -57,7 +57,7 @@ abstract class LeafMathExpression(c: Double, name: String)
* @param name The short name of the function
*/
abstract class UnaryMathExpression(val f: Double => Double, name: String)
- extends UnaryExpression with Serializable with ImplicitCastInputTypes {
+ extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant with Serializable {
override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType)
override def dataType: DataType = DoubleType
@@ -111,7 +111,7 @@ abstract class UnaryLogExpression(f: Double => Double, name: String)
* @param name The short name of the function
*/
abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String)
- extends BinaryExpression with Serializable with ImplicitCastInputTypes {
+ extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant with Serializable {
override def inputTypes: Seq[DataType] = Seq(DoubleType, DoubleType)
@@ -324,7 +324,7 @@ case class Acosh(child: Expression)
-16
""")
case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expression)
- extends TernaryExpression with ImplicitCastInputTypes {
+ extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def children: Seq[Expression] = Seq(numExpr, fromBaseExpr, toBaseExpr)
override def inputTypes: Seq[AbstractDataType] = Seq(StringType, IntegerType, IntegerType)
@@ -452,7 +452,8 @@ object Factorial {
> SELECT _FUNC_(5);
120
""")
-case class Factorial(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class Factorial(child: Expression)
+ extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def inputTypes: Seq[DataType] = Seq(IntegerType)
@@ -491,7 +492,9 @@ case class Factorial(child: Expression) extends UnaryExpression with ImplicitCas
> SELECT _FUNC_(1);
0.0
""")
-case class Log(child: Expression) extends UnaryLogExpression(StrictMath.log, "LOG")
+case class Log(child: Expression) extends UnaryLogExpression(StrictMath.log, "LOG") {
+ override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("ln")
+}
@ExpressionDescription(
usage = "_FUNC_(expr) - Returns the logarithm of `expr` with base 2.",
@@ -546,6 +549,7 @@ case class Log1p(child: Expression) extends UnaryLogExpression(StrictMath.log1p,
// scalastyle:on line.size.limit
case class Rint(child: Expression) extends UnaryMathExpression(math.rint, "ROUND") {
override def funcName: String = "rint"
+ override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("rint")
}
@ExpressionDescription(
@@ -732,7 +736,7 @@ case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadia
""")
// scalastyle:on line.size.limit
case class Bin(child: Expression)
- extends UnaryExpression with Serializable with ImplicitCastInputTypes {
+ extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant with Serializable {
override def inputTypes: Seq[DataType] = Seq(LongType)
override def dataType: DataType = StringType
@@ -831,7 +835,8 @@ object Hex {
> SELECT _FUNC_('Spark SQL');
537061726B2053514C
""")
-case class Hex(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class Hex(child: Expression)
+ extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(LongType, BinaryType, StringType))
@@ -866,7 +871,8 @@ case class Hex(child: Expression) extends UnaryExpression with ImplicitCastInput
> SELECT decode(_FUNC_('537061726B2053514C'), 'UTF-8');
Spark SQL
""")
-case class Unhex(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class Unhex(child: Expression)
+ extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] = Seq(StringType)
@@ -952,7 +958,7 @@ case class Pow(left: Expression, right: Expression)
4
""")
case class ShiftLeft(left: Expression, right: Expression)
- extends BinaryExpression with ImplicitCastInputTypes {
+ extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(IntegerType, LongType), IntegerType)
@@ -986,7 +992,7 @@ case class ShiftLeft(left: Expression, right: Expression)
2
""")
case class ShiftRight(left: Expression, right: Expression)
- extends BinaryExpression with ImplicitCastInputTypes {
+ extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(IntegerType, LongType), IntegerType)
@@ -1020,7 +1026,7 @@ case class ShiftRight(left: Expression, right: Expression)
2
""")
case class ShiftRightUnsigned(left: Expression, right: Expression)
- extends BinaryExpression with ImplicitCastInputTypes {
+ extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(IntegerType, LongType), IntegerType)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
index 3f60ca388a807..28924fac48eef 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
@@ -283,7 +283,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress
""",
since = "1.5.0")
case class StringSplit(str: Expression, regex: Expression, limit: Expression)
- extends TernaryExpression with ImplicitCastInputTypes {
+ extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def dataType: DataType = ArrayType(StringType)
override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType)
@@ -325,7 +325,7 @@ case class StringSplit(str: Expression, regex: Expression, limit: Expression)
since = "1.5.0")
// scalastyle:on line.size.limit
case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expression)
- extends TernaryExpression with ImplicitCastInputTypes {
+ extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant {
// last regex in string, we will update the pattern iff regexp value changed.
@transient private var lastRegex: UTF8String = _
@@ -433,7 +433,7 @@ object RegExpExtract {
""",
since = "1.5.0")
case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expression)
- extends TernaryExpression with ImplicitCastInputTypes {
+ extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant {
def this(s: Expression, r: Expression) = this(s, r, Literal(1))
// last regex in string, we will update the pattern iff regexp value changed.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index 0b9fb8f85fe3c..334a079fc1892 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -334,7 +334,7 @@ trait String2StringExpression extends ImplicitCastInputTypes {
""",
since = "1.0.1")
case class Upper(child: Expression)
- extends UnaryExpression with String2StringExpression {
+ extends UnaryExpression with String2StringExpression with NullIntolerant {
// scalastyle:off caselocale
override def convert(v: UTF8String): UTF8String = v.toUpperCase
@@ -356,7 +356,8 @@ case class Upper(child: Expression)
sparksql
""",
since = "1.0.1")
-case class Lower(child: Expression) extends UnaryExpression with String2StringExpression {
+case class Lower(child: Expression)
+ extends UnaryExpression with String2StringExpression with NullIntolerant {
// scalastyle:off caselocale
override def convert(v: UTF8String): UTF8String = v.toLowerCase
@@ -365,6 +366,9 @@ case class Lower(child: Expression) extends UnaryExpression with String2StringEx
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, c => s"($c).toLowerCase()")
}
+
+ override def prettyName: String =
+ getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("lower")
}
/** A base trait for functions that compare two strings, returning a boolean. */
@@ -432,7 +436,7 @@ case class EndsWith(left: Expression, right: Expression) extends StringPredicate
since = "2.3.0")
// scalastyle:on line.size.limit
case class StringReplace(srcExpr: Expression, searchExpr: Expression, replaceExpr: Expression)
- extends TernaryExpression with ImplicitCastInputTypes {
+ extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant {
def this(srcExpr: Expression, searchExpr: Expression) = {
this(srcExpr, searchExpr, Literal(""))
@@ -598,7 +602,7 @@ object StringTranslate {
since = "1.5.0")
// scalastyle:on line.size.limit
case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replaceExpr: Expression)
- extends TernaryExpression with ImplicitCastInputTypes {
+ extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant {
@transient private var lastMatching: UTF8String = _
@transient private var lastReplace: UTF8String = _
@@ -663,7 +667,7 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac
since = "1.5.0")
// scalastyle:on line.size.limit
case class FindInSet(left: Expression, right: Expression) extends BinaryExpression
- with ImplicitCastInputTypes {
+ with ImplicitCastInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType)
@@ -1035,7 +1039,7 @@ case class StringTrimRight(
since = "1.5.0")
// scalastyle:on line.size.limit
case class StringInstr(str: Expression, substr: Expression)
- extends BinaryExpression with ImplicitCastInputTypes {
+ extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def left: Expression = str
override def right: Expression = substr
@@ -1077,7 +1081,7 @@ case class StringInstr(str: Expression, substr: Expression)
since = "1.5.0")
// scalastyle:on line.size.limit
case class SubstringIndex(strExpr: Expression, delimExpr: Expression, countExpr: Expression)
- extends TernaryExpression with ImplicitCastInputTypes {
+ extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def dataType: DataType = StringType
override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType)
@@ -1182,7 +1186,8 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression)
""")
}
- override def prettyName: String = "locate"
+ override def prettyName: String =
+ getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("locate")
}
/**
@@ -1205,7 +1210,7 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression)
""",
since = "1.5.0")
case class StringLPad(str: Expression, len: Expression, pad: Expression = Literal(" "))
- extends TernaryExpression with ImplicitCastInputTypes {
+ extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant {
def this(str: Expression, len: Expression) = {
this(str, len, Literal(" "))
@@ -1246,7 +1251,7 @@ case class StringLPad(str: Expression, len: Expression, pad: Expression = Litera
""",
since = "1.5.0")
case class StringRPad(str: Expression, len: Expression, pad: Expression = Literal(" "))
- extends TernaryExpression with ImplicitCastInputTypes {
+ extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant {
def this(str: Expression, len: Expression) = {
this(str, len, Literal(" "))
@@ -1536,7 +1541,8 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC
Spark Sql
""",
since = "1.5.0")
-case class InitCap(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class InitCap(child: Expression)
+ extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def inputTypes: Seq[DataType] = Seq(StringType)
override def dataType: DataType = StringType
@@ -1563,7 +1569,7 @@ case class InitCap(child: Expression) extends UnaryExpression with ImplicitCastI
""",
since = "1.5.0")
case class StringRepeat(str: Expression, times: Expression)
- extends BinaryExpression with ImplicitCastInputTypes {
+ extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def left: Expression = str
override def right: Expression = times
@@ -1593,7 +1599,7 @@ case class StringRepeat(str: Expression, times: Expression)
""",
since = "1.5.0")
case class StringSpace(child: Expression)
- extends UnaryExpression with ImplicitCastInputTypes {
+ extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def dataType: DataType = StringType
override def inputTypes: Seq[DataType] = Seq(IntegerType)
@@ -1738,7 +1744,8 @@ case class Left(str: Expression, len: Expression, child: Expression) extends Run
""",
since = "1.5.0")
// scalastyle:on line.size.limit
-case class Length(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class Length(child: Expression)
+ extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def dataType: DataType = IntegerType
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType))
@@ -1766,7 +1773,8 @@ case class Length(child: Expression) extends UnaryExpression with ImplicitCastIn
72
""",
since = "2.3.0")
-case class BitLength(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class BitLength(child: Expression)
+ extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def dataType: DataType = IntegerType
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType))
@@ -1797,7 +1805,8 @@ case class BitLength(child: Expression) extends UnaryExpression with ImplicitCas
9
""",
since = "2.3.0")
-case class OctetLength(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class OctetLength(child: Expression)
+ extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def dataType: DataType = IntegerType
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType))
@@ -1828,7 +1837,7 @@ case class OctetLength(child: Expression) extends UnaryExpression with ImplicitC
""",
since = "1.5.0")
case class Levenshtein(left: Expression, right: Expression) extends BinaryExpression
- with ImplicitCastInputTypes {
+ with ImplicitCastInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType)
@@ -1853,7 +1862,8 @@ case class Levenshtein(left: Expression, right: Expression) extends BinaryExpres
M460
""",
since = "1.5.0")
-case class SoundEx(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+case class SoundEx(child: Expression)
+ extends UnaryExpression with ExpectsInputTypes with NullIntolerant {
override def dataType: DataType = StringType
@@ -1879,7 +1889,8 @@ case class SoundEx(child: Expression) extends UnaryExpression with ExpectsInputT
50
""",
since = "1.5.0")
-case class Ascii(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class Ascii(child: Expression)
+ extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def dataType: DataType = IntegerType
override def inputTypes: Seq[DataType] = Seq(StringType)
@@ -1921,7 +1932,8 @@ case class Ascii(child: Expression) extends UnaryExpression with ImplicitCastInp
""",
since = "2.3.0")
// scalastyle:on line.size.limit
-case class Chr(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class Chr(child: Expression)
+ extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def dataType: DataType = StringType
override def inputTypes: Seq[DataType] = Seq(LongType)
@@ -1964,7 +1976,8 @@ case class Chr(child: Expression) extends UnaryExpression with ImplicitCastInput
U3BhcmsgU1FM
""",
since = "1.5.0")
-case class Base64(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class Base64(child: Expression)
+ extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def dataType: DataType = StringType
override def inputTypes: Seq[DataType] = Seq(BinaryType)
@@ -1992,7 +2005,8 @@ case class Base64(child: Expression) extends UnaryExpression with ImplicitCastIn
Spark SQL
""",
since = "1.5.0")
-case class UnBase64(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+case class UnBase64(child: Expression)
+ extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def dataType: DataType = BinaryType
override def inputTypes: Seq[DataType] = Seq(StringType)
@@ -2024,7 +2038,7 @@ case class UnBase64(child: Expression) extends UnaryExpression with ImplicitCast
since = "1.5.0")
// scalastyle:on line.size.limit
case class Decode(bin: Expression, charset: Expression)
- extends BinaryExpression with ImplicitCastInputTypes {
+ extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def left: Expression = bin
override def right: Expression = charset
@@ -2064,7 +2078,7 @@ case class Decode(bin: Expression, charset: Expression)
since = "1.5.0")
// scalastyle:on line.size.limit
case class Encode(value: Expression, charset: Expression)
- extends BinaryExpression with ImplicitCastInputTypes {
+ extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def left: Expression = value
override def right: Expression = charset
@@ -2108,7 +2122,7 @@ case class Encode(value: Expression, charset: Expression)
""",
since = "1.5.0")
case class FormatNumber(x: Expression, d: Expression)
- extends BinaryExpression with ExpectsInputTypes {
+ extends BinaryExpression with ExpectsInputTypes with NullIntolerant {
override def left: Expression = x
override def right: Expression = d
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala
index 55e06cb9e8471..e08a10ecac71c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala
@@ -30,7 +30,8 @@ import org.apache.spark.unsafe.types.UTF8String
*
* This is not the world's most efficient implementation due to type conversion, but works.
*/
-abstract class XPathExtract extends BinaryExpression with ExpectsInputTypes with CodegenFallback {
+abstract class XPathExtract
+ extends BinaryExpression with ExpectsInputTypes with CodegenFallback with NullIntolerant {
override def left: Expression = xml
override def right: Expression = path
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
index b65221c236bfe..85c6600685bd1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
@@ -208,3 +208,161 @@ object ExtractPythonUDFFromJoinCondition extends Rule[LogicalPlan] with Predicat
}
}
}
+
+sealed abstract class BuildSide
+
+case object BuildRight extends BuildSide
+
+case object BuildLeft extends BuildSide
+
+trait JoinSelectionHelper {
+
+ def getBroadcastBuildSide(
+ left: LogicalPlan,
+ right: LogicalPlan,
+ joinType: JoinType,
+ hint: JoinHint,
+ hintOnly: Boolean,
+ conf: SQLConf): Option[BuildSide] = {
+ val buildLeft = if (hintOnly) {
+ hintToBroadcastLeft(hint)
+ } else {
+ canBroadcastBySize(left, conf) && !hintToNotBroadcastLeft(hint)
+ }
+ val buildRight = if (hintOnly) {
+ hintToBroadcastRight(hint)
+ } else {
+ canBroadcastBySize(right, conf) && !hintToNotBroadcastRight(hint)
+ }
+ getBuildSide(
+ canBuildLeft(joinType) && buildLeft,
+ canBuildRight(joinType) && buildRight,
+ left,
+ right
+ )
+ }
+
+ def getShuffleHashJoinBuildSide(
+ left: LogicalPlan,
+ right: LogicalPlan,
+ joinType: JoinType,
+ hint: JoinHint,
+ hintOnly: Boolean,
+ conf: SQLConf): Option[BuildSide] = {
+ val buildLeft = if (hintOnly) {
+ hintToShuffleHashJoinLeft(hint)
+ } else {
+ canBuildLocalHashMapBySize(left, conf) && muchSmaller(left, right)
+ }
+ val buildRight = if (hintOnly) {
+ hintToShuffleHashJoinRight(hint)
+ } else {
+ canBuildLocalHashMapBySize(right, conf) && muchSmaller(right, left)
+ }
+ getBuildSide(
+ canBuildLeft(joinType) && buildLeft,
+ canBuildRight(joinType) && buildRight,
+ left,
+ right
+ )
+ }
+
+ def getSmallerSide(left: LogicalPlan, right: LogicalPlan): BuildSide = {
+ if (right.stats.sizeInBytes <= left.stats.sizeInBytes) BuildRight else BuildLeft
+ }
+
+ /**
+ * Matches a plan whose output should be small enough to be used in broadcast join.
+ */
+ def canBroadcastBySize(plan: LogicalPlan, conf: SQLConf): Boolean = {
+ plan.stats.sizeInBytes >= 0 && plan.stats.sizeInBytes <= conf.autoBroadcastJoinThreshold
+ }
+
+ def canBuildLeft(joinType: JoinType): Boolean = {
+ joinType match {
+ case _: InnerLike | RightOuter => true
+ case _ => false
+ }
+ }
+
+ def canBuildRight(joinType: JoinType): Boolean = {
+ joinType match {
+ case _: InnerLike | LeftOuter | LeftSemi | LeftAnti | _: ExistenceJoin => true
+ case _ => false
+ }
+ }
+
+ def hintToBroadcastLeft(hint: JoinHint): Boolean = {
+ hint.leftHint.exists(_.strategy.contains(BROADCAST))
+ }
+
+ def hintToBroadcastRight(hint: JoinHint): Boolean = {
+ hint.rightHint.exists(_.strategy.contains(BROADCAST))
+ }
+
+ def hintToNotBroadcastLeft(hint: JoinHint): Boolean = {
+ hint.leftHint.exists(_.strategy.contains(NO_BROADCAST_HASH))
+ }
+
+ def hintToNotBroadcastRight(hint: JoinHint): Boolean = {
+ hint.rightHint.exists(_.strategy.contains(NO_BROADCAST_HASH))
+ }
+
+ def hintToShuffleHashJoinLeft(hint: JoinHint): Boolean = {
+ hint.leftHint.exists(_.strategy.contains(SHUFFLE_HASH))
+ }
+
+ def hintToShuffleHashJoinRight(hint: JoinHint): Boolean = {
+ hint.rightHint.exists(_.strategy.contains(SHUFFLE_HASH))
+ }
+
+ def hintToSortMergeJoin(hint: JoinHint): Boolean = {
+ hint.leftHint.exists(_.strategy.contains(SHUFFLE_MERGE)) ||
+ hint.rightHint.exists(_.strategy.contains(SHUFFLE_MERGE))
+ }
+
+ def hintToShuffleReplicateNL(hint: JoinHint): Boolean = {
+ hint.leftHint.exists(_.strategy.contains(SHUFFLE_REPLICATE_NL)) ||
+ hint.rightHint.exists(_.strategy.contains(SHUFFLE_REPLICATE_NL))
+ }
+
+ private def getBuildSide(
+ canBuildLeft: Boolean,
+ canBuildRight: Boolean,
+ left: LogicalPlan,
+ right: LogicalPlan): Option[BuildSide] = {
+ if (canBuildLeft && canBuildRight) {
+ // returns the smaller side base on its estimated physical size, if we want to build the
+ // both sides.
+ Some(getSmallerSide(left, right))
+ } else if (canBuildLeft) {
+ Some(BuildLeft)
+ } else if (canBuildRight) {
+ Some(BuildRight)
+ } else {
+ None
+ }
+ }
+
+ /**
+ * Matches a plan whose single partition should be small enough to build a hash table.
+ *
+ * Note: this assume that the number of partition is fixed, requires additional work if it's
+ * dynamic.
+ */
+ private def canBuildLocalHashMapBySize(plan: LogicalPlan, conf: SQLConf): Boolean = {
+ plan.stats.sizeInBytes < conf.autoBroadcastJoinThreshold * conf.numShufflePartitions
+ }
+
+ /**
+ * Returns whether plan a is much smaller (3X) than plan b.
+ *
+ * The cost to build hash map is higher than sorting, we should only build hash map on a table
+ * that is much smaller than other one. Since we does not have the statistic for number of rows,
+ * use the size of bytes here as estimation.
+ */
+ private def muchSmaller(a: LogicalPlan, b: LogicalPlan): Boolean = {
+ a.stats.sizeInBytes * 3 <= b.stats.sizeInBytes
+ }
+}
+
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index c0cecf8536c39..03571a740df3e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -1534,7 +1534,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
* Create a [[CreateStruct]] expression.
*/
override def visitStruct(ctx: StructContext): Expression = withOrigin(ctx) {
- CreateStruct(ctx.argument.asScala.map(expression))
+ CreateStruct.create(ctx.argument.asScala.map(expression))
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatterHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatterHelper.scala
index 0ea54c28cb285..353c074caa75e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatterHelper.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatterHelper.scala
@@ -217,9 +217,18 @@ private object DateTimeFormatterHelper {
toFormatter(builder, TimestampFormatter.defaultLocale)
}
+ private final val bugInStandAloneForm = {
+ // Java 8 has a bug for stand-alone form. See https://bugs.openjdk.java.net/browse/JDK-8114833
+ // Note: we only check the US locale so that it's a static check. It can produce false-negative
+ // as some locales are not affected by the bug. Since `L`/`q` is rarely used, we choose to not
+ // complicate the check here.
+ // TODO: remove it when we drop Java 8 support.
+ val formatter = DateTimeFormatter.ofPattern("LLL qqq", Locale.US)
+ formatter.format(LocalDate.of(2000, 1, 1)) == "1 1"
+ }
final val unsupportedLetters = Set('A', 'c', 'e', 'n', 'N', 'p')
final val unsupportedNarrowTextStyle =
- Set("GGGGG", "MMMMM", "LLLLL", "EEEEE", "uuuuu", "QQQQQ", "qqqqq", "uuuuu")
+ Seq("G", "M", "L", "E", "u", "Q", "q").map(_ * 5).toSet
/**
* In Spark 3.0, we switch to the Proleptic Gregorian calendar and use DateTimeFormatter for
@@ -244,6 +253,12 @@ private object DateTimeFormatterHelper {
for (style <- unsupportedNarrowTextStyle if patternPart.contains(style)) {
throw new IllegalArgumentException(s"Too many pattern letters: ${style.head}")
}
+ if (bugInStandAloneForm && (patternPart.contains("LLL") || patternPart.contains("qqq"))) {
+ throw new IllegalArgumentException("Java 8 has a bug to support stand-alone " +
+ "form (3 or more 'L' or 'q' in the pattern string). Please use 'M' or 'Q' instead, " +
+ "or upgrade your Java version. For more details, please read " +
+ "https://bugs.openjdk.java.net/browse/JDK-8114833")
+ }
// The meaning of 'u' was day number of week in SimpleDateFormat, it was changed to year
// in DateTimeFormatter. Substitute 'u' to 'e' and use DateTimeFormatter to parse the
// string. If parsable, return the result; otherwise, fall back to 'u', and then use the
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala
index de2fd312b7db5..8428964d45707 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala
@@ -127,16 +127,37 @@ class FractionTimestampFormatter(zoneId: ZoneId)
override protected lazy val formatter = DateTimeFormatterHelper.fractionFormatter
// The new formatter will omit the trailing 0 in the timestamp string, but the legacy formatter
- // can't. Here we borrow the code from Spark 2.4 DateTimeUtils.timestampToString to omit the
- // trailing 0 for the legacy formatter as well.
+ // can't. Here we use the legacy formatter to format the given timestamp up to seconds fractions,
+ // and custom implementation to format the fractional part without trailing zeros.
override def format(ts: Timestamp): String = {
- val timestampString = ts.toString
val formatted = legacyFormatter.format(ts)
-
- if (timestampString.length > 19 && timestampString.substring(19) != ".0") {
- formatted + timestampString.substring(19)
- } else {
+ var nanos = ts.getNanos
+ if (nanos == 0) {
formatted
+ } else {
+ // Formats non-zero seconds fraction w/o trailing zeros. For example:
+ // formatted = '2020-05:27 15:55:30'
+ // nanos = 001234000
+ // Counts the length of the fractional part: 001234000 -> 6
+ var fracLen = 9
+ while (nanos % 10 == 0) {
+ nanos /= 10
+ fracLen -= 1
+ }
+ // Places `nanos` = 1234 after '2020-05:27 15:55:30.'
+ val fracOffset = formatted.length + 1
+ val totalLen = fracOffset + fracLen
+ // The buffer for the final result: '2020-05:27 15:55:30.001234'
+ val buf = new Array[Char](totalLen)
+ formatted.getChars(0, formatted.length, buf, 0)
+ buf(formatted.length) = '.'
+ var i = totalLen
+ do {
+ i -= 1
+ buf(i) = ('0' + (nanos % 10)).toChar
+ nanos /= 10
+ } while (i > fracOffset)
+ new String(buf)
}
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
index e5bff7f7af007..6af995cab64fe 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
@@ -240,7 +240,6 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
checkCast(1.5, "1.5")
checkEvaluation(cast(cast(1.toDouble, TimestampType), DoubleType), 1.toDouble)
- checkEvaluation(cast(cast(1.toDouble, TimestampType), DoubleType), 1.toDouble)
}
test("cast from string") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
index 02d6d847dc063..1ca7380ead413 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
@@ -792,7 +792,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}
// Test escaping of format
- GenerateUnsafeProjection.generate(FromUnixTime(Literal(0L), Literal("\"quote")) :: Nil)
+ GenerateUnsafeProjection.generate(FromUnixTime(Literal(0L), Literal("\"quote"), UTC_OPT) :: Nil)
}
test("unix_timestamp") {
@@ -862,7 +862,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}
// Test escaping of format
GenerateUnsafeProjection.generate(
- UnixTimestamp(Literal("2015-07-24"), Literal("\"quote")) :: Nil)
+ UnixTimestamp(Literal("2015-07-24"), Literal("\"quote"), UTC_OPT) :: Nil)
}
test("to_unix_timestamp") {
@@ -940,7 +940,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}
// Test escaping of format
GenerateUnsafeProjection.generate(
- ToUnixTimestamp(Literal("2015-07-24"), Literal("\"quote")) :: Nil)
+ ToUnixTimestamp(Literal("2015-07-24"), Literal("\"quote"), UTC_OPT) :: Nil)
}
test("datediff") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinSelectionHelperSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinSelectionHelperSuite.scala
new file mode 100644
index 0000000000000..3513cfa14808f
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinSelectionHelperSuite.scala
@@ -0,0 +1,186 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions.AttributeMap
+import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
+import org.apache.spark.sql.catalyst.plans.logical.{BROADCAST, HintInfo, JoinHint, NO_BROADCAST_HASH, SHUFFLE_HASH}
+import org.apache.spark.sql.catalyst.statsEstimation.StatsTestPlan
+import org.apache.spark.sql.internal.SQLConf
+
+class JoinSelectionHelperSuite extends PlanTest with JoinSelectionHelper {
+
+ private val left = StatsTestPlan(
+ outputList = Seq('a.int, 'b.int, 'c.int),
+ rowCount = 20000000,
+ size = Some(20000000),
+ attributeStats = AttributeMap(Seq()))
+
+ private val right = StatsTestPlan(
+ outputList = Seq('d.int),
+ rowCount = 1000,
+ size = Some(1000),
+ attributeStats = AttributeMap(Seq()))
+
+ private val hintBroadcast = Some(HintInfo(Some(BROADCAST)))
+ private val hintNotToBroadcast = Some(HintInfo(Some(NO_BROADCAST_HASH)))
+ private val hintShuffleHash = Some(HintInfo(Some(SHUFFLE_HASH)))
+
+ test("getBroadcastBuildSide (hintOnly = true) return BuildLeft with only a left hint") {
+ val broadcastSide = getBroadcastBuildSide(
+ left,
+ right,
+ Inner,
+ JoinHint(hintBroadcast, None),
+ hintOnly = true,
+ SQLConf.get
+ )
+ assert(broadcastSide === Some(BuildLeft))
+ }
+
+ test("getBroadcastBuildSide (hintOnly = true) return BuildRight with only a right hint") {
+ val broadcastSide = getBroadcastBuildSide(
+ left,
+ right,
+ Inner,
+ JoinHint(None, hintBroadcast),
+ hintOnly = true,
+ SQLConf.get
+ )
+ assert(broadcastSide === Some(BuildRight))
+ }
+
+ test("getBroadcastBuildSide (hintOnly = true) return smaller side with both having hints") {
+ val broadcastSide = getBroadcastBuildSide(
+ left,
+ right,
+ Inner,
+ JoinHint(hintBroadcast, hintBroadcast),
+ hintOnly = true,
+ SQLConf.get
+ )
+ assert(broadcastSide === Some(BuildRight))
+ }
+
+ test("getBroadcastBuildSide (hintOnly = true) return None when no side has a hint") {
+ val broadcastSide = getBroadcastBuildSide(
+ left,
+ right,
+ Inner,
+ JoinHint(None, None),
+ hintOnly = true,
+ SQLConf.get
+ )
+ assert(broadcastSide === None)
+ }
+
+ test("getBroadcastBuildSide (hintOnly = false) return BuildRight when right is broadcastable") {
+ val broadcastSide = getBroadcastBuildSide(
+ left,
+ right,
+ Inner,
+ JoinHint(None, None),
+ hintOnly = false,
+ SQLConf.get
+ )
+ assert(broadcastSide === Some(BuildRight))
+ }
+
+ test("getBroadcastBuildSide (hintOnly = false) return None when right has no broadcast hint") {
+ val broadcastSide = getBroadcastBuildSide(
+ left,
+ right,
+ Inner,
+ JoinHint(None, hintNotToBroadcast ),
+ hintOnly = false,
+ SQLConf.get
+ )
+ assert(broadcastSide === None)
+ }
+
+ test("getShuffleHashJoinBuildSide (hintOnly = true) return BuildLeft with only a left hint") {
+ val broadcastSide = getShuffleHashJoinBuildSide(
+ left,
+ right,
+ Inner,
+ JoinHint(hintShuffleHash, None),
+ hintOnly = true,
+ SQLConf.get
+ )
+ assert(broadcastSide === Some(BuildLeft))
+ }
+
+ test("getShuffleHashJoinBuildSide (hintOnly = true) return BuildRight with only a right hint") {
+ val broadcastSide = getShuffleHashJoinBuildSide(
+ left,
+ right,
+ Inner,
+ JoinHint(None, hintShuffleHash),
+ hintOnly = true,
+ SQLConf.get
+ )
+ assert(broadcastSide === Some(BuildRight))
+ }
+
+ test("getShuffleHashJoinBuildSide (hintOnly = true) return smaller side when both have hints") {
+ val broadcastSide = getShuffleHashJoinBuildSide(
+ left,
+ right,
+ Inner,
+ JoinHint(hintShuffleHash, hintShuffleHash),
+ hintOnly = true,
+ SQLConf.get
+ )
+ assert(broadcastSide === Some(BuildRight))
+ }
+
+ test("getShuffleHashJoinBuildSide (hintOnly = true) return None when no side has a hint") {
+ val broadcastSide = getShuffleHashJoinBuildSide(
+ left,
+ right,
+ Inner,
+ JoinHint(None, None),
+ hintOnly = true,
+ SQLConf.get
+ )
+ assert(broadcastSide === None)
+ }
+
+ test("getShuffleHashJoinBuildSide (hintOnly = false) return BuildRight when right is smaller") {
+ val broadcastSide = getBroadcastBuildSide(
+ left,
+ right,
+ Inner,
+ JoinHint(None, None),
+ hintOnly = false,
+ SQLConf.get
+ )
+ assert(broadcastSide === Some(BuildRight))
+ }
+
+ test("getSmallerSide should return BuildRight") {
+ assert(getSmallerSide(left, right) === BuildRight)
+ }
+
+ test("canBroadcastBySize should return true if the plan size is less than 10MB") {
+ assert(canBroadcastBySize(left, SQLConf.get) === false)
+ assert(canBroadcastBySize(right, SQLConf.get) === true)
+ }
+
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/TimestampFormatterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/TimestampFormatterSuite.scala
index 4324d3cff63d7..7ff9b46bc6719 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/TimestampFormatterSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/TimestampFormatterSuite.scala
@@ -135,6 +135,9 @@ class TimestampFormatterSuite extends SparkFunSuite with SQLHelper with Matchers
test("format fraction of second") {
val formatter = TimestampFormatter.getFractionFormatter(UTC)
Seq(
+ -999999 -> "1969-12-31 23:59:59.000001",
+ -999900 -> "1969-12-31 23:59:59.0001",
+ -1 -> "1969-12-31 23:59:59.999999",
0 -> "1970-01-01 00:00:00",
1 -> "1970-01-01 00:00:00.000001",
1000 -> "1970-01-01 00:00:00.001",
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index 1df812d1aa809..89915d254883d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -22,6 +22,7 @@ import java.util.UUID
import org.apache.hadoop.fs.Path
+import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
import org.apache.spark.sql.catalyst.{InternalRow, QueryPlanningTracker}
@@ -50,7 +51,7 @@ import org.apache.spark.util.Utils
class QueryExecution(
val sparkSession: SparkSession,
val logical: LogicalPlan,
- val tracker: QueryPlanningTracker = new QueryPlanningTracker) {
+ val tracker: QueryPlanningTracker = new QueryPlanningTracker) extends Logging {
// TODO: Move the planner an optimizer into here from SessionState.
protected def planner = sparkSession.sessionState.planner
@@ -133,26 +134,42 @@ class QueryExecution(
tracker.measurePhase(phase)(block)
}
- def simpleString: String = simpleString(false)
-
- def simpleString(formatted: Boolean): String = withRedaction {
+ def simpleString: String = {
val concat = new PlanStringConcat()
- concat.append("== Physical Plan ==\n")
+ simpleString(false, SQLConf.get.maxToStringFields, concat.append)
+ withRedaction {
+ concat.toString
+ }
+ }
+
+ private def simpleString(
+ formatted: Boolean,
+ maxFields: Int,
+ append: String => Unit): Unit = {
+ append("== Physical Plan ==\n")
if (formatted) {
try {
- ExplainUtils.processPlan(executedPlan, concat.append)
+ ExplainUtils.processPlan(executedPlan, append)
} catch {
- case e: AnalysisException => concat.append(e.toString)
- case e: IllegalArgumentException => concat.append(e.toString)
+ case e: AnalysisException => append(e.toString)
+ case e: IllegalArgumentException => append(e.toString)
}
} else {
- QueryPlan.append(executedPlan, concat.append, verbose = false, addSuffix = false)
+ QueryPlan.append(executedPlan,
+ append, verbose = false, addSuffix = false, maxFields = maxFields)
}
- concat.append("\n")
- concat.toString
+ append("\n")
}
def explainString(mode: ExplainMode): String = {
+ val concat = new PlanStringConcat()
+ explainString(mode, SQLConf.get.maxToStringFields, concat.append)
+ withRedaction {
+ concat.toString
+ }
+ }
+
+ private def explainString(mode: ExplainMode, maxFields: Int, append: String => Unit): Unit = {
val queryExecution = if (logical.isStreaming) {
// This is used only by explaining `Dataset/DataFrame` created by `spark.readStream`, so the
// output mode does not matter since there is no `Sink`.
@@ -165,19 +182,19 @@ class QueryExecution(
mode match {
case SimpleMode =>
- queryExecution.simpleString
+ queryExecution.simpleString(false, maxFields, append)
case ExtendedMode =>
- queryExecution.toString
+ queryExecution.toString(maxFields, append)
case CodegenMode =>
try {
- org.apache.spark.sql.execution.debug.codegenString(queryExecution.executedPlan)
+ org.apache.spark.sql.execution.debug.writeCodegen(append, queryExecution.executedPlan)
} catch {
- case e: AnalysisException => e.toString
+ case e: AnalysisException => append(e.toString)
}
case CostMode =>
- queryExecution.stringWithStats
+ queryExecution.stringWithStats(maxFields, append)
case FormattedMode =>
- queryExecution.simpleString(formatted = true)
+ queryExecution.simpleString(formatted = true, maxFields = maxFields, append)
}
}
@@ -204,27 +221,39 @@ class QueryExecution(
override def toString: String = withRedaction {
val concat = new PlanStringConcat()
- writePlans(concat.append, SQLConf.get.maxToStringFields)
- concat.toString
+ toString(SQLConf.get.maxToStringFields, concat.append)
+ withRedaction {
+ concat.toString
+ }
+ }
+
+ private def toString(maxFields: Int, append: String => Unit): Unit = {
+ writePlans(append, maxFields)
}
- def stringWithStats: String = withRedaction {
+ def stringWithStats: String = {
val concat = new PlanStringConcat()
+ stringWithStats(SQLConf.get.maxToStringFields, concat.append)
+ withRedaction {
+ concat.toString
+ }
+ }
+
+ private def stringWithStats(maxFields: Int, append: String => Unit): Unit = {
val maxFields = SQLConf.get.maxToStringFields
// trigger to compute stats for logical plans
try {
optimizedPlan.stats
} catch {
- case e: AnalysisException => concat.append(e.toString + "\n")
+ case e: AnalysisException => append(e.toString + "\n")
}
// only show optimized logical plan and physical plan
- concat.append("== Optimized Logical Plan ==\n")
- QueryPlan.append(optimizedPlan, concat.append, verbose = true, addSuffix = true, maxFields)
- concat.append("\n== Physical Plan ==\n")
- QueryPlan.append(executedPlan, concat.append, verbose = true, addSuffix = false, maxFields)
- concat.append("\n")
- concat.toString
+ append("== Optimized Logical Plan ==\n")
+ QueryPlan.append(optimizedPlan, append, verbose = true, addSuffix = true, maxFields)
+ append("\n== Physical Plan ==\n")
+ QueryPlan.append(executedPlan, append, verbose = true, addSuffix = false, maxFields)
+ append("\n")
}
/**
@@ -261,19 +290,26 @@ class QueryExecution(
/**
* Dumps debug information about query execution into the specified file.
*
+ * @param path path of the file the debug info is written to.
* @param maxFields maximum number of fields converted to string representation.
+ * @param explainMode the explain mode to be used to generate the string
+ * representation of the plan.
*/
- def toFile(path: String, maxFields: Int = Int.MaxValue): Unit = {
+ def toFile(
+ path: String,
+ maxFields: Int = Int.MaxValue,
+ explainMode: Option[String] = None): Unit = {
val filePath = new Path(path)
val fs = filePath.getFileSystem(sparkSession.sessionState.newHadoopConf())
val writer = new BufferedWriter(new OutputStreamWriter(fs.create(filePath)))
- val append = (s: String) => {
- writer.write(s)
- }
try {
- writePlans(append, maxFields)
- writer.write("\n== Whole Stage Codegen ==\n")
- org.apache.spark.sql.execution.debug.writeCodegen(writer.write, executedPlan)
+ val mode = explainMode.map(ExplainMode.fromString(_)).getOrElse(ExtendedMode)
+ explainString(mode, maxFields, writer.write)
+ if (mode != CodegenMode) {
+ writer.write("\n== Whole Stage Codegen ==\n")
+ org.apache.spark.sql.execution.debug.writeCodegen(writer.write, executedPlan)
+ }
+ log.info(s"Debug information was written at: $filePath")
} finally {
writer.close()
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 12a1a1e7fc16e..302aae08d588b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
-import org.apache.spark.sql.catalyst.optimizer.NormalizeFloatingNumbers
+import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, JoinSelectionHelper, NormalizeFloatingNumbers}
import org.apache.spark.sql.catalyst.planning._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
@@ -33,7 +33,6 @@ import org.apache.spark.sql.execution.aggregate.AggUtils
import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec}
import org.apache.spark.sql.execution.command._
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
-import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide}
import org.apache.spark.sql.execution.python._
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.sources.MemoryPlan
@@ -135,93 +134,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
* Supports both equi-joins and non-equi-joins.
* Supports only inner like joins.
*/
- object JoinSelection extends Strategy with PredicateHelper {
-
- /**
- * Matches a plan whose output should be small enough to be used in broadcast join.
- */
- private def canBroadcast(plan: LogicalPlan): Boolean = {
- plan.stats.sizeInBytes >= 0 && plan.stats.sizeInBytes <= conf.autoBroadcastJoinThreshold
- }
-
- /**
- * Matches a plan whose single partition should be small enough to build a hash table.
- *
- * Note: this assume that the number of partition is fixed, requires additional work if it's
- * dynamic.
- */
- private def canBuildLocalHashMap(plan: LogicalPlan): Boolean = {
- plan.stats.sizeInBytes < conf.autoBroadcastJoinThreshold * conf.numShufflePartitions
- }
-
- /**
- * Returns whether plan a is much smaller (3X) than plan b.
- *
- * The cost to build hash map is higher than sorting, we should only build hash map on a table
- * that is much smaller than other one. Since we does not have the statistic for number of rows,
- * use the size of bytes here as estimation.
- */
- private def muchSmaller(a: LogicalPlan, b: LogicalPlan): Boolean = {
- a.stats.sizeInBytes * 3 <= b.stats.sizeInBytes
- }
-
- private def canBuildRight(joinType: JoinType): Boolean = joinType match {
- case _: InnerLike | LeftOuter | LeftSemi | LeftAnti | _: ExistenceJoin => true
- case _ => false
- }
-
- private def canBuildLeft(joinType: JoinType): Boolean = joinType match {
- case _: InnerLike | RightOuter => true
- case _ => false
- }
-
- private def getBuildSide(
- wantToBuildLeft: Boolean,
- wantToBuildRight: Boolean,
- left: LogicalPlan,
- right: LogicalPlan): Option[BuildSide] = {
- if (wantToBuildLeft && wantToBuildRight) {
- // returns the smaller side base on its estimated physical size, if we want to build the
- // both sides.
- Some(getSmallerSide(left, right))
- } else if (wantToBuildLeft) {
- Some(BuildLeft)
- } else if (wantToBuildRight) {
- Some(BuildRight)
- } else {
- None
- }
- }
-
- private def getSmallerSide(left: LogicalPlan, right: LogicalPlan) = {
- if (right.stats.sizeInBytes <= left.stats.sizeInBytes) BuildRight else BuildLeft
- }
-
- private def hintToBroadcastLeft(hint: JoinHint): Boolean = {
- hint.leftHint.exists(_.strategy.contains(BROADCAST))
- }
-
- private def hintToBroadcastRight(hint: JoinHint): Boolean = {
- hint.rightHint.exists(_.strategy.contains(BROADCAST))
- }
-
- private def hintToShuffleHashLeft(hint: JoinHint): Boolean = {
- hint.leftHint.exists(_.strategy.contains(SHUFFLE_HASH))
- }
-
- private def hintToShuffleHashRight(hint: JoinHint): Boolean = {
- hint.rightHint.exists(_.strategy.contains(SHUFFLE_HASH))
- }
-
- private def hintToSortMergeJoin(hint: JoinHint): Boolean = {
- hint.leftHint.exists(_.strategy.contains(SHUFFLE_MERGE)) ||
- hint.rightHint.exists(_.strategy.contains(SHUFFLE_MERGE))
- }
-
- private def hintToShuffleReplicateNL(hint: JoinHint): Boolean = {
- hint.leftHint.exists(_.strategy.contains(SHUFFLE_REPLICATE_NL)) ||
- hint.rightHint.exists(_.strategy.contains(SHUFFLE_REPLICATE_NL))
- }
+ object JoinSelection extends Strategy
+ with PredicateHelper
+ with JoinSelectionHelper {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
@@ -245,33 +160,31 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
// 5. Pick broadcast nested loop join as the final solution. It may OOM but we don't have
// other choice.
case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, hint) =>
- def createBroadcastHashJoin(buildLeft: Boolean, buildRight: Boolean) = {
- val wantToBuildLeft = canBuildLeft(joinType) && buildLeft
- val wantToBuildRight = canBuildRight(joinType) && buildRight
- getBuildSide(wantToBuildLeft, wantToBuildRight, left, right).map { buildSide =>
- Seq(joins.BroadcastHashJoinExec(
- leftKeys,
- rightKeys,
- joinType,
- buildSide,
- condition,
- planLater(left),
- planLater(right)))
+ def createBroadcastHashJoin(onlyLookingAtHint: Boolean) = {
+ getBroadcastBuildSide(left, right, joinType, hint, onlyLookingAtHint, conf).map {
+ buildSide =>
+ Seq(joins.BroadcastHashJoinExec(
+ leftKeys,
+ rightKeys,
+ joinType,
+ buildSide,
+ condition,
+ planLater(left),
+ planLater(right)))
}
}
- def createShuffleHashJoin(buildLeft: Boolean, buildRight: Boolean) = {
- val wantToBuildLeft = canBuildLeft(joinType) && buildLeft
- val wantToBuildRight = canBuildRight(joinType) && buildRight
- getBuildSide(wantToBuildLeft, wantToBuildRight, left, right).map { buildSide =>
- Seq(joins.ShuffledHashJoinExec(
- leftKeys,
- rightKeys,
- joinType,
- buildSide,
- condition,
- planLater(left),
- planLater(right)))
+ def createShuffleHashJoin(onlyLookingAtHint: Boolean) = {
+ getShuffleHashJoinBuildSide(left, right, joinType, hint, onlyLookingAtHint, conf).map {
+ buildSide =>
+ Seq(joins.ShuffledHashJoinExec(
+ leftKeys,
+ rightKeys,
+ joinType,
+ buildSide,
+ condition,
+ planLater(left),
+ planLater(right)))
}
}
@@ -293,14 +206,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
def createJoinWithoutHint() = {
- createBroadcastHashJoin(
- canBroadcast(left) && !hint.leftHint.exists(_.strategy.contains(NO_BROADCAST_HASH)),
- canBroadcast(right) && !hint.rightHint.exists(_.strategy.contains(NO_BROADCAST_HASH)))
+ createBroadcastHashJoin(false)
.orElse {
if (!conf.preferSortMergeJoin) {
- createShuffleHashJoin(
- canBuildLocalHashMap(left) && muchSmaller(left, right),
- canBuildLocalHashMap(right) && muchSmaller(right, left))
+ createShuffleHashJoin(false)
} else {
None
}
@@ -315,9 +224,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
}
- createBroadcastHashJoin(hintToBroadcastLeft(hint), hintToBroadcastRight(hint))
+ createBroadcastHashJoin(true)
.orElse { if (hintToSortMergeJoin(hint)) createSortMergeJoin() else None }
- .orElse(createShuffleHashJoin(hintToShuffleHashLeft(hint), hintToShuffleHashRight(hint)))
+ .orElse(createShuffleHashJoin(true))
.orElse { if (hintToShuffleReplicateNL(hint)) createCartesianProduct() else None }
.getOrElse(createJoinWithoutHint())
@@ -374,7 +283,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
def createJoinWithoutHint() = {
- createBroadcastNLJoin(canBroadcast(left), canBroadcast(right))
+ createBroadcastNLJoin(canBroadcastBySize(left, conf), canBroadcastBySize(right, conf))
.orElse(createCartesianProduct())
.getOrElse {
// This join could be very slow or OOM
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
index eb5fcd3b24227..bc924e6978ddc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
@@ -156,7 +156,7 @@ case class AdaptiveSparkPlanExec(
var currentLogicalPlan = currentPhysicalPlan.logicalLink.get
var result = createQueryStages(currentPhysicalPlan)
val events = new LinkedBlockingQueue[StageMaterializationEvent]()
- val errors = new mutable.ArrayBuffer[SparkException]()
+ val errors = new mutable.ArrayBuffer[Throwable]()
var stagesToReplace = Seq.empty[QueryStageExec]
while (!result.allChildStagesMaterialized) {
currentPhysicalPlan = result.newPlan
@@ -176,9 +176,7 @@ case class AdaptiveSparkPlanExec(
}(AdaptiveSparkPlanExec.executionContext)
} catch {
case e: Throwable =>
- val ex = new SparkException(
- s"Early failed query stage found: ${stage.treeString}", e)
- cleanUpAndThrowException(Seq(ex), Some(stage.id))
+ cleanUpAndThrowException(Seq(e), Some(stage.id))
}
}
}
@@ -191,10 +189,9 @@ case class AdaptiveSparkPlanExec(
events.drainTo(rem)
(Seq(nextMsg) ++ rem.asScala).foreach {
case StageSuccess(stage, res) =>
- stage.resultOption = Some(res)
+ stage.resultOption.set(Some(res))
case StageFailure(stage, ex) =>
- errors.append(
- new SparkException(s"Failed to materialize query stage: ${stage.treeString}.", ex))
+ errors.append(ex)
}
// In case of errors, we cancel all running stages and throw exception.
@@ -328,11 +325,11 @@ case class AdaptiveSparkPlanExec(
context.stageCache.get(e.canonicalized) match {
case Some(existingStage) if conf.exchangeReuseEnabled =>
val stage = reuseQueryStage(existingStage, e)
- // This is a leaf stage and is not materialized yet even if the reused exchange may has
- // been completed. It will trigger re-optimization later and stage materialization will
- // finish in instant if the underlying exchange is already completed.
+ val isMaterialized = stage.resultOption.get().isDefined
CreateStageResult(
- newPlan = stage, allChildStagesMaterialized = false, newStages = Seq(stage))
+ newPlan = stage,
+ allChildStagesMaterialized = isMaterialized,
+ newStages = if (isMaterialized) Seq.empty else Seq(stage))
case _ =>
val result = createQueryStages(e.child)
@@ -349,10 +346,11 @@ case class AdaptiveSparkPlanExec(
newStage = reuseQueryStage(queryStage, e)
}
}
-
- // We've created a new stage, which is obviously not ready yet.
- CreateStageResult(newPlan = newStage,
- allChildStagesMaterialized = false, newStages = Seq(newStage))
+ val isMaterialized = newStage.resultOption.get().isDefined
+ CreateStageResult(
+ newPlan = newStage,
+ allChildStagesMaterialized = isMaterialized,
+ newStages = if (isMaterialized) Seq.empty else Seq(newStage))
} else {
CreateStageResult(newPlan = newPlan,
allChildStagesMaterialized = false, newStages = result.newStages)
@@ -361,7 +359,7 @@ case class AdaptiveSparkPlanExec(
case q: QueryStageExec =>
CreateStageResult(newPlan = q,
- allChildStagesMaterialized = q.resultOption.isDefined, newStages = Seq.empty)
+ allChildStagesMaterialized = q.resultOption.get().isDefined, newStages = Seq.empty)
case _ =>
if (plan.children.isEmpty) {
@@ -537,31 +535,28 @@ case class AdaptiveSparkPlanExec(
* materialization errors and stage cancellation errors.
*/
private def cleanUpAndThrowException(
- errors: Seq[SparkException],
+ errors: Seq[Throwable],
earlyFailedStage: Option[Int]): Unit = {
- val runningStages = currentPhysicalPlan.collect {
+ currentPhysicalPlan.foreach {
// earlyFailedStage is the stage which failed before calling doMaterialize,
// so we should avoid calling cancel on it to re-trigger the failure again.
- case s: QueryStageExec if !earlyFailedStage.contains(s.id) => s
- }
- val cancelErrors = new mutable.ArrayBuffer[SparkException]()
- try {
- runningStages.foreach { s =>
+ case s: QueryStageExec if !earlyFailedStage.contains(s.id) =>
try {
s.cancel()
} catch {
case NonFatal(t) =>
- cancelErrors.append(
- new SparkException(s"Failed to cancel query stage: ${s.treeString}", t))
+ logError(s"Exception in cancelling query stage: ${s.treeString}", t)
}
- }
- } finally {
- val ex = new SparkException(
- "Adaptive execution failed due to stage materialization failures.", errors.head)
- errors.tail.foreach(ex.addSuppressed)
- cancelErrors.foreach(ex.addSuppressed)
- throw ex
+ case _ =>
+ }
+ val e = if (errors.size == 1) {
+ errors.head
+ } else {
+ val se = new SparkException("Multiple failures in stage materialization.", errors.head)
+ errors.tail.foreach(se.addSuppressed)
+ se
}
+ throw e
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/DemoteBroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/DemoteBroadcastHashJoin.scala
index 0f2868e41cc39..aba83b1337109 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/DemoteBroadcastHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/DemoteBroadcastHashJoin.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.internal.SQLConf
case class DemoteBroadcastHashJoin(conf: SQLConf) extends Rule[LogicalPlan] {
private def shouldDemote(plan: LogicalPlan): Boolean = plan match {
- case LogicalQueryStage(_, stage: ShuffleQueryStageExec) if stage.resultOption.isDefined
+ case LogicalQueryStage(_, stage: ShuffleQueryStageExec) if stage.resultOption.get().isDefined
&& stage.mapStats.isDefined =>
val mapStats = stage.mapStats.get
val partitionCnt = mapStats.bytesByPartitionId.length
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStageStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStageStrategy.scala
index d60c3ca72f6f6..ac98342277bc0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStageStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStageStrategy.scala
@@ -19,10 +19,11 @@ package org.apache.spark.sql.execution.adaptive
import org.apache.spark.sql.Strategy
import org.apache.spark.sql.catalyst.expressions.PredicateHelper
+import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan}
import org.apache.spark.sql.execution.SparkPlan
-import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, BuildLeft, BuildRight}
+import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec}
/**
* Strategy for plans containing [[LogicalQueryStage]] nodes:
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala
index 5416fde222cb6..3620f27058af2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala
@@ -17,10 +17,11 @@
package org.apache.spark.sql.execution.adaptive
+import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ShuffleExchangeExec}
-import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BuildLeft, BuildRight, BuildSide}
+import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
import org.apache.spark.sql.internal.SQLConf
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala
index f414f854b92ae..4e83b4344fbf0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.execution.adaptive
import java.util.concurrent.TimeUnit
+import java.util.concurrent.atomic.AtomicReference
import scala.concurrent.{Future, Promise}
@@ -25,6 +26,7 @@ import org.apache.spark.{FutureAction, MapOutputStatistics, SparkException}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.errors.attachTree
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
@@ -82,7 +84,7 @@ abstract class QueryStageExec extends LeafExecNode {
/**
* Compute the statistics of the query stage if executed, otherwise None.
*/
- def computeStats(): Option[Statistics] = resultOption.map { _ =>
+ def computeStats(): Option[Statistics] = resultOption.get().map { _ =>
// Metrics `dataSize` are available in both `ShuffleExchangeExec` and `BroadcastExchangeExec`.
val exchange = plan match {
case r: ReusedExchangeExec => r.child
@@ -94,7 +96,9 @@ abstract class QueryStageExec extends LeafExecNode {
@transient
@volatile
- private[adaptive] var resultOption: Option[Any] = None
+ protected var _resultOption = new AtomicReference[Option[Any]](None)
+
+ private[adaptive] def resultOption: AtomicReference[Option[Any]] = _resultOption
override def output: Seq[Attribute] = plan.output
override def outputPartitioning: Partitioning = plan.outputPartitioning
@@ -147,14 +151,16 @@ case class ShuffleQueryStageExec(
throw new IllegalStateException("wrong plan for shuffle stage:\n " + plan.treeString)
}
- override def doMaterialize(): Future[Any] = {
+ override def doMaterialize(): Future[Any] = attachTree(this, "execute") {
shuffle.mapOutputStatisticsFuture
}
override def newReuseInstance(newStageId: Int, newOutput: Seq[Attribute]): QueryStageExec = {
- ShuffleQueryStageExec(
+ val reuse = ShuffleQueryStageExec(
newStageId,
ReusedExchangeExec(newOutput, shuffle))
+ reuse._resultOption = this._resultOption
+ reuse
}
override def cancel(): Unit = {
@@ -171,8 +177,8 @@ case class ShuffleQueryStageExec(
* this method returns None, as there is no map statistics.
*/
def mapStats: Option[MapOutputStatistics] = {
- assert(resultOption.isDefined, "ShuffleQueryStageExec should already be ready")
- val stats = resultOption.get.asInstanceOf[MapOutputStatistics]
+ assert(resultOption.get().isDefined, "ShuffleQueryStageExec should already be ready")
+ val stats = resultOption.get().get.asInstanceOf[MapOutputStatistics]
Option(stats)
}
}
@@ -212,9 +218,11 @@ case class BroadcastQueryStageExec(
}
override def newReuseInstance(newStageId: Int, newOutput: Seq[Attribute]): QueryStageExec = {
- BroadcastQueryStageExec(
+ val reuse = BroadcastQueryStageExec(
newStageId,
ReusedExchangeExec(newOutput, broadcast))
+ reuse._resultOption = this._resultOption
+ reuse
}
override def cancel(): Unit = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala
index eb091758910cd..cfc653a23840d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.dynamicpruning
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions.{Alias, BindReferences, DynamicPruningExpression, DynamicPruningSubquery, Expression, ListQuery, Literal, PredicateHelper}
+import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode
import org.apache.spark.sql.catalyst.rules.Rule
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
index 08128d8f69dab..707ed1402d1ae 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
+import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, UnspecifiedDistribution}
import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
index 888e7af7c07ed..52b476f9cf134 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
@@ -21,6 +21,7 @@ import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{ExplainUtils, SparkPlan}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
index 7f90a51c1f234..c7c3e1672f034 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.joins
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
+import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.execution.{ExplainUtils, RowIterator}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
index 755a63e545ef1..2b7cd65e7d96f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
@@ -23,6 +23,7 @@ import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.optimizer.BuildSide
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.SparkPlan
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/package.scala
deleted file mode 100644
index 134376628ae7f..0000000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/package.scala
+++ /dev/null
@@ -1,31 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution
-
-/**
- * Physical execution operators for join operations.
- */
-package object joins {
-
- sealed abstract class BuildSide
-
- case object BuildRight extends BuildSide
-
- case object BuildLeft extends BuildSide
-
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala
index 8c23f2cbb86ba..33539c01ee5dd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala
@@ -203,11 +203,10 @@ private[ui] class ExecutionPagedTable(
private val (sortColumn, desc, pageSize) = getTableParameters(request, executionTag, "ID")
+ private val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name())
+
override val dataSource = new ExecutionDataSource(
- request,
- parent,
data,
- basePath,
currentTime,
pageSize,
sortColumn,
@@ -222,11 +221,9 @@ private[ui] class ExecutionPagedTable(
override def tableId: String = s"$executionTag-table"
override def tableCssClass: String =
- "table table-bordered table-sm table-striped " +
- "table-head-clickable table-cell-width-limited"
+ "table table-bordered table-sm table-striped table-head-clickable table-cell-width-limited"
override def pageLink(page: Int): String = {
- val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name())
parameterPath +
s"&$pageNumberFormField=$page" +
s"&$executionTag.sort=$encodedSortColumn" +
@@ -239,10 +236,8 @@ private[ui] class ExecutionPagedTable(
override def pageNumberFormField: String = s"$executionTag.page"
- override def goButtonFormPath: String = {
- val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name())
+ override def goButtonFormPath: String =
s"$parameterPath&$executionTag.sort=$encodedSortColumn&$executionTag.desc=$desc#$tableHeaderId"
- }
override def headers: Seq[Node] = {
// Information for each header: title, sortable, tooltip
@@ -348,7 +343,6 @@ private[ui] class ExecutionPagedTable(
private[ui] class ExecutionTableRowData(
- val submissionTime: Long,
val duration: Long,
val executionUIData: SQLExecutionUIData,
val runningJobData: Seq[Int],
@@ -357,10 +351,7 @@ private[ui] class ExecutionTableRowData(
private[ui] class ExecutionDataSource(
- request: HttpServletRequest,
- parent: SQLTab,
executionData: Seq[SQLExecutionUIData],
- basePath: String,
currentTime: Long,
pageSize: Int,
sortColumn: String,
@@ -373,20 +364,13 @@ private[ui] class ExecutionDataSource(
// in the table so that we can avoid creating duplicate contents during sorting the data
private val data = executionData.map(executionRow).sorted(ordering(sortColumn, desc))
- private var _sliceExecutionIds: Set[Int] = _
-
override def dataSize: Int = data.size
- override def sliceData(from: Int, to: Int): Seq[ExecutionTableRowData] = {
- val r = data.slice(from, to)
- _sliceExecutionIds = r.map(_.executionUIData.executionId.toInt).toSet
- r
- }
+ override def sliceData(from: Int, to: Int): Seq[ExecutionTableRowData] = data.slice(from, to)
private def executionRow(executionUIData: SQLExecutionUIData): ExecutionTableRowData = {
- val submissionTime = executionUIData.submissionTime
val duration = executionUIData.completionTime.map(_.getTime())
- .getOrElse(currentTime) - submissionTime
+ .getOrElse(currentTime) - executionUIData.submissionTime
val runningJobData = if (showRunningJobs) {
executionUIData.jobs.filter {
@@ -407,7 +391,6 @@ private[ui] class ExecutionDataSource(
} else Seq.empty
new ExecutionTableRowData(
- submissionTime,
duration,
executionUIData,
runningJobData,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 5481337bf6cee..0cca3e7b47c56 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1306,7 +1306,7 @@ object functions {
* @since 1.4.0
*/
@scala.annotation.varargs
- def struct(cols: Column*): Column = withExpr { CreateStruct(cols.map(_.expr)) }
+ def struct(cols: Column*): Column = withExpr { CreateStruct.create(cols.map(_.expr)) }
/**
* Creates a new struct column that composes multiple input columns.
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java
index 5603cb988b8e7..af0a22b036030 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java
@@ -18,6 +18,8 @@
package test.org.apache.spark.sql;
import java.io.Serializable;
+import java.sql.Timestamp;
+import java.text.SimpleDateFormat;
import java.time.Instant;
import java.time.LocalDate;
import java.util.*;
@@ -210,6 +212,17 @@ private static Row createRecordSpark22000Row(Long index) {
return new GenericRow(values);
}
+ private static String timestampToString(Timestamp ts) {
+ String timestampString = String.valueOf(ts);
+ String formatted = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(ts);
+
+ if (timestampString.length() > 19 && !timestampString.substring(19).equals(".0")) {
+ return formatted + timestampString.substring(19);
+ } else {
+ return formatted;
+ }
+ }
+
private static RecordSpark22000 createRecordSpark22000(Row recordRow) {
RecordSpark22000 record = new RecordSpark22000();
record.setShortField(String.valueOf(recordRow.getShort(0)));
@@ -219,7 +232,7 @@ private static RecordSpark22000 createRecordSpark22000(Row recordRow) {
record.setDoubleField(String.valueOf(recordRow.getDouble(4)));
record.setStringField(recordRow.getString(5));
record.setBooleanField(String.valueOf(recordRow.getBoolean(6)));
- record.setTimestampField(String.valueOf(recordRow.getTimestamp(7)));
+ record.setTimestampField(timestampToString(recordRow.getTimestamp(7)));
// This would figure out that null value will not become "null".
record.setNullIntField(null);
return record;
diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
index 23173c8ba1f11..d245aa5a17345 100644
--- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
+++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
@@ -2,7 +2,7 @@
## Summary
- Number of queries: 337
- Number of expressions that missing example: 34
- - Expressions missing examples: and,string,tinyint,double,smallint,date,decimal,boolean,float,binary,bigint,int,timestamp,cume_dist,dense_rank,input_file_block_length,input_file_block_start,input_file_name,lag,lead,monotonically_increasing_id,ntile,struct,!,not,or,percent_rank,rank,row_number,spark_partition_id,version,window,positive,count_min_sketch
+ - Expressions missing examples: and,string,tinyint,double,smallint,date,decimal,boolean,float,binary,bigint,int,timestamp,struct,cume_dist,dense_rank,input_file_block_length,input_file_block_start,input_file_name,lag,lead,monotonically_increasing_id,ntile,!,not,or,percent_rank,rank,row_number,spark_partition_id,version,window,positive,count_min_sketch
## Schema of Built-in Functions
| Class name | Function name or alias | Query example | Output schema |
| ---------- | ---------------------- | ------------- | ------------- |
@@ -79,6 +79,7 @@
| org.apache.spark.sql.catalyst.expressions.CreateArray | array | SELECT array(1, 2, 3) | struct> |
| org.apache.spark.sql.catalyst.expressions.CreateMap | map | SELECT map(1.0, '2', 3.0, '4') | struct