diff --git a/R/README.md b/R/README.md
index c808ca88f72dc..31174c73526f2 100644
--- a/R/README.md
+++ b/R/README.md
@@ -20,7 +20,7 @@ export R_HOME=/home/username/R
Build Spark with [Maven](https://spark.apache.org/docs/latest/building-spark.html#buildmvn) and include the `-Psparkr` profile to build the R package. For example to use the default Hadoop versions you can run
```bash
-build/mvn -DskipTests -Psparkr package
+./build/mvn -DskipTests -Psparkr package
```
#### Running sparkR
diff --git a/README.md b/README.md
index 9759559e6cf6f..29777a5962bc2 100644
--- a/README.md
+++ b/README.md
@@ -25,7 +25,7 @@ This README file only contains basic setup instructions.
Spark is built using [Apache Maven](https://maven.apache.org/).
To build Spark and its example programs, run:
- build/mvn -DskipTests clean package
+ ./build/mvn -DskipTests clean package
(You do not need to do this if you downloaded a pre-built package.)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/dynalloc/ExecutorMonitor.scala b/core/src/main/scala/org/apache/spark/scheduler/dynalloc/ExecutorMonitor.scala
index f5beb403555e9..d0337b6e34962 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/dynalloc/ExecutorMonitor.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/dynalloc/ExecutorMonitor.scala
@@ -182,19 +182,19 @@ private[spark] class ExecutorMonitor(
if (updateExecutors) {
val activeShuffleIds = shuffleStages.map(_._2).toSeq
var needTimeoutUpdate = false
- val activatedExecs = new mutable.ArrayBuffer[String]()
+ val activatedExecs = new ExecutorIdCollector()
executors.asScala.foreach { case (id, exec) =>
if (!exec.hasActiveShuffle) {
exec.updateActiveShuffles(activeShuffleIds)
if (exec.hasActiveShuffle) {
needTimeoutUpdate = true
- activatedExecs += id
+ activatedExecs.add(id)
}
}
}
- logDebug(s"Activated executors ${activatedExecs.mkString(",")} due to shuffle data " +
- s"needed by new job ${event.jobId}.")
+ logDebug(s"Activated executors $activatedExecs due to shuffle data needed by new job" +
+ s"${event.jobId}.")
if (needTimeoutUpdate) {
nextTimeout.set(Long.MinValue)
@@ -233,18 +233,18 @@ private[spark] class ExecutorMonitor(
}
}
- val deactivatedExecs = new mutable.ArrayBuffer[String]()
+ val deactivatedExecs = new ExecutorIdCollector()
executors.asScala.foreach { case (id, exec) =>
if (exec.hasActiveShuffle) {
exec.updateActiveShuffles(activeShuffles)
if (!exec.hasActiveShuffle) {
- deactivatedExecs += id
+ deactivatedExecs.add(id)
}
}
}
- logDebug(s"Executors ${deactivatedExecs.mkString(",")} do not have active shuffle data " +
- s"after job ${event.jobId} finished.")
+ logDebug(s"Executors $deactivatedExecs do not have active shuffle data after job " +
+ s"${event.jobId} finished.")
}
jobToStageIDs.remove(event.jobId).foreach { stages =>
@@ -448,7 +448,8 @@ private[spark] class ExecutorMonitor(
} else {
idleTimeoutMs
}
- idleStart + timeout
+ val deadline = idleStart + timeout
+ if (deadline >= 0) deadline else Long.MaxValue
} else {
Long.MaxValue
}
@@ -491,4 +492,22 @@ private[spark] class ExecutorMonitor(
private case class ShuffleCleanedEvent(id: Int) extends SparkListenerEvent {
override protected[spark] def logEvent: Boolean = false
}
+
+ /** Used to collect executor IDs for debug messages (and avoid too long messages). */
+ private class ExecutorIdCollector {
+ private val ids = if (log.isDebugEnabled) new mutable.ArrayBuffer[String]() else null
+ private var excess = 0
+
+ def add(id: String): Unit = if (log.isDebugEnabled) {
+ if (ids.size < 10) {
+ ids += id
+ } else {
+ excess += 1
+ }
+ }
+
+ override def toString(): String = {
+ ids.mkString(",") + (if (excess > 0) s" (and $excess more)" else "")
+ }
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
index 838fc82d2ee37..54f0f8e226791 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
@@ -641,18 +641,22 @@ private[ui] class TaskPagedTable(
{accumulatorsInfo(task)} |
}}
{if (hasInput(stage)) {
- metricInfo(task) { m =>
- val bytesRead = Utils.bytesToString(m.inputMetrics.bytesRead)
- val records = m.inputMetrics.recordsRead
- {bytesRead} / {records} |
- }
+ {
+ metricInfo(task) { m =>
+ val bytesRead = Utils.bytesToString(m.inputMetrics.bytesRead)
+ val records = m.inputMetrics.recordsRead
+ Unparsed(s"$bytesRead / $records")
+ }
+ } |
}}
{if (hasOutput(stage)) {
- metricInfo(task) { m =>
- val bytesWritten = Utils.bytesToString(m.outputMetrics.bytesWritten)
- val records = m.outputMetrics.recordsWritten
- {bytesWritten} / {records} |
- }
+ {
+ metricInfo(task) { m =>
+ val bytesWritten = Utils.bytesToString(m.outputMetrics.bytesWritten)
+ val records = m.outputMetrics.recordsWritten
+ Unparsed(s"$bytesWritten / $records")
+ }
+ } |
}}
{if (hasShuffleRead(stage)) {
diff --git a/core/src/test/scala/org/apache/spark/scheduler/dynalloc/ExecutorMonitorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/dynalloc/ExecutorMonitorSuite.scala
index e11ee97469b00..6a25754fcbe5a 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/dynalloc/ExecutorMonitorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/dynalloc/ExecutorMonitorSuite.scala
@@ -367,6 +367,26 @@ class ExecutorMonitorSuite extends SparkFunSuite {
assert(monitor.timedOutExecutors(idleDeadline).toSet === Set("1", "2"))
}
+ test("SPARK-28455: avoid overflow in timeout calculation") {
+ conf
+ .set(DYN_ALLOCATION_SHUFFLE_TIMEOUT, Long.MaxValue)
+ .set(DYN_ALLOCATION_SHUFFLE_TRACKING, true)
+ .set(SHUFFLE_SERVICE_ENABLED, false)
+ monitor = new ExecutorMonitor(conf, client, null, clock)
+
+ // Generate events that will make executor 1 be idle, while still holding shuffle data.
+ // The executor should not be eligible for removal since the timeout is basically "infinite".
+ val stage = stageInfo(1, shuffleId = 0)
+ monitor.onJobStart(SparkListenerJobStart(1, clock.getTimeMillis(), Seq(stage)))
+ clock.advance(1000L)
+ monitor.onExecutorAdded(SparkListenerExecutorAdded(clock.getTimeMillis(), "1", null))
+ monitor.onTaskStart(SparkListenerTaskStart(1, 0, taskInfo("1", 1)))
+ monitor.onTaskEnd(SparkListenerTaskEnd(1, 0, "foo", Success, taskInfo("1", 1), null))
+ monitor.onJobEnd(SparkListenerJobEnd(1, clock.getTimeMillis(), JobSucceeded))
+
+ assert(monitor.timedOutExecutors(idleDeadline).isEmpty)
+ }
+
private def idleDeadline: Long = clock.getTimeMillis() + idleTimeoutMs + 1
private def storageDeadline: Long = clock.getTimeMillis() + storageTimeoutMs + 1
private def shuffleDeadline: Long = clock.getTimeMillis() + shuffleTimeoutMs + 1
diff --git a/docs/README.md b/docs/README.md
index 670e9e01130df..da531321aa5da 100644
--- a/docs/README.md
+++ b/docs/README.md
@@ -84,7 +84,7 @@ $ PRODUCTION=1 jekyll build
## API Docs (Scaladoc, Javadoc, Sphinx, roxygen2, MkDocs)
-You can build just the Spark scaladoc and javadoc by running `build/sbt unidoc` from the `$SPARK_HOME` directory.
+You can build just the Spark scaladoc and javadoc by running `./build/sbt unidoc` from the `$SPARK_HOME` directory.
Similarly, you can build just the PySpark docs by running `make html` from the
`$SPARK_HOME/python/docs` directory. Documentation is only generated for classes that are listed as
@@ -94,7 +94,7 @@ after [building Spark](https://github.com/apache/spark#building-spark) first.
When you run `jekyll build` in the `docs` directory, it will also copy over the scaladoc and javadoc for the various
Spark subprojects into the `docs` directory (and then also into the `_site` directory). We use a
-jekyll plugin to run `build/sbt unidoc` before building the site so if you haven't run it (recently) it
+jekyll plugin to run `./build/sbt unidoc` before building the site so if you haven't run it (recently) it
may take some time as it generates all of the scaladoc and javadoc using [Unidoc](https://github.com/sbt/sbt-unidoc).
The jekyll plugin also generates the PySpark docs using [Sphinx](http://sphinx-doc.org/), SparkR docs
using [roxygen2](https://cran.r-project.org/web/packages/roxygen2/index.html) and SQL docs
diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md
index d4efb52e0fbba..769eed1e6f6b7 100644
--- a/docs/running-on-kubernetes.md
+++ b/docs/running-on-kubernetes.md
@@ -121,7 +121,7 @@ $ ./bin/docker-image-tool.sh -r -t my-tag -R ./kubernetes/dockerfiles/spa
To launch Spark Pi in cluster mode,
```bash
-$ bin/spark-submit \
+$ ./bin/spark-submit \
--master k8s://https://: \
--deploy-mode cluster \
--name spark-pi \
diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md
index 907f414e5dc4c..cf51620a700bc 100644
--- a/docs/running-on-mesos.md
+++ b/docs/running-on-mesos.md
@@ -212,7 +212,7 @@ protected (port 7077 by default).
By setting the Mesos proxy config property (requires mesos version >= 1.4), `--conf spark.mesos.proxy.baseURL=http://localhost:5050` when launching the dispatcher, the mesos sandbox URI for each driver is added to the mesos dispatcher UI.
-If you like to run the `MesosClusterDispatcher` with Marathon, you need to run the `MesosClusterDispatcher` in the foreground (i.e: `bin/spark-class org.apache.spark.deploy.mesos.MesosClusterDispatcher`). Note that the `MesosClusterDispatcher` not yet supports multiple instances for HA.
+If you like to run the `MesosClusterDispatcher` with Marathon, you need to run the `MesosClusterDispatcher` in the foreground (i.e: `./bin/spark-class org.apache.spark.deploy.mesos.MesosClusterDispatcher`). Note that the `MesosClusterDispatcher` not yet supports multiple instances for HA.
The `MesosClusterDispatcher` also supports writing recovery state into Zookeeper. This will allow the `MesosClusterDispatcher` to be able to recover all submitted and running containers on relaunch. In order to enable this recovery mode, you can set SPARK_DAEMON_JAVA_OPTS in spark-env by configuring `spark.deploy.recoveryMode` and related spark.deploy.zookeeper.* configurations.
For more information about these configurations please refer to the configurations [doc](configuration.html#deploy).
@@ -362,7 +362,7 @@ The External Shuffle Service to use is the Mesos Shuffle Service. It provides sh
on top of the Shuffle Service since Mesos doesn't yet support notifying another framework's
termination. To launch it, run `$SPARK_HOME/sbin/start-mesos-shuffle-service.sh` on all slave nodes, with `spark.shuffle.service.enabled` set to `true`.
-This can also be achieved through Marathon, using a unique host constraint, and the following command: `bin/spark-class org.apache.spark.deploy.mesos.MesosExternalShuffleService`.
+This can also be achieved through Marathon, using a unique host constraint, and the following command: `./bin/spark-class org.apache.spark.deploy.mesos.MesosExternalShuffleService`.
# Configuration
diff --git a/docs/sql-data-sources-jdbc.md b/docs/sql-data-sources-jdbc.md
index 3119ec004b2a1..c3502cbdea8e7 100644
--- a/docs/sql-data-sources-jdbc.md
+++ b/docs/sql-data-sources-jdbc.md
@@ -36,7 +36,7 @@ spark classpath. For example, to connect to postgres from the Spark Shell you wo
following command:
{% highlight bash %}
-bin/spark-shell --driver-class-path postgresql-9.4.1207.jar --jars postgresql-9.4.1207.jar
+./bin/spark-shell --driver-class-path postgresql-9.4.1207.jar --jars postgresql-9.4.1207.jar
{% endhighlight %}
Tables from the remote database can be loaded as a DataFrame or Spark SQL temporary view using
diff --git a/docs/sql-migration-guide-upgrade.md b/docs/sql-migration-guide-upgrade.md
index f13d298674b2d..e9d99b66353e2 100644
--- a/docs/sql-migration-guide-upgrade.md
+++ b/docs/sql-migration-guide-upgrade.md
@@ -151,7 +151,9 @@ license: |
- Since Spark 3.0, substitution order of nested WITH clauses is changed and an inner CTE definition takes precedence over an outer. In version 2.4 and earlier, `WITH t AS (SELECT 1), t2 AS (WITH t AS (SELECT 2) SELECT * FROM t) SELECT * FROM t2` returns `1` while in version 3.0 it returns `2`. The previous behaviour can be restored by setting `spark.sql.legacy.ctePrecedence.enabled` to `true`.
- - Since Spark 3.0, the `add_months` function adjusts the resulting date to a last day of month only if it is invalid. For example, `select add_months(DATE'2019-01-31', 1)` results `2019-02-28`. In Spark version 2.4 and earlier, the resulting date is adjusted when it is invalid, or the original date is a last day of months. For example, adding a month to `2019-02-28` resultes in `2019-03-31`.
+ - Since Spark 3.0, the `add_months` function does not adjust the resulting date to a last day of month if the original date is a last day of months. For example, `select add_months(DATE'2019-02-28', 1)` results `2019-03-28`. In Spark version 2.4 and earlier, the resulting date is adjusted when the original date is a last day of months. For example, adding a month to `2019-02-28` results in `2019-03-31`.
+
+ - Since Spark 3.0, 0-argument Java UDF is executed in the executor side identically with other UDFs. In Spark version 2.4 and earlier, 0-argument Java UDF alone was executed in the driver side, and the result was propagated to executors, which might be more performant in some cases but caused inconsistency with a correctness issue in some cases.
## Upgrading from Spark SQL 2.4 to 2.4.1
diff --git a/docs/streaming-kinesis-integration.md b/docs/streaming-kinesis-integration.md
index fd6d776045cd7..55acec53302e4 100644
--- a/docs/streaming-kinesis-integration.md
+++ b/docs/streaming-kinesis-integration.md
@@ -222,17 +222,17 @@ To run the example,
- bin/run-example --packages org.apache.spark:spark-streaming-kinesis-asl_{{site.SCALA_BINARY_VERSION}}:{{site.SPARK_VERSION_SHORT}} streaming.KinesisWordCountASL [Kinesis app name] [Kinesis stream name] [endpoint URL]
+ ./bin/run-example --packages org.apache.spark:spark-streaming-kinesis-asl_{{site.SCALA_BINARY_VERSION}}:{{site.SPARK_VERSION_SHORT}} streaming.KinesisWordCountASL [Kinesis app name] [Kinesis stream name] [endpoint URL]
- bin/run-example --packages org.apache.spark:spark-streaming-kinesis-asl_{{site.SCALA_BINARY_VERSION}}:{{site.SPARK_VERSION_SHORT}} streaming.JavaKinesisWordCountASL [Kinesis app name] [Kinesis stream name] [endpoint URL]
+ ./bin/run-example --packages org.apache.spark:spark-streaming-kinesis-asl_{{site.SCALA_BINARY_VERSION}}:{{site.SPARK_VERSION_SHORT}} streaming.JavaKinesisWordCountASL [Kinesis app name] [Kinesis stream name] [endpoint URL]
- bin/spark-submit --jars external/kinesis-asl/target/scala-*/\
+ ./bin/spark-submit --jars external/kinesis-asl/target/scala-*/\
spark-streaming-kinesis-asl-assembly_*.jar \
external/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py \
[Kinesis app name] [Kinesis stream name] [endpoint URL] [region name]
@@ -244,7 +244,7 @@ To run the example,
- To generate random string data to put onto the Kinesis stream, in another terminal, run the associated Kinesis data producer.
- bin/run-example streaming.KinesisWordProducerASL [Kinesis stream name] [endpoint URL] 1000 10
+ ./bin/run-example streaming.KinesisWordProducerASL [Kinesis stream name] [endpoint URL] 1000 10
This will push 1000 lines per second of 10 random numbers per line to the Kinesis stream. This data should then be received and processed by the running example.
diff --git a/docs/structured-streaming-kafka-integration.md b/docs/structured-streaming-kafka-integration.md
index fe3c60040d0a0..b0009e01703bf 100644
--- a/docs/structured-streaming-kafka-integration.md
+++ b/docs/structured-streaming-kafka-integration.md
@@ -388,6 +388,16 @@ The following configurations are optional:
streaming and batch |
Rate limit on maximum number of offsets processed per trigger interval. The specified total number of offsets will be proportionally split across topicPartitions of different volume. |
+
+ minPartitions |
+ int |
+ none |
+ streaming and batch |
+ Minimum number of partitions to read from Kafka.
+ By default, Spark has a 1-1 mapping of topicPartitions to Spark partitions consuming from Kafka.
+ If you set this option to a value greater than your topicPartitions, Spark will divvy up large
+ Kafka partitions to smaller pieces. |
+
groupIdPrefix |
string |
diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md
index 925e2cfe717c0..821225753320d 100644
--- a/docs/structured-streaming-programming-guide.md
+++ b/docs/structured-streaming-programming-guide.md
@@ -511,7 +511,7 @@ returned by `SparkSession.readStream()`. In [R](api/R/read.stream.html), with th
There are a few built-in sources.
- **File source** - Reads files written in a directory as a stream of data. Files will be processed in the order of file modification time. If `latestFirst` is set, order will be reversed. Supported file formats are text, CSV, JSON, ORC, Parquet. See the docs of the DataStreamReader interface for a more up-to-date list, and supported options for each file format. Note that the files must be atomically placed in the given directory, which in most file systems, can be achieved by file move operations.
- - **Kafka source** - Reads data from Kafka. It's compatible with Kafka broker versions 0.10.0 or higher. See the [Kafka Integration Guide](structured-streaming-kafka-0-10-integration.html) for more details.
+ - **Kafka source** - Reads data from Kafka. It's compatible with Kafka broker versions 0.10.0 or higher. See the [Kafka Integration Guide](structured-streaming-kafka-integration.html) for more details.
- **Socket source (for testing)** - Reads UTF8 text data from a socket connection. The listening server socket is at the driver. Note that this should be used only for testing as this does not provide end-to-end fault-tolerance guarantees.
@@ -582,7 +582,7 @@ Here are the details of all the sources in Spark.
Kafka Source |
- See the Kafka Integration Guide.
+ See the Kafka Integration Guide.
|
Yes |
|
@@ -1835,7 +1835,7 @@ Here are the details of all the sinks in Spark.
Kafka Sink |
Append, Update, Complete |
- See the Kafka Integration Guide |
+ See the Kafka Integration Guide |
Yes (at-least-once) |
More details in the Kafka Integration Guide |
diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
index 924bf374c7370..a7c9e3fb7d329 100644
--- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
+++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
@@ -1001,14 +1001,14 @@ abstract class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUti
sql("select interval 1 days").write.format("avro").mode("overwrite").save(tempDir)
}.getMessage
assert(msg.contains("Cannot save interval data type into external storage.") ||
- msg.contains("AVRO data source does not support calendarinterval data type."))
+ msg.contains("AVRO data source does not support interval data type."))
msg = intercept[AnalysisException] {
spark.udf.register("testType", () => new IntervalData())
sql("select testType()").write.format("avro").mode("overwrite").save(tempDir)
}.getMessage
assert(msg.toLowerCase(Locale.ROOT)
- .contains(s"avro data source does not support calendarinterval data type."))
+ .contains(s"avro data source does not support interval data type."))
}
}
diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala
index 462f88ff14a8d..89da9a1de6f74 100644
--- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala
+++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala
@@ -206,4 +206,17 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite {
""".stripMargin.replaceAll("\n", " "))
assert(sql("select c1, c3 from queryOption").collect.toSet == expectedResult)
}
+
+ test("write byte as smallint") {
+ sqlContext.createDataFrame(Seq((1.toByte, 2.toShort)))
+ .write.jdbc(jdbcUrl, "byte_to_smallint_test", new Properties)
+ val df = sqlContext.read.jdbc(jdbcUrl, "byte_to_smallint_test", new Properties)
+ val schema = df.schema
+ assert(schema.head.dataType == ShortType)
+ assert(schema(1).dataType == ShortType)
+ val rows = df.collect()
+ assert(rows.length === 1)
+ assert(rows(0).getShort(0) === 1)
+ assert(rows(0).getShort(1) === 2)
+ }
}
diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala
index 6e950f968a65d..6e43d60bd03a3 100644
--- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala
+++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala
@@ -603,6 +603,15 @@ class SparseVector @Since("2.0.0") (
private[spark] override def asBreeze: BV[Double] = new BSV[Double](indices, values, size)
+ override def apply(i: Int): Double = {
+ if (i < 0 || i >= size) {
+ throw new IndexOutOfBoundsException(s"Index $i out of bounds [0, $size)")
+ }
+
+ val j = util.Arrays.binarySearch(indices, i)
+ if (j < 0) 0.0 else values(j)
+ }
+
override def foreachActive(f: (Int, Double) => Unit): Unit = {
var i = 0
val localValuesSize = values.length
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index 9cdf1944329b8..b754fad0c1796 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -785,6 +785,15 @@ class SparseVector @Since("1.0.0") (
private[spark] override def asBreeze: BV[Double] = new BSV[Double](indices, values, size)
+ override def apply(i: Int): Double = {
+ if (i < 0 || i >= size) {
+ throw new IndexOutOfBoundsException(s"Index $i out of bounds [0, $size)")
+ }
+
+ val j = util.Arrays.binarySearch(indices, i)
+ if (j < 0) 0.0 else values(j)
+ }
+
@Since("1.6.0")
override def foreachActive(f: (Int, Double) => Unit): Unit = {
var i = 0
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala
index a660492c7ae59..03afd29e47505 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.mllib.evaluation
import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.rdd.RDD
class MultilabelMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
@@ -79,24 +80,24 @@ class MultilabelMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
val hammingLoss = (1.0 / (7 * 3)) * (2 + 2 + 1 + 0 + 0 + 1 + 1)
val strictAccuracy = 2.0 / 7
val accuracy = 1.0 / 7 * (1.0 / 3 + 1.0 /3 + 0 + 1.0 / 1 + 2.0 / 2 + 2.0 / 3 + 1.0 / 2)
- assert(math.abs(metrics.precision(0.0) - precision0) < delta)
- assert(math.abs(metrics.precision(1.0) - precision1) < delta)
- assert(math.abs(metrics.precision(2.0) - precision2) < delta)
- assert(math.abs(metrics.recall(0.0) - recall0) < delta)
- assert(math.abs(metrics.recall(1.0) - recall1) < delta)
- assert(math.abs(metrics.recall(2.0) - recall2) < delta)
- assert(math.abs(metrics.f1Measure(0.0) - f1measure0) < delta)
- assert(math.abs(metrics.f1Measure(1.0) - f1measure1) < delta)
- assert(math.abs(metrics.f1Measure(2.0) - f1measure2) < delta)
- assert(math.abs(metrics.microPrecision - microPrecisionClass) < delta)
- assert(math.abs(metrics.microRecall - microRecallClass) < delta)
- assert(math.abs(metrics.microF1Measure - microF1MeasureClass) < delta)
- assert(math.abs(metrics.precision - macroPrecisionDoc) < delta)
- assert(math.abs(metrics.recall - macroRecallDoc) < delta)
- assert(math.abs(metrics.f1Measure - macroF1MeasureDoc) < delta)
- assert(math.abs(metrics.hammingLoss - hammingLoss) < delta)
- assert(math.abs(metrics.subsetAccuracy - strictAccuracy) < delta)
- assert(math.abs(metrics.accuracy - accuracy) < delta)
+ assert(metrics.precision(0.0) ~== precision0 absTol delta)
+ assert(metrics.precision(1.0) ~== precision1 absTol delta)
+ assert(metrics.precision(2.0) ~== precision2 absTol delta)
+ assert(metrics.recall(0.0) ~== recall0 absTol delta)
+ assert(metrics.recall(1.0) ~== recall1 absTol delta)
+ assert(metrics.recall(2.0) ~== recall2 absTol delta)
+ assert(metrics.f1Measure(0.0) ~== f1measure0 absTol delta)
+ assert(metrics.f1Measure(1.0) ~== f1measure1 absTol delta)
+ assert(metrics.f1Measure(2.0) ~== f1measure2 absTol delta)
+ assert(metrics.microPrecision ~== microPrecisionClass absTol delta)
+ assert(metrics.microRecall ~== microRecallClass absTol delta)
+ assert(metrics.microF1Measure ~== microF1MeasureClass absTol delta)
+ assert(metrics.precision ~== macroPrecisionDoc absTol delta)
+ assert(metrics.recall ~== macroRecallDoc absTol delta)
+ assert(metrics.f1Measure ~== macroF1MeasureDoc absTol delta)
+ assert(metrics.hammingLoss ~== hammingLoss absTol delta)
+ assert(metrics.subsetAccuracy ~== strictAccuracy absTol delta)
+ assert(metrics.accuracy ~== accuracy absTol delta)
assert(metrics.labels.sameElements(Array(0.0, 1.0, 2.0)))
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala
index dcb1f398b04b8..26a75699248d0 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala
@@ -18,6 +18,7 @@ package org.apache.spark.mllib.fpm
import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
class AssociationRulesSuite extends SparkFunSuite with MLlibTestSparkContext {
@@ -63,7 +64,7 @@ class AssociationRulesSuite extends SparkFunSuite with MLlibTestSparkContext {
[1] 23
*/
assert(results1.size === 23)
- assert(results1.count(rule => math.abs(rule.confidence - 1.0D) < 1e-6) == 23)
+ assert(results1.count(rule => rule.confidence ~= 1.0D absTol 1e-6) == 23)
val results2 = ar
.setMinConfidence(0)
@@ -84,7 +85,7 @@ class AssociationRulesSuite extends SparkFunSuite with MLlibTestSparkContext {
[1] 23
*/
assert(results2.size === 30)
- assert(results2.count(rule => math.abs(rule.confidence - 1.0D) < 1e-6) == 23)
+ assert(results2.count(rule => rule.confidence ~= 1.0D absTol 1e-6) == 23)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
index 20bd2e5e0dc17..fa8f03be089ce 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
@@ -18,6 +18,7 @@ package org.apache.spark.mllib.fpm
import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.util.Utils
class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext {
@@ -172,7 +173,7 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext {
.collect()
assert(rules.size === 23)
- assert(rules.count(rule => math.abs(rule.confidence - 1.0D) < 1e-6) == 23)
+ assert(rules.count(rule => rule.confidence ~= 1.0D absTol 1e-6) == 23)
}
test("FP-Growth using Int type") {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala
index 566ce95be084a..cca4eb4e4260e 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala
@@ -22,6 +22,7 @@ import breeze.linalg.{diag => brzDiag, DenseMatrix => BDM, DenseVector => BDV}
import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.rdd.RDD
class IndexedRowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext {
@@ -238,7 +239,7 @@ class IndexedRowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext {
for (i <- 0 until n; j <- i + 1 until n) {
val trueResult = gram(i, j) / scala.math.sqrt(gram(i, i) * gram(j, j))
- assert(math.abs(G(i, j) - trueResult) < 1e-6)
+ assert(G(i, j) ~== trueResult absTol 1e-6)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala
index e30ad159676ff..8011026e6fa65 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala
@@ -20,9 +20,9 @@ package org.apache.spark.mllib.random
import org.apache.commons.math3.special.Gamma
import org.apache.spark.SparkFunSuite
+import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.util.StatCounter
-// TODO update tests to use TestingUtils for floating point comparison after PR 1367 is merged
class RandomDataGeneratorSuite extends SparkFunSuite {
def apiChecks(gen: RandomDataGenerator[Double]) {
@@ -61,8 +61,8 @@ class RandomDataGeneratorSuite extends SparkFunSuite {
gen.setSeed(seed.toLong)
val sample = (0 until 100000).map { _ => gen.nextValue()}
val stats = new StatCounter(sample)
- assert(math.abs(stats.mean - mean) < epsilon)
- assert(math.abs(stats.stdev - stddev) < epsilon)
+ assert(stats.mean ~== mean absTol epsilon)
+ assert(stats.stdev ~== stddev absTol epsilon)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala
index f464d25c3fbda..9b4dc29d326a1 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala
@@ -23,14 +23,13 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.rdd.{RandomRDD, RandomRDDPartition}
import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.rdd.RDD
import org.apache.spark.util.StatCounter
/*
* Note: avoid including APIs that do not set the seed for the RNG in unit tests
* in order to guarantee deterministic behavior.
- *
- * TODO update tests to use TestingUtils for floating point comparison after PR 1367 is merged
*/
class RandomRDDsSuite extends SparkFunSuite with MLlibTestSparkContext with Serializable {
@@ -43,8 +42,8 @@ class RandomRDDsSuite extends SparkFunSuite with MLlibTestSparkContext with Seri
val stats = rdd.stats()
assert(expectedSize === stats.count)
assert(expectedNumPartitions === rdd.partitions.size)
- assert(math.abs(stats.mean - expectedMean) < epsilon)
- assert(math.abs(stats.stdev - expectedStddev) < epsilon)
+ assert(stats.mean ~== expectedMean absTol epsilon)
+ assert(stats.stdev ~== expectedStddev absTol epsilon)
}
// assume test RDDs are small
@@ -63,8 +62,8 @@ class RandomRDDsSuite extends SparkFunSuite with MLlibTestSparkContext with Seri
}}
assert(expectedRows === values.size / expectedColumns)
val stats = new StatCounter(values)
- assert(math.abs(stats.mean - expectedMean) < epsilon)
- assert(math.abs(stats.stdev - expectedStddev) < epsilon)
+ assert(stats.mean ~== expectedMean absTol epsilon)
+ assert(stats.stdev ~== expectedStddev absTol epsilon)
}
test("RandomRDD sizes") {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala
index e32767edb17a8..4613f7fb6f400 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala
@@ -26,6 +26,7 @@ import org.apache.spark.mllib.random.RandomRDDs
import org.apache.spark.mllib.stat.correlation.{Correlations, PearsonCorrelation,
SpearmanCorrelation}
import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
class CorrelationSuite extends SparkFunSuite with MLlibTestSparkContext with Logging {
@@ -57,15 +58,15 @@ class CorrelationSuite extends SparkFunSuite with MLlibTestSparkContext with Log
val expected = 0.6546537
val default = Statistics.corr(x, y)
val p1 = Statistics.corr(x, y, "pearson")
- assert(approxEqual(expected, default))
- assert(approxEqual(expected, p1))
+ assert(expected ~== default absTol 1e-6)
+ assert(expected ~== p1 absTol 1e-6)
// numPartitions >= size for input RDDs
for (numParts <- List(xData.size, xData.size * 2)) {
val x1 = sc.parallelize(xData, numParts)
val y1 = sc.parallelize(yData, numParts)
val p2 = Statistics.corr(x1, y1)
- assert(approxEqual(expected, p2))
+ assert(expected ~== p2 absTol 1e-6)
}
// RDD of zero variance
@@ -78,14 +79,14 @@ class CorrelationSuite extends SparkFunSuite with MLlibTestSparkContext with Log
val y = sc.parallelize(yData)
val expected = 0.5
val s1 = Statistics.corr(x, y, "spearman")
- assert(approxEqual(expected, s1))
+ assert(expected ~== s1 absTol 1e-6)
// numPartitions >= size for input RDDs
for (numParts <- List(xData.size, xData.size * 2)) {
val x1 = sc.parallelize(xData, numParts)
val y1 = sc.parallelize(yData, numParts)
val s2 = Statistics.corr(x1, y1, "spearman")
- assert(approxEqual(expected, s2))
+ assert(expected ~== s2 absTol 1e-6)
}
// RDD of zero variance => zero variance in ranks
@@ -141,14 +142,14 @@ class CorrelationSuite extends SparkFunSuite with MLlibTestSparkContext with Log
val a = RandomRDDs.normalRDD(sc, 100000, 10).map(_ + 1000000000.0)
val b = RandomRDDs.normalRDD(sc, 100000, 10).map(_ + 1000000000.0)
val p = Statistics.corr(a, b, method = "pearson")
- assert(approxEqual(p, 0.0, 0.01))
+ assert(p ~== 0.0 absTol 0.01)
}
def approxEqual(v1: Double, v2: Double, threshold: Double = 1e-6): Boolean = {
if (v1.isNaN) {
v2.isNaN
} else {
- math.abs(v1 - v2) <= threshold
+ v1 ~== v2 absTol threshold
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala
index 5feccdf33681a..9cbb3d0024daa 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala
@@ -21,6 +21,7 @@ import org.apache.commons.math3.distribution.NormalDistribution
import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
class KernelDensitySuite extends SparkFunSuite with MLlibTestSparkContext {
test("kernel density single sample") {
@@ -29,8 +30,8 @@ class KernelDensitySuite extends SparkFunSuite with MLlibTestSparkContext {
val densities = new KernelDensity().setSample(rdd).setBandwidth(3.0).estimate(evaluationPoints)
val normal = new NormalDistribution(5.0, 3.0)
val acceptableErr = 1e-6
- assert(math.abs(densities(0) - normal.density(5.0)) < acceptableErr)
- assert(math.abs(densities(1) - normal.density(6.0)) < acceptableErr)
+ assert(densities(0) ~== normal.density(5.0) absTol acceptableErr)
+ assert(densities(1) ~== normal.density(6.0) absTol acceptableErr)
}
test("kernel density multiple samples") {
@@ -40,9 +41,9 @@ class KernelDensitySuite extends SparkFunSuite with MLlibTestSparkContext {
val normal1 = new NormalDistribution(5.0, 3.0)
val normal2 = new NormalDistribution(10.0, 3.0)
val acceptableErr = 1e-6
- assert(math.abs(
- densities(0) - (normal1.density(5.0) + normal2.density(5.0)) / 2) < acceptableErr)
- assert(math.abs(
- densities(1) - (normal1.density(6.0) + normal2.density(6.0)) / 2) < acceptableErr)
+ assert(
+ densities(0) ~== ((normal1.density(5.0) + normal2.density(5.0)) / 2) absTol acceptableErr)
+ assert(
+ densities(1) ~== ((normal1.density(6.0) + normal2.density(6.0)) / 2) absTol acceptableErr)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala
index 1cc8f342021a0..d43e62bb65535 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala
@@ -22,6 +22,7 @@ import scala.collection.mutable
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.model.TreeEnsembleModel
+import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.util.StatCounter
object EnsembleTestHelper {
@@ -43,8 +44,8 @@ object EnsembleTestHelper {
values ++= row
}
val stats = new StatCounter(values)
- assert(math.abs(stats.mean - expectedMean) < epsilon)
- assert(math.abs(stats.stdev - expectedStddev) < epsilon)
+ assert(stats.mean ~== expectedMean absTol epsilon)
+ assert(stats.stdev ~== expectedStddev absTol epsilon)
}
def validateClassifier(
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index c257ace02cfe9..ce6543952bf6d 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -874,13 +874,6 @@ class TreeClassifierParams(object):
def __init__(self):
super(TreeClassifierParams, self).__init__()
- @since("1.6.0")
- def setImpurity(self, value):
- """
- Sets the value of :py:attr:`impurity`.
- """
- return self._set(impurity=value)
-
@since("1.6.0")
def getImpurity(self):
"""
@@ -1003,6 +996,49 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre
def _create_model(self, java_model):
return DecisionTreeClassificationModel(java_model)
+ def setMaxDepth(self, value):
+ """
+ Sets the value of :py:attr:`maxDepth`.
+ """
+ return self._set(maxDepth=value)
+
+ def setMaxBins(self, value):
+ """
+ Sets the value of :py:attr:`maxBins`.
+ """
+ return self._set(maxBins=value)
+
+ def setMinInstancesPerNode(self, value):
+ """
+ Sets the value of :py:attr:`minInstancesPerNode`.
+ """
+ return self._set(minInstancesPerNode=value)
+
+ def setMinInfoGain(self, value):
+ """
+ Sets the value of :py:attr:`minInfoGain`.
+ """
+ return self._set(minInfoGain=value)
+
+ def setMaxMemoryInMB(self, value):
+ """
+ Sets the value of :py:attr:`maxMemoryInMB`.
+ """
+ return self._set(maxMemoryInMB=value)
+
+ def setCacheNodeIds(self, value):
+ """
+ Sets the value of :py:attr:`cacheNodeIds`.
+ """
+ return self._set(cacheNodeIds=value)
+
+ @since("1.4.0")
+ def setImpurity(self, value):
+ """
+ Sets the value of :py:attr:`impurity`.
+ """
+ return self._set(impurity=value)
+
@inherit_doc
class DecisionTreeClassificationModel(DecisionTreeModel, JavaClassificationModel, JavaMLWritable,
@@ -1133,6 +1169,63 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre
def _create_model(self, java_model):
return RandomForestClassificationModel(java_model)
+ def setMaxDepth(self, value):
+ """
+ Sets the value of :py:attr:`maxDepth`.
+ """
+ return self._set(maxDepth=value)
+
+ def setMaxBins(self, value):
+ """
+ Sets the value of :py:attr:`maxBins`.
+ """
+ return self._set(maxBins=value)
+
+ def setMinInstancesPerNode(self, value):
+ """
+ Sets the value of :py:attr:`minInstancesPerNode`.
+ """
+ return self._set(minInstancesPerNode=value)
+
+ def setMinInfoGain(self, value):
+ """
+ Sets the value of :py:attr:`minInfoGain`.
+ """
+ return self._set(minInfoGain=value)
+
+ def setMaxMemoryInMB(self, value):
+ """
+ Sets the value of :py:attr:`maxMemoryInMB`.
+ """
+ return self._set(maxMemoryInMB=value)
+
+ def setCacheNodeIds(self, value):
+ """
+ Sets the value of :py:attr:`cacheNodeIds`.
+ """
+ return self._set(cacheNodeIds=value)
+
+ @since("1.4.0")
+ def setImpurity(self, value):
+ """
+ Sets the value of :py:attr:`impurity`.
+ """
+ return self._set(impurity=value)
+
+ @since("1.4.0")
+ def setNumTrees(self, value):
+ """
+ Sets the value of :py:attr:`numTrees`.
+ """
+ return self._set(numTrees=value)
+
+ @since("1.4.0")
+ def setSubsamplingRate(self, value):
+ """
+ Sets the value of :py:attr:`subsamplingRate`.
+ """
+ return self._set(subsamplingRate=value)
+
@since("2.4.0")
def setFeatureSubsetStrategy(self, value):
"""
@@ -1317,6 +1410,49 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre
def _create_model(self, java_model):
return GBTClassificationModel(java_model)
+ def setMaxDepth(self, value):
+ """
+ Sets the value of :py:attr:`maxDepth`.
+ """
+ return self._set(maxDepth=value)
+
+ def setMaxBins(self, value):
+ """
+ Sets the value of :py:attr:`maxBins`.
+ """
+ return self._set(maxBins=value)
+
+ def setMinInstancesPerNode(self, value):
+ """
+ Sets the value of :py:attr:`minInstancesPerNode`.
+ """
+ return self._set(minInstancesPerNode=value)
+
+ def setMinInfoGain(self, value):
+ """
+ Sets the value of :py:attr:`minInfoGain`.
+ """
+ return self._set(minInfoGain=value)
+
+ def setMaxMemoryInMB(self, value):
+ """
+ Sets the value of :py:attr:`maxMemoryInMB`.
+ """
+ return self._set(maxMemoryInMB=value)
+
+ def setCacheNodeIds(self, value):
+ """
+ Sets the value of :py:attr:`cacheNodeIds`.
+ """
+ return self._set(cacheNodeIds=value)
+
+ @since("1.4.0")
+ def setImpurity(self, value):
+ """
+ Sets the value of :py:attr:`impurity`.
+ """
+ return self._set(impurity=value)
+
@since("1.4.0")
def setLossType(self, value):
"""
@@ -1324,6 +1460,13 @@ def setLossType(self, value):
"""
return self._set(lossType=value)
+ @since("1.4.0")
+ def setSubsamplingRate(self, value):
+ """
+ Sets the value of :py:attr:`subsamplingRate`.
+ """
+ return self._set(subsamplingRate=value)
+
@since("2.4.0")
def setFeatureSubsetStrategy(self, value):
"""
diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py
index 6405b9fce7efb..56d6190723161 100644
--- a/python/pyspark/ml/param/shared.py
+++ b/python/pyspark/ml/param/shared.py
@@ -765,72 +765,36 @@ class DecisionTreeParams(Params):
def __init__(self):
super(DecisionTreeParams, self).__init__()
- def setMaxDepth(self, value):
- """
- Sets the value of :py:attr:`maxDepth`.
- """
- return self._set(maxDepth=value)
-
def getMaxDepth(self):
"""
Gets the value of maxDepth or its default value.
"""
return self.getOrDefault(self.maxDepth)
- def setMaxBins(self, value):
- """
- Sets the value of :py:attr:`maxBins`.
- """
- return self._set(maxBins=value)
-
def getMaxBins(self):
"""
Gets the value of maxBins or its default value.
"""
return self.getOrDefault(self.maxBins)
- def setMinInstancesPerNode(self, value):
- """
- Sets the value of :py:attr:`minInstancesPerNode`.
- """
- return self._set(minInstancesPerNode=value)
-
def getMinInstancesPerNode(self):
"""
Gets the value of minInstancesPerNode or its default value.
"""
return self.getOrDefault(self.minInstancesPerNode)
- def setMinInfoGain(self, value):
- """
- Sets the value of :py:attr:`minInfoGain`.
- """
- return self._set(minInfoGain=value)
-
def getMinInfoGain(self):
"""
Gets the value of minInfoGain or its default value.
"""
return self.getOrDefault(self.minInfoGain)
- def setMaxMemoryInMB(self, value):
- """
- Sets the value of :py:attr:`maxMemoryInMB`.
- """
- return self._set(maxMemoryInMB=value)
-
def getMaxMemoryInMB(self):
"""
Gets the value of maxMemoryInMB or its default value.
"""
return self.getOrDefault(self.maxMemoryInMB)
- def setCacheNodeIds(self, value):
- """
- Sets the value of :py:attr:`cacheNodeIds`.
- """
- return self._set(cacheNodeIds=value)
-
def getCacheNodeIds(self):
"""
Gets the value of cacheNodeIds or its default value.
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index 927cc77e201a5..349130f22fade 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -609,13 +609,6 @@ class TreeEnsembleParams(DecisionTreeParams):
def __init__(self):
super(TreeEnsembleParams, self).__init__()
- @since("1.4.0")
- def setSubsamplingRate(self, value):
- """
- Sets the value of :py:attr:`subsamplingRate`.
- """
- return self._set(subsamplingRate=value)
-
@since("1.4.0")
def getSubsamplingRate(self):
"""
@@ -623,15 +616,6 @@ def getSubsamplingRate(self):
"""
return self.getOrDefault(self.subsamplingRate)
- @since("1.4.0")
- def setFeatureSubsetStrategy(self, value):
- """
- Sets the value of :py:attr:`featureSubsetStrategy`.
-
- .. note:: Deprecated in 2.4.0 and will be removed in 3.0.0.
- """
- return self._set(featureSubsetStrategy=value)
-
@since("1.4.0")
def getFeatureSubsetStrategy(self):
"""
@@ -655,13 +639,6 @@ class HasVarianceImpurity(Params):
def __init__(self):
super(HasVarianceImpurity, self).__init__()
- @since("1.4.0")
- def setImpurity(self, value):
- """
- Sets the value of :py:attr:`impurity`.
- """
- return self._set(impurity=value)
-
@since("1.4.0")
def getImpurity(self):
"""
@@ -685,13 +662,6 @@ class RandomForestParams(TreeEnsembleParams):
def __init__(self):
super(RandomForestParams, self).__init__()
- @since("1.4.0")
- def setNumTrees(self, value):
- """
- Sets the value of :py:attr:`numTrees`.
- """
- return self._set(numTrees=value)
-
@since("1.4.0")
def getNumTrees(self):
"""
@@ -843,6 +813,49 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre
def _create_model(self, java_model):
return DecisionTreeRegressionModel(java_model)
+ def setMaxDepth(self, value):
+ """
+ Sets the value of :py:attr:`maxDepth`.
+ """
+ return self._set(maxDepth=value)
+
+ def setMaxBins(self, value):
+ """
+ Sets the value of :py:attr:`maxBins`.
+ """
+ return self._set(maxBins=value)
+
+ def setMinInstancesPerNode(self, value):
+ """
+ Sets the value of :py:attr:`minInstancesPerNode`.
+ """
+ return self._set(minInstancesPerNode=value)
+
+ def setMinInfoGain(self, value):
+ """
+ Sets the value of :py:attr:`minInfoGain`.
+ """
+ return self._set(minInfoGain=value)
+
+ def setMaxMemoryInMB(self, value):
+ """
+ Sets the value of :py:attr:`maxMemoryInMB`.
+ """
+ return self._set(maxMemoryInMB=value)
+
+ def setCacheNodeIds(self, value):
+ """
+ Sets the value of :py:attr:`cacheNodeIds`.
+ """
+ return self._set(cacheNodeIds=value)
+
+ @since("1.4.0")
+ def setImpurity(self, value):
+ """
+ Sets the value of :py:attr:`impurity`.
+ """
+ return self._set(impurity=value)
+
@inherit_doc
class DecisionTreeModel(JavaModel, JavaPredictionModel):
@@ -1036,6 +1049,63 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre
def _create_model(self, java_model):
return RandomForestRegressionModel(java_model)
+ def setMaxDepth(self, value):
+ """
+ Sets the value of :py:attr:`maxDepth`.
+ """
+ return self._set(maxDepth=value)
+
+ def setMaxBins(self, value):
+ """
+ Sets the value of :py:attr:`maxBins`.
+ """
+ return self._set(maxBins=value)
+
+ def setMinInstancesPerNode(self, value):
+ """
+ Sets the value of :py:attr:`minInstancesPerNode`.
+ """
+ return self._set(minInstancesPerNode=value)
+
+ def setMinInfoGain(self, value):
+ """
+ Sets the value of :py:attr:`minInfoGain`.
+ """
+ return self._set(minInfoGain=value)
+
+ def setMaxMemoryInMB(self, value):
+ """
+ Sets the value of :py:attr:`maxMemoryInMB`.
+ """
+ return self._set(maxMemoryInMB=value)
+
+ def setCacheNodeIds(self, value):
+ """
+ Sets the value of :py:attr:`cacheNodeIds`.
+ """
+ return self._set(cacheNodeIds=value)
+
+ @since("1.4.0")
+ def setImpurity(self, value):
+ """
+ Sets the value of :py:attr:`impurity`.
+ """
+ return self._set(impurity=value)
+
+ @since("1.4.0")
+ def setNumTrees(self, value):
+ """
+ Sets the value of :py:attr:`numTrees`.
+ """
+ return self._set(numTrees=value)
+
+ @since("1.4.0")
+ def setSubsamplingRate(self, value):
+ """
+ Sets the value of :py:attr:`subsamplingRate`.
+ """
+ return self._set(subsamplingRate=value)
+
@since("2.4.0")
def setFeatureSubsetStrategy(self, value):
"""
@@ -1180,6 +1250,49 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre
def _create_model(self, java_model):
return GBTRegressionModel(java_model)
+ def setMaxDepth(self, value):
+ """
+ Sets the value of :py:attr:`maxDepth`.
+ """
+ return self._set(maxDepth=value)
+
+ def setMaxBins(self, value):
+ """
+ Sets the value of :py:attr:`maxBins`.
+ """
+ return self._set(maxBins=value)
+
+ def setMinInstancesPerNode(self, value):
+ """
+ Sets the value of :py:attr:`minInstancesPerNode`.
+ """
+ return self._set(minInstancesPerNode=value)
+
+ def setMinInfoGain(self, value):
+ """
+ Sets the value of :py:attr:`minInfoGain`.
+ """
+ return self._set(minInfoGain=value)
+
+ def setMaxMemoryInMB(self, value):
+ """
+ Sets the value of :py:attr:`maxMemoryInMB`.
+ """
+ return self._set(maxMemoryInMB=value)
+
+ def setCacheNodeIds(self, value):
+ """
+ Sets the value of :py:attr:`cacheNodeIds`.
+ """
+ return self._set(cacheNodeIds=value)
+
+ @since("1.4.0")
+ def setImpurity(self, value):
+ """
+ Sets the value of :py:attr:`impurity`.
+ """
+ return self._set(impurity=value)
+
@since("1.4.0")
def setLossType(self, value):
"""
@@ -1187,6 +1300,13 @@ def setLossType(self, value):
"""
return self._set(lossType=value)
+ @since("1.4.0")
+ def setSubsamplingRate(self, value):
+ """
+ Sets the value of :py:attr:`subsamplingRate`.
+ """
+ return self._set(subsamplingRate=value)
+
@since("2.4.0")
def setFeatureSubsetStrategy(self, value):
"""
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 6bb7da6b2edb2..e531000f3295c 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -36,6 +36,7 @@
from pyspark.sql.types import StringType, DataType
# Keep UserDefinedFunction import for backwards compatible import; moved in SPARK-22409
from pyspark.sql.udf import UserDefinedFunction, _create_udf
+from pyspark.sql.utils import to_str
# Note to developers: all of PySpark functions here take string as column names whenever possible.
# Namely, if columns are referred as arguments, they can be always both Column or string,
@@ -114,6 +115,10 @@ def _():
_.__doc__ = 'Window function: ' + doc
return _
+
+def _options_to_str(options):
+ return {key: to_str(value) for (key, value) in options.items()}
+
_lit_doc = """
Creates a :class:`Column` of literal value.
@@ -2343,7 +2348,7 @@ def from_json(col, schema, options={}):
schema = schema.json()
elif isinstance(schema, Column):
schema = _to_java_column(schema)
- jc = sc._jvm.functions.from_json(_to_java_column(col), schema, options)
+ jc = sc._jvm.functions.from_json(_to_java_column(col), schema, _options_to_str(options))
return Column(jc)
@@ -2384,7 +2389,7 @@ def to_json(col, options={}):
"""
sc = SparkContext._active_spark_context
- jc = sc._jvm.functions.to_json(_to_java_column(col), options)
+ jc = sc._jvm.functions.to_json(_to_java_column(col), _options_to_str(options))
return Column(jc)
@@ -2415,7 +2420,7 @@ def schema_of_json(json, options={}):
raise TypeError("schema argument should be a column or string")
sc = SparkContext._active_spark_context
- jc = sc._jvm.functions.schema_of_json(col, options)
+ jc = sc._jvm.functions.schema_of_json(col, _options_to_str(options))
return Column(jc)
@@ -2442,7 +2447,7 @@ def schema_of_csv(csv, options={}):
raise TypeError("schema argument should be a column or string")
sc = SparkContext._active_spark_context
- jc = sc._jvm.functions.schema_of_csv(col, options)
+ jc = sc._jvm.functions.schema_of_csv(col, _options_to_str(options))
return Column(jc)
@@ -2464,7 +2469,7 @@ def to_csv(col, options={}):
"""
sc = SparkContext._active_spark_context
- jc = sc._jvm.functions.to_csv(_to_java_column(col), options)
+ jc = sc._jvm.functions.to_csv(_to_java_column(col), _options_to_str(options))
return Column(jc)
@@ -2693,7 +2698,10 @@ def array_repeat(col, count):
[Row(r=[u'ab', u'ab', u'ab'])]
"""
sc = SparkContext._active_spark_context
- return Column(sc._jvm.functions.array_repeat(_to_java_column(col), count))
+ return Column(sc._jvm.functions.array_repeat(
+ _to_java_column(col),
+ _to_java_column(count) if isinstance(count, Column) else count
+ ))
@since(2.4)
@@ -2775,6 +2783,11 @@ def from_csv(col, schema, options={}):
>>> value = data[0][0]
>>> df.select(from_csv(df.value, schema_of_csv(value)).alias("csv")).collect()
[Row(csv=Row(_c0=1, _c1=2, _c2=3))]
+ >>> data = [(" abc",)]
+ >>> df = spark.createDataFrame(data, ("value",))
+ >>> options = {'ignoreLeadingWhiteSpace': True}
+ >>> df.select(from_csv(df.value, "s string", options).alias("csv")).collect()
+ [Row(csv=Row(s=u'abc'))]
"""
sc = SparkContext._active_spark_context
@@ -2785,7 +2798,7 @@ def from_csv(col, schema, options={}):
else:
raise TypeError("schema argument should be a column or string")
- jc = sc._jvm.functions.from_csv(_to_java_column(col), schema, options)
+ jc = sc._jvm.functions.from_csv(_to_java_column(col), schema, _options_to_str(options))
return Column(jc)
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index aa5bf635d1874..f9bc2ff72a505 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -27,23 +27,11 @@
from pyspark.sql.column import _to_seq
from pyspark.sql.types import *
from pyspark.sql import utils
+from pyspark.sql.utils import to_str
__all__ = ["DataFrameReader", "DataFrameWriter"]
-def to_str(value):
- """
- A wrapper over str(), but converts bool values to lower case strings.
- If None is given, just returns None, instead of converting it to string "None".
- """
- if isinstance(value, bool):
- return str(value).lower()
- elif value is None:
- return value
- else:
- return str(value)
-
-
class OptionUtils(object):
def _set_opts(self, schema=None, **options):
@@ -757,7 +745,7 @@ def save(self, path=None, format=None, mode=None, partitionBy=None, **options):
self._jwrite.save(path)
@since(1.4)
- def insertInto(self, tableName, overwrite=False):
+ def insertInto(self, tableName, overwrite=None):
"""Inserts the content of the :class:`DataFrame` to the specified table.
It requires that the schema of the class:`DataFrame` is the same as the
@@ -765,7 +753,9 @@ def insertInto(self, tableName, overwrite=False):
Optionally overwriting any existing data.
"""
- self._jwrite.mode("overwrite" if overwrite else "append").insertInto(tableName)
+ if overwrite is not None:
+ self.mode("overwrite" if overwrite else "append")
+ self._jwrite.insertInto(tableName)
@since(1.4)
def saveAsTable(self, name, format=None, mode=None, partitionBy=None, **options):
diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py
index 7dfc757970091..64f2fd6a3919f 100644
--- a/python/pyspark/sql/tests/test_functions.py
+++ b/python/pyspark/sql/tests/test_functions.py
@@ -294,6 +294,16 @@ def test_input_file_name_reset_for_rdd(self):
for result in results:
self.assertEqual(result[0], '')
+ def test_array_repeat(self):
+ from pyspark.sql.functions import array_repeat, lit
+
+ df = self.spark.range(1)
+
+ self.assertEquals(
+ df.select(array_repeat("id", 3)).toDF("val").collect(),
+ df.select(array_repeat("id", lit(3))).toDF("val").collect(),
+ )
+
if __name__ == "__main__":
import unittest
diff --git a/python/pyspark/sql/tests/test_readwriter.py b/python/pyspark/sql/tests/test_readwriter.py
index a708072489601..2530cc2ebf224 100644
--- a/python/pyspark/sql/tests/test_readwriter.py
+++ b/python/pyspark/sql/tests/test_readwriter.py
@@ -141,6 +141,27 @@ def count_bucketed_cols(names, table="pyspark_bucket"):
.mode("overwrite").saveAsTable("pyspark_bucket"))
self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))
+ def test_insert_into(self):
+ df = self.spark.createDataFrame([("a", 1), ("b", 2)], ["C1", "C2"])
+ with self.table("test_table"):
+ df.write.saveAsTable("test_table")
+ self.assertEqual(2, self.spark.sql("select * from test_table").count())
+
+ df.write.insertInto("test_table")
+ self.assertEqual(4, self.spark.sql("select * from test_table").count())
+
+ df.write.mode("overwrite").insertInto("test_table")
+ self.assertEqual(2, self.spark.sql("select * from test_table").count())
+
+ df.write.insertInto("test_table", True)
+ self.assertEqual(2, self.spark.sql("select * from test_table").count())
+
+ df.write.insertInto("test_table", False)
+ self.assertEqual(4, self.spark.sql("select * from test_table").count())
+
+ df.write.mode("overwrite").insertInto("test_table", False)
+ self.assertEqual(6, self.spark.sql("select * from test_table").count())
+
if __name__ == "__main__":
import unittest
diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py
index ca5e85bb3a9bb..c30cc1482750a 100644
--- a/python/pyspark/sql/utils.py
+++ b/python/pyspark/sql/utils.py
@@ -207,3 +207,16 @@ def call(self, jdf, batch_id):
class Java:
implements = ['org.apache.spark.sql.execution.streaming.sources.PythonForeachBatchFunction']
+
+
+def to_str(value):
+ """
+ A wrapper over str(), but converts bool values to lower case strings.
+ If None is given, just returns None, instead of converting it to string "None".
+ """
+ if isinstance(value, bool):
+ return str(value).lower()
+ elif value is None:
+ return value
+ else:
+ return str(value)
diff --git a/resource-managers/kubernetes/integration-tests/README.md b/resource-managers/kubernetes/integration-tests/README.md
index ea8286124a68c..d7ad35a175a61 100644
--- a/resource-managers/kubernetes/integration-tests/README.md
+++ b/resource-managers/kubernetes/integration-tests/README.md
@@ -11,7 +11,7 @@ is subject to change. Note that currently the integration tests only run with Ja
The simplest way to run the integration tests is to install and run Minikube, then run the following from this
directory:
- dev/dev-run-integration-tests.sh
+ ./dev/dev-run-integration-tests.sh
The minimum tested version of Minikube is 0.23.0. The kube-dns addon must be enabled. Minikube should
run with a minimum of 4 CPUs and 6G of memory:
@@ -62,11 +62,11 @@ By default, the test framework will build new Docker images on every test execut
and it is written to file at `target/imageTag.txt`. To reuse the images built in a previous run, or to use a Docker
image tag that you have built by other means already, pass the tag to the test script:
- dev/dev-run-integration-tests.sh --image-tag
+ ./dev/dev-run-integration-tests.sh --image-tag
where if you still want to use images that were built before by the test framework:
- dev/dev-run-integration-tests.sh --image-tag $(cat target/imageTag.txt)
+ ./dev/dev-run-integration-tests.sh --image-tag $(cat target/imageTag.txt)
### Customising the Image Names
@@ -74,11 +74,11 @@ If your image names do not follow the standard Spark naming convention - `spark`
If you use the same basic pattern but a different prefix for the name e.g. `apache-spark` you can just set `--base-image-name ` e.g.
- dev/dev-run-integration-tests.sh --base-image-name apache-spark
+ ./dev/dev-run-integration-tests.sh --base-image-name apache-spark
Alternatively if you use completely custom names then you can set each individually via the `--jvm-image-name `, `--python-image-name ` and `--r-image-name ` arguments e.g.
- dev/dev-run-integration-tests.sh --jvm-image-name jvm-spark --python-image-name pyspark --r-image-name sparkr
+ ./dev/dev-run-integration-tests.sh --jvm-image-name jvm-spark --python-image-name pyspark --r-image-name sparkr
## Spark Distribution Under Test
diff --git a/sql/README.md b/sql/README.md
index 70cc7c637b58d..f0ea848a41d09 100644
--- a/sql/README.md
+++ b/sql/README.md
@@ -9,4 +9,4 @@ Spark SQL is broken up into four subprojects:
- Hive Support (sql/hive) - Includes an extension of SQLContext called HiveContext that allows users to write queries using a subset of HiveQL and access data from a Hive Metastore using Hive SerDes. There are also wrappers that allow users to run queries that include Hive UDFs, UDAFs, and UDTFs.
- HiveServer and CLI support (sql/hive-thriftserver) - Includes support for the SQL CLI (bin/spark-sql) and a HiveServer2 (for JDBC/ODBC) compatible server.
-Running `sql/create-docs.sh` generates SQL documentation for built-in functions under `sql/site`.
+Running `./sql/create-docs.sh` generates SQL documentation for built-in functions under `sql/site`.
diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
index d991e7cf7e898..0a142c29a16f3 100644
--- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
+++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
@@ -113,6 +113,14 @@ statement
(AS? query)? #createHiveTable
| CREATE TABLE (IF NOT EXISTS)? target=tableIdentifier
LIKE source=tableIdentifier locationSpec? #createTableLike
+ | replaceTableHeader ('(' colTypeList ')')? tableProvider
+ ((OPTIONS options=tablePropertyList) |
+ (PARTITIONED BY partitioning=transformList) |
+ bucketSpec |
+ locationSpec |
+ (COMMENT comment=STRING) |
+ (TBLPROPERTIES tableProps=tablePropertyList))*
+ (AS? query)? #replaceTable
| ANALYZE TABLE tableIdentifier partitionSpec? COMPUTE STATISTICS
(identifier | FOR COLUMNS identifierSeq | FOR ALL COLUMNS)? #analyze
| ALTER TABLE multipartIdentifier
@@ -261,6 +269,10 @@ createTableHeader
: CREATE TEMPORARY? EXTERNAL? TABLE (IF NOT EXISTS)? multipartIdentifier
;
+replaceTableHeader
+ : (CREATE OR)? REPLACE TABLE multipartIdentifier
+ ;
+
bucketSpec
: CLUSTERED BY identifierList
(SORTED BY orderedIdentifierList)?
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/StagingTableCatalog.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/StagingTableCatalog.java
new file mode 100644
index 0000000000000..fc055e91a6acf
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/StagingTableCatalog.java
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalog.v2;
+
+import java.util.Map;
+
+import org.apache.spark.sql.catalog.v2.expressions.Transform;
+import org.apache.spark.sql.catalyst.analysis.NoSuchNamespaceException;
+import org.apache.spark.sql.catalyst.analysis.NoSuchTableException;
+import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException;
+import org.apache.spark.sql.sources.v2.StagedTable;
+import org.apache.spark.sql.sources.v2.SupportsWrite;
+import org.apache.spark.sql.sources.v2.writer.BatchWrite;
+import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage;
+import org.apache.spark.sql.types.StructType;
+import org.apache.spark.sql.util.CaseInsensitiveStringMap;
+
+/**
+ * An optional mix-in for implementations of {@link TableCatalog} that support staging creation of
+ * the a table before committing the table's metadata along with its contents in CREATE TABLE AS
+ * SELECT or REPLACE TABLE AS SELECT operations.
+ *
+ * It is highly recommended to implement this trait whenever possible so that CREATE TABLE AS
+ * SELECT and REPLACE TABLE AS SELECT operations are atomic. For example, when one runs a REPLACE
+ * TABLE AS SELECT operation, if the catalog does not implement this trait, the planner will first
+ * drop the table via {@link TableCatalog#dropTable(Identifier)}, then create the table via
+ * {@link TableCatalog#createTable(Identifier, StructType, Transform[], Map)}, and then perform
+ * the write via {@link SupportsWrite#newWriteBuilder(CaseInsensitiveStringMap)}. However, if the
+ * write operation fails, the catalog will have already dropped the table, and the planner cannot
+ * roll back the dropping of the table.
+ *
+ * If the catalog implements this plugin, the catalog can implement the methods to "stage" the
+ * creation and the replacement of a table. After the table's
+ * {@link BatchWrite#commit(WriterCommitMessage[])} is called,
+ * {@link StagedTable#commitStagedChanges()} is called, at which point the staged table can
+ * complete both the data write and the metadata swap operation atomically.
+ */
+public interface StagingTableCatalog extends TableCatalog {
+
+ /**
+ * Stage the creation of a table, preparing it to be committed into the metastore.
+ *
+ * When the table is committed, the contents of any writes performed by the Spark planner are
+ * committed along with the metadata about the table passed into this method's arguments. If the
+ * table exists when this method is called, the method should throw an exception accordingly. If
+ * another process concurrently creates the table before this table's staged changes are
+ * committed, an exception should be thrown by {@link StagedTable#commitStagedChanges()}.
+ *
+ * @param ident a table identifier
+ * @param schema the schema of the new table, as a struct type
+ * @param partitions transforms to use for partitioning data in the table
+ * @param properties a string map of table properties
+ * @return metadata for the new table
+ * @throws TableAlreadyExistsException If a table or view already exists for the identifier
+ * @throws UnsupportedOperationException If a requested partition transform is not supported
+ * @throws NoSuchNamespaceException If the identifier namespace does not exist (optional)
+ */
+ StagedTable stageCreate(
+ Identifier ident,
+ StructType schema,
+ Transform[] partitions,
+ Map properties) throws TableAlreadyExistsException, NoSuchNamespaceException;
+
+ /**
+ * Stage the replacement of a table, preparing it to be committed into the metastore when the
+ * returned table's {@link StagedTable#commitStagedChanges()} is called.
+ *
+ * When the table is committed, the contents of any writes performed by the Spark planner are
+ * committed along with the metadata about the table passed into this method's arguments. If the
+ * table exists, the metadata and the contents of this table replace the metadata and contents of
+ * the existing table. If a concurrent process commits changes to the table's data or metadata
+ * while the write is being performed but before the staged changes are committed, the catalog
+ * can decide whether to move forward with the table replacement anyways or abort the commit
+ * operation.
+ *
+ * If the table does not exist, committing the staged changes should fail with
+ * {@link NoSuchTableException}. This differs from the semantics of
+ * {@link #stageCreateOrReplace(Identifier, StructType, Transform[], Map)}, which should create
+ * the table in the data source if the table does not exist at the time of committing the
+ * operation.
+ *
+ * @param ident a table identifier
+ * @param schema the schema of the new table, as a struct type
+ * @param partitions transforms to use for partitioning data in the table
+ * @param properties a string map of table properties
+ * @return metadata for the new table
+ * @throws UnsupportedOperationException If a requested partition transform is not supported
+ * @throws NoSuchNamespaceException If the identifier namespace does not exist (optional)
+ * @throws NoSuchTableException If the table does not exist
+ */
+ StagedTable stageReplace(
+ Identifier ident,
+ StructType schema,
+ Transform[] partitions,
+ Map properties) throws NoSuchNamespaceException, NoSuchTableException;
+
+ /**
+ * Stage the creation or replacement of a table, preparing it to be committed into the metastore
+ * when the returned table's {@link StagedTable#commitStagedChanges()} is called.
+ *
+ * When the table is committed, the contents of any writes performed by the Spark planner are
+ * committed along with the metadata about the table passed into this method's arguments. If the
+ * table exists, the metadata and the contents of this table replace the metadata and contents of
+ * the existing table. If a concurrent process commits changes to the table's data or metadata
+ * while the write is being performed but before the staged changes are committed, the catalog
+ * can decide whether to move forward with the table replacement anyways or abort the commit
+ * operation.
+ *
+ * If the table does not exist when the changes are committed, the table should be created in the
+ * backing data source. This differs from the expected semantics of
+ * {@link #stageReplace(Identifier, StructType, Transform[], Map)}, which should fail when
+ * the staged changes are committed but the table doesn't exist at commit time.
+ *
+ * @param ident a table identifier
+ * @param schema the schema of the new table, as a struct type
+ * @param partitions transforms to use for partitioning data in the table
+ * @param properties a string map of table properties
+ * @return metadata for the new table
+ * @throws UnsupportedOperationException If a requested partition transform is not supported
+ * @throws NoSuchNamespaceException If the identifier namespace does not exist (optional)
+ */
+ StagedTable stageCreateOrReplace(
+ Identifier ident,
+ StructType schema,
+ Transform[] partitions,
+ Map properties) throws NoSuchNamespaceException;
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/StagedTable.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/StagedTable.java
new file mode 100644
index 0000000000000..b2baa93b146a5
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/StagedTable.java
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources.v2;
+
+import java.util.Map;
+import org.apache.spark.sql.catalog.v2.Identifier;
+import org.apache.spark.sql.catalog.v2.StagingTableCatalog;
+import org.apache.spark.sql.catalog.v2.expressions.Transform;
+import org.apache.spark.sql.types.StructType;
+import org.apache.spark.sql.util.CaseInsensitiveStringMap;
+
+/**
+ * Represents a table which is staged for being committed to the metastore.
+ *
+ * This is used to implement atomic CREATE TABLE AS SELECT and REPLACE TABLE AS SELECT queries. The
+ * planner will create one of these via
+ * {@link StagingTableCatalog#stageCreate(Identifier, StructType, Transform[], Map)} or
+ * {@link StagingTableCatalog#stageReplace(Identifier, StructType, Transform[], Map)} to prepare the
+ * table for being written to. This table should usually implement {@link SupportsWrite}. A new
+ * writer will be constructed via {@link SupportsWrite#newWriteBuilder(CaseInsensitiveStringMap)},
+ * and the write will be committed. The job concludes with a call to {@link #commitStagedChanges()},
+ * at which point implementations are expected to commit the table's metadata into the metastore
+ * along with the data that was written by the writes from the write builder this table created.
+ */
+public interface StagedTable extends Table {
+
+ /**
+ * Finalize the creation or replacement of this table.
+ */
+ void commitStagedChanges();
+
+ /**
+ * Abort the changes that were staged, both in metadata and from temporary outputs of this
+ * table's writers.
+ */
+ void abortStagedChanges();
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
index 9853a4fcc2f9d..29d81c553ff61 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
@@ -80,4 +80,10 @@ trait Encoder[T] extends Serializable {
* A ClassTag that can be used to construct an Array to contain a collection of `T`.
*/
def clsTag: ClassTag[T]
+
+ /**
+ * Create a copied [[Encoder]]. The implementation may just copy internal reusable fields to speed
+ * up the [[Encoder]] creation.
+ */
+ def makeCopy: Encoder[T]
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
index 6020b068155fc..488252aa0c7b5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
@@ -343,6 +343,9 @@ object CatalystTypeConverters {
private class DecimalConverter(dataType: DecimalType)
extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] {
+
+ private val nullOnOverflow = SQLConf.get.decimalOperationsNullOnOverflow
+
override def toCatalystImpl(scalaValue: Any): Decimal = {
val decimal = scalaValue match {
case d: BigDecimal => Decimal(d)
@@ -353,7 +356,7 @@ object CatalystTypeConverters {
s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) "
+ s"cannot be converted to ${dataType.catalogString}")
}
- decimal.toPrecision(dataType.precision, dataType.scale)
+ decimal.toPrecision(dataType.precision, dataType.scale, Decimal.ROUND_HALF_UP, nullOnOverflow)
}
override def toScala(catalystValue: Decimal): JavaBigDecimal = {
if (catalystValue == null) null
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CannotReplaceMissingTableException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CannotReplaceMissingTableException.scala
new file mode 100644
index 0000000000000..3036f7c21093f
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CannotReplaceMissingTableException.scala
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalog.v2.Identifier
+
+class CannotReplaceMissingTableException(
+ tableIdentifier: Identifier,
+ cause: Option[Throwable] = None)
+ extends AnalysisException(
+ s"Table $tableIdentifier cannot be replaced as it did not exist." +
+ s" Use CREATE OR REPLACE TABLE to create the table.", cause = cause)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index c72400a8b72c2..3408b496d9d2a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -414,6 +414,7 @@ object FunctionRegistry {
expression[WeekOfYear]("weekofyear"),
expression[Year]("year"),
expression[TimeWindow]("window"),
+ expression[MakeDate]("make_date"),
// collection functions
expression[CreateArray]("array"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
index 2d646721f87a2..c6c1d3bfa6347 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
@@ -561,6 +561,8 @@ object CatalogTableType {
val EXTERNAL = new CatalogTableType("EXTERNAL")
val MANAGED = new CatalogTableType("MANAGED")
val VIEW = new CatalogTableType("VIEW")
+
+ val tableTypes = Seq(EXTERNAL, MANAGED, VIEW)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala
index 1268fcffcfcd0..7d52847216cc9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala
@@ -163,6 +163,11 @@ class CSVOptions(
val inputBufferSize = 128
+ /**
+ * The max error content length in CSV parser/writer exception message.
+ */
+ val maxErrorContentLength = 1000
+
val isCommentSet = this.comment != '\u0000'
val samplingRatio =
@@ -220,6 +225,7 @@ class CSVOptions(
writerSettings.setSkipEmptyLines(true)
writerSettings.setQuoteAllFields(quoteAll)
writerSettings.setQuoteEscapingEnabled(escapeQuotes)
+ writerSettings.setErrorContentLength(maxErrorContentLength)
writerSettings
}
@@ -246,6 +252,7 @@ class CSVOptions(
lineSeparatorInRead.foreach { _ =>
settings.setNormalizeLineEndingsWithinQuotes(!multiLine)
}
+ settings.setErrorContentLength(maxErrorContentLength)
settings
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index c97303be1d27c..bd499671d6441 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -382,6 +382,8 @@ case class ExpressionEncoder[T](
.map { case(f, a) => s"${f.name}$a: ${f.dataType.simpleString}"}.mkString(", ")
override def toString: String = s"class[$schemaString]"
+
+ override def makeCopy: ExpressionEncoder[T] = copy()
}
// A dummy logical plan that can hold expressions and go through optimizer rules.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index f671ede21782a..5314821ea3a59 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -16,7 +16,8 @@
*/
package org.apache.spark.sql.catalyst.expressions
-import java.util.{Comparator, TimeZone}
+import java.time.ZoneId
+import java.util.Comparator
import scala.collection.mutable
import scala.reflect.ClassTag
@@ -2459,10 +2460,10 @@ case class Sequence(
new IntegralSequenceImpl(iType)(ct, iType.integral)
case TimestampType =>
- new TemporalSequenceImpl[Long](LongType, 1, identity, timeZone)
+ new TemporalSequenceImpl[Long](LongType, 1, identity, zoneId)
case DateType =>
- new TemporalSequenceImpl[Int](IntegerType, MICROS_PER_DAY, _.toInt, timeZone)
+ new TemporalSequenceImpl[Int](IntegerType, MICROS_PER_DAY, _.toInt, zoneId)
}
override def eval(input: InternalRow): Any = {
@@ -2603,7 +2604,7 @@ object Sequence {
}
private class TemporalSequenceImpl[T: ClassTag]
- (dt: IntegralType, scale: Long, fromLong: Long => T, timeZone: TimeZone)
+ (dt: IntegralType, scale: Long, fromLong: Long => T, zoneId: ZoneId)
(implicit num: Integral[T]) extends SequenceImpl {
override val defaultStep: DefaultStep = new DefaultStep(
@@ -2642,7 +2643,7 @@ object Sequence {
while (t < exclusiveItem ^ stepSign < 0) {
arr(i) = fromLong(t / scale)
i += 1
- t = timestampAddInterval(startMicros, i * stepMonths, i * stepMicros, timeZone)
+ t = timestampAddInterval(startMicros, i * stepMonths, i * stepMicros, zoneId)
}
// truncate array to the correct length
@@ -2668,7 +2669,7 @@ object Sequence {
val exclusiveItem = ctx.freshName("exclusiveItem")
val t = ctx.freshName("t")
val i = ctx.freshName("i")
- val genTimeZone = ctx.addReferenceObj("timeZone", timeZone, classOf[TimeZone].getName)
+ val zid = ctx.addReferenceObj("zoneId", zoneId, classOf[ZoneId].getName)
val sequenceLengthCode =
s"""
@@ -2701,7 +2702,7 @@ object Sequence {
| $arr[$i] = ($elemType) ($t / ${scale}L);
| $i += 1;
| $t = org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampAddInterval(
- | $startMicros, $i * $stepMonths, $i * $stepMicros, $genTimeZone);
+ | $startMicros, $i * $stepMonths, $i * $stepMicros, $zid);
| }
|
| if ($arr.length > $i) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
index ccf6b36effa08..edb5382ae4437 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
@@ -996,14 +996,14 @@ case class TimeAdd(start: Expression, interval: Expression, timeZoneId: Option[S
override def nullSafeEval(start: Any, interval: Any): Any = {
val itvl = interval.asInstanceOf[CalendarInterval]
DateTimeUtils.timestampAddInterval(
- start.asInstanceOf[Long], itvl.months, itvl.microseconds, timeZone)
+ start.asInstanceOf[Long], itvl.months, itvl.microseconds, zoneId)
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- val tz = ctx.addReferenceObj("timeZone", timeZone)
+ val zid = ctx.addReferenceObj("zoneId", zoneId, classOf[ZoneId].getName)
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
defineCodeGen(ctx, ev, (sd, i) => {
- s"""$dtu.timestampAddInterval($sd, $i.months, $i.microseconds, $tz)"""
+ s"""$dtu.timestampAddInterval($sd, $i.months, $i.microseconds, $zid)"""
})
}
}
@@ -1111,14 +1111,14 @@ case class TimeSub(start: Expression, interval: Expression, timeZoneId: Option[S
override def nullSafeEval(start: Any, interval: Any): Any = {
val itvl = interval.asInstanceOf[CalendarInterval]
DateTimeUtils.timestampAddInterval(
- start.asInstanceOf[Long], 0 - itvl.months, 0 - itvl.microseconds, timeZone)
+ start.asInstanceOf[Long], 0 - itvl.months, 0 - itvl.microseconds, zoneId)
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- val tz = ctx.addReferenceObj("timeZone", timeZone)
+ val zid = ctx.addReferenceObj("zoneId", zoneId, classOf[ZoneId].getName)
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
defineCodeGen(ctx, ev, (sd, i) => {
- s"""$dtu.timestampAddInterval($sd, 0 - $i.months, 0 - $i.microseconds, $tz)"""
+ s"""$dtu.timestampAddInterval($sd, 0 - $i.months, 0 - $i.microseconds, $zid)"""
})
}
}
@@ -1605,3 +1605,55 @@ private case class GetTimestamp(
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
copy(timeZoneId = Option(timeZoneId))
}
+
+@ExpressionDescription(
+ usage = "_FUNC_(year, month, day) - Create date from year, month and day fields.",
+ arguments = """
+ Arguments:
+ * year - the year to represent, from 1 to 9999
+ * month - the month-of-year to represent, from 1 (January) to 12 (December)
+ * day - the day-of-month to represent, from 1 to 31
+ """,
+ examples = """
+ Examples:
+ > SELECT _FUNC_(2013, 7, 15);
+ 2013-07-15
+ > SELECT _FUNC_(2019, 13, 1);
+ NULL
+ > SELECT _FUNC_(2019, 7, NULL);
+ NULL
+ > SELECT _FUNC_(2019, 2, 30);
+ NULL
+ """,
+ since = "3.0.0")
+case class MakeDate(year: Expression, month: Expression, day: Expression)
+ extends TernaryExpression with ImplicitCastInputTypes {
+
+ override def children: Seq[Expression] = Seq(year, month, day)
+ override def inputTypes: Seq[AbstractDataType] = Seq(IntegerType, IntegerType, IntegerType)
+ override def dataType: DataType = DateType
+ override def nullable: Boolean = true
+
+ override def nullSafeEval(year: Any, month: Any, day: Any): Any = {
+ try {
+ val ld = LocalDate.of(year.asInstanceOf[Int], month.asInstanceOf[Int], day.asInstanceOf[Int])
+ localDateToDays(ld)
+ } catch {
+ case _: java.time.DateTimeException => null
+ }
+ }
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
+ nullSafeCodeGen(ctx, ev, (year, month, day) => {
+ s"""
+ try {
+ ${ev.value} = $dtu.localDateToDays(java.time.LocalDate.of($year, $month, $day));
+ } catch (java.time.DateTimeException e) {
+ ${ev.isNull} = true;
+ }"""
+ })
+ }
+
+ override def prettyName: String = "make_date"
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
index e873f8ed1a21c..6dd2fa716e6bc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
@@ -1177,6 +1177,7 @@ abstract class RoundBase(child: Expression, scale: Expression,
dataType match {
case DecimalType.Fixed(_, s) =>
val decimal = input1.asInstanceOf[Decimal]
+ // Overflow cannot happen, so no need to control nullOnOverflow
decimal.toPrecision(decimal.precision, s, mode)
case ByteType =>
BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, mode).toByte
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala
index 95aefb6422d67..43a6006f9b5c0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala
@@ -54,7 +54,7 @@ object NestedColumnAliasing {
/**
* Return a replaced project list.
*/
- private def getNewProjectList(
+ def getNewProjectList(
projectList: Seq[NamedExpression],
nestedFieldToAlias: Map[ExtractValue, Alias]): Seq[NamedExpression] = {
projectList.map(_.transform {
@@ -66,7 +66,7 @@ object NestedColumnAliasing {
/**
* Return a plan with new children replaced with aliases.
*/
- private def replaceChildrenWithAliases(
+ def replaceChildrenWithAliases(
plan: LogicalPlan,
attrToAliases: Map[ExprId, Seq[Alias]]): LogicalPlan = {
plan.withNewChildren(plan.children.map { plan =>
@@ -107,10 +107,10 @@ object NestedColumnAliasing {
* 1. ExtractValue -> Alias: A new alias is created for each nested field.
* 2. ExprId -> Seq[Alias]: A reference attribute has multiple aliases pointing it.
*/
- private def getAliasSubMap(projectList: Seq[NamedExpression])
+ def getAliasSubMap(exprList: Seq[Expression])
: Option[(Map[ExtractValue, Alias], Map[ExprId, Seq[Alias]])] = {
val (nestedFieldReferences, otherRootReferences) =
- projectList.flatMap(collectRootReferenceAndExtractValue).partition {
+ exprList.flatMap(collectRootReferenceAndExtractValue).partition {
case _: ExtractValue => true
case _ => false
}
@@ -155,4 +155,15 @@ object NestedColumnAliasing {
case MapType(keyType, valueType, _) => totalFieldNum(keyType) + totalFieldNum(valueType)
case _ => 1 // UDT and others
}
+
+ /**
+ * This is a while-list for pruning nested fields at `Generator`.
+ */
+ def canPruneGenerator(g: Generator): Boolean = g match {
+ case _: Explode => true
+ case _: Stack => true
+ case _: PosExplode => true
+ case _: Inline => true
+ case _ => false
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index c99d2c06fac63..206d09a6f79e9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -487,7 +487,7 @@ object LimitPushDown extends Rule[LogicalPlan] {
* Union:
* Right now, Union means UNION ALL, which does not de-duplicate rows. So, it is
* safe to pushdown Filters and Projections through it. Filter pushdown is handled by another
- * rule PushDownPredicate. Once we add UNION DISTINCT, we will not be able to pushdown Projections.
+ * rule PushDownPredicates. Once we add UNION DISTINCT, we will not be able to pushdown Projections.
*/
object PushProjectionThroughUnion extends Rule[LogicalPlan] with PredicateHelper {
@@ -588,6 +588,24 @@ object ColumnPruning extends Rule[LogicalPlan] {
.map(_._2)
p.copy(child = g.copy(child = newChild, unrequiredChildIndex = unrequiredIndices))
+ // prune unrequired nested fields
+ case p @ Project(projectList, g: Generate) if SQLConf.get.nestedPruningOnExpressions &&
+ NestedColumnAliasing.canPruneGenerator(g.generator) =>
+ NestedColumnAliasing.getAliasSubMap(projectList ++ g.generator.children).map {
+ case (nestedFieldToAlias, attrToAliases) =>
+ val newGenerator = g.generator.transform {
+ case f: ExtractValue if nestedFieldToAlias.contains(f) =>
+ nestedFieldToAlias(f).toAttribute
+ }.asInstanceOf[Generator]
+
+ // Defer updating `Generate.unrequiredChildIndex` to next round of `ColumnPruning`.
+ val newGenerate = g.copy(generator = newGenerator)
+
+ val newChild = NestedColumnAliasing.replaceChildrenWithAliases(newGenerate, attrToAliases)
+
+ Project(NestedColumnAliasing.getNewProjectList(projectList, nestedFieldToAlias), newChild)
+ }.getOrElse(p)
+
// Eliminate unneeded attributes from right side of a Left Existence Join.
case j @ Join(_, right, LeftExistence(_), _, _) =>
j.copy(right = prunedChild(right, j.references))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index d9f8b9a7203ff..a7a3b96ba726d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last}
import org.apache.spark.sql.catalyst.parser.SqlBaseParser._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.plans.logical.sql.{AlterTableAddColumnsStatement, AlterTableAlterColumnStatement, AlterTableDropColumnsStatement, AlterTableRenameColumnStatement, AlterTableSetLocationStatement, AlterTableSetPropertiesStatement, AlterTableUnsetPropertiesStatement, AlterViewSetPropertiesStatement, AlterViewUnsetPropertiesStatement, CreateTableAsSelectStatement, CreateTableStatement, DropTableStatement, DropViewStatement, QualifiedColType}
+import org.apache.spark.sql.catalyst.plans.logical.sql.{AlterTableAddColumnsStatement, AlterTableAlterColumnStatement, AlterTableDropColumnsStatement, AlterTableRenameColumnStatement, AlterTableSetLocationStatement, AlterTableSetPropertiesStatement, AlterTableUnsetPropertiesStatement, AlterViewSetPropertiesStatement, AlterViewUnsetPropertiesStatement, CreateTableAsSelectStatement, CreateTableStatement, DropTableStatement, DropViewStatement, QualifiedColType, ReplaceTableAsSelectStatement, ReplaceTableStatement}
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, stringToDate, stringToTimestamp}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -1926,6 +1926,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
case ("decimal", precision :: Nil) => DecimalType(precision.getText.toInt, 0)
case ("decimal", precision :: scale :: Nil) =>
DecimalType(precision.getText.toInt, scale.getText.toInt)
+ case ("interval", Nil) => CalendarIntervalType
case (dt, params) =>
val dtStr = if (params.nonEmpty) s"$dt(${params.mkString(",")})" else dt
throw new ParseException(s"DataType $dtStr is not supported.", ctx)
@@ -2127,6 +2128,15 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
(multipartIdentifier, temporary, ifNotExists, ctx.EXTERNAL != null)
}
+ /**
+ * Validate a replace table statement and return the [[TableIdentifier]].
+ */
+ override def visitReplaceTableHeader(
+ ctx: ReplaceTableHeaderContext): TableHeader = withOrigin(ctx) {
+ val multipartIdentifier = ctx.multipartIdentifier.parts.asScala.map(_.getText)
+ (multipartIdentifier, false, false, false)
+ }
+
/**
* Parse a qualified name to a multipart name.
*/
@@ -2294,6 +2304,69 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
}
}
+ /**
+ * Replace a table, returning a [[ReplaceTableStatement]] logical plan.
+ *
+ * Expected format:
+ * {{{
+ * [CREATE OR] REPLACE TABLE [db_name.]table_name
+ * USING table_provider
+ * replace_table_clauses
+ * [[AS] select_statement];
+ *
+ * replace_table_clauses (order insensitive):
+ * [OPTIONS table_property_list]
+ * [PARTITIONED BY (col_name, transform(col_name), transform(constant, col_name), ...)]
+ * [CLUSTERED BY (col_name, col_name, ...)
+ * [SORTED BY (col_name [ASC|DESC], ...)]
+ * INTO num_buckets BUCKETS
+ * ]
+ * [LOCATION path]
+ * [COMMENT table_comment]
+ * [TBLPROPERTIES (property_name=property_value, ...)]
+ * }}}
+ */
+ override def visitReplaceTable(ctx: ReplaceTableContext): LogicalPlan = withOrigin(ctx) {
+ val (table, _, ifNotExists, external) = visitReplaceTableHeader(ctx.replaceTableHeader)
+ if (external) {
+ operationNotAllowed("REPLACE EXTERNAL TABLE ... USING", ctx)
+ }
+
+ checkDuplicateClauses(ctx.TBLPROPERTIES, "TBLPROPERTIES", ctx)
+ checkDuplicateClauses(ctx.OPTIONS, "OPTIONS", ctx)
+ checkDuplicateClauses(ctx.PARTITIONED, "PARTITIONED BY", ctx)
+ checkDuplicateClauses(ctx.COMMENT, "COMMENT", ctx)
+ checkDuplicateClauses(ctx.bucketSpec(), "CLUSTERED BY", ctx)
+ checkDuplicateClauses(ctx.locationSpec, "LOCATION", ctx)
+
+ val schema = Option(ctx.colTypeList()).map(createSchema)
+ val partitioning: Seq[Transform] =
+ Option(ctx.partitioning).map(visitTransformList).getOrElse(Nil)
+ val bucketSpec = ctx.bucketSpec().asScala.headOption.map(visitBucketSpec)
+ val properties = Option(ctx.tableProps).map(visitPropertyKeyValues).getOrElse(Map.empty)
+ val options = Option(ctx.options).map(visitPropertyKeyValues).getOrElse(Map.empty)
+
+ val provider = ctx.tableProvider.qualifiedName.getText
+ val location = ctx.locationSpec.asScala.headOption.map(visitLocationSpec)
+ val comment = Option(ctx.comment).map(string)
+ val orCreate = ctx.replaceTableHeader().CREATE() != null
+
+ Option(ctx.query).map(plan) match {
+ case Some(_) if schema.isDefined =>
+ operationNotAllowed(
+ "Schema may not be specified in a Replace Table As Select (RTAS) statement",
+ ctx)
+
+ case Some(query) =>
+ ReplaceTableAsSelectStatement(table, query, partitioning, bucketSpec, properties,
+ provider, options, location, comment, orCreate = orCreate)
+
+ case _ =>
+ ReplaceTableStatement(table, schema.getOrElse(new StructType), partitioning,
+ bucketSpec, properties, provider, options, location, comment, orCreate = orCreate)
+ }
+ }
+
/**
* Create a [[DropTableStatement]] command.
*/
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index 2cb04c9ec70c5..2698ba282f962 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -441,6 +441,47 @@ case class CreateTableAsSelect(
}
}
+/**
+ * Replace a table with a v2 catalog.
+ *
+ * If the table does not exist, and orCreate is true, then it will be created.
+ * If the table does not exist, and orCreate is false, then an exception will be thrown.
+ *
+ * The persisted table will have no contents as a result of this operation.
+ */
+case class ReplaceTable(
+ catalog: TableCatalog,
+ tableName: Identifier,
+ tableSchema: StructType,
+ partitioning: Seq[Transform],
+ properties: Map[String, String],
+ orCreate: Boolean) extends Command
+
+/**
+ * Replaces a table from a select query with a v2 catalog.
+ *
+ * If the table does not exist, and orCreate is true, then it will be created.
+ * If the table does not exist, and orCreate is false, then an exception will be thrown.
+ */
+case class ReplaceTableAsSelect(
+ catalog: TableCatalog,
+ tableName: Identifier,
+ partitioning: Seq[Transform],
+ query: LogicalPlan,
+ properties: Map[String, String],
+ writeOptions: Map[String, String],
+ orCreate: Boolean) extends Command {
+
+ override def children: Seq[LogicalPlan] = Seq(query)
+
+ override lazy val resolved: Boolean = {
+ // the table schema is created from the query schema, so the only resolution needed is to check
+ // that the columns referenced by the table's partitioning exist in the query schema
+ val references = partitioning.flatMap(_.references).toSet
+ references.map(_.fieldNames).forall(query.schema.findNestedField(_).isDefined)
+ }
+}
+
/**
* Append data to an existing table.
*/
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/ReplaceTableStatement.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/ReplaceTableStatement.scala
new file mode 100644
index 0000000000000..2808892b089b9
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/ReplaceTableStatement.scala
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.plans.logical.sql
+
+import org.apache.spark.sql.catalog.v2.expressions.Transform
+import org.apache.spark.sql.catalyst.catalog.BucketSpec
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.types.StructType
+
+/**
+ * A REPLACE TABLE command, as parsed from SQL.
+ *
+ * If the table exists prior to running this command, executing this statement
+ * will replace the table's metadata and clear the underlying rows from the table.
+ */
+case class ReplaceTableStatement(
+ tableName: Seq[String],
+ tableSchema: StructType,
+ partitioning: Seq[Transform],
+ bucketSpec: Option[BucketSpec],
+ properties: Map[String, String],
+ provider: String,
+ options: Map[String, String],
+ location: Option[String],
+ comment: Option[String],
+ orCreate: Boolean) extends ParsedStatement
+
+/**
+ * A REPLACE TABLE AS SELECT command, as parsed from SQL.
+ */
+case class ReplaceTableAsSelectStatement(
+ tableName: Seq[String],
+ asSelect: LogicalPlan,
+ partitioning: Seq[Transform],
+ bucketSpec: Option[BucketSpec],
+ properties: Map[String, String],
+ provider: String,
+ options: Map[String, String],
+ location: Option[String],
+ comment: Option[String],
+ orCreate: Boolean) extends ParsedStatement {
+
+ override def children: Seq[LogicalPlan] = Seq(asSelect)
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index 0596dc00985a1..e79000d583506 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -287,7 +287,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
mapChildren(_.transformDown(rule))
} else {
// If the transform function replaces this node with a new one, carry over the tags.
- afterRule.tags ++= this.tags
+ afterRule.copyTagsFrom(this)
afterRule.mapChildren(_.transformDown(rule))
}
}
@@ -311,7 +311,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
}
}
// If the transform function replaces this node with a new one, carry over the tags.
- newNode.tags ++= this.tags
+ newNode.copyTagsFrom(this)
newNode
}
@@ -429,8 +429,15 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
private def makeCopy(
newArgs: Array[AnyRef],
allowEmptyArgs: Boolean): BaseType = attachTree(this, "makeCopy") {
+ val allCtors = getClass.getConstructors
+ if (newArgs.isEmpty && allCtors.isEmpty) {
+ // This is a singleton object which doesn't have any constructor. Just return `this` as we
+ // can't copy it.
+ return this
+ }
+
// Skip no-arg constructors that are just there for kryo.
- val ctors = getClass.getConstructors.filter(allowEmptyArgs || _.getParameterTypes.size != 0)
+ val ctors = allCtors.filter(allowEmptyArgs || _.getParameterTypes.size != 0)
if (ctors.isEmpty) {
sys.error(s"No valid constructor for $nodeName")
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
index 1daf65a0c560c..10a7f9bd550e2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
@@ -19,8 +19,7 @@ package org.apache.spark.sql.catalyst.util
import java.sql.{Date, Timestamp}
import java.time._
-import java.time.Year.isLeap
-import java.time.temporal.IsoFields
+import java.time.temporal.{ChronoUnit, IsoFields}
import java.util.{Locale, TimeZone}
import java.util.concurrent.TimeUnit._
@@ -521,12 +520,12 @@ object DateTimeUtils {
start: SQLTimestamp,
months: Int,
microseconds: Long,
- timeZone: TimeZone): SQLTimestamp = {
- val days = millisToDays(MICROSECONDS.toMillis(start), timeZone)
- val newDays = dateAddMonths(days, months)
- start +
- MILLISECONDS.toMicros(daysToMillis(newDays, timeZone) - daysToMillis(days, timeZone)) +
- microseconds
+ zoneId: ZoneId): SQLTimestamp = {
+ val resultTimestamp = microsToInstant(start)
+ .atZone(zoneId)
+ .plusMonths(months)
+ .plus(microseconds, ChronoUnit.MICROS)
+ instantToMicros(resultTimestamp.toInstant)
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 57f5128fd4fbe..fbdb1c5f957d9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -1656,6 +1656,16 @@ object SQLConf {
.booleanConf
.createWithDefault(false)
+ val NESTED_PRUNING_ON_EXPRESSIONS =
+ buildConf("spark.sql.optimizer.expression.nestedPruning.enabled")
+ .internal()
+ .doc("Prune nested fields from expressions in an operator which are unnecessary in " +
+ "satisfying a query. Note that this optimization doesn't prune nested fields from " +
+ "physical data source scanning. For pruning nested fields from scanning, please use " +
+ "`spark.sql.optimizer.nestedSchemaPruning.enabled` config.")
+ .booleanConf
+ .createWithDefault(false)
+
val TOP_K_SORT_FALLBACK_THRESHOLD =
buildConf("spark.sql.execution.topKSortFallbackThreshold")
.internal()
@@ -2315,6 +2325,8 @@ class SQLConf extends Serializable with Logging {
def serializerNestedSchemaPruningEnabled: Boolean =
getConf(SERIALIZER_NESTED_SCHEMA_PRUNING_ENABLED)
+ def nestedPruningOnExpressions: Boolean = getConf(NESTED_PRUNING_ON_EXPRESSIONS)
+
def csvColumnPruning: Boolean = getConf(SQLConf.CSV_PARSER_COLUMN_PRUNING)
def legacySizeOfNull: Boolean = getConf(SQLConf.LEGACY_SIZE_OF_NULL)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala
index 8e297874a0d62..ea94cf626698a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala
@@ -34,6 +34,8 @@ class CalendarIntervalType private() extends DataType {
override def defaultSize: Int = 16
+ override def simpleString: String = "interval"
+
private[spark] override def asNullable: CalendarIntervalType = this
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
index 1bf322af21799..a5d1a72d62d5c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
@@ -414,20 +414,12 @@ final class Decimal extends Ordered[Decimal] with Serializable {
def floor: Decimal = if (scale == 0) this else {
val newPrecision = DecimalType.bounded(precision - scale + 1, 0).precision
- val res = toPrecision(newPrecision, 0, ROUND_FLOOR)
- if (res == null) {
- throw new AnalysisException(s"Overflow when setting precision to $newPrecision")
- }
- res
+ toPrecision(newPrecision, 0, ROUND_FLOOR, nullOnOverflow = false)
}
def ceil: Decimal = if (scale == 0) this else {
val newPrecision = DecimalType.bounded(precision - scale + 1, 0).precision
- val res = toPrecision(newPrecision, 0, ROUND_CEILING)
- if (res == null) {
- throw new AnalysisException(s"Overflow when setting precision to $newPrecision")
- }
- res
+ toPrecision(newPrecision, 0, ROUND_CEILING, nullOnOverflow = false)
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
index 725764755c626..4440ac9e281c4 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
@@ -77,9 +77,9 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
assertErrorForDifferingTypes(BitwiseOr('intField, 'booleanField))
assertErrorForDifferingTypes(BitwiseXor('intField, 'booleanField))
- assertError(Add('booleanField, 'booleanField), "requires (numeric or calendarinterval) type")
+ assertError(Add('booleanField, 'booleanField), "requires (numeric or interval) type")
assertError(Subtract('booleanField, 'booleanField),
- "requires (numeric or calendarinterval) type")
+ "requires (numeric or interval) type")
assertError(Multiply('booleanField, 'booleanField), "requires numeric type")
assertError(Divide('booleanField, 'booleanField), "requires (double or decimal) type")
assertError(Remainder('booleanField, 'booleanField), "requires numeric type")
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index f4feeca1d05ad..9380c7e3f5f72 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -427,6 +427,10 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes
testOverflowingBigNumeric(BigInt("9" * 100), "scala very large big int")
testOverflowingBigNumeric(new BigInteger("9" * 100), "java very big int")
+ encodeDecodeTest("foo" -> 1L, "makeCopy") {
+ Encoders.product[(String, Long)].makeCopy.asInstanceOf[ExpressionEncoder[(String, Long)]]
+ }
+
private def testOverflowingBigNumeric[T: TypeTag](bigNumeric: T, testName: String): Unit = {
Seq(true, false).foreach { allowNullOnOverflow =>
testAndVerifyNotLeakingReflectionObjects(
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
index 4e8322d3c55d7..b4110afd55057 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
@@ -918,4 +918,14 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}
}
+
+ test("creating values of DateType via make_date") {
+ checkEvaluation(MakeDate(Literal(2013), Literal(7), Literal(15)), Date.valueOf("2013-7-15"))
+ checkEvaluation(MakeDate(Literal.create(null, IntegerType), Literal(7), Literal(15)), null)
+ checkEvaluation(MakeDate(Literal(2019), Literal.create(null, IntegerType), Literal(19)), null)
+ checkEvaluation(MakeDate(Literal(2019), Literal(7), Literal.create(null, IntegerType)), null)
+ checkEvaluation(MakeDate(Literal(Int.MaxValue), Literal(13), Literal(19)), null)
+ checkEvaluation(MakeDate(Literal(2019), Literal(13), Literal(19)), null)
+ checkEvaluation(MakeDate(Literal(2019), Literal(7), Literal(32)), null)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala
index b190d6f5caa1c..f8400a590606a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala
@@ -696,7 +696,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with
val struct2 = Literal.create(null, schema2)
StructsToJson(Map.empty, struct2, gmtId).checkInputDataTypes() match {
case TypeCheckResult.TypeCheckFailure(msg) =>
- assert(msg.contains("Unable to convert column a of type calendarinterval to JSON"))
+ assert(msg.contains("Unable to convert column a of type interval to JSON"))
case _ => fail("from_json should not work on interval map value type.")
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala
index df92fa3475bd9..981ef57c051fd 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala
@@ -21,7 +21,8 @@ import java.util.Locale
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
-import org.apache.spark.sql.types.{IntegerType, StringType}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.{DecimalType, IntegerType, StringType}
class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -54,4 +55,26 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {
ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil, false :: Nil).genCode(ctx)
assert(ctx.inlinedMutableStates.isEmpty)
}
+
+ test("SPARK-28369: honor nullOnOverflow config for ScalaUDF") {
+ withSQLConf(SQLConf.DECIMAL_OPERATIONS_NULL_ON_OVERFLOW.key -> "false") {
+ val udf = ScalaUDF(
+ (a: java.math.BigDecimal) => a.multiply(new java.math.BigDecimal(100)),
+ DecimalType.SYSTEM_DEFAULT,
+ Literal(BigDecimal("12345678901234567890.123")) :: Nil, false :: Nil)
+ val e1 = intercept[ArithmeticException](udf.eval())
+ assert(e1.getMessage.contains("cannot be represented as Decimal"))
+ val e2 = intercept[SparkException] {
+ checkEvaluationWithUnsafeProjection(udf, null)
+ }
+ assert(e2.getCause.isInstanceOf[ArithmeticException])
+ }
+ withSQLConf(SQLConf.DECIMAL_OPERATIONS_NULL_ON_OVERFLOW.key -> "true") {
+ val udf = ScalaUDF(
+ (a: java.math.BigDecimal) => a.multiply(new java.math.BigDecimal(100)),
+ DecimalType.SYSTEM_DEFAULT,
+ Literal(BigDecimal("12345678901234567890.123")) :: Nil, false :: Nil)
+ checkEvaluation(udf, null)
+ }
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
index 78ae131328644..75ff07637fccc 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer
import scala.reflect.runtime.universe.TypeTag
+import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
@@ -26,7 +27,8 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor
-import org.apache.spark.sql.types.StringType
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.{StringType, StructType}
class ColumnPruningSuite extends PlanTest {
@@ -101,6 +103,81 @@ class ColumnPruningSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}
+ test("Nested column pruning for Generate") {
+ def runTest(
+ origGenerator: Generator,
+ replacedGenerator: Seq[String] => Generator,
+ aliasedExprs: Seq[String] => Seq[Expression],
+ unrequiredChildIndex: Seq[Int],
+ generatorOutputNames: Seq[String]) {
+ withSQLConf(SQLConf.NESTED_PRUNING_ON_EXPRESSIONS.key -> "true") {
+ val structType = StructType.fromDDL("d double, e array, f double, g double, " +
+ "h array>")
+ val input = LocalRelation('a.int, 'b.int, 'c.struct(structType))
+ val generatorOutputs = generatorOutputNames.map(UnresolvedAttribute(_))
+
+ val selectedExprs = Seq(UnresolvedAttribute("a"), 'c.getField("d")) ++
+ generatorOutputs
+
+ val query =
+ input
+ .generate(origGenerator, outputNames = generatorOutputNames)
+ .select(selectedExprs: _*)
+ .analyze
+
+ val optimized = Optimize.execute(query)
+
+ val aliases = NestedColumnAliasingSuite.collectGeneratedAliases(optimized)
+
+ val selectedFields = UnresolvedAttribute("a") +: aliasedExprs(aliases)
+ val finalSelectedExprs = Seq(UnresolvedAttribute("a"), $"${aliases(0)}".as("c.d")) ++
+ generatorOutputs
+
+ val correctAnswer =
+ input
+ .select(selectedFields: _*)
+ .generate(replacedGenerator(aliases),
+ unrequiredChildIndex = unrequiredChildIndex,
+ outputNames = generatorOutputNames)
+ .select(finalSelectedExprs: _*)
+ .analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+ }
+
+ runTest(
+ Explode('c.getField("e")),
+ aliases => Explode($"${aliases(1)}".as("c.e")),
+ aliases => Seq('c.getField("d").as(aliases(0)), 'c.getField("e").as(aliases(1))),
+ Seq(2),
+ Seq("explode")
+ )
+ runTest(Stack(2 :: 'c.getField("f") :: 'c.getField("g") :: Nil),
+ aliases => Stack(2 :: $"${aliases(1)}".as("c.f") :: $"${aliases(2)}".as("c.g") :: Nil),
+ aliases => Seq(
+ 'c.getField("d").as(aliases(0)),
+ 'c.getField("f").as(aliases(1)),
+ 'c.getField("g").as(aliases(2))),
+ Seq(2, 3),
+ Seq("stack")
+ )
+ runTest(
+ PosExplode('c.getField("e")),
+ aliases => PosExplode($"${aliases(1)}".as("c.e")),
+ aliases => Seq('c.getField("d").as(aliases(0)), 'c.getField("e").as(aliases(1))),
+ Seq(2),
+ Seq("pos", "explode")
+ )
+ runTest(
+ Inline('c.getField("h")),
+ aliases => Inline($"${aliases(1)}".as("c.h")),
+ aliases => Seq('c.getField("d").as(aliases(0)), 'c.getField("h").as(aliases(1))),
+ Seq(2),
+ Seq("h1", "h2")
+ )
+ }
+
test("Column pruning for Project on Sort") {
val input = LocalRelation('a.int, 'b.string, 'c.double)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala
index ab2bd6dff1265..2351d8321c5f3 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala
@@ -29,6 +29,8 @@ import org.apache.spark.sql.types.{StringType, StructField, StructType}
class NestedColumnAliasingSuite extends SchemaPruningTest {
+ import NestedColumnAliasingSuite._
+
object Optimize extends RuleExecutor[LogicalPlan] {
val batches = Batch("Nested column pruning", FixedPoint(100),
ColumnPruning,
@@ -264,9 +266,10 @@ class NestedColumnAliasingSuite extends SchemaPruningTest {
.analyze
comparePlans(optimized, expected)
}
+}
-
- private def collectGeneratedAliases(query: LogicalPlan): ArrayBuffer[String] = {
+object NestedColumnAliasingSuite {
+ def collectGeneratedAliases(query: LogicalPlan): ArrayBuffer[String] = {
val aliases = ArrayBuffer[String]()
query.transformAllExpressions {
case a @ Alias(_, name) if name.startsWith("_gen_alias_") =>
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala
index d008b3c78fac3..dd84170e26200 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala
@@ -19,11 +19,11 @@ package org.apache.spark.sql.catalyst.parser
import java.util.Locale
-import org.apache.spark.sql.catalog.v2.expressions.{ApplyTransform, BucketTransform, DaysTransform, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, YearsTransform}
+import org.apache.spark.sql.catalog.v2.expressions.{ApplyTransform, BucketTransform, DaysTransform, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, Transform, YearsTransform}
import org.apache.spark.sql.catalyst.analysis.AnalysisTest
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.catalyst.plans.logical.sql.{AlterTableAddColumnsStatement, AlterTableAlterColumnStatement, AlterTableDropColumnsStatement, AlterTableRenameColumnStatement, AlterTableSetLocationStatement, AlterTableSetPropertiesStatement, AlterTableUnsetPropertiesStatement, AlterViewSetPropertiesStatement, AlterViewUnsetPropertiesStatement, CreateTableAsSelectStatement, CreateTableStatement, DropTableStatement, DropViewStatement, QualifiedColType}
+import org.apache.spark.sql.catalyst.plans.logical.sql.{AlterTableAddColumnsStatement, AlterTableAlterColumnStatement, AlterTableDropColumnsStatement, AlterTableRenameColumnStatement, AlterTableSetLocationStatement, AlterTableSetPropertiesStatement, AlterTableUnsetPropertiesStatement, AlterViewSetPropertiesStatement, AlterViewUnsetPropertiesStatement, CreateTableAsSelectStatement, CreateTableStatement, DropTableStatement, DropViewStatement, QualifiedColType, ReplaceTableAsSelectStatement, ReplaceTableStatement}
import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType, TimestampType}
import org.apache.spark.unsafe.types.UTF8String
@@ -47,82 +47,71 @@ class DDLParserSuite extends AnalysisTest {
comparePlans(parsePlan(sql), expected, checkAnalysis = false)
}
- test("create table using - schema") {
- val sql = "CREATE TABLE my_tab(a INT COMMENT 'test', b STRING) USING parquet"
-
- parsePlan(sql) match {
- case create: CreateTableStatement =>
- assert(create.tableName == Seq("my_tab"))
- assert(create.tableSchema == new StructType()
- .add("a", IntegerType, nullable = true, "test")
- .add("b", StringType))
- assert(create.partitioning.isEmpty)
- assert(create.bucketSpec.isEmpty)
- assert(create.properties.isEmpty)
- assert(create.provider == "parquet")
- assert(create.options.isEmpty)
- assert(create.location.isEmpty)
- assert(create.comment.isEmpty)
- assert(!create.ifNotExists)
-
- case other =>
- fail(s"Expected to parse ${classOf[CreateTableStatement].getClass.getName} from query," +
- s"got ${other.getClass.getName}: $sql")
+ test("create/replace table using - schema") {
+ val createSql = "CREATE TABLE my_tab(a INT COMMENT 'test', b STRING) USING parquet"
+ val replaceSql = "REPLACE TABLE my_tab(a INT COMMENT 'test', b STRING) USING parquet"
+ val expectedTableSpec = TableSpec(
+ Seq("my_tab"),
+ Some(new StructType()
+ .add("a", IntegerType, nullable = true, "test")
+ .add("b", StringType)),
+ Seq.empty[Transform],
+ None,
+ Map.empty[String, String],
+ "parquet",
+ Map.empty[String, String],
+ None,
+ None)
+
+ Seq(createSql, replaceSql).foreach { sql =>
+ testCreateOrReplaceDdl(sql, expectedTableSpec, expectedIfNotExists = false)
}
intercept("CREATE TABLE my_tab(a: INT COMMENT 'test', b: STRING) USING parquet",
"no viable alternative at input")
}
- test("create table - with IF NOT EXISTS") {
+ test("create/replace table - with IF NOT EXISTS") {
val sql = "CREATE TABLE IF NOT EXISTS my_tab(a INT, b STRING) USING parquet"
-
- parsePlan(sql) match {
- case create: CreateTableStatement =>
- assert(create.tableName == Seq("my_tab"))
- assert(create.tableSchema == new StructType().add("a", IntegerType).add("b", StringType))
- assert(create.partitioning.isEmpty)
- assert(create.bucketSpec.isEmpty)
- assert(create.properties.isEmpty)
- assert(create.provider == "parquet")
- assert(create.options.isEmpty)
- assert(create.location.isEmpty)
- assert(create.comment.isEmpty)
- assert(create.ifNotExists)
-
- case other =>
- fail(s"Expected to parse ${classOf[CreateTableStatement].getClass.getName} from query," +
- s"got ${other.getClass.getName}: $sql")
- }
+ testCreateOrReplaceDdl(
+ sql,
+ TableSpec(
+ Seq("my_tab"),
+ Some(new StructType().add("a", IntegerType).add("b", StringType)),
+ Seq.empty[Transform],
+ None,
+ Map.empty[String, String],
+ "parquet",
+ Map.empty[String, String],
+ None,
+ None),
+ expectedIfNotExists = true)
}
- test("create table - with partitioned by") {
- val query = "CREATE TABLE my_tab(a INT comment 'test', b STRING) " +
+ test("create/replace table - with partitioned by") {
+ val createSql = "CREATE TABLE my_tab(a INT comment 'test', b STRING) " +
"USING parquet PARTITIONED BY (a)"
-
- parsePlan(query) match {
- case create: CreateTableStatement =>
- assert(create.tableName == Seq("my_tab"))
- assert(create.tableSchema == new StructType()
- .add("a", IntegerType, nullable = true, "test")
- .add("b", StringType))
- assert(create.partitioning == Seq(IdentityTransform(FieldReference("a"))))
- assert(create.bucketSpec.isEmpty)
- assert(create.properties.isEmpty)
- assert(create.provider == "parquet")
- assert(create.options.isEmpty)
- assert(create.location.isEmpty)
- assert(create.comment.isEmpty)
- assert(!create.ifNotExists)
-
- case other =>
- fail(s"Expected to parse ${classOf[CreateTableStatement].getClass.getName} from query," +
- s"got ${other.getClass.getName}: $query")
+ val replaceSql = "REPLACE TABLE my_tab(a INT comment 'test', b STRING) " +
+ "USING parquet PARTITIONED BY (a)"
+ val expectedTableSpec = TableSpec(
+ Seq("my_tab"),
+ Some(new StructType()
+ .add("a", IntegerType, nullable = true, "test")
+ .add("b", StringType)),
+ Seq(IdentityTransform(FieldReference("a"))),
+ None,
+ Map.empty[String, String],
+ "parquet",
+ Map.empty[String, String],
+ None,
+ None)
+ Seq(createSql, replaceSql).foreach { sql =>
+ testCreateOrReplaceDdl(sql, expectedTableSpec, expectedIfNotExists = false)
}
}
- test("create table - partitioned by transforms") {
- val sql =
+ test("create/replace table - partitioned by transforms") {
+ val createSql =
"""
|CREATE TABLE my_tab (a INT, b STRING, ts TIMESTAMP) USING parquet
|PARTITIONED BY (
@@ -135,154 +124,151 @@ class DDLParserSuite extends AnalysisTest {
| foo(a, "bar", 34))
""".stripMargin
- parsePlan(sql) match {
- case create: CreateTableStatement =>
- assert(create.tableName == Seq("my_tab"))
- assert(create.tableSchema == new StructType()
- .add("a", IntegerType)
- .add("b", StringType)
- .add("ts", TimestampType))
- assert(create.partitioning == Seq(
- IdentityTransform(FieldReference("a")),
- BucketTransform(LiteralValue(16, IntegerType), Seq(FieldReference("b"))),
- YearsTransform(FieldReference("ts")),
- MonthsTransform(FieldReference("ts")),
- DaysTransform(FieldReference("ts")),
- HoursTransform(FieldReference("ts")),
- ApplyTransform("foo", Seq(
- FieldReference("a"),
- LiteralValue(UTF8String.fromString("bar"), StringType),
- LiteralValue(34, IntegerType)))))
- assert(create.bucketSpec.isEmpty)
- assert(create.properties.isEmpty)
- assert(create.provider == "parquet")
- assert(create.options.isEmpty)
- assert(create.location.isEmpty)
- assert(create.comment.isEmpty)
- assert(!create.ifNotExists)
-
- case other =>
- fail(s"Expected to parse ${classOf[CreateTableStatement].getClass.getName} from query," +
- s"got ${other.getClass.getName}: $sql")
+ val replaceSql =
+ """
+ |REPLACE TABLE my_tab (a INT, b STRING, ts TIMESTAMP) USING parquet
+ |PARTITIONED BY (
+ | a,
+ | bucket(16, b),
+ | years(ts),
+ | months(ts),
+ | days(ts),
+ | hours(ts),
+ | foo(a, "bar", 34))
+ """.stripMargin
+ val expectedTableSpec = TableSpec(
+ Seq("my_tab"),
+ Some(new StructType()
+ .add("a", IntegerType)
+ .add("b", StringType)
+ .add("ts", TimestampType)),
+ Seq(
+ IdentityTransform(FieldReference("a")),
+ BucketTransform(LiteralValue(16, IntegerType), Seq(FieldReference("b"))),
+ YearsTransform(FieldReference("ts")),
+ MonthsTransform(FieldReference("ts")),
+ DaysTransform(FieldReference("ts")),
+ HoursTransform(FieldReference("ts")),
+ ApplyTransform("foo", Seq(
+ FieldReference("a"),
+ LiteralValue(UTF8String.fromString("bar"), StringType),
+ LiteralValue(34, IntegerType)))),
+ None,
+ Map.empty[String, String],
+ "parquet",
+ Map.empty[String, String],
+ None,
+ None)
+ Seq(createSql, replaceSql).foreach { sql =>
+ testCreateOrReplaceDdl(sql, expectedTableSpec, expectedIfNotExists = false)
}
}
- test("create table - with bucket") {
- val query = "CREATE TABLE my_tab(a INT, b STRING) USING parquet " +
+ test("create/replace table - with bucket") {
+ val createSql = "CREATE TABLE my_tab(a INT, b STRING) USING parquet " +
"CLUSTERED BY (a) SORTED BY (b) INTO 5 BUCKETS"
- parsePlan(query) match {
- case create: CreateTableStatement =>
- assert(create.tableName == Seq("my_tab"))
- assert(create.tableSchema == new StructType().add("a", IntegerType).add("b", StringType))
- assert(create.partitioning.isEmpty)
- assert(create.bucketSpec.contains(BucketSpec(5, Seq("a"), Seq("b"))))
- assert(create.properties.isEmpty)
- assert(create.provider == "parquet")
- assert(create.options.isEmpty)
- assert(create.location.isEmpty)
- assert(create.comment.isEmpty)
- assert(!create.ifNotExists)
-
- case other =>
- fail(s"Expected to parse ${classOf[CreateTableStatement].getClass.getName} from query," +
- s"got ${other.getClass.getName}: $query")
+ val replaceSql = "REPLACE TABLE my_tab(a INT, b STRING) USING parquet " +
+ "CLUSTERED BY (a) SORTED BY (b) INTO 5 BUCKETS"
+
+ val expectedTableSpec = TableSpec(
+ Seq("my_tab"),
+ Some(new StructType().add("a", IntegerType).add("b", StringType)),
+ Seq.empty[Transform],
+ Some(BucketSpec(5, Seq("a"), Seq("b"))),
+ Map.empty[String, String],
+ "parquet",
+ Map.empty[String, String],
+ None,
+ None)
+ Seq(createSql, replaceSql).foreach { sql =>
+ testCreateOrReplaceDdl(sql, expectedTableSpec, expectedIfNotExists = false)
}
}
- test("create table - with comment") {
- val sql = "CREATE TABLE my_tab(a INT, b STRING) USING parquet COMMENT 'abc'"
-
- parsePlan(sql) match {
- case create: CreateTableStatement =>
- assert(create.tableName == Seq("my_tab"))
- assert(create.tableSchema == new StructType().add("a", IntegerType).add("b", StringType))
- assert(create.partitioning.isEmpty)
- assert(create.bucketSpec.isEmpty)
- assert(create.properties.isEmpty)
- assert(create.provider == "parquet")
- assert(create.options.isEmpty)
- assert(create.location.isEmpty)
- assert(create.comment.contains("abc"))
- assert(!create.ifNotExists)
-
- case other =>
- fail(s"Expected to parse ${classOf[CreateTableStatement].getClass.getName} from query," +
- s"got ${other.getClass.getName}: $sql")
+ test("create/replace table - with comment") {
+ val createSql = "CREATE TABLE my_tab(a INT, b STRING) USING parquet COMMENT 'abc'"
+ val replaceSql = "REPLACE TABLE my_tab(a INT, b STRING) USING parquet COMMENT 'abc'"
+ val expectedTableSpec = TableSpec(
+ Seq("my_tab"),
+ Some(new StructType().add("a", IntegerType).add("b", StringType)),
+ Seq.empty[Transform],
+ None,
+ Map.empty[String, String],
+ "parquet",
+ Map.empty[String, String],
+ None,
+ Some("abc"))
+ Seq(createSql, replaceSql).foreach{ sql =>
+ testCreateOrReplaceDdl(sql, expectedTableSpec, expectedIfNotExists = false)
}
}
- test("create table - with table properties") {
- val sql = "CREATE TABLE my_tab(a INT, b STRING) USING parquet TBLPROPERTIES('test' = 'test')"
-
- parsePlan(sql) match {
- case create: CreateTableStatement =>
- assert(create.tableName == Seq("my_tab"))
- assert(create.tableSchema == new StructType().add("a", IntegerType).add("b", StringType))
- assert(create.partitioning.isEmpty)
- assert(create.bucketSpec.isEmpty)
- assert(create.properties == Map("test" -> "test"))
- assert(create.provider == "parquet")
- assert(create.options.isEmpty)
- assert(create.location.isEmpty)
- assert(create.comment.isEmpty)
- assert(!create.ifNotExists)
-
- case other =>
- fail(s"Expected to parse ${classOf[CreateTableStatement].getClass.getName} from query," +
- s"got ${other.getClass.getName}: $sql")
+ test("create/replace table - with table properties") {
+ val createSql = "CREATE TABLE my_tab(a INT, b STRING) USING parquet" +
+ " TBLPROPERTIES('test' = 'test')"
+ val replaceSql = "REPLACE TABLE my_tab(a INT, b STRING) USING parquet" +
+ " TBLPROPERTIES('test' = 'test')"
+ val expectedTableSpec = TableSpec(
+ Seq("my_tab"),
+ Some(new StructType().add("a", IntegerType).add("b", StringType)),
+ Seq.empty[Transform],
+ None,
+ Map("test" -> "test"),
+ "parquet",
+ Map.empty[String, String],
+ None,
+ None)
+ Seq(createSql, replaceSql).foreach { sql =>
+ testCreateOrReplaceDdl(sql, expectedTableSpec, expectedIfNotExists = false)
}
}
- test("create table - with location") {
- val sql = "CREATE TABLE my_tab(a INT, b STRING) USING parquet LOCATION '/tmp/file'"
-
- parsePlan(sql) match {
- case create: CreateTableStatement =>
- assert(create.tableName == Seq("my_tab"))
- assert(create.tableSchema == new StructType().add("a", IntegerType).add("b", StringType))
- assert(create.partitioning.isEmpty)
- assert(create.bucketSpec.isEmpty)
- assert(create.properties.isEmpty)
- assert(create.provider == "parquet")
- assert(create.options.isEmpty)
- assert(create.location.contains("/tmp/file"))
- assert(create.comment.isEmpty)
- assert(!create.ifNotExists)
-
- case other =>
- fail(s"Expected to parse ${classOf[CreateTableStatement].getClass.getName} from query," +
- s"got ${other.getClass.getName}: $sql")
+ test("create/replace table - with location") {
+ val createSql = "CREATE TABLE my_tab(a INT, b STRING) USING parquet LOCATION '/tmp/file'"
+ val replaceSql = "REPLACE TABLE my_tab(a INT, b STRING) USING parquet LOCATION '/tmp/file'"
+ val expectedTableSpec = TableSpec(
+ Seq("my_tab"),
+ Some(new StructType().add("a", IntegerType).add("b", StringType)),
+ Seq.empty[Transform],
+ None,
+ Map.empty[String, String],
+ "parquet",
+ Map.empty[String, String],
+ Some("/tmp/file"),
+ None)
+ Seq(createSql, replaceSql).foreach { sql =>
+ testCreateOrReplaceDdl(sql, expectedTableSpec, expectedIfNotExists = false)
}
}
- test("create table - byte length literal table name") {
- val sql = "CREATE TABLE 1m.2g(a INT) USING parquet"
-
- parsePlan(sql) match {
- case create: CreateTableStatement =>
- assert(create.tableName == Seq("1m", "2g"))
- assert(create.tableSchema == new StructType().add("a", IntegerType))
- assert(create.partitioning.isEmpty)
- assert(create.bucketSpec.isEmpty)
- assert(create.properties.isEmpty)
- assert(create.provider == "parquet")
- assert(create.options.isEmpty)
- assert(create.location.isEmpty)
- assert(create.comment.isEmpty)
- assert(!create.ifNotExists)
-
- case other =>
- fail(s"Expected to parse ${classOf[CreateTableStatement].getClass.getName} from query," +
- s"got ${other.getClass.getName}: $sql")
+ test("create/replace table - byte length literal table name") {
+ val createSql = "CREATE TABLE 1m.2g(a INT) USING parquet"
+ val replaceSql = "REPLACE TABLE 1m.2g(a INT) USING parquet"
+ val expectedTableSpec = TableSpec(
+ Seq("1m", "2g"),
+ Some(new StructType().add("a", IntegerType)),
+ Seq.empty[Transform],
+ None,
+ Map.empty[String, String],
+ "parquet",
+ Map.empty[String, String],
+ None,
+ None)
+ Seq(createSql, replaceSql).foreach { sql =>
+ testCreateOrReplaceDdl(sql, expectedTableSpec, expectedIfNotExists = false)
}
}
- test("Duplicate clauses - create table") {
+ test("Duplicate clauses - create/replace table") {
def createTableHeader(duplicateClause: String): String = {
s"CREATE TABLE my_tab(a INT, b STRING) USING parquet $duplicateClause $duplicateClause"
}
+ def replaceTableHeader(duplicateClause: String): String = {
+ s"CREATE TABLE my_tab(a INT, b STRING) USING parquet $duplicateClause $duplicateClause"
+ }
+
intercept(createTableHeader("TBLPROPERTIES('test' = 'test2')"),
"Found duplicate clauses: TBLPROPERTIES")
intercept(createTableHeader("LOCATION '/tmp/file'"),
@@ -293,31 +279,44 @@ class DDLParserSuite extends AnalysisTest {
"Found duplicate clauses: CLUSTERED BY")
intercept(createTableHeader("PARTITIONED BY (b)"),
"Found duplicate clauses: PARTITIONED BY")
+
+ intercept(replaceTableHeader("TBLPROPERTIES('test' = 'test2')"),
+ "Found duplicate clauses: TBLPROPERTIES")
+ intercept(replaceTableHeader("LOCATION '/tmp/file'"),
+ "Found duplicate clauses: LOCATION")
+ intercept(replaceTableHeader("COMMENT 'a table'"),
+ "Found duplicate clauses: COMMENT")
+ intercept(replaceTableHeader("CLUSTERED BY(b) INTO 256 BUCKETS"),
+ "Found duplicate clauses: CLUSTERED BY")
+ intercept(replaceTableHeader("PARTITIONED BY (b)"),
+ "Found duplicate clauses: PARTITIONED BY")
}
test("support for other types in OPTIONS") {
- val sql =
+ val createSql =
"""
|CREATE TABLE table_name USING json
|OPTIONS (a 1, b 0.1, c TRUE)
""".stripMargin
-
- parsePlan(sql) match {
- case create: CreateTableStatement =>
- assert(create.tableName == Seq("table_name"))
- assert(create.tableSchema == new StructType)
- assert(create.partitioning.isEmpty)
- assert(create.bucketSpec.isEmpty)
- assert(create.properties.isEmpty)
- assert(create.provider == "json")
- assert(create.options == Map("a" -> "1", "b" -> "0.1", "c" -> "true"))
- assert(create.location.isEmpty)
- assert(create.comment.isEmpty)
- assert(!create.ifNotExists)
-
- case other =>
- fail(s"Expected to parse ${classOf[CreateTableStatement].getClass.getName} from query," +
- s"got ${other.getClass.getName}: $sql")
+ val replaceSql =
+ """
+ |REPLACE TABLE table_name USING json
+ |OPTIONS (a 1, b 0.1, c TRUE)
+ """.stripMargin
+ Seq(createSql, replaceSql).foreach { sql =>
+ testCreateOrReplaceDdl(
+ sql,
+ TableSpec(
+ Seq("table_name"),
+ Some(new StructType),
+ Seq.empty[Transform],
+ Option.empty[BucketSpec],
+ Map.empty[String, String],
+ "json",
+ Map("a" -> "1", "b" -> "0.1", "c" -> "true"),
+ None,
+ None),
+ expectedIfNotExists = false)
}
}
@@ -352,27 +351,28 @@ class DDLParserSuite extends AnalysisTest {
|AS SELECT * FROM src
""".stripMargin
- checkParsing(s1)
- checkParsing(s2)
- checkParsing(s3)
-
- def checkParsing(sql: String): Unit = {
- parsePlan(sql) match {
- case create: CreateTableAsSelectStatement =>
- assert(create.tableName == Seq("mydb", "page_view"))
- assert(create.partitioning.isEmpty)
- assert(create.bucketSpec.isEmpty)
- assert(create.properties == Map("p1" -> "v1", "p2" -> "v2"))
- assert(create.provider == "parquet")
- assert(create.options.isEmpty)
- assert(create.location.contains("/user/external/page_view"))
- assert(create.comment.contains("This is the staging page view table"))
- assert(create.ifNotExists)
+ val s4 =
+ """
+ |REPLACE TABLE mydb.page_view
+ |USING parquet
+ |COMMENT 'This is the staging page view table'
+ |LOCATION '/user/external/page_view'
+ |TBLPROPERTIES ('p1'='v1', 'p2'='v2')
+ |AS SELECT * FROM src
+ """.stripMargin
- case other =>
- fail(s"Expected to parse ${classOf[CreateTableAsSelectStatement].getClass.getName} " +
- s"from query, got ${other.getClass.getName}: $sql")
- }
+ val expectedTableSpec = TableSpec(
+ Seq("mydb", "page_view"),
+ None,
+ Seq.empty[Transform],
+ None,
+ Map("p1" -> "v1", "p2" -> "v2"),
+ "parquet",
+ Map.empty[String, String],
+ Some("/user/external/page_view"),
+ Some("This is the staging page view table"))
+ Seq(s1, s2, s3, s4).foreach { sql =>
+ testCreateOrReplaceDdl(sql, expectedTableSpec, expectedIfNotExists = true)
}
}
@@ -403,6 +403,28 @@ class DDLParserSuite extends AnalysisTest {
parseCompare(s"DROP VIEW IF EXISTS view", DropViewStatement(Seq("view"), ifExists = true))
}
+ private def testCreateOrReplaceDdl(
+ sqlStatement: String,
+ tableSpec: TableSpec,
+ expectedIfNotExists: Boolean) {
+ val parsedPlan = parsePlan(sqlStatement)
+ val newTableToken = sqlStatement.split(" ")(0).trim.toUpperCase(Locale.ROOT)
+ parsedPlan match {
+ case create: CreateTableStatement if newTableToken == "CREATE" =>
+ assert(create.ifNotExists == expectedIfNotExists)
+ case ctas: CreateTableAsSelectStatement if newTableToken == "CREATE" =>
+ assert(ctas.ifNotExists == expectedIfNotExists)
+ case replace: ReplaceTableStatement if newTableToken == "REPLACE" =>
+ case replace: ReplaceTableAsSelectStatement if newTableToken == "REPLACE" =>
+ case other =>
+ fail("First token in statement does not match the expected parsed plan; CREATE TABLE" +
+ " should create a CreateTableStatement, and REPLACE TABLE should create a" +
+ s" ReplaceTableStatement. Statement: $sqlStatement, plan type:" +
+ s" ${parsedPlan.getClass.getName}.")
+ }
+ assert(TableSpec(parsedPlan) === tableSpec)
+ }
+
// ALTER VIEW view_name SET TBLPROPERTIES ('comment' = new_comment);
// ALTER VIEW view_name UNSET TBLPROPERTIES [IF EXISTS] ('comment', 'key');
test("alter view: alter view properties") {
@@ -593,4 +615,69 @@ class DDLParserSuite extends AnalysisTest {
Seq(Seq("x"), Seq("y"), Seq("a", "b", "c"))))
}
}
+
+ private case class TableSpec(
+ name: Seq[String],
+ schema: Option[StructType],
+ partitioning: Seq[Transform],
+ bucketSpec: Option[BucketSpec],
+ properties: Map[String, String],
+ provider: String,
+ options: Map[String, String],
+ location: Option[String],
+ comment: Option[String])
+
+ private object TableSpec {
+ def apply(plan: LogicalPlan): TableSpec = {
+ plan match {
+ case create: CreateTableStatement =>
+ TableSpec(
+ create.tableName,
+ Some(create.tableSchema),
+ create.partitioning,
+ create.bucketSpec,
+ create.properties,
+ create.provider,
+ create.options,
+ create.location,
+ create.comment)
+ case replace: ReplaceTableStatement =>
+ TableSpec(
+ replace.tableName,
+ Some(replace.tableSchema),
+ replace.partitioning,
+ replace.bucketSpec,
+ replace.properties,
+ replace.provider,
+ replace.options,
+ replace.location,
+ replace.comment)
+ case ctas: CreateTableAsSelectStatement =>
+ TableSpec(
+ ctas.tableName,
+ Some(ctas.asSelect).filter(_.resolved).map(_.schema),
+ ctas.partitioning,
+ ctas.bucketSpec,
+ ctas.properties,
+ ctas.provider,
+ ctas.options,
+ ctas.location,
+ ctas.comment)
+ case rtas: ReplaceTableAsSelectStatement =>
+ TableSpec(
+ rtas.tableName,
+ Some(rtas.asSelect).filter(_.resolved).map(_.schema),
+ rtas.partitioning,
+ rtas.bucketSpec,
+ rtas.properties,
+ rtas.provider,
+ rtas.options,
+ rtas.location,
+ rtas.comment)
+ case other =>
+ fail(s"Expected to parse Create, CTAS, Replace, or RTAS plan" +
+ s" from query, got ${other.getClass.getName}.")
+ }
+ }
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala
index 2c491cd376edc..1a6286067a618 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala
@@ -58,6 +58,7 @@ class DataTypeParserSuite extends SparkFunSuite {
checkDataType("varchAr(20)", StringType)
checkDataType("cHaR(27)", StringType)
checkDataType("BINARY", BinaryType)
+ checkDataType("interval", CalendarIntervalType)
checkDataType("array", ArrayType(DoubleType, true))
checkDataType("Array |