diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index 716d604ca31b4..066512d159d00 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -214,7 +214,6 @@ private[ui] class AllJobsPage(parent: JobsTab, store: AppStatusStore) extends We val someJobHasJobGroup = jobs.exists(_.jobGroup.isDefined) val jobIdTitle = if (someJobHasJobGroup) "Job Id (Job Group)" else "Job Id" val jobPage = Option(request.getParameter(jobTag + ".page")).map(_.toInt).getOrElse(1) - val currentTime = System.currentTimeMillis() try { new JobPagedTable( @@ -226,7 +225,6 @@ private[ui] class AllJobsPage(parent: JobsTab, store: AppStatusStore) extends We UIUtils.prependBaseUri(request, parent.basePath), "jobs", // subPath killEnabled, - currentTime, jobIdTitle ).table(jobPage) } catch { @@ -399,7 +397,6 @@ private[ui] class JobDataSource( store: AppStatusStore, jobs: Seq[v1.JobData], basePath: String, - currentTime: Long, pageSize: Int, sortColumn: String, desc: Boolean) extends PagedDataSource[JobTableRowData](pageSize) { @@ -410,15 +407,9 @@ private[ui] class JobDataSource( // so that we can avoid creating duplicate contents during sorting the data private val data = jobs.map(jobRow).sorted(ordering(sortColumn, desc)) - private var _slicedJobIds: Set[Int] = null - override def dataSize: Int = data.size - override def sliceData(from: Int, to: Int): Seq[JobTableRowData] = { - val r = data.slice(from, to) - _slicedJobIds = r.map(_.jobData.jobId).toSet - r - } + override def sliceData(from: Int, to: Int): Seq[JobTableRowData] = data.slice(from, to) private def jobRow(jobData: v1.JobData): JobTableRowData = { val duration: Option[Long] = JobDataUtil.getDuration(jobData) @@ -479,17 +470,17 @@ private[ui] class JobPagedTable( basePath: String, subPath: String, killEnabled: Boolean, - currentTime: Long, jobIdTitle: String ) extends PagedTable[JobTableRowData] { + private val (sortColumn, desc, pageSize) = getTableParameters(request, jobTag, jobIdTitle) private val parameterPath = basePath + s"/$subPath/?" + getParameterOtherTable(request, jobTag) + private val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name()) override def tableId: String = jobTag + "-table" override def tableCssClass: String = - "table table-bordered table-sm table-striped " + - "table-head-clickable table-cell-width-limited" + "table table-bordered table-sm table-striped table-head-clickable table-cell-width-limited" override def pageSizeFormField: String = jobTag + ".pageSize" @@ -499,13 +490,11 @@ private[ui] class JobPagedTable( store, data, basePath, - currentTime, pageSize, sortColumn, desc) override def pageLink(page: Int): String = { - val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name()) parameterPath + s"&$pageNumberFormField=$page" + s"&$jobTag.sort=$encodedSortColumn" + @@ -514,10 +503,8 @@ private[ui] class JobPagedTable( s"#$tableHeaderId" } - override def goButtonFormPath: String = { - val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name()) + override def goButtonFormPath: String = s"$parameterPath&$jobTag.sort=$encodedSortColumn&$jobTag.desc=$desc#$tableHeaderId" - } override def headers: Seq[Node] = { // Information for each header: title, sortable, tooltip diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 1b072274541c8..47ba951953cec 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -212,7 +212,6 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We stageData, UIUtils.prependBaseUri(request, parent.basePath) + s"/stages/stage/?id=${stageId}&attempt=${stageAttemptId}", - currentTime, pageSize = taskPageSize, sortColumn = taskSortColumn, desc = taskSortDesc, @@ -452,7 +451,6 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We private[ui] class TaskDataSource( stage: StageData, - currentTime: Long, pageSize: Int, sortColumn: String, desc: Boolean, @@ -474,8 +472,6 @@ private[ui] class TaskDataSource( _tasksToShow } - def tasks: Seq[TaskData] = _tasksToShow - def executorLogs(id: String): Map[String, String] = { executorIdToLogs.getOrElseUpdate(id, store.asOption(store.executorSummary(id)).map(_.executorLogs).getOrElse(Map.empty)) @@ -486,7 +482,6 @@ private[ui] class TaskDataSource( private[ui] class TaskPagedTable( stage: StageData, basePath: String, - currentTime: Long, pageSize: Int, sortColumn: String, desc: Boolean, @@ -494,6 +489,8 @@ private[ui] class TaskPagedTable( import ApiHelper._ + private val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name()) + override def tableId: String = "task-table" override def tableCssClass: String = @@ -505,14 +502,12 @@ private[ui] class TaskPagedTable( override val dataSource: TaskDataSource = new TaskDataSource( stage, - currentTime, pageSize, sortColumn, desc, store) override def pageLink(page: Int): String = { - val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name()) basePath + s"&$pageNumberFormField=$page" + s"&task.sort=$encodedSortColumn" + @@ -520,10 +515,7 @@ private[ui] class TaskPagedTable( s"&$pageSizeFormField=$pageSize" } - override def goButtonFormPath: String = { - val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name()) - s"$basePath&task.sort=$encodedSortColumn&task.desc=$desc" - } + override def goButtonFormPath: String = s"$basePath&task.sort=$encodedSortColumn&task.desc=$desc" def headers: Seq[Node] = { import ApiHelper._ diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index f9e84c2b2f4ec..9e6eb418fe134 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -116,8 +116,7 @@ private[ui] class StagePagedTable( override def tableId: String = stageTag + "-table" override def tableCssClass: String = - "table table-bordered table-sm table-striped " + - "table-head-clickable table-cell-width-limited" + "table table-bordered table-sm table-striped table-head-clickable table-cell-width-limited" override def pageSizeFormField: String = stageTag + ".pageSize" @@ -125,7 +124,9 @@ private[ui] class StagePagedTable( private val (sortColumn, desc, pageSize) = getTableParameters(request, stageTag, "Stage Id") - val parameterPath = UIUtils.prependBaseUri(request, basePath) + s"/$subPath/?" + + private val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name()) + + private val parameterPath = UIUtils.prependBaseUri(request, basePath) + s"/$subPath/?" + getParameterOtherTable(request, stageTag) override val dataSource = new StageDataSource( @@ -138,7 +139,6 @@ private[ui] class StagePagedTable( ) override def pageLink(page: Int): String = { - val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name()) parameterPath + s"&$pageNumberFormField=$page" + s"&$stageTag.sort=$encodedSortColumn" + @@ -147,10 +147,8 @@ private[ui] class StagePagedTable( s"#$tableHeaderId" } - override def goButtonFormPath: String = { - val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name()) + override def goButtonFormPath: String = s"$parameterPath&$stageTag.sort=$encodedSortColumn&$stageTag.desc=$desc#$tableHeaderId" - } override def headers: Seq[Node] = { // stageHeadersAndCssClasses has three parts: header title, sortable and tooltip information. @@ -311,15 +309,9 @@ private[ui] class StageDataSource( // table so that we can avoid creating duplicate contents during sorting the data private val data = stages.map(stageRow).sorted(ordering(sortColumn, desc)) - private var _slicedStageIds: Set[Int] = _ - override def dataSize: Int = data.size - override def sliceData(from: Int, to: Int): Seq[StageTableRowData] = { - val r = data.slice(from, to) - _slicedStageIds = r.map(_.stageId).toSet - r - } + override def sliceData(from: Int, to: Int): Seq[StageTableRowData] = data.slice(from, to) private def stageRow(stageData: v1.StageData): StageTableRowData = { val formattedSubmissionTime = stageData.submissionTime match { @@ -350,7 +342,6 @@ private[ui] class StageDataSource( val shuffleWrite = stageData.shuffleWriteBytes val shuffleWriteWithUnit = if (shuffleWrite > 0) Utils.bytesToString(shuffleWrite) else "" - new StageTableRowData( stageData, Some(stageData), diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 26bbff5e54d25..844d9b7cf2c27 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -487,6 +487,7 @@ private[spark] object JsonProtocol { ("Callsite" -> rddInfo.callSite) ~ ("Parent IDs" -> parentIds) ~ ("Storage Level" -> storageLevel) ~ + ("Barrier" -> rddInfo.isBarrier) ~ ("Number of Partitions" -> rddInfo.numPartitions) ~ ("Number of Cached Partitions" -> rddInfo.numCachedPartitions) ~ ("Memory Size" -> rddInfo.memSize) ~ diff --git a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala index 6191e41b4118f..54899bfcf34fa 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala @@ -25,6 +25,7 @@ import org.scalatest.concurrent.Eventually import org.scalatest.time.SpanSugar._ import org.apache.spark._ +import org.apache.spark.internal.config import org.apache.spark.internal.config.Tests.TEST_NO_STAGE_RETRY class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext with Eventually { @@ -37,10 +38,10 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext with .setAppName("test-cluster") .set(TEST_NO_STAGE_RETRY, true) sc = new SparkContext(conf) + TestUtils.waitUntilExecutorsUp(sc, numWorker, 60000) } - // TODO (SPARK-31730): re-enable it - ignore("global sync by barrier() call") { + test("global sync by barrier() call") { initLocalClusterSparkContext() val rdd = sc.makeRDD(1 to 10, 4) val rdd2 = rdd.barrier().mapPartitions { it => @@ -57,10 +58,7 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext with } test("share messages with allGather() call") { - val conf = new SparkConf() - .setMaster("local-cluster[4, 1, 1024]") - .setAppName("test-cluster") - sc = new SparkContext(conf) + initLocalClusterSparkContext() val rdd = sc.makeRDD(1 to 10, 4) val rdd2 = rdd.barrier().mapPartitions { it => val context = BarrierTaskContext.get() @@ -78,10 +76,7 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext with } test("throw exception if we attempt to synchronize with different blocking calls") { - val conf = new SparkConf() - .setMaster("local-cluster[4, 1, 1024]") - .setAppName("test-cluster") - sc = new SparkContext(conf) + initLocalClusterSparkContext() val rdd = sc.makeRDD(1 to 10, 4) val rdd2 = rdd.barrier().mapPartitions { it => val context = BarrierTaskContext.get() @@ -100,10 +95,7 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext with } test("successively sync with allGather and barrier") { - val conf = new SparkConf() - .setMaster("local-cluster[4, 1, 1024]") - .setAppName("test-cluster") - sc = new SparkContext(conf) + initLocalClusterSparkContext() val rdd = sc.makeRDD(1 to 10, 4) val rdd2 = rdd.barrier().mapPartitions { it => val context = BarrierTaskContext.get() @@ -129,8 +121,7 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext with assert(times2.max - times2.min <= 1000) } - // TODO (SPARK-31730): re-enable it - ignore("support multiple barrier() call within a single task") { + test("support multiple barrier() call within a single task") { initLocalClusterSparkContext() val rdd = sc.makeRDD(1 to 10, 4) val rdd2 = rdd.barrier().mapPartitions { it => @@ -285,6 +276,9 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext with test("SPARK-31485: barrier stage should fail if only partial tasks are launched") { initLocalClusterSparkContext(2) + // It's required to reset the delay timer when a task is scheduled, otherwise all the tasks + // could get scheduled at ANY level. + sc.conf.set(config.LEGACY_LOCALITY_WAIT_RESET, true) val rdd0 = sc.parallelize(Seq(0, 1, 2, 3), 2) val dep = new OneToOneDependency[Int](rdd0) // set up a barrier stage with 2 tasks and both tasks prefer executor 0 (only 1 core) for diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index 61ea21fa86c5a..7c23e4449f461 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -36,6 +36,7 @@ import org.apache.spark.deploy.history.{EventLogFileReader, SingleEventLogFileWr import org.apache.spark.deploy.history.EventLogTestHelper._ import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics} import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.{EVENT_LOG_DIR, EVENT_LOG_ENABLED} import org.apache.spark.io._ import org.apache.spark.metrics.{ExecutorMetricType, MetricsSystem} import org.apache.spark.resource.ResourceProfile @@ -100,6 +101,49 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit testStageExecutorMetricsEventLogging() } + test("SPARK-31764: isBarrier should be logged in event log") { + val conf = new SparkConf() + conf.set(EVENT_LOG_ENABLED, true) + conf.set(EVENT_LOG_DIR, testDirPath.toString) + val sc = new SparkContext("local", "test-SPARK-31764", conf) + val appId = sc.applicationId + + sc.parallelize(1 to 10) + .barrier() + .mapPartitions(_.map(elem => (elem, elem))) + .filter(elem => elem._1 % 2 == 0) + .reduceByKey(_ + _) + .collect + sc.stop() + + val eventLogStream = EventLogFileReader.openEventLog(new Path(testDirPath, appId), fileSystem) + val events = readLines(eventLogStream).map(line => JsonProtocol.sparkEventFromJson(parse(line))) + val jobStartEvents = events + .filter(event => event.isInstanceOf[SparkListenerJobStart]) + .map(_.asInstanceOf[SparkListenerJobStart]) + + assert(jobStartEvents.size === 1) + val stageInfos = jobStartEvents.head.stageInfos + assert(stageInfos.size === 2) + + val stage0 = stageInfos(0) + val rddInfosInStage0 = stage0.rddInfos + assert(rddInfosInStage0.size === 3) + val sortedRddInfosInStage0 = rddInfosInStage0.sortBy(_.scope.get.name) + assert(sortedRddInfosInStage0(0).scope.get.name === "filter") + assert(sortedRddInfosInStage0(0).isBarrier === true) + assert(sortedRddInfosInStage0(1).scope.get.name === "mapPartitions") + assert(sortedRddInfosInStage0(1).isBarrier === true) + assert(sortedRddInfosInStage0(2).scope.get.name === "parallelize") + assert(sortedRddInfosInStage0(2).isBarrier === false) + + val stage1 = stageInfos(1) + val rddInfosInStage1 = stage1.rddInfos + assert(rddInfosInStage1.size === 1) + assert(rddInfosInStage1(0).scope.get.name === "reduceByKey") + assert(rddInfosInStage1(0).isBarrier === false) // reduceByKey + } + /* ----------------- * * Actual test logic * * ----------------- */ diff --git a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala index 5d34a56473375..3d52199b01327 100644 --- a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala @@ -98,7 +98,6 @@ class StagePageSuite extends SparkFunSuite with LocalSparkContext { val taskTable = new TaskPagedTable( stageData, basePath = "/a/b/c", - currentTime = 0, pageSize = 10, sortColumn = "Index", desc = false, diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index bc7f8b5d719db..248142a5ad633 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -1100,6 +1100,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Deserialized": true, | "Replication": 1 | }, + | "Barrier" : false, | "Number of Partitions": 201, | "Number of Cached Partitions": 301, | "Memory Size": 401, @@ -1623,6 +1624,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Deserialized": true, | "Replication": 1 | }, + | "Barrier" : false, | "Number of Partitions": 200, | "Number of Cached Partitions": 300, | "Memory Size": 400, @@ -1668,6 +1670,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Deserialized": true, | "Replication": 1 | }, + | "Barrier" : false, | "Number of Partitions": 400, | "Number of Cached Partitions": 600, | "Memory Size": 800, @@ -1684,6 +1687,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Deserialized": true, | "Replication": 1 | }, + | "Barrier" : false, | "Number of Partitions": 401, | "Number of Cached Partitions": 601, | "Memory Size": 801, @@ -1729,6 +1733,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Deserialized": true, | "Replication": 1 | }, + | "Barrier" : false, | "Number of Partitions": 600, | "Number of Cached Partitions": 900, | "Memory Size": 1200, @@ -1745,6 +1750,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Deserialized": true, | "Replication": 1 | }, + | "Barrier" : false, | "Number of Partitions": 601, | "Number of Cached Partitions": 901, | "Memory Size": 1201, @@ -1761,6 +1767,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Deserialized": true, | "Replication": 1 | }, + | "Barrier" : false, | "Number of Partitions": 602, | "Number of Cached Partitions": 902, | "Memory Size": 1202, @@ -1806,6 +1813,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Deserialized": true, | "Replication": 1 | }, + | "Barrier" : false, | "Number of Partitions": 800, | "Number of Cached Partitions": 1200, | "Memory Size": 1600, @@ -1822,6 +1830,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Deserialized": true, | "Replication": 1 | }, + | "Barrier" : false, | "Number of Partitions": 801, | "Number of Cached Partitions": 1201, | "Memory Size": 1601, @@ -1838,6 +1847,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Deserialized": true, | "Replication": 1 | }, + | "Barrier" : false, | "Number of Partitions": 802, | "Number of Cached Partitions": 1202, | "Memory Size": 1602, @@ -1854,6 +1864,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Deserialized": true, | "Replication": 1 | }, + | "Barrier" : false, | "Number of Partitions": 803, | "Number of Cached Partitions": 1203, | "Memory Size": 1603, diff --git a/dev/deps/spark-deps-hadoop-3.2-hive-2.3 b/dev/deps/spark-deps-hadoop-3.2-hive-2.3 index 3c3ce2dcdd6d4..b5a10b5dba378 100644 --- a/dev/deps/spark-deps-hadoop-3.2-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3.2-hive-2.3 @@ -183,7 +183,6 @@ metrics-jmx/4.1.1//metrics-jmx-4.1.1.jar metrics-json/4.1.1//metrics-json-4.1.1.jar metrics-jvm/4.1.1//metrics-jvm-4.1.1.jar minlog/1.3.0//minlog-1.3.0.jar -mssql-jdbc/6.2.1.jre7//mssql-jdbc-6.2.1.jre7.jar netty-all/4.1.47.Final//netty-all-4.1.47.Final.jar nimbus-jose-jwt/4.41.1//nimbus-jose-jwt-4.41.1.jar objenesis/2.5.1//objenesis-2.5.1.jar diff --git a/dev/run-tests-jenkins.py b/dev/run-tests-jenkins.py index 72e32d4e16e14..13be9592d771f 100755 --- a/dev/run-tests-jenkins.py +++ b/dev/run-tests-jenkins.py @@ -198,7 +198,7 @@ def main(): # format: http://linux.die.net/man/1/timeout # must be less than the timeout configured on Jenkins. Usually Jenkins's timeout is higher # then this. Please consult with the build manager or a committer when it should be increased. - tests_timeout = "400m" + tests_timeout = "500m" # Array to capture all test names to run on the pull request. These tests are represented # by their file equivalents in the dev/tests/ directory. diff --git a/docs/img/webui-structured-streaming-detail.png b/docs/img/webui-structured-streaming-detail.png new file mode 100644 index 0000000000000..f4850523c5c2f Binary files /dev/null and b/docs/img/webui-structured-streaming-detail.png differ diff --git a/docs/sql-ref-datetime-pattern.md b/docs/sql-ref-datetime-pattern.md index 4275f03335b33..48e85b450e6b2 100644 --- a/docs/sql-ref-datetime-pattern.md +++ b/docs/sql-ref-datetime-pattern.md @@ -76,7 +76,7 @@ The count of pattern letters determines the format. - Year: The count of letters determines the minimum field width below which padding is used. If the count of letters is two, then a reduced two digit form is used. For printing, this outputs the rightmost two digits. For parsing, this will parse using the base value of 2000, resulting in a year within the range 2000 to 2099 inclusive. If the count of letters is less than four (but not two), then the sign is only output for negative years. Otherwise, the sign is output if the pad width is exceeded when 'G' is not present. -- Month: If the number of pattern letters is 3 or more, the month is interpreted as text; otherwise, it is interpreted as a number. The text form is depend on letters - 'M' denotes the 'standard' form, and 'L' is for 'stand-alone' form. The difference between the 'standard' and 'stand-alone' forms is trickier to describe as there is no difference in English. However, in other languages there is a difference in the word used when the text is used alone, as opposed to in a complete date. For example, the word used for a month when used alone in a date picker is different to the word used for month in association with a day and year in a date. In Russian, 'Июль' is the stand-alone form of July, and 'Июля' is the standard form. Here are examples for all supported pattern letters (more than 4 letters is invalid): +- Month: It follows the rule of Number/Text. The text form is depend on letters - 'M' denotes the 'standard' form, and 'L' is for 'stand-alone' form. These two forms are different only in some certain languages. For example, in Russian, 'Июль' is the stand-alone form of July, and 'Июля' is the standard form. Here are examples for all supported pattern letters: - `'M'` or `'L'`: Month number in a year starting from 1. There is no difference between 'M' and 'L'. Month from 1 to 9 are printed without padding. ```sql spark-sql> select date_format(date '1970-01-01', "M"); @@ -107,8 +107,8 @@ The count of pattern letters determines the format. ``` - `'MMMM'`: full textual month representation in the standard form. It is used for parsing/formatting months as a part of dates/timestamps. ```sql - spark-sql> select date_format(date '1970-01-01', "MMMM yyyy"); - January 1970 + spark-sql> select date_format(date '1970-01-01', "d MMMM"); + 1 January spark-sql> select to_csv(named_struct('date', date '1970-01-01'), map('dateFormat', 'd MMMM', 'locale', 'RU')); 1 января ``` diff --git a/docs/web-ui.md b/docs/web-ui.md index 3c35dbeec86a2..e2e612cef3e54 100644 --- a/docs/web-ui.md +++ b/docs/web-ui.md @@ -407,6 +407,34 @@ Here is the list of SQL metrics: +## Structured Streaming Tab +When running Structured Streaming jobs in micro-batch mode, a Structured Streaming tab will be +available on the Web UI. The overview page displays some brief statistics for running and completed +queries. Also, you can check the latest exception of a failed query. For detailed statistics, please +click a "run id" in the tables. + +

+ Structured Streaming Query Statistics +

+ +The statistics page displays some useful metrics for insight into the status of your streaming +queries. Currently, it contains the following metrics. + +* **Input Rate.** The aggregate (across all sources) rate of data arriving. +* **Process Rate.** The aggregate (across all sources) rate at which Spark is processing data. +* **Input Rows.** The aggregate (across all sources) number of records processed in a trigger. +* **Batch Duration.** The process duration of each batch. +* **Operation Duration.** The amount of time taken to perform various operations in milliseconds. +The tracked operations are listed as follows. + * addBatch: Adds result data of the current batch to the sink. + * getBatch: Gets a new batch of data to process. + * latestOffset: Gets the latest offsets for sources. + * queryPlanning: Generates the execution plan. + * walCommit: Writes the offsets to the metadata log. + +As an early-release version, the statistics page is still under development and will be improved in +future releases. + ## Streaming Tab The web UI includes a Streaming tab if the application uses Spark streaming. This tab displays scheduling delay and processing time for each micro-batch in the data stream, which can be useful diff --git a/external/avro/src/test/resources/before_1582_date_v2_4.avro b/external/avro/src/test/resources/before_1582_date_v2_4.avro deleted file mode 100644 index 96aa7cbf176a5..0000000000000 Binary files a/external/avro/src/test/resources/before_1582_date_v2_4.avro and /dev/null differ diff --git a/external/avro/src/test/resources/before_1582_date_v2_4_5.avro b/external/avro/src/test/resources/before_1582_date_v2_4_5.avro new file mode 100644 index 0000000000000..5c15601f7ee4b Binary files /dev/null and b/external/avro/src/test/resources/before_1582_date_v2_4_5.avro differ diff --git a/external/avro/src/test/resources/before_1582_date_v2_4_6.avro b/external/avro/src/test/resources/before_1582_date_v2_4_6.avro new file mode 100644 index 0000000000000..212ea1d5efa5c Binary files /dev/null and b/external/avro/src/test/resources/before_1582_date_v2_4_6.avro differ diff --git a/external/avro/src/test/resources/before_1582_timestamp_micros_v2_4_5.avro b/external/avro/src/test/resources/before_1582_timestamp_micros_v2_4_5.avro new file mode 100644 index 0000000000000..c3445e3999bc1 Binary files /dev/null and b/external/avro/src/test/resources/before_1582_timestamp_micros_v2_4_5.avro differ diff --git a/external/avro/src/test/resources/before_1582_timestamp_micros_v2_4_6.avro b/external/avro/src/test/resources/before_1582_timestamp_micros_v2_4_6.avro new file mode 100644 index 0000000000000..96008d2378b1f Binary files /dev/null and b/external/avro/src/test/resources/before_1582_timestamp_micros_v2_4_6.avro differ diff --git a/external/avro/src/test/resources/before_1582_ts_millis_v2_4.avro b/external/avro/src/test/resources/before_1582_timestamp_millis_v2_4_5.avro similarity index 52% rename from external/avro/src/test/resources/before_1582_ts_millis_v2_4.avro rename to external/avro/src/test/resources/before_1582_timestamp_millis_v2_4_5.avro index dbaec814eb954..be12a0782073c 100644 Binary files a/external/avro/src/test/resources/before_1582_ts_millis_v2_4.avro and b/external/avro/src/test/resources/before_1582_timestamp_millis_v2_4_5.avro differ diff --git a/external/avro/src/test/resources/before_1582_timestamp_millis_v2_4_6.avro b/external/avro/src/test/resources/before_1582_timestamp_millis_v2_4_6.avro new file mode 100644 index 0000000000000..262f5dd6e77a4 Binary files /dev/null and b/external/avro/src/test/resources/before_1582_timestamp_millis_v2_4_6.avro differ diff --git a/external/avro/src/test/resources/before_1582_ts_micros_v2_4.avro b/external/avro/src/test/resources/before_1582_ts_micros_v2_4.avro deleted file mode 100644 index efe5e71a58813..0000000000000 Binary files a/external/avro/src/test/resources/before_1582_ts_micros_v2_4.avro and /dev/null differ diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index a5c1fb15add5c..e2ae489446d85 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.avro import java.io._ import java.net.URL -import java.nio.file.{Files, Paths} +import java.nio.file.{Files, Paths, StandardCopyOption} import java.sql.{Date, Timestamp} import java.util.{Locale, UUID} @@ -38,7 +38,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.TestingUDT.IntervalData import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.Filter -import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{withDefaultTimeZone, UTC} +import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{withDefaultTimeZone, LA, UTC} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.datasources.{DataSource, FilePartition} import org.apache.spark.sql.execution.datasources.v2.BatchScanExec @@ -1529,23 +1529,82 @@ abstract class AvroSuite extends QueryTest with SharedSparkSession { } } + // It generates input files for the test below: + // "SPARK-31183: compatibility with Spark 2.4 in reading dates/timestamps" + ignore("SPARK-31855: generate test files for checking compatibility with Spark 2.4") { + val resourceDir = "external/avro/src/test/resources" + val version = "2_4_6" + def save( + in: Seq[String], + t: String, + dstFile: String, + options: Map[String, String] = Map.empty): Unit = { + withTempDir { dir => + in.toDF("dt") + .select($"dt".cast(t)) + .repartition(1) + .write + .mode("overwrite") + .options(options) + .format("avro") + .save(dir.getCanonicalPath) + Files.copy( + dir.listFiles().filter(_.getName.endsWith(".avro")).head.toPath, + Paths.get(resourceDir, dstFile), + StandardCopyOption.REPLACE_EXISTING) + } + } + withDefaultTimeZone(LA) { + withSQLConf( + SQLConf.SESSION_LOCAL_TIMEZONE.key -> LA.getId) { + save( + Seq("1001-01-01"), + "date", + s"before_1582_date_v$version.avro") + save( + Seq("1001-01-01 01:02:03.123"), + "timestamp", + s"before_1582_timestamp_millis_v$version.avro", + // scalastyle:off line.size.limit + Map("avroSchema" -> + s""" + | { + | "namespace": "logical", + | "type": "record", + | "name": "test", + | "fields": [ + | {"name": "dt", "type": ["null", {"type": "long","logicalType": "timestamp-millis"}], "default": null} + | ] + | } + |""".stripMargin)) + // scalastyle:on line.size.limit + save( + Seq("1001-01-01 01:02:03.123456"), + "timestamp", + s"before_1582_timestamp_micros_v$version.avro") + } + } + } + test("SPARK-31183: compatibility with Spark 2.4 in reading dates/timestamps") { // test reading the existing 2.4 files and new 3.0 files (with rebase on/off) together. - def checkReadMixedFiles(fileName: String, dt: String, dataStr: String): Unit = { + def checkReadMixedFiles( + fileName: String, + dt: String, + dataStr: String, + checkDefaultLegacyRead: String => Unit): Unit = { withTempPaths(2) { paths => paths.foreach(_.delete()) val path2_4 = getResourceAvroFilePath(fileName) val path3_0 = paths(0).getCanonicalPath val path3_0_rebase = paths(1).getCanonicalPath if (dt == "date") { - val df = Seq(dataStr).toDF("str").select($"str".cast("date").as("date")) + val df = Seq(dataStr).toDF("str").select($"str".cast("date").as("dt")) // By default we should fail to write ancient datetime values. - var e = intercept[SparkException](df.write.format("avro").save(path3_0)) + val e = intercept[SparkException](df.write.format("avro").save(path3_0)) assert(e.getCause.getCause.getCause.isInstanceOf[SparkUpgradeException]) - // By default we should fail to read ancient datetime values. - e = intercept[SparkException](spark.read.format("avro").load(path2_4).collect()) - assert(e.getCause.isInstanceOf[SparkUpgradeException]) + checkDefaultLegacyRead(path2_4) withSQLConf(SQLConf.LEGACY_AVRO_REBASE_MODE_IN_WRITE.key -> CORRECTED.toString) { df.write.format("avro").mode("overwrite").save(path3_0) @@ -1562,25 +1621,23 @@ abstract class AvroSuite extends QueryTest with SharedSparkSession { 1.to(3).map(_ => Row(java.sql.Date.valueOf(dataStr)))) } } else { - val df = Seq(dataStr).toDF("str").select($"str".cast("timestamp").as("ts")) + val df = Seq(dataStr).toDF("str").select($"str".cast("timestamp").as("dt")) val avroSchema = s""" |{ | "type" : "record", | "name" : "test_schema", | "fields" : [ - | {"name": "ts", "type": {"type": "long", "logicalType": "$dt"}} + | {"name": "dt", "type": {"type": "long", "logicalType": "$dt"}} | ] |}""".stripMargin // By default we should fail to write ancient datetime values. - var e = intercept[SparkException] { + val e = intercept[SparkException] { df.write.format("avro").option("avroSchema", avroSchema).save(path3_0) } assert(e.getCause.getCause.getCause.isInstanceOf[SparkUpgradeException]) - // By default we should fail to read ancient datetime values. - e = intercept[SparkException](spark.read.format("avro").load(path2_4).collect()) - assert(e.getCause.isInstanceOf[SparkUpgradeException]) + checkDefaultLegacyRead(path2_4) withSQLConf(SQLConf.LEGACY_AVRO_REBASE_MODE_IN_WRITE.key -> CORRECTED.toString) { df.write.format("avro").option("avroSchema", avroSchema).mode("overwrite").save(path3_0) @@ -1600,11 +1657,33 @@ abstract class AvroSuite extends QueryTest with SharedSparkSession { } } - checkReadMixedFiles("before_1582_date_v2_4.avro", "date", "1001-01-01") - checkReadMixedFiles( - "before_1582_ts_micros_v2_4.avro", "timestamp-micros", "1001-01-01 01:02:03.123456") - checkReadMixedFiles( - "before_1582_ts_millis_v2_4.avro", "timestamp-millis", "1001-01-01 01:02:03.124") + def failInRead(path: String): Unit = { + val e = intercept[SparkException](spark.read.format("avro").load(path).collect()) + assert(e.getCause.isInstanceOf[SparkUpgradeException]) + } + def successInRead(path: String): Unit = spark.read.format("avro").load(path).collect() + Seq( + // By default we should fail to read ancient datetime values when parquet files don't + // contain Spark version. + "2_4_5" -> failInRead _, + "2_4_6" -> successInRead _ + ).foreach { case (version, checkDefaultRead) => + checkReadMixedFiles( + s"before_1582_date_v$version.avro", + "date", + "1001-01-01", + checkDefaultRead) + checkReadMixedFiles( + s"before_1582_timestamp_micros_v$version.avro", + "timestamp-micros", + "1001-01-01 01:02:03.123456", + checkDefaultRead) + checkReadMixedFiles( + s"before_1582_timestamp_millis_v$version.avro", + "timestamp-millis", + "1001-01-01 01:02:03.123", + checkDefaultRead) + } } test("SPARK-31183: rebasing microseconds timestamps in write") { diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index febeba7e13fcb..e0b128e369816 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml import org.apache.spark.annotation.Since import org.apache.spark.ml.feature.{Instance, LabeledPoint} +import org.apache.spark.ml.functions.checkNonNegativeWeight import org.apache.spark.ml.linalg.{Vector, VectorUDT} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -71,7 +72,7 @@ private[ml] trait PredictorParams extends Params val w = this match { case p: HasWeightCol => if (isDefined(p.weightCol) && $(p.weightCol).nonEmpty) { - col($(p.weightCol)).cast(DoubleType) + checkNonNegativeWeight((col($(p.weightCol)).cast(DoubleType))) } else { lit(1.0) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 5459a0fab9135..e65295dbdaf55 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -22,6 +22,7 @@ import org.json4s.DefaultFormats import org.apache.spark.annotation.Since import org.apache.spark.ml.PredictorParams +import org.apache.spark.ml.functions.checkNonNegativeWeight import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.HasWeightCol @@ -179,7 +180,7 @@ class NaiveBayes @Since("1.5.0") ( } val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) { - col($(weightCol)).cast(DoubleType) + checkNonNegativeWeight(col($(weightCol)).cast(DoubleType)) } else { lit(1.0) } @@ -259,7 +260,7 @@ class NaiveBayes @Since("1.5.0") ( import spark.implicits._ val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) { - col($(weightCol)).cast(DoubleType) + checkNonNegativeWeight(col($(weightCol)).cast(DoubleType)) } else { lit(1.0) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index 6c7112b80569f..b09f11dcfe156 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -21,6 +21,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.annotation.Since import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.functions.checkNonNegativeWeight import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -280,7 +281,7 @@ class BisectingKMeans @Since("2.0.0") ( val handlePersistence = dataset.storageLevel == StorageLevel.NONE val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) { - col($(weightCol)).cast(DoubleType) + checkNonNegativeWeight(col($(weightCol)).cast(DoubleType)) } else { lit(1.0) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index 6d4137b638dcc..18fd220b4ca9c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -22,6 +22,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.annotation.Since import org.apache.spark.broadcast.Broadcast import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.functions.checkNonNegativeWeight import org.apache.spark.ml.impl.Utils.{unpackUpperTriangular, EPSILON} import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param._ @@ -417,7 +418,7 @@ class GaussianMixture @Since("2.0.0") ( instr.logNumFeatures(numFeatures) val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) { - col($(weightCol)).cast(DoubleType) + checkNonNegativeWeight(col($(weightCol)).cast(DoubleType)) } else { lit(1.0) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index a42c920e24987..806015b633c23 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -23,6 +23,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.annotation.Since import org.apache.spark.ml.{Estimator, Model, PipelineStage} +import org.apache.spark.ml.functions.checkNonNegativeWeight import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -336,7 +337,7 @@ class KMeans @Since("1.5.0") ( val handlePersistence = dataset.storageLevel == StorageLevel.NONE val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) { - col($(weightCol)).cast(DoubleType) + checkNonNegativeWeight(col($(weightCol)).cast(DoubleType)) } else { lit(1.0) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index fac4d92b1810c..52be22f714981 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.Since +import org.apache.spark.ml.functions.checkNonNegativeWeight import org.apache.spark.ml.linalg.{Vector, VectorUDT} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -131,7 +132,7 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va col($(rawPredictionCol)), col($(labelCol)).cast(DoubleType), if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) - else col($(weightCol)).cast(DoubleType)).rdd.map { + else checkNonNegativeWeight(col($(weightCol)).cast(DoubleType))).rdd.map { case Row(rawPrediction: Vector, label: Double, weight: Double) => (rawPrediction(1), label, weight) case Row(rawPrediction: Double, label: Double, weight: Double) => diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala index 19790fd270619..fa2c25a5912a7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.Since +import org.apache.spark.ml.functions.checkNonNegativeWeight import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasPredictionCol, HasWeightCol} import org.apache.spark.ml.util._ @@ -139,7 +140,7 @@ class ClusteringEvaluator @Since("2.3.0") (@Since("2.3.0") override val uid: Str } else { dataset.select(col($(predictionCol)), vectorCol.as($(featuresCol), dataset.schema($(featuresCol)).metadata), - col(weightColName).cast(DoubleType)) + checkNonNegativeWeight(col(weightColName).cast(DoubleType))) } val metrics = new ClusteringMetrics(df) diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringMetrics.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringMetrics.scala index 8bf4ee1ecadfb..a785d063f1476 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringMetrics.scala @@ -300,7 +300,6 @@ private[evaluation] object SquaredEuclideanSilhouette extends Silhouette { (featureSum: DenseVector, squaredNormSum: Double, weightSum: Double), (features, squaredNorm, weight) ) => - require(weight >= 0.0, s"illegal weight value: $weight. weight must be >= 0.0.") BLAS.axpy(weight, features, featureSum) (featureSum, squaredNormSum + squaredNorm * weight, weightSum + weight) }, @@ -503,7 +502,6 @@ private[evaluation] object CosineSilhouette extends Silhouette { seqOp = { case ((normalizedFeaturesSum: DenseVector, weightSum: Double), (normalizedFeatures, weight)) => - require(weight >= 0.0, s"illegal weight value: $weight. weight must be >= 0.0.") BLAS.axpy(weight, normalizedFeatures, normalizedFeaturesSum) (normalizedFeaturesSum, weightSum + weight) }, diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala index ad1b70915e157..3d77792c4fc88 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.Since +import org.apache.spark.ml.functions.checkNonNegativeWeight import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -186,7 +187,7 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid SchemaUtils.checkNumericType(schema, $(labelCol)) val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) { - col($(weightCol)).cast(DoubleType) + checkNonNegativeWeight(col($(weightCol)).cast(DoubleType)) } else { lit(1.0) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala index aca017762deca..f0b7c345c3285 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.Since +import org.apache.spark.ml.functions.checkNonNegativeWeight import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol, HasWeightCol} import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils} @@ -122,7 +123,8 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui val predictionAndLabelsWithWeights = dataset .select(col($(predictionCol)).cast(DoubleType), col($(labelCol)).cast(DoubleType), - if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))) + if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) + else checkNonNegativeWeight(col($(weightCol)).cast(DoubleType))) .rdd .map { case Row(prediction: Double, label: Double, weight: Double) => (prediction, label, weight) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/functions.scala b/mllib/src/main/scala/org/apache/spark/ml/functions.scala index 0f03231079866..a0b6d11a46be9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/functions.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/functions.scala @@ -71,4 +71,10 @@ object functions { ) } } + + private[ml] def checkNonNegativeWeight = udf { + value: Double => + require(value >= 0, s"illegal weight value: $value. weight must be >= 0.0.") + value + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index fa41a98749f32..0ee895a95a288 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -29,6 +29,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.attribute._ import org.apache.spark.ml.feature.{Instance, OffsetInstance} +import org.apache.spark.ml.functions.checkNonNegativeWeight import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors} import org.apache.spark.ml.optim._ import org.apache.spark.ml.param._ @@ -399,7 +400,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val "GeneralizedLinearRegression was given data with 0 features, and with Param fitIntercept " + "set to false. To fit a model with 0 features, fitIntercept must be set to true." ) - val w = if (!hasWeightCol) lit(1.0) else col($(weightCol)) + val w = if (!hasWeightCol) lit(1.0) else checkNonNegativeWeight(col($(weightCol))) val offset = if (!hasOffsetCol) lit(0.0) else col($(offsetCol)).cast(DoubleType) val model = if (familyAndLink.family == Gaussian && familyAndLink.link == Identity) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index fe4de57de60f2..ec2640e9ef225 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -22,6 +22,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.annotation.Since import org.apache.spark.internal.Logging import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.functions.checkNonNegativeWeight import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -87,11 +88,11 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures } else { col($(featuresCol)) } - val w = if (hasWeightCol) col($(weightCol)).cast(DoubleType) else lit(1.0) + val w = + if (hasWeightCol) checkNonNegativeWeight(col($(weightCol)).cast(DoubleType)) else lit(1.0) dataset.select(col($(labelCol)).cast(DoubleType), f, w).rdd.map { - case Row(label: Double, feature: Double, weight: Double) => - (label, feature, weight) + case Row(label: Double, feature: Double, weight: Double) => (label, feature, weight) } } diff --git a/pom.xml b/pom.xml index fd4cebcd37319..deaf87f15539c 100644 --- a/pom.xml +++ b/pom.xml @@ -1357,6 +1357,10 @@ com.zaxxer HikariCP-java7 + + com.microsoft.sqlserver + mssql-jdbc + diff --git a/python/pyspark/context.py b/python/pyspark/context.py index b80149afa2af4..4f29f2f0be1e8 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -25,7 +25,6 @@ from tempfile import NamedTemporaryFile from py4j.protocol import Py4JError -from py4j.java_gateway import is_instance_of from pyspark import accumulators from pyspark.accumulators import Accumulator @@ -865,17 +864,10 @@ def union(self, rdds): first_jrdd_deserializer = rdds[0]._jrdd_deserializer if any(x._jrdd_deserializer != first_jrdd_deserializer for x in rdds): rdds = [x._reserialize() for x in rdds] - gw = SparkContext._gateway cls = SparkContext._jvm.org.apache.spark.api.java.JavaRDD - is_jrdd = is_instance_of(gw, rdds[0]._jrdd, cls) - jrdds = gw.new_array(cls, len(rdds)) + jrdds = SparkContext._gateway.new_array(cls, len(rdds)) for i in range(0, len(rdds)): - if is_jrdd: - jrdds[i] = rdds[i]._jrdd - else: - # zip could return JavaPairRDD hence we ensure `_jrdd` - # to be `JavaRDD` by wrapping it in a `map` - jrdds[i] = rdds[i].map(lambda x: x)._jrdd + jrdds[i] = rdds[i]._jrdd return RDD(self._jsc.union(jrdds), self, rdds[0]._jrdd_deserializer) def broadcast(self, value): diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index a3ce87096e790..65b902cf3c4d5 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -2219,6 +2219,20 @@ def semanticHash(self): """ return self._jdf.semanticHash() + @since(3.1) + def inputFiles(self): + """ + Returns a best-effort snapshot of the files that compose this :class:`DataFrame`. + This method simply asks each constituent BaseRelation for its respective files and + takes the union of all results. Depending on the source relations, this may not find + all input files. Duplicates are removed. + + >>> df = spark.read.load("examples/src/main/resources/people.json", format="json") + >>> len(df.inputFiles()) + 1 + """ + return list(self._jdf.inputFiles()) + where = copy_func( filter, sinceversion=1.3, diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 4dd15d14b9c53..ff0b10a9306cf 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -154,6 +154,9 @@ def create_array(s, t): # Ensure timestamp series are in expected form for Spark internal representation if t is not None and pa.types.is_timestamp(t): s = _check_series_convert_timestamps_internal(s, self._timezone) + elif type(s.dtype) == pd.CategoricalDtype: + # Note: This can be removed once minimum pyarrow version is >= 0.16.1 + s = s.astype(s.dtypes.categories.dtype) try: array = pa.Array.from_pandas(s, mask=mask, type=t, safe=self._safecheck) except pa.ArrowException as e: diff --git a/python/pyspark/sql/pandas/types.py b/python/pyspark/sql/pandas/types.py index d1edf3f9c47c1..4b70c8a2e95e1 100644 --- a/python/pyspark/sql/pandas/types.py +++ b/python/pyspark/sql/pandas/types.py @@ -114,6 +114,8 @@ def from_arrow_type(at): return StructType( [StructField(field.name, from_arrow_type(field.type), nullable=field.nullable) for field in at]) + elif types.is_dictionary(at): + spark_type = from_arrow_type(at.value_type) else: raise TypeError("Unsupported type in conversion from Arrow: " + str(at)) return spark_type diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index 004c79f290213..c59765dd79eb9 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -415,6 +415,33 @@ def run_test(num_records, num_parts, max_records, use_delay=False): for case in cases: run_test(*case) + def test_createDateFrame_with_category_type(self): + pdf = pd.DataFrame({"A": [u"a", u"b", u"c", u"a"]}) + pdf["B"] = pdf["A"].astype('category') + category_first_element = dict(enumerate(pdf['B'].cat.categories))[0] + + with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": True}): + arrow_df = self.spark.createDataFrame(pdf) + arrow_type = arrow_df.dtypes[1][1] + result_arrow = arrow_df.toPandas() + arrow_first_category_element = result_arrow["B"][0] + + with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}): + df = self.spark.createDataFrame(pdf) + spark_type = df.dtypes[1][1] + result_spark = df.toPandas() + spark_first_category_element = result_spark["B"][0] + + assert_frame_equal(result_spark, result_arrow) + + # ensure original category elements are string + self.assertIsInstance(category_first_element, str) + # spark data frame and arrow execution mode enabled data frame type must match pandas + self.assertEqual(spark_type, 'string') + self.assertEqual(arrow_type, 'string') + self.assertIsInstance(arrow_first_category_element, str) + self.assertIsInstance(spark_first_category_element, str) + @unittest.skipIf( not have_pandas or not have_pyarrow, diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index 9861178158f85..062e61663a332 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -17,6 +17,8 @@ import os import pydoc +import shutil +import tempfile import time import unittest @@ -820,6 +822,22 @@ def test_same_semantics_error(self): with self.assertRaisesRegexp(ValueError, "should be of DataFrame.*int"): self.spark.range(10).sameSemantics(1) + def test_input_files(self): + tpath = tempfile.mkdtemp() + shutil.rmtree(tpath) + try: + self.spark.range(1, 100, 1, 10).write.parquet(tpath) + # read parquet file and get the input files list + input_files_list = self.spark.read.parquet(tpath).inputFiles() + + # input files list should contain 10 entries + self.assertEquals(len(input_files_list), 10) + # all file paths in list must contain tpath + for file_path in input_files_list: + self.assertTrue(tpath in file_path) + finally: + shutil.rmtree(tpath) + class QueryExecutionListenerTests(unittest.TestCase, SQLTestUtils): # These tests are separate because it uses 'spark.sql.queryExecutionListeners' which is diff --git a/python/pyspark/sql/tests/test_pandas_udf_scalar.py b/python/pyspark/sql/tests/test_pandas_udf_scalar.py index 7260e80e2cfca..2d38efd39f902 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_scalar.py +++ b/python/pyspark/sql/tests/test_pandas_udf_scalar.py @@ -897,6 +897,24 @@ def test_timestamp_dst(self): result = df.withColumn('time', foo_udf(df.time)) self.assertEquals(df.collect(), result.collect()) + def test_udf_category_type(self): + + @pandas_udf('string') + def to_category_func(x): + return x.astype('category') + + pdf = pd.DataFrame({"A": [u"a", u"b", u"c", u"a"]}) + df = self.spark.createDataFrame(pdf) + df = df.withColumn("B", to_category_func(df['A'])) + result_spark = df.toPandas() + + spark_type = df.dtypes[1][1] + # spark data frame and arrow execution mode enabled data frame type must match pandas + self.assertEqual(spark_type, 'string') + + # Check result of column 'B' must be equal to column 'A' in type and values + pd.testing.assert_series_equal(result_spark["A"], result_spark["B"], check_names=False) + @unittest.skipIf(sys.version_info[:2] < (3, 5), "Type hints are supported from Python 3.5.") def test_type_annotation(self): # Regression test to check if type hints can be used. See SPARK-23569. diff --git a/python/pyspark/tests/test_rdd.py b/python/pyspark/tests/test_rdd.py index 04dfe68e57a3a..62ad4221d7078 100644 --- a/python/pyspark/tests/test_rdd.py +++ b/python/pyspark/tests/test_rdd.py @@ -168,15 +168,6 @@ def test_zip_chaining(self): set([(x, (x, x)) for x in 'abc']) ) - def test_union_pair_rdd(self): - # Regression test for SPARK-31788 - rdd = self.sc.parallelize([1, 2]) - pair_rdd = rdd.zip(rdd) - self.assertEqual( - self.sc.union([pair_rdd, pair_rdd]).collect(), - [((1, 1), (2, 2)), ((1, 1), (2, 2))] - ) - def test_deleting_input_files(self): # Regression test for SPARK-1025 tempFile = tempfile.NamedTemporaryFile(delete=False) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index d4799cace4531..e2559d4c07297 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -339,7 +339,7 @@ object FunctionRegistry { expression[GetJsonObject]("get_json_object"), expression[InitCap]("initcap"), expression[StringInstr]("instr"), - expression[Lower]("lcase"), + expression[Lower]("lcase", true), expression[Length]("length"), expression[Levenshtein]("levenshtein"), expression[Like]("like"), @@ -350,7 +350,7 @@ object FunctionRegistry { expression[StringTrimLeft]("ltrim"), expression[JsonTuple]("json_tuple"), expression[ParseUrl]("parse_url"), - expression[StringLocate]("position"), + expression[StringLocate]("position", true), expression[FormatString]("printf", true), expression[RegExpExtract]("regexp_extract"), expression[RegExpReplace]("regexp_replace"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index fe3fea5e35b1b..26f5bee72092c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst import java.sql.{Date, Timestamp} -import java.time.LocalDate +import java.time.{Instant, LocalDate} import scala.language.implicitConversions @@ -152,6 +152,7 @@ package object dsl { implicit def bigDecimalToLiteral(d: java.math.BigDecimal): Literal = Literal(d) implicit def decimalToLiteral(d: Decimal): Literal = Literal(d) implicit def timestampToLiteral(t: Timestamp): Literal = Literal(t) + implicit def instantToLiteral(i: Instant): Literal = Literal(i) implicit def binaryToLiteral(a: Array[Byte]): Literal = Literal(a) implicit def symbolToUnresolvedAttribute(s: Symbol): analysis.UnresolvedAttribute = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala index 82d689477080d..f7fe467cea830 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala @@ -144,7 +144,7 @@ object TimeWindow { case class PreciseTimestampConversion( child: Expression, fromType: DataType, - toType: DataType) extends UnaryExpression with ExpectsInputTypes { + toType: DataType) extends UnaryExpression with ExpectsInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(fromType) override def dataType: DataType = toType override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala index 7b819db32e425..342b14eaa3390 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala @@ -127,7 +127,8 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme > SELECT _FUNC_ 0; -1 """) -case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class BitwiseNot(child: Expression) + extends UnaryExpression with ExpectsInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(IntegralType) @@ -164,7 +165,8 @@ case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInp 0 """, since = "3.0.0") -case class BitwiseCount(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class BitwiseCount(child: Expression) + extends UnaryExpression with ExpectsInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(IntegralType, BooleanType)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 4fd68dcfe5156..b32e9ee05f1ef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -141,7 +141,7 @@ object Size { """, group = "map_funcs") case class MapKeys(child: Expression) - extends UnaryExpression with ExpectsInputTypes { + extends UnaryExpression with ExpectsInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(MapType) @@ -332,7 +332,7 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI """, group = "map_funcs") case class MapValues(child: Expression) - extends UnaryExpression with ExpectsInputTypes { + extends UnaryExpression with ExpectsInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(MapType) @@ -361,7 +361,8 @@ case class MapValues(child: Expression) """, group = "map_funcs", since = "3.0.0") -case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class MapEntries(child: Expression) + extends UnaryExpression with ExpectsInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(MapType) @@ -649,7 +650,7 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres """, group = "map_funcs", since = "2.4.0") -case class MapFromEntries(child: Expression) extends UnaryExpression { +case class MapFromEntries(child: Expression) extends UnaryExpression with NullIntolerant { @transient private lazy val dataTypeDetails: Option[(MapType, Boolean, Boolean)] = child.dataType match { @@ -873,7 +874,7 @@ object ArraySortLike { group = "array_funcs") // scalastyle:on line.size.limit case class SortArray(base: Expression, ascendingOrder: Expression) - extends BinaryExpression with ArraySortLike { + extends BinaryExpression with ArraySortLike with NullIntolerant { def this(e: Expression) = this(e, Literal(true)) @@ -1017,7 +1018,8 @@ case class Shuffle(child: Expression, randomSeed: Option[Long] = None) Reverse logic for arrays is available since 2.4.0. """ ) -case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Reverse(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { // Input types are utilized by type coercion in ImplicitTypeCasts. override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, ArrayType)) @@ -1086,7 +1088,7 @@ case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastI """, group = "array_funcs") case class ArrayContains(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = BooleanType @@ -1185,7 +1187,7 @@ case class ArrayContains(left: Expression, right: Expression) since = "2.4.0") // scalastyle:off line.size.limit case class ArraysOverlap(left: Expression, right: Expression) - extends BinaryArrayExpressionWithImplicitCast { + extends BinaryArrayExpressionWithImplicitCast with NullIntolerant { override def checkInputDataTypes(): TypeCheckResult = super.checkInputDataTypes() match { case TypeCheckResult.TypeCheckSuccess => @@ -1410,7 +1412,7 @@ case class ArraysOverlap(left: Expression, right: Expression) since = "2.4.0") // scalastyle:on line.size.limit case class Slice(x: Expression, start: Expression, length: Expression) - extends TernaryExpression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = x.dataType @@ -1688,7 +1690,8 @@ case class ArrayJoin( """, group = "array_funcs", since = "2.4.0") -case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class ArrayMin(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def nullable: Boolean = true @@ -1755,7 +1758,8 @@ case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCast """, group = "array_funcs", since = "2.4.0") -case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class ArrayMax(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def nullable: Boolean = true @@ -1831,7 +1835,7 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast group = "array_funcs", since = "2.4.0") case class ArrayPosition(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { @transient private lazy val ordering: Ordering[Any] = TypeUtils.getInterpretedOrdering(right.dataType) @@ -1909,7 +1913,7 @@ case class ArrayPosition(left: Expression, right: Expression) """, since = "2.4.0") case class ElementAt(left: Expression, right: Expression) - extends GetMapValueUtil with GetArrayItemUtil { + extends GetMapValueUtil with GetArrayItemUtil with NullIntolerant { @transient private lazy val mapKeyType = left.dataType.asInstanceOf[MapType].keyType @@ -2245,7 +2249,7 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio """, group = "array_funcs", since = "2.4.0") -case class Flatten(child: Expression) extends UnaryExpression { +case class Flatten(child: Expression) extends UnaryExpression with NullIntolerant { private def childDataType: ArrayType = child.dataType.asInstanceOf[ArrayType] @@ -2884,7 +2888,7 @@ case class ArrayRepeat(left: Expression, right: Expression) group = "array_funcs", since = "2.4.0") case class ArrayRemove(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = left.dataType @@ -3081,7 +3085,7 @@ trait ArraySetLike { group = "array_funcs", since = "2.4.0") case class ArrayDistinct(child: Expression) - extends UnaryExpression with ArraySetLike with ExpectsInputTypes { + extends UnaryExpression with ArraySetLike with ExpectsInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) @@ -3219,7 +3223,8 @@ case class ArrayDistinct(child: Expression) /** * Will become common base class for [[ArrayUnion]], [[ArrayIntersect]], and [[ArrayExcept]]. */ -trait ArrayBinaryLike extends BinaryArrayExpressionWithImplicitCast with ArraySetLike { +trait ArrayBinaryLike + extends BinaryArrayExpressionWithImplicitCast with ArraySetLike with NullIntolerant { override protected def dt: DataType = dataType override protected def et: DataType = elementType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 858c91a4d8e86..1b4a705e804f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} -import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.{FUNC_ALIAS, FunctionBuilder} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util._ @@ -255,7 +255,7 @@ object CreateMap { {1.0:"2",3.0:"4"} """, since = "2.4.0") case class MapFromArrays(left: Expression, right: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ExpectsInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, ArrayType) @@ -311,7 +311,12 @@ case object NamePlaceholder extends LeafExpression with Unevaluable { /** * Returns a Row containing the evaluation of all children expressions. */ -object CreateStruct extends FunctionBuilder { +object CreateStruct { + /** + * Returns a named struct with generated names or using the names when available. + * It should not be used for `struct` expressions or functions explicitly called + * by users. + */ def apply(children: Seq[Expression]): CreateNamedStruct = { CreateNamedStruct(children.zipWithIndex.flatMap { case (e: NamedExpression, _) if e.resolved => Seq(Literal(e.name), e) @@ -320,12 +325,23 @@ object CreateStruct extends FunctionBuilder { }) } + /** + * Returns a named struct with a pretty SQL name. It will show the pretty SQL string + * in its output column name as if `struct(...)` was called. Should be + * used for `struct` expressions or functions explicitly called by users. + */ + def create(children: Seq[Expression]): CreateNamedStruct = { + val expr = CreateStruct(children) + expr.setTagValue(FUNC_ALIAS, "struct") + expr + } + /** * Entry to use in the function registry. */ val registryEntry: (String, (ExpressionInfo, FunctionBuilder)) = { val info: ExpressionInfo = new ExpressionInfo( - "org.apache.spark.sql.catalyst.expressions.NamedStruct", + classOf[CreateNamedStruct].getCanonicalName, null, "struct", "_FUNC_(col1, col2, col3, ...) - Creates a struct with the given field values.", @@ -335,7 +351,7 @@ object CreateStruct extends FunctionBuilder { "", "", "") - ("struct", (info, this)) + ("struct", (info, this.create)) } } @@ -433,7 +449,15 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { """.stripMargin, isNull = FalseLiteral) } - override def prettyName: String = "named_struct" + // There is an alias set at `CreateStruct.create`. If there is an alias, + // this is the struct function explicitly called by a user and we should + // respect it in the SQL string as `struct(...)`. + override def prettyName: String = getTagValue(FUNC_ALIAS).getOrElse("named_struct") + + override def sql: String = getTagValue(FUNC_ALIAS).map { alias => + val childrenSQL = children.indices.filter(_ % 2 == 1).map(children(_).sql).mkString(", ") + s"$alias($childrenSQL)" + }.getOrElse(super.sql) } /** @@ -452,7 +476,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { since = "2.0.1") // scalastyle:on line.size.limit case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: Expression) - extends TernaryExpression with ExpectsInputTypes { + extends TernaryExpression with ExpectsInputTypes with NullIntolerant { def this(child: Expression, pairDelim: Expression) = { this(child, pairDelim, Literal(":")) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala index 5140db90c5954..f9ccf3c8c811f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala @@ -211,7 +211,8 @@ case class StructsToCsv( options: Map[String, String], child: Expression, timeZoneId: Option[String] = None) - extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes { + extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes + with NullIntolerant { override def nullable: Boolean = true def this(options: Map[String, String], child: Expression) = this(options, child, None) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 7dc008a2e5df8..4f3db1b8a57ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -198,7 +198,7 @@ case class CurrentBatchTimestamp( group = "datetime_funcs", since = "1.5.0") case class DateAdd(startDate: Expression, days: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ExpectsInputTypes with NullIntolerant { override def left: Expression = startDate override def right: Expression = days @@ -234,7 +234,7 @@ case class DateAdd(startDate: Expression, days: Expression) group = "datetime_funcs", since = "1.5.0") case class DateSub(startDate: Expression, days: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ExpectsInputTypes with NullIntolerant { override def left: Expression = startDate override def right: Expression = days @@ -266,7 +266,8 @@ case class DateSub(startDate: Expression, days: Expression) group = "datetime_funcs", since = "1.5.0") case class Hour(child: Expression, timeZoneId: Option[String] = None) - extends UnaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes { + extends UnaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes + with NullIntolerant { def this(child: Expression) = this(child, None) @@ -298,7 +299,8 @@ case class Hour(child: Expression, timeZoneId: Option[String] = None) group = "datetime_funcs", since = "1.5.0") case class Minute(child: Expression, timeZoneId: Option[String] = None) - extends UnaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes { + extends UnaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes + with NullIntolerant { def this(child: Expression) = this(child, None) @@ -330,7 +332,8 @@ case class Minute(child: Expression, timeZoneId: Option[String] = None) group = "datetime_funcs", since = "1.5.0") case class Second(child: Expression, timeZoneId: Option[String] = None) - extends UnaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes { + extends UnaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes + with NullIntolerant { def this(child: Expression) = this(child, None) @@ -353,7 +356,8 @@ case class Second(child: Expression, timeZoneId: Option[String] = None) } case class SecondWithFraction(child: Expression, timeZoneId: Option[String] = None) - extends UnaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes { + extends UnaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes + with NullIntolerant { def this(child: Expression) = this(child, None) @@ -385,7 +389,8 @@ case class SecondWithFraction(child: Expression, timeZoneId: Option[String] = No """, group = "datetime_funcs", since = "1.5.0") -case class DayOfYear(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class DayOfYear(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -402,7 +407,7 @@ case class DayOfYear(child: Expression) extends UnaryExpression with ImplicitCas } abstract class NumberToTimestampBase extends UnaryExpression - with ExpectsInputTypes { + with ExpectsInputTypes with NullIntolerant { protected def upScaleFactor: Long @@ -487,7 +492,8 @@ case class MicrosToTimestamp(child: Expression) """, group = "datetime_funcs", since = "1.5.0") -case class Year(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Year(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -503,7 +509,8 @@ case class Year(child: Expression) extends UnaryExpression with ImplicitCastInpu } } -case class YearOfWeek(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class YearOfWeek(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -528,7 +535,8 @@ case class YearOfWeek(child: Expression) extends UnaryExpression with ImplicitCa """, group = "datetime_funcs", since = "1.5.0") -case class Quarter(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Quarter(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -553,7 +561,8 @@ case class Quarter(child: Expression) extends UnaryExpression with ImplicitCastI """, group = "datetime_funcs", since = "1.5.0") -case class Month(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Month(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -577,7 +586,8 @@ case class Month(child: Expression) extends UnaryExpression with ImplicitCastInp 30 """, since = "1.5.0") -case class DayOfMonth(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class DayOfMonth(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -647,7 +657,7 @@ case class WeekDay(child: Expression) extends DayWeek { } } -abstract class DayWeek extends UnaryExpression with ImplicitCastInputTypes { +abstract class DayWeek extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -665,7 +675,8 @@ abstract class DayWeek extends UnaryExpression with ImplicitCastInputTypes { group = "datetime_funcs", since = "1.5.0") // scalastyle:on line.size.limit -case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class WeekOfYear(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -704,7 +715,8 @@ case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCa since = "1.5.0") // scalastyle:on line.size.limit case class DateFormatClass(left: Expression, right: Expression, timeZoneId: Option[String] = None) - extends BinaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes { + extends BinaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes + with NullIntolerant { def this(left: Expression, right: Expression) = this(left, right, None) @@ -1154,7 +1166,8 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ """, group = "datetime_funcs", since = "1.5.0") -case class LastDay(startDate: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class LastDay(startDate: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def child: Expression = startDate override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -1192,7 +1205,7 @@ case class LastDay(startDate: Expression) extends UnaryExpression with ImplicitC since = "1.5.0") // scalastyle:on line.size.limit case class NextDay(startDate: Expression, dayOfWeek: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def left: Expression = startDate override def right: Expression = dayOfWeek @@ -1248,7 +1261,7 @@ case class NextDay(startDate: Expression, dayOfWeek: Expression) * Adds an interval to timestamp. */ case class TimeAdd(start: Expression, interval: Expression, timeZoneId: Option[String] = None) - extends BinaryExpression with TimeZoneAwareExpression with ExpectsInputTypes { + extends BinaryExpression with TimeZoneAwareExpression with ExpectsInputTypes with NullIntolerant { def this(start: Expression, interval: Expression) = this(start, interval, None) @@ -1306,7 +1319,7 @@ case class DateAddInterval( interval: Expression, timeZoneId: Option[String] = None, ansiEnabled: Boolean = SQLConf.get.ansiEnabled) - extends BinaryExpression with ExpectsInputTypes with TimeZoneAwareExpression { + extends BinaryExpression with ExpectsInputTypes with TimeZoneAwareExpression with NullIntolerant { override def left: Expression = start override def right: Expression = interval @@ -1380,7 +1393,7 @@ case class DateAddInterval( since = "1.5.0") // scalastyle:on line.size.limit case class FromUTCTimestamp(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, StringType) override def dataType: DataType = TimestampType @@ -1440,7 +1453,7 @@ case class FromUTCTimestamp(left: Expression, right: Expression) since = "1.5.0") // scalastyle:on line.size.limit case class AddMonths(startDate: Expression, numMonths: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def left: Expression = startDate override def right: Expression = numMonths @@ -1494,7 +1507,8 @@ case class MonthsBetween( date2: Expression, roundOff: Expression, timeZoneId: Option[String] = None) - extends TernaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes { + extends TernaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes + with NullIntolerant { def this(date1: Expression, date2: Expression) = this(date1, date2, Literal.TrueLiteral, None) @@ -1552,7 +1566,7 @@ case class MonthsBetween( since = "1.5.0") // scalastyle:on line.size.limit case class ToUTCTimestamp(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, StringType) override def dataType: DataType = TimestampType @@ -1906,7 +1920,7 @@ case class TruncTimestamp( group = "datetime_funcs", since = "1.5.0") case class DateDiff(endDate: Expression, startDate: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def left: Expression = endDate override def right: Expression = startDate @@ -1960,7 +1974,7 @@ private case class GetTimestamp( group = "datetime_funcs", since = "3.0.0") case class MakeDate(year: Expression, month: Expression, day: Expression) - extends TernaryExpression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { override def children: Seq[Expression] = Seq(year, month, day) override def inputTypes: Seq[AbstractDataType] = Seq(IntegerType, IntegerType, IntegerType) @@ -2031,7 +2045,8 @@ case class MakeTimestamp( sec: Expression, timezone: Option[Expression] = None, timeZoneId: Option[String] = None) - extends SeptenaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes { + extends SeptenaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes + with NullIntolerant { def this( year: Expression, @@ -2307,7 +2322,7 @@ case class Extract(field: Expression, source: Expression, child: Expression) * between the given timestamps. */ case class SubtractTimestamps(endTimestamp: Expression, startTimestamp: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ExpectsInputTypes with NullIntolerant { override def left: Expression = endTimestamp override def right: Expression = startTimestamp @@ -2328,7 +2343,7 @@ case class SubtractTimestamps(endTimestamp: Expression, startTimestamp: Expressi * Returns the interval from the `left` date (inclusive) to the `right` date (exclusive). */ case class SubtractDates(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(DateType, DateType) override def dataType: DataType = CalendarIntervalType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index 9014ebfe2f96a..c2c70b2ab08e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.types._ * Note: this expression is internal and created only by the optimizer, * we don't need to do type check for it. */ -case class UnscaledValue(child: Expression) extends UnaryExpression { +case class UnscaledValue(child: Expression) extends UnaryExpression with NullIntolerant { override def dataType: DataType = LongType override def toString: String = s"UnscaledValue($child)" @@ -49,7 +49,7 @@ case class MakeDecimal( child: Expression, precision: Int, scale: Int, - nullOnOverflow: Boolean) extends UnaryExpression { + nullOnOverflow: Boolean) extends UnaryExpression with NullIntolerant { def this(child: Expression, precision: Int, scale: Int) = { this(child, precision, scale, !SQLConf.get.ansiEnabled) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index 4c8c58ae232f4..5e21b58f070ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -53,7 +53,8 @@ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} > SELECT _FUNC_('Spark'); 8cde774d6f7333752ed72cacddb05126 """) -case class Md5(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Md5(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = StringType @@ -89,7 +90,7 @@ case class Md5(child: Expression) extends UnaryExpression with ImplicitCastInput """) // scalastyle:on line.size.limit case class Sha2(left: Expression, right: Expression) - extends BinaryExpression with Serializable with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant with Serializable { override def dataType: DataType = StringType override def nullable: Boolean = true @@ -160,7 +161,8 @@ case class Sha2(left: Expression, right: Expression) > SELECT _FUNC_('Spark'); 85f5955f4b27a9a4c2aab6ffe5d7189fc298b92c """) -case class Sha1(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Sha1(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = StringType @@ -187,7 +189,8 @@ case class Sha1(child: Expression) extends UnaryExpression with ImplicitCastInpu > SELECT _FUNC_('Spark'); 1557323817 """) -case class Crc32(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Crc32(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = LongType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala index 1a569a7b89fe1..baab224691bc1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala @@ -31,7 +31,7 @@ abstract class ExtractIntervalPart( val dataType: DataType, func: CalendarInterval => Any, funcName: String) - extends UnaryExpression with ExpectsInputTypes with Serializable { + extends UnaryExpression with ExpectsInputTypes with NullIntolerant with Serializable { override def inputTypes: Seq[AbstractDataType] = Seq(CalendarIntervalType) @@ -82,7 +82,7 @@ object ExtractIntervalPart { abstract class IntervalNumOperation( interval: Expression, num: Expression) - extends BinaryExpression with ImplicitCastInputTypes with Serializable { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant with Serializable { override def left: Expression = interval override def right: Expression = num @@ -160,7 +160,7 @@ case class MakeInterval( hours: Expression, mins: Expression, secs: Expression) - extends SeptenaryExpression with ImplicitCastInputTypes { + extends SeptenaryExpression with ImplicitCastInputTypes with NullIntolerant { def this( years: Expression, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 205e5271517c3..f4568f860ac0e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -519,7 +519,8 @@ case class JsonToStructs( options: Map[String, String], child: Expression, timeZoneId: Option[String] = None) - extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes { + extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes + with NullIntolerant { // The JSON input data might be missing certain fields. We force the nullability // of the user-provided schema to avoid data corruptions. In particular, the parquet-mr encoder @@ -638,7 +639,8 @@ case class StructsToJson( options: Map[String, String], child: Expression, timeZoneId: Option[String] = None) - extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes { + extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback + with ExpectsInputTypes with NullIntolerant { override def nullable: Boolean = true def this(options: Map[String, String], child: Expression) = this(options, child, None) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 66e6334e3a450..fe8ea2a3c6733 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -57,7 +57,7 @@ abstract class LeafMathExpression(c: Double, name: String) * @param name The short name of the function */ abstract class UnaryMathExpression(val f: Double => Double, name: String) - extends UnaryExpression with Serializable with ImplicitCastInputTypes { + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant with Serializable { override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType) override def dataType: DataType = DoubleType @@ -111,7 +111,7 @@ abstract class UnaryLogExpression(f: Double => Double, name: String) * @param name The short name of the function */ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) - extends BinaryExpression with Serializable with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant with Serializable { override def inputTypes: Seq[DataType] = Seq(DoubleType, DoubleType) @@ -324,7 +324,7 @@ case class Acosh(child: Expression) -16 """) case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expression) - extends TernaryExpression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { override def children: Seq[Expression] = Seq(numExpr, fromBaseExpr, toBaseExpr) override def inputTypes: Seq[AbstractDataType] = Seq(StringType, IntegerType, IntegerType) @@ -452,7 +452,8 @@ object Factorial { > SELECT _FUNC_(5); 120 """) -case class Factorial(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Factorial(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[DataType] = Seq(IntegerType) @@ -491,7 +492,9 @@ case class Factorial(child: Expression) extends UnaryExpression with ImplicitCas > SELECT _FUNC_(1); 0.0 """) -case class Log(child: Expression) extends UnaryLogExpression(StrictMath.log, "LOG") +case class Log(child: Expression) extends UnaryLogExpression(StrictMath.log, "LOG") { + override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("ln") +} @ExpressionDescription( usage = "_FUNC_(expr) - Returns the logarithm of `expr` with base 2.", @@ -546,6 +549,7 @@ case class Log1p(child: Expression) extends UnaryLogExpression(StrictMath.log1p, // scalastyle:on line.size.limit case class Rint(child: Expression) extends UnaryMathExpression(math.rint, "ROUND") { override def funcName: String = "rint" + override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("rint") } @ExpressionDescription( @@ -732,7 +736,7 @@ case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadia """) // scalastyle:on line.size.limit case class Bin(child: Expression) - extends UnaryExpression with Serializable with ImplicitCastInputTypes { + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant with Serializable { override def inputTypes: Seq[DataType] = Seq(LongType) override def dataType: DataType = StringType @@ -831,7 +835,8 @@ object Hex { > SELECT _FUNC_('Spark SQL'); 537061726B2053514C """) -case class Hex(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Hex(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(LongType, BinaryType, StringType)) @@ -866,7 +871,8 @@ case class Hex(child: Expression) extends UnaryExpression with ImplicitCastInput > SELECT decode(_FUNC_('537061726B2053514C'), 'UTF-8'); Spark SQL """) -case class Unhex(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Unhex(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(StringType) @@ -952,7 +958,7 @@ case class Pow(left: Expression, right: Expression) 4 """) case class ShiftLeft(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(IntegerType, LongType), IntegerType) @@ -986,7 +992,7 @@ case class ShiftLeft(left: Expression, right: Expression) 2 """) case class ShiftRight(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(IntegerType, LongType), IntegerType) @@ -1020,7 +1026,7 @@ case class ShiftRight(left: Expression, right: Expression) 2 """) case class ShiftRightUnsigned(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(IntegerType, LongType), IntegerType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 3f60ca388a807..28924fac48eef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -283,7 +283,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress """, since = "1.5.0") case class StringSplit(str: Expression, regex: Expression, limit: Expression) - extends TernaryExpression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = ArrayType(StringType) override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType) @@ -325,7 +325,7 @@ case class StringSplit(str: Expression, regex: Expression, limit: Expression) since = "1.5.0") // scalastyle:on line.size.limit case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expression) - extends TernaryExpression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { // last regex in string, we will update the pattern iff regexp value changed. @transient private var lastRegex: UTF8String = _ @@ -433,7 +433,7 @@ object RegExpExtract { """, since = "1.5.0") case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expression) - extends TernaryExpression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { def this(s: Expression, r: Expression) = this(s, r, Literal(1)) // last regex in string, we will update the pattern iff regexp value changed. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 0b9fb8f85fe3c..334a079fc1892 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -334,7 +334,7 @@ trait String2StringExpression extends ImplicitCastInputTypes { """, since = "1.0.1") case class Upper(child: Expression) - extends UnaryExpression with String2StringExpression { + extends UnaryExpression with String2StringExpression with NullIntolerant { // scalastyle:off caselocale override def convert(v: UTF8String): UTF8String = v.toUpperCase @@ -356,7 +356,8 @@ case class Upper(child: Expression) sparksql """, since = "1.0.1") -case class Lower(child: Expression) extends UnaryExpression with String2StringExpression { +case class Lower(child: Expression) + extends UnaryExpression with String2StringExpression with NullIntolerant { // scalastyle:off caselocale override def convert(v: UTF8String): UTF8String = v.toLowerCase @@ -365,6 +366,9 @@ case class Lower(child: Expression) extends UnaryExpression with String2StringEx override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, c => s"($c).toLowerCase()") } + + override def prettyName: String = + getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("lower") } /** A base trait for functions that compare two strings, returning a boolean. */ @@ -432,7 +436,7 @@ case class EndsWith(left: Expression, right: Expression) extends StringPredicate since = "2.3.0") // scalastyle:on line.size.limit case class StringReplace(srcExpr: Expression, searchExpr: Expression, replaceExpr: Expression) - extends TernaryExpression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { def this(srcExpr: Expression, searchExpr: Expression) = { this(srcExpr, searchExpr, Literal("")) @@ -598,7 +602,7 @@ object StringTranslate { since = "1.5.0") // scalastyle:on line.size.limit case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replaceExpr: Expression) - extends TernaryExpression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { @transient private var lastMatching: UTF8String = _ @transient private var lastReplace: UTF8String = _ @@ -663,7 +667,7 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac since = "1.5.0") // scalastyle:on line.size.limit case class FindInSet(left: Expression, right: Expression) extends BinaryExpression - with ImplicitCastInputTypes { + with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType) @@ -1035,7 +1039,7 @@ case class StringTrimRight( since = "1.5.0") // scalastyle:on line.size.limit case class StringInstr(str: Expression, substr: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def left: Expression = str override def right: Expression = substr @@ -1077,7 +1081,7 @@ case class StringInstr(str: Expression, substr: Expression) since = "1.5.0") // scalastyle:on line.size.limit case class SubstringIndex(strExpr: Expression, delimExpr: Expression, countExpr: Expression) - extends TernaryExpression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType) @@ -1182,7 +1186,8 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) """) } - override def prettyName: String = "locate" + override def prettyName: String = + getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("locate") } /** @@ -1205,7 +1210,7 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) """, since = "1.5.0") case class StringLPad(str: Expression, len: Expression, pad: Expression = Literal(" ")) - extends TernaryExpression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { def this(str: Expression, len: Expression) = { this(str, len, Literal(" ")) @@ -1246,7 +1251,7 @@ case class StringLPad(str: Expression, len: Expression, pad: Expression = Litera """, since = "1.5.0") case class StringRPad(str: Expression, len: Expression, pad: Expression = Literal(" ")) - extends TernaryExpression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { def this(str: Expression, len: Expression) = { this(str, len, Literal(" ")) @@ -1536,7 +1541,8 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC Spark Sql """, since = "1.5.0") -case class InitCap(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class InitCap(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[DataType] = Seq(StringType) override def dataType: DataType = StringType @@ -1563,7 +1569,7 @@ case class InitCap(child: Expression) extends UnaryExpression with ImplicitCastI """, since = "1.5.0") case class StringRepeat(str: Expression, times: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def left: Expression = str override def right: Expression = times @@ -1593,7 +1599,7 @@ case class StringRepeat(str: Expression, times: Expression) """, since = "1.5.0") case class StringSpace(child: Expression) - extends UnaryExpression with ImplicitCastInputTypes { + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(IntegerType) @@ -1738,7 +1744,8 @@ case class Left(str: Expression, len: Expression, child: Expression) extends Run """, since = "1.5.0") // scalastyle:on line.size.limit -case class Length(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Length(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType)) @@ -1766,7 +1773,8 @@ case class Length(child: Expression) extends UnaryExpression with ImplicitCastIn 72 """, since = "2.3.0") -case class BitLength(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class BitLength(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType)) @@ -1797,7 +1805,8 @@ case class BitLength(child: Expression) extends UnaryExpression with ImplicitCas 9 """, since = "2.3.0") -case class OctetLength(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class OctetLength(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType)) @@ -1828,7 +1837,7 @@ case class OctetLength(child: Expression) extends UnaryExpression with ImplicitC """, since = "1.5.0") case class Levenshtein(left: Expression, right: Expression) extends BinaryExpression - with ImplicitCastInputTypes { + with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType) @@ -1853,7 +1862,8 @@ case class Levenshtein(left: Expression, right: Expression) extends BinaryExpres M460 """, since = "1.5.0") -case class SoundEx(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class SoundEx(child: Expression) + extends UnaryExpression with ExpectsInputTypes with NullIntolerant { override def dataType: DataType = StringType @@ -1879,7 +1889,8 @@ case class SoundEx(child: Expression) extends UnaryExpression with ExpectsInputT 50 """, since = "1.5.0") -case class Ascii(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Ascii(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = IntegerType override def inputTypes: Seq[DataType] = Seq(StringType) @@ -1921,7 +1932,8 @@ case class Ascii(child: Expression) extends UnaryExpression with ImplicitCastInp """, since = "2.3.0") // scalastyle:on line.size.limit -case class Chr(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Chr(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(LongType) @@ -1964,7 +1976,8 @@ case class Chr(child: Expression) extends UnaryExpression with ImplicitCastInput U3BhcmsgU1FM """, since = "1.5.0") -case class Base64(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Base64(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(BinaryType) @@ -1992,7 +2005,8 @@ case class Base64(child: Expression) extends UnaryExpression with ImplicitCastIn Spark SQL """, since = "1.5.0") -case class UnBase64(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class UnBase64(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = BinaryType override def inputTypes: Seq[DataType] = Seq(StringType) @@ -2024,7 +2038,7 @@ case class UnBase64(child: Expression) extends UnaryExpression with ImplicitCast since = "1.5.0") // scalastyle:on line.size.limit case class Decode(bin: Expression, charset: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def left: Expression = bin override def right: Expression = charset @@ -2064,7 +2078,7 @@ case class Decode(bin: Expression, charset: Expression) since = "1.5.0") // scalastyle:on line.size.limit case class Encode(value: Expression, charset: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { override def left: Expression = value override def right: Expression = charset @@ -2108,7 +2122,7 @@ case class Encode(value: Expression, charset: Expression) """, since = "1.5.0") case class FormatNumber(x: Expression, d: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ExpectsInputTypes with NullIntolerant { override def left: Expression = x override def right: Expression = d diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala index 55e06cb9e8471..e08a10ecac71c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala @@ -30,7 +30,8 @@ import org.apache.spark.unsafe.types.UTF8String * * This is not the world's most efficient implementation due to type conversion, but works. */ -abstract class XPathExtract extends BinaryExpression with ExpectsInputTypes with CodegenFallback { +abstract class XPathExtract + extends BinaryExpression with ExpectsInputTypes with CodegenFallback with NullIntolerant { override def left: Expression = xml override def right: Expression = path diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index b65221c236bfe..85c6600685bd1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -208,3 +208,161 @@ object ExtractPythonUDFFromJoinCondition extends Rule[LogicalPlan] with Predicat } } } + +sealed abstract class BuildSide + +case object BuildRight extends BuildSide + +case object BuildLeft extends BuildSide + +trait JoinSelectionHelper { + + def getBroadcastBuildSide( + left: LogicalPlan, + right: LogicalPlan, + joinType: JoinType, + hint: JoinHint, + hintOnly: Boolean, + conf: SQLConf): Option[BuildSide] = { + val buildLeft = if (hintOnly) { + hintToBroadcastLeft(hint) + } else { + canBroadcastBySize(left, conf) && !hintToNotBroadcastLeft(hint) + } + val buildRight = if (hintOnly) { + hintToBroadcastRight(hint) + } else { + canBroadcastBySize(right, conf) && !hintToNotBroadcastRight(hint) + } + getBuildSide( + canBuildLeft(joinType) && buildLeft, + canBuildRight(joinType) && buildRight, + left, + right + ) + } + + def getShuffleHashJoinBuildSide( + left: LogicalPlan, + right: LogicalPlan, + joinType: JoinType, + hint: JoinHint, + hintOnly: Boolean, + conf: SQLConf): Option[BuildSide] = { + val buildLeft = if (hintOnly) { + hintToShuffleHashJoinLeft(hint) + } else { + canBuildLocalHashMapBySize(left, conf) && muchSmaller(left, right) + } + val buildRight = if (hintOnly) { + hintToShuffleHashJoinRight(hint) + } else { + canBuildLocalHashMapBySize(right, conf) && muchSmaller(right, left) + } + getBuildSide( + canBuildLeft(joinType) && buildLeft, + canBuildRight(joinType) && buildRight, + left, + right + ) + } + + def getSmallerSide(left: LogicalPlan, right: LogicalPlan): BuildSide = { + if (right.stats.sizeInBytes <= left.stats.sizeInBytes) BuildRight else BuildLeft + } + + /** + * Matches a plan whose output should be small enough to be used in broadcast join. + */ + def canBroadcastBySize(plan: LogicalPlan, conf: SQLConf): Boolean = { + plan.stats.sizeInBytes >= 0 && plan.stats.sizeInBytes <= conf.autoBroadcastJoinThreshold + } + + def canBuildLeft(joinType: JoinType): Boolean = { + joinType match { + case _: InnerLike | RightOuter => true + case _ => false + } + } + + def canBuildRight(joinType: JoinType): Boolean = { + joinType match { + case _: InnerLike | LeftOuter | LeftSemi | LeftAnti | _: ExistenceJoin => true + case _ => false + } + } + + def hintToBroadcastLeft(hint: JoinHint): Boolean = { + hint.leftHint.exists(_.strategy.contains(BROADCAST)) + } + + def hintToBroadcastRight(hint: JoinHint): Boolean = { + hint.rightHint.exists(_.strategy.contains(BROADCAST)) + } + + def hintToNotBroadcastLeft(hint: JoinHint): Boolean = { + hint.leftHint.exists(_.strategy.contains(NO_BROADCAST_HASH)) + } + + def hintToNotBroadcastRight(hint: JoinHint): Boolean = { + hint.rightHint.exists(_.strategy.contains(NO_BROADCAST_HASH)) + } + + def hintToShuffleHashJoinLeft(hint: JoinHint): Boolean = { + hint.leftHint.exists(_.strategy.contains(SHUFFLE_HASH)) + } + + def hintToShuffleHashJoinRight(hint: JoinHint): Boolean = { + hint.rightHint.exists(_.strategy.contains(SHUFFLE_HASH)) + } + + def hintToSortMergeJoin(hint: JoinHint): Boolean = { + hint.leftHint.exists(_.strategy.contains(SHUFFLE_MERGE)) || + hint.rightHint.exists(_.strategy.contains(SHUFFLE_MERGE)) + } + + def hintToShuffleReplicateNL(hint: JoinHint): Boolean = { + hint.leftHint.exists(_.strategy.contains(SHUFFLE_REPLICATE_NL)) || + hint.rightHint.exists(_.strategy.contains(SHUFFLE_REPLICATE_NL)) + } + + private def getBuildSide( + canBuildLeft: Boolean, + canBuildRight: Boolean, + left: LogicalPlan, + right: LogicalPlan): Option[BuildSide] = { + if (canBuildLeft && canBuildRight) { + // returns the smaller side base on its estimated physical size, if we want to build the + // both sides. + Some(getSmallerSide(left, right)) + } else if (canBuildLeft) { + Some(BuildLeft) + } else if (canBuildRight) { + Some(BuildRight) + } else { + None + } + } + + /** + * Matches a plan whose single partition should be small enough to build a hash table. + * + * Note: this assume that the number of partition is fixed, requires additional work if it's + * dynamic. + */ + private def canBuildLocalHashMapBySize(plan: LogicalPlan, conf: SQLConf): Boolean = { + plan.stats.sizeInBytes < conf.autoBroadcastJoinThreshold * conf.numShufflePartitions + } + + /** + * Returns whether plan a is much smaller (3X) than plan b. + * + * The cost to build hash map is higher than sorting, we should only build hash map on a table + * that is much smaller than other one. Since we does not have the statistic for number of rows, + * use the size of bytes here as estimation. + */ + private def muchSmaller(a: LogicalPlan, b: LogicalPlan): Boolean = { + a.stats.sizeInBytes * 3 <= b.stats.sizeInBytes + } +} + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index c0cecf8536c39..03571a740df3e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1534,7 +1534,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging * Create a [[CreateStruct]] expression. */ override def visitStruct(ctx: StructContext): Expression = withOrigin(ctx) { - CreateStruct(ctx.argument.asScala.map(expression)) + CreateStruct.create(ctx.argument.asScala.map(expression)) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatterHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatterHelper.scala index 0ea54c28cb285..353c074caa75e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatterHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatterHelper.scala @@ -217,9 +217,18 @@ private object DateTimeFormatterHelper { toFormatter(builder, TimestampFormatter.defaultLocale) } + private final val bugInStandAloneForm = { + // Java 8 has a bug for stand-alone form. See https://bugs.openjdk.java.net/browse/JDK-8114833 + // Note: we only check the US locale so that it's a static check. It can produce false-negative + // as some locales are not affected by the bug. Since `L`/`q` is rarely used, we choose to not + // complicate the check here. + // TODO: remove it when we drop Java 8 support. + val formatter = DateTimeFormatter.ofPattern("LLL qqq", Locale.US) + formatter.format(LocalDate.of(2000, 1, 1)) == "1 1" + } final val unsupportedLetters = Set('A', 'c', 'e', 'n', 'N', 'p') final val unsupportedNarrowTextStyle = - Set("GGGGG", "MMMMM", "LLLLL", "EEEEE", "uuuuu", "QQQQQ", "qqqqq", "uuuuu") + Seq("G", "M", "L", "E", "u", "Q", "q").map(_ * 5).toSet /** * In Spark 3.0, we switch to the Proleptic Gregorian calendar and use DateTimeFormatter for @@ -244,6 +253,12 @@ private object DateTimeFormatterHelper { for (style <- unsupportedNarrowTextStyle if patternPart.contains(style)) { throw new IllegalArgumentException(s"Too many pattern letters: ${style.head}") } + if (bugInStandAloneForm && (patternPart.contains("LLL") || patternPart.contains("qqq"))) { + throw new IllegalArgumentException("Java 8 has a bug to support stand-alone " + + "form (3 or more 'L' or 'q' in the pattern string). Please use 'M' or 'Q' instead, " + + "or upgrade your Java version. For more details, please read " + + "https://bugs.openjdk.java.net/browse/JDK-8114833") + } // The meaning of 'u' was day number of week in SimpleDateFormat, it was changed to year // in DateTimeFormatter. Substitute 'u' to 'e' and use DateTimeFormatter to parse the // string. If parsable, return the result; otherwise, fall back to 'u', and then use the diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala index de2fd312b7db5..8428964d45707 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala @@ -127,16 +127,37 @@ class FractionTimestampFormatter(zoneId: ZoneId) override protected lazy val formatter = DateTimeFormatterHelper.fractionFormatter // The new formatter will omit the trailing 0 in the timestamp string, but the legacy formatter - // can't. Here we borrow the code from Spark 2.4 DateTimeUtils.timestampToString to omit the - // trailing 0 for the legacy formatter as well. + // can't. Here we use the legacy formatter to format the given timestamp up to seconds fractions, + // and custom implementation to format the fractional part without trailing zeros. override def format(ts: Timestamp): String = { - val timestampString = ts.toString val formatted = legacyFormatter.format(ts) - - if (timestampString.length > 19 && timestampString.substring(19) != ".0") { - formatted + timestampString.substring(19) - } else { + var nanos = ts.getNanos + if (nanos == 0) { formatted + } else { + // Formats non-zero seconds fraction w/o trailing zeros. For example: + // formatted = '2020-05:27 15:55:30' + // nanos = 001234000 + // Counts the length of the fractional part: 001234000 -> 6 + var fracLen = 9 + while (nanos % 10 == 0) { + nanos /= 10 + fracLen -= 1 + } + // Places `nanos` = 1234 after '2020-05:27 15:55:30.' + val fracOffset = formatted.length + 1 + val totalLen = fracOffset + fracLen + // The buffer for the final result: '2020-05:27 15:55:30.001234' + val buf = new Array[Char](totalLen) + formatted.getChars(0, formatted.length, buf, 0) + buf(formatted.length) = '.' + var i = totalLen + do { + i -= 1 + buf(i) = ('0' + (nanos % 10)).toChar + nanos /= 10 + } while (i > fracOffset) + new String(buf) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index e5bff7f7af007..6af995cab64fe 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -240,7 +240,6 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { checkCast(1.5, "1.5") checkEvaluation(cast(cast(1.toDouble, TimestampType), DoubleType), 1.toDouble) - checkEvaluation(cast(cast(1.toDouble, TimestampType), DoubleType), 1.toDouble) } test("cast from string") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index 02d6d847dc063..1ca7380ead413 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -792,7 +792,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } // Test escaping of format - GenerateUnsafeProjection.generate(FromUnixTime(Literal(0L), Literal("\"quote")) :: Nil) + GenerateUnsafeProjection.generate(FromUnixTime(Literal(0L), Literal("\"quote"), UTC_OPT) :: Nil) } test("unix_timestamp") { @@ -862,7 +862,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } // Test escaping of format GenerateUnsafeProjection.generate( - UnixTimestamp(Literal("2015-07-24"), Literal("\"quote")) :: Nil) + UnixTimestamp(Literal("2015-07-24"), Literal("\"quote"), UTC_OPT) :: Nil) } test("to_unix_timestamp") { @@ -940,7 +940,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } // Test escaping of format GenerateUnsafeProjection.generate( - ToUnixTimestamp(Literal("2015-07-24"), Literal("\"quote")) :: Nil) + ToUnixTimestamp(Literal("2015-07-24"), Literal("\"quote"), UTC_OPT) :: Nil) } test("datediff") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinSelectionHelperSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinSelectionHelperSuite.scala new file mode 100644 index 0000000000000..3513cfa14808f --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinSelectionHelperSuite.scala @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.AttributeMap +import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} +import org.apache.spark.sql.catalyst.plans.logical.{BROADCAST, HintInfo, JoinHint, NO_BROADCAST_HASH, SHUFFLE_HASH} +import org.apache.spark.sql.catalyst.statsEstimation.StatsTestPlan +import org.apache.spark.sql.internal.SQLConf + +class JoinSelectionHelperSuite extends PlanTest with JoinSelectionHelper { + + private val left = StatsTestPlan( + outputList = Seq('a.int, 'b.int, 'c.int), + rowCount = 20000000, + size = Some(20000000), + attributeStats = AttributeMap(Seq())) + + private val right = StatsTestPlan( + outputList = Seq('d.int), + rowCount = 1000, + size = Some(1000), + attributeStats = AttributeMap(Seq())) + + private val hintBroadcast = Some(HintInfo(Some(BROADCAST))) + private val hintNotToBroadcast = Some(HintInfo(Some(NO_BROADCAST_HASH))) + private val hintShuffleHash = Some(HintInfo(Some(SHUFFLE_HASH))) + + test("getBroadcastBuildSide (hintOnly = true) return BuildLeft with only a left hint") { + val broadcastSide = getBroadcastBuildSide( + left, + right, + Inner, + JoinHint(hintBroadcast, None), + hintOnly = true, + SQLConf.get + ) + assert(broadcastSide === Some(BuildLeft)) + } + + test("getBroadcastBuildSide (hintOnly = true) return BuildRight with only a right hint") { + val broadcastSide = getBroadcastBuildSide( + left, + right, + Inner, + JoinHint(None, hintBroadcast), + hintOnly = true, + SQLConf.get + ) + assert(broadcastSide === Some(BuildRight)) + } + + test("getBroadcastBuildSide (hintOnly = true) return smaller side with both having hints") { + val broadcastSide = getBroadcastBuildSide( + left, + right, + Inner, + JoinHint(hintBroadcast, hintBroadcast), + hintOnly = true, + SQLConf.get + ) + assert(broadcastSide === Some(BuildRight)) + } + + test("getBroadcastBuildSide (hintOnly = true) return None when no side has a hint") { + val broadcastSide = getBroadcastBuildSide( + left, + right, + Inner, + JoinHint(None, None), + hintOnly = true, + SQLConf.get + ) + assert(broadcastSide === None) + } + + test("getBroadcastBuildSide (hintOnly = false) return BuildRight when right is broadcastable") { + val broadcastSide = getBroadcastBuildSide( + left, + right, + Inner, + JoinHint(None, None), + hintOnly = false, + SQLConf.get + ) + assert(broadcastSide === Some(BuildRight)) + } + + test("getBroadcastBuildSide (hintOnly = false) return None when right has no broadcast hint") { + val broadcastSide = getBroadcastBuildSide( + left, + right, + Inner, + JoinHint(None, hintNotToBroadcast ), + hintOnly = false, + SQLConf.get + ) + assert(broadcastSide === None) + } + + test("getShuffleHashJoinBuildSide (hintOnly = true) return BuildLeft with only a left hint") { + val broadcastSide = getShuffleHashJoinBuildSide( + left, + right, + Inner, + JoinHint(hintShuffleHash, None), + hintOnly = true, + SQLConf.get + ) + assert(broadcastSide === Some(BuildLeft)) + } + + test("getShuffleHashJoinBuildSide (hintOnly = true) return BuildRight with only a right hint") { + val broadcastSide = getShuffleHashJoinBuildSide( + left, + right, + Inner, + JoinHint(None, hintShuffleHash), + hintOnly = true, + SQLConf.get + ) + assert(broadcastSide === Some(BuildRight)) + } + + test("getShuffleHashJoinBuildSide (hintOnly = true) return smaller side when both have hints") { + val broadcastSide = getShuffleHashJoinBuildSide( + left, + right, + Inner, + JoinHint(hintShuffleHash, hintShuffleHash), + hintOnly = true, + SQLConf.get + ) + assert(broadcastSide === Some(BuildRight)) + } + + test("getShuffleHashJoinBuildSide (hintOnly = true) return None when no side has a hint") { + val broadcastSide = getShuffleHashJoinBuildSide( + left, + right, + Inner, + JoinHint(None, None), + hintOnly = true, + SQLConf.get + ) + assert(broadcastSide === None) + } + + test("getShuffleHashJoinBuildSide (hintOnly = false) return BuildRight when right is smaller") { + val broadcastSide = getBroadcastBuildSide( + left, + right, + Inner, + JoinHint(None, None), + hintOnly = false, + SQLConf.get + ) + assert(broadcastSide === Some(BuildRight)) + } + + test("getSmallerSide should return BuildRight") { + assert(getSmallerSide(left, right) === BuildRight) + } + + test("canBroadcastBySize should return true if the plan size is less than 10MB") { + assert(canBroadcastBySize(left, SQLConf.get) === false) + assert(canBroadcastBySize(right, SQLConf.get) === true) + } + +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/TimestampFormatterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/TimestampFormatterSuite.scala index 4324d3cff63d7..7ff9b46bc6719 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/TimestampFormatterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/TimestampFormatterSuite.scala @@ -135,6 +135,9 @@ class TimestampFormatterSuite extends SparkFunSuite with SQLHelper with Matchers test("format fraction of second") { val formatter = TimestampFormatter.getFractionFormatter(UTC) Seq( + -999999 -> "1969-12-31 23:59:59.000001", + -999900 -> "1969-12-31 23:59:59.0001", + -1 -> "1969-12-31 23:59:59.999999", 0 -> "1970-01-01 00:00:00", 1 -> "1970-01-01 00:00:00.000001", 1000 -> "1970-01-01 00:00:00.001", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 1df812d1aa809..89915d254883d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -22,6 +22,7 @@ import java.util.UUID import org.apache.hadoop.fs.Path +import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.{InternalRow, QueryPlanningTracker} @@ -50,7 +51,7 @@ import org.apache.spark.util.Utils class QueryExecution( val sparkSession: SparkSession, val logical: LogicalPlan, - val tracker: QueryPlanningTracker = new QueryPlanningTracker) { + val tracker: QueryPlanningTracker = new QueryPlanningTracker) extends Logging { // TODO: Move the planner an optimizer into here from SessionState. protected def planner = sparkSession.sessionState.planner @@ -133,26 +134,42 @@ class QueryExecution( tracker.measurePhase(phase)(block) } - def simpleString: String = simpleString(false) - - def simpleString(formatted: Boolean): String = withRedaction { + def simpleString: String = { val concat = new PlanStringConcat() - concat.append("== Physical Plan ==\n") + simpleString(false, SQLConf.get.maxToStringFields, concat.append) + withRedaction { + concat.toString + } + } + + private def simpleString( + formatted: Boolean, + maxFields: Int, + append: String => Unit): Unit = { + append("== Physical Plan ==\n") if (formatted) { try { - ExplainUtils.processPlan(executedPlan, concat.append) + ExplainUtils.processPlan(executedPlan, append) } catch { - case e: AnalysisException => concat.append(e.toString) - case e: IllegalArgumentException => concat.append(e.toString) + case e: AnalysisException => append(e.toString) + case e: IllegalArgumentException => append(e.toString) } } else { - QueryPlan.append(executedPlan, concat.append, verbose = false, addSuffix = false) + QueryPlan.append(executedPlan, + append, verbose = false, addSuffix = false, maxFields = maxFields) } - concat.append("\n") - concat.toString + append("\n") } def explainString(mode: ExplainMode): String = { + val concat = new PlanStringConcat() + explainString(mode, SQLConf.get.maxToStringFields, concat.append) + withRedaction { + concat.toString + } + } + + private def explainString(mode: ExplainMode, maxFields: Int, append: String => Unit): Unit = { val queryExecution = if (logical.isStreaming) { // This is used only by explaining `Dataset/DataFrame` created by `spark.readStream`, so the // output mode does not matter since there is no `Sink`. @@ -165,19 +182,19 @@ class QueryExecution( mode match { case SimpleMode => - queryExecution.simpleString + queryExecution.simpleString(false, maxFields, append) case ExtendedMode => - queryExecution.toString + queryExecution.toString(maxFields, append) case CodegenMode => try { - org.apache.spark.sql.execution.debug.codegenString(queryExecution.executedPlan) + org.apache.spark.sql.execution.debug.writeCodegen(append, queryExecution.executedPlan) } catch { - case e: AnalysisException => e.toString + case e: AnalysisException => append(e.toString) } case CostMode => - queryExecution.stringWithStats + queryExecution.stringWithStats(maxFields, append) case FormattedMode => - queryExecution.simpleString(formatted = true) + queryExecution.simpleString(formatted = true, maxFields = maxFields, append) } } @@ -204,27 +221,39 @@ class QueryExecution( override def toString: String = withRedaction { val concat = new PlanStringConcat() - writePlans(concat.append, SQLConf.get.maxToStringFields) - concat.toString + toString(SQLConf.get.maxToStringFields, concat.append) + withRedaction { + concat.toString + } + } + + private def toString(maxFields: Int, append: String => Unit): Unit = { + writePlans(append, maxFields) } - def stringWithStats: String = withRedaction { + def stringWithStats: String = { val concat = new PlanStringConcat() + stringWithStats(SQLConf.get.maxToStringFields, concat.append) + withRedaction { + concat.toString + } + } + + private def stringWithStats(maxFields: Int, append: String => Unit): Unit = { val maxFields = SQLConf.get.maxToStringFields // trigger to compute stats for logical plans try { optimizedPlan.stats } catch { - case e: AnalysisException => concat.append(e.toString + "\n") + case e: AnalysisException => append(e.toString + "\n") } // only show optimized logical plan and physical plan - concat.append("== Optimized Logical Plan ==\n") - QueryPlan.append(optimizedPlan, concat.append, verbose = true, addSuffix = true, maxFields) - concat.append("\n== Physical Plan ==\n") - QueryPlan.append(executedPlan, concat.append, verbose = true, addSuffix = false, maxFields) - concat.append("\n") - concat.toString + append("== Optimized Logical Plan ==\n") + QueryPlan.append(optimizedPlan, append, verbose = true, addSuffix = true, maxFields) + append("\n== Physical Plan ==\n") + QueryPlan.append(executedPlan, append, verbose = true, addSuffix = false, maxFields) + append("\n") } /** @@ -261,19 +290,26 @@ class QueryExecution( /** * Dumps debug information about query execution into the specified file. * + * @param path path of the file the debug info is written to. * @param maxFields maximum number of fields converted to string representation. + * @param explainMode the explain mode to be used to generate the string + * representation of the plan. */ - def toFile(path: String, maxFields: Int = Int.MaxValue): Unit = { + def toFile( + path: String, + maxFields: Int = Int.MaxValue, + explainMode: Option[String] = None): Unit = { val filePath = new Path(path) val fs = filePath.getFileSystem(sparkSession.sessionState.newHadoopConf()) val writer = new BufferedWriter(new OutputStreamWriter(fs.create(filePath))) - val append = (s: String) => { - writer.write(s) - } try { - writePlans(append, maxFields) - writer.write("\n== Whole Stage Codegen ==\n") - org.apache.spark.sql.execution.debug.writeCodegen(writer.write, executedPlan) + val mode = explainMode.map(ExplainMode.fromString(_)).getOrElse(ExtendedMode) + explainString(mode, maxFields, writer.write) + if (mode != CodegenMode) { + writer.write("\n== Whole Stage Codegen ==\n") + org.apache.spark.sql.execution.debug.writeCodegen(writer.write, executedPlan) + } + log.info(s"Debug information was written at: $filePath") } finally { writer.close() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 12a1a1e7fc16e..302aae08d588b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.catalyst.optimizer.NormalizeFloatingNumbers +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, JoinSelectionHelper, NormalizeFloatingNumbers} import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -33,7 +33,6 @@ import org.apache.spark.sql.execution.aggregate.AggUtils import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec -import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.execution.python._ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.MemoryPlan @@ -135,93 +134,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * Supports both equi-joins and non-equi-joins. * Supports only inner like joins. */ - object JoinSelection extends Strategy with PredicateHelper { - - /** - * Matches a plan whose output should be small enough to be used in broadcast join. - */ - private def canBroadcast(plan: LogicalPlan): Boolean = { - plan.stats.sizeInBytes >= 0 && plan.stats.sizeInBytes <= conf.autoBroadcastJoinThreshold - } - - /** - * Matches a plan whose single partition should be small enough to build a hash table. - * - * Note: this assume that the number of partition is fixed, requires additional work if it's - * dynamic. - */ - private def canBuildLocalHashMap(plan: LogicalPlan): Boolean = { - plan.stats.sizeInBytes < conf.autoBroadcastJoinThreshold * conf.numShufflePartitions - } - - /** - * Returns whether plan a is much smaller (3X) than plan b. - * - * The cost to build hash map is higher than sorting, we should only build hash map on a table - * that is much smaller than other one. Since we does not have the statistic for number of rows, - * use the size of bytes here as estimation. - */ - private def muchSmaller(a: LogicalPlan, b: LogicalPlan): Boolean = { - a.stats.sizeInBytes * 3 <= b.stats.sizeInBytes - } - - private def canBuildRight(joinType: JoinType): Boolean = joinType match { - case _: InnerLike | LeftOuter | LeftSemi | LeftAnti | _: ExistenceJoin => true - case _ => false - } - - private def canBuildLeft(joinType: JoinType): Boolean = joinType match { - case _: InnerLike | RightOuter => true - case _ => false - } - - private def getBuildSide( - wantToBuildLeft: Boolean, - wantToBuildRight: Boolean, - left: LogicalPlan, - right: LogicalPlan): Option[BuildSide] = { - if (wantToBuildLeft && wantToBuildRight) { - // returns the smaller side base on its estimated physical size, if we want to build the - // both sides. - Some(getSmallerSide(left, right)) - } else if (wantToBuildLeft) { - Some(BuildLeft) - } else if (wantToBuildRight) { - Some(BuildRight) - } else { - None - } - } - - private def getSmallerSide(left: LogicalPlan, right: LogicalPlan) = { - if (right.stats.sizeInBytes <= left.stats.sizeInBytes) BuildRight else BuildLeft - } - - private def hintToBroadcastLeft(hint: JoinHint): Boolean = { - hint.leftHint.exists(_.strategy.contains(BROADCAST)) - } - - private def hintToBroadcastRight(hint: JoinHint): Boolean = { - hint.rightHint.exists(_.strategy.contains(BROADCAST)) - } - - private def hintToShuffleHashLeft(hint: JoinHint): Boolean = { - hint.leftHint.exists(_.strategy.contains(SHUFFLE_HASH)) - } - - private def hintToShuffleHashRight(hint: JoinHint): Boolean = { - hint.rightHint.exists(_.strategy.contains(SHUFFLE_HASH)) - } - - private def hintToSortMergeJoin(hint: JoinHint): Boolean = { - hint.leftHint.exists(_.strategy.contains(SHUFFLE_MERGE)) || - hint.rightHint.exists(_.strategy.contains(SHUFFLE_MERGE)) - } - - private def hintToShuffleReplicateNL(hint: JoinHint): Boolean = { - hint.leftHint.exists(_.strategy.contains(SHUFFLE_REPLICATE_NL)) || - hint.rightHint.exists(_.strategy.contains(SHUFFLE_REPLICATE_NL)) - } + object JoinSelection extends Strategy + with PredicateHelper + with JoinSelectionHelper { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { @@ -245,33 +160,31 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // 5. Pick broadcast nested loop join as the final solution. It may OOM but we don't have // other choice. case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, hint) => - def createBroadcastHashJoin(buildLeft: Boolean, buildRight: Boolean) = { - val wantToBuildLeft = canBuildLeft(joinType) && buildLeft - val wantToBuildRight = canBuildRight(joinType) && buildRight - getBuildSide(wantToBuildLeft, wantToBuildRight, left, right).map { buildSide => - Seq(joins.BroadcastHashJoinExec( - leftKeys, - rightKeys, - joinType, - buildSide, - condition, - planLater(left), - planLater(right))) + def createBroadcastHashJoin(onlyLookingAtHint: Boolean) = { + getBroadcastBuildSide(left, right, joinType, hint, onlyLookingAtHint, conf).map { + buildSide => + Seq(joins.BroadcastHashJoinExec( + leftKeys, + rightKeys, + joinType, + buildSide, + condition, + planLater(left), + planLater(right))) } } - def createShuffleHashJoin(buildLeft: Boolean, buildRight: Boolean) = { - val wantToBuildLeft = canBuildLeft(joinType) && buildLeft - val wantToBuildRight = canBuildRight(joinType) && buildRight - getBuildSide(wantToBuildLeft, wantToBuildRight, left, right).map { buildSide => - Seq(joins.ShuffledHashJoinExec( - leftKeys, - rightKeys, - joinType, - buildSide, - condition, - planLater(left), - planLater(right))) + def createShuffleHashJoin(onlyLookingAtHint: Boolean) = { + getShuffleHashJoinBuildSide(left, right, joinType, hint, onlyLookingAtHint, conf).map { + buildSide => + Seq(joins.ShuffledHashJoinExec( + leftKeys, + rightKeys, + joinType, + buildSide, + condition, + planLater(left), + planLater(right))) } } @@ -293,14 +206,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } def createJoinWithoutHint() = { - createBroadcastHashJoin( - canBroadcast(left) && !hint.leftHint.exists(_.strategy.contains(NO_BROADCAST_HASH)), - canBroadcast(right) && !hint.rightHint.exists(_.strategy.contains(NO_BROADCAST_HASH))) + createBroadcastHashJoin(false) .orElse { if (!conf.preferSortMergeJoin) { - createShuffleHashJoin( - canBuildLocalHashMap(left) && muchSmaller(left, right), - canBuildLocalHashMap(right) && muchSmaller(right, left)) + createShuffleHashJoin(false) } else { None } @@ -315,9 +224,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } - createBroadcastHashJoin(hintToBroadcastLeft(hint), hintToBroadcastRight(hint)) + createBroadcastHashJoin(true) .orElse { if (hintToSortMergeJoin(hint)) createSortMergeJoin() else None } - .orElse(createShuffleHashJoin(hintToShuffleHashLeft(hint), hintToShuffleHashRight(hint))) + .orElse(createShuffleHashJoin(true)) .orElse { if (hintToShuffleReplicateNL(hint)) createCartesianProduct() else None } .getOrElse(createJoinWithoutHint()) @@ -374,7 +283,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } def createJoinWithoutHint() = { - createBroadcastNLJoin(canBroadcast(left), canBroadcast(right)) + createBroadcastNLJoin(canBroadcastBySize(left, conf), canBroadcastBySize(right, conf)) .orElse(createCartesianProduct()) .getOrElse { // This join could be very slow or OOM diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index eb5fcd3b24227..bc924e6978ddc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -156,7 +156,7 @@ case class AdaptiveSparkPlanExec( var currentLogicalPlan = currentPhysicalPlan.logicalLink.get var result = createQueryStages(currentPhysicalPlan) val events = new LinkedBlockingQueue[StageMaterializationEvent]() - val errors = new mutable.ArrayBuffer[SparkException]() + val errors = new mutable.ArrayBuffer[Throwable]() var stagesToReplace = Seq.empty[QueryStageExec] while (!result.allChildStagesMaterialized) { currentPhysicalPlan = result.newPlan @@ -176,9 +176,7 @@ case class AdaptiveSparkPlanExec( }(AdaptiveSparkPlanExec.executionContext) } catch { case e: Throwable => - val ex = new SparkException( - s"Early failed query stage found: ${stage.treeString}", e) - cleanUpAndThrowException(Seq(ex), Some(stage.id)) + cleanUpAndThrowException(Seq(e), Some(stage.id)) } } } @@ -191,10 +189,9 @@ case class AdaptiveSparkPlanExec( events.drainTo(rem) (Seq(nextMsg) ++ rem.asScala).foreach { case StageSuccess(stage, res) => - stage.resultOption = Some(res) + stage.resultOption.set(Some(res)) case StageFailure(stage, ex) => - errors.append( - new SparkException(s"Failed to materialize query stage: ${stage.treeString}.", ex)) + errors.append(ex) } // In case of errors, we cancel all running stages and throw exception. @@ -328,11 +325,11 @@ case class AdaptiveSparkPlanExec( context.stageCache.get(e.canonicalized) match { case Some(existingStage) if conf.exchangeReuseEnabled => val stage = reuseQueryStage(existingStage, e) - // This is a leaf stage and is not materialized yet even if the reused exchange may has - // been completed. It will trigger re-optimization later and stage materialization will - // finish in instant if the underlying exchange is already completed. + val isMaterialized = stage.resultOption.get().isDefined CreateStageResult( - newPlan = stage, allChildStagesMaterialized = false, newStages = Seq(stage)) + newPlan = stage, + allChildStagesMaterialized = isMaterialized, + newStages = if (isMaterialized) Seq.empty else Seq(stage)) case _ => val result = createQueryStages(e.child) @@ -349,10 +346,11 @@ case class AdaptiveSparkPlanExec( newStage = reuseQueryStage(queryStage, e) } } - - // We've created a new stage, which is obviously not ready yet. - CreateStageResult(newPlan = newStage, - allChildStagesMaterialized = false, newStages = Seq(newStage)) + val isMaterialized = newStage.resultOption.get().isDefined + CreateStageResult( + newPlan = newStage, + allChildStagesMaterialized = isMaterialized, + newStages = if (isMaterialized) Seq.empty else Seq(newStage)) } else { CreateStageResult(newPlan = newPlan, allChildStagesMaterialized = false, newStages = result.newStages) @@ -361,7 +359,7 @@ case class AdaptiveSparkPlanExec( case q: QueryStageExec => CreateStageResult(newPlan = q, - allChildStagesMaterialized = q.resultOption.isDefined, newStages = Seq.empty) + allChildStagesMaterialized = q.resultOption.get().isDefined, newStages = Seq.empty) case _ => if (plan.children.isEmpty) { @@ -537,31 +535,28 @@ case class AdaptiveSparkPlanExec( * materialization errors and stage cancellation errors. */ private def cleanUpAndThrowException( - errors: Seq[SparkException], + errors: Seq[Throwable], earlyFailedStage: Option[Int]): Unit = { - val runningStages = currentPhysicalPlan.collect { + currentPhysicalPlan.foreach { // earlyFailedStage is the stage which failed before calling doMaterialize, // so we should avoid calling cancel on it to re-trigger the failure again. - case s: QueryStageExec if !earlyFailedStage.contains(s.id) => s - } - val cancelErrors = new mutable.ArrayBuffer[SparkException]() - try { - runningStages.foreach { s => + case s: QueryStageExec if !earlyFailedStage.contains(s.id) => try { s.cancel() } catch { case NonFatal(t) => - cancelErrors.append( - new SparkException(s"Failed to cancel query stage: ${s.treeString}", t)) + logError(s"Exception in cancelling query stage: ${s.treeString}", t) } - } - } finally { - val ex = new SparkException( - "Adaptive execution failed due to stage materialization failures.", errors.head) - errors.tail.foreach(ex.addSuppressed) - cancelErrors.foreach(ex.addSuppressed) - throw ex + case _ => + } + val e = if (errors.size == 1) { + errors.head + } else { + val se = new SparkException("Multiple failures in stage materialization.", errors.head) + errors.tail.foreach(se.addSuppressed) + se } + throw e } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/DemoteBroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/DemoteBroadcastHashJoin.scala index 0f2868e41cc39..aba83b1337109 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/DemoteBroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/DemoteBroadcastHashJoin.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.internal.SQLConf case class DemoteBroadcastHashJoin(conf: SQLConf) extends Rule[LogicalPlan] { private def shouldDemote(plan: LogicalPlan): Boolean = plan match { - case LogicalQueryStage(_, stage: ShuffleQueryStageExec) if stage.resultOption.isDefined + case LogicalQueryStage(_, stage: ShuffleQueryStageExec) if stage.resultOption.get().isDefined && stage.mapStats.isDefined => val mapStats = stage.mapStats.get val partitionCnt = mapStats.bytesByPartitionId.length diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStageStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStageStrategy.scala index d60c3ca72f6f6..ac98342277bc0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStageStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LogicalQueryStageStrategy.scala @@ -19,10 +19,11 @@ package org.apache.spark.sql.execution.adaptive import org.apache.spark.sql.Strategy import org.apache.spark.sql.catalyst.expressions.PredicateHelper +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan} import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, BuildLeft, BuildRight} +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec} /** * Strategy for plans containing [[LogicalQueryStage]] nodes: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala index 5416fde222cb6..3620f27058af2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.execution.adaptive +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ShuffleExchangeExec} -import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BuildLeft, BuildRight, BuildSide} +import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec import org.apache.spark.sql.internal.SQLConf /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala index f414f854b92ae..4e83b4344fbf0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.adaptive import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicReference import scala.concurrent.{Future, Promise} @@ -25,6 +26,7 @@ import org.apache.spark.{FutureAction, MapOutputStatistics, SparkException} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.catalyst.plans.physical.Partitioning @@ -82,7 +84,7 @@ abstract class QueryStageExec extends LeafExecNode { /** * Compute the statistics of the query stage if executed, otherwise None. */ - def computeStats(): Option[Statistics] = resultOption.map { _ => + def computeStats(): Option[Statistics] = resultOption.get().map { _ => // Metrics `dataSize` are available in both `ShuffleExchangeExec` and `BroadcastExchangeExec`. val exchange = plan match { case r: ReusedExchangeExec => r.child @@ -94,7 +96,9 @@ abstract class QueryStageExec extends LeafExecNode { @transient @volatile - private[adaptive] var resultOption: Option[Any] = None + protected var _resultOption = new AtomicReference[Option[Any]](None) + + private[adaptive] def resultOption: AtomicReference[Option[Any]] = _resultOption override def output: Seq[Attribute] = plan.output override def outputPartitioning: Partitioning = plan.outputPartitioning @@ -147,14 +151,16 @@ case class ShuffleQueryStageExec( throw new IllegalStateException("wrong plan for shuffle stage:\n " + plan.treeString) } - override def doMaterialize(): Future[Any] = { + override def doMaterialize(): Future[Any] = attachTree(this, "execute") { shuffle.mapOutputStatisticsFuture } override def newReuseInstance(newStageId: Int, newOutput: Seq[Attribute]): QueryStageExec = { - ShuffleQueryStageExec( + val reuse = ShuffleQueryStageExec( newStageId, ReusedExchangeExec(newOutput, shuffle)) + reuse._resultOption = this._resultOption + reuse } override def cancel(): Unit = { @@ -171,8 +177,8 @@ case class ShuffleQueryStageExec( * this method returns None, as there is no map statistics. */ def mapStats: Option[MapOutputStatistics] = { - assert(resultOption.isDefined, "ShuffleQueryStageExec should already be ready") - val stats = resultOption.get.asInstanceOf[MapOutputStatistics] + assert(resultOption.get().isDefined, "ShuffleQueryStageExec should already be ready") + val stats = resultOption.get().get.asInstanceOf[MapOutputStatistics] Option(stats) } } @@ -212,9 +218,11 @@ case class BroadcastQueryStageExec( } override def newReuseInstance(newStageId: Int, newOutput: Seq[Attribute]): QueryStageExec = { - BroadcastQueryStageExec( + val reuse = BroadcastQueryStageExec( newStageId, ReusedExchangeExec(newOutput, broadcast)) + reuse._resultOption = this._resultOption + reuse } override def cancel(): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala index eb091758910cd..cfc653a23840d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.dynamicpruning import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions.{Alias, BindReferences, DynamicPruningExpression, DynamicPruningSubquery, Expression, ListQuery, Literal, PredicateHelper} +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode import org.apache.spark.sql.catalyst.rules.Rule diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 08128d8f69dab..707ed1402d1ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, UnspecifiedDistribution} import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala index 888e7af7c07ed..52b476f9cf134 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala @@ -21,6 +21,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{ExplainUtils, SparkPlan} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 7f90a51c1f234..c7c3e1672f034 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.{ExplainUtils, RowIterator} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala index 755a63e545ef1..2b7cd65e7d96f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala @@ -23,6 +23,7 @@ import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.optimizer.BuildSide import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.SparkPlan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/package.scala deleted file mode 100644 index 134376628ae7f..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/package.scala +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -/** - * Physical execution operators for join operations. - */ -package object joins { - - sealed abstract class BuildSide - - case object BuildRight extends BuildSide - - case object BuildLeft extends BuildSide - -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala index 8c23f2cbb86ba..33539c01ee5dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala @@ -203,11 +203,10 @@ private[ui] class ExecutionPagedTable( private val (sortColumn, desc, pageSize) = getTableParameters(request, executionTag, "ID") + private val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name()) + override val dataSource = new ExecutionDataSource( - request, - parent, data, - basePath, currentTime, pageSize, sortColumn, @@ -222,11 +221,9 @@ private[ui] class ExecutionPagedTable( override def tableId: String = s"$executionTag-table" override def tableCssClass: String = - "table table-bordered table-sm table-striped " + - "table-head-clickable table-cell-width-limited" + "table table-bordered table-sm table-striped table-head-clickable table-cell-width-limited" override def pageLink(page: Int): String = { - val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name()) parameterPath + s"&$pageNumberFormField=$page" + s"&$executionTag.sort=$encodedSortColumn" + @@ -239,10 +236,8 @@ private[ui] class ExecutionPagedTable( override def pageNumberFormField: String = s"$executionTag.page" - override def goButtonFormPath: String = { - val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name()) + override def goButtonFormPath: String = s"$parameterPath&$executionTag.sort=$encodedSortColumn&$executionTag.desc=$desc#$tableHeaderId" - } override def headers: Seq[Node] = { // Information for each header: title, sortable, tooltip @@ -348,7 +343,6 @@ private[ui] class ExecutionPagedTable( private[ui] class ExecutionTableRowData( - val submissionTime: Long, val duration: Long, val executionUIData: SQLExecutionUIData, val runningJobData: Seq[Int], @@ -357,10 +351,7 @@ private[ui] class ExecutionTableRowData( private[ui] class ExecutionDataSource( - request: HttpServletRequest, - parent: SQLTab, executionData: Seq[SQLExecutionUIData], - basePath: String, currentTime: Long, pageSize: Int, sortColumn: String, @@ -373,20 +364,13 @@ private[ui] class ExecutionDataSource( // in the table so that we can avoid creating duplicate contents during sorting the data private val data = executionData.map(executionRow).sorted(ordering(sortColumn, desc)) - private var _sliceExecutionIds: Set[Int] = _ - override def dataSize: Int = data.size - override def sliceData(from: Int, to: Int): Seq[ExecutionTableRowData] = { - val r = data.slice(from, to) - _sliceExecutionIds = r.map(_.executionUIData.executionId.toInt).toSet - r - } + override def sliceData(from: Int, to: Int): Seq[ExecutionTableRowData] = data.slice(from, to) private def executionRow(executionUIData: SQLExecutionUIData): ExecutionTableRowData = { - val submissionTime = executionUIData.submissionTime val duration = executionUIData.completionTime.map(_.getTime()) - .getOrElse(currentTime) - submissionTime + .getOrElse(currentTime) - executionUIData.submissionTime val runningJobData = if (showRunningJobs) { executionUIData.jobs.filter { @@ -407,7 +391,6 @@ private[ui] class ExecutionDataSource( } else Seq.empty new ExecutionTableRowData( - submissionTime, duration, executionUIData, runningJobData, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 5481337bf6cee..0cca3e7b47c56 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1306,7 +1306,7 @@ object functions { * @since 1.4.0 */ @scala.annotation.varargs - def struct(cols: Column*): Column = withExpr { CreateStruct(cols.map(_.expr)) } + def struct(cols: Column*): Column = withExpr { CreateStruct.create(cols.map(_.expr)) } /** * Creates a new struct column that composes multiple input columns. diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java index 5603cb988b8e7..af0a22b036030 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java @@ -18,6 +18,8 @@ package test.org.apache.spark.sql; import java.io.Serializable; +import java.sql.Timestamp; +import java.text.SimpleDateFormat; import java.time.Instant; import java.time.LocalDate; import java.util.*; @@ -210,6 +212,17 @@ private static Row createRecordSpark22000Row(Long index) { return new GenericRow(values); } + private static String timestampToString(Timestamp ts) { + String timestampString = String.valueOf(ts); + String formatted = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(ts); + + if (timestampString.length() > 19 && !timestampString.substring(19).equals(".0")) { + return formatted + timestampString.substring(19); + } else { + return formatted; + } + } + private static RecordSpark22000 createRecordSpark22000(Row recordRow) { RecordSpark22000 record = new RecordSpark22000(); record.setShortField(String.valueOf(recordRow.getShort(0))); @@ -219,7 +232,7 @@ private static RecordSpark22000 createRecordSpark22000(Row recordRow) { record.setDoubleField(String.valueOf(recordRow.getDouble(4))); record.setStringField(recordRow.getString(5)); record.setBooleanField(String.valueOf(recordRow.getBoolean(6))); - record.setTimestampField(String.valueOf(recordRow.getTimestamp(7))); + record.setTimestampField(timestampToString(recordRow.getTimestamp(7))); // This would figure out that null value will not become "null". record.setNullIntField(null); return record; diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 23173c8ba1f11..d245aa5a17345 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -2,7 +2,7 @@ ## Summary - Number of queries: 337 - Number of expressions that missing example: 34 - - Expressions missing examples: and,string,tinyint,double,smallint,date,decimal,boolean,float,binary,bigint,int,timestamp,cume_dist,dense_rank,input_file_block_length,input_file_block_start,input_file_name,lag,lead,monotonically_increasing_id,ntile,struct,!,not,or,percent_rank,rank,row_number,spark_partition_id,version,window,positive,count_min_sketch + - Expressions missing examples: and,string,tinyint,double,smallint,date,decimal,boolean,float,binary,bigint,int,timestamp,struct,cume_dist,dense_rank,input_file_block_length,input_file_block_start,input_file_name,lag,lead,monotonically_increasing_id,ntile,!,not,or,percent_rank,rank,row_number,spark_partition_id,version,window,positive,count_min_sketch ## Schema of Built-in Functions | Class name | Function name or alias | Query example | Output schema | | ---------- | ---------------------- | ------------- | ------------- | @@ -79,6 +79,7 @@ | org.apache.spark.sql.catalyst.expressions.CreateArray | array | SELECT array(1, 2, 3) | struct> | | org.apache.spark.sql.catalyst.expressions.CreateMap | map | SELECT map(1.0, '2', 3.0, '4') | struct> | | org.apache.spark.sql.catalyst.expressions.CreateNamedStruct | named_struct | SELECT named_struct("a", 1, "b", 2, "c", 3) | struct> | +| org.apache.spark.sql.catalyst.expressions.CreateNamedStruct | struct | N/A | N/A | | org.apache.spark.sql.catalyst.expressions.CsvToStructs | from_csv | SELECT from_csv('1, 0.8', 'a INT, b DOUBLE') | struct> | | org.apache.spark.sql.catalyst.expressions.Cube | cube | SELECT name, age, count(*) FROM VALUES (2, 'Alice'), (5, 'Bob') people(age, name) GROUP BY cube(name, age) | struct | | org.apache.spark.sql.catalyst.expressions.CumeDist | cume_dist | N/A | N/A | @@ -157,12 +158,12 @@ | org.apache.spark.sql.catalyst.expressions.LessThanOrEqual | <= | SELECT 2 <= 2 | struct<(2 <= 2):boolean> | | org.apache.spark.sql.catalyst.expressions.Levenshtein | levenshtein | SELECT levenshtein('kitten', 'sitting') | struct | | org.apache.spark.sql.catalyst.expressions.Like | like | SELECT like('Spark', '_park') | struct | -| org.apache.spark.sql.catalyst.expressions.Log | ln | SELECT ln(1) | struct | +| org.apache.spark.sql.catalyst.expressions.Log | ln | SELECT ln(1) | struct | | org.apache.spark.sql.catalyst.expressions.Log10 | log10 | SELECT log10(10) | struct | | org.apache.spark.sql.catalyst.expressions.Log1p | log1p | SELECT log1p(0) | struct | | org.apache.spark.sql.catalyst.expressions.Log2 | log2 | SELECT log2(2) | struct | | org.apache.spark.sql.catalyst.expressions.Logarithm | log | SELECT log(10, 100) | struct | -| org.apache.spark.sql.catalyst.expressions.Lower | lcase | SELECT lcase('SparkSql') | struct | +| org.apache.spark.sql.catalyst.expressions.Lower | lcase | SELECT lcase('SparkSql') | struct | | org.apache.spark.sql.catalyst.expressions.Lower | lower | SELECT lower('SparkSql') | struct | | org.apache.spark.sql.catalyst.expressions.MakeDate | make_date | SELECT make_date(2013, 7, 15) | struct | | org.apache.spark.sql.catalyst.expressions.MakeInterval | make_interval | SELECT make_interval(100, 11, 1, 1, 12, 30, 01.001001) | struct | @@ -171,7 +172,7 @@ | org.apache.spark.sql.catalyst.expressions.MapEntries | map_entries | SELECT map_entries(map(1, 'a', 2, 'b')) | struct>> | | org.apache.spark.sql.catalyst.expressions.MapFilter | map_filter | SELECT map_filter(map(1, 0, 2, 2, 3, -1), (k, v) -> k > v) | struct namedlambdavariable()), namedlambdavariable(), namedlambdavariable())):map> | | org.apache.spark.sql.catalyst.expressions.MapFromArrays | map_from_arrays | SELECT map_from_arrays(array(1.0, 3.0), array('2', '4')) | struct> | -| org.apache.spark.sql.catalyst.expressions.MapFromEntries | map_from_entries | SELECT map_from_entries(array(struct(1, 'a'), struct(2, 'b'))) | struct> | +| org.apache.spark.sql.catalyst.expressions.MapFromEntries | map_from_entries | SELECT map_from_entries(array(struct(1, 'a'), struct(2, 'b'))) | struct> | | org.apache.spark.sql.catalyst.expressions.MapKeys | map_keys | SELECT map_keys(map(1, 'a', 2, 'b')) | struct> | | org.apache.spark.sql.catalyst.expressions.MapValues | map_values | SELECT map_values(map(1, 'a', 2, 'b')) | struct> | | org.apache.spark.sql.catalyst.expressions.MapZipWith | map_zip_with | SELECT map_zip_with(map(1, 'a', 2, 'b'), map(1, 'x', 2, 'y'), (k, v1, v2) -> concat(v1, v2)) | struct> | @@ -186,7 +187,6 @@ | org.apache.spark.sql.catalyst.expressions.Murmur3Hash | hash | SELECT hash('Spark', array(123), 2) | struct | | org.apache.spark.sql.catalyst.expressions.NTile | ntile | N/A | N/A | | org.apache.spark.sql.catalyst.expressions.NaNvl | nanvl | SELECT nanvl(cast('NaN' as double), 123) | struct | -| org.apache.spark.sql.catalyst.expressions.NamedStruct | struct | N/A | N/A | | org.apache.spark.sql.catalyst.expressions.NextDay | next_day | SELECT next_day('2015-01-14', 'TU') | struct | | org.apache.spark.sql.catalyst.expressions.Not | ! | N/A | N/A | | org.apache.spark.sql.catalyst.expressions.Not | not | N/A | N/A | @@ -219,7 +219,7 @@ | org.apache.spark.sql.catalyst.expressions.Remainder | mod | SELECT 2 % 1.8 | struct<(CAST(CAST(2 AS DECIMAL(1,0)) AS DECIMAL(2,1)) % CAST(1.8 AS DECIMAL(2,1))):decimal(2,1)> | | org.apache.spark.sql.catalyst.expressions.Reverse | reverse | SELECT reverse('Spark SQL') | struct | | org.apache.spark.sql.catalyst.expressions.Right | right | SELECT right('Spark SQL', 3) | struct | -| org.apache.spark.sql.catalyst.expressions.Rint | rint | SELECT rint(12.3456) | struct | +| org.apache.spark.sql.catalyst.expressions.Rint | rint | SELECT rint(12.3456) | struct | | org.apache.spark.sql.catalyst.expressions.Rollup | rollup | SELECT name, age, count(*) FROM VALUES (2, 'Alice'), (5, 'Bob') people(age, name) GROUP BY rollup(name, age) | struct | | org.apache.spark.sql.catalyst.expressions.Round | round | SELECT round(2.5, 0) | struct | | org.apache.spark.sql.catalyst.expressions.RowNumber | row_number | N/A | N/A | @@ -251,7 +251,7 @@ | org.apache.spark.sql.catalyst.expressions.Stack | stack | SELECT stack(2, 1, 2, 3) | struct | | org.apache.spark.sql.catalyst.expressions.StringInstr | instr | SELECT instr('SparkSQL', 'SQL') | struct | | org.apache.spark.sql.catalyst.expressions.StringLPad | lpad | SELECT lpad('hi', 5, '??') | struct | -| org.apache.spark.sql.catalyst.expressions.StringLocate | position | SELECT position('bar', 'foobarbar') | struct | +| org.apache.spark.sql.catalyst.expressions.StringLocate | position | SELECT position('bar', 'foobarbar') | struct | | org.apache.spark.sql.catalyst.expressions.StringLocate | locate | SELECT locate('bar', 'foobarbar') | struct | | org.apache.spark.sql.catalyst.expressions.StringRPad | rpad | SELECT rpad('hi', 5, '??') | struct | | org.apache.spark.sql.catalyst.expressions.StringRepeat | repeat | SELECT repeat('123', 2) | struct | diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out index b507713a73d1f..d5c0acb40bb1e 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out @@ -55,7 +55,7 @@ struct -- !query select position('bar' in 'foobarbar'), position(null, 'foobarbar'), position('aaads', null) -- !query schema -struct +struct -- !query output 4 NULL NULL diff --git a/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out index 3fcd132701a3f..d41d25280146b 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by-filter.sql.out @@ -272,7 +272,7 @@ struct= 0)):bigint> -- !query SELECT 'foo', MAX(STRUCT(a)) FILTER (WHERE b >= 1) FROM testData WHERE a = 0 GROUP BY 1 -- !query schema -struct= 1)):struct> +struct= 1)):struct> -- !query output diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index 7bfdd0ad53a95..50eb2a9f22f69 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -87,7 +87,7 @@ struct -- !query SELECT 'foo', MAX(STRUCT(a)) FROM testData WHERE a = 0 GROUP BY 1 -- !query schema -struct> +struct> -- !query output diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/numeric.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/numeric.sql.out index e59b9d5b63a40..7b7aeb4ec7934 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/numeric.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/numeric.sql.out @@ -4654,7 +4654,7 @@ struct -- !query select ln(1.2345678e-28) -- !query schema -struct +struct -- !query output -64.26166165451762 @@ -4662,7 +4662,7 @@ struct -- !query select ln(0.0456789) -- !query schema -struct +struct -- !query output -3.0861187944847437 @@ -4670,7 +4670,7 @@ struct -- !query select ln(0.99949452) -- !query schema -struct +struct -- !query output -5.056077980832118E-4 @@ -4678,7 +4678,7 @@ struct -- !query select ln(1.00049687395) -- !query schema -struct +struct -- !query output 4.967505490136803E-4 @@ -4686,7 +4686,7 @@ struct -- !query select ln(1234.567890123456789) -- !query schema -struct +struct -- !query output 7.11847630129779 @@ -4694,7 +4694,7 @@ struct -- !query select ln(5.80397490724e5) -- !query schema -struct +struct -- !query output 13.271468476626518 @@ -4702,7 +4702,7 @@ struct -- !query select ln(9.342536355e34) -- !query schema -struct +struct -- !query output 80.52247093552418 diff --git a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out index 0d37c0d02e61f..20c31b140b009 100644 --- a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out @@ -55,7 +55,7 @@ struct -- !query select position('bar' in 'foobarbar'), position(null, 'foobarbar'), position('aaads', null) -- !query schema -struct +struct -- !query output 4 NULL NULL diff --git a/sql/core/src/test/resources/sql-tests/results/struct.sql.out b/sql/core/src/test/resources/sql-tests/results/struct.sql.out index f294c5213d319..3b610edc47169 100644 --- a/sql/core/src/test/resources/sql-tests/results/struct.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/struct.sql.out @@ -83,7 +83,7 @@ struct -- !query SELECT ID, STRUCT(ST.C as STC, ST.D as STD).STD FROM tbl_x -- !query schema -struct +struct -- !query output 1 delta 2 eta diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out index ed7ab5a342c12..d046ff249379f 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out @@ -85,7 +85,7 @@ FROM various_maps struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'map_zip_with(various_maps.`decimal_map1`, various_maps.`decimal_map2`, lambdafunction(named_struct(NamePlaceholder(), k, NamePlaceholder(), v1, NamePlaceholder(), v2), k, v1, v2))' due to argument data type mismatch: The input to function map_zip_with should have been two maps with compatible key types, but the key types are [decimal(36,0), decimal(36,35)].; line 1 pos 7 +cannot resolve 'map_zip_with(various_maps.`decimal_map1`, various_maps.`decimal_map2`, lambdafunction(struct(k, v1, v2), k, v1, v2))' due to argument data type mismatch: The input to function map_zip_with should have been two maps with compatible key types, but the key types are [decimal(36,0), decimal(36,35)].; line 1 pos 7 -- !query @@ -113,7 +113,7 @@ FROM various_maps struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'map_zip_with(various_maps.`decimal_map2`, various_maps.`int_map`, lambdafunction(named_struct(NamePlaceholder(), k, NamePlaceholder(), v1, NamePlaceholder(), v2), k, v1, v2))' due to argument data type mismatch: The input to function map_zip_with should have been two maps with compatible key types, but the key types are [decimal(36,35), int].; line 1 pos 7 +cannot resolve 'map_zip_with(various_maps.`decimal_map2`, various_maps.`int_map`, lambdafunction(struct(k, v1, v2), k, v1, v2))' due to argument data type mismatch: The input to function map_zip_with should have been two maps with compatible key types, but the key types are [decimal(36,35), int].; line 1 pos 7 -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/udf/udf-group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/udf/udf-group-by.sql.out index 6403406413db9..da5256f5c0453 100644 --- a/sql/core/src/test/resources/sql-tests/results/udf/udf-group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/udf/udf-group-by.sql.out @@ -87,7 +87,7 @@ struct> +struct> -- !query output diff --git a/sql/core/src/test/resources/test-data/before_1582_date_v2_4.snappy.parquet b/sql/core/src/test/resources/test-data/before_1582_date_v2_4.snappy.parquet deleted file mode 100644 index 7d5cc12eefe04..0000000000000 Binary files a/sql/core/src/test/resources/test-data/before_1582_date_v2_4.snappy.parquet and /dev/null differ diff --git a/sql/core/src/test/resources/test-data/before_1582_date_v2_4_5.snappy.parquet b/sql/core/src/test/resources/test-data/before_1582_date_v2_4_5.snappy.parquet new file mode 100644 index 0000000000000..edd61c9b9fec8 Binary files /dev/null and b/sql/core/src/test/resources/test-data/before_1582_date_v2_4_5.snappy.parquet differ diff --git a/sql/core/src/test/resources/test-data/before_1582_date_v2_4_6.snappy.parquet b/sql/core/src/test/resources/test-data/before_1582_date_v2_4_6.snappy.parquet new file mode 100644 index 0000000000000..01f4887f5e994 Binary files /dev/null and b/sql/core/src/test/resources/test-data/before_1582_date_v2_4_6.snappy.parquet differ diff --git a/sql/core/src/test/resources/test-data/before_1582_timestamp_int96_dict_v2_4_5.snappy.parquet b/sql/core/src/test/resources/test-data/before_1582_timestamp_int96_dict_v2_4_5.snappy.parquet new file mode 100644 index 0000000000000..c7e8d3926f63a Binary files /dev/null and b/sql/core/src/test/resources/test-data/before_1582_timestamp_int96_dict_v2_4_5.snappy.parquet differ diff --git a/sql/core/src/test/resources/test-data/before_1582_timestamp_int96_dict_v2_4_6.snappy.parquet b/sql/core/src/test/resources/test-data/before_1582_timestamp_int96_dict_v2_4_6.snappy.parquet new file mode 100644 index 0000000000000..939e2b8088eb0 Binary files /dev/null and b/sql/core/src/test/resources/test-data/before_1582_timestamp_int96_dict_v2_4_6.snappy.parquet differ diff --git a/sql/core/src/test/resources/test-data/before_1582_timestamp_int96_plain_v2_4_5.snappy.parquet b/sql/core/src/test/resources/test-data/before_1582_timestamp_int96_plain_v2_4_5.snappy.parquet new file mode 100644 index 0000000000000..88a94ac482052 Binary files /dev/null and b/sql/core/src/test/resources/test-data/before_1582_timestamp_int96_plain_v2_4_5.snappy.parquet differ diff --git a/sql/core/src/test/resources/test-data/before_1582_timestamp_int96_plain_v2_4_6.snappy.parquet b/sql/core/src/test/resources/test-data/before_1582_timestamp_int96_plain_v2_4_6.snappy.parquet new file mode 100644 index 0000000000000..68bfa33aac13f Binary files /dev/null and b/sql/core/src/test/resources/test-data/before_1582_timestamp_int96_plain_v2_4_6.snappy.parquet differ diff --git a/sql/core/src/test/resources/test-data/before_1582_timestamp_int96_v2_4.snappy.parquet b/sql/core/src/test/resources/test-data/before_1582_timestamp_int96_v2_4.snappy.parquet deleted file mode 100644 index 13254bd93a5e6..0000000000000 Binary files a/sql/core/src/test/resources/test-data/before_1582_timestamp_int96_v2_4.snappy.parquet and /dev/null differ diff --git a/sql/core/src/test/resources/test-data/before_1582_timestamp_micros_v2_4.snappy.parquet b/sql/core/src/test/resources/test-data/before_1582_timestamp_micros_v2_4.snappy.parquet deleted file mode 100644 index 7d2b46e9bea41..0000000000000 Binary files a/sql/core/src/test/resources/test-data/before_1582_timestamp_micros_v2_4.snappy.parquet and /dev/null differ diff --git a/sql/core/src/test/resources/test-data/before_1582_timestamp_micros_v2_4_5.snappy.parquet b/sql/core/src/test/resources/test-data/before_1582_timestamp_micros_v2_4_5.snappy.parquet new file mode 100644 index 0000000000000..62e6048354dc1 Binary files /dev/null and b/sql/core/src/test/resources/test-data/before_1582_timestamp_micros_v2_4_5.snappy.parquet differ diff --git a/sql/core/src/test/resources/test-data/before_1582_timestamp_micros_v2_4_6.snappy.parquet b/sql/core/src/test/resources/test-data/before_1582_timestamp_micros_v2_4_6.snappy.parquet new file mode 100644 index 0000000000000..d7fdaa3e67212 Binary files /dev/null and b/sql/core/src/test/resources/test-data/before_1582_timestamp_micros_v2_4_6.snappy.parquet differ diff --git a/sql/core/src/test/resources/test-data/before_1582_timestamp_millis_v2_4.snappy.parquet b/sql/core/src/test/resources/test-data/before_1582_timestamp_millis_v2_4.snappy.parquet deleted file mode 100644 index e9825455c2015..0000000000000 Binary files a/sql/core/src/test/resources/test-data/before_1582_timestamp_millis_v2_4.snappy.parquet and /dev/null differ diff --git a/sql/core/src/test/resources/test-data/before_1582_timestamp_millis_v2_4_5.snappy.parquet b/sql/core/src/test/resources/test-data/before_1582_timestamp_millis_v2_4_5.snappy.parquet new file mode 100644 index 0000000000000..a7cef9e60f134 Binary files /dev/null and b/sql/core/src/test/resources/test-data/before_1582_timestamp_millis_v2_4_5.snappy.parquet differ diff --git a/sql/core/src/test/resources/test-data/before_1582_timestamp_millis_v2_4_6.snappy.parquet b/sql/core/src/test/resources/test-data/before_1582_timestamp_millis_v2_4_6.snappy.parquet new file mode 100644 index 0000000000000..4c213f4540a73 Binary files /dev/null and b/sql/core/src/test/resources/test-data/before_1582_timestamp_millis_v2_4_6.snappy.parquet differ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala index f68c416941266..234978b9ce176 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import org.apache.log4j.Level -import org.apache.spark.sql.catalyst.optimizer.EliminateResolvedHint +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide, EliminateResolvedHint} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MetadataCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MetadataCacheSuite.scala index a9f443be69cb2..956bd7861d99d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MetadataCacheSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MetadataCacheSuite.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql import java.io.File import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.sql.execution.adaptive.AdaptiveTestUtils.assertExceptionMessage import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -56,8 +55,8 @@ abstract class MetadataCacheSuite extends QueryTest with SharedSparkSession { val e = intercept[SparkException] { df.count() } - assertExceptionMessage(e, "FileNotFoundException") - assertExceptionMessage(e, "recreating the Dataset/DataFrame involved") + assert(e.getMessage.contains("FileNotFoundException")) + assert(e.getMessage.contains("recreating the Dataset/DataFrame involved")) } } } @@ -85,8 +84,8 @@ class MetadataCacheV1Suite extends MetadataCacheSuite { val e = intercept[SparkException] { sql("select count(*) from view_refresh").first() } - assertExceptionMessage(e, "FileNotFoundException") - assertExceptionMessage(e, "REFRESH") + assert(e.getMessage.contains("FileNotFoundException")) + assert(e.getMessage.contains("REFRESH")) // Refresh and we should be able to read it again. spark.catalog.refreshTable("view_refresh") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala index eca39f3f81726..5c35cedba9bab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala @@ -53,6 +53,7 @@ class QueryExecutionSuite extends SharedSparkSession { s"*(1) Range (0, $expected, step=1, splits=2)", "")) } + test("dumping query execution info to a file") { withTempDir { dir => val path = dir.getCanonicalPath + "/plans.txt" @@ -93,6 +94,25 @@ class QueryExecutionSuite extends SharedSparkSession { assert(exception.getMessage.contains("Illegal character in scheme name")) } + test("dumping query execution info to a file - explainMode=formatted") { + withTempDir { dir => + val path = dir.getCanonicalPath + "/plans.txt" + val df = spark.range(0, 10) + df.queryExecution.debug.toFile(path, explainMode = Option("formatted")) + assert(Source.fromFile(path).getLines.toList + .takeWhile(_ != "== Whole Stage Codegen ==").map(_.replaceAll("#\\d+", "#x")) == List( + "== Physical Plan ==", + s"* Range (1)", + "", + "", + s"(1) Range [codegen id : 1]", + "Output [1]: [id#xL]", + s"Arguments: Range (0, 10, step=1, splits=Some(2))", + "", + "")) + } + } + test("limit number of fields by sql config") { def relationPlans: String = { val ds = spark.createDataset(Seq(QueryExecutionTestRecord( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index a80fc410f5033..207fae826134d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -24,11 +24,12 @@ import org.apache.log4j.Level import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart} import org.apache.spark.sql.{QueryTest, Row, SparkSession, Strategy} +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} import org.apache.spark.sql.execution.{ReusedSubqueryExec, ShuffledRowRDD, SparkPlan} import org.apache.spark.sql.execution.command.DataWritingCommandExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, Exchange, ReusedExchangeExec} -import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BuildLeft, BuildRight, SortMergeJoinExec} +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -740,7 +741,7 @@ class AdaptiveQueryExecSuite val error = intercept[Exception] { agged.count() } - assert(error.getCause().toString contains "Early failed query stage found") + assert(error.getCause().toString contains "Invalid bucket file") assert(error.getSuppressed.size === 0) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveTestUtils.scala index ddaeb57d31547..48f85ae76cd8c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveTestUtils.scala @@ -69,25 +69,3 @@ trait DisableAdaptiveExecutionSuite extends SQLTestUtils { } } } - -object AdaptiveTestUtils { - def assertExceptionMessage(e: Exception, expected: String): Unit = { - val stringWriter = new StringWriter() - e.printStackTrace(new PrintWriter(stringWriter)) - val errorMsg = stringWriter.toString - assert(errorMsg.contains(expected)) - } - - def assertExceptionCause(t: Throwable, causeClass: Class[_]): Unit = { - var c = t.getCause - var foundCause = false - while (c != null && !foundCause) { - if (causeClass.isAssignableFrom(c.getClass)) { - foundCause = true - } else { - c = c.getCause - } - } - assert(foundCause, s"Can not find cause: $causeClass") - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 10ad8acc68937..e4709e469dca3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -1203,14 +1203,24 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } test("alter table: recover partitions (sequential)") { - withSQLConf(RDD_PARALLEL_LISTING_THRESHOLD.key -> "10") { + val oldRddParallelListingThreshold = spark.sparkContext.conf.get( + RDD_PARALLEL_LISTING_THRESHOLD) + try { + spark.sparkContext.conf.set(RDD_PARALLEL_LISTING_THRESHOLD.key, "10") testRecoverPartitions() + } finally { + spark.sparkContext.conf.set(RDD_PARALLEL_LISTING_THRESHOLD, oldRddParallelListingThreshold) } } test("alter table: recover partition (parallel)") { - withSQLConf(RDD_PARALLEL_LISTING_THRESHOLD.key -> "0") { + val oldRddParallelListingThreshold = spark.sparkContext.conf.get( + RDD_PARALLEL_LISTING_THRESHOLD) + try { + spark.sparkContext.conf.set(RDD_PARALLEL_LISTING_THRESHOLD.key, "0") testRecoverPartitions() + } finally { + spark.sparkContext.conf.set(RDD_PARALLEL_LISTING_THRESHOLD, oldRddParallelListingThreshold) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 1308f28d35b9c..899bd23c3a7d2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -35,7 +35,6 @@ import org.apache.spark.sql.{functions => F, _} import org.apache.spark.sql.catalyst.json._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.ExternalRDD -import org.apache.spark.sql.execution.adaptive.AdaptiveTestUtils.assertExceptionMessage import org.apache.spark.sql.execution.datasources.{DataSource, InMemoryFileIndex, NoopCache} import org.apache.spark.sql.execution.datasources.v2.json.JsonScanBuilder import org.apache.spark.sql.internal.SQLConf @@ -2241,7 +2240,7 @@ abstract class JsonSuite extends QueryTest with SharedSparkSession with TestJson .count() } - assertExceptionMessage(exception, "Malformed records are detected in record parsing") + assert(exception.getMessage.contains("Malformed records are detected in record parsing")) } def checkEncoding(expectedEncoding: String, pathToJsonFiles: String, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala index 60f278b8e5bb0..9caf0c836f711 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala @@ -34,7 +34,6 @@ import org.apache.orc.mapreduce.OrcInputFormat import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.execution.adaptive.AdaptiveTestUtils.assertExceptionMessage import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation, RecordReaderIterator} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -599,19 +598,19 @@ abstract class OrcQueryTest extends OrcTest { val e1 = intercept[SparkException] { testIgnoreCorruptFiles() } - assertExceptionMessage(e1, "Malformed ORC file") + assert(e1.getMessage.contains("Malformed ORC file")) val e2 = intercept[SparkException] { testIgnoreCorruptFilesWithoutSchemaInfer() } - assertExceptionMessage(e2, "Malformed ORC file") + assert(e2.getMessage.contains("Malformed ORC file")) val e3 = intercept[SparkException] { testAllCorruptFiles() } - assertExceptionMessage(e3, "Could not read footer for file") + assert(e3.getMessage.contains("Could not read footer for file")) val e4 = intercept[SparkException] { testAllCorruptFilesWithoutSchemaInfer() } - assertExceptionMessage(e4, "Malformed ORC file") + assert(e4.getMessage.contains("Malformed ORC file")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index f075d04165697..79c32976f02ec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.datasources.parquet +import java.nio.file.{Files, Paths, StandardCopyOption} import java.sql.{Date, Timestamp} import java.time._ import java.util.Locale @@ -45,7 +46,7 @@ import org.apache.spark.{SPARK_VERSION_SHORT, SparkException, SparkUpgradeExcept import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection} import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow} -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, DateTimeUtils} import org.apache.spark.sql.execution.datasources.SQLHadoopMapReduceCommitProtocol import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -875,81 +876,152 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession } } + // It generates input files for the test below: + // "SPARK-31159: compatibility with Spark 2.4 in reading dates/timestamps" + ignore("SPARK-31806: generate test files for checking compatibility with Spark 2.4") { + val resourceDir = "sql/core/src/test/resources/test-data" + val version = "2_4_5" + val N = 8 + def save( + in: Seq[(String, String)], + t: String, + dstFile: String, + options: Map[String, String] = Map.empty): Unit = { + withTempDir { dir => + in.toDF("dict", "plain") + .select($"dict".cast(t), $"plain".cast(t)) + .repartition(1) + .write + .mode("overwrite") + .options(options) + .parquet(dir.getCanonicalPath) + Files.copy( + dir.listFiles().filter(_.getName.endsWith(".snappy.parquet")).head.toPath, + Paths.get(resourceDir, dstFile), + StandardCopyOption.REPLACE_EXISTING) + } + } + DateTimeTestUtils.withDefaultTimeZone(DateTimeTestUtils.LA) { + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> DateTimeTestUtils.LA.getId) { + save( + (1 to N).map(i => ("1001-01-01", s"1001-01-0$i")), + "date", + s"before_1582_date_v$version.snappy.parquet") + withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> "TIMESTAMP_MILLIS") { + save( + (1 to N).map(i => ("1001-01-01 01:02:03.123", s"1001-01-0$i 01:02:03.123")), + "timestamp", + s"before_1582_timestamp_millis_v$version.snappy.parquet") + } + val usTs = (1 to N).map(i => ("1001-01-01 01:02:03.123456", s"1001-01-0$i 01:02:03.123456")) + withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> "TIMESTAMP_MICROS") { + save(usTs, "timestamp", s"before_1582_timestamp_micros_v$version.snappy.parquet") + } + withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> "INT96") { + // Comparing to other logical types, Parquet-MR chooses dictionary encoding for the + // INT96 logical type because it consumes less memory for small column cardinality. + // Huge parquet files doesn't make sense to place to the resource folder. That's why + // we explicitly set `parquet.enable.dictionary` and generate two files w/ and w/o + // dictionary encoding. + save( + usTs, + "timestamp", + s"before_1582_timestamp_int96_plain_v$version.snappy.parquet", + Map("parquet.enable.dictionary" -> "false")) + save( + usTs, + "timestamp", + s"before_1582_timestamp_int96_dict_v$version.snappy.parquet", + Map("parquet.enable.dictionary" -> "true")) + } + } + } + } + test("SPARK-31159: compatibility with Spark 2.4 in reading dates/timestamps") { + val N = 8 // test reading the existing 2.4 files and new 3.0 files (with rebase on/off) together. - def checkReadMixedFiles(fileName: String, dt: String, dataStr: String): Unit = { + def checkReadMixedFiles[T]( + fileName: String, + catalystType: String, + rowFunc: Int => (String, String), + toJavaType: String => T, + checkDefaultLegacyRead: String => Unit, + tsOutputType: String = "TIMESTAMP_MICROS"): Unit = { withTempPaths(2) { paths => paths.foreach(_.delete()) val path2_4 = getResourceParquetFilePath("test-data/" + fileName) val path3_0 = paths(0).getCanonicalPath val path3_0_rebase = paths(1).getCanonicalPath - if (dt == "date") { - val df = Seq(dataStr).toDF("str").select($"str".cast("date").as("date")) - + val df = Seq.tabulate(N)(rowFunc).toDF("dict", "plain") + .select($"dict".cast(catalystType), $"plain".cast(catalystType)) + withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> tsOutputType) { + checkDefaultLegacyRead(path2_4) // By default we should fail to write ancient datetime values. - var e = intercept[SparkException](df.write.parquet(path3_0)) + val e = intercept[SparkException](df.write.parquet(path3_0)) assert(e.getCause.getCause.getCause.isInstanceOf[SparkUpgradeException]) - // By default we should fail to read ancient datetime values. - e = intercept[SparkException](spark.read.parquet(path2_4).collect()) - assert(e.getCause.isInstanceOf[SparkUpgradeException]) - withSQLConf(SQLConf.LEGACY_PARQUET_REBASE_MODE_IN_WRITE.key -> CORRECTED.toString) { df.write.mode("overwrite").parquet(path3_0) } withSQLConf(SQLConf.LEGACY_PARQUET_REBASE_MODE_IN_WRITE.key -> LEGACY.toString) { df.write.parquet(path3_0_rebase) } - - // For Parquet files written by Spark 3.0, we know the writer info and don't need the - // config to guide the rebase behavior. - withSQLConf(SQLConf.LEGACY_PARQUET_REBASE_MODE_IN_READ.key -> LEGACY.toString) { - checkAnswer( - spark.read.format("parquet").load(path2_4, path3_0, path3_0_rebase), - 1.to(3).map(_ => Row(java.sql.Date.valueOf(dataStr)))) - } - } else { - val df = Seq(dataStr).toDF("str").select($"str".cast("timestamp").as("ts")) - withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> dt) { - // By default we should fail to write ancient datetime values. - var e = intercept[SparkException](df.write.parquet(path3_0)) - assert(e.getCause.getCause.getCause.isInstanceOf[SparkUpgradeException]) - // By default we should fail to read ancient datetime values. - e = intercept[SparkException](spark.read.parquet(path2_4).collect()) - assert(e.getCause.isInstanceOf[SparkUpgradeException]) - - withSQLConf(SQLConf.LEGACY_PARQUET_REBASE_MODE_IN_WRITE.key -> CORRECTED.toString) { - df.write.mode("overwrite").parquet(path3_0) - } - withSQLConf(SQLConf.LEGACY_PARQUET_REBASE_MODE_IN_WRITE.key -> LEGACY.toString) { - df.write.parquet(path3_0_rebase) - } - } - // For Parquet files written by Spark 3.0, we know the writer info and don't need the - // config to guide the rebase behavior. - withSQLConf(SQLConf.LEGACY_PARQUET_REBASE_MODE_IN_READ.key -> LEGACY.toString) { - checkAnswer( - spark.read.format("parquet").load(path2_4, path3_0, path3_0_rebase), - 1.to(3).map(_ => Row(java.sql.Timestamp.valueOf(dataStr)))) - } + } + // For Parquet files written by Spark 3.0, we know the writer info and don't need the + // config to guide the rebase behavior. + withSQLConf(SQLConf.LEGACY_PARQUET_REBASE_MODE_IN_READ.key -> LEGACY.toString) { + checkAnswer( + spark.read.format("parquet").load(path2_4, path3_0, path3_0_rebase), + (0 until N).flatMap { i => + val (dictS, plainS) = rowFunc(i) + Seq.tabulate(3) { _ => + Row(toJavaType(dictS), toJavaType(plainS)) + } + }) } } } - - withAllParquetReaders { - checkReadMixedFiles("before_1582_date_v2_4.snappy.parquet", "date", "1001-01-01") - checkReadMixedFiles( - "before_1582_timestamp_micros_v2_4.snappy.parquet", - "TIMESTAMP_MICROS", - "1001-01-01 01:02:03.123456") - checkReadMixedFiles( - "before_1582_timestamp_millis_v2_4.snappy.parquet", - "TIMESTAMP_MILLIS", - "1001-01-01 01:02:03.123") - - // INT96 is a legacy timestamp format and we always rebase the seconds for it. - checkAnswer(readResourceParquetFile( - "test-data/before_1582_timestamp_int96_v2_4.snappy.parquet"), - Row(java.sql.Timestamp.valueOf("1001-01-01 01:02:03.123456"))) + def failInRead(path: String): Unit = { + val e = intercept[SparkException](spark.read.parquet(path).collect()) + assert(e.getCause.isInstanceOf[SparkUpgradeException]) + } + def successInRead(path: String): Unit = spark.read.parquet(path).collect() + Seq( + // By default we should fail to read ancient datetime values when parquet files don't + // contain Spark version. + "2_4_5" -> failInRead _, + "2_4_6" -> successInRead _).foreach { case (version, checkDefaultRead) => + withAllParquetReaders { + checkReadMixedFiles( + s"before_1582_date_v$version.snappy.parquet", + "date", + (i: Int) => ("1001-01-01", s"1001-01-0${i + 1}"), + java.sql.Date.valueOf, + checkDefaultRead) + checkReadMixedFiles( + s"before_1582_timestamp_micros_v$version.snappy.parquet", + "timestamp", + (i: Int) => ("1001-01-01 01:02:03.123456", s"1001-01-0${i + 1} 01:02:03.123456"), + java.sql.Timestamp.valueOf, + checkDefaultRead) + checkReadMixedFiles( + s"before_1582_timestamp_millis_v$version.snappy.parquet", + "timestamp", + (i: Int) => ("1001-01-01 01:02:03.123", s"1001-01-0${i + 1} 01:02:03.123"), + java.sql.Timestamp.valueOf, + checkDefaultRead, + tsOutputType = "TIMESTAMP_MILLIS") + // INT96 is a legacy timestamp format and we always rebase the seconds for it. + Seq("plain", "dict").foreach { enc => + checkAnswer(readResourceParquetFile( + s"test-data/before_1582_timestamp_int96_${enc}_v$version.snappy.parquet"), + Seq.tabulate(N) { i => + Row( + java.sql.Timestamp.valueOf("1001-01-01 01:02:03.123456"), + java.sql.Timestamp.valueOf(s"1001-01-0${i + 1} 01:02:03.123456")) + }) + } + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 1be9308c06d8c..f7d5a899df1c9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -22,9 +22,10 @@ import scala.reflect.ClassTag import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession} import org.apache.spark.sql.catalyst.expressions.{BitwiseAnd, BitwiseOr, Cast, Literal, ShiftLeft} +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.plans.logical.BROADCAST import org.apache.spark.sql.execution.{SparkPlan, WholeStageCodegenExec} -import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, AdaptiveTestUtils, DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite} +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite} import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.functions._ @@ -411,7 +412,7 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils val e = intercept[Exception] { testDf.collect() } - AdaptiveTestUtils.assertExceptionMessage(e, s"Could not execute broadcast in $timeout secs.") + assert(e.getMessage.contains(s"Could not execute broadcast in $timeout secs.")) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala index 5490246baceea..554990413c28c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{Join, JoinHint} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index 08898f80034e6..44ab3f7d023d3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical.{Join, JoinHint} @@ -133,7 +134,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSparkSession { withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => makeBroadcastHashJoin( - leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildLeft), + leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, BuildLeft), expectedAnswer.map(Row.fromTuple), sortAnswers = true) } @@ -145,7 +146,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSparkSession { withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => makeBroadcastHashJoin( - leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildRight), + leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, BuildRight), expectedAnswer.map(Row.fromTuple), sortAnswers = true) } @@ -157,7 +158,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSparkSession { withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => makeShuffledHashJoin( - leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildLeft), + leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, BuildLeft), expectedAnswer.map(Row.fromTuple), sortAnswers = true) } @@ -169,7 +170,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSparkSession { withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => makeShuffledHashJoin( - leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildRight), + leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, BuildRight), expectedAnswer.map(Row.fromTuple), sortAnswers = true) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index a5ade0d8d7508..879f282e4d05d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan} +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{Join, JoinHint} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala index e18514c6f93f9..53f9757750735 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala @@ -20,11 +20,12 @@ package org.apache.spark.sql.expressions import scala.collection.parallel.immutable.ParVector import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.FunctionIdentifier -import org.apache.spark.sql.catalyst.expressions.ExpressionInfo +import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.HiveResult.hiveResultString import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.util.Utils class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession { @@ -156,4 +157,38 @@ class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession { } } } + + test("Check whether SQL expressions should extend NullIntolerant") { + // Only check expressions extended from these expressions because these expressions are + // NullIntolerant by default. + val exprTypesToCheck = Seq(classOf[UnaryExpression], classOf[BinaryExpression], + classOf[TernaryExpression], classOf[QuaternaryExpression], classOf[SeptenaryExpression]) + + // Do not check these expressions, because these expressions extend NullIntolerant + // and override the eval method to avoid evaluating input1 if input2 is 0. + val ignoreSet = Set(classOf[IntegralDivide], classOf[Divide], classOf[Remainder], classOf[Pmod]) + + val candidateExprsToCheck = spark.sessionState.functionRegistry.listFunction() + .map(spark.sessionState.catalog.lookupFunctionInfo).map(_.getClassName) + .filterNot(c => ignoreSet.exists(_.getName.equals(c))) + .map(name => Utils.classForName(name)) + .filterNot(classOf[NonSQLExpression].isAssignableFrom) + + exprTypesToCheck.foreach { superClass => + candidateExprsToCheck.filter(superClass.isAssignableFrom).foreach { clazz => + val isEvalOverrode = clazz.getMethod("eval", classOf[InternalRow]) != + superClass.getMethod("eval", classOf[InternalRow]) + val isNullIntolerantMixedIn = classOf[NullIntolerant].isAssignableFrom(clazz) + if (isEvalOverrode && isNullIntolerantMixedIn) { + fail(s"${clazz.getName} should not extend ${classOf[NullIntolerant].getSimpleName}, " + + s"or add ${clazz.getName} in the ignoreSet of this test.") + } else if (!isEvalOverrode && !isNullIntolerantMixedIn) { + fail(s"${clazz.getName} should extend ${classOf[NullIntolerant].getSimpleName}.") + } else { + assert((!isEvalOverrode && isNullIntolerantMixedIn) || + (isEvalOverrode && !isNullIntolerantMixedIn)) + } + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index e153c7168dbf2..1d8303b9e7750 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -29,7 +29,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.execution.{DataSourceScanExec, FileSourceScanExec, SortExec, SparkPlan} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec -import org.apache.spark.sql.execution.adaptive.AdaptiveTestUtils.assertExceptionMessage import org.apache.spark.sql.execution.datasources.BucketingUtils import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.joins.SortMergeJoinExec @@ -771,7 +770,7 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { agged.count() } - assertExceptionMessage(error, "Invalid bucket file") + assert(error.getCause().toString contains "Invalid bucket file") } } diff --git a/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala index a01d5a44da714..b68563956c82c 100644 --- a/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala +++ b/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.orc -import java.time.LocalDate +import java.time.{Instant, LocalDate} import org.apache.orc.storage.common.`type`.HiveDecimal import org.apache.orc.storage.ql.io.sarg.{PredicateLeaf, SearchArgument} @@ -26,7 +26,7 @@ import org.apache.orc.storage.ql.io.sarg.SearchArgumentFactory.newBuilder import org.apache.orc.storage.serde2.io.HiveDecimalWritable import org.apache.spark.SparkException -import org.apache.spark.sql.catalyst.util.DateTimeUtils.{localDateToDays, toJavaDate} +import org.apache.spark.sql.catalyst.util.DateTimeUtils.{instantToMicros, localDateToDays, toJavaDate, toJavaTimestamp} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.quoteIfNeeded import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types._ @@ -167,6 +167,8 @@ private[sql] object OrcFilters extends OrcFiltersBase { new HiveDecimalWritable(HiveDecimal.create(value.asInstanceOf[java.math.BigDecimal])) case _: DateType if value.isInstanceOf[LocalDate] => toJavaDate(localDateToDays(value.asInstanceOf[LocalDate])) + case _: TimestampType if value.isInstanceOf[Instant] => + toJavaTimestamp(instantToMicros(value.asInstanceOf[Instant])) case _ => value } diff --git a/sql/core/v1.2/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala b/sql/core/v1.2/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala index a1c325e7bb876..88b4b243b543a 100644 --- a/sql/core/v1.2/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala +++ b/sql/core/v1.2/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala @@ -245,29 +245,41 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession { } test("filter pushdown - timestamp") { - val timeString = "2015-08-20 14:57:00" - val timestamps = (1 to 4).map { i => - val milliseconds = Timestamp.valueOf(timeString).getTime + i * 3600 - new Timestamp(milliseconds) - } - withOrcDataFrame(timestamps.map(Tuple1(_))) { implicit df => - checkFilterPredicate($"_1".isNull, PredicateLeaf.Operator.IS_NULL) + val input = Seq( + "1000-01-01 01:02:03", + "1582-10-01 00:11:22", + "1900-01-01 23:59:59", + "2020-05-25 10:11:12").map(Timestamp.valueOf) - checkFilterPredicate($"_1" === timestamps(0), PredicateLeaf.Operator.EQUALS) - checkFilterPredicate($"_1" <=> timestamps(0), PredicateLeaf.Operator.NULL_SAFE_EQUALS) - - checkFilterPredicate($"_1" < timestamps(1), PredicateLeaf.Operator.LESS_THAN) - checkFilterPredicate($"_1" > timestamps(2), PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate($"_1" <= timestamps(0), PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate($"_1" >= timestamps(3), PredicateLeaf.Operator.LESS_THAN) - - checkFilterPredicate(Literal(timestamps(0)) === $"_1", PredicateLeaf.Operator.EQUALS) - checkFilterPredicate(Literal(timestamps(0)) <=> $"_1", - PredicateLeaf.Operator.NULL_SAFE_EQUALS) - checkFilterPredicate(Literal(timestamps(1)) > $"_1", PredicateLeaf.Operator.LESS_THAN) - checkFilterPredicate(Literal(timestamps(2)) < $"_1", PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate(Literal(timestamps(0)) >= $"_1", PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate(Literal(timestamps(3)) <= $"_1", PredicateLeaf.Operator.LESS_THAN) + withOrcFile(input.map(Tuple1(_))) { path => + Seq(false, true).foreach { java8Api => + withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> java8Api.toString) { + readFile(path) { implicit df => + val timestamps = input.map(Literal(_)) + checkFilterPredicate($"_1".isNull, PredicateLeaf.Operator.IS_NULL) + + checkFilterPredicate($"_1" === timestamps(0), PredicateLeaf.Operator.EQUALS) + checkFilterPredicate($"_1" <=> timestamps(0), PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate($"_1" < timestamps(1), PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate($"_1" > timestamps(2), PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate($"_1" <= timestamps(0), PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate($"_1" >= timestamps(3), PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(Literal(timestamps(0)) === $"_1", PredicateLeaf.Operator.EQUALS) + checkFilterPredicate( + Literal(timestamps(0)) <=> $"_1", PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal(timestamps(1)) > $"_1", PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate( + Literal(timestamps(2)) < $"_1", + PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate( + Literal(timestamps(0)) >= $"_1", + PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(timestamps(3)) <= $"_1", PredicateLeaf.Operator.LESS_THAN) + } + } + } } } diff --git a/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala index 445a52cece1c3..4b642080d25ad 100644 --- a/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala +++ b/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.orc -import java.time.LocalDate +import java.time.{Instant, LocalDate} import org.apache.hadoop.hive.common.`type`.HiveDecimal import org.apache.hadoop.hive.ql.io.sarg.{PredicateLeaf, SearchArgument} @@ -26,7 +26,7 @@ import org.apache.hadoop.hive.ql.io.sarg.SearchArgumentFactory.newBuilder import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable import org.apache.spark.SparkException -import org.apache.spark.sql.catalyst.util.DateTimeUtils.{localDateToDays, toJavaDate} +import org.apache.spark.sql.catalyst.util.DateTimeUtils.{instantToMicros, localDateToDays, toJavaDate, toJavaTimestamp} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.quoteIfNeeded import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types._ @@ -167,6 +167,8 @@ private[sql] object OrcFilters extends OrcFiltersBase { new HiveDecimalWritable(HiveDecimal.create(value.asInstanceOf[java.math.BigDecimal])) case _: DateType if value.isInstanceOf[LocalDate] => toJavaDate(localDateToDays(value.asInstanceOf[LocalDate])) + case _: TimestampType if value.isInstanceOf[Instant] => + toJavaTimestamp(instantToMicros(value.asInstanceOf[Instant])) case _ => value } diff --git a/sql/core/v2.3/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala b/sql/core/v2.3/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala index 815af05beb002..2263179515a5f 100644 --- a/sql/core/v2.3/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala +++ b/sql/core/v2.3/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala @@ -246,29 +246,41 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession { } test("filter pushdown - timestamp") { - val timeString = "2015-08-20 14:57:00" - val timestamps = (1 to 4).map { i => - val milliseconds = Timestamp.valueOf(timeString).getTime + i * 3600 - new Timestamp(milliseconds) - } - withOrcDataFrame(timestamps.map(Tuple1(_))) { implicit df => - checkFilterPredicate($"_1".isNull, PredicateLeaf.Operator.IS_NULL) - - checkFilterPredicate($"_1" === timestamps(0), PredicateLeaf.Operator.EQUALS) - checkFilterPredicate($"_1" <=> timestamps(0), PredicateLeaf.Operator.NULL_SAFE_EQUALS) + val input = Seq( + "1000-01-01 01:02:03", + "1582-10-01 00:11:22", + "1900-01-01 23:59:59", + "2020-05-25 10:11:12").map(Timestamp.valueOf) - checkFilterPredicate($"_1" < timestamps(1), PredicateLeaf.Operator.LESS_THAN) - checkFilterPredicate($"_1" > timestamps(2), PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate($"_1" <= timestamps(0), PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate($"_1" >= timestamps(3), PredicateLeaf.Operator.LESS_THAN) + withOrcFile(input.map(Tuple1(_))) { path => + Seq(false, true).foreach { java8Api => + withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> java8Api.toString) { + readFile(path) { implicit df => + val timestamps = input.map(Literal(_)) + checkFilterPredicate($"_1".isNull, PredicateLeaf.Operator.IS_NULL) - checkFilterPredicate(Literal(timestamps(0)) === $"_1", PredicateLeaf.Operator.EQUALS) - checkFilterPredicate( - Literal(timestamps(0)) <=> $"_1", PredicateLeaf.Operator.NULL_SAFE_EQUALS) - checkFilterPredicate(Literal(timestamps(1)) > $"_1", PredicateLeaf.Operator.LESS_THAN) - checkFilterPredicate(Literal(timestamps(2)) < $"_1", PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate(Literal(timestamps(0)) >= $"_1", PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate(Literal(timestamps(3)) <= $"_1", PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate($"_1" === timestamps(0), PredicateLeaf.Operator.EQUALS) + checkFilterPredicate($"_1" <=> timestamps(0), PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate($"_1" < timestamps(1), PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate($"_1" > timestamps(2), PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate($"_1" <= timestamps(0), PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate($"_1" >= timestamps(3), PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(Literal(timestamps(0)) === $"_1", PredicateLeaf.Operator.EQUALS) + checkFilterPredicate( + Literal(timestamps(0)) <=> $"_1", PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal(timestamps(1)) > $"_1", PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate( + Literal(timestamps(2)) < $"_1", + PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate( + Literal(timestamps(0)) >= $"_1", + PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(timestamps(3)) <= $"_1", PredicateLeaf.Operator.LESS_THAN) + } + } + } } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala index 1f2d4b1b87773..8efbdb30c605c 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala @@ -175,19 +175,19 @@ private[ui] class SqlStatsPagedTable( private val (sortColumn, desc, pageSize) = getTableParameters(request, sqlStatsTableTag, "Start Time") - override val dataSource = new SqlStatsTableDataSource(data, pageSize, sortColumn, desc) + private val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name()) private val parameterPath = s"$basePath/$subPath/?${getParameterOtherTable(request, sqlStatsTableTag)}" + override val dataSource = new SqlStatsTableDataSource(data, pageSize, sortColumn, desc) + override def tableId: String = sqlStatsTableTag override def tableCssClass: String = - "table table-bordered table-sm table-striped " + - "table-head-clickable table-cell-width-limited" + "table table-bordered table-sm table-striped table-head-clickable table-cell-width-limited" override def pageLink(page: Int): String = { - val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name()) parameterPath + s"&$pageNumberFormField=$page" + s"&$sqlStatsTableTag.sort=$encodedSortColumn" + @@ -200,11 +200,9 @@ private[ui] class SqlStatsPagedTable( override def pageNumberFormField: String = s"$sqlStatsTableTag.page" - override def goButtonFormPath: String = { - val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name()) + override def goButtonFormPath: String = s"$parameterPath&$sqlStatsTableTag.sort=$encodedSortColumn" + s"&$sqlStatsTableTag.desc=$desc#$sqlStatsTableTag" - } override def headers: Seq[Node] = { val sqlTableHeadersAndTooltips: Seq[(String, Boolean, Option[String])] = @@ -307,19 +305,19 @@ private[ui] class SessionStatsPagedTable( private val (sortColumn, desc, pageSize) = getTableParameters(request, sessionStatsTableTag, "Start Time") - override val dataSource = new SessionStatsTableDataSource(data, pageSize, sortColumn, desc) + private val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name()) private val parameterPath = s"$basePath/$subPath/?${getParameterOtherTable(request, sessionStatsTableTag)}" + override val dataSource = new SessionStatsTableDataSource(data, pageSize, sortColumn, desc) + override def tableId: String = sessionStatsTableTag override def tableCssClass: String = - "table table-bordered table-sm table-striped " + - "table-head-clickable table-cell-width-limited" + "table table-bordered table-sm table-striped table-head-clickable table-cell-width-limited" override def pageLink(page: Int): String = { - val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name()) parameterPath + s"&$pageNumberFormField=$page" + s"&$sessionStatsTableTag.sort=$encodedSortColumn" + @@ -332,11 +330,9 @@ private[ui] class SessionStatsPagedTable( override def pageNumberFormField: String = s"$sessionStatsTableTag.page" - override def goButtonFormPath: String = { - val encodedSortColumn = URLEncoder.encode(sortColumn, UTF_8.name()) + override def goButtonFormPath: String = s"$parameterPath&$sessionStatsTableTag.sort=$encodedSortColumn" + s"&$sessionStatsTableTag.desc=$desc#$sessionStatsTableTag" - } override def headers: Seq[Node] = { val sessionTableHeadersAndTooltips: Seq[(String, Boolean, Option[String])] = @@ -370,108 +366,94 @@ private[ui] class SessionStatsPagedTable( } } - private[ui] class SqlStatsTableRow( +private[ui] class SqlStatsTableRow( val jobId: Seq[String], val duration: Long, val executionTime: Long, val executionInfo: ExecutionInfo, val detail: String) - private[ui] class SqlStatsTableDataSource( +private[ui] class SqlStatsTableDataSource( info: Seq[ExecutionInfo], pageSize: Int, sortColumn: String, desc: Boolean) extends PagedDataSource[SqlStatsTableRow](pageSize) { - // Convert ExecutionInfo to SqlStatsTableRow which contains the final contents to show in - // the table so that we can avoid creating duplicate contents during sorting the data - private val data = info.map(sqlStatsTableRow).sorted(ordering(sortColumn, desc)) - - private var _slicedStartTime: Set[Long] = null + // Convert ExecutionInfo to SqlStatsTableRow which contains the final contents to show in + // the table so that we can avoid creating duplicate contents during sorting the data + private val data = info.map(sqlStatsTableRow).sorted(ordering(sortColumn, desc)) - override def dataSize: Int = data.size + override def dataSize: Int = data.size - override def sliceData(from: Int, to: Int): Seq[SqlStatsTableRow] = { - val r = data.slice(from, to) - _slicedStartTime = r.map(_.executionInfo.startTimestamp).toSet - r - } + override def sliceData(from: Int, to: Int): Seq[SqlStatsTableRow] = data.slice(from, to) - private def sqlStatsTableRow(executionInfo: ExecutionInfo): SqlStatsTableRow = { - val duration = executionInfo.totalTime(executionInfo.closeTimestamp) - val executionTime = executionInfo.totalTime(executionInfo.finishTimestamp) - val detail = Option(executionInfo.detail).filter(!_.isEmpty) - .getOrElse(executionInfo.executePlan) - val jobId = executionInfo.jobId.toSeq.sorted + private def sqlStatsTableRow(executionInfo: ExecutionInfo): SqlStatsTableRow = { + val duration = executionInfo.totalTime(executionInfo.closeTimestamp) + val executionTime = executionInfo.totalTime(executionInfo.finishTimestamp) + val detail = Option(executionInfo.detail).filter(!_.isEmpty) + .getOrElse(executionInfo.executePlan) + val jobId = executionInfo.jobId.toSeq.sorted - new SqlStatsTableRow(jobId, duration, executionTime, executionInfo, detail) + new SqlStatsTableRow(jobId, duration, executionTime, executionInfo, detail) + } + /** + * Return Ordering according to sortColumn and desc. + */ + private def ordering(sortColumn: String, desc: Boolean): Ordering[SqlStatsTableRow] = { + val ordering: Ordering[SqlStatsTableRow] = sortColumn match { + case "User" => Ordering.by(_.executionInfo.userName) + case "JobID" => Ordering by (_.jobId.headOption) + case "GroupID" => Ordering.by(_.executionInfo.groupId) + case "Start Time" => Ordering.by(_.executionInfo.startTimestamp) + case "Finish Time" => Ordering.by(_.executionInfo.finishTimestamp) + case "Close Time" => Ordering.by(_.executionInfo.closeTimestamp) + case "Execution Time" => Ordering.by(_.executionTime) + case "Duration" => Ordering.by(_.duration) + case "Statement" => Ordering.by(_.executionInfo.statement) + case "State" => Ordering.by(_.executionInfo.state) + case "Detail" => Ordering.by(_.detail) + case unknownColumn => throw new IllegalArgumentException(s"Unknown column: $unknownColumn") } - - /** - * Return Ordering according to sortColumn and desc. - */ - private def ordering(sortColumn: String, desc: Boolean): Ordering[SqlStatsTableRow] = { - val ordering: Ordering[SqlStatsTableRow] = sortColumn match { - case "User" => Ordering.by(_.executionInfo.userName) - case "JobID" => Ordering by (_.jobId.headOption) - case "GroupID" => Ordering.by(_.executionInfo.groupId) - case "Start Time" => Ordering.by(_.executionInfo.startTimestamp) - case "Finish Time" => Ordering.by(_.executionInfo.finishTimestamp) - case "Close Time" => Ordering.by(_.executionInfo.closeTimestamp) - case "Execution Time" => Ordering.by(_.executionTime) - case "Duration" => Ordering.by(_.duration) - case "Statement" => Ordering.by(_.executionInfo.statement) - case "State" => Ordering.by(_.executionInfo.state) - case "Detail" => Ordering.by(_.detail) - case unknownColumn => throw new IllegalArgumentException(s"Unknown column: $unknownColumn") - } - if (desc) { - ordering.reverse - } else { - ordering - } + if (desc) { + ordering.reverse + } else { + ordering } - } +} - private[ui] class SessionStatsTableDataSource( +private[ui] class SessionStatsTableDataSource( info: Seq[SessionInfo], pageSize: Int, sortColumn: String, desc: Boolean) extends PagedDataSource[SessionInfo](pageSize) { - // Sorting SessionInfo data - private val data = info.sorted(ordering(sortColumn, desc)) - - private var _slicedStartTime: Set[Long] = null - - override def dataSize: Int = data.size - - override def sliceData(from: Int, to: Int): Seq[SessionInfo] = { - val r = data.slice(from, to) - _slicedStartTime = r.map(_.startTimestamp).toSet - r + // Sorting SessionInfo data + private val data = info.sorted(ordering(sortColumn, desc)) + + override def dataSize: Int = data.size + + override def sliceData(from: Int, to: Int): Seq[SessionInfo] = data.slice(from, to) + + /** + * Return Ordering according to sortColumn and desc. + */ + private def ordering(sortColumn: String, desc: Boolean): Ordering[SessionInfo] = { + val ordering: Ordering[SessionInfo] = sortColumn match { + case "User" => Ordering.by(_.userName) + case "IP" => Ordering.by(_.ip) + case "Session ID" => Ordering.by(_.sessionId) + case "Start Time" => Ordering by (_.startTimestamp) + case "Finish Time" => Ordering.by(_.finishTimestamp) + case "Duration" => Ordering.by(_.totalTime) + case "Total Execute" => Ordering.by(_.totalExecution) + case unknownColumn => throw new IllegalArgumentException(s"Unknown column: $unknownColumn") } - - /** - * Return Ordering according to sortColumn and desc. - */ - private def ordering(sortColumn: String, desc: Boolean): Ordering[SessionInfo] = { - val ordering: Ordering[SessionInfo] = sortColumn match { - case "User" => Ordering.by(_.userName) - case "IP" => Ordering.by(_.ip) - case "Session ID" => Ordering.by(_.sessionId) - case "Start Time" => Ordering by (_.startTimestamp) - case "Finish Time" => Ordering.by(_.finishTimestamp) - case "Duration" => Ordering.by(_.totalTime) - case "Total Execute" => Ordering.by(_.totalExecution) - case unknownColumn => throw new IllegalArgumentException(s"Unknown column: $unknownColumn") - } - if (desc) { - ordering.reverse - } else { - ordering - } + if (desc) { + ordering.reverse + } else { + ordering } } +} diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SharedThriftServer.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SharedThriftServer.scala index ce610098156f3..e002bc0117c8b 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SharedThriftServer.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SharedThriftServer.scala @@ -19,29 +19,25 @@ package org.apache.spark.sql.hive.thriftserver import java.sql.{DriverManager, Statement} +import scala.collection.JavaConverters._ import scala.concurrent.duration._ -import scala.util.{Random, Try} +import scala.util.Try import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.hive.service.cli.thrift.ThriftCLIService -import org.apache.spark.sql.QueryTest import org.apache.spark.sql.test.SharedSparkSession trait SharedThriftServer extends SharedSparkSession { private var hiveServer2: HiveThriftServer2 = _ + private var serverPort: Int = 0 override def beforeAll(): Unit = { super.beforeAll() - // Chooses a random port between 10000 and 19999 - var listeningPort = 10000 + Random.nextInt(10000) - // Retries up to 3 times with different port numbers if the server fails to start - (1 to 3).foldLeft(Try(startThriftServer(listeningPort, 0))) { case (started, attempt) => - started.orElse { - listeningPort += 1 - Try(startThriftServer(listeningPort, attempt)) - } + (1 to 3).foldLeft(Try(startThriftServer(0))) { case (started, attempt) => + started.orElse(Try(startThriftServer(attempt))) }.recover { case cause: Throwable => throw cause @@ -59,8 +55,7 @@ trait SharedThriftServer extends SharedSparkSession { protected def withJdbcStatement(fs: (Statement => Unit)*): Unit = { val user = System.getProperty("user.name") - - val serverPort = hiveServer2.getHiveConf.get(ConfVars.HIVE_SERVER2_THRIFT_PORT.varname) + require(serverPort != 0, "Failed to bind an actual port for HiveThriftServer2") val connections = fs.map { _ => DriverManager.getConnection(s"jdbc:hive2://localhost:$serverPort", user, "") } val statements = connections.map(_.createStatement()) @@ -73,11 +68,19 @@ trait SharedThriftServer extends SharedSparkSession { } } - private def startThriftServer(port: Int, attempt: Int): Unit = { - logInfo(s"Trying to start HiveThriftServer2: port=$port, attempt=$attempt") + private def startThriftServer(attempt: Int): Unit = { + logInfo(s"Trying to start HiveThriftServer2:, attempt=$attempt") val sqlContext = spark.newSession().sqlContext - sqlContext.setConf(ConfVars.HIVE_SERVER2_THRIFT_PORT.varname, port.toString) + // Set the HIVE_SERVER2_THRIFT_PORT to 0, so it could randomly pick any free port to use. + // It's much more robust than set a random port generated by ourselves ahead + sqlContext.setConf(ConfVars.HIVE_SERVER2_THRIFT_PORT.varname, "0") hiveServer2 = HiveThriftServer2.startWithContext(sqlContext) + hiveServer2.getServices.asScala.foreach { + case t: ThriftCLIService if t.getPortNumber != 0 => + serverPort = t.getPortNumber + logInfo(s"Started HiveThriftServer2: port=$serverPort, attempt=$attempt") + case _ => + } // Wait for thrift server to be ready to serve the query, via executing simple query // till the query succeeds. See SPARK-30345 for more details. diff --git a/sql/hive-thriftserver/v1.2/src/main/java/org/apache/hive/service/cli/thrift/ThriftBinaryCLIService.java b/sql/hive-thriftserver/v1.2/src/main/java/org/apache/hive/service/cli/thrift/ThriftBinaryCLIService.java index 21b8bf7de75ce..e1ee503b81209 100644 --- a/sql/hive-thriftserver/v1.2/src/main/java/org/apache/hive/service/cli/thrift/ThriftBinaryCLIService.java +++ b/sql/hive-thriftserver/v1.2/src/main/java/org/apache/hive/service/cli/thrift/ThriftBinaryCLIService.java @@ -76,6 +76,10 @@ public void run() { keyStorePassword, sslVersionBlacklist); } + // In case HIVE_SERVER2_THRIFT_PORT or hive.server2.thrift.port is configured with 0 which + // represents any free port, we should set it to the actual one + portNum = serverSocket.getServerSocket().getLocalPort(); + // Server args int maxMessageSize = hiveConf.getIntVar(HiveConf.ConfVars.HIVE_SERVER2_THRIFT_MAX_MESSAGE_SIZE); int requestTimeout = (int) hiveConf.getTimeVar( diff --git a/sql/hive-thriftserver/v1.2/src/main/java/org/apache/hive/service/cli/thrift/ThriftHttpCLIService.java b/sql/hive-thriftserver/v1.2/src/main/java/org/apache/hive/service/cli/thrift/ThriftHttpCLIService.java index 504e63dbc5e5e..1099a00b67eb7 100644 --- a/sql/hive-thriftserver/v1.2/src/main/java/org/apache/hive/service/cli/thrift/ThriftHttpCLIService.java +++ b/sql/hive-thriftserver/v1.2/src/main/java/org/apache/hive/service/cli/thrift/ThriftHttpCLIService.java @@ -143,6 +143,9 @@ public void run() { // TODO: check defaults: maxTimeout, keepalive, maxBodySize, bodyRecieveDuration, etc. // Finally, start the server httpServer.start(); + // In case HIVE_SERVER2_THRIFT_HTTP_PORT or hive.server2.thrift.http.port is configured with + // 0 which represents any free port, we should set it to the actual one + portNum = connector.getLocalPort(); String msg = "Started " + ThriftHttpCLIService.class.getSimpleName() + " in " + schemeName + " mode on port " + connector.getLocalPort()+ " path=" + httpPath + " with " + minWorkerThreads + "..." + maxWorkerThreads + " worker threads"; diff --git a/sql/hive-thriftserver/v2.3/src/main/java/org/apache/hive/service/cli/thrift/ThriftBinaryCLIService.java b/sql/hive-thriftserver/v2.3/src/main/java/org/apache/hive/service/cli/thrift/ThriftBinaryCLIService.java index fc19c65daaf54..a7de9c0f3d0d2 100644 --- a/sql/hive-thriftserver/v2.3/src/main/java/org/apache/hive/service/cli/thrift/ThriftBinaryCLIService.java +++ b/sql/hive-thriftserver/v2.3/src/main/java/org/apache/hive/service/cli/thrift/ThriftBinaryCLIService.java @@ -77,6 +77,10 @@ public void run() { keyStorePassword, sslVersionBlacklist); } + // In case HIVE_SERVER2_THRIFT_PORT or hive.server2.thrift.port is configured with 0 which + // represents any free port, we should set it to the actual one + portNum = serverSocket.getServerSocket().getLocalPort(); + // Server args int maxMessageSize = hiveConf.getIntVar(HiveConf.ConfVars.HIVE_SERVER2_THRIFT_MAX_MESSAGE_SIZE); int requestTimeout = (int) hiveConf.getTimeVar( diff --git a/sql/hive-thriftserver/v2.3/src/main/java/org/apache/hive/service/cli/thrift/ThriftHttpCLIService.java b/sql/hive-thriftserver/v2.3/src/main/java/org/apache/hive/service/cli/thrift/ThriftHttpCLIService.java index 08626e7eb146d..73d5f84476af0 100644 --- a/sql/hive-thriftserver/v2.3/src/main/java/org/apache/hive/service/cli/thrift/ThriftHttpCLIService.java +++ b/sql/hive-thriftserver/v2.3/src/main/java/org/apache/hive/service/cli/thrift/ThriftHttpCLIService.java @@ -144,6 +144,9 @@ public void run() { // TODO: check defaults: maxTimeout, keepalive, maxBodySize, bodyRecieveDuration, etc. // Finally, start the server httpServer.start(); + // In case HIVE_SERVER2_THRIFT_HTTP_PORT or hive.server2.thrift.http.port is configured with + // 0 which represents any free port, we should set it to the actual one + portNum = connector.getLocalPort(); String msg = "Started " + ThriftHttpCLIService.class.getSimpleName() + " in " + schemeName + " mode on port " + portNum + " path=" + httpPath + " with " + minWorkerThreads + "..." + maxWorkerThreads + " worker threads"; diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala index 743cdbd6457d7..db8ebcd45f3eb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala @@ -21,7 +21,6 @@ import org.apache.hadoop.fs.Path import org.apache.spark.SparkException import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.execution.adaptive.AdaptiveTestUtils.assertExceptionMessage import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils @@ -100,7 +99,7 @@ class HiveMetadataCacheSuite extends QueryTest with SQLTestUtils with TestHiveSi val e = intercept[SparkException] { sql("select * from test").count() } - assertExceptionMessage(e, "FileNotFoundException") + assert(e.getMessage.contains("FileNotFoundException")) // Test refreshing the cache. spark.catalog.refreshTable("test") @@ -115,7 +114,7 @@ class HiveMetadataCacheSuite extends QueryTest with SQLTestUtils with TestHiveSi val e2 = intercept[SparkException] { sql("select * from test").count() } - assertExceptionMessage(e2, "FileNotFoundException") + assert(e.getMessage.contains("FileNotFoundException")) spark.catalog.refreshByPath(dir.getAbsolutePath) assert(sql("select * from test").count() == 3) }