diff --git a/core/src/main/resources/error/error-classes.json b/core/src/main/resources/error/error-classes.json index aa72576c8f..816df79e50 100644 --- a/core/src/main/resources/error/error-classes.json +++ b/core/src/main/resources/error/error-classes.json @@ -19,8 +19,7 @@ }, "CANNOT_DECODE_URL" : { "message" : [ - "Cannot decode url : .", - "
" + "Cannot decode url : ." ], "sqlState" : "42000" }, diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index d01de3b9ed..ab2bd1b780 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -772,6 +772,7 @@ private[spark] class Executor( uncaughtExceptionHandler.uncaughtException(Thread.currentThread(), t) } } finally { + cleanMDCForTask(taskName, mdcProperties) runningTasks.remove(taskId) if (taskStarted) { // This means the task was successfully deserialized, its stageId and stageAttemptId @@ -788,8 +789,6 @@ private[spark] class Executor( private def setMDCForTask(taskName: String, mdc: Seq[(String, String)]): Unit = { try { - // make sure we run the task with the user-specified mdc properties only - MDC.clear() mdc.foreach { case (key, value) => MDC.put(key, value) } // avoid overriding the takName by the user MDC.put("mdc.taskName", taskName) @@ -798,6 +797,15 @@ private[spark] class Executor( } } + private def cleanMDCForTask(taskName: String, mdc: Seq[(String, String)]): Unit = { + try { + mdc.foreach { case (key, _) => MDC.remove(key) } + MDC.remove("mdc.taskName") + } catch { + case _: NoSuchFieldError => logInfo("MDC is not supported.") + } + } + /** * Supervises the killing / cancellation of a task by sending the interrupted flag, optionally * sending a Thread.interrupt(), and monitoring the task until it finishes. @@ -897,6 +905,7 @@ private[spark] class Executor( } } } finally { + cleanMDCForTask(taskRunner.taskName, taskRunner.mdcProperties) // Clean up entries in the taskReaperForTask map. taskReaperForTask.synchronized { taskReaperForTask.get(taskId).foreach { taskReaperInMap => diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 9d1a56843c..07d3d3e077 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -1956,6 +1956,13 @@ package object config { .intConf .createWithDefault(10) + private[spark] val RDD_LIMIT_INITIAL_NUM_PARTITIONS = + ConfigBuilder("spark.rdd.limit.initialNumPartitions") + .version("3.4.0") + .intConf + .checkValue(_ > 0, "value should be positive") + .createWithDefault(1) + private[spark] val RDD_LIMIT_SCALE_UP_FACTOR = ConfigBuilder("spark.rdd.limit.scaleUpFactor") .version("2.1.0") diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index d6379156cc..9f89c82db3 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -25,6 +25,7 @@ import scala.reflect.ClassTag import org.apache.spark.{ComplexFutureAction, FutureAction, JobSubmitter} import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.{RDD_LIMIT_INITIAL_NUM_PARTITIONS, RDD_LIMIT_SCALE_UP_FACTOR} import org.apache.spark.util.ThreadUtils /** @@ -72,6 +73,8 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi val results = new ArrayBuffer[T] val totalParts = self.partitions.length + val scaleUpFactor = Math.max(self.conf.get(RDD_LIMIT_SCALE_UP_FACTOR), 2) + /* Recursively triggers jobs to scan partitions until either the requested number of elements are retrieved, or the partitions to scan are exhausted. @@ -84,18 +87,18 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi } else { // The number of partitions to try in this iteration. It is ok for this number to be // greater than totalParts because we actually cap it at totalParts in runJob. - var numPartsToTry = 1L + var numPartsToTry = self.conf.get(RDD_LIMIT_INITIAL_NUM_PARTITIONS) if (partsScanned > 0) { - // If we didn't find any rows after the previous iteration, quadruple and retry. - // Otherwise, interpolate the number of partitions we need to try, but overestimate it - // by 50%. We also cap the estimation in the end. - if (results.size == 0) { - numPartsToTry = partsScanned * 4L + // If we didn't find any rows after the previous iteration, multiply by + // limitScaleUpFactor and retry. Otherwise, interpolate the number of partitions we need + // to try, but overestimate it by 50%. We also cap the estimation in the end. + if (results.isEmpty) { + numPartsToTry = partsScanned * scaleUpFactor } else { // the left side of max is >=1 whenever partsScanned >= 2 numPartsToTry = Math.max(1, (1.5 * num * partsScanned / results.size).toInt - partsScanned) - numPartsToTry = Math.min(numPartsToTry, partsScanned * 4L) + numPartsToTry = Math.min(numPartsToTry, partsScanned * scaleUpFactor) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index b7284d2512..461510b252 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -547,7 +547,6 @@ abstract class RDD[T: ClassTag]( s"Fraction must be nonnegative, but got ${fraction}") withScope { - require(fraction >= 0.0, "Negative fraction value: " + fraction) if (withReplacement) { new PartitionwiseSampledRDD[T, T](this, new PoissonSampler[T](fraction), true, seed) } else { @@ -1445,12 +1444,12 @@ abstract class RDD[T: ClassTag]( while (buf.size < num && partsScanned < totalParts) { // The number of partitions to try in this iteration. It is ok for this number to be // greater than totalParts because we actually cap it at totalParts in runJob. - var numPartsToTry = 1L + var numPartsToTry = conf.get(RDD_LIMIT_INITIAL_NUM_PARTITIONS) val left = num - buf.size if (partsScanned > 0) { - // If we didn't find any rows after the previous iteration, quadruple and retry. - // Otherwise, interpolate the number of partitions we need to try, but overestimate - // it by 50%. We also cap the estimation in the end. + // If we didn't find any rows after the previous iteration, multiply by + // limitScaleUpFactor and retry. Otherwise, interpolate the number of partitions we need + // to try, but overestimate it by 50%. We also cap the estimation in the end. if (buf.isEmpty) { numPartsToTry = partsScanned * scaleUpFactor } else { diff --git a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala index 6138611499..6bb5058f5e 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala @@ -250,4 +250,13 @@ class BitSet(numBits: Int) extends Serializable { /** Return the number of longs it would take to hold numBits. */ private def bit2words(numBits: Int) = ((numBits - 1) >> 6) + 1 + + override def equals(other: Any): Boolean = other match { + case otherSet: BitSet => Arrays.equals(words, otherSet.words) + case _ => false + } + + override def hashCode(): Int = { + Arrays.hashCode(words) + } } diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index ccef00c8e9..c64573f7a0 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.rdd import java.io.{File, IOException, ObjectInputStream, ObjectOutputStream} import java.lang.management.ManagementFactory +import java.util.concurrent.atomic.AtomicInteger import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, HashMap} @@ -32,8 +33,9 @@ import org.scalatest.concurrent.Eventually import org.apache.spark._ import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} -import org.apache.spark.internal.config.RDD_PARALLEL_LISTING_THRESHOLD +import org.apache.spark.internal.config.{RDD_LIMIT_INITIAL_NUM_PARTITIONS, RDD_PARALLEL_LISTING_THRESHOLD} import org.apache.spark.rdd.RDDSuiteUtils._ +import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.util.{ThreadUtils, Utils} class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually { @@ -1255,6 +1257,41 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually { assert(numPartsPerLocation(locations(1)) > 0.4 * numCoalescedPartitions) } + test("SPARK-40211: customize initialNumPartitions for take") { + val totalElements = 100 + val numToTake = 50 + val rdd = sc.parallelize(0 to totalElements, totalElements) + import scala.language.reflectiveCalls + val jobCountListener = new SparkListener { + private var count: AtomicInteger = new AtomicInteger(0) + def getCount: Int = count.get + def reset(): Unit = count.set(0) + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + count.incrementAndGet() + } + } + sc.addSparkListener(jobCountListener) + // with default RDD_LIMIT_INITIAL_NUM_PARTITIONS = 1, expecting multiple jobs + rdd.take(numToTake) + sc.listenerBus.waitUntilEmpty() + assert(jobCountListener.getCount > 1) + jobCountListener.reset() + rdd.takeAsync(numToTake).get() + sc.listenerBus.waitUntilEmpty() + assert(jobCountListener.getCount > 1) + + // setting RDD_LIMIT_INITIAL_NUM_PARTITIONS to large number(1000), expecting only 1 job + sc.conf.set(RDD_LIMIT_INITIAL_NUM_PARTITIONS, 1000) + jobCountListener.reset() + rdd.take(numToTake) + sc.listenerBus.waitUntilEmpty() + assert(jobCountListener.getCount == 1) + jobCountListener.reset() + rdd.takeAsync(numToTake).get() + sc.listenerBus.waitUntilEmpty() + assert(jobCountListener.getCount == 1) + } + // NOTE // Below tests calling sc.stop() have to be the last tests in this suite. If there are tests // running after them and if they access sc those tests will fail as sc is already closed, because diff --git a/docs/sql-ref-functions-udf-hive.md b/docs/sql-ref-functions-udf-hive.md index 819c446c41..ed05902c09 100644 --- a/docs/sql-ref-functions-udf-hive.md +++ b/docs/sql-ref-functions-udf-hive.md @@ -52,10 +52,22 @@ SELECT testUDF(value) FROM t; | 2.0| | 3.0| +--------------+ + +-- Register `UDFSubstr` and use it in Spark SQL. +-- Note that, it can achieve better performance if the return types and method parameters use Java primitives. +-- e.g., UDFSubstr. The data processing method is UTF8String <-> Text <-> String. we can avoid UTF8String <-> Text. +CREATE TEMPORARY FUNCTION hive_substr AS 'org.apache.hadoop.hive.ql.udf.UDFSubstr'; + +select hive_substr('Spark SQL', 1, 5) as value; ++-----+ +|value| ++-----+ +|Spark| ++-----+ ``` -An example below uses [GenericUDTFExplode](https://github.com/apache/hive/blob/master/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDTFExplode.java) derived from [GenericUDTF](https://github.com/apache/hive/blob/master/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDF.java). +An example below uses [GenericUDTFExplode](https://github.com/apache/hive/blob/master/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDTFExplode.java) derived from [GenericUDTF](https://github.com/apache/hive/blob/master/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDTF.java). ```sql -- Register `GenericUDTFExplode` and use it in Spark SQL diff --git a/hadoop-cloud/pom.xml b/hadoop-cloud/pom.xml index f4e0557ea7..1cded76057 100644 --- a/hadoop-cloud/pom.xml +++ b/hadoop-cloud/pom.xml @@ -213,8 +213,10 @@ true - src/hadoop-3/main/scala - src/hadoop-3/test/scala + src/hadoop-3/main/java + src/hadoop-3/test/java + src/hadoop-3/main/scala + src/hadoop-3/test/scala @@ -240,26 +242,28 @@ build-helper-maven-plugin - add-scala-sources + add-extra-sources generate-sources add-source - ${extra.source.dir} + ${extra.java.source.dir} + ${extra.scala.source.dir} - add-scala-test-sources + add-extra-test-sources generate-test-sources add-test-source - ${extra.testsource.dir} + ${extra.java.testsource.dir} + ${extra.scala.testsource.dir} diff --git a/hadoop-cloud/src/hadoop-3/test/java/org/apache/spark/internal/io/cloud/abortable/AbortableFileSystem.java b/hadoop-cloud/src/hadoop-3/test/java/org/apache/spark/internal/io/cloud/abortable/AbortableFileSystem.java new file mode 100644 index 0000000000..5c7f68f437 --- /dev/null +++ b/hadoop-cloud/src/hadoop-3/test/java/org/apache/spark/internal/io/cloud/abortable/AbortableFileSystem.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.internal.io.cloud.abortable; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.OutputStream; +import java.net.URI; +import java.util.concurrent.atomic.AtomicBoolean; + +import org.apache.hadoop.fs.*; +import org.apache.hadoop.fs.permission.FsPermission; +import org.apache.hadoop.util.Progressable; + +public class AbortableFileSystem extends RawLocalFileSystem { + + public static String ABORTABLE_FS_SCHEME = "abortable"; + + @Override + public URI getUri() { + return URI.create(ABORTABLE_FS_SCHEME + ":///"); + } + + public FSDataOutputStream create(Path f, FsPermission permission, boolean overwrite, + int bufferSize, short replication, long blockSize, Progressable progress) throws IOException { + FSDataOutputStream out = this.create(f, overwrite, bufferSize, replication, blockSize, + progress, permission); + return out; + } + + private FSDataOutputStream create(Path f, boolean overwrite, int bufferSize, short replication, + long blockSize, Progressable progress, FsPermission permission) throws IOException { + if (this.exists(f) && !overwrite) { + throw new FileAlreadyExistsException("File already exists: " + f); + } else { + Path parent = f.getParent(); + if (parent != null && !this.mkdirs(parent)) { + throw new IOException("Mkdirs failed to create " + parent.toString()); + } else { + return new FSDataOutputStream(this.createOutputStreamWithMode(f, false, permission), null); + } + } + } + + @Override + protected OutputStream createOutputStreamWithMode(Path f, boolean append, + FsPermission permission) throws IOException { + return new AbortableOutputStream(f, append, permission); + } + + class AbortableOutputStream extends ByteArrayOutputStream + implements Abortable, StreamCapabilities { + + private final AtomicBoolean closed = new AtomicBoolean(false); + + private Path f; + + private boolean append; + + private FsPermission permission; + + AbortableOutputStream(Path f, boolean append, FsPermission permission) { + this.f = f; + this.append = append; + this.permission = permission; + } + + @Override + public void close() throws IOException { + if (closed.getAndSet(true)) { + return; + } + + OutputStream output = + AbortableFileSystem.super.createOutputStreamWithMode(f, append, permission); + writeTo(output); + output.close(); + } + + @Override + public AbortableResult abort() { + final boolean isAlreadyClosed = closed.getAndSet(true); + return new AbortableResult() { + public boolean alreadyClosed() { + return isAlreadyClosed; + } + + public IOException anyCleanupException() { + return null; + } + }; + } + + @Override + public boolean hasCapability(String capability) { + return capability == CommonPathCapabilities.ABORTABLE_STREAM; + } + } +} diff --git a/hadoop-cloud/src/hadoop-3/test/java/org/apache/spark/internal/io/cloud/abortable/AbstractAbortableFileSystem.java b/hadoop-cloud/src/hadoop-3/test/java/org/apache/spark/internal/io/cloud/abortable/AbstractAbortableFileSystem.java new file mode 100644 index 0000000000..57ede38a23 --- /dev/null +++ b/hadoop-cloud/src/hadoop-3/test/java/org/apache/spark/internal/io/cloud/abortable/AbstractAbortableFileSystem.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.internal.io.cloud.abortable; + +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.CommonPathCapabilities; +import org.apache.hadoop.fs.DelegateToFileSystem; +import org.apache.hadoop.fs.Path; + +public class AbstractAbortableFileSystem extends DelegateToFileSystem { + + public AbstractAbortableFileSystem( + URI theUri, + Configuration conf) throws IOException, URISyntaxException { + super(theUri, new AbortableFileSystem(), conf, AbortableFileSystem.ABORTABLE_FS_SCHEME, false); + } + + @Override + public boolean hasPathCapability(Path path, String capability) throws IOException { + if (capability == CommonPathCapabilities.ABORTABLE_STREAM) { + return true; + } else { + return super.hasPathCapability(path, capability); + } + } +} diff --git a/hadoop-cloud/src/hadoop-3/test/resources/log4j2.properties b/hadoop-cloud/src/hadoop-3/test/resources/log4j2.properties new file mode 100644 index 0000000000..01a9cafafa --- /dev/null +++ b/hadoop-cloud/src/hadoop-3/test/resources/log4j2.properties @@ -0,0 +1,40 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the file target/unit-tests.log +rootLogger.level = info +rootLogger.appenderRef.file.ref = ${sys:test.appender:-File} + +appender.file.type = File +appender.file.name = File +appender.file.fileName = target/unit-tests.log +appender.file.append = true +appender.file.layout.type = PatternLayout +appender.file.layout.pattern = %d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n%ex + +# Tests that launch java subprocesses can set the "test.appender" system property to +# "console" to avoid having the child process's logs overwrite the unit test's +# log file. +appender.console.type = Console +appender.console.name = STDERR +appender.console.target = SYSTEM_ERR +appender.console.layout.type = PatternLayout +appender.console.layout.pattern = %t: %m%n%ex + +# Ignore messages below warning level from Jetty, because it's a bit verbose +logger.jetty.name = org.spark_project.jetty +logger.jetty.level = warn diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index b631f141a8..5fe463233a 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -1122,6 +1122,7 @@ def takeSample( Examples -------- + >>> import sys >>> rdd = sc.parallelize(range(0, 10)) >>> len(rdd.takeSample(True, 20, 1)) 20 @@ -1129,12 +1130,19 @@ def takeSample( 5 >>> len(rdd.takeSample(False, 15, 3)) 10 + >>> sc.range(0, 10).takeSample(False, sys.maxsize) + Traceback (most recent call last): + ... + ValueError: Sample size cannot be greater than ... """ numStDev = 10.0 - + maxSampleSize = sys.maxsize - int(numStDev * sqrt(sys.maxsize)) if num < 0: raise ValueError("Sample size cannot be negative.") - elif num == 0: + elif num > maxSampleSize: + raise ValueError("Sample size cannot be greater than %d." % maxSampleSize) + + if num == 0 or self.getNumPartitions() == 0: return [] initialCount = self.count() @@ -1149,10 +1157,6 @@ def takeSample( rand.shuffle(samples) return samples - maxSampleSize = sys.maxsize - int(numStDev * sqrt(sys.maxsize)) - if num > maxSampleSize: - raise ValueError("Sample size cannot be greater than %d." % maxSampleSize) - fraction = RDD._computeFractionForSampleSize(num, initialCount, withReplacement) samples = self.sample(withReplacement, fraction, seed).collect() diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index fd7a7247fc..03c16db602 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1938,6 +1938,13 @@ def degrees(col: "ColumnOrName") -> Column: ------- :class:`~pyspark.sql.Column` angle in degrees, as if computed by `java.lang.Math.toDegrees()` + + Examples + -------- + >>> import math + >>> df = spark.range(1) + >>> df.select(degrees(lit(math.pi))).first() + Row(DEGREES(3.14159...)=180.0) """ return _invoke_function_over_columns("degrees", col) @@ -1958,6 +1965,12 @@ def radians(col: "ColumnOrName") -> Column: ------- :class:`~pyspark.sql.Column` angle in radians, as if computed by `java.lang.Math.toRadians()` + + Examples + -------- + >>> df = spark.range(1) + >>> df.select(radians(lit(180))).first() + Row(RADIANS(180)=3.14159...) """ return _invoke_function_over_columns("radians", col) @@ -1996,6 +2009,12 @@ def atan2(col1: Union["ColumnOrName", float], col2: Union["ColumnOrName", float] in polar coordinates that corresponds to the point (`x`, `y`) in Cartesian coordinates, as if computed by `java.lang.Math.atan2()` + + Examples + -------- + >>> df = spark.range(1) + >>> df.select(atan2(lit(1), lit(2))).first() + Row(ATAN2(1, 2)=0.46364...) """ return _invoke_binary_math_function("atan2", col1, col2) @@ -2020,6 +2039,24 @@ def hypot(col1: Union["ColumnOrName", float], col2: Union["ColumnOrName", float] Computes ``sqrt(a^2 + b^2)`` without intermediate overflow or underflow. .. versionadded:: 1.4.0 + + Parameters + ---------- + col1 : str, :class:`~pyspark.sql.Column` or float + a leg. + col2 : str, :class:`~pyspark.sql.Column` or float + b leg. + + Returns + ------- + :class:`~pyspark.sql.Column` + length of the hypotenuse. + + Examples + -------- + >>> df = spark.range(1) + >>> df.select(hypot(lit(1), lit(2))).first() + Row(HYPOT(1, 2)=2.23606...) """ return _invoke_binary_math_function("hypot", col1, col2) @@ -2044,6 +2081,24 @@ def pow(col1: Union["ColumnOrName", float], col2: Union["ColumnOrName", float]) Returns the value of the first argument raised to the power of the second argument. .. versionadded:: 1.4.0 + + Parameters + ---------- + col1 : str, :class:`~pyspark.sql.Column` or float + the base number. + col2 : str, :class:`~pyspark.sql.Column` or float + the exponent number. + + Returns + ------- + :class:`~pyspark.sql.Column` + the base rased to the power the argument. + + Examples + -------- + >>> df = spark.range(1) + >>> df.select(pow(lit(3), lit(2))).first() + Row(POWER(3, 2)=9.0) """ return _invoke_binary_math_function("pow", col1, col2) @@ -2061,6 +2116,11 @@ def pmod(dividend: Union["ColumnOrName", float], divisor: Union["ColumnOrName", divisor : str, :class:`~pyspark.sql.Column` or float the column that contains divisor, or the specified divisor value + Returns + ------- + :class:`~pyspark.sql.Column` + positive value of dividend mod divisor. + Examples -------- >>> from pyspark.sql.functions import pmod @@ -2092,6 +2152,25 @@ def row_number() -> Column: Window function: returns a sequential number starting at 1 within a window partition. .. versionadded:: 1.6.0 + + Returns + ------- + :class:`~pyspark.sql.Column` + the column for calculating row numbers. + + Examples + -------- + >>> from pyspark.sql import Window + >>> df = spark.range(3) + >>> w = Window.orderBy(df.id.desc()) + >>> df.withColumn("desc_order", row_number().over(w)).show() + +---+----------+ + | id|desc_order| + +---+----------+ + | 2| 1| + | 1| 2| + | 0| 3| + +---+----------+ """ return _invoke_function("row_number") @@ -2109,6 +2188,28 @@ def dense_rank() -> Column: This is equivalent to the DENSE_RANK function in SQL. .. versionadded:: 1.6.0 + + Returns + ------- + :class:`~pyspark.sql.Column` + the column for calculating ranks. + + Examples + -------- + >>> from pyspark.sql import Window, types + >>> df = spark.createDataFrame([1, 1, 2, 3, 3, 4], types.IntegerType()) + >>> w = Window.orderBy("value") + >>> df.withColumn("drank", dense_rank().over(w)).show() + +-----+-----+ + |value|drank| + +-----+-----+ + | 1| 1| + | 1| 1| + | 2| 2| + | 3| 3| + | 3| 3| + | 4| 4| + +-----+-----+ """ return _invoke_function("dense_rank") @@ -2126,6 +2227,28 @@ def rank() -> Column: This is equivalent to the RANK function in SQL. .. versionadded:: 1.6.0 + + Returns + ------- + :class:`~pyspark.sql.Column` + the column for calculating ranks. + + Examples + -------- + >>> from pyspark.sql import Window, types + >>> df = spark.createDataFrame([1, 1, 2, 3, 3, 4], types.IntegerType()) + >>> w = Window.orderBy("value") + >>> df.withColumn("drank", rank().over(w)).show() + +-----+-----+ + |value|drank| + +-----+-----+ + | 1| 1| + | 1| 1| + | 2| 3| + | 3| 4| + | 3| 4| + | 4| 6| + +-----+-----+ """ return _invoke_function("rank") @@ -2136,6 +2259,27 @@ def cume_dist() -> Column: i.e. the fraction of rows that are below the current row. .. versionadded:: 1.6.0 + + Returns + ------- + :class:`~pyspark.sql.Column` + the column for calculating cumulative distribution. + + Examples + -------- + >>> from pyspark.sql import Window, types + >>> df = spark.createDataFrame([1, 2, 3, 3, 4], types.IntegerType()) + >>> w = Window.orderBy("value") + >>> df.withColumn("cd", cume_dist().over(w)).show() + +-----+---+ + |value| cd| + +-----+---+ + | 1|0.2| + | 2|0.4| + | 3|0.8| + | 3|0.8| + | 4|1.0| + +-----+---+ """ return _invoke_function("cume_dist") @@ -2145,6 +2289,28 @@ def percent_rank() -> Column: Window function: returns the relative rank (i.e. percentile) of rows within a window partition. .. versionadded:: 1.6.0 + + Returns + ------- + :class:`~pyspark.sql.Column` + the column for calculating relative rank. + + Examples + -------- + >>> from pyspark.sql import Window, types + >>> df = spark.createDataFrame([1, 1, 2, 3, 3, 4], types.IntegerType()) + >>> w = Window.orderBy("value") + >>> df.withColumn("pr", percent_rank().over(w)).show() + +-----+---+ + |value| pr| + +-----+---+ + | 1|0.0| + | 1|0.0| + | 2|0.4| + | 3|0.6| + | 3|0.6| + | 4|1.0| + +-----+---+ """ return _invoke_function("percent_rank") @@ -2189,6 +2355,25 @@ def broadcast(df: DataFrame) -> DataFrame: Marks a DataFrame as small enough for use in broadcast joins. .. versionadded:: 1.6.0 + + Returns + ------- + :class:`~pyspark.sql.DataFrame` + DataFrame marked as ready for broadcast join. + + Examples + -------- + >>> from pyspark.sql import types + >>> df = spark.createDataFrame([1, 2, 3, 3, 4], types.IntegerType()) + >>> df_small = spark.range(3) + >>> df_b = broadcast(df_small) + >>> df.join(df_b, df.value == df_small.id).show() + +-----+---+ + |value| id| + +-----+---+ + | 1| 1| + | 2| 2| + +-----+---+ """ sc = SparkContext._active_spark_context @@ -2201,6 +2386,16 @@ def coalesce(*cols: "ColumnOrName") -> Column: .. versionadded:: 1.4.0 + Parameters + ---------- + cols : :class:`~pyspark.sql.Column` or str + list of columns to work on. + + Returns + ------- + :class:`~pyspark.sql.Column` + value of the first column that is not null. + Examples -------- >>> cDf = spark.createDataFrame([(None, None), (1, None), (None, 2)], ("a", "b")) @@ -2240,6 +2435,18 @@ def corr(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: .. versionadded:: 1.6.0 + Parameters + ---------- + col1 : :class:`~pyspark.sql.Column` or str + first column to calculate correlation. + col1 : :class:`~pyspark.sql.Column` or str + second column to calculate correlation. + + Returns + ------- + :class:`~pyspark.sql.Column` + Pearson Correlation Coefficient of these two column values. + Examples -------- >>> a = range(20) @@ -2257,6 +2464,18 @@ def covar_pop(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: .. versionadded:: 2.0.0 + Parameters + ---------- + col1 : :class:`~pyspark.sql.Column` or str + first column to calculate covariance. + col1 : :class:`~pyspark.sql.Column` or str + second column to calculate covariance. + + Returns + ------- + :class:`~pyspark.sql.Column` + covariance of these two column values. + Examples -------- >>> a = [1] * 10 @@ -2274,6 +2493,18 @@ def covar_samp(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: .. versionadded:: 2.0.0 + Parameters + ---------- + col1 : :class:`~pyspark.sql.Column` or str + first column to calculate covariance. + col1 : :class:`~pyspark.sql.Column` or str + second column to calculate covariance. + + Returns + ------- + :class:`~pyspark.sql.Column` + sample covariance of these two column values. + Examples -------- >>> a = [1] * 10 @@ -2301,13 +2532,40 @@ def count_distinct(col: "ColumnOrName", *cols: "ColumnOrName") -> Column: .. versionadded:: 3.2.0 - Examples - -------- - >>> df.agg(count_distinct(df.age, df.name).alias('c')).collect() - [Row(c=2)] + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + first column to compute on. + cols : :class:`~pyspark.sql.Column` or str + other columns to compute on. - >>> df.agg(count_distinct("age", "name").alias('c')).collect() - [Row(c=2)] + Returns + ------- + :class:`~pyspark.sql.Column` + distinct values of these two column values. + + Examples + -------- + >>> from pyspark.sql import types + >>> df1 = spark.createDataFrame([1, 1, 3], types.IntegerType()) + >>> df2 = spark.createDataFrame([1, 2], types.IntegerType()) + >>> df1.join(df2).show() + +-----+-----+ + |value|value| + +-----+-----+ + | 1| 1| + | 1| 2| + | 1| 1| + | 1| 2| + | 3| 1| + | 3| 2| + +-----+-----+ + >>> df1.join(df2).select(count_distinct(df1.value, df2.value)).show() + +----------------------------+ + |count(DISTINCT value, value)| + +----------------------------+ + | 4| + +----------------------------+ """ sc = SparkContext._active_spark_context assert sc is not None and sc._jvm is not None @@ -2329,13 +2587,36 @@ def first(col: "ColumnOrName", ignorenulls: bool = False) -> Column: The function is non-deterministic because its results depends on the order of the rows which may be non-deterministic after a shuffle. + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + column to fetch first value for. + ignorenulls : :class:`~pyspark.sql.Column` or str + if first value is null then look for first non-null value. + + Returns + ------- + :class:`~pyspark.sql.Column` + first value of the group. + Examples -------- - >>> df = spark.createDataFrame([("Alice", 2), ("Bob", 5)], ("name", "age")) + >>> df = spark.createDataFrame([("Alice", 2), ("Bob", 5), ("Alice", None)], ("name", "age")) + >>> df = df.orderBy(df.age) >>> df.groupby("name").agg(first("age")).orderBy("name").show() +-----+----------+ | name|first(age)| +-----+----------+ + |Alice| null| + | Bob| 5| + +-----+----------+ + + Now, to ignore any nulls we needs to set ``ignorenulls`` to `True` + + >>> df.groupby("name").agg(first("age", ignorenulls=True)).orderBy("name").show() + +-----+----------+ + | name|first(age)| + +-----+----------+ |Alice| 2| | Bob| 5| +-----+----------+ @@ -2350,6 +2631,16 @@ def grouping(col: "ColumnOrName") -> Column: .. versionadded:: 2.0.0 + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + column to check if it's aggregated. + + Returns + ------- + :class:`~pyspark.sql.Column` + returns 1 for aggregated or 0 for not aggregated in the result set. + Examples -------- >>> df.cube("name").agg(grouping("name"), sum("age")).orderBy("name").show() @@ -2377,16 +2668,33 @@ def grouping_id(*cols: "ColumnOrName") -> Column: The list of columns should match with grouping columns exactly, or empty (means all the grouping columns). - Examples - -------- - >>> df.cube("name").agg(grouping_id(), sum("age")).orderBy("name").show() - +-----+-------------+--------+ - | name|grouping_id()|sum(age)| - +-----+-------------+--------+ - | null| 1| 7| - |Alice| 0| 2| - | Bob| 0| 5| - +-----+-------------+--------+ + Parameters + ---------- + cols : :class:`~pyspark.sql.Column` or str + columns to check for. + + Returns + ------- + :class:`~pyspark.sql.Column` + returns level of the grouping it relates to. + + Examples + -------- + >>> df = spark.createDataFrame([(1, "a", "a"), + ... (3, "a", "a"), + ... (4, "b", "c")], ["c1", "c2", "c3"]) + >>> df.cube("c2", "c3").agg(grouping_id(), sum("c1")).orderBy("c2", "c3").show() + +----+----+-------------+-------+ + | c2| c3|grouping_id()|sum(c1)| + +----+----+-------------+-------+ + |null|null| 3| 8| + |null| a| 2| 4| + |null| c| 2| 4| + | a|null| 1| 4| + | a| a| 0| 4| + | b|null| 1| 4| + | b| c| 0| 4| + +----+----+-------------+-------+ """ return _invoke_function_over_seq_of_columns("grouping_id", cols) @@ -2396,34 +2704,77 @@ def input_file_name() -> Column: Creates a string column for the file name of the current Spark task. .. versionadded:: 1.6.0 + + Returns + ------- + :class:`~pyspark.sql.Column` + file names. + + Examples + -------- + >>> import os + >>> path = os.path.abspath(__file__) + >>> df = spark.read.text(path) + >>> df.select(input_file_name()).first() + Row(input_file_name()='file:///...') """ return _invoke_function("input_file_name") def isnan(col: "ColumnOrName") -> Column: - """An expression that returns true iff the column is NaN. + """An expression that returns true if the column is NaN. .. versionadded:: 1.6.0 + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + target column to compute on. + + Returns + ------- + :class:`~pyspark.sql.Column` + True if value is NaN and False otherwise. + Examples -------- >>> df = spark.createDataFrame([(1.0, float('nan')), (float('nan'), 2.0)], ("a", "b")) - >>> df.select(isnan("a").alias("r1"), isnan(df.a).alias("r2")).collect() - [Row(r1=False, r2=False), Row(r1=True, r2=True)] + >>> df.select("a", "b", isnan("a").alias("r1"), isnan(df.b).alias("r2")).show() + +---+---+-----+-----+ + | a| b| r1| r2| + +---+---+-----+-----+ + |1.0|NaN|false| true| + |NaN|2.0| true|false| + +---+---+-----+-----+ """ return _invoke_function_over_columns("isnan", col) def isnull(col: "ColumnOrName") -> Column: - """An expression that returns true iff the column is null. + """An expression that returns true if the column is null. .. versionadded:: 1.6.0 + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + target column to compute on. + + Returns + ------- + :class:`~pyspark.sql.Column` + True if value is null and False otherwise. + Examples -------- >>> df = spark.createDataFrame([(1, None), (None, 2)], ("a", "b")) - >>> df.select(isnull("a").alias("r1"), isnull(df.a).alias("r2")).collect() - [Row(r1=False, r2=False), Row(r1=True, r2=True)] + >>> df.select("a", "b", isnull("a").alias("r1"), isnull(df.b).alias("r2")).show() + +----+----+-----+-----+ + | a| b| r1| r2| + +----+----+-----+-----+ + | 1|null|false| true| + |null| 2| true|false| + +----+----+-----+-----+ """ return _invoke_function_over_columns("isnull", col) @@ -2440,6 +2791,40 @@ def last(col: "ColumnOrName", ignorenulls: bool = False) -> Column: ----- The function is non-deterministic because its results depends on the order of the rows which may be non-deterministic after a shuffle. + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + column to fetch last value for. + ignorenulls : :class:`~pyspark.sql.Column` or str + if last value is null then look for non-null value. + + Returns + ------- + :class:`~pyspark.sql.Column` + last value of the group. + + Examples + -------- + >>> df = spark.createDataFrame([("Alice", 2), ("Bob", 5), ("Alice", None)], ("name", "age")) + >>> df = df.orderBy(df.age.desc()) + >>> df.groupby("name").agg(last("age")).orderBy("name").show() + +-----+---------+ + | name|last(age)| + +-----+---------+ + |Alice| null| + | Bob| 5| + +-----+---------+ + + Now, to ignore any nulls we needs to set ``ignorenulls`` to `True` + + >>> df.groupby("name").agg(last("age", ignorenulls=True)).orderBy("name").show() + +-----+---------+ + | name|last(age)| + +-----+---------+ + |Alice| 2| + | Bob| 5| + +-----+---------+ """ return _invoke_function("last", _to_java_column(col), ignorenulls) @@ -2462,6 +2847,13 @@ def monotonically_increasing_id() -> Column: This expression would return the following IDs: 0, 1, 2, 8589934592 (1L << 33), 8589934593, 8589934594. + Returns + ------- + :class:`~pyspark.sql.Column` + last value of the group. + + Examples + -------- >>> df0 = sc.parallelize(range(2), 2).mapPartitions(lambda x: [(1,), (2,), (3,)]).toDF(['col1']) >>> df0.select(monotonically_increasing_id().alias('id')).collect() [Row(id=0), Row(id=1), Row(id=2), Row(id=8589934592), Row(id=8589934593), Row(id=8589934594)] @@ -2476,6 +2868,18 @@ def nanvl(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: .. versionadded:: 1.6.0 + Parameters + ---------- + col1 : :class:`~pyspark.sql.Column` or str + first column to check. + col2 : :class:`~pyspark.sql.Column` or str + second column to return if first is NaN. + + Returns + ------- + :class:`~pyspark.sql.Column` + value from first column or second if first is NaN . + Examples -------- >>> df = spark.createDataFrame([(1.0, float('nan')), (float('nan'), 2.0)], ("a", "b")) @@ -2493,19 +2897,29 @@ def percentile_approx( """Returns the approximate `percentile` of the numeric column `col` which is the smallest value in the ordered `col` values (sorted from least to greatest) such that no more than `percentage` of `col` values is less than the value or equal to that value. - The value of percentage must be between 0.0 and 1.0. - - The accuracy parameter (default: 10000) - is a positive numeric literal which controls approximation accuracy at the cost of memory. - Higher value of accuracy yields better accuracy, 1.0/accuracy is the relative error - of the approximation. - When percentage is an array, each value of the percentage array must be between 0.0 and 1.0. - In this case, returns the approximate percentile array of column col - at the given percentage array. .. versionadded:: 3.1.0 + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + input column. + percentage : :class:`~pyspark.sql.Column`, float, list of floats or tuple of floats + percentage in decimal (must be between 0.0 and 1.0). + When percentage is an array, each value of the percentage array must be between 0.0 and 1.0. + In this case, returns the approximate percentile array of column col + at the given percentage array. + accuracy : :class:`~pyspark.sql.Column` or float + is a positive numeric literal which controls approximation accuracy + at the cost of memory. Higher value of accuracy yields better accuracy, + 1.0/accuracy is the relative error of the approximation. (default: 10000). + + Returns + ------- + :class:`~pyspark.sql.Column` + approximate `percentile` of the numeric column. + Examples -------- >>> key = (col("id") % 3).alias("key") @@ -2559,6 +2973,16 @@ def rand(seed: Optional[int] = None) -> Column: ----- The function is non-deterministic in general case. + Parameters + ---------- + seed : int (default: None) + seed value for random generator. + + Returns + ------- + :class:`~pyspark.sql.Column` + random values. + Examples -------- >>> df.withColumn('rand', rand(seed=42) * 3).collect() @@ -2581,6 +3005,16 @@ def randn(seed: Optional[int] = None) -> Column: ----- The function is non-deterministic in general case. + Parameters + ---------- + seed : int (default: None) + seed value for random generator. + + Returns + ------- + :class:`~pyspark.sql.Column` + random values. + Examples -------- >>> df.withColumn('randn', randn(seed=42)).collect() @@ -2600,6 +3034,18 @@ def round(col: "ColumnOrName", scale: int = 0) -> Column: .. versionadded:: 1.5.0 + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + input column to round. + scale : int optional default 0 + scale value. + + Returns + ------- + :class:`~pyspark.sql.Column` + rounded values. + Examples -------- >>> spark.createDataFrame([(2.5,)], ['a']).select(round('a', 0).alias('r')).collect() @@ -2615,6 +3061,18 @@ def bround(col: "ColumnOrName", scale: int = 0) -> Column: .. versionadded:: 2.0.0 + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + input column to round. + scale : int optional default 0 + scale value. + + Returns + ------- + :class:`~pyspark.sql.Column` + rounded values. + Examples -------- >>> spark.createDataFrame([(2.5,)], ['a']).select(bround('a', 0).alias('r')).collect() @@ -2640,6 +3098,18 @@ def shiftleft(col: "ColumnOrName", numBits: int) -> Column: .. versionadded:: 3.2.0 + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + input column of values to shift. + numBits : int + number of bits to shift. + + Returns + ------- + :class:`~pyspark.sql.Column` + shifted value. + Examples -------- >>> spark.createDataFrame([(21,)], ['a']).select(shiftleft('a', 1).alias('r')).collect() @@ -2665,6 +3135,18 @@ def shiftright(col: "ColumnOrName", numBits: int) -> Column: .. versionadded:: 3.2.0 + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + input column of values to shift. + numBits : int + number of bits to shift. + + Returns + ------- + :class:`~pyspark.sql.Column` + shifted values. + Examples -------- >>> spark.createDataFrame([(42,)], ['a']).select(shiftright('a', 1).alias('r')).collect() @@ -2690,6 +3172,18 @@ def shiftrightunsigned(col: "ColumnOrName", numBits: int) -> Column: .. versionadded:: 3.2.0 + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + input column of values to shift. + numBits : int + number of bits to shift. + + Returns + ------- + :class:`~pyspark.sql.Column` + shifted value. + Examples -------- >>> df = spark.createDataFrame([(-42,)], ['a']) @@ -2708,6 +3202,11 @@ def spark_partition_id() -> Column: ----- This is non deterministic because it depends on data partitioning and task scheduling. + Returns + ------- + :class:`~pyspark.sql.Column` + partition id the record belongs to. + Examples -------- >>> df.repartition(1).select(spark_partition_id().alias("pid")).collect() @@ -2721,6 +3220,16 @@ def expr(str: str) -> Column: .. versionadded:: 1.5.0 + Parameters + ---------- + str : str + expression defined in string. + + Returns + ------- + :class:`~pyspark.sql.Column` + column representing the expression. + Examples -------- >>> df.select(expr("length(name)")).collect() @@ -2751,6 +3260,11 @@ def struct( cols : list, set, str or :class:`~pyspark.sql.Column` column names or :class:`~pyspark.sql.Column`\\s to contain in the output struct. + Returns + ------- + :class:`~pyspark.sql.Column` + a struct type column of given columns. + Examples -------- >>> df.select(struct('age', 'name').alias("struct")).collect() @@ -2766,10 +3280,20 @@ def struct( def greatest(*cols: "ColumnOrName") -> Column: """ Returns the greatest value of the list of column names, skipping null values. - This function takes at least 2 parameters. It will return null iff all parameters are null. + This function takes at least 2 parameters. It will return null if all parameters are null. .. versionadded:: 1.5.0 + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + columns to check for gratest value. + + Returns + ------- + :class:`~pyspark.sql.Column` + gratest value. + Examples -------- >>> df = spark.createDataFrame([(1, 4, 3)], ['a', 'b', 'c']) @@ -2784,7 +3308,7 @@ def greatest(*cols: "ColumnOrName") -> Column: def least(*cols: "ColumnOrName") -> Column: """ Returns the least value of the list of column names, skipping null values. - This function takes at least 2 parameters. It will return null iff all parameters are null. + This function takes at least 2 parameters. It will return null if all parameters are null. .. versionadded:: 1.5.0 @@ -2793,6 +3317,11 @@ def least(*cols: "ColumnOrName") -> Column: cols : :class:`~pyspark.sql.Column` or str column names or columns to be compared + Returns + ------- + :class:`~pyspark.sql.Column` + least value. + Examples -------- >>> df = spark.createDataFrame([(1, 4, 3)], ['a', 'b', 'c']) @@ -2818,6 +3347,11 @@ def when(condition: Column, value: Any) -> Column: value : a literal value, or a :class:`~pyspark.sql.Column` expression. + Returns + ------- + :class:`~pyspark.sql.Column` + column representing when expression. + Examples -------- >>> df.select(when(df['age'] == 2, 3).otherwise(4).alias("age")).collect() @@ -2851,6 +3385,18 @@ def log(arg1: Union["ColumnOrName", float], arg2: Optional["ColumnOrName"] = Non .. versionadded:: 1.5.0 + Parameters + ---------- + arg1 : :class:`~pyspark.sql.Column`, str or float + base number or actual number (in this case base is `e`) + arg2 : :class:`~pyspark.sql.Column`, str or float + number to calculate logariphm for. + + Returns + ------- + :class:`~pyspark.sql.Column` + logariphm of given value. + Examples -------- >>> df.select(log(10.0, df.age).alias('ten')).rdd.map(lambda l: str(l.ten)[:7]).collect() @@ -2870,6 +3416,16 @@ def log2(col: "ColumnOrName") -> Column: .. versionadded:: 1.5.0 + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + a column to calculate logariphm for. + + Returns + ------- + :class:`~pyspark.sql.Column` + logariphm of given value. + Examples -------- >>> spark.createDataFrame([(4,)], ['a']).select(log2('a').alias('log2')).collect() @@ -2884,6 +3440,20 @@ def conv(col: "ColumnOrName", fromBase: int, toBase: int) -> Column: .. versionadded:: 1.5.0 + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + a column to convert base for. + fromBase: int + from base number. + toBase: int + to base number. + + Returns + ------- + :class:`~pyspark.sql.Column` + logariphm of given value. + Examples -------- >>> df = spark.createDataFrame([("010101",)], ['n']) @@ -2899,6 +3469,16 @@ def factorial(col: "ColumnOrName") -> Column: .. versionadded:: 1.5.0 + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + a column to calculate factorial for. + + Returns + ------- + :class:`~pyspark.sql.Column` + factorial of given value. + Examples -------- >>> df = spark.createDataFrame([(5,)], ['n']) @@ -2925,10 +3505,65 @@ def lag(col: "ColumnOrName", offset: int = 1, default: Optional[Any] = None) -> ---------- col : :class:`~pyspark.sql.Column` or str name of column or expression - offset : int, optional + offset : int, optional default 1 number of row to extend default : optional default value + + Returns + ------- + :class:`~pyspark.sql.Column` + value before current row based on `offset`. + + Examples + -------- + >>> from pyspark.sql import Window + >>> df = spark.createDataFrame([("a", 1), + ... ("a", 2), + ... ("a", 3), + ... ("b", 8), + ... ("b", 2)], ["c1", "c2"]) + >>> df.show() + +---+---+ + | c1| c2| + +---+---+ + | a| 1| + | a| 2| + | a| 3| + | b| 8| + | b| 2| + +---+---+ + >>> w = Window.partitionBy("c1").orderBy("c2") + >>> df.withColumn("previos_value", lag("c2").over(w)).show() + +---+---+-------------+ + | c1| c2|previos_value| + +---+---+-------------+ + | a| 1| null| + | a| 2| 1| + | a| 3| 2| + | b| 2| null| + | b| 8| 2| + +---+---+-------------+ + >>> df.withColumn("previos_value", lag("c2", 1, 0).over(w)).show() + +---+---+-------------+ + | c1| c2|previos_value| + +---+---+-------------+ + | a| 1| 0| + | a| 2| 1| + | a| 3| 2| + | b| 2| 0| + | b| 8| 2| + +---+---+-------------+ + >>> df.withColumn("previos_value", lag("c2", 2, -1).over(w)).show() + +---+---+-------------+ + | c1| c2|previos_value| + +---+---+-------------+ + | a| 1| -1| + | a| 2| -1| + | a| 3| 1| + | b| 2| -1| + | b| 8| -1| + +---+---+-------------+ """ return _invoke_function("lag", _to_java_column(col), offset, default) @@ -2947,10 +3582,65 @@ def lead(col: "ColumnOrName", offset: int = 1, default: Optional[Any] = None) -> ---------- col : :class:`~pyspark.sql.Column` or str name of column or expression - offset : int, optional + offset : int, optional default 1 number of row to extend default : optional default value + + Returns + ------- + :class:`~pyspark.sql.Column` + value after current row based on `offset`. + + Examples + -------- + >>> from pyspark.sql import Window + >>> df = spark.createDataFrame([("a", 1), + ... ("a", 2), + ... ("a", 3), + ... ("b", 8), + ... ("b", 2)], ["c1", "c2"]) + >>> df.show() + +---+---+ + | c1| c2| + +---+---+ + | a| 1| + | a| 2| + | a| 3| + | b| 8| + | b| 2| + +---+---+ + >>> w = Window.partitionBy("c1").orderBy("c2") + >>> df.withColumn("next_value", lead("c2").over(w)).show() + +---+---+----------+ + | c1| c2|next_value| + +---+---+----------+ + | a| 1| 2| + | a| 2| 3| + | a| 3| null| + | b| 2| 8| + | b| 8| null| + +---+---+----------+ + >>> df.withColumn("next_value", lead("c2", 1, 0).over(w)).show() + +---+---+----------+ + | c1| c2|next_value| + +---+---+----------+ + | a| 1| 2| + | a| 2| 3| + | a| 3| 0| + | b| 2| 8| + | b| 8| 0| + +---+---+----------+ + >>> df.withColumn("next_value", lead("c2", 2, -1).over(w)).show() + +---+---+----------+ + | c1| c2|next_value| + +---+---+----------+ + | a| 1| 3| + | a| 2| -1| + | a| 3| -1| + | b| 2| -1| + | b| 8| -1| + +---+---+----------+ """ return _invoke_function("lead", _to_java_column(col), offset, default) @@ -2971,11 +3661,56 @@ def nth_value(col: "ColumnOrName", offset: int, ignoreNulls: Optional[bool] = Fa ---------- col : :class:`~pyspark.sql.Column` or str name of column or expression - offset : int, optional + offset : int number of row to use as the value ignoreNulls : bool, optional indicates the Nth value should skip null in the determination of which row to use + + Returns + ------- + :class:`~pyspark.sql.Column` + value of nth row. + + Examples + -------- + >>> from pyspark.sql import Window + >>> df = spark.createDataFrame([("a", 1), + ... ("a", 2), + ... ("a", 3), + ... ("b", 8), + ... ("b", 2)], ["c1", "c2"]) + >>> df.show() + +---+---+ + | c1| c2| + +---+---+ + | a| 1| + | a| 2| + | a| 3| + | b| 8| + | b| 2| + +---+---+ + >>> w = Window.partitionBy("c1").orderBy("c2") + >>> df.withColumn("nth_value", nth_value("c2", 1).over(w)).show() + +---+---+---------+ + | c1| c2|nth_value| + +---+---+---------+ + | a| 1| 1| + | a| 2| 1| + | a| 3| 1| + | b| 2| 2| + | b| 8| 2| + +---+---+---------+ + >>> df.withColumn("nth_value", nth_value("c2", 2).over(w)).show() + +---+---+---------+ + | c1| c2|nth_value| + +---+---+---------+ + | a| 1| null| + | a| 2| 2| + | a| 3| 2| + | b| 2| null| + | b| 8| 8| + +---+---+---------+ """ return _invoke_function("nth_value", _to_java_column(col), offset, ignoreNulls) @@ -2995,6 +3730,41 @@ def ntile(n: int) -> Column: ---------- n : int an integer + + Returns + ------- + :class:`~pyspark.sql.Column` + portioned group id. + + Examples + -------- + >>> from pyspark.sql import Window + >>> df = spark.createDataFrame([("a", 1), + ... ("a", 2), + ... ("a", 3), + ... ("b", 8), + ... ("b", 2)], ["c1", "c2"]) + >>> df.show() + +---+---+ + | c1| c2| + +---+---+ + | a| 1| + | a| 2| + | a| 3| + | b| 8| + | b| 2| + +---+---+ + >>> w = Window.partitionBy("c1").orderBy("c2") + >>> df.withColumn("ntile", ntile(2).over(w)).show() + +---+---+-----+ + | c1| c2|ntile| + +---+---+-----+ + | a| 1| 1| + | a| 2| 1| + | a| 3| 2| + | b| 2| 1| + | b| 8| 2| + +---+---+-----+ """ return _invoke_function("ntile", int(n)) @@ -3008,6 +3778,21 @@ def current_date() -> Column: All calls of current_date within the same query return the same value. .. versionadded:: 1.5.0 + + Returns + ------- + :class:`~pyspark.sql.Column` + current date. + + Examples + -------- + >>> df = spark.range(1) + >>> df.select(current_date()).show() # doctest: +SKIP + +--------------+ + |current_date()| + +--------------+ + | 2022-08-26| + +--------------+ """ return _invoke_function("current_date") @@ -3018,6 +3803,21 @@ def current_timestamp() -> Column: column. All calls of current_timestamp within the same query return the same value. .. versionadded:: 1.5.0 + + Returns + ------- + :class:`~pyspark.sql.Column` + current date and time. + + Examples + -------- + >>> df = spark.range(1) + >>> df.select(current_timestamp()).show(truncate=False) # doctest: +SKIP + +-----------------------+ + |current_timestamp() | + +-----------------------+ + |2022-08-26 21:23:22.716| + +-----------------------+ """ return _invoke_function("current_timestamp") @@ -3030,16 +3830,20 @@ def localtimestamp() -> Column: .. versionadded:: 3.4.0 + Returns + ------- + :class:`~pyspark.sql.Column` + current local date and time. + Examples -------- - >>> from pyspark.sql.functions import localtimestamp - >>> df = spark.range(0, 100) - >>> df.select(localtimestamp()).distinct().show() - +--------------------+ - | localtimestamp()| - +--------------------+ - |20...-...-... ...:...:...| - +--------------------+ + >>> df = spark.range(1) + >>> df.select(localtimestamp()).show(truncate=False) # doctest: +SKIP + +-----------------------+ + |localtimestamp() | + +-----------------------+ + |2022-08-26 21:28:34.639| + +-----------------------+ """ return _invoke_function("localtimestamp") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 669857b6a1..ae177efa05 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.analysis -import java.lang.reflect.{Method, Modifier} import java.util import java.util.Locale import java.util.concurrent.atomic.AtomicBoolean @@ -47,8 +46,7 @@ import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._ import org.apache.spark.sql.connector.catalog._ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.connector.catalog.TableChange.{After, ColumnPosition} -import org.apache.spark.sql.connector.catalog.functions.{AggregateFunction => V2AggregateFunction, BoundFunction, ScalarFunction, UnboundFunction} -import org.apache.spark.sql.connector.catalog.functions.ScalarFunction.MAGIC_METHOD_NAME +import org.apache.spark.sql.connector.catalog.functions.{AggregateFunction => V2AggregateFunction, ScalarFunction, UnboundFunction} import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation @@ -377,7 +375,7 @@ class Analyzer(override val catalogManager: CatalogManager) _.containsPattern(BINARY_ARITHMETIC), ruleId) { case p: LogicalPlan => p.transformExpressionsUpWithPruning( _.containsPattern(BINARY_ARITHMETIC), ruleId) { - case a @ Add(l, r, f) if a.childrenResolved => (l.dataType, r.dataType) match { + case a @ Add(l, r, mode) if a.childrenResolved => (l.dataType, r.dataType) match { case (DateType, DayTimeIntervalType(DAY, DAY)) => DateAdd(l, ExtractANSIIntervalDays(r)) case (DateType, _: DayTimeIntervalType) => TimeAdd(Cast(l, TimestampType), r) case (DayTimeIntervalType(DAY, DAY), DateType) => DateAdd(r, ExtractANSIIntervalDays(l)) @@ -394,23 +392,25 @@ class Analyzer(override val catalogManager: CatalogManager) a.copy(left = Cast(a.left, a.right.dataType)) case (_: AnsiIntervalType, _: NullType) => a.copy(right = Cast(a.right, a.left.dataType)) - case (DateType, CalendarIntervalType) => DateAddInterval(l, r, ansiEnabled = f) + case (DateType, CalendarIntervalType) => + DateAddInterval(l, r, ansiEnabled = mode == EvalMode.ANSI) case (_, CalendarIntervalType | _: DayTimeIntervalType) => Cast(TimeAdd(l, r), l.dataType) - case (CalendarIntervalType, DateType) => DateAddInterval(r, l, ansiEnabled = f) + case (CalendarIntervalType, DateType) => + DateAddInterval(r, l, ansiEnabled = mode == EvalMode.ANSI) case (CalendarIntervalType | _: DayTimeIntervalType, _) => Cast(TimeAdd(r, l), r.dataType) case (DateType, dt) if dt != StringType => DateAdd(l, r) case (dt, DateType) if dt != StringType => DateAdd(r, l) case _ => a } - case s @ Subtract(l, r, f) if s.childrenResolved => (l.dataType, r.dataType) match { + case s @ Subtract(l, r, mode) if s.childrenResolved => (l.dataType, r.dataType) match { case (DateType, DayTimeIntervalType(DAY, DAY)) => - DateAdd(l, UnaryMinus(ExtractANSIIntervalDays(r), f)) + DateAdd(l, UnaryMinus(ExtractANSIIntervalDays(r), mode == EvalMode.ANSI)) case (DateType, _: DayTimeIntervalType) => - DatetimeSub(l, r, TimeAdd(Cast(l, TimestampType), UnaryMinus(r, f))) + DatetimeSub(l, r, TimeAdd(Cast(l, TimestampType), UnaryMinus(r, mode == EvalMode.ANSI))) case (DateType, _: YearMonthIntervalType) => - DatetimeSub(l, r, DateAddYMInterval(l, UnaryMinus(r, f))) + DatetimeSub(l, r, DateAddYMInterval(l, UnaryMinus(r, mode == EvalMode.ANSI))) case (TimestampType | TimestampNTZType, _: YearMonthIntervalType) => - DatetimeSub(l, r, TimestampAddYMInterval(l, UnaryMinus(r, f))) + DatetimeSub(l, r, TimestampAddYMInterval(l, UnaryMinus(r, mode == EvalMode.ANSI))) case (CalendarIntervalType, CalendarIntervalType) | (_: DayTimeIntervalType, _: DayTimeIntervalType) => s case (_: NullType, _: AnsiIntervalType) => @@ -418,26 +418,27 @@ class Analyzer(override val catalogManager: CatalogManager) case (_: AnsiIntervalType, _: NullType) => s.copy(right = Cast(s.right, s.left.dataType)) case (DateType, CalendarIntervalType) => - DatetimeSub(l, r, DateAddInterval(l, UnaryMinus(r, f), ansiEnabled = f)) + DatetimeSub(l, r, DateAddInterval(l, + UnaryMinus(r, mode == EvalMode.ANSI), ansiEnabled = mode == EvalMode.ANSI)) case (_, CalendarIntervalType | _: DayTimeIntervalType) => - Cast(DatetimeSub(l, r, TimeAdd(l, UnaryMinus(r, f))), l.dataType) + Cast(DatetimeSub(l, r, TimeAdd(l, UnaryMinus(r, mode == EvalMode.ANSI))), l.dataType) case _ if AnyTimestampType.unapply(l) || AnyTimestampType.unapply(r) => SubtractTimestamps(l, r) case (_, DateType) => SubtractDates(l, r) case (DateType, dt) if dt != StringType => DateSub(l, r) case _ => s } - case m @ Multiply(l, r, f) if m.childrenResolved => (l.dataType, r.dataType) match { - case (CalendarIntervalType, _) => MultiplyInterval(l, r, f) - case (_, CalendarIntervalType) => MultiplyInterval(r, l, f) + case m @ Multiply(l, r, mode) if m.childrenResolved => (l.dataType, r.dataType) match { + case (CalendarIntervalType, _) => MultiplyInterval(l, r, mode == EvalMode.ANSI) + case (_, CalendarIntervalType) => MultiplyInterval(r, l, mode == EvalMode.ANSI) case (_: YearMonthIntervalType, _) => MultiplyYMInterval(l, r) case (_, _: YearMonthIntervalType) => MultiplyYMInterval(r, l) case (_: DayTimeIntervalType, _) => MultiplyDTInterval(l, r) case (_, _: DayTimeIntervalType) => MultiplyDTInterval(r, l) case _ => m } - case d @ Divide(l, r, f) if d.childrenResolved => (l.dataType, r.dataType) match { - case (CalendarIntervalType, _) => DivideInterval(l, r, f) + case d @ Divide(l, r, mode) if d.childrenResolved => (l.dataType, r.dataType) match { + case (CalendarIntervalType, _) => DivideInterval(l, r, mode == EvalMode.ANSI) case (_: YearMonthIntervalType, _) => DivideYMInterval(l, r) case (_: DayTimeIntervalType, _) => DivideDTInterval(l, r) case _ => d @@ -2385,33 +2386,7 @@ class Analyzer(override val catalogManager: CatalogManager) throw QueryCompilationErrors.functionWithUnsupportedSyntaxError( scalarFunc.name(), "IGNORE NULLS") } else { - val declaredInputTypes = scalarFunc.inputTypes().toSeq - val argClasses = declaredInputTypes.map(ScalaReflection.dataTypeJavaClass) - findMethod(scalarFunc, MAGIC_METHOD_NAME, argClasses) match { - case Some(m) if Modifier.isStatic(m.getModifiers) => - StaticInvoke(scalarFunc.getClass, scalarFunc.resultType(), - MAGIC_METHOD_NAME, arguments, inputTypes = declaredInputTypes, - propagateNull = false, returnNullable = scalarFunc.isResultNullable, - isDeterministic = scalarFunc.isDeterministic) - case Some(_) => - val caller = Literal.create(scalarFunc, ObjectType(scalarFunc.getClass)) - Invoke(caller, MAGIC_METHOD_NAME, scalarFunc.resultType(), - arguments, methodInputTypes = declaredInputTypes, propagateNull = false, - returnNullable = scalarFunc.isResultNullable, - isDeterministic = scalarFunc.isDeterministic) - case _ => - // TODO: handle functions defined in Scala too - in Scala, even if a - // subclass do not override the default method in parent interface - // defined in Java, the method can still be found from - // `getDeclaredMethod`. - findMethod(scalarFunc, "produceResult", Seq(classOf[InternalRow])) match { - case Some(_) => - ApplyFunctionExpression(scalarFunc, arguments) - case _ => - failAnalysis(s"ScalarFunction '${scalarFunc.name()}' neither implement" + - s" magic method nor override 'produceResult'") - } - } + V2ExpressionUtils.resolveScalarFunction(scalarFunc, arguments) } } @@ -2426,23 +2401,6 @@ class Analyzer(override val catalogManager: CatalogManager) val aggregator = V2Aggregator(aggFunc, arguments) aggregator.toAggregateExpression(u.isDistinct, u.filter) } - - /** - * Check if the input `fn` implements the given `methodName` with parameter types specified - * via `argClasses`. - */ - private def findMethod( - fn: BoundFunction, - methodName: String, - argClasses: Seq[Class[_]]): Option[Method] = { - val cls = fn.getClass - try { - Some(cls.getDeclaredMethod(methodName, argClasses: _*)) - } catch { - case _: NoSuchMethodException => - None - } - } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryEval.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryEval.scala index c179c83bef..a23f4f6194 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryEval.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryEval.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.{DataType, NumericType} case class TryEval(child: Expression) extends UnaryExpression with NullIntolerant { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -77,8 +77,13 @@ case class TryEval(child: Expression) extends UnaryExpression with NullIntoleran // scalastyle:on line.size.limit case class TryAdd(left: Expression, right: Expression, replacement: Expression) extends RuntimeReplaceable with InheritAnalysisRules { - def this(left: Expression, right: Expression) = - this(left, right, TryEval(Add(left, right, failOnError = true))) + def this(left: Expression, right: Expression) = this(left, right, + (left.dataType, right.dataType) match { + case (_: NumericType, _: NumericType) => Add(left, right, EvalMode.TRY) + // TODO: support TRY eval mode on datetime arithmetic expressions. + case _ => TryEval(Add(left, right, EvalMode.ANSI)) + } + ) override def prettyName: String = "try_add" @@ -110,8 +115,13 @@ case class TryAdd(left: Expression, right: Expression, replacement: Expression) // scalastyle:on line.size.limit case class TryDivide(left: Expression, right: Expression, replacement: Expression) extends RuntimeReplaceable with InheritAnalysisRules { - def this(left: Expression, right: Expression) = - this(left, right, TryEval(Divide(left, right, failOnError = true))) + def this(left: Expression, right: Expression) = this(left, right, + (left.dataType, right.dataType) match { + case (_: NumericType, _: NumericType) => Divide(left, right, EvalMode.TRY) + // TODO: support TRY eval mode on datetime arithmetic expressions. + case _ => TryEval(Divide(left, right, EvalMode.ANSI)) + } + ) override def prettyName: String = "try_divide" @@ -144,8 +154,13 @@ case class TryDivide(left: Expression, right: Expression, replacement: Expressio group = "math_funcs") case class TrySubtract(left: Expression, right: Expression, replacement: Expression) extends RuntimeReplaceable with InheritAnalysisRules { - def this(left: Expression, right: Expression) = - this(left, right, TryEval(Subtract(left, right, failOnError = true))) + def this(left: Expression, right: Expression) = this(left, right, + (left.dataType, right.dataType) match { + case (_: NumericType, _: NumericType) => Subtract(left, right, EvalMode.TRY) + // TODO: support TRY eval mode on datetime arithmetic expressions. + case _ => TryEval(Subtract(left, right, EvalMode.ANSI)) + } + ) override def prettyName: String = "try_subtract" @@ -171,8 +186,13 @@ case class TrySubtract(left: Expression, right: Expression, replacement: Express group = "math_funcs") case class TryMultiply(left: Expression, right: Expression, replacement: Expression) extends RuntimeReplaceable with InheritAnalysisRules { - def this(left: Expression, right: Expression) = - this(left, right, TryEval(Multiply(left, right, failOnError = true))) + def this(left: Expression, right: Expression) = this(left, right, + (left.dataType, right.dataType) match { + case (_: NumericType, _: NumericType) => Multiply(left, right, EvalMode.TRY) + // TODO: support TRY eval mode on datetime arithmetic expressions. + case _ => TryEval(Multiply(left, right, EvalMode.ANSI)) + } + ) override def prettyName: String = "try_multiply" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala index c252ea5ccf..64eb307bb9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala @@ -17,13 +17,17 @@ package org.apache.spark.sql.catalyst.expressions +import java.lang.reflect.{Method, Modifier} + import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.SQLConfHelper +import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection, SQLConfHelper} import org.apache.spark.sql.catalyst.analysis.NoSuchFunctionException +import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.connector.catalog.{FunctionCatalog, Identifier} import org.apache.spark.sql.connector.catalog.functions._ +import org.apache.spark.sql.connector.catalog.functions.ScalarFunction.MAGIC_METHOD_NAME import org.apache.spark.sql.connector.expressions.{BucketTransform, Expression => V2Expression, FieldReference, IdentityTransform, NamedReference, NamedTransform, NullOrdering => V2NullOrdering, SortDirection => V2SortDirection, SortOrder => V2SortOrder, SortValue, Transform} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types._ @@ -52,8 +56,11 @@ object V2ExpressionUtils extends SQLConfHelper with Logging { /** * Converts the array of input V2 [[V2SortOrder]] into their counterparts in catalyst. */ - def toCatalystOrdering(ordering: Array[V2SortOrder], query: LogicalPlan): Seq[SortOrder] = { - ordering.map(toCatalyst(_, query).asInstanceOf[SortOrder]) + def toCatalystOrdering( + ordering: Array[V2SortOrder], + query: LogicalPlan, + funCatalogOpt: Option[FunctionCatalog] = None): Seq[SortOrder] = { + ordering.map(toCatalyst(_, query, funCatalogOpt).asInstanceOf[SortOrder]) } def toCatalyst( @@ -143,4 +150,53 @@ object V2ExpressionUtils extends SQLConfHelper with Logging { case V2NullOrdering.NULLS_FIRST => NullsFirst case V2NullOrdering.NULLS_LAST => NullsLast } + + def resolveScalarFunction( + scalarFunc: ScalarFunction[_], + arguments: Seq[Expression]): Expression = { + val declaredInputTypes = scalarFunc.inputTypes().toSeq + val argClasses = declaredInputTypes.map(ScalaReflection.dataTypeJavaClass) + findMethod(scalarFunc, MAGIC_METHOD_NAME, argClasses) match { + case Some(m) if Modifier.isStatic(m.getModifiers) => + StaticInvoke(scalarFunc.getClass, scalarFunc.resultType(), + MAGIC_METHOD_NAME, arguments, inputTypes = declaredInputTypes, + propagateNull = false, returnNullable = scalarFunc.isResultNullable, + isDeterministic = scalarFunc.isDeterministic) + case Some(_) => + val caller = Literal.create(scalarFunc, ObjectType(scalarFunc.getClass)) + Invoke(caller, MAGIC_METHOD_NAME, scalarFunc.resultType(), + arguments, methodInputTypes = declaredInputTypes, propagateNull = false, + returnNullable = scalarFunc.isResultNullable, + isDeterministic = scalarFunc.isDeterministic) + case _ => + // TODO: handle functions defined in Scala too - in Scala, even if a + // subclass do not override the default method in parent interface + // defined in Java, the method can still be found from + // `getDeclaredMethod`. + findMethod(scalarFunc, "produceResult", Seq(classOf[InternalRow])) match { + case Some(_) => + ApplyFunctionExpression(scalarFunc, arguments) + case _ => + throw new AnalysisException(s"ScalarFunction '${scalarFunc.name()}'" + + s" neither implement magic method nor override 'produceResult'") + } + } + } + + /** + * Check if the input `fn` implements the given `methodName` with parameter types specified + * via `argClasses`. + */ + private def findMethod( + fn: BoundFunction, + methodName: String, + argClasses: Seq[Class[_]]): Option[Method] = { + val cls = fn.getClass + try { + Some(cls.getDeclaredMethod(methodName, argClasses: _*)) + } catch { + case _: NoSuchMethodException => + None + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index 36ffcd8f76..9bc2891ae5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -69,7 +69,7 @@ abstract class AverageBase protected def add(left: Expression, right: Expression): Expression = left.dataType match { case _: DecimalType => DecimalAddNoOverflowCheck(left, right, left.dataType) - case _ => Add(left, right, useAnsiAdd) + case _ => Add(left, right, EvalMode.fromBoolean(useAnsiAdd)) } override lazy val aggBufferAttributes = sum :: count :: Nil @@ -103,7 +103,7 @@ abstract class AverageBase If(EqualTo(count, Literal(0L)), Literal(null, DayTimeIntervalType()), DivideDTInterval(sum, count)) case _ => - Divide(sum.cast(resultType), count.cast(resultType), failOnError = false) + Divide(sum.cast(resultType), count.cast(resultType), EvalMode.LEGACY) } protected def getUpdateExpressions: Seq[Expression] = Seq( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index 869a27c616..db8bec7c93 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -65,7 +65,7 @@ abstract class SumBase(child: Expression) extends DeclarativeAggregate private def add(left: Expression, right: Expression): Expression = left.dataType match { case _: DecimalType => DecimalAddNoOverflowCheck(left, right, left.dataType) - case _ => Add(left, right, useAnsiAdd) + case _ => Add(left, right, EvalMode.fromBoolean(useAnsiAdd)) } override lazy val aggBufferAttributes = if (shouldTrackIsEmpty) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 24ac685eac..45e0ec876d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -214,7 +214,14 @@ case class Abs(child: Expression, failOnError: Boolean = SQLConf.get.ansiEnabled abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant with SupportQueryContext { - protected val failOnError: Boolean + protected val evalMode: EvalMode.Value + + protected def failOnError: Boolean = evalMode match { + // The TRY mode executes as if it would fail on errors, except that it would capture the errors + // and return null results. + case EvalMode.ANSI | EvalMode.TRY => true + case _ => false + } override def checkInputDataTypes(): TypeCheckResult = (left.dataType, right.dataType) match { case (l: DecimalType, r: DecimalType) if inputType.acceptsType(l) && inputType.acceptsType(r) => @@ -240,11 +247,11 @@ abstract class BinaryArithmetic extends BinaryOperator s"${getClass.getSimpleName} must override `resultDecimalType`.") } - override def nullable: Boolean = super.nullable || { + override def nullable: Boolean = super.nullable || evalMode == EvalMode.TRY || { if (left.dataType.isInstanceOf[DecimalType]) { // For decimal arithmetic, we may return null even if both inputs are not null, if overflow // happens and this `failOnError` flag is false. - !failOnError + evalMode != EvalMode.ANSI } else { // For non-decimal arithmetic, the calculation always return non-null result when inputs are // not null. If overflow happens, we return either the overflowed value or fail. @@ -349,6 +356,49 @@ abstract class BinaryArithmetic extends BinaryOperator """.stripMargin }) } + + override def nullSafeCodeGen( + ctx: CodegenContext, + ev: ExprCode, + f: (String, String) => String): ExprCode = { + if (evalMode == EvalMode.TRY) { + val tryBlock: (String, String) => String = (eval1, eval2) => { + s""" + |try { + | ${f(eval1, eval2)} + |} catch (Exception e) { + | ${ev.isNull} = true; + |} + |""".stripMargin + } + super.nullSafeCodeGen(ctx, ev, tryBlock) + } else { + super.nullSafeCodeGen(ctx, ev, f) + } + } + + override def eval(input: InternalRow): Any = { + val value1 = left.eval(input) + if (value1 == null) { + null + } else { + val value2 = right.eval(input) + if (value2 == null) { + null + } else { + if (evalMode == EvalMode.TRY) { + try { + nullSafeEval(value1, value2) + } catch { + case _: Exception => + null + } + } else { + nullSafeEval(value1, value2) + } + } + } + } } object BinaryArithmetic { @@ -367,9 +417,10 @@ object BinaryArithmetic { case class Add( left: Expression, right: Expression, - failOnError: Boolean = SQLConf.get.ansiEnabled) extends BinaryArithmetic { + evalMode: EvalMode.Value = EvalMode.fromSQLConf(SQLConf.get)) extends BinaryArithmetic { - def this(left: Expression, right: Expression) = this(left, right, SQLConf.get.ansiEnabled) + def this(left: Expression, right: Expression) = + this(left, right, EvalMode.fromSQLConf(SQLConf.get)) override def inputType: AbstractDataType = TypeCollection.NumericAndInterval @@ -436,9 +487,10 @@ case class Add( case class Subtract( left: Expression, right: Expression, - failOnError: Boolean = SQLConf.get.ansiEnabled) extends BinaryArithmetic { + evalMode: EvalMode.Value = EvalMode.fromSQLConf(SQLConf.get)) extends BinaryArithmetic { - def this(left: Expression, right: Expression) = this(left, right, SQLConf.get.ansiEnabled) + def this(left: Expression, right: Expression) = + this(left, right, EvalMode.fromSQLConf(SQLConf.get)) override def inputType: AbstractDataType = TypeCollection.NumericAndInterval @@ -511,9 +563,10 @@ case class Subtract( case class Multiply( left: Expression, right: Expression, - failOnError: Boolean = SQLConf.get.ansiEnabled) extends BinaryArithmetic { + evalMode: EvalMode.Value = EvalMode.fromSQLConf(SQLConf.get)) extends BinaryArithmetic { - def this(left: Expression, right: Expression) = this(left, right, SQLConf.get.ansiEnabled) + def this(left: Expression, right: Expression) = + this(left, right, EvalMode.fromSQLConf(SQLConf.get)) override def inputType: AbstractDataType = NumericType @@ -698,9 +751,14 @@ trait DivModLike extends BinaryArithmetic { case class Divide( left: Expression, right: Expression, - failOnError: Boolean = SQLConf.get.ansiEnabled) extends DivModLike { + evalMode: EvalMode.Value = EvalMode.fromSQLConf(SQLConf.get)) extends DivModLike { + + def this(left: Expression, right: Expression) = + this(left, right, EvalMode.fromSQLConf(SQLConf.get)) - def this(left: Expression, right: Expression) = this(left, right, SQLConf.get.ansiEnabled) + // `try_divide` has exactly the same behavior as the legacy divide, so here it only executes + // the error code path when `evalMode` is `ANSI`. + protected override def failOnError: Boolean = evalMode == EvalMode.ANSI override def inputType: AbstractDataType = TypeCollection(DoubleType, DecimalType) @@ -762,9 +820,10 @@ case class Divide( case class IntegralDivide( left: Expression, right: Expression, - failOnError: Boolean = SQLConf.get.ansiEnabled) extends DivModLike { + evalMode: EvalMode.Value = EvalMode.fromSQLConf(SQLConf.get)) extends DivModLike { - def this(left: Expression, right: Expression) = this(left, right, SQLConf.get.ansiEnabled) + def this(left: Expression, right: Expression) = this(left, right, + EvalMode.fromSQLConf(SQLConf.get)) override def checkDivideOverflow: Boolean = left.dataType match { case LongType if failOnError => true @@ -835,9 +894,10 @@ case class IntegralDivide( case class Remainder( left: Expression, right: Expression, - failOnError: Boolean = SQLConf.get.ansiEnabled) extends DivModLike { + evalMode: EvalMode.Value = EvalMode.fromSQLConf(SQLConf.get)) extends DivModLike { - def this(left: Expression, right: Expression) = this(left, right, SQLConf.get.ansiEnabled) + def this(left: Expression, right: Expression) = + this(left, right, EvalMode.fromSQLConf(SQLConf.get)) override def inputType: AbstractDataType = NumericType @@ -912,9 +972,10 @@ case class Remainder( case class Pmod( left: Expression, right: Expression, - failOnError: Boolean = SQLConf.get.ansiEnabled) extends BinaryArithmetic { + evalMode: EvalMode.Value = EvalMode.fromSQLConf(SQLConf.get)) extends BinaryArithmetic { - def this(left: Expression, right: Expression) = this(left, right, SQLConf.get.ansiEnabled) + def this(left: Expression, right: Expression) = + this(left, right, EvalMode.fromSQLConf(SQLConf.get)) override def toString: String = s"pmod($left, $right)" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala index 57ab9e2773..a178500fba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.types._ group = "bitwise_funcs") case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic { - protected override val failOnError: Boolean = false + protected override val evalMode: EvalMode.Value = EvalMode.LEGACY override def inputType: AbstractDataType = IntegralType @@ -77,7 +77,7 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme group = "bitwise_funcs") case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic { - protected override val failOnError: Boolean = false + protected override val evalMode: EvalMode.Value = EvalMode.LEGACY override def inputType: AbstractDataType = IntegralType @@ -116,7 +116,7 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet group = "bitwise_funcs") case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic { - protected override val failOnError: Boolean = false + protected override val evalMode: EvalMode.Value = EvalMode.LEGACY override def inputType: AbstractDataType = IntegralType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala index 2b9885743e..de5fde27f6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala @@ -114,7 +114,7 @@ object UrlCodec { UTF8String.fromString(URLDecoder.decode(src.toString, enc.toString)) } catch { case e: IllegalArgumentException => - throw QueryExecutionErrors.illegalUrlError(src, e) + throw QueryExecutionErrors.illegalUrlError(src) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 8b9663f173..834e0e6b21 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -34,7 +34,6 @@ import org.apache.spark.sql.catalyst.util.{FailFastMode, ParseMode, PermissiveMo import org.apache.spark.sql.connector.catalog._ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, UnboundFunction} -import org.apache.spark.sql.connector.expressions.NamedReference import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.{LEGACY_ALLOW_NEGATIVE_SCALE_OF_DECIMAL_ENABLED, LEGACY_CTE_PRECEDENCE_POLICY} @@ -476,20 +475,6 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase { s"CalendarIntervalType, but got ${dt}") } - def viewOutputNumberMismatchQueryColumnNamesError( - output: Seq[Attribute], queryColumnNames: Seq[String]): Throwable = { - new AnalysisException( - s"The view output ${output.mkString("[", ",", "]")} doesn't have the same" + - "number of columns with the query column names " + - s"${queryColumnNames.mkString("[", ",", "]")}") - } - - def attributeNotFoundError(colName: String, child: LogicalPlan): Throwable = { - new AnalysisException( - s"Attribute with name '$colName' is not found in " + - s"'${child.output.map(_.name).mkString("(", ",", ")")}'") - } - def functionUndefinedError(name: FunctionIdentifier): Throwable = { new AnalysisException(s"undefined function $name") } @@ -583,10 +568,6 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase { s"'${db.head}' != '${v1TableName.database.get}'") } - def sqlOnlySupportedWithV1TablesError(sql: String): Throwable = { - new AnalysisException(s"$sql is only supported with v1 tables.") - } - def cannotCreateTableWithBothProviderAndSerdeError( provider: Option[String], maybeSerdeInfo: Option[SerdeInfo]): Throwable = { new AnalysisException( @@ -1453,10 +1434,6 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase { new AnalysisException("Cannot use interval type in the table schema.") } - def cannotPartitionByNestedColumnError(reference: NamedReference): Throwable = { - new AnalysisException(s"Cannot partition by nested column: $reference") - } - def missingCatalogAbilityError(plugin: CatalogPlugin, ability: String): Throwable = { new AnalysisException(s"Catalog ${plugin.name} does not support $ability") } @@ -1530,12 +1507,6 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase { new AnalysisException(msg) } - def lookupFunctionInNonFunctionCatalogError( - ident: Identifier, catalog: CatalogPlugin): Throwable = { - new AnalysisException(s"Trying to lookup function '$ident' in " + - s"catalog '${catalog.name()}', but it is not a FunctionCatalog.") - } - def functionCannotProcessInputError( unbound: UnboundFunction, arguments: Seq[Expression], @@ -1724,10 +1695,6 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase { s"Supported interval fields: ${supportedIds.mkString(", ")}.") } - def invalidYearMonthIntervalType(startFieldName: String, endFieldName: String): Throwable = { - new AnalysisException(s"'interval $startFieldName to $endFieldName' is invalid.") - } - def configRemovedInVersionError( configName: String, version: String, @@ -2235,18 +2202,6 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase { new AnalysisException(s"Boundary end is not a valid integer: $end") } - def databaseDoesNotExistError(dbName: String): Throwable = { - new AnalysisException(s"Database '$dbName' does not exist.") - } - - def tableDoesNotExistInDatabaseError(tableName: String, dbName: String): Throwable = { - new AnalysisException(s"Table '$tableName' does not exist in database '$dbName'.") - } - - def tableOrViewNotFoundInDatabaseError(tableName: String, dbName: String): Throwable = { - new AnalysisException(s"Table or view '$tableName' not found in database '$dbName'") - } - def tableOrViewNotFound(ident: Seq[String]): Throwable = { new AnalysisException(s"Table or view '${ident.quoted}' not found") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 035918b6f4..d8d6139e91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -21,7 +21,6 @@ import java.io.{FileNotFoundException, IOException} import java.lang.reflect.InvocationTargetException import java.net.{URISyntaxException, URL} import java.sql.{SQLException, SQLFeatureNotSupportedException} -import java.text.{ParseException => JavaParseException} import java.time.{DateTimeException, LocalDate} import java.time.temporal.ChronoField import java.util.ConcurrentModificationException @@ -265,12 +264,6 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase { new DateTimeException(newMessage, e.getCause) } - def ansiParseError(e: JavaParseException): JavaParseException = { - val newMessage = s"${e.getMessage}. " + - s"If necessary set ${SQLConf.ANSI_ENABLED.key} to false to bypass this error." - new JavaParseException(newMessage, e.getErrorOffset) - } - def ansiIllegalArgumentError(message: String): IllegalArgumentException = { val newMessage = s"$message. If necessary set ${SQLConf.ANSI_ENABLED.key} " + s"to false to bypass this error." @@ -332,10 +325,10 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase { s"If necessary set ${SQLConf.ANSI_ENABLED.key} to false to bypass this error.", e) } - def illegalUrlError(url: UTF8String, e: IllegalArgumentException): + def illegalUrlError(url: UTF8String): Throwable with SparkThrowable = { new SparkIllegalArgumentException(errorClass = "CANNOT_DECODE_URL", - messageParameters = Array(url.toString, e.getMessage) + messageParameters = Array(url.toString) ) } @@ -369,10 +362,6 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase { s"Cannot generate $codeType code for incomparable type: ${dataType.catalogString}") } - def cannotGenerateCodeForUnsupportedTypeError(dataType: DataType): Throwable = { - new IllegalArgumentException(s"cannot generate code for unsupported type: $dataType") - } - def cannotInterpolateClassIntoCodeBlockError(arg: Any): Throwable = { new IllegalArgumentException( s"Can not interpolate ${arg.getClass.getName} into code block.") @@ -1023,18 +1012,6 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase { new SparkException(s"Failed to merge fields '$leftName' and '$rightName'. ${e.getMessage}") } - def cannotMergeDecimalTypesWithIncompatiblePrecisionAndScaleError( - leftPrecision: Int, rightPrecision: Int, leftScale: Int, rightScale: Int): Throwable = { - new SparkException("Failed to merge decimal types with incompatible " + - s"precision $leftPrecision and $rightPrecision & scale $leftScale and $rightScale") - } - - def cannotMergeDecimalTypesWithIncompatiblePrecisionError( - leftPrecision: Int, rightPrecision: Int): Throwable = { - new SparkException("Failed to merge decimal types with incompatible " + - s"precision $leftPrecision and $rightPrecision") - } - def cannotMergeDecimalTypesWithIncompatibleScaleError( leftScale: Int, rightScale: Int): Throwable = { new SparkException("Failed to merge decimal types with incompatible " + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 2045c59933..4fe01ac760 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NamedRelat import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.{ExposesMetadataColumns, LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.util.{truncatedString, CharVarcharUtils} -import org.apache.spark.sql.connector.catalog.{CatalogPlugin, Identifier, MetadataColumn, SupportsMetadataColumns, Table, TableCapability} +import org.apache.spark.sql.connector.catalog.{CatalogPlugin, FunctionCatalog, Identifier, MetadataColumn, SupportsMetadataColumns, Table, TableCapability} import org.apache.spark.sql.connector.read.{Scan, Statistics => V2Statistics, SupportsReportStatistics} import org.apache.spark.sql.connector.read.streaming.{Offset, SparkDataStream} import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -48,6 +48,10 @@ case class DataSourceV2Relation( import DataSourceV2Implicits._ + lazy val funCatalog: Option[FunctionCatalog] = catalog.collect { + case c: FunctionCatalog => c + } + override lazy val metadataOutput: Seq[AttributeReference] = table match { case hasMeta: SupportsMetadataColumns => val resolve = conf.resolver diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index eb7a6a9105..de25c19a26 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -526,6 +526,16 @@ object SQLConf { .checkValue(_ >= 1, "The shuffle hash join factor cannot be negative.") .createWithDefault(3) + val LIMIT_INITIAL_NUM_PARTITIONS = buildConf("spark.sql.limit.initialNumPartitions") + .internal() + .doc("Initial number of partitions to try when executing a take on a query. Higher values " + + "lead to more partitions read. Lower values might lead to longer execution times as more" + + "jobs will be run") + .version("3.4.0") + .intConf + .checkValue(_ > 0, "value should be positive") + .createWithDefault(1) + val LIMIT_SCALE_UP_FACTOR = buildConf("spark.sql.limit.scaleUpFactor") .internal() .doc("Minimal increase rate in number of partitions between attempts when executing a take " + @@ -4316,6 +4326,8 @@ class SQLConf extends Serializable with Logging { def autoBroadcastJoinThreshold: Long = getConf(AUTO_BROADCASTJOIN_THRESHOLD) + def limitInitialNumPartitions: Int = getConf(LIMIT_INITIAL_NUM_PARTITIONS) + def limitScaleUpFactor: Int = getConf(LIMIT_SCALE_UP_FACTOR) def advancedPartitionPredicatePushdownEnabled: Boolean = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 2bfa072a13..63862ee355 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -95,7 +95,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper stopIndex = Some(7 + query.length -1), sqlText = Some(s"select $query")) withOrigin(o) { - val expr = Add(maxValue, maxValue, failOnError = true) + val expr = Add(maxValue, maxValue, EvalMode.ANSI) checkExceptionInExpression[ArithmeticException](expr, EmptyRow, query) } } @@ -180,7 +180,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper stopIndex = Some(7 + query.length -1), sqlText = Some(s"select $query")) withOrigin(o) { - val expr = Subtract(minValue, maxValue, failOnError = true) + val expr = Subtract(minValue, maxValue, EvalMode.ANSI) checkExceptionInExpression[ArithmeticException](expr, EmptyRow, query) } } @@ -219,7 +219,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper stopIndex = Some(7 + query.length -1), sqlText = Some(s"select $query")) withOrigin(o) { - val expr = Multiply(maxValue, maxValue, failOnError = true) + val expr = Multiply(maxValue, maxValue, EvalMode.ANSI) checkExceptionInExpression[ArithmeticException](expr, EmptyRow, query) } } @@ -264,7 +264,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper stopIndex = Some(7 + query.length -1), sqlText = Some(s"select $query")) withOrigin(o) { - val expr = Divide(Literal(1234.5, DoubleType), Literal(0.0, DoubleType), failOnError = true) + val expr = Divide(Literal(1234.5, DoubleType), Literal(0.0, DoubleType), EvalMode.ANSI) checkExceptionInExpression[ArithmeticException](expr, EmptyRow, query) } } @@ -320,7 +320,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper withOrigin(o) { val expr = IntegralDivide( - Literal(Long.MinValue, LongType), Literal(right, LongType), failOnError = true) + Literal(Long.MinValue, LongType), Literal(right, LongType), EvalMode.ANSI) checkExceptionInExpression[ArithmeticException](expr, EmptyRow, query) } } @@ -367,7 +367,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper stopIndex = Some(7 + query.length -1), sqlText = Some(s"select $query")) withOrigin(o) { - val expression = exprBuilder(Literal(1L, LongType), Literal(0L, LongType), true) + val expression = exprBuilder(Literal(1L, LongType), Literal(0L, LongType), EvalMode.ANSI) checkExceptionInExpression[ArithmeticException](expression, EmptyRow, query) } } @@ -760,24 +760,24 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper } test("SPARK-34677: exact add and subtract of day-time and year-month intervals") { - Seq(true, false).foreach { failOnError => + Seq(EvalMode.ANSI, EvalMode.LEGACY).foreach { evalMode => checkExceptionInExpression[ArithmeticException]( UnaryMinus( Literal.create(Period.ofMonths(Int.MinValue), YearMonthIntervalType()), - failOnError), + evalMode == EvalMode.ANSI), "overflow") checkExceptionInExpression[ArithmeticException]( Subtract( Literal.create(Period.ofMonths(Int.MinValue), YearMonthIntervalType()), Literal.create(Period.ofMonths(10), YearMonthIntervalType()), - failOnError + evalMode ), "overflow") checkExceptionInExpression[ArithmeticException]( Add( Literal.create(Period.ofMonths(Int.MaxValue), YearMonthIntervalType()), Literal.create(Period.ofMonths(10), YearMonthIntervalType()), - failOnError + evalMode ), "overflow") @@ -785,14 +785,14 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper Subtract( Literal.create(Duration.ofDays(-106751991), DayTimeIntervalType()), Literal.create(Duration.ofDays(10), DayTimeIntervalType()), - failOnError + evalMode ), "overflow") checkExceptionInExpression[ArithmeticException]( Add( Literal.create(Duration.ofDays(106751991), DayTimeIntervalType()), Literal.create(Duration.ofDays(10), DayTimeIntervalType()), - failOnError + evalMode ), "overflow") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 229e698fb2..9a6caea59b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1483,7 +1483,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ElementAt(a0, Literal(0)), null) }.getMessage.contains("SQL array indices start at 1") intercept[Exception] { checkEvaluation(ElementAt(a0, Literal(1.1)), null) } - withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { + withSQLConf(SQLConf.ANSI_ENABLED.key -> false.toString) { checkEvaluation(ElementAt(a0, Literal(4)), null) checkEvaluation(ElementAt(a0, Literal(-4)), null) } @@ -1512,7 +1512,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper assert(ElementAt(m0, Literal(1.0)).checkInputDataTypes().isFailure) - withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { + withSQLConf(SQLConf.ANSI_ENABLED.key -> false.toString) { checkEvaluation(ElementAt(m0, Literal("d")), null) checkEvaluation(ElementAt(m1, Literal("a")), null) } @@ -1529,7 +1529,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper MapType(BinaryType, StringType)) val mb1 = Literal.create(Map[Array[Byte], String](), MapType(BinaryType, StringType)) - withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { + withSQLConf(SQLConf.ANSI_ENABLED.key -> false.toString) { checkEvaluation(ElementAt(mb0, Literal(Array[Byte](1, 2, 3))), null) checkEvaluation(ElementAt(mb1, Literal(Array[Byte](1, 2))), null) } @@ -1537,22 +1537,24 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ElementAt(mb0, Literal(Array[Byte](3, 4))), null) // test defaultValueOutOfBound - val delimiter = Literal.create(".", StringType) - val str = StringSplitSQL(Literal.create("11.12.13", StringType), delimiter) - val outOfBoundValue = Some(Literal.create("", StringType)) - - checkEvaluation(ElementAt(str, Literal(3), outOfBoundValue), UTF8String.fromString("13")) - checkEvaluation(ElementAt(str, Literal(1), outOfBoundValue), UTF8String.fromString("11")) - checkEvaluation(ElementAt(str, Literal(10), outOfBoundValue), UTF8String.fromString("")) - checkEvaluation(ElementAt(str, Literal(-10), outOfBoundValue), UTF8String.fromString("")) - - checkEvaluation(ElementAt(StringSplitSQL(Literal.create(null, StringType), delimiter), - Literal(1), outOfBoundValue), null) - checkEvaluation(ElementAt(StringSplitSQL(Literal.create("11.12.13", StringType), - Literal.create(null, StringType)), Literal(1), outOfBoundValue), null) - - checkExceptionInExpression[Exception]( - ElementAt(str, Literal(0), outOfBoundValue), "The index 0 is invalid") + withSQLConf(SQLConf.ANSI_ENABLED.key -> false.toString) { + val delimiter = Literal.create(".", StringType) + val str = StringSplitSQL(Literal.create("11.12.13", StringType), delimiter) + val outOfBoundValue = Some(Literal.create("", StringType)) + + checkEvaluation(ElementAt(str, Literal(3), outOfBoundValue), UTF8String.fromString("13")) + checkEvaluation(ElementAt(str, Literal(1), outOfBoundValue), UTF8String.fromString("11")) + checkEvaluation(ElementAt(str, Literal(10), outOfBoundValue), UTF8String.fromString("")) + checkEvaluation(ElementAt(str, Literal(-10), outOfBoundValue), UTF8String.fromString("")) + + checkEvaluation(ElementAt(StringSplitSQL(Literal.create(null, StringType), delimiter), + Literal(1), outOfBoundValue), null) + checkEvaluation(ElementAt(StringSplitSQL(Literal.create("11.12.13", StringType), + Literal.create(null, StringType)), Literal(1), outOfBoundValue), null) + + checkExceptionInExpression[Exception]( + ElementAt(str, Literal(0), outOfBoundValue), "The index 0 is invalid") + } } test("correctly handles ElementAt nullability for arrays") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryCastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryCastSuite.scala index 4dc7f87d19..9ead075663 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryCastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryCastSuite.scala @@ -104,7 +104,7 @@ class TryCastThrowExceptionSuite extends SparkFunSuite with ExpressionEvalHelper // The method checkExceptionInExpression is overridden in TryCastSuite, so here we have a // new test suite for testing exceptions from the child of `try_cast()`. test("TryCast should not catch the exception from it's child") { - val child = Divide(Literal(1.0), Literal(0.0), failOnError = true) + val child = Divide(Literal(1.0), Literal(0.0), EvalMode.ANSI) checkExceptionInExpression[Exception]( Cast(child, StringType, None, EvalMode.TRY), "Division by zero") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryEvalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryEvalSuite.scala index 1eccd46d96..780a2692e8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryEvalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryEvalSuite.scala @@ -28,7 +28,7 @@ class TryEvalSuite extends SparkFunSuite with ExpressionEvalHelper { ).foreach { case (a, b, expected) => val left = Literal(a) val right = Literal(b) - val input = TryEval(Add(left, right, failOnError = true)) + val input = Add(left, right, EvalMode.TRY) checkEvaluation(input, expected) } } @@ -41,7 +41,7 @@ class TryEvalSuite extends SparkFunSuite with ExpressionEvalHelper { ).foreach { case (a, b, expected) => val left = Literal(a) val right = Literal(b) - val input = TryEval(Divide(left, right, failOnError = true)) + val input = Divide(left, right, EvalMode.TRY) checkEvaluation(input, expected) } } @@ -54,7 +54,7 @@ class TryEvalSuite extends SparkFunSuite with ExpressionEvalHelper { ).foreach { case (a, b, expected) => val left = Literal(a) val right = Literal(b) - val input = TryEval(Subtract(left, right, failOnError = true)) + val input = Subtract(left, right, EvalMode.TRY) checkEvaluation(input, expected) } } @@ -67,8 +67,24 @@ class TryEvalSuite extends SparkFunSuite with ExpressionEvalHelper { ).foreach { case (a, b, expected) => val left = Literal(a) val right = Literal(b) - val input = TryEval(Multiply(left, right, failOnError = true)) + val input = Multiply(left, right, EvalMode.TRY) checkEvaluation(input, expected) } } + + test("Throw exceptions from children") { + val failingChild = Divide(Literal(1.0), Literal(0.0), EvalMode.ANSI) + Seq( + Add(failingChild, Literal(1.0), EvalMode.TRY), + Add(Literal(1.0), failingChild, EvalMode.TRY), + Subtract(failingChild, Literal(1.0), EvalMode.TRY), + Subtract(Literal(1.0), failingChild, EvalMode.TRY), + Multiply(failingChild, Literal(1.0), EvalMode.TRY), + Multiply(Literal(1.0), failingChild, EvalMode.TRY), + Divide(failingChild, Literal(1.0), EvalMode.TRY), + Divide(Literal(1.0), failingChild, EvalMode.TRY) + ).foreach { expr => + checkExceptionInExpression[ArithmeticException](expr, "DIVIDE_BY_ZERO") + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala index 749c8791da..947a5e9f38 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -35,11 +35,11 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { private def canTranslate(b: BinaryOperator) = b match { case _: BinaryComparison => true case _: BitwiseAnd | _: BitwiseOr | _: BitwiseXor => true - case add: Add => add.failOnError - case sub: Subtract => sub.failOnError - case mul: Multiply => mul.failOnError - case div: Divide => div.failOnError - case r: Remainder => r.failOnError + case add: Add => add.evalMode == EvalMode.ANSI + case sub: Subtract => sub.evalMode == EvalMode.ANSI + case mul: Multiply => mul.evalMode == EvalMode.ANSI + case div: Divide => div.evalMode == EvalMode.ANSI + case r: Remainder => r.evalMode == EvalMode.ANSI case _ => false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 50a309d443..a56732fdc1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -469,7 +469,8 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ if (n == 0) { return new Array[InternalRow](0) } - + val limitScaleUpFactor = Math.max(conf.limitScaleUpFactor, 2) + // TODO: refactor and reuse the code from RDD's take() val childRDD = getByteArrayRdd(n, takeFromEnd) val buf = if (takeFromEnd) new ListBuffer[InternalRow] else new ArrayBuffer[InternalRow] @@ -478,12 +479,11 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ while (buf.length < n && partsScanned < totalParts) { // The number of partitions to try in this iteration. It is ok for this number to be // greater than totalParts because we actually cap it at totalParts in runJob. - var numPartsToTry = 1L + var numPartsToTry = conf.limitInitialNumPartitions if (partsScanned > 0) { - // If we didn't find any rows after the previous iteration, quadruple and retry. - // Otherwise, interpolate the number of partitions we need to try, but overestimate - // it by 50%. We also cap the estimation in the end. - val limitScaleUpFactor = Math.max(conf.limitScaleUpFactor, 2) + // If we didn't find any rows after the previous iteration, multiply by + // limitScaleUpFactor and retry. Otherwise, interpolate the number of partitions we need + // to try, but overestimate it by 50%. We also cap the estimation in the end. if (buf.isEmpty) { numPartsToTry = partsScanned * limitScaleUpFactor } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala index ebb6ee3852..1d89e56eeb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala @@ -207,7 +207,7 @@ class ObjectAggregationIterator( if (sortBased) { aggBufferIterator = sortBasedAggregationStore.destructiveIterator() } else { - aggBufferIterator = hashMap.iterator + aggBufferIterator = hashMap.destructiveIterator() } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationMap.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationMap.scala index 9f2cf84a6d..6aede04b06 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationMap.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationMap.scala @@ -45,7 +45,11 @@ class ObjectAggregationMap() { def size: Int = hashMap.size() - def iterator: Iterator[AggregationBufferEntry] = { + /** + * Returns a destructive iterator of AggregationBufferEntry. + * Notice: it is illegal to call any method after `destructiveIterator()` has been called. + */ + def destructiveIterator(): Iterator[AggregationBufferEntry] = { val iter = hashMap.entrySet().iterator() new Iterator[AggregationBufferEntry] { @@ -54,6 +58,7 @@ class ObjectAggregationMap() { } override def next(): AggregationBufferEntry = { val entry = iter.next() + iter.remove() new AggregationBufferEntry(entry.getKey, entry.getValue) } } @@ -77,7 +82,7 @@ class ObjectAggregationMap() { null ) - val mapIterator = iterator + val mapIterator = destructiveIterator() val unsafeAggBufferProjection = UnsafeProjection.create(aggBufferAttributes.map(_.dataType).toArray) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index 3cc69656bb..d498960692 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -530,9 +530,12 @@ object PartitioningUtils extends SQLConfHelper { case _ if value == DEFAULT_PARTITION_NAME => null case NullType => null case StringType => UTF8String.fromString(unescapePathName(value)) - case ByteType | ShortType | IntegerType => Integer.parseInt(value) + case ByteType => Integer.parseInt(value).toByte + case ShortType => Integer.parseInt(value).toShort + case IntegerType => Integer.parseInt(value) case LongType => JLong.parseLong(value) - case FloatType | DoubleType => JDouble.parseDouble(value) + case FloatType => JDouble.parseDouble(value).toFloat + case DoubleType => JDouble.parseDouble(value) case _: DecimalType => Literal(new JBigDecimal(value)).value case DateType => Cast(Literal(value), DateType, Some(zoneId.getId)).eval() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DistributionAndOrderingUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DistributionAndOrderingUtils.scala index 07ede81988..b0b0d7bbc2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DistributionAndOrderingUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DistributionAndOrderingUtils.scala @@ -17,22 +17,33 @@ package org.apache.spark.sql.execution.datasources.v2 -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.analysis.{AnsiTypeCoercion, TypeCoercion} +import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, SortOrder, TransformExpression, V2ExpressionUtils} import org.apache.spark.sql.catalyst.expressions.V2ExpressionUtils._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, RebalancePartitions, RepartitionByExpression, Sort} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.connector.catalog.FunctionCatalog +import org.apache.spark.sql.connector.catalog.functions.ScalarFunction import org.apache.spark.sql.connector.distributions._ import org.apache.spark.sql.connector.write.{RequiresDistributionAndOrdering, Write} import org.apache.spark.sql.errors.QueryCompilationErrors object DistributionAndOrderingUtils { - def prepareQuery(write: Write, query: LogicalPlan): LogicalPlan = write match { + def prepareQuery( + write: Write, + query: LogicalPlan, + funCatalogOpt: Option[FunctionCatalog]): LogicalPlan = write match { case write: RequiresDistributionAndOrdering => val numPartitions = write.requiredNumPartitions() val distribution = write.requiredDistribution match { - case d: OrderedDistribution => toCatalystOrdering(d.ordering(), query) - case d: ClusteredDistribution => d.clustering.map(e => toCatalyst(e, query)).toSeq + case d: OrderedDistribution => + toCatalystOrdering(d.ordering(), query, funCatalogOpt) + .map(e => resolveTransformExpression(e).asInstanceOf[SortOrder]) + case d: ClusteredDistribution => + d.clustering.map(e => toCatalyst(e, query, funCatalogOpt)) + .map(e => resolveTransformExpression(e)).toSeq case _: UnspecifiedDistribution => Seq.empty[Expression] } @@ -53,16 +64,33 @@ object DistributionAndOrderingUtils { query } - val ordering = toCatalystOrdering(write.requiredOrdering, query) + val ordering = toCatalystOrdering(write.requiredOrdering, query, funCatalogOpt) val queryWithDistributionAndOrdering = if (ordering.nonEmpty) { - Sort(ordering, global = false, queryWithDistribution) + Sort( + ordering.map(e => resolveTransformExpression(e).asInstanceOf[SortOrder]), + global = false, + queryWithDistribution) } else { queryWithDistribution } - queryWithDistributionAndOrdering - + // Apply typeCoercionRules since the converted expression from TransformExpression + // implemented ImplicitCastInputTypes + typeCoercionRules.foldLeft(queryWithDistributionAndOrdering)((plan, rule) => rule(plan)) case _ => query } + + private def resolveTransformExpression(expr: Expression): Expression = expr.transform { + case TransformExpression(scalarFunc: ScalarFunction[_], arguments, Some(numBuckets)) => + V2ExpressionUtils.resolveScalarFunction(scalarFunc, Seq(Literal(numBuckets)) ++ arguments) + case TransformExpression(scalarFunc: ScalarFunction[_], arguments, None) => + V2ExpressionUtils.resolveScalarFunction(scalarFunc, arguments) + } + + private def typeCoercionRules: List[Rule[LogicalPlan]] = if (conf.ansiEnabled) { + AnsiTypeCoercion.typeCoercionRules + } else { + TypeCoercion.typeCoercionRules + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala index 21503fda53..9b6f993286 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala @@ -83,15 +83,14 @@ trait FileScan extends Scan protected def seqToString(seq: Seq[Any]): String = seq.mkString("[", ", ", "]") private lazy val (normalizedPartitionFilters, normalizedDataFilters) = { - val output = readSchema().toAttributes val partitionFilterAttributes = AttributeSet(partitionFilters).map(a => a.name -> a).toMap - val dataFiltersAttributes = AttributeSet(dataFilters).map(a => a.name -> a).toMap val normalizedPartitionFilters = ExpressionSet(partitionFilters.map( - QueryPlan.normalizeExpressions(_, - output.map(a => partitionFilterAttributes.getOrElse(a.name, a))))) + QueryPlan.normalizeExpressions(_, fileIndex.partitionSchema.toAttributes + .map(a => partitionFilterAttributes.getOrElse(a.name, a))))) + val dataFiltersAttributes = AttributeSet(dataFilters).map(a => a.name -> a).toMap val normalizedDataFilters = ExpressionSet(dataFilters.map( - QueryPlan.normalizeExpressions(_, - output.map(a => dataFiltersAttributes.getOrElse(a.name, a))))) + QueryPlan.normalizeExpressions(_, dataSchema.toAttributes + .map(a => dataFiltersAttributes.getOrElse(a.name, a))))) (normalizedPartitionFilters, normalizedDataFilters) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanPartitioningAndOrdering.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanPartitioningAndOrdering.scala index 7ea1ca8c24..8ab0dc7072 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanPartitioningAndOrdering.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanPartitioningAndOrdering.scala @@ -20,7 +20,6 @@ import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.expressions.V2ExpressionUtils import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.connector.catalog.FunctionCatalog import org.apache.spark.sql.connector.read.{SupportsReportOrdering, SupportsReportPartitioning} import org.apache.spark.sql.connector.read.partitioning.{KeyGroupedPartitioning, UnknownPartitioning} import org.apache.spark.util.collection.Utils.sequenceToOption @@ -41,14 +40,9 @@ object V2ScanPartitioningAndOrdering extends Rule[LogicalPlan] with SQLConfHelpe private def partitioning(plan: LogicalPlan) = plan.transformDown { case d @ DataSourceV2ScanRelation(relation, scan: SupportsReportPartitioning, _, None, _) => - val funCatalogOpt = relation.catalog.flatMap { - case c: FunctionCatalog => Some(c) - case _ => None - } - val catalystPartitioning = scan.outputPartitioning() match { case kgp: KeyGroupedPartitioning => sequenceToOption(kgp.keys().map( - V2ExpressionUtils.toCatalystOpt(_, relation, funCatalogOpt))) + V2ExpressionUtils.toCatalystOpt(_, relation, relation.funCatalog))) case _: UnknownPartitioning => None case p => throw new IllegalArgumentException("Unsupported data source V2 partitioning " + "type: " + p.getClass.getSimpleName) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala index 2d47d94ff1..afdcf2c870 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala @@ -43,7 +43,7 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper { case a @ AppendData(r: DataSourceV2Relation, query, options, _, None) => val writeBuilder = newWriteBuilder(r.table, options, query.schema) val write = writeBuilder.build() - val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query) + val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, r.funCatalog) a.copy(write = Some(write), query = newQuery) case o @ OverwriteByExpression(r: DataSourceV2Relation, deleteExpr, query, options, _, None) => @@ -67,7 +67,7 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper { throw QueryExecutionErrors.overwriteTableByUnsupportedExpressionError(table) } - val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query) + val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, r.funCatalog) o.copy(write = Some(write), query = newQuery) case o @ OverwritePartitionsDynamic(r: DataSourceV2Relation, query, options, _, None) => @@ -79,7 +79,7 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper { case _ => throw QueryExecutionErrors.dynamicPartitionOverwriteUnsupportedByTableError(table) } - val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query) + val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, r.funCatalog) o.copy(write = Some(write), query = newQuery) case WriteToMicroBatchDataSource( @@ -89,14 +89,15 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper { val write = buildWriteForMicroBatch(table, writeBuilder, outputMode) val microBatchWrite = new MicroBatchWrite(batchId, write.toStreaming) val customMetrics = write.supportedCustomMetrics.toSeq - val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query) + val funCatalogOpt = relation.flatMap(_.funCatalog) + val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, funCatalogOpt) WriteToDataSourceV2(relation, microBatchWrite, newQuery, customMetrics) case rd @ ReplaceData(r: DataSourceV2Relation, _, query, _, None) => val rowSchema = StructType.fromAttributes(rd.dataInput) val writeBuilder = newWriteBuilder(r.table, Map.empty, rowSchema) val write = writeBuilder.build() - val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query) + val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, r.funCatalog) // project away any metadata columns that could be used for distribution and ordering rd.copy(write = Some(write), query = Project(rd.dataInput, newQuery)) diff --git a/sql/core/src/test/resources/sql-tests/inputs/try_arithmetic.sql b/sql/core/src/test/resources/sql-tests/inputs/try_arithmetic.sql index 586680f550..55907b6701 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/try_arithmetic.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/try_arithmetic.sql @@ -4,6 +4,9 @@ SELECT try_add(2147483647, 1); SELECT try_add(-2147483648, -1); SELECT try_add(9223372036854775807L, 1); SELECT try_add(-9223372036854775808L, -1); +SELECT try_add(1, (2147483647 + 1)); +SELECT try_add(1L, (9223372036854775807L + 1L)); +SELECT try_add(1, 1.0 / 0.0); -- Date + Integer SELECT try_add(date'2021-01-01', 1); @@ -32,6 +35,9 @@ SELECT try_add(interval 106751991 day, interval 3 day); SELECT try_divide(1, 0.5); SELECT try_divide(1, 0); SELECT try_divide(0, 0); +SELECT try_divide(1, (2147483647 + 1)); +SELECT try_divide(1L, (9223372036854775807L + 1L)); +SELECT try_divide(1, 1.0 / 0.0); -- Interval / Numeric SELECT try_divide(interval 2 year, 2); @@ -47,6 +53,9 @@ SELECT try_subtract(2147483647, -1); SELECT try_subtract(-2147483648, 1); SELECT try_subtract(9223372036854775807L, -1); SELECT try_subtract(-9223372036854775808L, 1); +SELECT try_subtract(1, (2147483647 + 1)); +SELECT try_subtract(1L, (9223372036854775807L + 1L)); +SELECT try_subtract(1, 1.0 / 0.0); -- Interval - Interval SELECT try_subtract(interval 2 year, interval 3 year); @@ -60,6 +69,9 @@ SELECT try_multiply(2147483647, -2); SELECT try_multiply(-2147483648, 2); SELECT try_multiply(9223372036854775807L, 2); SELECT try_multiply(-9223372036854775808L, -2); +SELECT try_multiply(1, (2147483647 + 1)); +SELECT try_multiply(1L, (9223372036854775807L + 1L)); +SELECT try_multiply(1, 1.0 / 0.0); -- Interval * Numeric SELECT try_multiply(interval 2 year, 2); diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/try_arithmetic.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/try_arithmetic.sql.out index 8622b97a20..914ee064c5 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/try_arithmetic.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/try_arithmetic.sql.out @@ -39,6 +39,76 @@ struct NULL +-- !query +SELECT try_add(1, (2147483647 + 1)) +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkArithmeticException +{ + "errorClass" : "ARITHMETIC_OVERFLOW", + "sqlState" : "22003", + "messageParameters" : { + "message" : "integer overflow", + "alternative" : " Use 'try_add' to tolerate overflow and return NULL instead.", + "config" : "spark.sql.ansi.enabled" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 20, + "stopIndex" : 33, + "fragment" : "2147483647 + 1" + } ] +} + + +-- !query +SELECT try_add(1L, (9223372036854775807L + 1L)) +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkArithmeticException +{ + "errorClass" : "ARITHMETIC_OVERFLOW", + "sqlState" : "22003", + "messageParameters" : { + "message" : "long overflow", + "alternative" : " Use 'try_add' to tolerate overflow and return NULL instead.", + "config" : "spark.sql.ansi.enabled" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 21, + "stopIndex" : 45, + "fragment" : "9223372036854775807L + 1L" + } ] +} + + +-- !query +SELECT try_add(1, 1.0 / 0.0) +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkArithmeticException +{ + "errorClass" : "DIVIDE_BY_ZERO", + "sqlState" : "22012", + "messageParameters" : { + "config" : "\"spark.sql.ansi.enabled\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 19, + "stopIndex" : 27, + "fragment" : "1.0 / 0.0" + } ] +} + + -- !query SELECT try_add(date'2021-01-01', 1) -- !query schema @@ -184,6 +254,76 @@ struct NULL +-- !query +SELECT try_divide(1, (2147483647 + 1)) +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkArithmeticException +{ + "errorClass" : "ARITHMETIC_OVERFLOW", + "sqlState" : "22003", + "messageParameters" : { + "message" : "integer overflow", + "alternative" : " Use 'try_add' to tolerate overflow and return NULL instead.", + "config" : "spark.sql.ansi.enabled" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 23, + "stopIndex" : 36, + "fragment" : "2147483647 + 1" + } ] +} + + +-- !query +SELECT try_divide(1L, (9223372036854775807L + 1L)) +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkArithmeticException +{ + "errorClass" : "ARITHMETIC_OVERFLOW", + "sqlState" : "22003", + "messageParameters" : { + "message" : "long overflow", + "alternative" : " Use 'try_add' to tolerate overflow and return NULL instead.", + "config" : "spark.sql.ansi.enabled" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 24, + "stopIndex" : 48, + "fragment" : "9223372036854775807L + 1L" + } ] +} + + +-- !query +SELECT try_divide(1, 1.0 / 0.0) +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkArithmeticException +{ + "errorClass" : "DIVIDE_BY_ZERO", + "sqlState" : "22012", + "messageParameters" : { + "config" : "\"spark.sql.ansi.enabled\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 22, + "stopIndex" : 30, + "fragment" : "1.0 / 0.0" + } ] +} + + -- !query SELECT try_divide(interval 2 year, 2) -- !query schema @@ -272,6 +412,76 @@ struct NULL +-- !query +SELECT try_subtract(1, (2147483647 + 1)) +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkArithmeticException +{ + "errorClass" : "ARITHMETIC_OVERFLOW", + "sqlState" : "22003", + "messageParameters" : { + "message" : "integer overflow", + "alternative" : " Use 'try_add' to tolerate overflow and return NULL instead.", + "config" : "spark.sql.ansi.enabled" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 25, + "stopIndex" : 38, + "fragment" : "2147483647 + 1" + } ] +} + + +-- !query +SELECT try_subtract(1L, (9223372036854775807L + 1L)) +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkArithmeticException +{ + "errorClass" : "ARITHMETIC_OVERFLOW", + "sqlState" : "22003", + "messageParameters" : { + "message" : "long overflow", + "alternative" : " Use 'try_add' to tolerate overflow and return NULL instead.", + "config" : "spark.sql.ansi.enabled" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 26, + "stopIndex" : 50, + "fragment" : "9223372036854775807L + 1L" + } ] +} + + +-- !query +SELECT try_subtract(1, 1.0 / 0.0) +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkArithmeticException +{ + "errorClass" : "DIVIDE_BY_ZERO", + "sqlState" : "22012", + "messageParameters" : { + "config" : "\"spark.sql.ansi.enabled\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 24, + "stopIndex" : 32, + "fragment" : "1.0 / 0.0" + } ] +} + + -- !query SELECT try_subtract(interval 2 year, interval 3 year) -- !query schema @@ -344,6 +554,76 @@ struct NULL +-- !query +SELECT try_multiply(1, (2147483647 + 1)) +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkArithmeticException +{ + "errorClass" : "ARITHMETIC_OVERFLOW", + "sqlState" : "22003", + "messageParameters" : { + "message" : "integer overflow", + "alternative" : " Use 'try_add' to tolerate overflow and return NULL instead.", + "config" : "spark.sql.ansi.enabled" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 25, + "stopIndex" : 38, + "fragment" : "2147483647 + 1" + } ] +} + + +-- !query +SELECT try_multiply(1L, (9223372036854775807L + 1L)) +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkArithmeticException +{ + "errorClass" : "ARITHMETIC_OVERFLOW", + "sqlState" : "22003", + "messageParameters" : { + "message" : "long overflow", + "alternative" : " Use 'try_add' to tolerate overflow and return NULL instead.", + "config" : "spark.sql.ansi.enabled" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 26, + "stopIndex" : 50, + "fragment" : "9223372036854775807L + 1L" + } ] +} + + +-- !query +SELECT try_multiply(1, 1.0 / 0.0) +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkArithmeticException +{ + "errorClass" : "DIVIDE_BY_ZERO", + "sqlState" : "22012", + "messageParameters" : { + "config" : "\"spark.sql.ansi.enabled\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 24, + "stopIndex" : 32, + "fragment" : "1.0 / 0.0" + } ] +} + + -- !query SELECT try_multiply(interval 2 year, 2) -- !query schema diff --git a/sql/core/src/test/resources/sql-tests/results/try_arithmetic.sql.out b/sql/core/src/test/resources/sql-tests/results/try_arithmetic.sql.out index 8622b97a20..50bbafedd0 100644 --- a/sql/core/src/test/resources/sql-tests/results/try_arithmetic.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/try_arithmetic.sql.out @@ -39,6 +39,30 @@ struct NULL +-- !query +SELECT try_add(1, (2147483647 + 1)) +-- !query schema +struct +-- !query output +-2147483647 + + +-- !query +SELECT try_add(1L, (9223372036854775807L + 1L)) +-- !query schema +struct +-- !query output +-9223372036854775807 + + +-- !query +SELECT try_add(1, 1.0 / 0.0) +-- !query schema +struct +-- !query output +NULL + + -- !query SELECT try_add(date'2021-01-01', 1) -- !query schema @@ -184,6 +208,30 @@ struct NULL +-- !query +SELECT try_divide(1, (2147483647 + 1)) +-- !query schema +struct +-- !query output +-4.6566128730773926E-10 + + +-- !query +SELECT try_divide(1L, (9223372036854775807L + 1L)) +-- !query schema +struct +-- !query output +-1.0842021724855044E-19 + + +-- !query +SELECT try_divide(1, 1.0 / 0.0) +-- !query schema +struct +-- !query output +NULL + + -- !query SELECT try_divide(interval 2 year, 2) -- !query schema @@ -272,6 +320,30 @@ struct NULL +-- !query +SELECT try_subtract(1, (2147483647 + 1)) +-- !query schema +struct +-- !query output +NULL + + +-- !query +SELECT try_subtract(1L, (9223372036854775807L + 1L)) +-- !query schema +struct +-- !query output +NULL + + +-- !query +SELECT try_subtract(1, 1.0 / 0.0) +-- !query schema +struct +-- !query output +NULL + + -- !query SELECT try_subtract(interval 2 year, interval 3 year) -- !query schema @@ -344,6 +416,30 @@ struct NULL +-- !query +SELECT try_multiply(1, (2147483647 + 1)) +-- !query schema +struct +-- !query output +-2147483648 + + +-- !query +SELECT try_multiply(1L, (9223372036854775807L + 1L)) +-- !query schema +struct +-- !query output +-9223372036854775808 + + +-- !query +SELECT try_multiply(1, 1.0 / 0.0) +-- !query schema +struct +-- !query output +NULL + + -- !query SELECT try_multiply(interval 2 year, 2) -- !query schema diff --git a/sql/core/src/test/resources/sql-tests/results/url-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/url-functions.sql.out index 44f4682e27..748904e9b2 100644 --- a/sql/core/src/test/resources/sql-tests/results/url-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/url-functions.sql.out @@ -105,8 +105,7 @@ org.apache.spark.SparkIllegalArgumentException "errorClass" : "CANNOT_DECODE_URL", "sqlState" : "42000", "messageParameters" : { - "url" : "http%3A%2F%2spark.apache.org", - "details" : "URLDecoder: Illegal hex characters in escape (%) pattern - For input string: \"2s\"" + "url" : "http%3A%2F%2spark.apache.org" } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala index 36989efbe8..9c442456ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala @@ -17,8 +17,11 @@ package org.apache.spark.sql +import java.util.concurrent.atomic.AtomicInteger + import org.apache.commons.math3.stat.inference.ChiSquareTest +import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecution import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -68,4 +71,42 @@ class ConfigBehaviorSuite extends QueryTest with SharedSparkSession { } } + test("SPARK-40211: customize initialNumPartitions for take") { + val totalElements = 100 + val numToTake = 50 + import scala.language.reflectiveCalls + val jobCountListener = new SparkListener { + private var count: AtomicInteger = new AtomicInteger(0) + def getCount: Int = count.get + def reset(): Unit = count.set(0) + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + count.incrementAndGet() + } + } + spark.sparkContext.addSparkListener(jobCountListener) + val df = spark.range(0, totalElements, 1, totalElements) + + // with default LIMIT_INITIAL_NUM_PARTITIONS = 1, expecting multiple jobs + df.take(numToTake) + spark.sparkContext.listenerBus.waitUntilEmpty() + assert(jobCountListener.getCount > 1) + jobCountListener.reset() + df.tail(numToTake) + spark.sparkContext.listenerBus.waitUntilEmpty() + assert(jobCountListener.getCount > 1) + + // setting LIMIT_INITIAL_NUM_PARTITIONS to large number(1000), expecting only 1 job + + withSQLConf(SQLConf.LIMIT_INITIAL_NUM_PARTITIONS.key -> "1000") { + jobCountListener.reset() + df.take(numToTake) + spark.sparkContext.listenerBus.waitUntilEmpty() + assert(jobCountListener.getCount == 1) + jobCountListener.reset() + df.tail(numToTake) + spark.sparkContext.listenerBus.waitUntilEmpty() + assert(jobCountListener.getCount == 1) + } + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index cbd65ede05..0854c6ba45 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -4003,6 +4003,61 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark } } + test("SPARK-40247: Fix BitSet equals") { + withTable("td") { + testData + .withColumn("bucket", $"key" % 3) + .write + .mode(SaveMode.Overwrite) + .bucketBy(2, "bucket") + .format("parquet") + .saveAsTable("td") + val df = sql( + """ + |SELECT t1.key, t2.key, t3.key + |FROM td AS t1 + |JOIN td AS t2 ON t2.key = t1.key + |JOIN td AS t3 ON t3.key = t2.key + |WHERE t1.bucket = 1 AND t2.bucket = 1 AND t3.bucket = 1 + |""".stripMargin) + df.collect() + val reusedExchanges = collect(df.queryExecution.executedPlan) { + case r: ReusedExchangeExec => r + } + assert(reusedExchanges.size == 1) + } + } + + test("SPARK-40245: Fix FileScan canonicalization when partition or data filter columns are not " + + "read") { + withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") { + withTempPath { path => + spark.range(5) + .withColumn("p", $"id" % 2) + .write + .mode("overwrite") + .partitionBy("p") + .parquet(path.toString) + withTempView("t") { + spark.read.parquet(path.toString).createOrReplaceTempView("t") + val df = sql( + """ + |SELECT t1.id, t2.id, t3.id + |FROM t AS t1 + |JOIN t AS t2 ON t2.id = t1.id + |JOIN t AS t3 ON t3.id = t2.id + |WHERE t1.p = 1 AND t2.p = 1 AND t3.p = 1 + |""".stripMargin) + df.collect() + val reusedExchanges = collect(df.queryExecution.executedPlan) { + case r: ReusedExchangeExec => r + } + assert(reusedExchanges.size == 1) + } + } + } + } + test("SPARK-35331: Fix resolving original expression in RepartitionByExpression after aliased") { Seq("CLUSTER", "DISTRIBUTE").foreach { keyword => Seq("a", "substr(a, 0, 3)").foreach { expr => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala index 102c971d6f..bcdb66bab3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala @@ -735,7 +735,7 @@ class BrokenColumnarAdd( left: ColumnarExpression, right: ColumnarExpression, failOnError: Boolean = false) - extends Add(left, right, failOnError) with ColumnarExpression { + extends Add(left, right, EvalMode.fromBoolean(failOnError)) with ColumnarExpression { override def supportsColumnar(): Boolean = left.supportsColumnar && right.supportsColumnar diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala index 26baec90f3..7966add773 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala @@ -20,11 +20,13 @@ package org.apache.spark.sql.connector import java.util.Collections import org.apache.spark.sql.{catalyst, AnalysisException, DataFrame, Row} +import org.apache.spark.sql.catalyst.expressions.{ApplyFunctionExpression, Cast, Literal} import org.apache.spark.sql.catalyst.plans.physical import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, RangePartitioning, UnknownPartitioning} import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.connector.catalog.functions.{BucketFunction, StringSelfFunction, UnboundBucketFunction, UnboundStringSelfFunction} import org.apache.spark.sql.connector.distributions.{Distribution, Distributions} -import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, NullOrdering, SortDirection, SortOrder} +import org.apache.spark.sql.connector.expressions._ import org.apache.spark.sql.connector.expressions.LogicalExpressions._ import org.apache.spark.sql.execution.{QueryExecution, SortExec, SparkPlan} import org.apache.spark.sql.execution.adaptive.AQEShuffleReadExec @@ -36,13 +38,21 @@ import org.apache.spark.sql.functions.lit import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.{StreamingQueryException, Trigger} import org.apache.spark.sql.test.SQLTestData.TestData -import org.apache.spark.sql.types.{IntegerType, StringType, StructType} +import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType} import org.apache.spark.sql.util.QueryExecutionListener class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase { import testImplicits._ + before { + Seq(UnboundBucketFunction, UnboundStringSelfFunction).foreach { f => + catalog.createFunction(Identifier.of(Array.empty, f.name()), f) + } + } + after { + catalog.clearTables() + catalog.clearFunctions() spark.sessionState.catalogManager.reset() } @@ -987,6 +997,95 @@ class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase } } + test("clustered distribution and local sort contains v2 function: append") { + checkClusteredDistributionAndLocalSortContainsV2FunctionInVariousCases("append") + } + + test("clustered distribution and local sort contains v2 function: overwrite") { + checkClusteredDistributionAndLocalSortContainsV2FunctionInVariousCases("overwrite") + } + + test("clustered distribution and local sort contains v2 function: overwriteDynamic") { + checkClusteredDistributionAndLocalSortContainsV2FunctionInVariousCases("overwriteDynamic") + } + + test("clustered distribution and local sort contains v2 function with numPartitions: append") { + checkClusteredDistributionAndLocalSortContainsV2Function("append", Some(10)) + } + + test("clustered distribution and local sort contains v2 function with numPartitions: " + + "overwrite") { + checkClusteredDistributionAndLocalSortContainsV2Function("overwrite", Some(10)) + } + + test("clustered distribution and local sort contains v2 function with numPartitions: " + + "overwriteDynamic") { + checkClusteredDistributionAndLocalSortContainsV2Function("overwriteDynamic", Some(10)) + } + + private def checkClusteredDistributionAndLocalSortContainsV2FunctionInVariousCases( + cmd: String): Unit = { + Seq(true, false).foreach { distributionStrictlyRequired => + Seq(true, false).foreach { dataSkewed => + Seq(true, false).foreach { coalesce => + checkClusteredDistributionAndLocalSortContainsV2Function( + cmd, None, distributionStrictlyRequired, dataSkewed, coalesce) + } + } + } + } + + private def checkClusteredDistributionAndLocalSortContainsV2Function( + command: String, + targetNumPartitions: Option[Int] = None, + distributionStrictlyRequired: Boolean = true, + dataSkewed: Boolean = false, + coalesce: Boolean = false): Unit = { + val tableOrdering = Array[SortOrder]( + sort(FieldReference("data"), SortDirection.DESCENDING, NullOrdering.NULLS_FIRST), + sort( + BucketTransform(LiteralValue(10, IntegerType), Seq(FieldReference("id"))), + SortDirection.DESCENDING, + NullOrdering.NULLS_FIRST) + ) + val tableDistribution = Distributions.clustered(Array( + ApplyTransform("string_self", Seq(FieldReference("data"))))) + + val writeOrdering = Seq( + catalyst.expressions.SortOrder( + attr("data"), + catalyst.expressions.Descending, + catalyst.expressions.NullsFirst, + Seq.empty + ), + catalyst.expressions.SortOrder( + ApplyFunctionExpression(BucketFunction, Seq(Literal(10), Cast(attr("id"), LongType))), + catalyst.expressions.Descending, + catalyst.expressions.NullsFirst, + Seq.empty + ) + ) + + val writePartitioningExprs = Seq( + ApplyFunctionExpression(StringSelfFunction, Seq(attr("data")))) + val writePartitioning = if (!coalesce) { + clusteredWritePartitioning(writePartitioningExprs, targetNumPartitions) + } else { + clusteredWritePartitioning(writePartitioningExprs, Some(1)) + } + + checkWriteRequirements( + tableDistribution, + tableOrdering, + targetNumPartitions, + expectedWritePartitioning = writePartitioning, + expectedWriteOrdering = writeOrdering, + writeCommand = command, + distributionStrictlyRequired = distributionStrictlyRequired, + dataSkewed = dataSkewed, + coalesce = coalesce) + } + // scalastyle:off argcount private def checkWriteRequirements( tableDistribution: Distribution, @@ -1209,12 +1308,20 @@ class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase if (skewSplit) { assert(actualPartitioning.numPartitions > conf.numShufflePartitions) } else { - assert(actualPartitioning == expectedPartitioning, "partitioning must match") + (actualPartitioning, expectedPartitioning) match { + case (actual: catalyst.expressions.Expression, expected: catalyst.expressions.Expression) => + assert(actual semanticEquals expected, "partitioning must match") + case (actual, expected) => + assert(actual == expected, "partitioning must match") + } } val actualOrdering = plan.outputOrdering val expectedOrdering = ordering.map(resolveAttrs(_, plan)) - assert(actualOrdering == expectedOrdering, "ordering must match") + assert(actualOrdering.length == expectedOrdering.length) + (actualOrdering zip expectedOrdering).foreach { case (actual, expected) => + assert(actual semanticEquals expected, "ordering must match") + } } // executes a write operation and keeps the executed physical plan diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala index 1994874d32..9277e8d059 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala @@ -16,7 +16,9 @@ */ package org.apache.spark.sql.connector.catalog.functions +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String object UnboundYearsFunction extends UnboundFunction { override def bind(inputType: StructType): BoundFunction = { @@ -70,9 +72,30 @@ object UnboundBucketFunction extends UnboundFunction { override def name(): String = "bucket" } -object BucketFunction extends BoundFunction { - override def inputTypes(): Array[DataType] = Array(IntegerType, IntegerType) +object BucketFunction extends ScalarFunction[Int] { + override def inputTypes(): Array[DataType] = Array(IntegerType, LongType) override def resultType(): DataType = IntegerType override def name(): String = "bucket" override def canonicalName(): String = name() + override def toString: String = name() + override def produceResult(input: InternalRow): Int = { + (input.getLong(1) % input.getInt(0)).toInt + } +} + +object UnboundStringSelfFunction extends UnboundFunction { + override def bind(inputType: StructType): BoundFunction = StringSelfFunction + override def description(): String = name() + override def name(): String = "string_self" +} + +object StringSelfFunction extends ScalarFunction[UTF8String] { + override def inputTypes(): Array[DataType] = Array(StringType) + override def resultType(): DataType = StringType + override def name(): String = "string_self" + override def canonicalName(): String = name() + override def toString: String = name() + override def produceResult(input: InternalRow): UTF8String = { + input.getUTF8String(0) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/functions/V2FunctionBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/functions/V2FunctionBenchmark.scala index 38f016c2b6..d9c3848d3b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/functions/V2FunctionBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/functions/V2FunctionBenchmark.scala @@ -24,7 +24,7 @@ import org.apache.spark.benchmark.Benchmark import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.{BinaryArithmetic, Expression} +import org.apache.spark.sql.catalyst.expressions.{BinaryArithmetic, EvalMode, Expression} import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryCatalog} @@ -104,7 +104,7 @@ object V2FunctionBenchmark extends SqlBasedBenchmark { left: Expression, right: Expression, override val nullable: Boolean) extends BinaryArithmetic { - override protected val failOnError: Boolean = false + protected override val evalMode: EvalMode.Value = EvalMode.LEGACY override def inputType: AbstractDataType = NumericType override def symbol: String = "+" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala index a672a3fb1b..4a0c88be42 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala @@ -124,7 +124,8 @@ class SortBasedAggregationStoreSuite extends SparkFunSuite with LocalSparkConte def createSortedAggBufferIterator( hashMap: ObjectAggregationMap): KVIterator[UnsafeRow, UnsafeRow] = { - val sortedIterator = hashMap.iterator.toList.sortBy(_.groupingKey.getInt(0)).iterator + val sortedIterator = hashMap.destructiveIterator().toList.sortBy(_.groupingKey.getInt(0)) + .iterator new KVIterator[UnsafeRow, UnsafeRow] { var key: UnsafeRow = null var value: UnsafeRow = null diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index fb5595322f..6151e1d7cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -1095,6 +1095,23 @@ abstract class ParquetPartitionDiscoverySuite checkAnswer(readback, Row(0, "AA") :: Row(1, "-0") :: Nil) } } + + test("SPARK-40212: SparkSQL castPartValue does not properly handle byte, short, float") { + withTempDir { dir => + val data = Seq[(Int, Byte, Short, Float)]( + (1, 2, 3, 4.0f) + ) + data.toDF("a", "b", "c", "d") + .write + .mode("overwrite") + .partitionBy("b", "c", "d") + .parquet(dir.getCanonicalPath) + val res = spark.read + .schema("a INT, b BYTE, c SHORT, d FLOAT") + .parquet(dir.getCanonicalPath) + checkAnswer(res, Seq(Row(1, 2, 3, 4.0f))) + } + } } class ParquetV1PartitionDiscoverySuite extends ParquetPartitionDiscoverySuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala index 101315ccb7..106802a54c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala @@ -228,11 +228,6 @@ class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession { // Do not check these expressions, because these expressions override the eval method val ignoreSet = Set( - // Extend NullIntolerant and avoid evaluating input1 if input2 is 0 - classOf[IntegralDivide], - classOf[Divide], - classOf[Remainder], - classOf[Pmod], // Throws an exception, even if input is null classOf[RaiseError] ) @@ -242,6 +237,8 @@ class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession { .filterNot(c => ignoreSet.exists(_.getName.equals(c))) .map(name => Utils.classForName(name)) .filterNot(classOf[NonSQLExpression].isAssignableFrom) + // BinaryArithmetic overrides the eval method + .filterNot(classOf[BinaryArithmetic].isAssignableFrom) exprTypesToCheck.foreach { superClass => candidateExprsToCheck.filter(superClass.isAssignableFrom).foreach { clazz =>