diff --git a/appveyor.yml b/appveyor.yml index 1a2aef0d3b..fdb247d5d4 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -28,6 +28,7 @@ only_commits: files: - appveyor.yml - dev/appveyor-install-dependencies.ps1 + - build/spark-build-info.ps1 - R/ - sql/core/src/main/scala/org/apache/spark/sql/api/r/ - core/src/main/scala/org/apache/spark/api/r/ diff --git a/build/spark-build-info.ps1 b/build/spark-build-info.ps1 new file mode 100644 index 0000000000..43db882334 --- /dev/null +++ b/build/spark-build-info.ps1 @@ -0,0 +1,46 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# This script generates the build info for spark and places it into the spark-version-info.properties file. +# Arguments: +# ResourceDir - The target directory where properties file would be created. [./core/target/extra-resources] +# SparkVersion - The current version of spark + +param( + # The resource directory. + [Parameter(Position = 0)] + [String] + $ResourceDir, + + # The Spark version. + [Parameter(Position = 1)] + [String] + $SparkVersion +) + +$null = New-Item -Type Directory -Force $ResourceDir +$SparkBuildInfoPath = $ResourceDir.TrimEnd('\').TrimEnd('/') + '\spark-version-info.properties' + +$SparkBuildInfoContent = +"version=$SparkVersion +user=$($Env:USERNAME) +revision=$(git rev-parse HEAD) +branch=$(git rev-parse --abbrev-ref HEAD) +date=$([DateTime]::UtcNow | Get-Date -UFormat +%Y-%m-%dT%H:%M:%SZ) +url=$(git config --get remote.origin.url)" + +Set-Content -Path $SparkBuildInfoPath -Value $SparkBuildInfoContent diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/NettyUtils.java b/common/network-common/src/main/java/org/apache/spark/network/util/NettyUtils.java index 4f070f02a1..cc4657efe3 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/NettyUtils.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/NettyUtils.java @@ -179,4 +179,18 @@ public static PooledByteBufAllocator createPooledByteBufAllocator( allowCache ? PooledByteBufAllocator.defaultUseCacheForAllThreads() : false ); } + + /** + * ByteBuf allocator prefers to allocate direct ByteBuf iif both Spark allows to create direct + * ByteBuf and Netty enables directBufferPreferred. + */ + public static boolean preferDirectBufs(TransportConf conf) { + boolean allowDirectBufs; + if (conf.sharedByteBufAllocators()) { + allowDirectBufs = conf.preferDirectBufsForSharedByteBufAllocators(); + } else { + allowDirectBufs = conf.preferDirectBufs(); + } + return allowDirectBufs && PlatformDependent.directBufferPreferred(); + } } diff --git a/common/network-yarn/pom.xml b/common/network-yarn/pom.xml index a4073969db..a77732bb8b 100644 --- a/common/network-yarn/pom.xml +++ b/common/network-yarn/pom.xml @@ -136,6 +136,11 @@ shade + + + + + diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala b/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala index d0f38c1242..763b9abe4f 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala @@ -70,10 +70,6 @@ case class AvroScan( override def hashCode(): Int = super.hashCode() - override def description(): String = { - super.description() + ", PushedFilters: " + pushedFilters.mkString("[", ", ", "]") - } - override def getMetaData(): Map[String, String] = { super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters)) } diff --git a/connector/connect/README.md b/connector/connect/README.md index c009ff8d29..e4753eef0c 100644 --- a/connector/connect/README.md +++ b/connector/connect/README.md @@ -32,7 +32,7 @@ for example, compiling `connect` module on CentOS 6 or CentOS 7 which the defaul specifying the user-defined `protoc` and `protoc-gen-grpc-java` binary files as follows: ```bash -export CONNECT_PROTOC_EXEC_PATH=/path-to-protoc-exe +export SPARK_PROTOC_EXEC_PATH=/path-to-protoc-exe export CONNECT_PLUGIN_EXEC_PATH=/path-to-protoc-gen-grpc-java-exe ./build/mvn -Phive -Puser-defined-protoc clean package ``` @@ -40,7 +40,7 @@ export CONNECT_PLUGIN_EXEC_PATH=/path-to-protoc-gen-grpc-java-exe or ```bash -export CONNECT_PROTOC_EXEC_PATH=/path-to-protoc-exe +export SPARK_PROTOC_EXEC_PATH=/path-to-protoc-exe export CONNECT_PLUGIN_EXEC_PATH=/path-to-protoc-gen-grpc-java-exe ./build/sbt -Puser-defined-protoc clean package ``` @@ -82,7 +82,7 @@ To use the release version of Spark Connect: ```bash # Run all Spark Connect Python tests as a module. -./python/run-tests --module pyspark-connect +./python/run-tests --module pyspark-connect --parallelism 1 ``` diff --git a/connector/connect/common/pom.xml b/connector/connect/common/pom.xml index 555afd5bc4..2d80d8215b 100644 --- a/connector/connect/common/pom.xml +++ b/connector/connect/common/pom.xml @@ -193,7 +193,7 @@ user-defined-protoc - ${env.CONNECT_PROTOC_EXEC_PATH} + ${env.SPARK_PROTOC_EXEC_PATH} ${env.CONNECT_PLUGIN_EXEC_PATH} @@ -203,7 +203,7 @@ protobuf-maven-plugin 0.6.1 - ${connect.protoc.executable.path} + ${spark.protoc.executable.path} grpc-java ${connect.plugin.executable.path} src/main/protobuf diff --git a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto index ec4490e845..6c0facbfee 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto @@ -43,7 +43,11 @@ message Expression { Expression expr = 1; // (Required) the data type that the expr to be casted to. - DataType cast_to_type = 2; + oneof cast_to_type { + DataType type = 2; + // If this is set, Server will use Catalyst parser to parse this string to DataType. + string type_str = 3; + } } message Literal { diff --git a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto index ece8767c06..20b067ddb4 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto @@ -20,6 +20,7 @@ syntax = 'proto3'; package spark.connect; import "spark/connect/expressions.proto"; +import "spark/connect/types.proto"; option java_multiple_files = true; option java_package = "org.apache.spark.connect.proto"; @@ -54,6 +55,7 @@ message Relation { Tail tail = 22; WithColumns with_columns = 23; Hint hint = 24; + Unpivot unpivot = 25; // NA functions NAFill fill_na = 90; @@ -304,6 +306,17 @@ message LocalRelation { // Local collection data serialized into Arrow IPC streaming format which contains // the schema of the data. bytes data = 1; + + // (Optional) The user provided schema. + // + // The Sever side will update the column names and data types according to this schema. + oneof schema { + + DataType datatype = 2; + + // Server will use Catalyst parser to parse this string to DataType. + string datatype_str = 3; + } } // Relation of type [[Sample]] that samples a fraction of the dataset. @@ -570,3 +583,21 @@ message Hint { // (Optional) Hint parameters. repeated Expression.Literal parameters = 3; } + +// Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns set. +message Unpivot { + // (Required) The input relation. + Relation input = 1; + + // (Required) Id columns. + repeated Expression ids = 2; + + // (Optional) Value columns to unpivot. + repeated Expression values = 3; + + // (Required) Name of the variable column. + string variable_column_name = 4; + + // (Required) Name of the value column. + string value_column_name = 5; +} diff --git a/connector/connect/server/pom.xml b/connector/connect/server/pom.xml index 43bb10e7f5..4c21d70098 100644 --- a/connector/connect/server/pom.xml +++ b/connector/connect/server/pom.xml @@ -55,6 +55,12 @@ org.apache.spark spark-connect-common_${scala.binary.version} ${project.version} + + + com.google.guava + guava + + org.apache.spark @@ -106,6 +112,12 @@ spark-tags_${scala.binary.version} ${project.version} provided + + + com.google.guava + guava + + + + com.google.guava + guava + ${guava.version} + compile + + + com.google.guava + failureaccess + ${guava.failureaccess.version} + com.google.protobuf protobuf-java ${protobuf.version} compile + + io.grpc + grpc-netty + ${io.grpc.version} + + + io.grpc + grpc-protobuf + ${io.grpc.version} + + + io.grpc + grpc-services + ${io.grpc.version} + + + io.grpc + grpc-stub + ${io.grpc.version} + + + io.netty + netty-codec-http2 + ${netty.version} + provided + + + io.netty + netty-handler-proxy + ${netty.version} + provided + + + io.netty + netty-transport-native-unix-common + ${netty.version} + provided + + + org.apache.tomcat + annotations-api + ${tomcat.annotations.api.version} + provided + org.scalacheck scalacheck_${scala.binary.version} diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala index a17e9784ec..60fdd96401 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala @@ -38,8 +38,10 @@ private[spark] object Connect { val CONNECT_GRPC_ARROW_MAX_BATCH_SIZE = ConfigBuilder("spark.connect.grpc.arrow.maxBatchSize") - .doc("When using Apache Arrow, limit the maximum size of one arrow batch that " + - "can be sent from server side to client side.") + .doc( + "When using Apache Arrow, limit the maximum size of one arrow batch that " + + "can be sent from server side to client side. Currently, we conservatively use 70% " + + "of it because the size is not accurate but estimated.") .version("3.4.0") .bytesConf(ByteUnit.MiB) .createWithDefaultString("4m") diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala index 44baf40781..545c2aaaf0 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala @@ -96,7 +96,17 @@ package object dsl { Expression.Cast .newBuilder() .setExpr(expr) - .setCastToType(dataType)) + .setType(dataType)) + .build() + + def cast(dataType: String): Expression = + Expression + .newBuilder() + .setCast( + Expression.Cast + .newBuilder() + .setExpr(expr) + .setTypeStr(dataType)) .build() } @@ -709,6 +719,53 @@ package object dsl { .build() } + def unpivot( + ids: Seq[Expression], + values: Seq[Expression], + variableColumnName: String, + valueColumnName: String): Relation = { + Relation + .newBuilder() + .setUnpivot( + Unpivot + .newBuilder() + .setInput(logicalPlan) + .addAllIds(ids.asJava) + .addAllValues(values.asJava) + .setVariableColumnName(variableColumnName) + .setValueColumnName(valueColumnName)) + .build() + } + + def unpivot( + ids: Seq[Expression], + variableColumnName: String, + valueColumnName: String): Relation = { + Relation + .newBuilder() + .setUnpivot( + Unpivot + .newBuilder() + .setInput(logicalPlan) + .addAllIds(ids.asJava) + .setVariableColumnName(variableColumnName) + .setValueColumnName(valueColumnName)) + .build() + } + + def melt( + ids: Seq[Expression], + values: Seq[Expression], + variableColumnName: String, + valueColumnName: String): Relation = + unpivot(ids, values, variableColumnName, valueColumnName) + + def melt( + ids: Seq[Expression], + variableColumnName: String, + valueColumnName: String): Relation = + unpivot(ids, variableColumnName, valueColumnName) + private def createSetOperation( left: Relation, right: Relation, diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index af5d9abc51..ba5ceed452 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.CombineUnions import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException} import org.apache.spark.sql.catalyst.plans.{logical, Cross, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin} -import org.apache.spark.sql.catalyst.plans.logical.{Deduplicate, Except, Intersect, LocalRelation, LogicalPlan, Sample, SubqueryAlias, Union, UnresolvedHint} +import org.apache.spark.sql.catalyst.plans.logical.{Deduplicate, Except, Intersect, LocalRelation, LogicalPlan, Sample, SubqueryAlias, Union, Unpivot, UnresolvedHint} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.connect.planner.LiteralValueProtoConverter.{toCatalystExpression, toCatalystValue} import org.apache.spark.sql.errors.QueryCompilationErrors @@ -95,6 +95,7 @@ class SparkConnectPlanner(session: SparkSession) { transformRenameColumnsByNameToNameMap(rel.getRenameColumnsByNameToNameMap) case proto.Relation.RelTypeCase.WITH_COLUMNS => transformWithColumns(rel.getWithColumns) case proto.Relation.RelTypeCase.HINT => transformHint(rel.getHint) + case proto.Relation.RelTypeCase.UNPIVOT => transformUnpivot(rel.getUnpivot) case proto.Relation.RelTypeCase.RELTYPE_NOT_SET => throw new IndexOutOfBoundsException("Expected Relation to be set, but is empty.") case _ => throw InvalidPlanInput(s"${rel.getUnknown} not supported.") @@ -309,6 +310,34 @@ class SparkConnectPlanner(session: SparkSession) { UnresolvedHint(rel.getName, params, transformRelation(rel.getInput)) } + private def transformUnpivot(rel: proto.Unpivot): LogicalPlan = { + val ids = rel.getIdsList.asScala.toArray.map { expr => + Column(transformExpression(expr)) + } + + if (rel.getValuesList.isEmpty) { + Unpivot( + Some(ids.map(_.named)), + None, + None, + rel.getVariableColumnName, + Seq(rel.getValueColumnName), + transformRelation(rel.getInput)) + } else { + val values = rel.getValuesList.asScala.toArray.map { expr => + Column(transformExpression(expr)) + } + + Unpivot( + Some(ids.map(_.named)), + Some(values.map(v => Seq(v.named))), + None, + rel.getVariableColumnName, + Seq(rel.getValueColumnName), + transformRelation(rel.getInput)) + } + } + private def transformDeduplicate(rel: proto.Deduplicate): LogicalPlan = { if (!rel.hasInput) { throw InvalidPlanInput("Deduplicate needs a plan input") @@ -340,6 +369,21 @@ class SparkConnectPlanner(session: SparkSession) { } } + private def parseDatatypeString(sqlText: String): DataType = { + val parser = session.sessionState.sqlParser + try { + parser.parseTableSchema(sqlText) + } catch { + case _: ParseException => + try { + parser.parseDataType(sqlText) + } catch { + case _: ParseException => + parser.parseDataType(s"struct<${sqlText.trim}>") + } + } + } + private def transformLocalRelation(rel: proto.LocalRelation): LogicalPlan = { val (rows, structType) = ArrowConverters.fromBatchWithSchemaIterator( Iterator(rel.getData.toByteArray), @@ -349,7 +393,28 @@ class SparkConnectPlanner(session: SparkSession) { } val attributes = structType.toAttributes val proj = UnsafeProjection.create(attributes, attributes) - new logical.LocalRelation(attributes, rows.map(r => proj(r).copy()).toSeq) + val relation = logical.LocalRelation(attributes, rows.map(r => proj(r).copy()).toSeq) + + if (!rel.hasDatatype && !rel.hasDatatypeStr) { + return relation + } + + val schemaType = if (rel.hasDatatype) { + DataTypeProtoConverter.toCatalystType(rel.getDatatype) + } else { + parseDatatypeString(rel.getDatatypeStr) + } + + val schemaStruct = schemaType match { + case s: StructType => s + case d => StructType(Seq(StructField("value", d))) + } + + Dataset + .ofRows(session, logicalPlan = relation) + .toDF(schemaStruct.names: _*) + .to(schemaStruct) + .logicalPlan } private def transformReadRel(rel: proto.Read): LogicalPlan = { @@ -518,9 +583,16 @@ class SparkConnectPlanner(session: SparkSession) { } private def transformCast(cast: proto.Expression.Cast): Expression = { - Cast( - transformExpression(cast.getExpr), - DataTypeProtoConverter.toCatalystType(cast.getCastToType)) + cast.getCastToTypeCase match { + case proto.Expression.Cast.CastToTypeCase.TYPE => + Cast( + transformExpression(cast.getExpr), + DataTypeProtoConverter.toCatalystType(cast.getType)) + case _ => + Cast( + transformExpression(cast.getExpr), + session.sessionState.sqlParser.parseDataType(cast.getTypeStr)) + } } private def transformSetOperation(u: proto.SetOperation): LogicalPlan = { diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverterSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverterSuite.scala index dc8254c47f..7c8ee6209a 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverterSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverterSuite.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.connect.planner -import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.funsuite.AnyFunSuite // scalastyle:ignore funsuite import org.apache.spark.sql.connect.planner.LiteralValueProtoConverter.{toCatalystValue, toConnectProtoValue} -class LiteralValueProtoConverterSuite extends AnyFunSuite { +class LiteralValueProtoConverterSuite extends AnyFunSuite { // scalastyle:ignore funsuite test("basic proto value and catalyst value conversion") { val values = Array(null, true, 1.toByte, 1.toShort, 1, 1L, 1.1d, 1.1f, "spark") diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala index 6d36ea9a63..8611ba45f7 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala @@ -550,12 +550,44 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { connectTestRelation.select("id".protoAttr.cast( proto.DataType.newBuilder().setString(proto.DataType.String.getDefaultInstance).build())), sparkTestRelation.select(col("id").cast(StringType))) + + comparePlans( + connectTestRelation.select("id".protoAttr.cast("string")), + sparkTestRelation.select(col("id").cast("string"))) } test("Test Hint") { comparePlans(connectTestRelation.hint("COALESCE", 3), sparkTestRelation.hint("COALESCE", 3)) } + test("Test Unpivot") { + val connectPlan0 = + connectTestRelation.unpivot(Seq("id".protoAttr), Seq("name".protoAttr), "variable", "value") + val sparkPlan0 = + sparkTestRelation.unpivot(Array(Column("id")), Array(Column("name")), "variable", "value") + comparePlans(connectPlan0, sparkPlan0) + + val connectPlan1 = + connectTestRelation.unpivot(Seq("id".protoAttr), "variable", "value") + val sparkPlan1 = + sparkTestRelation.unpivot(Array(Column("id")), "variable", "value") + comparePlans(connectPlan1, sparkPlan1) + } + + test("Test Melt") { + val connectPlan0 = + connectTestRelation.melt(Seq("id".protoAttr), Seq("name".protoAttr), "variable", "value") + val sparkPlan0 = + sparkTestRelation.melt(Array(Column("id")), Array(Column("name")), "variable", "value") + comparePlans(connectPlan0, sparkPlan0) + + val connectPlan1 = + connectTestRelation.melt(Seq("id".protoAttr), "variable", "value") + val sparkPlan1 = + sparkTestRelation.melt(Array(Column("id")), "variable", "value") + comparePlans(connectPlan1, sparkPlan1) + } + private def createLocalRelationProtoByAttributeReferences( attrs: Seq[AttributeReference]): proto.Relation = { val localRelationBuilder = proto.LocalRelation.newBuilder() diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala index 2bde50d01e..bc202b1b83 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala @@ -26,9 +26,9 @@ import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._ import org.apache.spark.tags.DockerTest /** - * To run this test suite for a specific version (e.g., mysql:5.7.36): + * To run this test suite for a specific version (e.g., mysql:8.0.31): * {{{ - * ENABLE_DOCKER_INTEGRATION_TESTS=1 MYSQL_DOCKER_IMAGE_NAME=mysql:5.7.36 + * ENABLE_DOCKER_INTEGRATION_TESTS=1 MYSQL_DOCKER_IMAGE_NAME=mysql:8.0.31 * ./build/sbt -Pdocker-integration-tests * "testOnly org.apache.spark.sql.jdbc.MySQLIntegrationSuite" * }}} @@ -36,7 +36,7 @@ import org.apache.spark.tags.DockerTest @DockerTest class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite { override val db = new DatabaseOnDocker { - override val imageName = sys.env.getOrElse("MYSQL_DOCKER_IMAGE_NAME", "mysql:5.7.36") + override val imageName = sys.env.getOrElse("MYSQL_DOCKER_IMAGE_NAME", "mysql:8.0.31") override val env = Map( "MYSQL_ROOT_PASSWORD" -> "rootpass" ) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala index 2562ee78ec..d3229ba50e 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -29,9 +29,9 @@ import org.apache.spark.sql.types.{ArrayType, DecimalType, FloatType, ShortType} import org.apache.spark.tags.DockerTest /** - * To run this test suite for a specific version (e.g., postgres:14.0): + * To run this test suite for a specific version (e.g., postgres:15.1): * {{{ - * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:14.0 + * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:15.1 * ./build/sbt -Pdocker-integration-tests * "testOnly org.apache.spark.sql.jdbc.PostgresIntegrationSuite" * }}} @@ -39,7 +39,7 @@ import org.apache.spark.tags.DockerTest @DockerTest class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { override val db = new DatabaseOnDocker { - override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:14.0-alpine") + override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:15.1-alpine") override val env = Map( "POSTGRES_PASSWORD" -> "rootpass" ) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresKrbIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresKrbIntegrationSuite.scala index c46a845a74..4debe24754 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresKrbIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresKrbIntegrationSuite.scala @@ -25,9 +25,9 @@ import org.apache.spark.sql.execution.datasources.jdbc.connection.SecureConnecti import org.apache.spark.tags.DockerTest /** - * To run this test suite for a specific version (e.g., postgres:14.0): + * To run this test suite for a specific version (e.g., postgres:15.1): * {{{ - * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:14.0 + * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:15.1 * ./build/sbt -Pdocker-integration-tests "testOnly *PostgresKrbIntegrationSuite" * }}} */ @@ -37,7 +37,7 @@ class PostgresKrbIntegrationSuite extends DockerKrbJDBCIntegrationSuite { override protected val keytabFileName = "postgres.keytab" override val db = new DatabaseOnDocker { - override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:14.0") + override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:15.1") override val env = Map( "POSTGRES_PASSWORD" -> "rootpass" ) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala index 6e76b74c7d..072fdbb3f3 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala @@ -29,9 +29,9 @@ import org.apache.spark.sql.types._ import org.apache.spark.tags.DockerTest /** - * To run this test suite for a specific version (e.g., mysql:5.7.36): + * To run this test suite for a specific version (e.g., mysql:8.0.31): * {{{ - * ENABLE_DOCKER_INTEGRATION_TESTS=1 MYSQL_DOCKER_IMAGE_NAME=mysql:5.7.36 + * ENABLE_DOCKER_INTEGRATION_TESTS=1 MYSQL_DOCKER_IMAGE_NAME=mysql:8.0.31 * ./build/sbt -Pdocker-integration-tests "testOnly *v2*MySQLIntegrationSuite" * }}} */ @@ -39,7 +39,7 @@ import org.apache.spark.tags.DockerTest class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest { override val catalogName: String = "mysql" override val db = new DatabaseOnDocker { - override val imageName = sys.env.getOrElse("MYSQL_DOCKER_IMAGE_NAME", "mysql:5.7.36") + override val imageName = sys.env.getOrElse("MYSQL_DOCKER_IMAGE_NAME", "mysql:8.0.31") override val env = Map( "MYSQL_ROOT_PASSWORD" -> "rootpass" ) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLNamespaceSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLNamespaceSuite.scala index d8dee61d70..b73e2b8fd2 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLNamespaceSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLNamespaceSuite.scala @@ -28,16 +28,16 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.tags.DockerTest /** - * To run this test suite for a specific version (e.g., mysql:5.7.36): + * To run this test suite for a specific version (e.g., mysql:8.0.31): * {{{ - * ENABLE_DOCKER_INTEGRATION_TESTS=1 MYSQL_DOCKER_IMAGE_NAME=mysql:5.7.36 + * ENABLE_DOCKER_INTEGRATION_TESTS=1 MYSQL_DOCKER_IMAGE_NAME=mysql:8.0.31 * ./build/sbt -Pdocker-integration-tests "testOnly *v2*MySQLNamespaceSuite" * }}} */ @DockerTest class MySQLNamespaceSuite extends DockerJDBCIntegrationSuite with V2JDBCNamespaceTest { override val db = new DatabaseOnDocker { - override val imageName = sys.env.getOrElse("MYSQL_DOCKER_IMAGE_NAME", "mysql:5.7.36") + override val imageName = sys.env.getOrElse("MYSQL_DOCKER_IMAGE_NAME", "mysql:8.0.31") override val env = Map( "MYSQL_ROOT_PASSWORD" -> "rootpass" ) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala index 1ff7527c97..db3a80ffea 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala @@ -27,9 +27,9 @@ import org.apache.spark.sql.types._ import org.apache.spark.tags.DockerTest /** - * To run this test suite for a specific version (e.g., postgres:14.0): + * To run this test suite for a specific version (e.g., postgres:15.1): * {{{ - * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:14.0 + * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:15.1 * ./build/sbt -Pdocker-integration-tests "testOnly *v2.PostgresIntegrationSuite" * }}} */ @@ -37,7 +37,7 @@ import org.apache.spark.tags.DockerTest class PostgresIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest { override val catalogName: String = "postgresql" override val db = new DatabaseOnDocker { - override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:14.0-alpine") + override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:15.1-alpine") override val env = Map( "POSTGRES_PASSWORD" -> "rootpass" ) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresNamespaceSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresNamespaceSuite.scala index 33190103d6..8c52571775 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresNamespaceSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresNamespaceSuite.scala @@ -26,16 +26,16 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.tags.DockerTest /** - * To run this test suite for a specific version (e.g., postgres:14.0): + * To run this test suite for a specific version (e.g., postgres:15.1): * {{{ - * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:14.0 + * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:15.1 * ./build/sbt -Pdocker-integration-tests "testOnly *v2.PostgresNamespaceSuite" * }}} */ @DockerTest class PostgresNamespaceSuite extends DockerJDBCIntegrationSuite with V2JDBCNamespaceTest { override val db = new DatabaseOnDocker { - override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:14.0-alpine") + override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:15.1-alpine") override val env = Map( "POSTGRES_PASSWORD" -> "rootpass" ) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCNamespaceTest.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCNamespaceTest.scala index 50e79e03a7..b2500a2dbf 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCNamespaceTest.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCNamespaceTest.scala @@ -62,7 +62,8 @@ private[v2] trait V2JDBCNamespaceTest extends SharedSparkSession with DockerInte Map.empty[String, String] } catalog.createNamespace(Array("foo"), commentMap.asJava) - assert(catalog.listNamespaces() === listNamespaces(Array("foo"))) + assert(catalog.listNamespaces().map(_.toSet).toSet === + listNamespaces(Array("foo")).map(_.toSet).toSet) assert(catalog.listNamespaces(Array("foo")) === Array()) assert(catalog.namespaceExists(Array("foo")) === true) diff --git a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala index 77bc658a1e..a371d25899 100644 --- a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala +++ b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala @@ -85,8 +85,6 @@ private[kafka010] class KafkaMicroBatchStream( private val includeHeaders = options.getBoolean(INCLUDE_HEADERS, false) - private var endPartitionOffsets: KafkaSourceOffset = _ - private var latestPartitionOffsets: PartitionOffsetMap = _ private var allDataForTriggerAvailableNow: PartitionOffsetMap = _ @@ -114,7 +112,7 @@ private[kafka010] class KafkaMicroBatchStream( } override def reportLatestOffset(): Offset = { - KafkaSourceOffset(latestPartitionOffsets) + Option(KafkaSourceOffset(latestPartitionOffsets)).filterNot(_.partitionToOffsets.isEmpty).orNull } override def latestOffset(): Offset = { @@ -163,8 +161,7 @@ private[kafka010] class KafkaMicroBatchStream( }.getOrElse(latestPartitionOffsets) } - endPartitionOffsets = KafkaSourceOffset(offsets) - endPartitionOffsets + Option(KafkaSourceOffset(offsets)).filterNot(_.partitionToOffsets.isEmpty).orNull } /** Checks if we need to skip this trigger based on minOffsetsPerTrigger & maxTriggerDelay */ diff --git a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index a7840ef105..a9ee5b6462 100644 --- a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -177,7 +177,7 @@ private[kafka010] class KafkaSource( kafkaReader.fetchLatestOffsets(currentOffsets) } - latestPartitionOffsets = Some(latest) + latestPartitionOffsets = if (latest.isEmpty) None else Some(latest) val limits: Seq[ReadLimit] = limit match { case rows: CompositeReadLimit => rows.getReadLimits @@ -213,7 +213,7 @@ private[kafka010] class KafkaSource( } currentPartitionOffsets = Some(offsets) logDebug(s"GetOffset: ${offsets.toSeq.map(_.toString).sorted}") - KafkaSourceOffset(offsets) + Option(KafkaSourceOffset(offsets)).filterNot(_.partitionToOffsets.isEmpty).orNull } /** Checks if we need to skip this trigger based on minOffsetsPerTrigger & maxTriggerDelay */ diff --git a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index af66ecd21c..bf0e72cd32 100644 --- a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -627,6 +627,45 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { ) } + test("SPARK-41375: empty partitions should not record to latest offset") { + val topicPrefix = newTopic() + val topic = topicPrefix + "-good" + testUtils.createTopic(topic, partitions = 5) + testUtils.sendMessages(topic, Array("-1")) + require(testUtils.getLatestOffsets(Set(topic)).size === 5) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("kafka.request.timeout.ms", "3000") + .option("kafka.default.api.timeout.ms", "3000") + .option("subscribePattern", s"$topicPrefix-.*") + .option("failOnDataLoss", "false") + + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped = kafka.map(kv => kv._2.toInt + 1) + + testStream(mapped)( + makeSureGetOffsetCalled, + AddKafkaData(Set(topic), 1, 2, 3), + CheckAnswer(2, 3, 4), + Assert { + testUtils.deleteTopic(topic) + true + }, + AssertOnQuery { q => + val latestOffset: Option[(Long, OffsetSeq)] = q.offsetLog.getLatest + latestOffset.exists { offset => + !offset._2.offsets.exists(_.exists(_.json == "{}")) + } + } + ) + } + test("subscribe topic by pattern with topic recreation between batches") { val topicPrefix = newTopic() val topic = topicPrefix + "-good" diff --git a/connector/protobuf/README.md b/connector/protobuf/README.md index 4fc2895049..9dd0a2457d 100644 --- a/connector/protobuf/README.md +++ b/connector/protobuf/README.md @@ -21,15 +21,14 @@ for example, compiling `protobuf` module on CentOS 6 or CentOS 7 which the defau specifying the user-defined `protoc` binary files as follows: ```bash -export PROTOBUF_PROTOC_EXEC_PATH=/path-to-protoc-exe +export SPARK_PROTOC_EXEC_PATH=/path-to-protoc-exe ./build/mvn -Phive -Puser-defined-protoc clean package ``` or ```bash -export PROTOBUF_PROTOC_EXEC_PATH=/path-to-protoc-exe -export CONNECT_PLUGIN_EXEC_PATH=/path-to-protoc-gen-grpc-java-exe +export SPARK_PROTOC_EXEC_PATH=/path-to-protoc-exe ./build/sbt -Puser-defined-protoc clean package ``` diff --git a/connector/protobuf/pom.xml b/connector/protobuf/pom.xml index 3036fcbf25..56f222c401 100644 --- a/connector/protobuf/pom.xml +++ b/connector/protobuf/pom.xml @@ -107,6 +107,14 @@ + + + *:* + + google/protobuf/** + + + @@ -147,7 +155,7 @@ user-defined-protoc - ${env.PROTOBUF_PROTOC_EXEC_PATH} + ${env.SPARK_PROTOC_EXEC_PATH} @@ -165,7 +173,7 @@ com.google.protobuf:protoc:${protobuf.version} ${protobuf.version} - ${protobuf.protoc.executable.path} + ${spark.protoc.executable.path} src/test/resources/protobuf diff --git a/core/pom.xml b/core/pom.xml index a9b40acf5a..fb032064ed 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -558,12 +558,42 @@ maven-antrun-plugin + choose-shell-and-script + validate + + run + + + true + + + + + + + + + + + + Shell to use for generating spark-version-info.properties file = + ${shell} + + Script to use for generating spark-version-info.properties file = + ${spark-build-info-script} + + + + + + generate-spark-build-info generate-resources - - + + @@ -629,10 +659,34 @@ true + org.spark-project.spark:unused + org.eclipse.jetty:jetty-io + org.eclipse.jetty:jetty-http + org.eclipse.jetty:jetty-proxy + org.eclipse.jetty:jetty-client + org.eclipse.jetty:jetty-continuation + org.eclipse.jetty:jetty-servlet + org.eclipse.jetty:jetty-servlets + org.eclipse.jetty:jetty-plus + org.eclipse.jetty:jetty-security + org.eclipse.jetty:jetty-util + org.eclipse.jetty:jetty-server + com.google.guava:guava com.google.protobuf:* + + org.eclipse.jetty + ${spark.shade.packageName}.jetty + + org.eclipse.jetty.** + + + + com.google.common + ${spark.shade.packageName}.guava + com.google.protobuf ${spark.shade.packageName}.spark-core.protobuf @@ -643,26 +697,6 @@ - - com.github.os72 - protoc-jar-maven-plugin - ${protoc-jar-maven-plugin.version} - - - generate-sources - - run - - - com.google.protobuf:protoc:${protobuf.version} - ${protobuf.version} - - src/main/protobuf - - - - - @@ -713,6 +747,69 @@ + + default-protoc + + + !skipDefaultProtoc + + + + + + com.github.os72 + protoc-jar-maven-plugin + ${protoc-jar-maven-plugin.version} + + + generate-sources + + run + + + com.google.protobuf:protoc:${protobuf.version} + ${protobuf.version} + + src/main/protobuf + + + + + + + + + + user-defined-protoc + + ${env.SPARK_PROTOC_EXEC_PATH} + + + + + com.github.os72 + protoc-jar-maven-plugin + ${protoc-jar-maven-plugin.version} + + + generate-sources + + run + + + com.google.protobuf:protoc:${protobuf.version} + ${protobuf.version} + ${spark.protoc.executable.path} + + src/main/protobuf + + + + + + + + diff --git a/core/src/main/resources/error/error-classes.json b/core/src/main/resources/error/error-classes.json index 19ab5ada2b..25362d5893 100644 --- a/core/src/main/resources/error/error-classes.json +++ b/core/src/main/resources/error/error-classes.json @@ -5,6 +5,12 @@ ], "sqlState" : "42000" }, + "AMBIGUOUS_LATERAL_COLUMN_ALIAS" : { + "message" : [ + "Lateral column alias is ambiguous and has matches." + ], + "sqlState" : "42000" + }, "AMBIGUOUS_REFERENCE" : { "message" : [ "Reference is ambiguous, could be: ." @@ -109,6 +115,11 @@ "The column already exists. Consider to choose another name or rename the existing column." ] }, + "COLUMN_NOT_FOUND" : { + "message" : [ + "The column cannot be found. Verify the spelling and correctness of the column name according to the SQL config ." + ] + }, "CONCURRENT_QUERY" : { "message" : [ "Another instance of this query was just started by a concurrent session." @@ -813,6 +824,12 @@ } } }, + "INVALID_TYPED_LITERAL" : { + "message" : [ + "The value of the typed literal is invalid: ." + ], + "sqlState" : "42000" + }, "INVALID_WHERE_CONDITION" : { "message" : [ "The WHERE condition contains invalid expressions: .", @@ -932,7 +949,7 @@ }, "NUM_COLUMNS_MISMATCH" : { "message" : [ - " can only be performed on tables with the same number of columns, but the first table has columns and the table has columns." + " can only be performed on inputs with the same number of columns, but the first input has columns and the input has columns." ] }, "ORDER_BY_POS_OUT_OF_RANGE" : { @@ -1257,6 +1274,11 @@ "AES- with the padding by the function." ] }, + "ANALYZE_UNCACHED_TEMP_VIEW" : { + "message" : [ + "The ANALYZE TABLE FOR COLUMNS command can operate on temporary views that have been cached already. Consider to cache the view ." + ] + }, "CATALOG_OPERATION" : { "message" : [ "Catalog does not support ." @@ -1599,16 +1621,6 @@ "Function trim doesn't support with type . Please use BOTH, LEADING or TRAILING as trim type." ] }, - "_LEGACY_ERROR_TEMP_0019" : { - "message" : [ - "Cannot parse the value: ." - ] - }, - "_LEGACY_ERROR_TEMP_0020" : { - "message" : [ - "Cannot parse the INTERVAL value: ." - ] - }, "_LEGACY_ERROR_TEMP_0022" : { "message" : [ "." @@ -2091,11 +2103,6 @@ " does not support nested column: ." ] }, - "_LEGACY_ERROR_TEMP_1061" : { - "message" : [ - "Column does not exist." - ] - }, "_LEGACY_ERROR_TEMP_1065" : { "message" : [ "`` is not a valid name for tables/databases. Valid names only contain alphabet characters, numbers and _." @@ -2868,11 +2875,6 @@ "Partition spec is invalid. The spec () must match the partition spec () defined in table ''." ] }, - "_LEGACY_ERROR_TEMP_1234" : { - "message" : [ - "Temporary view is not cached for analyzing columns." - ] - }, "_LEGACY_ERROR_TEMP_1235" : { "message" : [ "Column in table is of type , and Spark does not support statistics collection on this column type." diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index fe738f4149..825d9ce779 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -25,7 +25,7 @@ import scala.concurrent.Future import org.apache.spark.executor.ExecutorMetrics import org.apache.spark.internal.{config, Logging} import org.apache.spark.internal.config.Network -import org.apache.spark.rpc.{IsolatedRpcEndpoint, RpcCallContext, RpcEnv} +import org.apache.spark.rpc.{IsolatedThreadSafeRpcEndpoint, RpcCallContext, RpcEnv} import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RemoveExecutor import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend @@ -65,7 +65,7 @@ private[spark] case class HeartbeatResponse(reregisterBlockManager: Boolean) * Lives in the driver to receive heartbeats from executors.. */ private[spark] class HeartbeatReceiver(sc: SparkContext, clock: Clock) - extends SparkListener with IsolatedRpcEndpoint with Logging { + extends SparkListener with IsolatedThreadSafeRpcEndpoint with Logging { def this(sc: SparkContext) = { this(sc, new SystemClock) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 783cf47df1..73acfedd8b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -20,6 +20,7 @@ package org.apache.spark.deploy import java.io._ import java.lang.reflect.{InvocationTargetException, UndeclaredThrowableException} import java.net.{URI, URL} +import java.nio.file.Files import java.security.PrivilegedExceptionAction import java.text.ParseException import java.util.{ServiceLoader, UUID} @@ -383,43 +384,55 @@ private[spark] class SparkSubmit extends Logging { }.orNull if (isKubernetesClusterModeDriver) { - // Replace with the downloaded local jar path to avoid propagating hadoop compatible uris. - // Executors will get the jars from the Spark file server. - // Explicitly download the related files here - args.jars = localJars - val filesLocalFiles = Option(args.files).map { - downloadFileList(_, targetDir, sparkConf, hadoopConf) - }.orNull - val archiveLocalFiles = Option(args.archives).map { uris => + // SPARK-33748: this mimics the behaviour of Yarn cluster mode. If the driver is running + // in cluster mode, the archives should be available in the driver's current working + // directory too. + // SPARK-33782 : This downloads all the files , jars , archiveFiles and pyfiles to current + // working directory + def downloadResourcesToCurrentDirectory(uris: String, isArchive: Boolean = false): + String = { val resolvedUris = Utils.stringToSeq(uris).map(Utils.resolveURI) - val localArchives = downloadFileList( + val localResources = downloadFileList( resolvedUris.map( UriBuilder.fromUri(_).fragment(null).build().toString).mkString(","), targetDir, sparkConf, hadoopConf) - - // SPARK-33748: this mimics the behaviour of Yarn cluster mode. If the driver is running - // in cluster mode, the archives should be available in the driver's current working - // directory too. - Utils.stringToSeq(localArchives).map(Utils.resolveURI).zip(resolvedUris).map { - case (localArchive, resolvedUri) => - val source = new File(localArchive.getPath) + Utils.stringToSeq(localResources).map(Utils.resolveURI).zip(resolvedUris).map { + case (localResources, resolvedUri) => + val source = new File(localResources.getPath) val dest = new File( ".", if (resolvedUri.getFragment != null) resolvedUri.getFragment else source.getName) logInfo( - s"Unpacking an archive $resolvedUri " + + s"Files $resolvedUri " + s"from ${source.getAbsolutePath} to ${dest.getAbsolutePath}") Utils.deleteRecursively(dest) - Utils.unpack(source, dest) - + if (isArchive) { + Utils.unpack(source, dest) + } else { + Files.copy(source.toPath, dest.toPath) + } // Keep the URIs of local files with the given fragments. UriBuilder.fromUri( - localArchive).fragment(resolvedUri.getFragment).build().toString + localResources).fragment(resolvedUri.getFragment).build().toString }.mkString(",") + } + + val filesLocalFiles = Option(args.files).map { + downloadResourcesToCurrentDirectory(_) + }.orNull + val jarsLocalJars = Option(args.jars).map { + downloadResourcesToCurrentDirectory(_) + }.orNull + val archiveLocalFiles = Option(args.archives).map { + downloadResourcesToCurrentDirectory(_, true) + }.orNull + val pyLocalFiles = Option(args.pyFiles).map { + downloadResourcesToCurrentDirectory(_) }.orNull args.files = filesLocalFiles args.archives = archiveLocalFiles - args.pyFiles = localPyFiles + args.pyFiles = pyLocalFiles + args.jars = jarsLocalJars } } diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index a94e63656e..d8f33a0612 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -35,6 +35,8 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.worker.WorkerWatcher import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ +import org.apache.spark.network.netty.SparkTransportConf +import org.apache.spark.network.util.NettyUtils import org.apache.spark.resource.ResourceInformation import org.apache.spark.resource.ResourceProfile import org.apache.spark.resource.ResourceProfile._ @@ -54,7 +56,7 @@ private[spark] class CoarseGrainedExecutorBackend( env: SparkEnv, resourcesFileOpt: Option[String], resourceProfile: ResourceProfile) - extends IsolatedRpcEndpoint with ExecutorBackend with Logging { + extends IsolatedThreadSafeRpcEndpoint with ExecutorBackend with Logging { import CoarseGrainedExecutorBackend._ @@ -85,7 +87,8 @@ private[spark] class CoarseGrainedExecutorBackend( logInfo("Connecting to driver: " + driverUrl) try { - if (PlatformDependent.directBufferPreferred() && + val shuffleClientTransportConf = SparkTransportConf.fromSparkConf(env.conf, "shuffle") + if (NettyUtils.preferDirectBufs(shuffleClientTransportConf) && PlatformDependent.maxDirectMemory() < env.conf.get(MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM)) { throw new SparkException(s"Netty direct memory should at least be bigger than " + s"'${MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM.key}', but got " + diff --git a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriterUtils.scala b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriterUtils.scala index 657842c620..6ba6713b69 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriterUtils.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriterUtils.scala @@ -47,11 +47,22 @@ object SparkHadoopWriterUtils { * @return a job ID */ def createJobID(time: Date, id: Int): JobID = { + val jobTrackerID = createJobTrackerID(time) + createJobID(jobTrackerID, id) + } + + /** + * Create a job ID. + * + * @param jobTrackerID unique job track id + * @param id job number + * @return a job ID + */ + def createJobID(jobTrackerID: String, id: Int): JobID = { if (id < 0) { throw new IllegalArgumentException("Job number is negative") } - val jobtrackerID = createJobTrackerID(time) - new JobID(jobtrackerID, id) + new JobID(jobTrackerID, id) } /** diff --git a/core/src/main/scala/org/apache/spark/internal/plugin/PluginEndpoint.scala b/core/src/main/scala/org/apache/spark/internal/plugin/PluginEndpoint.scala index 9a59b6bf67..989ef8f2ed 100644 --- a/core/src/main/scala/org/apache/spark/internal/plugin/PluginEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/internal/plugin/PluginEndpoint.scala @@ -19,14 +19,14 @@ package org.apache.spark.internal.plugin import org.apache.spark.api.plugin.DriverPlugin import org.apache.spark.internal.Logging -import org.apache.spark.rpc.{IsolatedRpcEndpoint, RpcCallContext, RpcEnv} +import org.apache.spark.rpc.{IsolatedThreadSafeRpcEndpoint, RpcCallContext, RpcEnv} case class PluginMessage(pluginName: String, message: AnyRef) private class PluginEndpoint( plugins: Map[String, DriverPlugin], override val rpcEnv: RpcEnv) - extends IsolatedRpcEndpoint with Logging { + extends IsolatedThreadSafeRpcEndpoint with Logging { override def receive: PartialFunction[Any, Unit] = { case PluginMessage(pluginName, message) => diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala index 4728759e7f..627f17f886 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala @@ -153,12 +153,25 @@ private[spark] trait ThreadSafeRpcEndpoint extends RpcEndpoint private[spark] trait IsolatedRpcEndpoint extends RpcEndpoint { /** - * How many threads to use for delivering messages. By default, use a single thread. + * How many threads to use for delivering messages. * * Note that requesting more than one thread means that the endpoint should be able to handle * messages arriving from many threads at once, and all the things that entails (including * messages being delivered to the endpoint out of order). */ - def threadCount(): Int = 1 + def threadCount(): Int + +} + +/** + * An endpoint that uses a dedicated thread pool for delivering messages and + * ensured to be thread-safe. + */ +private[spark] trait IsolatedThreadSafeRpcEndpoint extends IsolatedRpcEndpoint { + + /** + * Limit the threadCount to 1 so that messages are ensured to be handled in a thread-safe way. + */ + final def threadCount(): Int = 1 } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 225dd1d75b..2d3cf2ebc4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -127,7 +127,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp ThreadUtils.newDaemonSingleThreadScheduledExecutor("cleanup-decommission-execs") } - class DriverEndpoint extends IsolatedRpcEndpoint with Logging { + class DriverEndpoint extends IsolatedThreadSafeRpcEndpoint with Logging { override val rpcEnv: RpcEnv = CoarseGrainedSchedulerBackend.this.rpcEnv diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index ea028dfd11..287bf2165c 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -74,7 +74,7 @@ private[spark] class AppStatusListener( private val liveStages = new ConcurrentHashMap[(Int, Int), LiveStage]() private val liveJobs = new HashMap[Int, LiveJob]() private[spark] val liveExecutors = new HashMap[String, LiveExecutor]() - private val deadExecutors = new HashMap[String, LiveExecutor]() + private[spark] val deadExecutors = new HashMap[String, LiveExecutor]() private val liveTasks = new HashMap[Long, LiveTask]() private val liveRDDs = new HashMap[Int, LiveRDD]() private val pools = new HashMap[String, SchedulerPool]() @@ -674,22 +674,30 @@ private[spark] class AppStatusListener( delta }.orNull - val (completedDelta, failedDelta, killedDelta) = event.reason match { + // SPARK-41187: For `SparkListenerTaskEnd` with `Resubmitted` reason, which is raised by + // executor lost, it can lead to negative `LiveStage.activeTasks` since there's no + // corresponding `SparkListenerTaskStart` event for each of them. The negative activeTasks + // will make the stage always remains in the live stage list as it can never meet the + // condition activeTasks == 0. This in turn causes the dead executor to never be retained + // if that live stage's submissionTime is less than the dead executor's removeTime. + val (completedDelta, failedDelta, killedDelta, activeDelta) = event.reason match { case Success => - (1, 0, 0) + (1, 0, 0, 1) case _: TaskKilled => - (0, 0, 1) + (0, 0, 1, 1) case _: TaskCommitDenied => - (0, 0, 1) + (0, 0, 1, 1) + case _ @ Resubmitted => + (0, 1, 0, 0) case _ => - (0, 1, 0) + (0, 1, 0, 1) } Option(liveStages.get((event.stageId, event.stageAttemptId))).foreach { stage => if (metricsDelta != null) { stage.metrics = LiveEntityHelpers.addMetrics(stage.metrics, metricsDelta) } - stage.activeTasks -= 1 + stage.activeTasks -= activeDelta stage.completedTasks += completedDelta if (completedDelta > 0) { stage.completedIndices.add(event.taskInfo.index) @@ -699,7 +707,7 @@ private[spark] class AppStatusListener( if (killedDelta > 0) { stage.killedSummary = killedTasksSummary(event.reason, stage.killedSummary) } - stage.activeTasksPerExecutor(event.taskInfo.executorId) -= 1 + stage.activeTasksPerExecutor(event.taskInfo.executorId) -= activeDelta stage.peakExecutorMetrics.compareAndUpdatePeakValues(event.taskExecutorMetrics) stage.executorSummary(event.taskInfo.executorId).peakExecutorMetrics @@ -718,7 +726,7 @@ private[spark] class AppStatusListener( // Store both stage ID and task index in a single long variable for tracking at job level. val taskIndex = (event.stageId.toLong << Integer.SIZE) | event.taskInfo.index stage.jobs.foreach { job => - job.activeTasks -= 1 + job.activeTasks -= activeDelta job.completedTasks += completedDelta if (completedDelta > 0) { job.completedIndices.add(taskIndex) @@ -774,7 +782,7 @@ private[spark] class AppStatusListener( } liveExecutors.get(event.taskInfo.executorId).foreach { exec => - exec.activeTasks -= 1 + exec.activeTasks -= activeDelta exec.completedTasks += completedDelta exec.failedTasks += failedDelta exec.totalDuration += event.taskInfo.duration diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index d5fde96b14..1067ee1556 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -637,9 +637,14 @@ private[spark] class BlockManager( def reregister(): Unit = { // TODO: We might need to rate limit re-registering. logInfo(s"BlockManager $blockManagerId re-registering with master") - master.registerBlockManager(blockManagerId, diskBlockManager.localDirsString, maxOnHeapMemory, - maxOffHeapMemory, storageEndpoint) - reportAllBlocks() + val id = master.registerBlockManager(blockManagerId, diskBlockManager.localDirsString, + maxOnHeapMemory, maxOffHeapMemory, storageEndpoint, isReRegister = true) + if (id.executorId != BlockManagerId.INVALID_EXECUTOR_ID) { + reportAllBlocks() + } else { + logError("Exiting executor due to block manager re-registration failure") + System.exit(-1) + } } /** diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala index c6a4457d8f..12e416bbb3 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala @@ -147,4 +147,6 @@ private[spark] object BlockManagerId { } private[spark] val SHUFFLE_MERGER_IDENTIFIER = "shuffle-push-merger" + + private[spark] val INVALID_EXECUTOR_ID = "invalid" } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 40008e6afb..0ee3dc249d 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -74,11 +74,25 @@ class BlockManagerMaster( localDirs: Array[String], maxOnHeapMemSize: Long, maxOffHeapMemSize: Long, - storageEndpoint: RpcEndpointRef): BlockManagerId = { + storageEndpoint: RpcEndpointRef, + isReRegister: Boolean = false): BlockManagerId = { logInfo(s"Registering BlockManager $id") val updatedId = driverEndpoint.askSync[BlockManagerId]( - RegisterBlockManager(id, localDirs, maxOnHeapMemSize, maxOffHeapMemSize, storageEndpoint)) - logInfo(s"Registered BlockManager $updatedId") + RegisterBlockManager( + id, + localDirs, + maxOnHeapMemSize, + maxOffHeapMemSize, + storageEndpoint, + isReRegister + ) + ) + if (updatedId.executorId == BlockManagerId.INVALID_EXECUTOR_ID) { + assert(isReRegister, "Got invalid executor id from non re-register case") + logInfo(s"Re-register BlockManager $id failed") + } else { + logInfo(s"Registered BlockManager $updatedId") + } updatedId } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index adeb507941..d30272c51b 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -33,7 +33,7 @@ import org.apache.spark.{MapOutputTrackerMaster, SparkConf} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.internal.{config, Logging} import org.apache.spark.network.shuffle.ExternalBlockStoreClient -import org.apache.spark.rpc.{IsolatedRpcEndpoint, RpcCallContext, RpcEndpointRef, RpcEnv} +import org.apache.spark.rpc.{IsolatedThreadSafeRpcEndpoint, RpcCallContext, RpcEndpointRef, RpcEnv} import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.{CoarseGrainedClusterMessages, CoarseGrainedSchedulerBackend} import org.apache.spark.shuffle.ShuffleManager @@ -41,8 +41,8 @@ import org.apache.spark.storage.BlockManagerMessages._ import org.apache.spark.util.{RpcUtils, ThreadUtils, Utils} /** - * BlockManagerMasterEndpoint is an [[IsolatedRpcEndpoint]] on the master node to track statuses - * of all the storage endpoints' block managers. + * BlockManagerMasterEndpoint is an [[IsolatedThreadSafeRpcEndpoint]] on the master node to + * track statuses of all the storage endpoints' block managers. */ private[spark] class BlockManagerMasterEndpoint( @@ -55,7 +55,7 @@ class BlockManagerMasterEndpoint( mapOutputTracker: MapOutputTrackerMaster, shuffleManager: ShuffleManager, isDriver: Boolean) - extends IsolatedRpcEndpoint with Logging { + extends IsolatedThreadSafeRpcEndpoint with Logging { // Mapping from executor id to the block manager's local disk directories. private val executorIdToLocalDirs = @@ -117,8 +117,10 @@ class BlockManagerMasterEndpoint( RpcUtils.makeDriverRef(CoarseGrainedSchedulerBackend.ENDPOINT_NAME, conf, rpcEnv) override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case RegisterBlockManager(id, localDirs, maxOnHeapMemSize, maxOffHeapMemSize, endpoint) => - context.reply(register(id, localDirs, maxOnHeapMemSize, maxOffHeapMemSize, endpoint)) + case RegisterBlockManager( + id, localDirs, maxOnHeapMemSize, maxOffHeapMemSize, endpoint, isReRegister) => + context.reply( + register(id, localDirs, maxOnHeapMemSize, maxOffHeapMemSize, endpoint, isReRegister)) case _updateBlockInfo @ UpdateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size) => @@ -572,7 +574,8 @@ class BlockManagerMasterEndpoint( localDirs: Array[String], maxOnHeapMemSize: Long, maxOffHeapMemSize: Long, - storageEndpoint: RpcEndpointRef): BlockManagerId = { + storageEndpoint: RpcEndpointRef, + isReRegister: Boolean): BlockManagerId = { // the dummy id is not expected to contain the topology information. // we get that info here and respond back with a more fleshed out block manager id val id = BlockManagerId( @@ -583,7 +586,12 @@ class BlockManagerMasterEndpoint( val time = System.currentTimeMillis() executorIdToLocalDirs.put(id.executorId, localDirs) - if (!blockManagerInfo.contains(id)) { + // SPARK-41360: For the block manager re-registration, we should only allow it when + // the executor is recognized as active by the scheduler backend. Otherwise, this kind + // of re-registration from the terminating/stopped executor is meaningless and harmful. + lazy val isExecutorAlive = + driverEndpoint.askSync[Boolean](CoarseGrainedClusterMessages.IsExecutorAlive(id.executorId)) + if (!blockManagerInfo.contains(id) && (!isReRegister || isExecutorAlive)) { blockManagerIdByExecutor.get(id.executorId) match { case Some(oldId) => // A block manager of the same executor already exists, so remove it (assumed dead) @@ -616,10 +624,29 @@ class BlockManagerMasterEndpoint( if (pushBasedShuffleEnabled) { addMergerLocation(id) } + listenerBus.post(SparkListenerBlockManagerAdded(time, id, + maxOnHeapMemSize + maxOffHeapMemSize, Some(maxOnHeapMemSize), Some(maxOffHeapMemSize))) } - listenerBus.post(SparkListenerBlockManagerAdded(time, id, maxOnHeapMemSize + maxOffHeapMemSize, - Some(maxOnHeapMemSize), Some(maxOffHeapMemSize))) - id + val updatedId = if (isReRegister && !isExecutorAlive) { + assert(!blockManagerInfo.contains(id), + "BlockManager re-registration shouldn't succeed when the executor is lost") + + logInfo(s"BlockManager ($id) re-registration is rejected since " + + s"the executor (${id.executorId}) has been lost") + + // Use "invalid" as the return executor id to indicate the block manager that + // re-registration failed. It's a bit hacky but fine since the returned block + // manager id won't be accessed in the case of re-registration. And we'll use + // this "invalid" executor id to print better logs and avoid blocks reporting. + BlockManagerId( + BlockManagerId.INVALID_EXECUTOR_ID, + id.host, + id.port, + id.topologyInfo) + } else { + id + } + updatedId } private def updateShuffleBlockInfo(blockId: BlockId, blockManagerId: BlockManagerId) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index afe416a55e..e047b61fcb 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -63,7 +63,8 @@ private[spark] object BlockManagerMessages { localDirs: Array[String], maxOnHeapMemSize: Long, maxOffHeapMemSize: Long, - sender: RpcEndpointRef) + sender: RpcEndpointRef, + isReRegister: Boolean) extends ToBlockManagerMaster case class UpdateBlockInfo( diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerStorageEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerStorageEndpoint.scala index 54a72568b1..71c7a4de4c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerStorageEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerStorageEndpoint.scala @@ -21,7 +21,7 @@ import scala.concurrent.{ExecutionContext, Future} import org.apache.spark.{MapOutputTracker, SparkEnv} import org.apache.spark.internal.Logging -import org.apache.spark.rpc.{IsolatedRpcEndpoint, RpcCallContext, RpcEnv} +import org.apache.spark.rpc.{IsolatedThreadSafeRpcEndpoint, RpcCallContext, RpcEnv} import org.apache.spark.storage.BlockManagerMessages._ import org.apache.spark.util.{ThreadUtils, Utils} @@ -34,7 +34,7 @@ class BlockManagerStorageEndpoint( override val rpcEnv: RpcEnv, blockManager: BlockManager, mapOutputTracker: MapOutputTracker) - extends IsolatedRpcEndpoint with Logging { + extends IsolatedThreadSafeRpcEndpoint with Logging { private val asyncThreadPool = ThreadUtils.newDaemonCachedThreadPool("block-manager-storage-async-thread-pool", 100) diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala index 7a08de9c18..27198039fd 100644 --- a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala @@ -32,7 +32,7 @@ import org.apache.logging.log4j.core.{LogEvent, Logger, LoggerContext} import org.apache.logging.log4j.core.appender.AbstractAppender import org.apache.logging.log4j.core.config.Property import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, BeforeAndAfterEach, Failed, Outcome} -import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.funsuite.AnyFunSuite // scalastyle:ignore funsuite import org.apache.spark.deploy.LocalSparkCluster import org.apache.spark.internal.Logging @@ -64,7 +64,7 @@ import org.apache.spark.util.{AccumulatorContext, Utils} * } */ abstract class SparkFunSuite - extends AnyFunSuite + extends AnyFunSuite // scalastyle:ignore funsuite with BeforeAndAfterAll with BeforeAndAfterEach with ThreadAudit diff --git a/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala b/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala index c15ae9504c..64703b0b04 100644 --- a/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala @@ -147,6 +147,18 @@ class SparkThrowableSuite extends SparkFunSuite { assert(rereadErrorClassToInfoMap == errorReader.errorInfoMap) } + test("Error class names should contain only capital letters, numbers and underscores") { + val allowedChars = "[A-Z0-9_]*" + errorReader.errorInfoMap.foreach { e => + assert(e._1.matches(allowedChars), s"Error class: ${e._1} is invalid") + e._2.subClass.map { s => + s.keys.foreach { k => + assert(k.matches(allowedChars), s"Error sub-class: $k is invalid") + } + } + } + } + test("Check if error class is missing") { val ex1 = intercept[SparkException] { getMessage("", Map.empty[String, String]) diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 6bd3a49576..76311d0ab1 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -486,6 +486,41 @@ class SparkSubmitSuite conf.get("spark.kubernetes.driver.container.image") should be ("bar") } + test("SPARK-33782: handles k8s files download to current directory") { + val clArgs = Seq( + "--deploy-mode", "client", + "--proxy-user", "test.user", + "--master", "k8s://host:port", + "--executor-memory", "5g", + "--class", "org.SomeClass", + "--driver-memory", "4g", + "--conf", "spark.kubernetes.namespace=spark", + "--conf", "spark.kubernetes.driver.container.image=bar", + "--conf", "spark.kubernetes.submitInDriver=true", + "--files", "src/test/resources/test_metrics_config.properties", + "--py-files", "src/test/resources/test_metrics_system.properties", + "--archives", "src/test/resources/log4j2.properties", + "--jars", "src/test/resources/TestUDTF.jar", + "/home/thejar.jar", + "arg1") + val appArgs = new SparkSubmitArguments(clArgs) + val (childArgs, classpath, conf, mainClass) = submit.prepareSubmitEnvironment(appArgs) + conf.get("spark.master") should be ("k8s://https://host:port") + conf.get("spark.executor.memory") should be ("5g") + conf.get("spark.driver.memory") should be ("4g") + conf.get("spark.kubernetes.namespace") should be ("spark") + conf.get("spark.kubernetes.driver.container.image") should be ("bar") + + Files.exists(Paths.get("test_metrics_config.properties")) should be (true) + Files.exists(Paths.get("test_metrics_system.properties")) should be (true) + Files.exists(Paths.get("log4j2.properties")) should be (true) + Files.exists(Paths.get("TestUDTF.jar")) should be (true) + Files.delete(Paths.get("test_metrics_config.properties")) + Files.delete(Paths.get("test_metrics_system.properties")) + Files.delete(Paths.get("log4j2.properties")) + Files.delete(Paths.get("TestUDTF.jar")) + } + /** * Helper function for testing main class resolution on remote JAR files. * diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index c70dde79b3..6e5eb77322 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -962,7 +962,8 @@ abstract class RpcEnvSuite extends SparkFunSuite { val singleThreadedEnv = createRpcEnv( new SparkConf().set(Network.RPC_NETTY_DISPATCHER_NUM_THREADS, 1), "singleThread", 0) try { - val blockingEndpoint = singleThreadedEnv.setupEndpoint("blocking", new IsolatedRpcEndpoint { + val blockingEndpoint = singleThreadedEnv + .setupEndpoint("blocking", new IsolatedThreadSafeRpcEndpoint { override val rpcEnv: RpcEnv = singleThreadedEnv override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index 24a8a6844f..5d0c25aa86 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -1849,6 +1849,68 @@ abstract class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter checkInfoPopulated(listener, logUrlMap, processId) } + test("SPARK-41187: Stage should be removed from liveStages to avoid deadExecutors accumulated") { + + val listener = new AppStatusListener(store, conf, true) + + listener.onExecutorAdded(createExecutorAddedEvent(1)) + listener.onExecutorAdded(createExecutorAddedEvent(2)) + val stage = new StageInfo(1, 0, "stage", 4, Nil, Nil, "details", + resourceProfileId = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID) + listener.onJobStart(SparkListenerJobStart(1, time, Seq(stage), null)) + + time += 1 + stage.submissionTime = Some(time) + listener.onStageSubmitted(SparkListenerStageSubmitted(stage, new Properties())) + + val tasks = createTasks(2, Array("1", "2")) + tasks.foreach { task => + listener.onTaskStart(SparkListenerTaskStart(stage.stageId, stage.attemptNumber, task)) + } + + time += 1 + tasks(0).markFinished(TaskState.FINISHED, time) + listener.onTaskEnd(SparkListenerTaskEnd(stage.stageId, stage.attemptNumber, "taskType", + Success, tasks(0), new ExecutorMetrics, null)) + + // executor lost, success task will be resubmitted + time += 1 + listener.onTaskEnd(SparkListenerTaskEnd(stage.stageId, stage.attemptNumber, "taskType", + Resubmitted, tasks(0), new ExecutorMetrics, null)) + + // executor lost, running task will be failed and rerun + time += 1 + tasks(1).markFinished(TaskState.FAILED, time) + listener.onTaskEnd(SparkListenerTaskEnd(stage.stageId, stage.attemptNumber, "taskType", + ExecutorLostFailure("1", true, Some("Lost executor")), tasks(1), new ExecutorMetrics, + null)) + + tasks.foreach { task => + listener.onTaskStart(SparkListenerTaskStart(stage.stageId, stage.attemptNumber, task)) + } + + time += 1 + tasks(0).markFinished(TaskState.FINISHED, time) + listener.onTaskEnd(SparkListenerTaskEnd(stage.stageId, stage.attemptNumber, "taskType", + Success, tasks(0), new ExecutorMetrics, null)) + + time += 1 + tasks(1).markFinished(TaskState.FINISHED, time) + listener.onTaskEnd(SparkListenerTaskEnd(stage.stageId, stage.attemptNumber, "taskType", + Success, tasks(1), new ExecutorMetrics, null)) + + listener.onStageCompleted(SparkListenerStageCompleted(stage)) + time += 1 + listener.onJobEnd(SparkListenerJobEnd(1, time, JobSucceeded )) + + time += 1 + listener.onExecutorRemoved(SparkListenerExecutorRemoved(time, "1", "Test")) + time += 1 + listener.onExecutorRemoved(SparkListenerExecutorRemoved(time, "2", "Test")) + + assert(listener.deadExecutors.size === 0) + } + private def key(stage: StageInfo): Array[Int] = Array(stage.stageId, stage.attemptNumber) private def check[T: ClassTag](key: Any)(fn: T => Unit): Unit = { diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index c8914761b9..842b66193f 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -295,7 +295,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with PrivateMethodTe eventually(timeout(5.seconds)) { // make sure both bm1 and bm2 are registered at driver side BlockManagerMaster verify(master, times(2)) - .registerBlockManager(mc.any(), mc.any(), mc.any(), mc.any(), mc.any()) + .registerBlockManager(mc.any(), mc.any(), mc.any(), mc.any(), mc.any(), mc.any()) assert(driverEndpoint.askSync[Boolean]( CoarseGrainedClusterMessages.IsExecutorAlive(bm1Id.executorId))) assert(driverEndpoint.askSync[Boolean]( @@ -361,6 +361,44 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with PrivateMethodTe master.removeShuffle(0, true) } + test("SPARK-41360: Avoid block manager re-registration if the executor has been lost") { + // Set up a DriverEndpoint which always returns isExecutorAlive=false + rpcEnv.setupEndpoint(CoarseGrainedSchedulerBackend.ENDPOINT_NAME, + new RpcEndpoint { + override val rpcEnv: RpcEnv = BlockManagerSuite.this.rpcEnv + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case CoarseGrainedClusterMessages.RegisterExecutor(executorId, _, _, _, _, _, _, _) => + context.reply(true) + case CoarseGrainedClusterMessages.IsExecutorAlive(executorId) => + // always return false + context.reply(false) + } + } + ) + + // Set up a block manager endpoint and endpoint reference + val bmRef = rpcEnv.setupEndpoint(s"bm-0", new RpcEndpoint { + override val rpcEnv: RpcEnv = BlockManagerSuite.this.rpcEnv + + private def reply[T](context: RpcCallContext, response: T): Unit = { + context.reply(response) + } + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case RemoveRdd(_) => reply(context, 1) + case RemoveBroadcast(_, _) => reply(context, 1) + case RemoveShuffle(_) => reply(context, true) + } + }) + val bmId = BlockManagerId(s"exec-0", "localhost", 1234, None) + // Register the block manager with isReRegister = true + val updatedId = master.registerBlockManager( + bmId, Array.empty, 2000, 0, bmRef, isReRegister = true) + // The re-registration should fail since the executor is considered as dead by DriverEndpoint + assert(updatedId.executorId === BlockManagerId.INVALID_EXECUTOR_ID) + } + test("StorageLevel object caching") { val level1 = StorageLevel(false, false, false, 3) // this should return the same object as level1 @@ -669,6 +707,22 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with PrivateMethodTe val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) + // Set up a DriverEndpoint which simulates the executor is alive (required by SPARK-41360) + rpcEnv.setupEndpoint(CoarseGrainedSchedulerBackend.ENDPOINT_NAME, + new RpcEndpoint { + override val rpcEnv: RpcEnv = BlockManagerSuite.this.rpcEnv + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case CoarseGrainedClusterMessages.IsExecutorAlive(executorId) => + if (executorId == store.blockManagerId.executorId) { + context.reply(true) + } else { + context.reply(false) + } + } + } + ) + store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) assert(master.getLocations("a1").size > 0, "master was not told about a1") @@ -2207,7 +2261,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with PrivateMethodTe }.getMessage assert(e.contains("TimeoutException")) verify(master, times(0)) - .registerBlockManager(mc.any(), mc.any(), mc.any(), mc.any(), mc.any()) + .registerBlockManager(mc.any(), mc.any(), mc.any(), mc.any(), mc.any(), mc.any()) server.close() } } diff --git a/dev/deps/spark-deps-hadoop-2-hive-2.3 b/dev/deps/spark-deps-hadoop-2-hive-2.3 index ad7a8a1a4c..ae7cc9d592 100644 --- a/dev/deps/spark-deps-hadoop-2-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-2-hive-2.3 @@ -101,12 +101,11 @@ hive-shims-common/2.3.9//hive-shims-common-2.3.9.jar hive-shims-scheduler/2.3.9//hive-shims-scheduler-2.3.9.jar hive-shims/2.3.9//hive-shims-2.3.9.jar hive-storage-api/2.7.3//hive-storage-api-2.7.3.jar -hive-vector-code-gen/2.3.9//hive-vector-code-gen-2.3.9.jar hk2-api/2.6.1//hk2-api-2.6.1.jar hk2-locator/2.6.1//hk2-locator-2.6.1.jar hk2-utils/2.6.1//hk2-utils-2.6.1.jar htrace-core/3.1.0-incubating//htrace-core-3.1.0-incubating.jar -httpclient/4.5.13//httpclient-4.5.13.jar +httpclient/4.5.14//httpclient-4.5.14.jar httpcore/4.4.14//httpcore-4.4.14.jar istack-commons-runtime/3.0.8//istack-commons-runtime-3.0.8.jar ivy/2.5.1//ivy-2.5.1.jar @@ -261,7 +260,6 @@ threeten-extra/1.7.1//threeten-extra-1.7.1.jar tink/1.7.0//tink-1.7.0.jar transaction-api/1.1//transaction-api-1.1.jar univocity-parsers/2.9.1//univocity-parsers-2.9.1.jar -velocity/1.5//velocity-1.5.jar xbean-asm9-shaded/4.22//xbean-asm9-shaded-4.22.jar xercesImpl/2.12.2//xercesImpl-2.12.2.jar xml-apis/1.4.01//xml-apis-1.4.01.jar diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3 index cac2e9f305..f70abedd34 100644 --- a/dev/deps/spark-deps-hadoop-3-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3-hive-2.3 @@ -89,11 +89,10 @@ hive-shims-common/2.3.9//hive-shims-common-2.3.9.jar hive-shims-scheduler/2.3.9//hive-shims-scheduler-2.3.9.jar hive-shims/2.3.9//hive-shims-2.3.9.jar hive-storage-api/2.7.3//hive-storage-api-2.7.3.jar -hive-vector-code-gen/2.3.9//hive-vector-code-gen-2.3.9.jar hk2-api/2.6.1//hk2-api-2.6.1.jar hk2-locator/2.6.1//hk2-locator-2.6.1.jar hk2-utils/2.6.1//hk2-utils-2.6.1.jar -httpclient/4.5.13//httpclient-4.5.13.jar +httpclient/4.5.14//httpclient-4.5.14.jar httpcore/4.4.14//httpcore-4.4.14.jar ini4j/0.5.4//ini4j-0.5.4.jar istack-commons-runtime/3.0.8//istack-commons-runtime-3.0.8.jar @@ -248,7 +247,6 @@ threeten-extra/1.7.1//threeten-extra-1.7.1.jar tink/1.7.0//tink-1.7.0.jar transaction-api/1.1//transaction-api-1.1.jar univocity-parsers/2.9.1//univocity-parsers-2.9.1.jar -velocity/1.5//velocity-1.5.jar wildfly-openssl/1.0.7.Final//wildfly-openssl-1.0.7.Final.jar xbean-asm9-shaded/4.22//xbean-asm9-shaded-4.22.jar xz/1.9//xz-1.9.jar diff --git a/dev/lint-scala b/dev/lint-scala index ea3b98464b..48ecf57ef4 100755 --- a/dev/lint-scala +++ b/dev/lint-scala @@ -29,14 +29,14 @@ ERRORS=$(./build/mvn \ -Dscalafmt.skip=false \ -Dscalafmt.validateOnly=true \ -Dscalafmt.changedOnly=false \ - -pl connector/connect \ + -pl connector/connect/server \ 2>&1 | grep -e "^Requires formatting" \ ) if test ! -z "$ERRORS"; then echo -e "The scalafmt check failed on connector/connect at following occurrences:\n\n$ERRORS\n" echo "Before submitting your change, please make sure to format your code using the following command:" - echo "./build/mvn -Pscala-2.12 scalafmt:format -Dscalafmt.skip=fase -Dscalafmt.validateOnly=false -Dscalafmt.changedOnly=false -pl connector/connect" + echo "./build/mvn -Pscala-2.12 scalafmt:format -Dscalafmt.skip=false -Dscalafmt.validateOnly=false -Dscalafmt.changedOnly=false -pl connector/connect/server" exit 1 else echo -e "Scalafmt checks passed." diff --git a/dev/sparktestsupport/utils.py b/dev/sparktestsupport/utils.py index 37e023aaa6..a3df038b55 100755 --- a/dev/sparktestsupport/utils.py +++ b/dev/sparktestsupport/utils.py @@ -34,19 +34,22 @@ def determine_modules_for_files(filenames): Given a list of filenames, return the set of modules that contain those files. If a file is not associated with a more specific submodule, then this method will consider that file to belong to the 'root' module. `.github` directory is counted only in GitHub Actions, - and `appveyor.yml` is always ignored because this file is dedicated only to AppVeyor builds. + and `appveyor.yml` is always ignored because this file is dedicated only to AppVeyor builds, + and `README.md` is always ignored too. >>> sorted(x.name for x in determine_modules_for_files(["python/pyspark/a.py", "sql/core/foo"])) ['pyspark-core', 'sql'] >>> [x.name for x in determine_modules_for_files(["file_not_matched_by_any_subproject"])] ['root'] - >>> [x.name for x in determine_modules_for_files(["appveyor.yml"])] + >>> [x.name for x in determine_modules_for_files(["appveyor.yml", "sql/README.md"])] [] """ changed_modules = set() for filename in filenames: if filename in ("appveyor.yml",): continue + if filename.endswith("README.md"): + continue if ("GITHUB_ACTIONS" not in os.environ) and filename.startswith(".github"): continue matched_at_least_one_module = False diff --git a/dev/tox.ini b/dev/tox.ini index f44cbe54dd..15c93832c2 100644 --- a/dev/tox.ini +++ b/dev/tox.ini @@ -36,7 +36,8 @@ per-file-ignores = python/pyspark/resource/tests/*.py: F403, python/pyspark/sql/tests/*.py: F403, python/pyspark/streaming/tests/*.py: F403, - python/pyspark/tests/*.py: F403 + python/pyspark/tests/*.py: F403, + python/pyspark/testing/*: F401 exclude = */target/*, docs/.local_ruby_bundle/, diff --git a/docs/building-spark.md b/docs/building-spark.md index 3e1ec771da..9b115f1ad9 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -317,3 +317,21 @@ To build and run tests on IPv6-only environment, the following configurations ar export MAVEN_OPTS="-Djava.net.preferIPv6Addresses=true" export SBT_OPTS="-Djava.net.preferIPv6Addresses=true" export SERIAL_SBT_TESTS=1 + +### Building with user-defined `protoc` + +When the user cannot use the official `protoc` binary files to build the `core` module in the compilation environment, for example, compiling `core` module on CentOS 6 or CentOS 7 which the default `glibc` version is less than 2.14, we can try to compile and test by specifying the user-defined `protoc` binary files as follows: + +```bash +export SPARK_PROTOC_EXEC_PATH=/path-to-protoc-exe +./build/mvn -Puser-defined-protoc -DskipDefaultProtoc clean package +``` + +or + +```bash +export SPARK_PROTOC_EXEC_PATH=/path-to-protoc-exe +./build/sbt -Puser-defined-protoc clean package +``` + +The user-defined `protoc` binary files can be produced in the user's compilation environment by source code compilation, for compilation steps, please refer to [protobuf](https://github.com/protocolbuffers/protobuf). diff --git a/docs/mllib-isotonic-regression.md b/docs/mllib-isotonic-regression.md index 95be32a819..711e828bd8 100644 --- a/docs/mllib-isotonic-regression.md +++ b/docs/mllib-isotonic-regression.md @@ -43,7 +43,14 @@ best fitting the original data points. which uses an approach to [parallelizing isotonic regression](https://doi.org/10.1007/978-3-642-99789-1_10). The training input is an RDD of tuples of three double values that represent -label, feature and weight in this order. Additionally, IsotonicRegression algorithm has one +label, feature and weight in this order. In case there are multiple tuples with +the same feature then these tuples are aggregated into a single tuple as follows: + +* Aggregated label is the weighted average of all labels. +* Aggregated feature is the unique feature value. +* Aggregated weight is the sum of all weights. + +Additionally, IsotonicRegression algorithm has one optional parameter called $isotonic$ defaulting to true. This argument specifies if the isotonic regression is isotonic (monotonically increasing) or antitonic (monotonically decreasing). @@ -53,17 +60,12 @@ labels for both known and unknown features. The result of isotonic regression is treated as piecewise linear function. The rules for prediction therefore are: * If the prediction input exactly matches a training feature - then associated prediction is returned. In case there are multiple predictions with the same - feature then one of them is returned. Which one is undefined - (same as java.util.Arrays.binarySearch). + then associated prediction is returned. * If the prediction input is lower or higher than all training features then prediction with lowest or highest feature is returned respectively. - In case there are multiple predictions with the same feature - then the lowest or highest is returned respectively. * If the prediction input falls between two training features then prediction is treated as piecewise linear function and interpolated value is calculated from the - predictions of the two closest features. In case there are multiple values - with the same feature then the same rules as in previous point are used. + predictions of the two closest features. ### Examples diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index 08580a77eb..21c81c508e 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -204,6 +204,26 @@ When this property is set, it's highly recommended to make it unique across all Use the exact prefix `spark.kubernetes.authenticate` for Kubernetes authentication parameters in client mode. +## IPv4 and IPv6 + +Starting with 3.4.0, Spark supports additionally IPv6-only environment via +[IPv4/IPv6 dual-stack network](https://kubernetes.io/docs/concepts/services-networking/dual-stack/) +feature which enables the allocation of both IPv4 and IPv6 addresses to Pods and Services. +According to the K8s cluster capability, `spark.kubernetes.driver.service.ipFamilyPolicy` and +`spark.kubernetes.driver.service.ipFamilies` can be one of `SingleStack`, `PreferDualStack`, +and `RequireDualStack` and one of `IPv4`, `IPv6`, `IPv4,IPv6`, and `IPv6,IPv4` respectively. +By default, Spark uses `spark.kubernetes.driver.service.ipFamilyPolicy=SingleStack` and +`spark.kubernetes.driver.service.ipFamilies=IPv4`. + +To use only `IPv6`, you can submit your jobs with the following. +```bash +... + --conf spark.kubernetes.driver.service.ipFamilies=IPv6 \ +``` + +In `DualStack` environment, you may need `java.net.preferIPv6Addresses=true` for JVM +and `SPARK_PREFER_IPV6=true` for Python additionally to use `IPv6`. + ## Dependency Management If your application's dependencies are all hosted in remote locations like HDFS or HTTP servers, they may be referred to @@ -1418,7 +1438,8 @@ See the [configuration page](configuration.html) for information on Spark config spark.kubernetes.driver.service.ipFamilyPolicy SingleStack - K8s IP Family Policy for Driver Service. + K8s IP Family Policy for Driver Service. Valid values are + SingleStack, PreferDualStack, and RequireDualStack. 3.4.0 @@ -1426,7 +1447,8 @@ See the [configuration page](configuration.html) for information on Spark config spark.kubernetes.driver.service.ipFamilies IPv4 - A list of IP families for K8s Driver Service. + A list of IP families for K8s Driver Service. Valid values are + IPv4 and IPv6. 3.4.0 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala index 649f9816e6..fbf0dc9c35 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala @@ -14,7 +14,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.spark.mllib.regression import java.io.Serializable @@ -272,8 +271,8 @@ class IsotonicRegression private (private var isotonic: Boolean) extends Seriali * @param input RDD of tuples (label, feature, weight) where label is dependent variable * for which we calculate isotonic regression, feature is independent variable * and weight represents number of measures with default 1. - * If multiple labels share the same feature value then they are ordered before - * the algorithm is executed. + * If multiple labels share the same feature value then they are aggregated using + * the weighted average before the algorithm is executed. * @return Isotonic regression model. */ @Since("1.3.0") @@ -298,8 +297,8 @@ class IsotonicRegression private (private var isotonic: Boolean) extends Seriali * @param input JavaRDD of tuples (label, feature, weight) where label is dependent variable * for which we calculate isotonic regression, feature is independent variable * and weight represents number of measures with default 1. - * If multiple labels share the same feature value then they are ordered before - * the algorithm is executed. + * If multiple labels share the same feature value then they are aggregated using + * the weighted average before the algorithm is executed. * @return Isotonic regression model. */ @Since("1.3.0") @@ -307,6 +306,58 @@ class IsotonicRegression private (private var isotonic: Boolean) extends Seriali run(input.rdd.retag.asInstanceOf[RDD[(Double, Double, Double)]]) } + /** + * Aggregates points of duplicate feature values into a single point using as label the weighted + * average of the labels of the points with duplicate feature values. All points for a unique + * feature value are aggregated as: + * + * - Aggregated label is the weighted average of all labels. + * - Aggregated feature is the unique feature value. + * - Aggregated weight is the sum of all weights. + * + * @param input Input data of tuples (label, feature, weight). Weights must be non-negative. + * @return Points with unique feature values. + */ + private[regression] def makeUnique( + input: Array[(Double, Double, Double)]): Array[(Double, Double, Double)] = { + + val cleanInput = input.filter { case (y, x, weight) => + require( + weight >= 0.0, + s"Negative weight at point ($y, $x, $weight). Weights must be non-negative") + weight > 0 + } + + if (cleanInput.length <= 1) { + cleanInput + } else { + val pointsAccumulator = new IsotonicRegression.PointsAccumulator + + // Go through input points, merging all points with equal feature values into a single point. + // Equality of features is defined by shouldAccumulate method. The label of the accumulated + // points is the weighted average of the labels of all points of equal feature value. + + // Initialize with first point + pointsAccumulator := cleanInput.head + // Accumulate the rest + cleanInput.tail.foreach { case point @ (_, feature, _) => + if (pointsAccumulator.shouldAccumulate(feature)) { + // Still on a duplicate feature, accumulate + pointsAccumulator += point + } else { + // A new unique feature encountered: + // - append the last accumulated point to unique features output + pointsAccumulator.appendToOutput() + // - and reset + pointsAccumulator := point + } + } + // Append the last accumulated point to unique features output + pointsAccumulator.appendToOutput() + pointsAccumulator.getOutput + } + } + /** * Performs a pool adjacent violators algorithm (PAV). Implements the algorithm originally * described in [1], using the formulation from [2, 3]. Uses an array to keep track of start @@ -322,35 +373,27 @@ class IsotonicRegression private (private var isotonic: Boolean) extends Seriali * functions subject to simple chain constraints." SIAM Journal on Optimization 10.3 (2000): * 658-672. * - * @param input Input data of tuples (label, feature, weight). Weights must - be non-negative. + * @param cleanUniqueInput Input data of tuples(label, feature, weight).Features must be unique + * and weights must be non-negative. * @return Result tuples (label, feature, weight) where labels were updated * to form a monotone sequence as per isotonic regression definition. */ private def poolAdjacentViolators( - input: Array[(Double, Double, Double)]): Array[(Double, Double, Double)] = { + cleanUniqueInput: Array[(Double, Double, Double)]): Array[(Double, Double, Double)] = { - val cleanInput = input.filter{ case (y, x, weight) => - require( - weight >= 0.0, - s"Negative weight at point ($y, $x, $weight). Weights must be non-negative" - ) - weight > 0 - } - - if (cleanInput.isEmpty) { + if (cleanUniqueInput.isEmpty) { return Array.empty } // Keeps track of the start and end indices of the blocks. if [i, j] is a valid block from // cleanInput(i) to cleanInput(j) (inclusive), then blockBounds(i) = j and blockBounds(j) = i // Initially, each data point is its own block. - val blockBounds = Array.range(0, cleanInput.length) + val blockBounds = Array.range(0, cleanUniqueInput.length) // Keep track of the sum of weights and sum of weight * y for each block. weights(start) // gives the values for the block. Entries that are not at the start of a block // are meaningless. - val weights: Array[(Double, Double)] = cleanInput.map { case (y, _, weight) => + val weights: Array[(Double, Double)] = cleanUniqueInput.map { case (y, _, weight) => (weight, weight * y) } @@ -392,10 +435,10 @@ class IsotonicRegression private (private var isotonic: Boolean) extends Seriali // Merge on >= instead of > because it eliminates adjacent blocks with the same average, and we // want to compress our output as much as possible. Both give correct results. var i = 0 - while (nextBlock(i) < cleanInput.length) { + while (nextBlock(i) < cleanUniqueInput.length) { if (average(i) >= average(nextBlock(i))) { merge(i, nextBlock(i)) - while((i > 0) && (average(prevBlock(i)) >= average(i))) { + while ((i > 0) && (average(prevBlock(i)) >= average(i))) { i = merge(prevBlock(i), i) } } else { @@ -406,15 +449,15 @@ class IsotonicRegression private (private var isotonic: Boolean) extends Seriali // construct the output by walking through the blocks in order val output = ArrayBuffer.empty[(Double, Double, Double)] i = 0 - while (i < cleanInput.length) { + while (i < cleanUniqueInput.length) { // If block size is > 1, a point at the start and end of the block, // each receiving half the weight. Otherwise, a single point with // all the weight. - if (cleanInput(blockEnd(i))._2 > cleanInput(i)._2) { - output += ((average(i), cleanInput(i)._2, weights(i)._1 / 2)) - output += ((average(i), cleanInput(blockEnd(i))._2, weights(i)._1 / 2)) + if (cleanUniqueInput(blockEnd(i))._2 > cleanUniqueInput(i)._2) { + output += ((average(i), cleanUniqueInput(i)._2, weights(i)._1 / 2)) + output += ((average(i), cleanUniqueInput(blockEnd(i))._2, weights(i)._1 / 2)) } else { - output += ((average(i), cleanInput(i)._2, weights(i)._1)) + output += ((average(i), cleanUniqueInput(i)._2, weights(i)._1)) } i = nextBlock(i) } @@ -434,12 +477,58 @@ class IsotonicRegression private (private var isotonic: Boolean) extends Seriali input: RDD[(Double, Double, Double)]): Array[(Double, Double, Double)] = { val keyedInput = input.keyBy(_._2) val parallelStepResult = keyedInput + // Points with same or adjacent features must collocate within the same partition. .partitionBy(new RangePartitioner(keyedInput.getNumPartitions, keyedInput)) .values - .mapPartitions(p => Iterator(p.toArray.sortBy(x => (x._2, x._1)))) + // Lexicographically sort points by features. + .mapPartitions(p => Iterator(p.toArray.sortBy(_._2))) + // Aggregate points with equal features into a single point. + .map(makeUnique) .flatMap(poolAdjacentViolators) .collect() - .sortBy(x => (x._2, x._1)) // Sort again because collect() doesn't promise ordering. + // Sort again because collect() doesn't promise ordering. + .sortBy(_._2) poolAdjacentViolators(parallelStepResult) } } + +object IsotonicRegression { + /** + * Utility class, holds a buffer of all points with unique features so far, and performs + * weighted sum accumulation of points. Hides these details for better readability of the + * main algorithm. + */ + class PointsAccumulator { + private val output = ArrayBuffer[(Double, Double, Double)]() + private var (currentLabel: Double, currentFeature: Double, currentWeight: Double) = + (0d, 0d, 0d) + + /** Whether or not this feature exactly equals the current accumulated feature. */ + @inline def shouldAccumulate(feature: Double): Boolean = currentFeature == feature + + /** Resets the current value of the point accumulator using the provided point. */ + @inline def :=(point: (Double, Double, Double)): Unit = { + val (label, feature, weight) = point + currentLabel = label * weight + currentFeature = feature + currentWeight = weight + } + + /** Accumulates the provided point into the current value of the point accumulator. */ + @inline def +=(point: (Double, Double, Double)): Unit = { + val (label, _, weight) = point + currentLabel += label * weight + currentWeight += weight + } + + /** Appends the current value of the point accumulator to the output. */ + @inline def appendToOutput(): Unit = + output += (( + currentLabel / currentWeight, + currentFeature, + currentWeight)) + + /** Returns all accumulated points so far. */ + @inline def getOutput: Array[(Double, Double, Double)] = output.toArray + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala index 8066900dfa..a206e922e5 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala @@ -24,6 +24,24 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.util.Utils +/** + * Tests can be verified through the following python snippet: + * + * {{{ + * from sklearn.isotonic import IsotonicRegression + * + * def test(x, y, x_test, isotonic=True): + * ir = IsotonicRegression(out_of_bounds='clip', increasing=isotonic).fit(x, y) + * y_test = ir.predict(x_test) + * + * def print_array(label, a): + * print(f"{label}: [{', '.join([str(i) for i in a])}]") + * + * print_array("boundaries", ir.X_thresholds_) + * print_array("predictions", ir.y_thresholds_) + * print_array("y_test", y_test) + * }}} + */ class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers { private def round(d: Double) = { @@ -44,8 +62,7 @@ class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w labels: Seq[Double], weights: Seq[Double], isotonic: Boolean): IsotonicRegressionModel = { - val trainRDD = sc.parallelize(generateIsotonicInput(labels, weights)).cache() - new IsotonicRegression().setIsotonic(isotonic).run(trainRDD) + runIsotonicRegressionOnInput(generateIsotonicInput(labels, weights), isotonic) } private def runIsotonicRegression( @@ -54,17 +71,37 @@ class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w runIsotonicRegression(labels, Array.fill(labels.size)(1d), isotonic) } + private def runIsotonicRegression( + labels: Seq[Double], + features: Seq[Double], + weights: Seq[Double], + isotonic: Boolean): IsotonicRegressionModel = { + runIsotonicRegressionOnInput( + labels.indices.map(i => (labels(i), features(i), weights(i))), + isotonic) + } + + private def runIsotonicRegressionOnInput( + input: Seq[(Double, Double, Double)], + isotonic: Boolean, + slices: Int = sc.defaultParallelism): IsotonicRegressionModel = { + val trainRDD = sc.parallelize(input, slices).cache() + new IsotonicRegression().setIsotonic(isotonic).run(trainRDD) + } + test("increasing isotonic regression") { /* The following result could be re-produced with sklearn. - > from sklearn.isotonic import IsotonicRegression - > x = range(9) - > y = [1, 2, 3, 1, 6, 17, 16, 17, 18] - > ir = IsotonicRegression(x, y) - > print ir.predict(x) + > test( + > x = range(9), + > y = [1, 2, 3, 1, 6, 17, 16, 17, 18], + > x_test = range(9) + > ) - array([ 1. , 2. , 2. , 2. , 6. , 16.5, 16.5, 17. , 18. ]) + boundaries: [0.0, 1.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0] + predictions: [1.0, 2.0, 2.0, 6.0, 16.5, 16.5, 17.0, 18.0] + y_test: [1.0, 2.0, 2.0, 2.0, 6.0, 16.5, 16.5, 17.0, 18.0] */ val model = runIsotonicRegression(Seq(1, 2, 3, 1, 6, 17, 16, 17, 18), true) @@ -142,9 +179,9 @@ class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w } test("isotonic regression with unordered input") { - val trainRDD = sc.parallelize(generateIsotonicInput(Seq(1, 2, 3, 4, 5)).reverse, 2).cache() + val model = + runIsotonicRegressionOnInput(generateIsotonicInput(Seq(1, 2, 3, 4, 5)).reverse, true, 2) - val model = new IsotonicRegression().run(trainRDD) assert(model.predictions === Array(1, 2, 3, 4, 5)) } @@ -159,7 +196,7 @@ class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w val model = runIsotonicRegression(Seq(1, 2, 3, 2, 1), Seq(1, 1, 1, 0.1, 0.1), true) assert(model.boundaries === Array(0, 1, 2, 4)) - assert(model.predictions.map(round) === Array(1, 2, 3.3/1.2, 3.3/1.2)) + assert(model.predictions.map(round) === Array(1, 2, 3.3 / 1.2, 3.3 / 1.2)) } test("weighted isotonic regression with negative weights") { @@ -176,16 +213,31 @@ class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w } test("SPARK-16426 isotonic regression with duplicate features that produce NaNs") { - val trainRDD = sc.parallelize(Seq[(Double, Double, Double)]((2, 1, 1), (1, 1, 1), (0, 2, 1), - (1, 2, 1), (0.5, 3, 1), (0, 3, 1)), - 2) - - val model = new IsotonicRegression().run(trainRDD) + val model = runIsotonicRegressionOnInput( + Seq((2, 1, 1), (1, 1, 1), (0, 2, 1), (1, 2, 1), (0.5, 3, 1), (0, 3, 1)), + true, + 2) assert(model.boundaries === Array(1.0, 3.0)) assert(model.predictions === Array(0.75, 0.75)) } + test("SPARK-41008 isotonic regression with duplicate features differs from sklearn") { + val model = runIsotonicRegressionOnInput( + Seq((1, 0.6, 1), (0, 0.6, 1), + (0, 1.0 / 3, 1), (1, 1.0 / 3, 1), (0, 1.0 / 3, 1), + (1, 0.2, 1), (0, 0.2, 1), (0, 0.2, 1), (0, 0.2, 1)), + true, + 2) + + assert(model.boundaries === Array(0.2, 1.0 / 3, 0.6)) + assert(model.predictions === Array(0.25, 1.0 / 3, 0.5)) + + assert(model.predict(0.6) === 0.5) + assert(model.predict(1.0 / 3) === 1.0 / 3) + assert(model.predict(0.2) === 0.25) + } + test("isotonic regression prediction") { val model = runIsotonicRegression(Seq(1, 2, 7, 1, 2), true) @@ -194,32 +246,38 @@ class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w assert(model.predict(0.5) === 1.5) assert(model.predict(0.75) === 1.75) assert(model.predict(1) === 2) - assert(model.predict(2) === 10d/3) - assert(model.predict(9) === 10d/3) + assert(model.predict(2) === 10.0 / 3) + assert(model.predict(9) === 10.0 / 3) } test("isotonic regression prediction with duplicate features") { - val trainRDD = sc.parallelize( - Seq[(Double, Double, Double)]( - (2, 1, 1), (1, 1, 1), (4, 2, 1), (2, 2, 1), (6, 3, 1), (5, 3, 1)), 2).cache() - val model = new IsotonicRegression().run(trainRDD) - - assert(model.predict(0) === 1) - assert(model.predict(1.5) === 2) - assert(model.predict(2.5) === 4.5) - assert(model.predict(4) === 6) + val model = runIsotonicRegressionOnInput( + Seq((2, 1, 1), (1, 1, 1), (4, 2, 1), (2, 2, 1), (6, 3, 1), (5, 3, 1)), + true, + 2) + + assert(model.boundaries === Array(1.0, 2.0, 3.0)) + assert(model.predictions === Array(1.5, 3.0, 5.5)) + + assert(model.predict(0) === 1.5) + assert(model.predict(1.5) === 2.25) + assert(model.predict(2.5) === 4.25) + assert(model.predict(4) === 5.5) } test("antitonic regression prediction with duplicate features") { - val trainRDD = sc.parallelize( - Seq[(Double, Double, Double)]( - (5, 1, 1), (6, 1, 1), (2, 2, 1), (4, 2, 1), (1, 3, 1), (2, 3, 1)), 2).cache() - val model = new IsotonicRegression().setIsotonic(false).run(trainRDD) - - assert(model.predict(0) === 6) - assert(model.predict(1.5) === 4.5) - assert(model.predict(2.5) === 2) - assert(model.predict(4) === 1) + val model = runIsotonicRegressionOnInput( + Seq((5, 1, 1), (6, 1, 1), (2, 2, 1), (4, 2, 1), (1, 3, 1), (2, 3, 1)), + false, + 2) + + assert(model.boundaries === Array(1.0, 2.0, 3.0)) + assert(model.predictions === Array(5.5, 3.0, 1.5)) + + assert(model.predict(0) === 5.5) + assert(model.predict(1.5) === 4.25) + assert(model.predict(2.5) === 2.25) + assert(model.predict(4) === 1.5) } test("isotonic regression RDD prediction") { @@ -227,7 +285,7 @@ class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w val testRDD = sc.parallelize(List(-2.0, -1.0, 0.5, 0.75, 1.0, 2.0, 9.0), 2).cache() val predictions = testRDD.map(x => (x, model.predict(x))).collect().sortBy(_._1).map(_._2) - assert(predictions === Array(1, 1, 1.5, 1.75, 2, 10.0/3, 10.0/3)) + assert(predictions === Array(1, 1, 1.5, 1.75, 2, 10.0 / 3, 10.0 / 3)) } test("antitonic regression prediction") { @@ -270,4 +328,63 @@ class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w new IsotonicRegressionModel(Array(0.0, 1.0), Array(1.0, 2.0), isotonic = false) } } + + test("makeUnique: handle duplicate features") { + val regressor = new IsotonicRegression() + import regressor.makeUnique + + // Note: input must be lexicographically sorted by feature + + // empty + assert(makeUnique(Array.empty) === Array.empty) + + // single + assert(makeUnique(Array((1.0, 1.0, 1.0))) === Array((1.0, 1.0, 1.0))) + + // two and duplicate + assert(makeUnique(Array((1.0, 1.0, 1.0), (1.0, 1.0, 1.0))) === Array((1.0, 1.0, 2.0))) + + // two and unique + assert( + makeUnique(Array((1.0, 1.0, 1.0), (1.0, 2.0, 1.0))) === + Array((1.0, 1.0, 1.0), (1.0, 2.0, 1.0))) + + // generic with duplicates + assert( + makeUnique( + Array( + (10.0, 1.0, 1.0), (20.0, 1.0, 1.0), + (10.0, 2.0, 1.0), (20.0, 2.0, 1.0), (30.0, 2.0, 1.0), + (10.0, 3.0, 1.0) + )) === Array((15.0, 1.0, 2.0), (20.0, 2.0, 3.0), (10.0, 3.0, 1.0))) + + // generic unique + assert( + makeUnique(Array((10.0, 1.0, 1.0), (10.0, 2.0, 1.0), (10.0, 3.0, 1.0))) === Array( + (10.0, 1.0, 1.0), + (10.0, 2.0, 1.0), + (10.0, 3.0, 1.0))) + + // generic with duplicates and non-uniform weights + assert( + makeUnique( + Array( + (10.0, 1.0, 0.3), (20.0, 1.0, 0.7), + (10.0, 2.0, 0.3), (20.0, 2.0, 0.3), (30.0, 2.0, 0.4), + (10.0, 3.0, 1.0) + )) === Array( + (10.0 * 0.3 + 20.0 * 0.7, 1.0, 1.0), + (10.0 * 0.3 + 20.0 * 0.3 + 30.0 * 0.4, 2.0, 1.0), + (10.0, 3.0, 1.0))) + + // don't handle tiny representation errors + // e.g. infinitely adjacent doubles are already unique + val adjacentDoubles = { + // i-th next representable double to 1.0 is java.lang.Double.longBitsToDouble(base + i) + val base = java.lang.Double.doubleToRawLongBits(1.0) + (0 until 10).map(i => java.lang.Double.longBitsToDouble(base + i)) + .map((1.0, _, 1.0)).toArray + } + assert(makeUnique(adjacentDoubles) === adjacentDoubles) + } } diff --git a/pom.xml b/pom.xml index b2e5979f46..da7c8eccfc 100644 --- a/pom.xml +++ b/pom.xml @@ -123,7 +123,7 @@ 2.5.0 - 3.21.9 + 3.21.11 3.11.4 ${hadoop.version} 3.6.3 @@ -161,7 +161,7 @@ 0.12.8 hadoop3-2.2.7 - 4.5.13 + 4.5.14 4.4.14 3.6.1 @@ -175,7 +175,7 @@ errors building different Hadoop versions. See: SPARK-36547, SPARK-38394. --> - 4.7.2 + 4.8.0 true true @@ -2042,6 +2042,10 @@ ${hive.group} hive-ant + + ${hive.group} + hive-vector-code-gen + ${hive.group} @@ -3229,17 +3233,6 @@ org.spark-project.spark:unused - org.eclipse.jetty:jetty-io - org.eclipse.jetty:jetty-http - org.eclipse.jetty:jetty-proxy - org.eclipse.jetty:jetty-client - org.eclipse.jetty:jetty-continuation - org.eclipse.jetty:jetty-servlet - org.eclipse.jetty:jetty-servlets - org.eclipse.jetty:jetty-plus - org.eclipse.jetty:jetty-security - org.eclipse.jetty:jetty-util - org.eclipse.jetty:jetty-server com.google.guava:guava org.jpmml:* diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index eed79d1f20..7ec4ef37a0 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -123,7 +123,13 @@ object MimaExcludes { ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryException.this"), // [SPARK-41180][SQL] Reuse INVALID_SCHEMA instead of _LEGACY_ERROR_TEMP_1227 - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.types.DataType.parseTypeWithFallback") + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.types.DataType.parseTypeWithFallback"), + + // [SPARK-41360][CORE] Avoid BlockManager re-registration if the executor has been lost + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.BlockManagerMessages#RegisterBlockManager.copy"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.BlockManagerMessages#RegisterBlockManager.this"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.storage.BlockManagerMessages$RegisterBlockManager$"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.BlockManagerMessages#RegisterBlockManager.apply") ) // Defulat exclude rules diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index e6a39714e6..556f8528ea 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -87,7 +87,7 @@ object BuildCommons { // Google Protobuf version used for generating the protobuf. // SPARK-41247: needs to be consistent with `protobuf.version` in `pom.xml`. - val protoVersion = "3.21.9" + val protoVersion = "3.21.11" // GRPC version used for Spark Connect. val gprcVersion = "1.47.0" } @@ -112,15 +112,13 @@ object SparkBuild extends PomBuild { sys.props.put("test.jdwp.enabled", "true") } if (profiles.contains("user-defined-protoc")) { - val connectProtocExecPath = Properties.envOrNone("CONNECT_PROTOC_EXEC_PATH") + val sparkProtocExecPath = Properties.envOrNone("SPARK_PROTOC_EXEC_PATH") val connectPluginExecPath = Properties.envOrNone("CONNECT_PLUGIN_EXEC_PATH") - val protobufProtocExecPath = Properties.envOrNone("PROTOBUF_PROTOC_EXEC_PATH") - if (connectProtocExecPath.isDefined && connectPluginExecPath.isDefined) { - sys.props.put("connect.protoc.executable.path", connectProtocExecPath.get) - sys.props.put("connect.plugin.executable.path", connectPluginExecPath.get) + if (sparkProtocExecPath.isDefined) { + sys.props.put("spark.protoc.executable.path", sparkProtocExecPath.get) } - if (protobufProtocExecPath.isDefined) { - sys.props.put("protobuf.protoc.executable.path", protobufProtocExecPath.get) + if (connectPluginExecPath.isDefined) { + sys.props.put("connect.plugin.executable.path", connectPluginExecPath.get) } } profiles @@ -644,7 +642,16 @@ object Core { val propsFile = baseDirectory.value / "target" / "extra-resources" / "spark-version-info.properties" Seq(propsFile) }.taskValue - ) + ) ++ { + val sparkProtocExecPath = sys.props.get("spark.protoc.executable.path") + if (sparkProtocExecPath.isDefined) { + Seq( + PB.protocExecutable := file(sparkProtocExecPath.get) + ) + } else { + Seq.empty + } + } } object SparkConnectCommon { @@ -709,15 +716,15 @@ object SparkConnectCommon { case _ => MergeStrategy.first } ) ++ { - val connectProtocExecPath = sys.props.get("connect.protoc.executable.path") + val sparkProtocExecPath = sys.props.get("spark.protoc.executable.path") val connectPluginExecPath = sys.props.get("connect.plugin.executable.path") - if (connectProtocExecPath.isDefined && connectPluginExecPath.isDefined) { + if (sparkProtocExecPath.isDefined && connectPluginExecPath.isDefined) { Seq( (Compile / PB.targets) := Seq( PB.gens.java -> (Compile / sourceManaged).value, PB.gens.plugin(name = "grpc-java", path = connectPluginExecPath.get) -> (Compile / sourceManaged).value ), - PB.protocExecutable := file(connectProtocExecPath.get) + PB.protocExecutable := file(sparkProtocExecPath.get) ) } else { Seq( @@ -867,10 +874,10 @@ object SparkProtobuf { case _ => MergeStrategy.first }, ) ++ { - val protobufProtocExecPath = sys.props.get("protobuf.protoc.executable.path") - if (protobufProtocExecPath.isDefined) { + val sparkProtocExecPath = sys.props.get("spark.protoc.executable.path") + if (sparkProtocExecPath.isDefined) { Seq( - PB.protocExecutable := file(protobufProtocExecPath.get) + PB.protocExecutable := file(sparkProtocExecPath.get) ) } else { Seq.empty diff --git a/python/mypy.ini b/python/mypy.ini index 927254d3b3..603647bd3c 100644 --- a/python/mypy.ini +++ b/python/mypy.ini @@ -26,7 +26,7 @@ warn_redundant_casts = True [mypy-pyspark.sql.connect.proto.*] ignore_errors = True -; Allow untyped def in internal modules and tests +; Allow untyped def in internal modules [mypy-pyspark.daemon] disallow_untyped_defs = False @@ -46,33 +46,18 @@ disallow_untyped_defs = False [mypy-pyspark.join] disallow_untyped_defs = False -[mypy-pyspark.ml.tests.*] -disallow_untyped_defs = False - -[mypy-pyspark.mllib.tests.*] -disallow_untyped_defs = False - [mypy-pyspark.rddsampler] disallow_untyped_defs = False -[mypy-pyspark.resource.tests.*] -disallow_untyped_defs = False - [mypy-pyspark.serializers] disallow_untyped_defs = False [mypy-pyspark.shuffle] disallow_untyped_defs = False -[mypy-pyspark.streaming.tests.*] -disallow_untyped_defs = False - [mypy-pyspark.streaming.util] disallow_untyped_defs = False -[mypy-pyspark.sql.tests.*] -disallow_untyped_defs = False - [mypy-pyspark.sql.pandas.serializers] disallow_untyped_defs = False @@ -88,20 +73,37 @@ disallow_untyped_defs = False [mypy-pyspark.pandas.usage_logging.*] disallow_untyped_defs = False -[mypy-pyspark.pandas.tests.*] +[mypy-pyspark.traceback_utils] disallow_untyped_defs = False -[mypy-pyspark.tests.*] +[mypy-pyspark.worker] disallow_untyped_defs = False -[mypy-pyspark.testing.*] -disallow_untyped_defs = False +; Ignore errors in tests -[mypy-pyspark.traceback_utils] -disallow_untyped_defs = False +[mypy-pyspark.ml.tests.*] +ignore_errors = True -[mypy-pyspark.worker] -disallow_untyped_defs = False +[mypy-pyspark.mllib.tests.*] +ignore_errors = True + +[mypy-pyspark.resource.tests.*] +ignore_errors = True + +[mypy-pyspark.streaming.tests.*] +ignore_errors = True + +[mypy-pyspark.sql.tests.*] +ignore_errors = True + +[mypy-pyspark.pandas.tests.*] +ignore_errors = True + +[mypy-pyspark.tests.*] +ignore_errors = True + +[mypy-pyspark.testing.*] +ignore_errors = True ; Allow non-strict optional for pyspark.pandas @@ -145,6 +147,9 @@ ignore_missing_imports = True [mypy-google.protobuf.*] ignore_missing_imports = True +[mypy-grpc.*] +ignore_missing_imports = True + ; Ignore errors for proto generated code [mypy-pyspark.sql.connect.proto.*, pyspark.sql.connect.proto] ignore_errors = True diff --git a/python/pyspark/ml/tests/test_algorithms.py b/python/pyspark/ml/tests/test_algorithms.py index e677e79cec..accdddb29c 100644 --- a/python/pyspark/ml/tests/test_algorithms.py +++ b/python/pyspark/ml/tests/test_algorithms.py @@ -413,7 +413,7 @@ def test_linear_regression_with_huber_loss(self): from pyspark.ml.tests.test_algorithms import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/ml/tests/test_base.py b/python/pyspark/ml/tests/test_base.py index b95b8fbdd5..6c3c51d1c0 100644 --- a/python/pyspark/ml/tests/test_base.py +++ b/python/pyspark/ml/tests/test_base.py @@ -88,7 +88,7 @@ def testDefaultFitMultiple(self): from pyspark.ml.tests.test_base import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/ml/tests/test_evaluation.py b/python/pyspark/ml/tests/test_evaluation.py index d2fd369624..3c5ae3fbe7 100644 --- a/python/pyspark/ml/tests/test_evaluation.py +++ b/python/pyspark/ml/tests/test_evaluation.py @@ -69,7 +69,7 @@ def test_clustering_evaluator_with_cosine_distance(self): from pyspark.ml.tests.test_evaluation import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/ml/tests/test_feature.py b/python/pyspark/ml/tests/test_feature.py index 6cf3175865..0051d47ae3 100644 --- a/python/pyspark/ml/tests/test_feature.py +++ b/python/pyspark/ml/tests/test_feature.py @@ -393,7 +393,7 @@ def test_apply_binary_term_freqs(self): from pyspark.ml.tests.test_feature import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/ml/tests/test_image.py b/python/pyspark/ml/tests/test_image.py index 8a155ab56a..86fa46c324 100644 --- a/python/pyspark/ml/tests/test_image.py +++ b/python/pyspark/ml/tests/test_image.py @@ -74,7 +74,7 @@ def test_read_images(self): from pyspark.ml.tests.test_image import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/ml/tests/test_linalg.py b/python/pyspark/ml/tests/test_linalg.py index a6e9f4e752..6632d100ea 100644 --- a/python/pyspark/ml/tests/test_linalg.py +++ b/python/pyspark/ml/tests/test_linalg.py @@ -401,7 +401,7 @@ def test_infer_schema(self): from pyspark.ml.tests.test_linalg import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/ml/tests/test_param.py b/python/pyspark/ml/tests/test_param.py index 64ed2f6dbe..8df50a5963 100644 --- a/python/pyspark/ml/tests/test_param.py +++ b/python/pyspark/ml/tests/test_param.py @@ -433,7 +433,7 @@ def test_java_params(self): from pyspark.ml.tests.test_param import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/ml/tests/test_persistence.py b/python/pyspark/ml/tests/test_persistence.py index 0b54540f06..406180d9a6 100644 --- a/python/pyspark/ml/tests/test_persistence.py +++ b/python/pyspark/ml/tests/test_persistence.py @@ -538,7 +538,7 @@ def test_save_and_load_on_nested_list_params(self): from pyspark.ml.tests.test_persistence import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/ml/tests/test_pipeline.py b/python/pyspark/ml/tests/test_pipeline.py index 1f73fdd344..afc900cec4 100644 --- a/python/pyspark/ml/tests/test_pipeline.py +++ b/python/pyspark/ml/tests/test_pipeline.py @@ -63,7 +63,7 @@ def doTransform(pipeline): from pyspark.ml.tests.test_pipeline import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/ml/tests/test_stat.py b/python/pyspark/ml/tests/test_stat.py index 16ce1bc7da..6bab41b567 100644 --- a/python/pyspark/ml/tests/test_stat.py +++ b/python/pyspark/ml/tests/test_stat.py @@ -44,7 +44,7 @@ def test_chisquaretest(self): from pyspark.ml.tests.test_stat import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/ml/tests/test_training_summary.py b/python/pyspark/ml/tests/test_training_summary.py index 27d9c182cf..5704d71867 100644 --- a/python/pyspark/ml/tests/test_training_summary.py +++ b/python/pyspark/ml/tests/test_training_summary.py @@ -486,7 +486,7 @@ def test_kmeans_summary(self): from pyspark.ml.tests.test_training_summary import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/ml/tests/test_tuning.py b/python/pyspark/ml/tests/test_tuning.py index c4273f36d7..d9a5c51fd5 100644 --- a/python/pyspark/ml/tests/test_tuning.py +++ b/python/pyspark/ml/tests/test_tuning.py @@ -1027,7 +1027,7 @@ def test_copy(self): from pyspark.ml.tests.test_tuning import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/ml/tests/test_util.py b/python/pyspark/ml/tests/test_util.py index 4d5c6a4727..55c973831b 100644 --- a/python/pyspark/ml/tests/test_util.py +++ b/python/pyspark/ml/tests/test_util.py @@ -77,7 +77,7 @@ def _check_uid_set_equal(stages, expected_stages): from pyspark.ml.tests.test_util import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/ml/tests/test_wrapper.py b/python/pyspark/ml/tests/test_wrapper.py index 02ce6f3192..33d93c02ac 100644 --- a/python/pyspark/ml/tests/test_wrapper.py +++ b/python/pyspark/ml/tests/test_wrapper.py @@ -130,7 +130,7 @@ def test_new_java_array(self): from pyspark.ml.tests.test_wrapper import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/mllib/tests/test_algorithms.py b/python/pyspark/mllib/tests/test_algorithms.py index 8882242259..6a9be99ecd 100644 --- a/python/pyspark/mllib/tests/test_algorithms.py +++ b/python/pyspark/mllib/tests/test_algorithms.py @@ -338,7 +338,7 @@ def test_fpgrowth(self): from pyspark.mllib.tests.test_algorithms import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/mllib/tests/test_feature.py b/python/pyspark/mllib/tests/test_feature.py index 080a2bf1f5..ca06f39da2 100644 --- a/python/pyspark/mllib/tests/test_feature.py +++ b/python/pyspark/mllib/tests/test_feature.py @@ -184,7 +184,7 @@ def test_pca(self): from pyspark.mllib.tests.test_feature import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/mllib/tests/test_linalg.py b/python/pyspark/mllib/tests/test_linalg.py index 007f42d3c2..d137c88836 100644 --- a/python/pyspark/mllib/tests/test_linalg.py +++ b/python/pyspark/mllib/tests/test_linalg.py @@ -672,7 +672,7 @@ def test_regression(self): from pyspark.mllib.tests.test_linalg import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/mllib/tests/test_stat.py b/python/pyspark/mllib/tests/test_stat.py index 7a33d773d1..cef1294ada 100644 --- a/python/pyspark/mllib/tests/test_stat.py +++ b/python/pyspark/mllib/tests/test_stat.py @@ -198,7 +198,7 @@ def test_R_implementation_equivalence(self): from pyspark.mllib.tests.test_stat import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/mllib/tests/test_streaming_algorithms.py b/python/pyspark/mllib/tests/test_streaming_algorithms.py index 779fff7090..5a06742ba7 100644 --- a/python/pyspark/mllib/tests/test_streaming_algorithms.py +++ b/python/pyspark/mllib/tests/test_streaming_algorithms.py @@ -463,7 +463,7 @@ def condition(): from pyspark.mllib.tests.test_streaming_algorithms import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/mllib/tests/test_util.py b/python/pyspark/mllib/tests/test_util.py index aad1349c71..28a53af0aa 100644 --- a/python/pyspark/mllib/tests/test_util.py +++ b/python/pyspark/mllib/tests/test_util.py @@ -100,7 +100,7 @@ def test_to_java_object_rdd(self): # SPARK-6660 from pyspark.mllib.tests.test_util import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/data_type_ops/test_base.py b/python/pyspark/pandas/tests/data_type_ops/test_base.py index db4724b982..9b40d15db6 100644 --- a/python/pyspark/pandas/tests/data_type_ops/test_base.py +++ b/python/pyspark/pandas/tests/data_type_ops/test_base.py @@ -95,7 +95,7 @@ def test_bool_ext_ops(self): from pyspark.pandas.tests.data_type_ops.test_base import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/data_type_ops/test_binary_ops.py b/python/pyspark/pandas/tests/data_type_ops/test_binary_ops.py index 7135800bd9..6eca20d2db 100644 --- a/python/pyspark/pandas/tests/data_type_ops/test_binary_ops.py +++ b/python/pyspark/pandas/tests/data_type_ops/test_binary_ops.py @@ -212,7 +212,7 @@ def test_ge(self): from pyspark.pandas.tests.data_type_ops.test_binary_ops import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/data_type_ops/test_boolean_ops.py b/python/pyspark/pandas/tests/data_type_ops/test_boolean_ops.py index 7376120226..ad7ead6316 100644 --- a/python/pyspark/pandas/tests/data_type_ops/test_boolean_ops.py +++ b/python/pyspark/pandas/tests/data_type_ops/test_boolean_ops.py @@ -813,7 +813,7 @@ def test_ge(self): from pyspark.pandas.tests.data_type_ops.test_boolean_ops import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/data_type_ops/test_categorical_ops.py b/python/pyspark/pandas/tests/data_type_ops/test_categorical_ops.py index 992e3ed70f..41e6c4885d 100644 --- a/python/pyspark/pandas/tests/data_type_ops/test_categorical_ops.py +++ b/python/pyspark/pandas/tests/data_type_ops/test_categorical_ops.py @@ -550,7 +550,7 @@ def test_ge(self): from pyspark.pandas.tests.data_type_ops.test_categorical_ops import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/data_type_ops/test_complex_ops.py b/python/pyspark/pandas/tests/data_type_ops/test_complex_ops.py index bbdf837ce2..2b85e7bb26 100644 --- a/python/pyspark/pandas/tests/data_type_ops/test_complex_ops.py +++ b/python/pyspark/pandas/tests/data_type_ops/test_complex_ops.py @@ -356,7 +356,7 @@ def test_ge(self): from pyspark.pandas.tests.data_type_ops.test_complex_ops import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/data_type_ops/test_date_ops.py b/python/pyspark/pandas/tests/data_type_ops/test_date_ops.py index b457ab2cc8..2fe8a4c688 100644 --- a/python/pyspark/pandas/tests/data_type_ops/test_date_ops.py +++ b/python/pyspark/pandas/tests/data_type_ops/test_date_ops.py @@ -235,7 +235,7 @@ def test_ge(self): from pyspark.pandas.tests.data_type_ops.test_date_ops import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/data_type_ops/test_datetime_ops.py b/python/pyspark/pandas/tests/data_type_ops/test_datetime_ops.py index de9c6acb2c..55d06c07cd 100644 --- a/python/pyspark/pandas/tests/data_type_ops/test_datetime_ops.py +++ b/python/pyspark/pandas/tests/data_type_ops/test_datetime_ops.py @@ -248,7 +248,7 @@ def setUpClass(cls): from pyspark.pandas.tests.data_type_ops.test_datetime_ops import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/data_type_ops/test_null_ops.py b/python/pyspark/pandas/tests/data_type_ops/test_null_ops.py index fc6cdd1a43..44ea159f2a 100644 --- a/python/pyspark/pandas/tests/data_type_ops/test_null_ops.py +++ b/python/pyspark/pandas/tests/data_type_ops/test_null_ops.py @@ -165,7 +165,7 @@ def test_ge(self): from pyspark.pandas.tests.data_type_ops.test_null_ops import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/data_type_ops/test_num_ops.py b/python/pyspark/pandas/tests/data_type_ops/test_num_ops.py index cb678ff585..22d4e8d8ff 100644 --- a/python/pyspark/pandas/tests/data_type_ops/test_num_ops.py +++ b/python/pyspark/pandas/tests/data_type_ops/test_num_ops.py @@ -694,7 +694,7 @@ def test_ge(self): from pyspark.pandas.tests.data_type_ops.test_num_ops import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/data_type_ops/test_string_ops.py b/python/pyspark/pandas/tests/data_type_ops/test_string_ops.py index cc448dc42d..cf785f1ebb 100644 --- a/python/pyspark/pandas/tests/data_type_ops/test_string_ops.py +++ b/python/pyspark/pandas/tests/data_type_ops/test_string_ops.py @@ -342,7 +342,7 @@ def test_ge(self): from pyspark.pandas.tests.data_type_ops.test_string_ops import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/data_type_ops/test_timedelta_ops.py b/python/pyspark/pandas/tests/data_type_ops/test_timedelta_ops.py index eeaba4d277..3889520ad8 100644 --- a/python/pyspark/pandas/tests/data_type_ops/test_timedelta_ops.py +++ b/python/pyspark/pandas/tests/data_type_ops/test_timedelta_ops.py @@ -207,7 +207,7 @@ def test_ge(self): from pyspark.pandas.tests.data_type_ops.test_timedelta_ops import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/data_type_ops/test_udt_ops.py b/python/pyspark/pandas/tests/data_type_ops/test_udt_ops.py index 81767af76f..beebc1f320 100644 --- a/python/pyspark/pandas/tests/data_type_ops/test_udt_ops.py +++ b/python/pyspark/pandas/tests/data_type_ops/test_udt_ops.py @@ -180,7 +180,7 @@ def test_ge(self): from pyspark.pandas.tests.data_type_ops.test_udt_ops import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/indexes/test_base.py b/python/pyspark/pandas/tests/indexes/test_base.py index 9ca31923d5..0e2c640979 100644 --- a/python/pyspark/pandas/tests/indexes/test_base.py +++ b/python/pyspark/pandas/tests/indexes/test_base.py @@ -2590,7 +2590,7 @@ def test_multi_index_nunique(self): from pyspark.pandas.tests.indexes.test_base import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/indexes/test_category.py b/python/pyspark/pandas/tests/indexes/test_category.py index ba737eb520..10c822a3ca 100644 --- a/python/pyspark/pandas/tests/indexes/test_category.py +++ b/python/pyspark/pandas/tests/indexes/test_category.py @@ -459,7 +459,7 @@ def test_map(self): from pyspark.pandas.tests.indexes.test_category import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/indexes/test_datetime.py b/python/pyspark/pandas/tests/indexes/test_datetime.py index f715518743..8f8e283f3a 100644 --- a/python/pyspark/pandas/tests/indexes/test_datetime.py +++ b/python/pyspark/pandas/tests/indexes/test_datetime.py @@ -254,7 +254,7 @@ def test_map(self): from pyspark.pandas.tests.indexes.test_datetime import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/indexes/test_timedelta.py b/python/pyspark/pandas/tests/indexes/test_timedelta.py index b191ff8bfb..654f5ee3a0 100644 --- a/python/pyspark/pandas/tests/indexes/test_timedelta.py +++ b/python/pyspark/pandas/tests/indexes/test_timedelta.py @@ -110,7 +110,7 @@ def test_properties(self): from pyspark.pandas.tests.indexes.test_timedelta import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/plot/test_frame_plot.py b/python/pyspark/pandas/tests/plot/test_frame_plot.py index 5d265ff2ee..817ea896e7 100644 --- a/python/pyspark/pandas/tests/plot/test_frame_plot.py +++ b/python/pyspark/pandas/tests/plot/test_frame_plot.py @@ -158,7 +158,7 @@ def check_box_multi_columns(psdf): from pyspark.pandas.tests.plot.test_frame_plot import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/plot/test_frame_plot_matplotlib.py b/python/pyspark/pandas/tests/plot/test_frame_plot_matplotlib.py index bb400996e2..7c63371098 100644 --- a/python/pyspark/pandas/tests/plot/test_frame_plot_matplotlib.py +++ b/python/pyspark/pandas/tests/plot/test_frame_plot_matplotlib.py @@ -477,7 +477,7 @@ def check_kde_plot(pdf, psdf, *args, **kwargs): from pyspark.pandas.tests.plot.test_frame_plot_matplotlib import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/plot/test_frame_plot_plotly.py b/python/pyspark/pandas/tests/plot/test_frame_plot_plotly.py index d169326b7b..f7cf1fc349 100644 --- a/python/pyspark/pandas/tests/plot/test_frame_plot_plotly.py +++ b/python/pyspark/pandas/tests/plot/test_frame_plot_plotly.py @@ -273,7 +273,7 @@ def test_kde_plot(self): from pyspark.pandas.tests.plot.test_frame_plot_plotly import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/plot/test_series_plot.py b/python/pyspark/pandas/tests/plot/test_series_plot.py index f3d4ef553b..fab04bac21 100644 --- a/python/pyspark/pandas/tests/plot/test_series_plot.py +++ b/python/pyspark/pandas/tests/plot/test_series_plot.py @@ -94,7 +94,7 @@ def check_box_summary(psdf, pdf): from pyspark.pandas.tests.plot.test_series_plot import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/plot/test_series_plot_matplotlib.py b/python/pyspark/pandas/tests/plot/test_series_plot_matplotlib.py index 680eee13de..c17290c44b 100644 --- a/python/pyspark/pandas/tests/plot/test_series_plot_matplotlib.py +++ b/python/pyspark/pandas/tests/plot/test_series_plot_matplotlib.py @@ -397,7 +397,7 @@ def test_single_value_hist(self): from pyspark.pandas.tests.plot.test_series_plot_matplotlib import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/plot/test_series_plot_plotly.py b/python/pyspark/pandas/tests/plot/test_series_plot_plotly.py index 8a50b1829d..7bd612c1a8 100644 --- a/python/pyspark/pandas/tests/plot/test_series_plot_plotly.py +++ b/python/pyspark/pandas/tests/plot/test_series_plot_plotly.py @@ -235,7 +235,7 @@ def test_kde_plot(self): from pyspark.pandas.tests.plot.test_series_plot_plotly import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/test_categorical.py b/python/pyspark/pandas/tests/test_categorical.py index 99f315a43a..d5a660a66e 100644 --- a/python/pyspark/pandas/tests/test_categorical.py +++ b/python/pyspark/pandas/tests/test_categorical.py @@ -436,7 +436,7 @@ def test_groupby_transform_without_shortcut(self): pdf, psdf = self.df_pair - def identity(x) -> ps.Series[psdf.b.dtype]: # type: ignore[name-defined] + def identity(x) -> ps.Series[psdf.b.dtype]: return x self.assert_eq( @@ -796,7 +796,7 @@ def test_set_categories(self): from pyspark.pandas.tests.test_categorical import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/test_config.py b/python/pyspark/pandas/tests/test_config.py index d3900e216c..c1c2299240 100644 --- a/python/pyspark/pandas/tests/test_config.py +++ b/python/pyspark/pandas/tests/test_config.py @@ -148,7 +148,7 @@ def test_dir_options(self): from pyspark.pandas.tests.test_config import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/test_csv.py b/python/pyspark/pandas/tests/test_csv.py index 6bdc989c5d..a94125e648 100644 --- a/python/pyspark/pandas/tests/test_csv.py +++ b/python/pyspark/pandas/tests/test_csv.py @@ -435,7 +435,7 @@ def test_to_csv_with_partition_cols(self): from pyspark.pandas.tests.test_csv import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/test_dataframe.py b/python/pyspark/pandas/tests/test_dataframe.py index 4e80c680b6..1b06d321e1 100644 --- a/python/pyspark/pandas/tests/test_dataframe.py +++ b/python/pyspark/pandas/tests/test_dataframe.py @@ -7074,6 +7074,10 @@ def test_cov(self): psdf = ps.from_pandas(pdf) self.assert_eq(pdf.cov(), psdf.cov()) + @unittest.skipIf( + LooseVersion(pd.__version__) < LooseVersion("1.3.0"), + "pandas support `Styler.to_latex` since 1.3.0", + ) def test_style(self): # Currently, the `style` function returns a pandas object `Styler` as it is, # processing only the number of rows declared in `compute.max_rows`. @@ -7102,7 +7106,7 @@ def check_style(): from pyspark.pandas.tests.test_dataframe import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/test_dataframe_conversion.py b/python/pyspark/pandas/tests/test_dataframe_conversion.py index 4e4c9ac2e7..67ff40e9f1 100644 --- a/python/pyspark/pandas/tests/test_dataframe_conversion.py +++ b/python/pyspark/pandas/tests/test_dataframe_conversion.py @@ -262,7 +262,7 @@ def test_from_records(self): from pyspark.pandas.tests.test_dataframe_conversion import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/test_dataframe_spark_io.py b/python/pyspark/pandas/tests/test_dataframe_spark_io.py index dd83070a16..9904ff032d 100644 --- a/python/pyspark/pandas/tests/test_dataframe_spark_io.py +++ b/python/pyspark/pandas/tests/test_dataframe_spark_io.py @@ -475,7 +475,7 @@ def test_orc_write(self): from pyspark.pandas.tests.test_dataframe_spark_io import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/test_default_index.py b/python/pyspark/pandas/tests/test_default_index.py index dcb120aee4..ddd9e29662 100644 --- a/python/pyspark/pandas/tests/test_default_index.py +++ b/python/pyspark/pandas/tests/test_default_index.py @@ -97,7 +97,7 @@ def test_index_distributed_sequence_cleanup(self): from pyspark.pandas.tests.test_default_index import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/test_ewm.py b/python/pyspark/pandas/tests/test_ewm.py index 3ce0bd4507..4d3c98572d 100644 --- a/python/pyspark/pandas/tests/test_ewm.py +++ b/python/pyspark/pandas/tests/test_ewm.py @@ -422,7 +422,7 @@ def test_groupby_ewm_func(self): from pyspark.pandas.tests.test_ewm import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/test_expanding.py b/python/pyspark/pandas/tests/test_expanding.py index 77ced41eb8..d712f03f7d 100644 --- a/python/pyspark/pandas/tests/test_expanding.py +++ b/python/pyspark/pandas/tests/test_expanding.py @@ -241,7 +241,7 @@ def test_groupby_expanding_kurt(self): from pyspark.pandas.tests.test_expanding import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/test_extension.py b/python/pyspark/pandas/tests/test_extension.py index dd2d08dded..5d4b5dfa76 100644 --- a/python/pyspark/pandas/tests/test_extension.py +++ b/python/pyspark/pandas/tests/test_extension.py @@ -140,7 +140,7 @@ def __init__(self, data): from pyspark.pandas.tests.test_extension import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/test_frame_spark.py b/python/pyspark/pandas/tests/test_frame_spark.py index 9b47ceca7a..df090b74d9 100644 --- a/python/pyspark/pandas/tests/test_frame_spark.py +++ b/python/pyspark/pandas/tests/test_frame_spark.py @@ -148,7 +148,7 @@ def test_local_checkpoint(self): from pyspark.pandas.tests.test_frame_spark import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/test_generic_functions.py b/python/pyspark/pandas/tests/test_generic_functions.py index d476302205..72e0e47aed 100644 --- a/python/pyspark/pandas/tests/test_generic_functions.py +++ b/python/pyspark/pandas/tests/test_generic_functions.py @@ -222,7 +222,7 @@ def test_prod_precision(self): from pyspark.pandas.tests.test_generic_functions import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/test_groupby.py b/python/pyspark/pandas/tests/test_groupby.py index a203f77717..1c940e3abf 100644 --- a/python/pyspark/pandas/tests/test_groupby.py +++ b/python/pyspark/pandas/tests/test_groupby.py @@ -2334,7 +2334,7 @@ def add_max2( def test_apply_negative(self): def func(_) -> ps.Series[int]: - return pd.Series([1]) # type: ignore[return-value] + return pd.Series([1]) with self.assertRaisesRegex(TypeError, "Series as a return type hint at frame groupby"): ps.range(10).groupby("id").apply(func) @@ -3242,7 +3242,7 @@ def test_getitem(self): from pyspark.pandas.tests.test_groupby import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/test_indexing.py b/python/pyspark/pandas/tests/test_indexing.py index c939a69929..9d52c41274 100644 --- a/python/pyspark/pandas/tests/test_indexing.py +++ b/python/pyspark/pandas/tests/test_indexing.py @@ -1327,7 +1327,7 @@ def test_index_operator_int(self): from pyspark.pandas.tests.test_indexing import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/test_indexops_spark.py b/python/pyspark/pandas/tests/test_indexops_spark.py index 275ef77f71..8b0b5c87c9 100644 --- a/python/pyspark/pandas/tests/test_indexops_spark.py +++ b/python/pyspark/pandas/tests/test_indexops_spark.py @@ -68,7 +68,7 @@ def test_series_apply_negative(self): from pyspark.pandas.tests.test_indexops_spark import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/test_internal.py b/python/pyspark/pandas/tests/test_internal.py index 2ace222ed6..30a4bdcb66 100644 --- a/python/pyspark/pandas/tests/test_internal.py +++ b/python/pyspark/pandas/tests/test_internal.py @@ -112,7 +112,7 @@ def test_from_pandas(self): from pyspark.pandas.tests.test_internal import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/test_namespace.py b/python/pyspark/pandas/tests/test_namespace.py index 8f73c65846..c0bda11d98 100644 --- a/python/pyspark/pandas/tests/test_namespace.py +++ b/python/pyspark/pandas/tests/test_namespace.py @@ -621,7 +621,7 @@ def test_missing(self): from pyspark.pandas.tests.test_namespace import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/test_numpy_compat.py b/python/pyspark/pandas/tests/test_numpy_compat.py index d16d9996ec..fc6e332782 100644 --- a/python/pyspark/pandas/tests/test_numpy_compat.py +++ b/python/pyspark/pandas/tests/test_numpy_compat.py @@ -188,7 +188,7 @@ def test_np_spark_compat_frame(self): from pyspark.pandas.tests.test_numpy_compat import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/test_ops_on_diff_frames.py b/python/pyspark/pandas/tests/test_ops_on_diff_frames.py index 71c393dcf3..734e2545d1 100644 --- a/python/pyspark/pandas/tests/test_ops_on_diff_frames.py +++ b/python/pyspark/pandas/tests/test_ops_on_diff_frames.py @@ -2141,7 +2141,7 @@ def test_series_eq(self): from pyspark.pandas.tests.test_ops_on_diff_frames import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby.py b/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby.py index 69621e4930..1bc1ab4772 100644 --- a/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby.py +++ b/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby.py @@ -630,7 +630,7 @@ def test_fillna(self): from pyspark.pandas.tests.test_ops_on_diff_frames_groupby import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby_expanding.py b/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby_expanding.py index 08f17745df..072a83d294 100644 --- a/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby_expanding.py +++ b/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby_expanding.py @@ -99,7 +99,7 @@ def test_groupby_expanding_var(self): from pyspark.pandas.tests.test_ops_on_diff_frames_groupby_expanding import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby_rolling.py b/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby_rolling.py index 04ea448d80..e9a42e79ab 100644 --- a/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby_rolling.py +++ b/python/pyspark/pandas/tests/test_ops_on_diff_frames_groupby_rolling.py @@ -99,7 +99,7 @@ def test_groupby_rolling_var(self): from pyspark.pandas.tests.test_ops_on_diff_frames_groupby_rolling import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/test_repr.py b/python/pyspark/pandas/tests/test_repr.py index 271ed0a14c..d1ba46e63f 100644 --- a/python/pyspark/pandas/tests/test_repr.py +++ b/python/pyspark/pandas/tests/test_repr.py @@ -178,7 +178,7 @@ def test_repr_float_index(self): from pyspark.pandas.tests.test_repr import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/test_resample.py b/python/pyspark/pandas/tests/test_resample.py index 56106940f1..3b494e05e7 100644 --- a/python/pyspark/pandas/tests/test_resample.py +++ b/python/pyspark/pandas/tests/test_resample.py @@ -295,7 +295,7 @@ def test_resample_on(self): from pyspark.pandas.tests.test_resample import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/test_reshape.py b/python/pyspark/pandas/tests/test_reshape.py index 30550a9fba..a7574a5388 100644 --- a/python/pyspark/pandas/tests/test_reshape.py +++ b/python/pyspark/pandas/tests/test_reshape.py @@ -483,7 +483,7 @@ def test_merge_asof(self): from pyspark.pandas.tests.test_reshape import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/test_rolling.py b/python/pyspark/pandas/tests/test_rolling.py index be21bf16d4..6c31073d3f 100644 --- a/python/pyspark/pandas/tests/test_rolling.py +++ b/python/pyspark/pandas/tests/test_rolling.py @@ -242,7 +242,7 @@ def test_groupby_rolling_kurt(self): from pyspark.pandas.tests.test_rolling import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/test_scalars.py b/python/pyspark/pandas/tests/test_scalars.py index 0c8aa8508f..00900dbdd9 100644 --- a/python/pyspark/pandas/tests/test_scalars.py +++ b/python/pyspark/pandas/tests/test_scalars.py @@ -47,7 +47,7 @@ def test_missing(self): from pyspark.pandas.tests.test_scalars import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/test_series.py b/python/pyspark/pandas/tests/test_series.py index e47f716ecf..46a687b36c 100644 --- a/python/pyspark/pandas/tests/test_series.py +++ b/python/pyspark/pandas/tests/test_series.py @@ -3392,7 +3392,7 @@ def test_series_stat_fail(self): from pyspark.pandas.tests.test_series import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/test_series_conversion.py b/python/pyspark/pandas/tests/test_series_conversion.py index bc83fdacbe..79c2f1ff30 100644 --- a/python/pyspark/pandas/tests/test_series_conversion.py +++ b/python/pyspark/pandas/tests/test_series_conversion.py @@ -68,7 +68,7 @@ def test_to_latex(self): from pyspark.pandas.tests.test_series_conversion import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/test_series_datetime.py b/python/pyspark/pandas/tests/test_series_datetime.py index 1fe078e972..1c392644ed 100644 --- a/python/pyspark/pandas/tests/test_series_datetime.py +++ b/python/pyspark/pandas/tests/test_series_datetime.py @@ -287,7 +287,7 @@ def test_unsupported_type(self): from pyspark.pandas.tests.test_series_datetime import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/test_series_string.py b/python/pyspark/pandas/tests/test_series_string.py index 0b778583e7..f82f57981f 100644 --- a/python/pyspark/pandas/tests/test_series_string.py +++ b/python/pyspark/pandas/tests/test_series_string.py @@ -336,7 +336,7 @@ def test_string_get_dummies(self): from pyspark.pandas.tests.test_series_string import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/test_spark_functions.py b/python/pyspark/pandas/tests/test_spark_functions.py index c18dc30240..4da20f754d 100644 --- a/python/pyspark/pandas/tests/test_spark_functions.py +++ b/python/pyspark/pandas/tests/test_spark_functions.py @@ -34,7 +34,7 @@ def test_repeat(self): from pyspark.pandas.tests.test_spark_functions import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/test_sql.py b/python/pyspark/pandas/tests/test_sql.py index 5a5d6d484b..4d4afb8882 100644 --- a/python/pyspark/pandas/tests/test_sql.py +++ b/python/pyspark/pandas/tests/test_sql.py @@ -100,7 +100,7 @@ def test_sql_with_pandas_on_spark_objects(self): from pyspark.pandas.tests.test_sql import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/test_stats.py b/python/pyspark/pandas/tests/test_stats.py index 4fb08ee69e..fa7cff8f3c 100644 --- a/python/pyspark/pandas/tests/test_stats.py +++ b/python/pyspark/pandas/tests/test_stats.py @@ -554,7 +554,7 @@ def test_numeric_only_unsupported(self): from pyspark.pandas.tests.test_stats import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/test_typedef.py b/python/pyspark/pandas/tests/test_typedef.py index 1bc5c8cfdd..a5f2b2dc2b 100644 --- a/python/pyspark/pandas/tests/test_typedef.py +++ b/python/pyspark/pandas/tests/test_typedef.py @@ -133,7 +133,7 @@ def func() -> pd.DataFrame[np.float_]: pdf = pd.DataFrame({"a": [1, 2, 3], "b": [3, 4, 5]}) - def func() -> pd.DataFrame[pdf.dtypes]: # type: ignore[name-defined] + def func() -> pd.DataFrame[pdf.dtypes]: pass expected = StructType([StructField("c0", LongType()), StructField("c1", LongType())]) @@ -143,14 +143,14 @@ def func() -> pd.DataFrame[pdf.dtypes]: # type: ignore[name-defined] pdf = pd.DataFrame({"a": [1, 2, 3], "b": pd.Categorical(["a", "b", "c"])}) - def func() -> pd.Series[pdf.b.dtype]: # type: ignore[name-defined] + def func() -> pd.Series[pdf.b.dtype]: pass inferred = infer_return_type(func) self.assertEqual(inferred.dtype, CategoricalDtype(categories=["a", "b", "c"])) self.assertEqual(inferred.spark_type, LongType()) - def func() -> pd.DataFrame[pdf.dtypes]: # type: ignore[name-defined] + def func() -> pd.DataFrame[pdf.dtypes]: pass expected = StructType([StructField("c0", LongType()), StructField("c1", LongType())]) @@ -246,7 +246,7 @@ def f() -> 'pd.DataFrame["a" : float : 1, "b":str:2]': # noqa: F405 pdf = pd.DataFrame({"a": ["a", 2, None]}) def try_infer_return_type(): - def f() -> pd.DataFrame[pdf.dtypes]: # type: ignore[name-defined] + def f() -> pd.DataFrame[pdf.dtypes]: pass infer_return_type(f) @@ -254,7 +254,7 @@ def f() -> pd.DataFrame[pdf.dtypes]: # type: ignore[name-defined] self.assertRaisesRegex(TypeError, "object.*not understood", try_infer_return_type) def try_infer_return_type(): - def f() -> pd.Series[pdf.a.dtype]: # type: ignore[name-defined] + def f() -> pd.Series[pdf.a.dtype]: pass infer_return_type(f) @@ -293,7 +293,7 @@ def f() -> 'ps.DataFrame["a" : np.float_ : 1, "b":str:2]': # noqa: F405 pdf = pd.DataFrame({"a": ["a", 2, None]}) def try_infer_return_type(): - def f() -> ps.DataFrame[pdf.dtypes]: # type: ignore[name-defined] + def f() -> ps.DataFrame[pdf.dtypes]: pass infer_return_type(f) @@ -301,7 +301,7 @@ def f() -> ps.DataFrame[pdf.dtypes]: # type: ignore[name-defined] self.assertRaisesRegex(TypeError, "object.*not understood", try_infer_return_type) def try_infer_return_type(): - def f() -> ps.Series[pdf.a.dtype]: # type: ignore[name-defined] + def f() -> ps.Series[pdf.a.dtype]: pass infer_return_type(f) @@ -439,7 +439,7 @@ def test_as_spark_type_extension_float_dtypes(self): from pyspark.pandas.tests.test_typedef import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/test_utils.py b/python/pyspark/pandas/tests/test_utils.py index 11f560c6f5..cfbcb5ba0a 100644 --- a/python/pyspark/pandas/tests/test_utils.py +++ b/python/pyspark/pandas/tests/test_utils.py @@ -121,7 +121,7 @@ def lazy_prop(self): from pyspark.pandas.tests.test_utils import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/pandas/tests/test_window.py b/python/pyspark/pandas/tests/test_window.py index 49779566c9..d8bc2775fa 100644 --- a/python/pyspark/pandas/tests/test_window.py +++ b/python/pyspark/pandas/tests/test_window.py @@ -453,7 +453,7 @@ def test_missing_groupby(self): from pyspark.pandas.tests.test_window import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/resource/tests/test_resources.py b/python/pyspark/resource/tests/test_resources.py index b6babf3c6c..81a4ea4f1d 100644 --- a/python/pyspark/resource/tests/test_resources.py +++ b/python/pyspark/resource/tests/test_resources.py @@ -75,7 +75,7 @@ def assert_request_contents(exec_reqs, task_reqs): from pyspark.resource.tests.test_resources import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index 35c3397de5..da03110c32 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -78,9 +78,12 @@ def _get_local_dirs(sub): path = os.environ.get("SPARK_LOCAL_DIRS", "/tmp") dirs = path.split(",") if len(dirs) > 1: - # different order in different processes and instances - rnd = random.Random(os.getpid() + id(dirs)) - random.shuffle(dirs, rnd.random) + if sys.version_info < (3, 11): + # different order in different processes and instances + rnd = random.Random(os.getpid() + id(dirs)) + random.shuffle(dirs, rnd.random) + else: + random.shuffle(dirs) return [os.path.join(d, "python", str(os.getpid()), sub) for d in dirs] diff --git a/python/pyspark/sql/connect/__init__.py b/python/pyspark/sql/connect/__init__.py index 3df96963f9..4a98368c81 100644 --- a/python/pyspark/sql/connect/__init__.py +++ b/python/pyspark/sql/connect/__init__.py @@ -18,5 +18,13 @@ """Currently Spark Connect is very experimental and the APIs to interact with Spark through this API are can be changed at any time without warning.""" - from pyspark.sql.connect.dataframe import DataFrame # noqa: F401 +from pyspark.sql.pandas.utils import ( + require_minimum_pandas_version, + require_minimum_pyarrow_version, + require_minimum_grpc_version, +) + +require_minimum_pandas_version() +require_minimum_pyarrow_version() +require_minimum_grpc_version() diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py index 745ca79fda..c4c74f5d6c 100644 --- a/python/pyspark/sql/connect/client.py +++ b/python/pyspark/sql/connect/client.py @@ -20,36 +20,19 @@ import uuid from typing import Iterable, Optional, Any, Union, List, Tuple, Dict -import grpc # type: ignore +import grpc import pandas import pyarrow as pa import pyspark.sql.connect.proto as pb2 import pyspark.sql.connect.proto.base_pb2_grpc as grpc_lib +import pyspark.sql.connect.types as types import pyspark.sql.types from pyspark import cloudpickle from pyspark.sql.types import ( DataType, - ByteType, - ShortType, - IntegerType, - FloatType, - DateType, - TimestampType, - DayTimeIntervalType, - MapType, - StringType, - CharType, - VarcharType, StructType, StructField, - ArrayType, - DoubleType, - LongType, - DecimalType, - BinaryType, - BooleanType, - NullType, ) @@ -350,73 +333,7 @@ def _to_pandas(self, plan: pb2.Plan) -> "pandas.DataFrame": return self._execute_and_fetch(req) def _proto_schema_to_pyspark_schema(self, schema: pb2.DataType) -> DataType: - if schema.HasField("null"): - return NullType() - elif schema.HasField("boolean"): - return BooleanType() - elif schema.HasField("binary"): - return BinaryType() - elif schema.HasField("byte"): - return ByteType() - elif schema.HasField("short"): - return ShortType() - elif schema.HasField("integer"): - return IntegerType() - elif schema.HasField("long"): - return LongType() - elif schema.HasField("float"): - return FloatType() - elif schema.HasField("double"): - return DoubleType() - elif schema.HasField("decimal"): - p = schema.decimal.precision if schema.decimal.HasField("precision") else 10 - s = schema.decimal.scale if schema.decimal.HasField("scale") else 0 - return DecimalType(precision=p, scale=s) - elif schema.HasField("string"): - return StringType() - elif schema.HasField("char"): - return CharType(schema.char.length) - elif schema.HasField("var_char"): - return VarcharType(schema.var_char.length) - elif schema.HasField("date"): - return DateType() - elif schema.HasField("timestamp"): - return TimestampType() - elif schema.HasField("day_time_interval"): - start: Optional[int] = ( - schema.day_time_interval.start_field - if schema.day_time_interval.HasField("start_field") - else None - ) - end: Optional[int] = ( - schema.day_time_interval.end_field - if schema.day_time_interval.HasField("end_field") - else None - ) - return DayTimeIntervalType(startField=start, endField=end) - elif schema.HasField("array"): - return ArrayType( - self._proto_schema_to_pyspark_schema(schema.array.element_type), - schema.array.contains_null, - ) - elif schema.HasField("struct"): - fields = [ - StructField( - f.name, - self._proto_schema_to_pyspark_schema(f.data_type), - f.nullable, - ) - for f in schema.struct.fields - ] - return StructType(fields) - elif schema.HasField("map"): - return MapType( - self._proto_schema_to_pyspark_schema(schema.map.key_type), - self._proto_schema_to_pyspark_schema(schema.map.value_type), - schema.map.value_contains_null, - ) - else: - raise Exception(f"Unsupported data type {schema}") + return types.proto_schema_to_pyspark_data_type(schema) def schema(self, plan: pb2.Plan) -> StructType: proto_schema = self._analyze(plan).schema diff --git a/python/pyspark/sql/connect/column.py b/python/pyspark/sql/connect/column.py index e864f6c93e..58d4e3dc41 100644 --- a/python/pyspark/sql/connect/column.py +++ b/python/pyspark/sql/connect/column.py @@ -21,18 +21,21 @@ import decimal import datetime -from pyspark.sql.types import TimestampType, DayTimeIntervalType, DateType +from pyspark.sql.types import TimestampType, DayTimeIntervalType, DataType, DateType import pyspark.sql.connect.proto as proto +from pyspark.sql.connect.types import pyspark_types_to_proto_types if TYPE_CHECKING: - from pyspark.sql.connect._typing import ColumnOrName + from pyspark.sql.connect._typing import ColumnOrName, PrimitiveType from pyspark.sql.connect.client import SparkConnectClient import pyspark.sql.connect.proto as proto -# TODO(SPARK-41329): solve the circular import between _typing and this class -# if we want to reuse _type.PrimitiveType -PrimitiveType = Union[bool, float, int, str] + +JVM_INT_MIN = -(1 << 31) +JVM_INT_MAX = (1 << 31) - 1 +JVM_LONG_MIN = -(1 << 63) +JVM_LONG_MAX = (1 << 63) - 1 def _func_op(name: str, doc: str = "") -> Callable[["Column"], "Column"]: @@ -183,7 +186,12 @@ def to_plan(self, session: "SparkConnectClient") -> "proto.Expression": elif isinstance(self._value, bool): expr.literal.boolean = bool(self._value) elif isinstance(self._value, int): - expr.literal.long = int(self._value) + if JVM_INT_MIN <= self._value <= JVM_INT_MAX: + expr.literal.integer = int(self._value) + elif JVM_LONG_MIN <= self._value <= JVM_LONG_MAX: + expr.literal.long = int(self._value) + else: + raise ValueError(f"integer {self._value} out of bounds") elif isinstance(self._value, float): expr.literal.double = float(self._value) elif isinstance(self._value, str): @@ -355,6 +363,29 @@ def __repr__(self) -> str: return f"{self._name}({', '.join([str(arg) for arg in self._args])})" +class CastExpression(Expression): + def __init__( + self, + col: "Column", + data_type: Union[DataType, str], + ) -> None: + super().__init__() + self._col = col + self._data_type = data_type + + def to_plan(self, session: "SparkConnectClient") -> proto.Expression: + fun = proto.Expression() + fun.cast.expr.CopyFrom(self._col.to_plan(session)) + if isinstance(self._data_type, str): + fun.cast.type_str = self._data_type + else: + fun.cast.type.CopyFrom(pyspark_types_to_proto_types(self._data_type)) + return fun + + def __repr__(self) -> str: + return f"({self._col} ({self._data_type}))" + + class Column: """ A column in a DataFrame. Column can refer to different things based on the @@ -530,7 +561,7 @@ def __ne__( # type: ignore[override] return _func_op("not")(_bin_op("==")(self, other)) # string methods - def contains(self, other: Union[PrimitiveType, "Column"]) -> "Column": + def contains(self, other: Union["PrimitiveType", "Column"]) -> "Column": """ Contains the other element. Returns a boolean :class:`Column` based on a string match. @@ -674,6 +705,9 @@ def substr(self, startPos: Union[int, "Column"], length: Union[int, "Column"]) - >>> df.select(df.name.substr(1, 3).alias("col")).collect() [Row(col='Ali'), Row(col='Bob')] """ + from pyspark.sql.connect.function_builder import functions as F + from pyspark.sql.connect.functions import lit + if type(startPos) != type(length): raise TypeError( "startPos and length must be the same type. " @@ -682,17 +716,16 @@ def substr(self, startPos: Union[int, "Column"], length: Union[int, "Column"]) - length_t=type(length), ) ) - from pyspark.sql.connect.function_builder import functions as F if isinstance(length, int): - length_exp = self._lit(length) + length_exp = lit(length) elif isinstance(length, Column): length_exp = length else: raise TypeError("Unsupported type for substr().") if isinstance(startPos, int): - start_exp = self._lit(startPos) + start_exp = lit(startPos) else: start_exp = startPos @@ -702,8 +735,11 @@ def __eq__(self, other: Any) -> "Column": # type: ignore[override] """Returns a binary expression with the current column as the left side and the other expression as the right side. """ + from pyspark.sql.connect._typing import PrimitiveType + from pyspark.sql.connect.functions import lit + if isinstance(other, get_args(PrimitiveType)): - other = self._lit(other) + other = lit(other) return scalar_function("==", self, other) def to_plan(self, session: "SparkConnectClient") -> proto.Expression: @@ -733,10 +769,63 @@ def desc_nulls_last(self) -> "Column": def name(self) -> str: return self._expr.name() - # TODO(SPARK-41329): solve the circular import between functions.py and - # this class if we want to reuse functions.lit - def _lit(self, x: Any) -> "Column": - return Column(LiteralExpression(x)) + def cast(self, dataType: Union[DataType, str]) -> "Column": + """ + Casts the column into type ``dataType``. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + dataType : :class:`DataType` or str + a DataType or Python string literal with a DDL-formatted string + to use when parsing the column to the same type. + + Returns + ------- + :class:`Column` + Column representing whether each element of Column is cast into new type. + """ + if isinstance(dataType, (DataType, str)): + return Column(CastExpression(col=self, data_type=dataType)) + else: + raise TypeError("unexpected type: %s" % type(dataType)) def __repr__(self) -> str: return "Column<'%s'>" % self._expr.__repr__() + + def otherwise(self, *args: Any, **kwargs: Any) -> None: + raise NotImplementedError("otherwise() is not yet implemented.") + + def over(self, *args: Any, **kwargs: Any) -> None: + raise NotImplementedError("over() is not yet implemented.") + + def isin(self, *args: Any, **kwargs: Any) -> None: + raise NotImplementedError("isin() is not yet implemented.") + + def when(self, *args: Any, **kwargs: Any) -> None: + raise NotImplementedError("when() is not yet implemented.") + + def getItem(self, *args: Any, **kwargs: Any) -> None: + raise NotImplementedError("getItem() is not yet implemented.") + + def astype(self, *args: Any, **kwargs: Any) -> None: + raise NotImplementedError("astype() is not yet implemented.") + + def between(self, *args: Any, **kwargs: Any) -> None: + raise NotImplementedError("between() is not yet implemented.") + + def getField(self, *args: Any, **kwargs: Any) -> None: + raise NotImplementedError("getField() is not yet implemented.") + + def withField(self, *args: Any, **kwargs: Any) -> None: + raise NotImplementedError("withField() is not yet implemented.") + + def dropFields(self, *args: Any, **kwargs: Any) -> None: + raise NotImplementedError("dropFields() is not yet implemented.") + + def __getitem__(self, k: Any) -> None: + raise NotImplementedError("apply() - __getitem__ is not yet implemented.") + + def __iter__(self) -> None: + raise TypeError("Column is not iterable") diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index f268dc431b..08d48bb11f 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -824,6 +824,57 @@ def withColumn(self, colName: str, col: Column) -> "DataFrame": session=self._session, ) + def unpivot( + self, + ids: Optional[Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]]], + values: Optional[Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]]], + variableColumnName: str, + valueColumnName: str, + ) -> "DataFrame": + """ + Returns a new :class:`DataFrame` by unpivot a DataFrame from wide format to long format, + optionally leaving identifier columns set. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + ids : list + Id columns. + values : list, optional + Value columns to unpivot. + variableColumnName : str + Name of the variable column. + valueColumnName : str + Name of the value column. + + Returns + ------- + :class:`DataFrame` + """ + + def to_jcols( + cols: Optional[Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]]] + ) -> List["ColumnOrName"]: + if cols is None: + lst = [] + elif isinstance(cols, tuple): + lst = list(cols) + elif isinstance(cols, list): + lst = cols + else: + lst = [cols] + return lst + + return DataFrame.withPlan( + plan.Unpivot( + self._plan, to_jcols(ids), to_jcols(values), variableColumnName, valueColumnName + ), + self._session, + ) + + melt = unpivot + def show(self, n: int = 20, truncate: Union[bool, int] = True, vertical: bool = False) -> None: """ Prints the first ``n`` rows to the console. diff --git a/python/pyspark/sql/connect/functions.py b/python/pyspark/sql/connect/functions.py index 8b36647ae5..dccb6d6e0c 100644 --- a/python/pyspark/sql/connect/functions.py +++ b/python/pyspark/sql/connect/functions.py @@ -23,7 +23,7 @@ SQLExpression, ) -from typing import Any, TYPE_CHECKING, Union, List, Optional, Tuple +from typing import Any, TYPE_CHECKING, Union, List, overload, Optional, Tuple if TYPE_CHECKING: from pyspark.sql.connect._typing import ColumnOrName @@ -90,7 +90,10 @@ def col(col: str) -> Column: def lit(col: Any) -> Column: - return Column(LiteralExpression(col)) + if isinstance(col, Column): + return col + else: + return Column(LiteralExpression(col)) # def bitwiseNOT(col: "ColumnOrName") -> Column: @@ -3208,136 +3211,235 @@ def variance(col: "ColumnOrName") -> Column: return var_samp(col) -# String/Binary functions +# Collection Functions -def upper(col: "ColumnOrName") -> Column: +# TODO(SPARK-41434): need to support LambdaFunction Expression first +# def aggregate( +# col: "ColumnOrName", +# initialValue: "ColumnOrName", +# merge: Callable[[Column, Column], Column], +# finish: Optional[Callable[[Column], Column]] = None, +# ) -> Column: +# """ +# Applies a binary operator to an initial state and all elements in the array, +# and reduces this to a single state. The final state is converted into the final result +# by applying a finish function. +# +# Both functions can use methods of :class:`~pyspark.sql.Column`, functions defined in +# :py:mod:`pyspark.sql.functions` and Scala ``UserDefinedFunctions``. +# Python ``UserDefinedFunctions`` are not supported +# (`SPARK-27052 `__). +# +# .. versionadded:: 3.1.0 +# +# Parameters +# ---------- +# col : :class:`~pyspark.sql.Column` or str +# name of column or expression +# initialValue : :class:`~pyspark.sql.Column` or str +# initial value. Name of column or expression +# merge : function +# a binary function ``(acc: Column, x: Column) -> Column...`` returning expression +# of the same type as ``zero`` +# finish : function +# an optional unary function ``(x: Column) -> Column: ...`` +# used to convert accumulated value. +# +# Returns +# ------- +# :class:`~pyspark.sql.Column` +# final value after aggregate function is applied. +# +# Examples +# -------- +# >>> df = spark.createDataFrame([(1, [20.0, 4.0, 2.0, 6.0, 10.0])], ("id", "values")) +# >>> df.select(aggregate("values", lit(0.0), lambda acc, x: acc + x).alias("sum")).show() +# +----+ +# | sum| +# +----+ +# |42.0| +# +----+ +# +# >>> def merge(acc, x): +# ... count = acc.count + 1 +# ... sum = acc.sum + x +# ... return struct(count.alias("count"), sum.alias("sum")) +# >>> df.select( +# ... aggregate( +# ... "values", +# ... struct(lit(0).alias("count"), lit(0.0).alias("sum")), +# ... merge, +# ... lambda acc: acc.sum / acc.count, +# ... ).alias("mean") +# ... ).show() +# +----+ +# |mean| +# +----+ +# | 8.4| +# +----+ +# """ +# if finish is not None: +# return _invoke_higher_order_function("ArrayAggregate", [col, initialValue], +# [merge, finish]) +# +# else: +# return _invoke_higher_order_function("ArrayAggregate", [col, initialValue], +# [merge]) + + +def array(*cols: Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]]) -> Column: + """Creates a new array column. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + cols : :class:`~pyspark.sql.Column` or str + column names or :class:`~pyspark.sql.Column`\\s that have + the same data type. + + Returns + ------- + :class:`~pyspark.sql.Column` + a column of array type. + + Examples + -------- + >>> df = spark.createDataFrame([("Alice", 2), ("Bob", 5)], ("name", "age")) + >>> df.select(array('age', 'age').alias("arr")).collect() + [Row(arr=[2, 2]), Row(arr=[5, 5])] + >>> df.select(array([df.age, df.age]).alias("arr")).collect() + [Row(arr=[2, 2]), Row(arr=[5, 5])] + >>> df.select(array('age', 'age').alias("col")).printSchema() + root + |-- col: array (nullable = false) + | |-- element: long (containsNull = true) """ - Converts a string expression to upper case. + if len(cols) == 1 and isinstance(cols[0], (list, set, tuple)): + cols = cols[0] # type: ignore[assignment] + return _invoke_function_over_columns("array", *cols) # type: ignore[arg-type] + + +def array_contains(col: "ColumnOrName", value: Any) -> Column: + """ + Collection function: returns null if the array is null, true if the array contains the + given value, and false otherwise. .. versionadded:: 3.4.0 Parameters ---------- col : :class:`~pyspark.sql.Column` or str - target column to work on. + name of column containing array + value : + value or column to check for in array Returns ------- :class:`~pyspark.sql.Column` - upper case values. + a column of Boolean type. Examples -------- - >>> df = spark.createDataFrame(["Spark", "PySpark", "Pandas API"], "STRING") - >>> df.select(upper("value")).show() - +------------+ - |upper(value)| - +------------+ - | SPARK| - | PYSPARK| - | PANDAS API| - +------------+ + >>> df = spark.createDataFrame([(["a", "b", "c"],), ([],)], ['data']) + >>> df.select(array_contains(df.data, "a")).collect() + [Row(array_contains(data, a)=True), Row(array_contains(data, a)=False)] + >>> df.select(array_contains(df.data, lit("a"))).collect() + [Row(array_contains(data, a)=True), Row(array_contains(data, a)=False)] """ - return _invoke_function_over_columns("upper", col) + return _invoke_function("array_contains", _to_col(col), lit(value)) -def lower(col: "ColumnOrName") -> Column: +def array_distinct(col: "ColumnOrName") -> Column: """ - Converts a string expression to lower case. + Collection function: removes duplicate values from the array. .. versionadded:: 3.4.0 Parameters ---------- col : :class:`~pyspark.sql.Column` or str - target column to work on. + name of column or expression Returns ------- :class:`~pyspark.sql.Column` - lower case values. + an array of unique values. Examples -------- - >>> df = spark.createDataFrame(["Spark", "PySpark", "Pandas API"], "STRING") - >>> df.select(lower("value")).show() - +------------+ - |lower(value)| - +------------+ - | spark| - | pyspark| - | pandas api| - +------------+ + >>> df = spark.createDataFrame([([1, 2, 3, 2],), ([4, 5, 5, 4],)], ['data']) + >>> df.select(array_distinct(df.data)).collect() + [Row(array_distinct(data)=[1, 2, 3]), Row(array_distinct(data)=[4, 5])] """ - return _invoke_function_over_columns("lower", col) + return _invoke_function_over_columns("array_distinct", col) -def ascii(col: "ColumnOrName") -> Column: +def array_except(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: """ - Computes the numeric value of the first character of the string column. + Collection function: returns an array of the elements in col1 but not in col2, + without duplicates. .. versionadded:: 3.4.0 Parameters ---------- - col : :class:`~pyspark.sql.Column` or str - target column to work on. + col1 : :class:`~pyspark.sql.Column` or str + name of column containing array + col2 : :class:`~pyspark.sql.Column` or str + name of column containing array Returns ------- :class:`~pyspark.sql.Column` - numeric value. + an array of values from first array that are not in the second. Examples -------- - >>> df = spark.createDataFrame(["Spark", "PySpark", "Pandas API"], "STRING") - >>> df.select(ascii("value")).show() - +------------+ - |ascii(value)| - +------------+ - | 83| - | 80| - | 80| - +------------+ + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])]) + >>> df.select(array_except(df.c1, df.c2)).collect() + [Row(array_except(c1, c2)=['b'])] """ - return _invoke_function_over_columns("ascii", col) + return _invoke_function_over_columns("array_except", col1, col2) -def base64(col: "ColumnOrName") -> Column: +def array_intersect(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: """ - Computes the BASE64 encoding of a binary column and returns it as a string column. + Collection function: returns an array of the elements in the intersection of col1 and col2, + without duplicates. .. versionadded:: 3.4.0 Parameters ---------- - col : :class:`~pyspark.sql.Column` or str - target column to work on. + col1 : :class:`~pyspark.sql.Column` or str + name of column containing array + col2 : :class:`~pyspark.sql.Column` or str + name of column containing array Returns ------- :class:`~pyspark.sql.Column` - BASE64 encoding of string value. + an array of values in the intersection of two arrays. Examples -------- - >>> df = spark.createDataFrame(["Spark", "PySpark", "Pandas API"], "STRING") - >>> df.select(base64("value")).show() - +----------------+ - | base64(value)| - +----------------+ - | U3Bhcms=| - | UHlTcGFyaw==| - |UGFuZGFzIEFQSQ==| - +----------------+ + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])]) + >>> df.select(array_intersect(df.c1, df.c2)).collect() + [Row(array_intersect(c1, c2)=['a', 'c'])] """ - return _invoke_function_over_columns("base64", col) + return _invoke_function_over_columns("array_intersect", col1, col2) -def unbase64(col: "ColumnOrName") -> Column: +def array_join( + col: "ColumnOrName", delimiter: str, null_replacement: Optional[str] = None +) -> Column: """ - Decodes a BASE64 encoded string column and returns it as a binary column. + Concatenates the elements of `column` using the `delimiter`. Null values are replaced with + `null_replacement` if set, otherwise they are ignored. .. versionadded:: 3.4.0 @@ -3345,210 +3447,3510 @@ def unbase64(col: "ColumnOrName") -> Column: ---------- col : :class:`~pyspark.sql.Column` or str target column to work on. + delimiter : str + delimiter used to concatenate elements + null_replacement : str, optional + if set then null values will be replaced by this value Returns ------- :class:`~pyspark.sql.Column` - encoded string value. + a column of string type. Concatenated values. Examples -------- - >>> df = spark.createDataFrame(["U3Bhcms=", - ... "UHlTcGFyaw==", - ... "UGFuZGFzIEFQSQ=="], "STRING") - >>> df.select(unbase64("value")).show() - +--------------------+ - | unbase64(value)| - +--------------------+ - | [53 70 61 72 6B]| - |[50 79 53 70 61 7...| - |[50 61 6E 64 61 7...| - +--------------------+ + >>> df = spark.createDataFrame([(["a", "b", "c"],), (["a", None],)], ['data']) + >>> df.select(array_join(df.data, ",").alias("joined")).collect() + [Row(joined='a,b,c'), Row(joined='a')] + >>> df.select(array_join(df.data, ",", "NULL").alias("joined")).collect() + [Row(joined='a,b,c'), Row(joined='a,NULL')] + """ + if null_replacement is None: + return _invoke_function("array_join", _to_col(col), lit(delimiter)) + else: + return _invoke_function("array_join", _to_col(col), lit(delimiter), lit(null_replacement)) + + +def array_max(col: "ColumnOrName") -> Column: """ - return _invoke_function_over_columns("unbase64", col) + Collection function: returns the maximum value of the array. + .. versionadded:: 3.4.0 -def ltrim(col: "ColumnOrName") -> Column: + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + name of column or expression + + Returns + ------- + :class:`~pyspark.sql.Column` + maximum value of an array. + + Examples + -------- + >>> df = spark.createDataFrame([([2, 1, 3],), ([None, 10, -1],)], ['data']) + >>> df.select(array_max(df.data).alias('max')).collect() + [Row(max=3), Row(max=10)] """ - Trim the spaces from left end for the specified string value. + return _invoke_function_over_columns("array_max", col) + + +def array_min(col: "ColumnOrName") -> Column: + """ + Collection function: returns the minimum value of the array. .. versionadded:: 3.4.0 Parameters ---------- col : :class:`~pyspark.sql.Column` or str - target column to work on. + name of column or expression Returns ------- :class:`~pyspark.sql.Column` - left trimmed values. + minimum value of array. Examples -------- - >>> df = spark.createDataFrame([" Spark", "Spark ", " Spark"], "STRING") - >>> df.select(ltrim("value").alias("r")).withColumn("length", length("r")).show() - +-------+------+ - | r|length| - +-------+------+ - | Spark| 5| - |Spark | 7| - | Spark| 5| - +-------+------+ + >>> df = spark.createDataFrame([([2, 1, 3],), ([None, 10, -1],)], ['data']) + >>> df.select(array_min(df.data).alias('min')).collect() + [Row(min=1), Row(min=-1)] """ - return _invoke_function_over_columns("ltrim", col) + return _invoke_function_over_columns("array_min", col) -def rtrim(col: "ColumnOrName") -> Column: +def array_position(col: "ColumnOrName", value: Any) -> Column: """ - Trim the spaces from right end for the specified string value. + Collection function: Locates the position of the first occurrence of the given value + in the given array. Returns null if either of the arguments are null. .. versionadded:: 3.4.0 + Notes + ----- + The position is not zero based, but 1 based index. Returns 0 if the given + value could not be found in the array. + Parameters ---------- col : :class:`~pyspark.sql.Column` or str target column to work on. + value : Any + value to look for. Returns ------- :class:`~pyspark.sql.Column` - right trimmed values. + position of the value in the given array if found and 0 otherwise. Examples -------- - >>> df = spark.createDataFrame([" Spark", "Spark ", " Spark"], "STRING") - >>> df.select(rtrim("value").alias("r")).withColumn("length", length("r")).show() - +--------+------+ - | r|length| - +--------+------+ - | Spark| 8| - | Spark| 5| - | Spark| 6| - +--------+------+ + >>> df = spark.createDataFrame([(["c", "b", "a"],), ([],)], ['data']) + >>> df.select(array_position(df.data, "a")).collect() + [Row(array_position(data, a)=3), Row(array_position(data, a)=0)] """ - return _invoke_function_over_columns("rtrim", col) + return _invoke_function("array_position", _to_col(col), lit(value)) -def trim(col: "ColumnOrName") -> Column: +def array_remove(col: "ColumnOrName", element: Any) -> Column: """ - Trim the spaces from both ends for the specified string column. + Collection function: Remove all elements that equal to element from the given array. .. versionadded:: 3.4.0 Parameters ---------- col : :class:`~pyspark.sql.Column` or str - target column to work on. + name of column containing array + element : + element to be removed from the array Returns ------- :class:`~pyspark.sql.Column` - trimmed values from both sides. + an array excluding given value. Examples -------- - >>> df = spark.createDataFrame([" Spark", "Spark ", " Spark"], "STRING") - >>> df.select(trim("value").alias("r")).withColumn("length", length("r")).show() - +-----+------+ - | r|length| - +-----+------+ - |Spark| 5| - |Spark| 5| - |Spark| 5| - +-----+------+ + >>> df = spark.createDataFrame([([1, 2, 3, 1, 1],), ([],)], ['data']) + >>> df.select(array_remove(df.data, 1)).collect() + [Row(array_remove(data, 1)=[2, 3]), Row(array_remove(data, 1)=[])] """ - return _invoke_function_over_columns("trim", col) + return _invoke_function("array_remove", _to_col(col), lit(element)) -def concat_ws(sep: str, *cols: "ColumnOrName") -> Column: +def array_repeat(col: "ColumnOrName", count: Union["ColumnOrName", int]) -> Column: """ - Concatenates multiple input string columns together into a single string column, - using the given separator. + Collection function: creates an array containing a column repeated count times. .. versionadded:: 3.4.0 Parameters ---------- - sep : str - words separator. - cols : :class:`~pyspark.sql.Column` or str - list of columns to work on. + col : :class:`~pyspark.sql.Column` or str + column name or column that contains the element to be repeated + count : :class:`~pyspark.sql.Column` or str or int + column name, column, or int containing the number of times to repeat the first argument Returns ------- :class:`~pyspark.sql.Column` - string of concatenated words. + an array of repeated elements. Examples -------- - >>> df = spark.createDataFrame([('abcd','123')], ['s', 'd']) - >>> df.select(concat_ws('-', df.s, df.d).alias('s')).collect() - [Row(s='abcd-123')] + >>> df = spark.createDataFrame([('ab',)], ['data']) + >>> df.select(array_repeat(df.data, 3).alias('r')).collect() + [Row(r=['ab', 'ab', 'ab'])] """ - return _invoke_function("concat_ws", lit(sep), *[_to_col(c) for c in cols]) + _count = lit(count) if isinstance(count, int) else _to_col(count) + + return _invoke_function("array_repeat", _to_col(col), _count) -# TODO: enable with SPARK-41402 -# def decode(col: "ColumnOrName", charset: str) -> Column: +# TODO(SPARK-41434): need to support LambdaFunction Expression first +# def array_sort( +# col: "ColumnOrName", comparator: Optional[Callable[[Column, Column], Column]] = None +# ) -> Column: # """ -# Computes the first argument into a string from a binary using the provided character set -# (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). +# Collection function: sorts the input array in ascending order. The elements of the input array +# must be orderable. Null elements will be placed at the end of the returned array. # -# .. versionadded:: 3.4.0 +# .. versionadded:: 2.4.0 +# .. versionchanged:: 3.4.0 +# Can take a `comparator` function. # # Parameters # ---------- # col : :class:`~pyspark.sql.Column` or str -# target column to work on. -# charset : str -# charset to use to decode to. +# name of column or expression +# comparator : callable, optional +# A binary ``(Column, Column) -> Column: ...``. +# The comparator will take two +# arguments representing two elements of the array. It returns a negative integer, 0, or a +# positive integer as the first element is less than, equal to, or greater than the second +# element. If the comparator function returns null, the function will fail and raise an +# error. # # Returns # ------- # :class:`~pyspark.sql.Column` -# the column for computed results. +# sorted array. # # Examples # -------- -# >>> df = spark.createDataFrame([('abcd',)], ['a']) -# >>> df.select(decode("a", "UTF-8")).show() -# +----------------------+ -# |stringdecode(a, UTF-8)| -# +----------------------+ -# | abcd| -# +----------------------+ +# >>> df = spark.createDataFrame([([2, 1, None, 3],),([1],),([],)], ['data']) +# >>> df.select(array_sort(df.data).alias('r')).collect() +# [Row(r=[1, 2, 3, None]), Row(r=[1]), Row(r=[])] +# >>> df = spark.createDataFrame([(["foo", "foobar", None, "bar"],),(["foo"],),([],)], ['data']) +# >>> df.select(array_sort( +# ... "data", +# ... lambda x, y: when(x.isNull() | y.isNull(), lit(0)).otherwise(length(y) - length(x)) +# ... ).alias("r")).collect() +# [Row(r=['foobar', 'foo', None, 'bar']), Row(r=['foo']), Row(r=[])] # """ -# return _invoke_function("decode", _to_col(col), lit(charset)) +# if comparator is None: +# return _invoke_function_over_columns("array_sort", col) +# else: +# return _invoke_higher_order_function("ArraySort", [col], [comparator]) -def encode(col: "ColumnOrName", charset: str) -> Column: +def array_union(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: """ - Computes the first argument into a binary from a string using the provided character set - (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). + Collection function: returns an array of the elements in the union of col1 and col2, + without duplicates. .. versionadded:: 3.4.0 Parameters ---------- - col : :class:`~pyspark.sql.Column` or str - target column to work on. - charset : str - charset to use to encode. + col1 : :class:`~pyspark.sql.Column` or str + name of column containing array + col2 : :class:`~pyspark.sql.Column` or str + name of column containing array Returns ------- :class:`~pyspark.sql.Column` - the column for computed results. + an array of values in union of two arrays. Examples -------- - >>> df = spark.createDataFrame([('abcd',)], ['c']) - >>> df.select(encode("c", "UTF-8")).show() - +----------------+ - |encode(c, UTF-8)| - +----------------+ - | [61 62 63 64]| - +----------------+ + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2=["c", "d", "a", "f"])]) + >>> df.select(array_union(df.c1, df.c2)).collect() + [Row(array_union(c1, c2)=['b', 'a', 'c', 'd', 'f'])] """ - return _invoke_function("encode", _to_col(col), lit(charset)) + return _invoke_function_over_columns("array_union", col1, col2) + + +def arrays_overlap(a1: "ColumnOrName", a2: "ColumnOrName") -> Column: + """ + Collection function: returns true if the arrays contain any common non-null element; if not, + returns null if both the arrays are non-empty and any of them contains a null element; returns + false otherwise. + + .. versionadded:: 3.4.0 + + Returns + ------- + :class:`~pyspark.sql.Column` + a column of Boolean type. + + Examples + -------- + >>> df = spark.createDataFrame([(["a", "b"], ["b", "c"]), (["a"], ["b", "c"])], ['x', 'y']) + >>> df.select(arrays_overlap(df.x, df.y).alias("overlap")).collect() + [Row(overlap=True), Row(overlap=False)] + """ + return _invoke_function_over_columns("arrays_overlap", a1, a2) + + +def arrays_zip(*cols: "ColumnOrName") -> Column: + """ + Collection function: Returns a merged array of structs in which the N-th struct contains all + N-th values of input arrays. If one of the arrays is shorter than others then + resulting struct type value will be a `null` for missing elements. + + .. versionadded:: 2.4.0 + + Parameters + ---------- + cols : :class:`~pyspark.sql.Column` or str + columns of arrays to be merged. + + Returns + ------- + :class:`~pyspark.sql.Column` + merged array of entries. + + Examples + -------- + >>> from pyspark.sql.functions import arrays_zip + >>> df = spark.createDataFrame([(([1, 2, 3], [2, 4, 6], [3, 6]))], ['vals1', 'vals2', 'vals3']) + >>> df = df.select(arrays_zip(df.vals1, df.vals2, df.vals3).alias('zipped')) + >>> df.show(truncate=False) + +------------------------------------+ + |zipped | + +------------------------------------+ + |[{1, 2, 3}, {2, 4, 6}, {3, 6, null}]| + +------------------------------------+ + >>> df.printSchema() + root + |-- zipped: array (nullable = true) + | |-- element: struct (containsNull = false) + | | |-- vals1: long (nullable = true) + | | |-- vals2: long (nullable = true) + | | |-- vals3: long (nullable = true) + """ + return _invoke_function_over_columns("arrays_zip", *cols) + + +def concat(*cols: "ColumnOrName") -> Column: + """ + Concatenates multiple input columns together into a single column. + The function works with strings, numeric, binary and compatible array columns. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + cols : :class:`~pyspark.sql.Column` or str + target column or columns to work on. + + Returns + ------- + :class:`~pyspark.sql.Column` + concatenated values. Type of the `Column` depends on input columns' type. + + See Also + -------- + :meth:`pyspark.sql.functions.array_join` : to concatenate string columns with delimiter + + Examples + -------- + >>> df = spark.createDataFrame([('abcd','123')], ['s', 'd']) + >>> df = df.select(concat(df.s, df.d).alias('s')) + >>> df.collect() + [Row(s='abcd123')] + >>> df + DataFrame[s: string] + + >>> df = spark.createDataFrame([([1, 2], [3, 4], [5]), ([1, 2], None, [3])], ['a', 'b', 'c']) + >>> df = df.select(concat(df.a, df.b, df.c).alias("arr")) + >>> df.collect() + [Row(arr=[1, 2, 3, 4, 5]), Row(arr=None)] + >>> df + DataFrame[arr: array] + """ + return _invoke_function_over_columns("concat", *cols) + + +def create_map( + *cols: Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]] +) -> Column: + """Creates a new map column. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + cols : :class:`~pyspark.sql.Column` or str + column names or :class:`~pyspark.sql.Column`\\s that are + grouped as key-value pairs, e.g. (key1, value1, key2, value2, ...). + + Examples + -------- + >>> df = spark.createDataFrame([("Alice", 2), ("Bob", 5)], ("name", "age")) + >>> df.select(create_map('name', 'age').alias("map")).collect() + [Row(map={'Alice': 2}), Row(map={'Bob': 5})] + >>> df.select(create_map([df.name, df.age]).alias("map")).collect() + [Row(map={'Alice': 2}), Row(map={'Bob': 5})] + """ + if len(cols) == 1 and isinstance(cols[0], (list, set, tuple)): + cols = cols[0] # type: ignore[assignment] + return _invoke_function_over_columns("map", *cols) # type: ignore[arg-type] + + +def element_at(col: "ColumnOrName", extraction: Any) -> Column: + """ + Collection function: Returns element of array at given index in `extraction` if col is array. + Returns value for the given key in `extraction` if col is map. If position is negative + then location of the element will start from end, if number is outside the + array boundaries then None will be returned. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + name of column containing array or map + extraction : + index to check for in array or key to check for in map + + Returns + ------- + :class:`~pyspark.sql.Column` + value at given position. + + Notes + ----- + The position is not zero based, but 1 based index. + + See Also + -------- + :meth:`get` + + Examples + -------- + >>> df = spark.createDataFrame([(["a", "b", "c"],)], ['data']) + >>> df.select(element_at(df.data, 1)).collect() + [Row(element_at(data, 1)='a')] + >>> df.select(element_at(df.data, -1)).collect() + [Row(element_at(data, -1)='c')] + + >>> df = spark.createDataFrame([({"a": 1.0, "b": 2.0},)], ['data']) + >>> df.select(element_at(df.data, lit("a"))).collect() + [Row(element_at(data, a)=1.0)] + """ + return _invoke_function("element_at", _to_col(col), lit(extraction)) + + +# TODO(SPARK-41434): need to support LambdaFunction Expression first +# def exists(col: "ColumnOrName", f: Callable[[Column], Column]) -> Column: +# """ +# Returns whether a predicate holds for one or more elements in the array. +# +# .. versionadded:: 3.1.0 +# +# Parameters +# ---------- +# col : :class:`~pyspark.sql.Column` or str +# name of column or expression +# f : function +# ``(x: Column) -> Column: ...`` returning the Boolean expression. +# Can use methods of :class:`~pyspark.sql.Column`, functions defined in +# :py:mod:`pyspark.sql.functions` and Scala ``UserDefinedFunctions``. +# Python ``UserDefinedFunctions`` are not supported +# (`SPARK-27052 `__). +# +# Returns +# ------- +# :class:`~pyspark.sql.Column` +# True if "any" element of an array evaluates to True when passed as an argument to +# given function and False otherwise. +# +# Examples +# -------- +# >>> df = spark.createDataFrame([(1, [1, 2, 3, 4]), (2, [3, -1, 0])],("key", "values")) +# >>> df.select(exists("values", lambda x: x < 0).alias("any_negative")).show() +# +------------+ +# |any_negative| +# +------------+ +# | false| +# | true| +# +------------+ +# """ +# return _invoke_higher_order_function("ArrayExists", [col], [f]) + + +def explode(col: "ColumnOrName") -> Column: + """ + Returns a new row for each element in the given array or map. + Uses the default column name `col` for elements in the array and + `key` and `value` for elements in the map unless specified otherwise. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + target column to work on. + + Returns + ------- + :class:`~pyspark.sql.Column` + one row per array item or map key value. + + See Also + -------- + :meth:`pyspark.functions.posexplode` + :meth:`pyspark.functions.explode_outer` + :meth:`pyspark.functions.posexplode_outer` + + Examples + -------- + >>> from pyspark.sql import Row + >>> eDF = spark.createDataFrame([Row(a=1, intlist=[1,2,3], mapfield={"a": "b"})]) + >>> eDF.select(explode(eDF.intlist).alias("anInt")).collect() + [Row(anInt=1), Row(anInt=2), Row(anInt=3)] + + >>> eDF.select(explode(eDF.mapfield).alias("key", "value")).show() + +---+-----+ + |key|value| + +---+-----+ + | a| b| + +---+-----+ + """ + return _invoke_function_over_columns("explode", col) + + +def explode_outer(col: "ColumnOrName") -> Column: + """ + Returns a new row for each element in the given array or map. + Unlike explode, if the array/map is null or empty then null is produced. + Uses the default column name `col` for elements in the array and + `key` and `value` for elements in the map unless specified otherwise. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + target column to work on. + + Returns + ------- + :class:`~pyspark.sql.Column` + one row per array item or map key value. + + Examples + -------- + >>> df = spark.createDataFrame( + ... [(1, ["foo", "bar"], {"x": 1.0}), (2, [], {}), (3, None, None)], + ... ("id", "an_array", "a_map") + ... ) + >>> df.select("id", "an_array", explode_outer("a_map")).show() + +---+----------+----+-----+ + | id| an_array| key|value| + +---+----------+----+-----+ + | 1|[foo, bar]| x| 1.0| + | 2| []|null| null| + | 3| null|null| null| + +---+----------+----+-----+ + + >>> df.select("id", "a_map", explode_outer("an_array")).show() + +---+----------+----+ + | id| a_map| col| + +---+----------+----+ + | 1|{x -> 1.0}| foo| + | 1|{x -> 1.0}| bar| + | 2| {}|null| + | 3| null|null| + +---+----------+----+ + """ + return _invoke_function_over_columns("explode_outer", col) + + +# TODO(SPARK-41434): need to support LambdaFunction Expression first +# def filter( +# col: "ColumnOrName", +# f: Union[Callable[[Column], Column], Callable[[Column, Column], Column]], +# ) -> Column: +# """ +# Returns an array of elements for which a predicate holds in a given array. +# +# .. versionadded:: 3.1.0 +# +# Parameters +# ---------- +# col : :class:`~pyspark.sql.Column` or str +# name of column or expression +# f : function +# A function that returns the Boolean expression. +# Can take one of the following forms: +# +# - Unary ``(x: Column) -> Column: ...`` +# - Binary ``(x: Column, i: Column) -> Column...``, where the second argument is +# a 0-based index of the element. +# +# and can use methods of :class:`~pyspark.sql.Column`, functions defined in +# :py:mod:`pyspark.sql.functions` and Scala ``UserDefinedFunctions``. +# Python ``UserDefinedFunctions`` are not supported +# (`SPARK-27052 `__). +# +# Returns +# ------- +# :class:`~pyspark.sql.Column` +# filtered array of elements where given function evaluated to True +# when passed as an argument. +# +# Examples +# -------- +# >>> df = spark.createDataFrame( +# ... [(1, ["2018-09-20", "2019-02-03", "2019-07-01", "2020-06-01"])], +# ... ("key", "values") +# ... ) +# >>> def after_second_quarter(x): +# ... return month(to_date(x)) > 6 +# >>> df.select( +# ... filter("values", after_second_quarter).alias("after_second_quarter") +# ... ).show(truncate=False) +# +------------------------+ +# |after_second_quarter | +# +------------------------+ +# |[2018-09-20, 2019-07-01]| +# +------------------------+ +# """ +# return _invoke_higher_order_function("ArrayFilter", [col], [f]) + + +def flatten(col: "ColumnOrName") -> Column: + """ + Collection function: creates a single array from an array of arrays. + If a structure of nested arrays is deeper than two levels, + only one level of nesting is removed. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + name of column or expression + + Returns + ------- + :class:`~pyspark.sql.Column` + flattened array. + + Examples + -------- + >>> df = spark.createDataFrame([([[1, 2, 3], [4, 5], [6]],), ([None, [4, 5]],)], ['data']) + >>> df.show(truncate=False) + +------------------------+ + |data | + +------------------------+ + |[[1, 2, 3], [4, 5], [6]]| + |[null, [4, 5]] | + +------------------------+ + >>> df.select(flatten(df.data).alias('r')).show() + +------------------+ + | r| + +------------------+ + |[1, 2, 3, 4, 5, 6]| + | null| + +------------------+ + """ + return _invoke_function_over_columns("flatten", col) + + +# TODO(SPARK-41434): need to support LambdaFunction Expression first +# def forall(col: "ColumnOrName", f: Callable[[Column], Column]) -> Column: +# """ +# Returns whether a predicate holds for every element in the array. +# +# .. versionadded:: 3.1.0 +# +# Parameters +# ---------- +# col : :class:`~pyspark.sql.Column` or str +# name of column or expression +# f : function +# ``(x: Column) -> Column: ...`` returning the Boolean expression. +# Can use methods of :class:`~pyspark.sql.Column`, functions defined in +# :py:mod:`pyspark.sql.functions` and Scala ``UserDefinedFunctions``. +# Python ``UserDefinedFunctions`` are not supported +# (`SPARK-27052 `__). +# +# Returns +# ------- +# :class:`~pyspark.sql.Column` +# True if "all" elements of an array evaluates to True when passed as an argument to +# given function and False otherwise. +# +# Examples +# -------- +# >>> df = spark.createDataFrame( +# ... [(1, ["bar"]), (2, ["foo", "bar"]), (3, ["foobar", "foo"])], +# ... ("key", "values") +# ... ) +# >>> df.select(forall("values", lambda x: x.rlike("foo")).alias("all_foo")).show() +# +-------+ +# |all_foo| +# +-------+ +# | false| +# | false| +# | true| +# +-------+ +# """ +# return _invoke_higher_order_function("ArrayForAll", [col], [f]) + + +# TODO: support options +def from_csv( + col: "ColumnOrName", + schema: Union[Column, str], +) -> Column: + """ + Parses a column containing a CSV string to a row with the specified schema. + Returns `null`, in the case of an unparseable string. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + a column or column name in CSV format + schema :class:`~pyspark.sql.Column` or str + a column, or Python string literal with schema in DDL format, to use + when parsing the CSV column. + + Returns + ------- + :class:`~pyspark.sql.Column` + a column of parsed CSV values + + Examples + -------- + >>> data = [("1,2,3",)] + >>> df = spark.createDataFrame(data, ("value",)) + >>> df.select(from_csv(df.value, "a INT, b INT, c INT").alias("csv")).collect() + [Row(csv=Row(a=1, b=2, c=3))] + >>> value = data[0][0] + >>> df.select(from_csv(df.value, schema_of_csv(value)).alias("csv")).collect() + [Row(csv=Row(_c0=1, _c1=2, _c2=3))] + >>> data = [(" abc",)] + >>> df = spark.createDataFrame(data, ("value",)) + >>> options = {'ignoreLeadingWhiteSpace': True} + >>> df.select(from_csv(df.value, "s string", options).alias("csv")).collect() + [Row(csv=Row(s='abc'))] + """ + + if isinstance(schema, Column): + _schema = schema + elif isinstance(schema, str): + _schema = lit(schema) + else: + raise TypeError(f"schema should be a Column or str, but got {type(schema).__name__}") + + return _invoke_function("from_csv", _to_col(col), _schema) + + +# TODO: 1, support ArrayType and StructType schema; 2, support options +def from_json( + col: "ColumnOrName", + schema: Union[Column, str], +) -> Column: + """ + Parses a column containing a JSON string into a :class:`MapType` with :class:`StringType` + as keys type, :class:`StructType` or :class:`ArrayType` with + the specified schema. Returns `null`, in the case of an unparseable string. + + .. versionadded:: 2.1.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + a column or column name in JSON format + schema :class:`~pyspark.sql.Column` or str + a column, or Python string literal with schema in DDL format, to use when + parsing the JSON column. + + Returns + ------- + :class:`~pyspark.sql.Column` + a new column of complex type from given JSON object. + + Examples + -------- + >>> from pyspark.sql.types import * + >>> data = [(1, '''{"a": 1}''')] + >>> schema = StructType([StructField("a", IntegerType())]) + >>> df = spark.createDataFrame(data, ("key", "value")) + >>> df.select(from_json(df.value, schema).alias("json")).collect() + [Row(json=Row(a=1))] + >>> df.select(from_json(df.value, "a INT").alias("json")).collect() + [Row(json=Row(a=1))] + >>> df.select(from_json(df.value, "MAP").alias("json")).collect() + [Row(json={'a': 1})] + >>> data = [(1, '''[{"a": 1}]''')] + >>> schema = ArrayType(StructType([StructField("a", IntegerType())])) + >>> df = spark.createDataFrame(data, ("key", "value")) + >>> df.select(from_json(df.value, schema).alias("json")).collect() + [Row(json=[Row(a=1)])] + >>> schema = schema_of_json(lit('''{"a": 0}''')) + >>> df.select(from_json(df.value, schema).alias("json")).collect() + [Row(json=Row(a=None))] + >>> data = [(1, '''[1, 2, 3]''')] + >>> schema = ArrayType(IntegerType()) + >>> df = spark.createDataFrame(data, ("key", "value")) + >>> df.select(from_json(df.value, schema).alias("json")).collect() + [Row(json=[1, 2, 3])] + """ + + if isinstance(schema, Column): + _schema = schema + elif isinstance(schema, str): + _schema = lit(schema) + else: + raise TypeError(f"schema should be a Column or str, but got {type(schema).__name__}") + + return _invoke_function("from_json", _to_col(col), _schema) + + +def get(col: "ColumnOrName", index: Union["ColumnOrName", int]) -> Column: + """ + Collection function: Returns element of array at given (0-based) index. + If the index points outside of the array boundaries, then this function + returns NULL. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + name of column containing array + index : :class:`~pyspark.sql.Column` or str or int + index to check for in array + + Returns + ------- + :class:`~pyspark.sql.Column` + value at given position. + + Notes + ----- + The position is not 1 based, but 0 based index. + + See Also + -------- + :meth:`element_at` + + Examples + -------- + >>> df = spark.createDataFrame([(["a", "b", "c"], 1)], ['data', 'index']) + >>> df.select(get(df.data, 1)).show() + +------------+ + |get(data, 1)| + +------------+ + | b| + +------------+ + + >>> df.select(get(df.data, -1)).show() + +-------------+ + |get(data, -1)| + +-------------+ + | null| + +-------------+ + + >>> df.select(get(df.data, 3)).show() + +------------+ + |get(data, 3)| + +------------+ + | null| + +------------+ + + >>> df.select(get(df.data, "index")).show() + +----------------+ + |get(data, index)| + +----------------+ + | b| + +----------------+ + + >>> df.select(get(df.data, col("index") - 1)).show() + +----------------------+ + |get(data, (index - 1))| + +----------------------+ + | a| + +----------------------+ + """ + index = lit(index) if isinstance(index, int) else index + + return _invoke_function_over_columns("get", col, index) + + +def get_json_object(col: "ColumnOrName", path: str) -> Column: + """ + Extracts json object from a json string based on json `path` specified, and returns json string + of the extracted json object. It will return null if the input json string is invalid. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + string column in json format + path : str + path to the json object to extract + + Returns + ------- + :class:`~pyspark.sql.Column` + string representation of given JSON object value. + + Examples + -------- + >>> data = [("1", '''{"f1": "value1", "f2": "value2"}'''), ("2", '''{"f1": "value12"}''')] + >>> df = spark.createDataFrame(data, ("key", "jstring")) + >>> df.select(df.key, get_json_object(df.jstring, '$.f1').alias("c0"), \\ + ... get_json_object(df.jstring, '$.f2').alias("c1") ).collect() + [Row(key='1', c0='value1', c1='value2'), Row(key='2', c0='value12', c1=None)] + """ + return _invoke_function("get_json_object", _to_col(col), lit(path)) + + +def inline(col: "ColumnOrName") -> Column: + """ + Explodes an array of structs into a table. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + input column of values to explode. + + Returns + ------- + :class:`~pyspark.sql.Column` + generator expression with the inline exploded result. + + See Also + -------- + :meth:`explode` + + Examples + -------- + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([Row(structlist=[Row(a=1, b=2), Row(a=3, b=4)])]) + >>> df.select(inline(df.structlist)).show() + +---+---+ + | a| b| + +---+---+ + | 1| 2| + | 3| 4| + +---+---+ + """ + return _invoke_function_over_columns("inline", col) + + +def inline_outer(col: "ColumnOrName") -> Column: + """ + Explodes an array of structs into a table. + Unlike inline, if the array is null or empty then null is produced for each nested column. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + input column of values to explode. + + Returns + ------- + :class:`~pyspark.sql.Column` + generator expression with the inline exploded result. + + See Also + -------- + :meth:`explode_outer` + :meth:`inline` + + Examples + -------- + >>> from pyspark.sql import Row + >>> df = spark.createDataFrame([ + ... Row(id=1, structlist=[Row(a=1, b=2), Row(a=3, b=4)]), + ... Row(id=2, structlist=[]) + ... ]) + >>> df.select('id', inline_outer(df.structlist)).show() + +---+----+----+ + | id| a| b| + +---+----+----+ + | 1| 1| 2| + | 1| 3| 4| + | 2|null|null| + +---+----+----+ + """ + return _invoke_function_over_columns("inline_outer", col) + + +def json_tuple(col: "ColumnOrName", *fields: str) -> Column: + """Creates a new row for a json column according to the given field names. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + string column in json format + fields : str + a field or fields to extract + + Returns + ------- + :class:`~pyspark.sql.Column` + a new row for each given field value from json object + + Examples + -------- + >>> data = [("1", '''{"f1": "value1", "f2": "value2"}'''), ("2", '''{"f1": "value12"}''')] + >>> df = spark.createDataFrame(data, ("key", "jstring")) + >>> df.select(df.key, json_tuple(df.jstring, 'f1', 'f2')).collect() + [Row(key='1', c0='value1', c1='value2'), Row(key='2', c0='value12', c1=None)] + """ + + return _invoke_function("json_tuple", _to_col(col), *[lit(field) for field in fields]) + + +def map_concat( + *cols: Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]] +) -> Column: + """Returns the union of all the given maps. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + cols : :class:`~pyspark.sql.Column` or str + column names or :class:`~pyspark.sql.Column`\\s + + Returns + ------- + :class:`~pyspark.sql.Column` + a map of merged entries from other maps. + + Examples + -------- + >>> from pyspark.sql.functions import map_concat + >>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as map1, map(3, 'c') as map2") + >>> df.select(map_concat("map1", "map2").alias("map3")).show(truncate=False) + +------------------------+ + |map3 | + +------------------------+ + |{1 -> a, 2 -> b, 3 -> c}| + +------------------------+ + """ + if len(cols) == 1 and isinstance(cols[0], (list, set, tuple)): + cols = cols[0] # type: ignore[assignment] + return _invoke_function_over_columns("map_concat", *cols) # type: ignore[arg-type] + + +def map_contains_key(col: "ColumnOrName", value: Any) -> Column: + """ + Returns true if the map contains the key. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + name of column or expression + value : + a literal value + + Returns + ------- + :class:`~pyspark.sql.Column` + True if key is in the map and False otherwise. + + Examples + -------- + >>> from pyspark.sql.functions import map_contains_key + >>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as data") + >>> df.select(map_contains_key("data", 1)).show() + +---------------------------------+ + |array_contains(map_keys(data), 1)| + +---------------------------------+ + | true| + +---------------------------------+ + >>> df.select(map_contains_key("data", -1)).show() + +----------------------------------+ + |array_contains(map_keys(data), -1)| + +----------------------------------+ + | false| + +----------------------------------+ + """ + return array_contains(map_keys(col), lit(value)) + + +def map_entries(col: "ColumnOrName") -> Column: + """ + Collection function: Returns an unordered array of all entries in the given map. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + name of column or expression + + Returns + ------- + :class:`~pyspark.sql.Column` + ar array of key value pairs as a struct type + + Examples + -------- + >>> from pyspark.sql.functions import map_entries + >>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as data") + >>> df = df.select(map_entries("data").alias("entries")) + >>> df.show() + +----------------+ + | entries| + +----------------+ + |[{1, a}, {2, b}]| + +----------------+ + >>> df.printSchema() + root + |-- entries: array (nullable = false) + | |-- element: struct (containsNull = false) + | | |-- key: integer (nullable = false) + | | |-- value: string (nullable = false) + """ + return _invoke_function_over_columns("map_entries", col) + + +# TODO(SPARK-41434): need to support LambdaFunction Expression first +# def map_filter(col: "ColumnOrName", f: Callable[[Column, Column], Column]) -> Column: +# """ +# Returns a map whose key-value pairs satisfy a predicate. +# +# .. versionadded:: 3.1.0 +# +# Parameters +# ---------- +# col : :class:`~pyspark.sql.Column` or str +# name of column or expression +# f : function +# a binary function ``(k: Column, v: Column) -> Column...`` +# Can use methods of :class:`~pyspark.sql.Column`, functions defined in +# :py:mod:`pyspark.sql.functions` and Scala ``UserDefinedFunctions``. +# Python ``UserDefinedFunctions`` are not supported +# (`SPARK-27052 `__). +# +# Returns +# ------- +# :class:`~pyspark.sql.Column` +# filtered map. +# +# Examples +# -------- +# >>> df = spark.createDataFrame([(1, {"foo": 42.0, "bar": 1.0, "baz": 32.0})], ("id", "data")) +# >>> df.select(map_filter( +# ... "data", lambda _, v: v > 30.0).alias("data_filtered") +# ... ).show(truncate=False) +# +--------------------------+ +# |data_filtered | +# +--------------------------+ +# |{baz -> 32.0, foo -> 42.0}| +# +--------------------------+ +# """ +# return _invoke_higher_order_function("MapFilter", [col], [f]) + + +def map_from_arrays(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: + """Creates a new map from two arrays. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col1 : :class:`~pyspark.sql.Column` or str + name of column containing a set of keys. All elements should not be null + col2 : :class:`~pyspark.sql.Column` or str + name of column containing a set of values + + Returns + ------- + :class:`~pyspark.sql.Column` + a column of map type. + + Examples + -------- + >>> df = spark.createDataFrame([([2, 5], ['a', 'b'])], ['k', 'v']) + >>> df = df.select(map_from_arrays(df.k, df.v).alias("col")) + >>> df.show() + +----------------+ + | col| + +----------------+ + |{2 -> a, 5 -> b}| + +----------------+ + >>> df.printSchema() + root + |-- col: map (nullable = true) + | |-- key: long + | |-- value: string (valueContainsNull = true) + """ + return _invoke_function_over_columns("map_from_arrays", col1, col2) + + +def map_from_entries(col: "ColumnOrName") -> Column: + """ + Collection function: Converts an array of entries (key value struct types) to a map + of values. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + name of column or expression + + Returns + ------- + :class:`~pyspark.sql.Column` + a map created from the given array of entries. + + Examples + -------- + >>> from pyspark.sql.functions import map_from_entries + >>> df = spark.sql("SELECT array(struct(1, 'a'), struct(2, 'b')) as data") + >>> df.select(map_from_entries("data").alias("map")).show() + +----------------+ + | map| + +----------------+ + |{1 -> a, 2 -> b}| + +----------------+ + """ + return _invoke_function_over_columns("map_from_entries", col) + + +def map_keys(col: "ColumnOrName") -> Column: + """ + Collection function: Returns an unordered array containing the keys of the map. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + name of column or expression + + Returns + ------- + :class:`~pyspark.sql.Column` + keys of the map as an array. + + Examples + -------- + >>> from pyspark.sql.functions import map_keys + >>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as data") + >>> df.select(map_keys("data").alias("keys")).show() + +------+ + | keys| + +------+ + |[1, 2]| + +------+ + """ + return _invoke_function_over_columns("map_keys", col) + + +def map_values(col: "ColumnOrName") -> Column: + """ + Collection function: Returns an unordered array containing the values of the map. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + name of column or expression + + Returns + ------- + :class:`~pyspark.sql.Column` + values of the map as an array. + + Examples + -------- + >>> from pyspark.sql.functions import map_values + >>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as data") + >>> df.select(map_values("data").alias("values")).show() + +------+ + |values| + +------+ + |[a, b]| + +------+ + """ + return _invoke_function_over_columns("map_values", col) + + +# TODO(SPARK-41434): need to support LambdaFunction Expression first +# def map_zip_with( +# col1: "ColumnOrName", +# col2: "ColumnOrName", +# f: Callable[[Column, Column, Column], Column], +# ) -> Column: +# """ +# Merge two given maps, key-wise into a single map using a function. +# +# .. versionadded:: 3.1.0 +# +# Parameters +# ---------- +# col1 : :class:`~pyspark.sql.Column` or str +# name of the first column or expression +# col2 : :class:`~pyspark.sql.Column` or str +# name of the second column or expression +# f : function +# a ternary function ``(k: Column, v1: Column, v2: Column) -> Column...`` +# Can use methods of :class:`~pyspark.sql.Column`, functions defined in +# :py:mod:`pyspark.sql.functions` and Scala ``UserDefinedFunctions``. +# Python ``UserDefinedFunctions`` are not supported +# (`SPARK-27052 `__). +# +# Returns +# ------- +# :class:`~pyspark.sql.Column` +# zipped map where entries are calculated by applying given function to each +# pair of arguments. +# +# Examples +# -------- +# >>> df = spark.createDataFrame([ +# ... (1, {"IT": 24.0, "SALES": 12.00}, {"IT": 2.0, "SALES": 1.4})], +# ... ("id", "base", "ratio") +# ... ) +# >>> df.select(map_zip_with( +# ... "base", "ratio", lambda k, v1, v2: round(v1 * v2, 2)).alias("updated_data") +# ... ).show(truncate=False) +# +---------------------------+ +# |updated_data | +# +---------------------------+ +# |{SALES -> 16.8, IT -> 48.0}| +# +---------------------------+ +# """ +# return _invoke_higher_order_function("MapZipWith", [col1, col2], [f]) + + +def posexplode(col: "ColumnOrName") -> Column: + """ + Returns a new row for each element with position in the given array or map. + Uses the default column name `pos` for position, and `col` for elements in the + array and `key` and `value` for elements in the map unless specified otherwise. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + target column to work on. + + Returns + ------- + :class:`~pyspark.sql.Column` + one row per array item or map key value including positions as a separate column. + + Examples + -------- + >>> from pyspark.sql import Row + >>> eDF = spark.createDataFrame([Row(a=1, intlist=[1,2,3], mapfield={"a": "b"})]) + >>> eDF.select(posexplode(eDF.intlist)).collect() + [Row(pos=0, col=1), Row(pos=1, col=2), Row(pos=2, col=3)] + + >>> eDF.select(posexplode(eDF.mapfield)).show() + +---+---+-----+ + |pos|key|value| + +---+---+-----+ + | 0| a| b| + +---+---+-----+ + """ + return _invoke_function_over_columns("posexplode", col) + + +def posexplode_outer(col: "ColumnOrName") -> Column: + """ + Returns a new row for each element with position in the given array or map. + Unlike posexplode, if the array/map is null or empty then the row (null, null) is produced. + Uses the default column name `pos` for position, and `col` for elements in the + array and `key` and `value` for elements in the map unless specified otherwise. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + target column to work on. + + Returns + ------- + :class:`~pyspark.sql.Column` + one row per array item or map key value including positions as a separate column. + + Examples + -------- + >>> df = spark.createDataFrame( + ... [(1, ["foo", "bar"], {"x": 1.0}), (2, [], {}), (3, None, None)], + ... ("id", "an_array", "a_map") + ... ) + >>> df.select("id", "an_array", posexplode_outer("a_map")).show() + +---+----------+----+----+-----+ + | id| an_array| pos| key|value| + +---+----------+----+----+-----+ + | 1|[foo, bar]| 0| x| 1.0| + | 2| []|null|null| null| + | 3| null|null|null| null| + +---+----------+----+----+-----+ + >>> df.select("id", "a_map", posexplode_outer("an_array")).show() + +---+----------+----+----+ + | id| a_map| pos| col| + +---+----------+----+----+ + | 1|{x -> 1.0}| 0| foo| + | 1|{x -> 1.0}| 1| bar| + | 2| {}|null|null| + | 3| null|null|null| + +---+----------+----+----+ + """ + return _invoke_function_over_columns("posexplode_outer", col) + + +def reverse(col: "ColumnOrName") -> Column: + """ + Collection function: returns a reversed string or an array with reverse order of elements. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + name of column or expression + + Returns + ------- + :class:`~pyspark.sql.Column` + array of elements in reverse order. + + Examples + -------- + >>> df = spark.createDataFrame([('Spark SQL',)], ['data']) + >>> df.select(reverse(df.data).alias('s')).collect() + [Row(s='LQS krapS')] + >>> df = spark.createDataFrame([([2, 1, 3],) ,([1],) ,([],)], ['data']) + >>> df.select(reverse(df.data).alias('r')).collect() + [Row(r=[3, 1, 2]), Row(r=[1]), Row(r=[])] + """ + return _invoke_function_over_columns("reverse", col) + + +# TODO(SPARK-41493): Support options +def schema_of_csv(csv: "ColumnOrName") -> Column: + """ + Parses a CSV string and infers its schema in DDL format. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + csv : :class:`~pyspark.sql.Column` or str + a CSV string or a foldable string column containing a CSV string. + + Returns + ------- + :class:`~pyspark.sql.Column` + a string representation of a :class:`StructType` parsed from given CSV. + + Examples + -------- + >>> df = spark.range(1) + >>> df.select(schema_of_csv(lit('1|a'), {'sep':'|'}).alias("csv")).collect() + [Row(csv='STRUCT<_c0: INT, _c1: STRING>')] + >>> df.select(schema_of_csv('1|a', {'sep':'|'}).alias("csv")).collect() + [Row(csv='STRUCT<_c0: INT, _c1: STRING>')] + """ + + if isinstance(csv, Column): + _csv = csv + elif isinstance(csv, str): + _csv = lit(csv) + else: + raise TypeError(f"csv should be a Column or str, but got {type(csv).__name__}") + + return _invoke_function("schema_of_csv", _csv) + + +# TODO(SPARK-41494): Support options +def schema_of_json(json: "ColumnOrName") -> Column: + """ + Parses a JSON string and infers its schema in DDL format. + + .. versionadded:: 2.4.0 + + Parameters + ---------- + json : :class:`~pyspark.sql.Column` or str + a JSON string or a foldable string column containing a JSON string. + + Returns + ------- + :class:`~pyspark.sql.Column` + a string representation of a :class:`StructType` parsed from given JSON. + + Examples + -------- + >>> df = spark.range(1) + >>> df.select(schema_of_json(lit('{"a": 0}')).alias("json")).collect() + [Row(json='STRUCT')] + >>> schema = schema_of_json('{a: 1}', {'allowUnquotedFieldNames':'true'}) + >>> df.select(schema.alias("json")).collect() + [Row(json='STRUCT')] + """ + + if isinstance(json, Column): + _json = json + elif isinstance(json, str): + _json = lit(json) + else: + raise TypeError(f"json should be a Column or str, but got {type(json).__name__}") + + return _invoke_function("schema_of_json", _json) + + +def shuffle(col: "ColumnOrName") -> Column: + """ + Collection function: Generates a random permutation of the given array. + + .. versionadded:: 3.4.0 + + Notes + ----- + The function is non-deterministic. + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + name of column or expression + + Returns + ------- + :class:`~pyspark.sql.Column` + an array of elements in random order. + + Examples + -------- + >>> df = spark.createDataFrame([([1, 20, 3, 5],), ([1, 20, None, 3],)], ['data']) + >>> df.select(shuffle(df.data).alias('s')).collect() # doctest: +SKIP + [Row(s=[3, 1, 5, 20]), Row(s=[20, None, 3, 1])] + """ + return _invoke_function_over_columns("shuffle", col) + + +def size(col: "ColumnOrName") -> Column: + """ + Collection function: returns the length of the array or map stored in the column. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + name of column or expression + + Returns + ------- + :class:`~pyspark.sql.Column` + length of the array/map. + + Examples + -------- + >>> df = spark.createDataFrame([([1, 2, 3],),([1],),([],)], ['data']) + >>> df.select(size(df.data)).collect() + [Row(size(data)=3), Row(size(data)=1), Row(size(data)=0)] + """ + return _invoke_function_over_columns("size", col) + + +def slice( + col: "ColumnOrName", start: Union["ColumnOrName", int], length: Union["ColumnOrName", int] +) -> Column: + """ + Collection function: returns an array containing all the elements in `x` from index `start` + (array indices start at 1, or from the end if `start` is negative) with the specified `length`. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + column name or column containing the array to be sliced + start : :class:`~pyspark.sql.Column` or str or int + column name, column, or int containing the starting index + length : :class:`~pyspark.sql.Column` or str or int + column name, column, or int containing the length of the slice + + Returns + ------- + :class:`~pyspark.sql.Column` + a column of array type. Subset of array. + + Examples + -------- + >>> df = spark.createDataFrame([([1, 2, 3],), ([4, 5],)], ['x']) + >>> df.select(slice(df.x, 2, 2).alias("sliced")).collect() + [Row(sliced=[2, 3]), Row(sliced=[5])] + """ + if isinstance(start, Column): + _start = start + elif isinstance(start, int): + _start = lit(start) + else: + raise TypeError(f"start should be a Column or int, but got {type(start).__name__}") + + if isinstance(length, Column): + _length = length + elif isinstance(length, int): + _length = lit(length) + else: + raise TypeError(f"start should be a Column or int, but got {type(length).__name__}") + + return _invoke_function("slice", _to_col(col), _start, _length) + + +def sort_array(col: "ColumnOrName", asc: bool = True) -> Column: + """ + Collection function: sorts the input array in ascending or descending order according + to the natural ordering of the array elements. Null elements will be placed at the beginning + of the returned array in ascending order or at the end of the returned array in descending + order. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + name of column or expression + asc : bool, optional + whether to sort in ascending or descending order. If `asc` is True (default) + then ascending and if False then descending. + + Returns + ------- + :class:`~pyspark.sql.Column` + sorted array. + + Examples + -------- + >>> df = spark.createDataFrame([([2, 1, None, 3],),([1],),([],)], ['data']) + >>> df.select(sort_array(df.data).alias('r')).collect() + [Row(r=[None, 1, 2, 3]), Row(r=[1]), Row(r=[])] + >>> df.select(sort_array(df.data, asc=False).alias('r')).collect() + [Row(r=[3, 2, 1, None]), Row(r=[1]), Row(r=[])] + """ + return _invoke_function("sort_array", _to_col(col), lit(asc)) + + +def struct( + *cols: Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]] +) -> Column: + """Creates a new struct column. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + cols : list, set, str or :class:`~pyspark.sql.Column` + column names or :class:`~pyspark.sql.Column`\\s to contain in the output struct. + + Returns + ------- + :class:`~pyspark.sql.Column` + a struct type column of given columns. + + Examples + -------- + >>> df = spark.createDataFrame([("Alice", 2), ("Bob", 5)], ("name", "age")) + >>> df.select(struct('age', 'name').alias("struct")).collect() + [Row(struct=Row(age=2, name='Alice')), Row(struct=Row(age=5, name='Bob'))] + >>> df.select(struct([df.age, df.name]).alias("struct")).collect() + [Row(struct=Row(age=2, name='Alice')), Row(struct=Row(age=5, name='Bob'))] + """ + if len(cols) == 1 and isinstance(cols[0], (list, set, tuple)): + cols = cols[0] # type: ignore[assignment] + return _invoke_function_over_columns("struct", *cols) # type: ignore[arg-type] + + +# TODO(SPARK-41493): Support options +def to_csv(col: "ColumnOrName") -> Column: + """ + Converts a column containing a :class:`StructType` into a CSV string. + Throws an exception, in the case of an unsupported type. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + name of column containing a struct. + + Returns + ------- + :class:`~pyspark.sql.Column` + a CSV string converted from given :class:`StructType`. + + Examples + -------- + >>> from pyspark.sql import Row + >>> data = [(1, Row(age=2, name='Alice'))] + >>> df = spark.createDataFrame(data, ("key", "value")) + >>> df.select(to_csv(df.value).alias("csv")).collect() + [Row(csv='2,Alice')] + """ + + return _invoke_function("to_csv", _to_col(col)) + + +# TODO(SPARK-41494): Support options +def to_json(col: "ColumnOrName") -> Column: + """ + Converts a column containing a :class:`StructType`, :class:`ArrayType` or a :class:`MapType` + into a JSON string. Throws an exception, in the case of an unsupported type. + + .. versionadded:: 2.1.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + name of column containing a struct, an array or a map. + + Returns + ------- + :class:`~pyspark.sql.Column` + JSON object as string column. + + Examples + -------- + >>> from pyspark.sql import Row + >>> from pyspark.sql.types import * + >>> data = [(1, Row(age=2, name='Alice'))] + >>> df = spark.createDataFrame(data, ("key", "value")) + >>> df.select(to_json(df.value).alias("json")).collect() + [Row(json='{"age":2,"name":"Alice"}')] + >>> data = [(1, [Row(age=2, name='Alice'), Row(age=3, name='Bob')])] + >>> df = spark.createDataFrame(data, ("key", "value")) + >>> df.select(to_json(df.value).alias("json")).collect() + [Row(json='[{"age":2,"name":"Alice"},{"age":3,"name":"Bob"}]')] + >>> data = [(1, {"name": "Alice"})] + >>> df = spark.createDataFrame(data, ("key", "value")) + >>> df.select(to_json(df.value).alias("json")).collect() + [Row(json='{"name":"Alice"}')] + >>> data = [(1, [{"name": "Alice"}, {"name": "Bob"}])] + >>> df = spark.createDataFrame(data, ("key", "value")) + >>> df.select(to_json(df.value).alias("json")).collect() + [Row(json='[{"name":"Alice"},{"name":"Bob"}]')] + >>> data = [(1, ["Alice", "Bob"])] + >>> df = spark.createDataFrame(data, ("key", "value")) + >>> df.select(to_json(df.value).alias("json")).collect() + [Row(json='["Alice","Bob"]')] + """ + + return _invoke_function("to_json", _to_col(col)) + + +# TODO(SPARK-41434): need to support LambdaFunction Expression first +# def transform( +# col: "ColumnOrName", +# f: Union[Callable[[Column], Column], Callable[[Column, Column], Column]], +# ) -> Column: +# """ +# Returns an array of elements after applying a transformation to each element in +# the input array. +# +# .. versionadded:: 3.1.0 +# +# Parameters +# ---------- +# col : :class:`~pyspark.sql.Column` or str +# name of column or expression +# f : function +# a function that is applied to each element of the input array. +# Can take one of the following forms: +# +# - Unary ``(x: Column) -> Column: ...`` +# - Binary ``(x: Column, i: Column) -> Column...``, where the second argument is +# a 0-based index of the element. +# +# and can use methods of :class:`~pyspark.sql.Column`, functions defined in +# :py:mod:`pyspark.sql.functions` and Scala ``UserDefinedFunctions``. +# Python ``UserDefinedFunctions`` are not supported +# (`SPARK-27052 `__). +# +# Returns +# ------- +# :class:`~pyspark.sql.Column` +# a new array of transformed elements. +# +# Examples +# -------- +# >>> df = spark.createDataFrame([(1, [1, 2, 3, 4])], ("key", "values")) +# >>> df.select(transform("values", lambda x: x * 2).alias("doubled")).show() +# +------------+ +# | doubled| +# +------------+ +# |[2, 4, 6, 8]| +# +------------+ +# +# >>> def alternate(x, i): +# ... return when(i % 2 == 0, x).otherwise(-x) +# >>> df.select(transform("values", alternate).alias("alternated")).show() +# +--------------+ +# | alternated| +# +--------------+ +# |[1, -2, 3, -4]| +# +--------------+ +# """ +# return _invoke_higher_order_function("ArrayTransform", [col], [f]) + + +# TODO(SPARK-41434): need to support LambdaFunction Expression first +# def transform_keys(col: "ColumnOrName", f: Callable[[Column, Column], Column]) -> Column: +# """ +# Applies a function to every key-value pair in a map and returns +# a map with the results of those applications as the new keys for the pairs. +# +# .. versionadded:: 3.1.0 +# +# Parameters +# ---------- +# col : :class:`~pyspark.sql.Column` or str +# name of column or expression +# f : function +# a binary function ``(k: Column, v: Column) -> Column...`` +# Can use methods of :class:`~pyspark.sql.Column`, functions defined in +# :py:mod:`pyspark.sql.functions` and Scala ``UserDefinedFunctions``. +# Python ``UserDefinedFunctions`` are not supported +# (`SPARK-27052 `__). +# +# Returns +# ------- +# :class:`~pyspark.sql.Column` +# a new map of enties where new keys were calculated by applying given function to +# each key value argument. +# +# Examples +# -------- +# >>> df = spark.createDataFrame([(1, {"foo": -2.0, "bar": 2.0})], ("id", "data")) +# >>> df.select(transform_keys( +# ... "data", lambda k, _: upper(k)).alias("data_upper") +# ... ).show(truncate=False) +# +-------------------------+ +# |data_upper | +# +-------------------------+ +# |{BAR -> 2.0, FOO -> -2.0}| +# +-------------------------+ +# """ +# return _invoke_higher_order_function("TransformKeys", [col], [f]) + +# TODO(SPARK-41434): need to support LambdaFunction Expression first +# def transform_values(col: "ColumnOrName", f: Callable[[Column, Column], Column]) -> Column: +# """ +# Applies a function to every key-value pair in a map and returns +# a map with the results of those applications as the new values for the pairs. +# +# .. versionadded:: 3.1.0 +# +# Parameters +# ---------- +# col : :class:`~pyspark.sql.Column` or str +# name of column or expression +# f : function +# a binary function ``(k: Column, v: Column) -> Column...`` +# Can use methods of :class:`~pyspark.sql.Column`, functions defined in +# :py:mod:`pyspark.sql.functions` and Scala ``UserDefinedFunctions``. +# Python ``UserDefinedFunctions`` are not supported +# (`SPARK-27052 `__). +# +# Returns +# ------- +# :class:`~pyspark.sql.Column` +# a new map of enties where new values were calculated by applying given function to +# each key value argument. +# +# Examples +# -------- +# >>> df = spark.createDataFrame([(1, {"IT": 10.0, "SALES": 2.0, "OPS": 24.0})], ("id", "data")) +# >>> df.select(transform_values( +# ... "data", lambda k, v: when(k.isin("IT", "OPS"), v + 10.0).otherwise(v) +# ... ).alias("new_data")).show(truncate=False) +# +---------------------------------------+ +# |new_data | +# +---------------------------------------+ +# |{OPS -> 34.0, IT -> 20.0, SALES -> 2.0}| +# +---------------------------------------+ +# """ +# return _invoke_higher_order_function("TransformValues", [col], [f]) + + +# TODO(SPARK-41434): need to support LambdaFunction Expression first +# def zip_with( +# left: "ColumnOrName", +# right: "ColumnOrName", +# f: Callable[[Column, Column], Column], +# ) -> Column: +# """ +# Merge two given arrays, element-wise, into a single array using a function. +# If one array is shorter, nulls are appended at the end to match the length of the longer +# array, before applying the function. +# +# .. versionadded:: 3.1.0 +# +# Parameters +# ---------- +# left : :class:`~pyspark.sql.Column` or str +# name of the first column or expression +# right : :class:`~pyspark.sql.Column` or str +# name of the second column or expression +# f : function +# a binary function ``(x1: Column, x2: Column) -> Column...`` +# Can use methods of :class:`~pyspark.sql.Column`, functions defined in +# :py:mod:`pyspark.sql.functions` and Scala ``UserDefinedFunctions``. +# Python ``UserDefinedFunctions`` are not supported +# (`SPARK-27052 `__). +# +# Returns +# ------- +# :class:`~pyspark.sql.Column` +# array of calculated values derived by applying given function to each pair of arguments. +# +# Examples +# -------- +# >>> df = spark.createDataFrame([(1, [1, 3, 5, 8], [0, 2, 4, 6])], ("id", "xs", "ys")) +# >>> df.select(zip_with("xs", "ys", lambda x, y: x ** y).alias("powers")).show(truncate=False) +# +---------------------------+ +# |powers | +# +---------------------------+ +# |[1.0, 9.0, 625.0, 262144.0]| +# +---------------------------+ +# +# >>> df = spark.createDataFrame([(1, ["foo", "bar"], [1, 2, 3])], ("id", "xs", "ys")) +# >>> df.select(zip_with("xs", "ys", lambda x, y: concat_ws("_", x, y)).alias("xs_ys")).show() +# +-----------------+ +# | xs_ys| +# +-----------------+ +# |[foo_1, bar_2, 3]| +# +-----------------+ +# """ +# return _invoke_higher_order_function("ZipWith", [left, right], [f]) + + +# String/Binary functions + + +def upper(col: "ColumnOrName") -> Column: + """ + Converts a string expression to upper case. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + target column to work on. + + Returns + ------- + :class:`~pyspark.sql.Column` + upper case values. + + Examples + -------- + >>> df = spark.createDataFrame(["Spark", "PySpark", "Pandas API"], "STRING") + >>> df.select(upper("value")).show() + +------------+ + |upper(value)| + +------------+ + | SPARK| + | PYSPARK| + | PANDAS API| + +------------+ + """ + return _invoke_function_over_columns("upper", col) + + +def lower(col: "ColumnOrName") -> Column: + """ + Converts a string expression to lower case. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + target column to work on. + + Returns + ------- + :class:`~pyspark.sql.Column` + lower case values. + + Examples + -------- + >>> df = spark.createDataFrame(["Spark", "PySpark", "Pandas API"], "STRING") + >>> df.select(lower("value")).show() + +------------+ + |lower(value)| + +------------+ + | spark| + | pyspark| + | pandas api| + +------------+ + """ + return _invoke_function_over_columns("lower", col) + + +def ascii(col: "ColumnOrName") -> Column: + """ + Computes the numeric value of the first character of the string column. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + target column to work on. + + Returns + ------- + :class:`~pyspark.sql.Column` + numeric value. + + Examples + -------- + >>> df = spark.createDataFrame(["Spark", "PySpark", "Pandas API"], "STRING") + >>> df.select(ascii("value")).show() + +------------+ + |ascii(value)| + +------------+ + | 83| + | 80| + | 80| + +------------+ + """ + return _invoke_function_over_columns("ascii", col) + + +def base64(col: "ColumnOrName") -> Column: + """ + Computes the BASE64 encoding of a binary column and returns it as a string column. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + target column to work on. + + Returns + ------- + :class:`~pyspark.sql.Column` + BASE64 encoding of string value. + + Examples + -------- + >>> df = spark.createDataFrame(["Spark", "PySpark", "Pandas API"], "STRING") + >>> df.select(base64("value")).show() + +----------------+ + | base64(value)| + +----------------+ + | U3Bhcms=| + | UHlTcGFyaw==| + |UGFuZGFzIEFQSQ==| + +----------------+ + """ + return _invoke_function_over_columns("base64", col) + + +def unbase64(col: "ColumnOrName") -> Column: + """ + Decodes a BASE64 encoded string column and returns it as a binary column. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + target column to work on. + + Returns + ------- + :class:`~pyspark.sql.Column` + encoded string value. + + Examples + -------- + >>> df = spark.createDataFrame(["U3Bhcms=", + ... "UHlTcGFyaw==", + ... "UGFuZGFzIEFQSQ=="], "STRING") + >>> df.select(unbase64("value")).show() + +--------------------+ + | unbase64(value)| + +--------------------+ + | [53 70 61 72 6B]| + |[50 79 53 70 61 7...| + |[50 61 6E 64 61 7...| + +--------------------+ + """ + return _invoke_function_over_columns("unbase64", col) + + +def ltrim(col: "ColumnOrName") -> Column: + """ + Trim the spaces from left end for the specified string value. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + target column to work on. + + Returns + ------- + :class:`~pyspark.sql.Column` + left trimmed values. + + Examples + -------- + >>> df = spark.createDataFrame([" Spark", "Spark ", " Spark"], "STRING") + >>> df.select(ltrim("value").alias("r")).withColumn("length", length("r")).show() + +-------+------+ + | r|length| + +-------+------+ + | Spark| 5| + |Spark | 7| + | Spark| 5| + +-------+------+ + """ + return _invoke_function_over_columns("ltrim", col) + + +def rtrim(col: "ColumnOrName") -> Column: + """ + Trim the spaces from right end for the specified string value. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + target column to work on. + + Returns + ------- + :class:`~pyspark.sql.Column` + right trimmed values. + + Examples + -------- + >>> df = spark.createDataFrame([" Spark", "Spark ", " Spark"], "STRING") + >>> df.select(rtrim("value").alias("r")).withColumn("length", length("r")).show() + +--------+------+ + | r|length| + +--------+------+ + | Spark| 8| + | Spark| 5| + | Spark| 6| + +--------+------+ + """ + return _invoke_function_over_columns("rtrim", col) + + +def trim(col: "ColumnOrName") -> Column: + """ + Trim the spaces from both ends for the specified string column. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + target column to work on. + + Returns + ------- + :class:`~pyspark.sql.Column` + trimmed values from both sides. + + Examples + -------- + >>> df = spark.createDataFrame([" Spark", "Spark ", " Spark"], "STRING") + >>> df.select(trim("value").alias("r")).withColumn("length", length("r")).show() + +-----+------+ + | r|length| + +-----+------+ + |Spark| 5| + |Spark| 5| + |Spark| 5| + +-----+------+ + """ + return _invoke_function_over_columns("trim", col) + + +def concat_ws(sep: str, *cols: "ColumnOrName") -> Column: + """ + Concatenates multiple input string columns together into a single string column, + using the given separator. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + sep : str + words separator. + cols : :class:`~pyspark.sql.Column` or str + list of columns to work on. + + Returns + ------- + :class:`~pyspark.sql.Column` + string of concatenated words. + + Examples + -------- + >>> df = spark.createDataFrame([('abcd','123')], ['s', 'd']) + >>> df.select(concat_ws('-', df.s, df.d).alias('s')).collect() + [Row(s='abcd-123')] + """ + return _invoke_function("concat_ws", lit(sep), *[_to_col(c) for c in cols]) + + +def decode(col: "ColumnOrName", charset: str) -> Column: + """ + Computes the first argument into a string from a binary using the provided character set + (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + target column to work on. + charset : str + charset to use to decode to. + + Returns + ------- + :class:`~pyspark.sql.Column` + the column for computed results. + + Examples + -------- + >>> df = spark.createDataFrame([('abcd',)], ['a']) + >>> df.select(decode("a", "UTF-8")).show() + +----------------+ + |decode(a, UTF-8)| + +----------------+ + | abcd| + +----------------+ + """ + return _invoke_function("decode", _to_col(col), lit(charset)) + + +def encode(col: "ColumnOrName", charset: str) -> Column: + """ + Computes the first argument into a binary from a string using the provided character set + (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + target column to work on. + charset : str + charset to use to encode. + + Returns + ------- + :class:`~pyspark.sql.Column` + the column for computed results. + + Examples + -------- + >>> df = spark.createDataFrame([('abcd',)], ['c']) + >>> df.select(encode("c", "UTF-8")).show() + +----------------+ + |encode(c, UTF-8)| + +----------------+ + | [61 62 63 64]| + +----------------+ + """ + return _invoke_function("encode", _to_col(col), lit(charset)) + + +# Date/Timestamp functions +# TODO(SPARK-41283): Resolve dtypes inconsistencies for: +# to_timestamp, from_utc_timestamp, to_utc_timestamp, +# timestamp_seconds, current_timestamp, date_trunc + + +def current_date() -> Column: + """ + Returns the current date at the start of query evaluation as a :class:`DateType` column. + All calls of current_date within the same query return the same value. + + .. versionadded:: 3.4.0 + + Returns + ------- + :class:`~pyspark.sql.Column` + current date. + + Examples + -------- + >>> df = spark.range(1) + >>> df.select(current_date()).show() # doctest: +SKIP + +--------------+ + |current_date()| + +--------------+ + | 2022-08-26| + +--------------+ + """ + return _invoke_function("current_date") + + +def current_timestamp() -> Column: + """ + Returns the current timestamp at the start of query evaluation as a :class:`TimestampType` + column. All calls of current_timestamp within the same query return the same value. + + .. versionadded:: 3.4.0 + + Returns + ------- + :class:`~pyspark.sql.Column` + current date and time. + + Examples + -------- + >>> df = spark.range(1) + >>> df.select(current_timestamp()).show(truncate=False) # doctest: +SKIP + +-----------------------+ + |current_timestamp() | + +-----------------------+ + |2022-08-26 21:23:22.716| + +-----------------------+ + """ + return _invoke_function("current_timestamp") + + +def localtimestamp() -> Column: + """ + Returns the current timestamp without time zone at the start of query evaluation + as a timestamp without time zone column. All calls of localtimestamp within the + same query return the same value. + + .. versionadded:: 3.4.0 + + Returns + ------- + :class:`~pyspark.sql.Column` + current local date and time. + + Examples + -------- + >>> df = spark.range(1) + >>> df.select(localtimestamp()).show(truncate=False) # doctest: +SKIP + +-----------------------+ + |localtimestamp() | + +-----------------------+ + |2022-08-26 21:28:34.639| + +-----------------------+ + """ + return _invoke_function("localtimestamp") + + +def date_format(date: "ColumnOrName", format: str) -> Column: + """ + Converts a date/timestamp/string to a value of string in the format specified by the date + format given by the second argument. + + A pattern could be for instance `dd.MM.yyyy` and could return a string like '18.03.1993'. All + pattern letters of `datetime pattern`_. can be used. + + .. _datetime pattern: https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html + + .. versionadded:: 3.4.0 + + Notes + ----- + Whenever possible, use specialized functions like `year`. + + Parameters + ---------- + date : :class:`~pyspark.sql.Column` or str + input column of values to format. + format: str + format to use to represent datetime values. + + Returns + ------- + :class:`~pyspark.sql.Column` + string value representing formatted datetime. + + Examples + -------- + >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) + >>> df.select(date_format('dt', 'MM/dd/yyy').alias('date')).collect() + [Row(date='04/08/2015')] + """ + return _invoke_function("date_format", _to_col(date), lit(format)) + + +def year(col: "ColumnOrName") -> Column: + """ + Extract the year of a given date/timestamp as integer. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + target date/timestamp column to work on. + + Returns + ------- + :class:`~pyspark.sql.Column` + year part of the date/timestamp as integer. + + Examples + -------- + >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) + >>> df.select(year('dt').alias('year')).collect() + [Row(year=2015)] + """ + return _invoke_function_over_columns("year", col) + + +def quarter(col: "ColumnOrName") -> Column: + """ + Extract the quarter of a given date/timestamp as integer. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + target date/timestamp column to work on. + + Returns + ------- + :class:`~pyspark.sql.Column` + quarter of the date/timestamp as integer. + + Examples + -------- + >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) + >>> df.select(quarter('dt').alias('quarter')).collect() + [Row(quarter=2)] + """ + return _invoke_function_over_columns("quarter", col) + + +def month(col: "ColumnOrName") -> Column: + """ + Extract the month of a given date/timestamp as integer. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + target date/timestamp column to work on. + + Returns + ------- + :class:`~pyspark.sql.Column` + month part of the date/timestamp as integer. + + Examples + -------- + >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) + >>> df.select(month('dt').alias('month')).collect() + [Row(month=4)] + """ + return _invoke_function_over_columns("month", col) + + +def dayofweek(col: "ColumnOrName") -> Column: + """ + Extract the day of the week of a given date/timestamp as integer. + Ranges from 1 for a Sunday through to 7 for a Saturday + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + target date/timestamp column to work on. + + Returns + ------- + :class:`~pyspark.sql.Column` + day of the week for given date/timestamp as integer. + + Examples + -------- + >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) + >>> df.select(dayofweek('dt').alias('day')).collect() + [Row(day=4)] + """ + return _invoke_function_over_columns("dayofweek", col) + + +def dayofmonth(col: "ColumnOrName") -> Column: + """ + Extract the day of the month of a given date/timestamp as integer. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + target date/timestamp column to work on. + + Returns + ------- + :class:`~pyspark.sql.Column` + day of the month for given date/timestamp as integer. + + Examples + -------- + >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) + >>> df.select(dayofmonth('dt').alias('day')).collect() + [Row(day=8)] + """ + return _invoke_function_over_columns("dayofmonth", col) + + +def dayofyear(col: "ColumnOrName") -> Column: + """ + Extract the day of the year of a given date/timestamp as integer. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + target date/timestamp column to work on. + + Returns + ------- + :class:`~pyspark.sql.Column` + day of the year for given date/timestamp as integer. + + Examples + -------- + >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) + >>> df.select(dayofyear('dt').alias('day')).collect() + [Row(day=98)] + """ + return _invoke_function_over_columns("dayofyear", col) + + +def hour(col: "ColumnOrName") -> Column: + """ + Extract the hours of a given timestamp as integer. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + target date/timestamp column to work on. + + Returns + ------- + :class:`~pyspark.sql.Column` + hour part of the timestamp as integer. + + Examples + -------- + >>> import datetime + >>> df = spark.createDataFrame([(datetime.datetime(2015, 4, 8, 13, 8, 15),)], ['ts']) + >>> df.select(hour('ts').alias('hour')).collect() + [Row(hour=13)] + """ + return _invoke_function_over_columns("hour", col) + + +def minute(col: "ColumnOrName") -> Column: + """ + Extract the minutes of a given timestamp as integer. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + target date/timestamp column to work on. + + Returns + ------- + :class:`~pyspark.sql.Column` + minutes part of the timestamp as integer. + + Examples + -------- + >>> import datetime + >>> df = spark.createDataFrame([(datetime.datetime(2015, 4, 8, 13, 8, 15),)], ['ts']) + >>> df.select(minute('ts').alias('minute')).collect() + [Row(minute=8)] + """ + return _invoke_function_over_columns("minute", col) + + +def second(col: "ColumnOrName") -> Column: + """ + Extract the seconds of a given date as integer. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + target date/timestamp column to work on. + + Returns + ------- + :class:`~pyspark.sql.Column` + `seconds` part of the timestamp as integer. + + Examples + -------- + >>> import datetime + >>> df = spark.createDataFrame([(datetime.datetime(2015, 4, 8, 13, 8, 15),)], ['ts']) + >>> df.select(second('ts').alias('second')).collect() + [Row(second=15)] + """ + return _invoke_function_over_columns("second", col) + + +def weekofyear(col: "ColumnOrName") -> Column: + """ + Extract the week number of a given date as integer. + A week is considered to start on a Monday and week 1 is the first week with more than 3 days, + as defined by ISO 8601 + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + target timestamp column to work on. + + Returns + ------- + :class:`~pyspark.sql.Column` + `week` of the year for given date as integer. + + Examples + -------- + >>> df = spark.createDataFrame([('2015-04-08',)], ['dt']) + >>> df.select(weekofyear(df.dt).alias('week')).collect() + [Row(week=15)] + """ + return _invoke_function_over_columns("weekofyear", col) + + +def make_date(year: "ColumnOrName", month: "ColumnOrName", day: "ColumnOrName") -> Column: + """ + Returns a column with a date built from the year, month and day columns. + + .. versionadded:: 3.3.0 + + Parameters + ---------- + year : :class:`~pyspark.sql.Column` or str + The year to build the date + month : :class:`~pyspark.sql.Column` or str + The month to build the date + day : :class:`~pyspark.sql.Column` or str + The day to build the date + + Returns + ------- + :class:`~pyspark.sql.Column` + a date built from given parts. + + Examples + -------- + >>> df = spark.createDataFrame([(2020, 6, 26)], ['Y', 'M', 'D']) + >>> df.select(make_date(df.Y, df.M, df.D).alias("datefield")).collect() + [Row(datefield=datetime.date(2020, 6, 26))] + """ + return _invoke_function_over_columns("make_date", year, month, day) + + +def date_add(start: "ColumnOrName", days: Union["ColumnOrName", int]) -> Column: + """ + Returns the date that is `days` days after `start`. If `days` is a negative value + then these amount of days will be deducted from `start`. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + start : :class:`~pyspark.sql.Column` or str + date column to work on. + days : :class:`~pyspark.sql.Column` or str or int + how many days after the given date to calculate. + Accepts negative value as well to calculate backwards in time. + + Returns + ------- + :class:`~pyspark.sql.Column` + a date after/before given number of days. + + Examples + -------- + >>> df = spark.createDataFrame([('2015-04-08', 2,)], ['dt', 'add']) + >>> df.select(date_add(df.dt, 1).alias('next_date')).collect() + [Row(next_date=datetime.date(2015, 4, 9))] + >>> df.select(date_add(df.dt, df.add.cast('integer')).alias('next_date')).collect() + [Row(next_date=datetime.date(2015, 4, 10))] + >>> df.select(date_add('dt', -1).alias('prev_date')).collect() + [Row(prev_date=datetime.date(2015, 4, 7))] + """ + days = lit(days) if isinstance(days, int) else days + return _invoke_function_over_columns("date_add", start, days) + + +def date_sub(start: "ColumnOrName", days: Union["ColumnOrName", int]) -> Column: + """ + Returns the date that is `days` days before `start`. If `days` is a negative value + then these amount of days will be added to `start`. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + start : :class:`~pyspark.sql.Column` or str + date column to work on. + days : :class:`~pyspark.sql.Column` or str or int + how many days before the given date to calculate. + Accepts negative value as well to calculate forward in time. + + Returns + ------- + :class:`~pyspark.sql.Column` + a date before/after given number of days. + + Examples + -------- + >>> df = spark.createDataFrame([('2015-04-08', 2,)], ['dt', 'sub']) + >>> df.select(date_sub(df.dt, 1).alias('prev_date')).collect() + [Row(prev_date=datetime.date(2015, 4, 7))] + >>> df.select(date_sub(df.dt, df.sub.cast('integer')).alias('prev_date')).collect() + [Row(prev_date=datetime.date(2015, 4, 6))] + >>> df.select(date_sub('dt', -1).alias('next_date')).collect() + [Row(next_date=datetime.date(2015, 4, 9))] + """ + days = lit(days) if isinstance(days, int) else days + return _invoke_function_over_columns("date_sub", start, days) + + +def datediff(end: "ColumnOrName", start: "ColumnOrName") -> Column: + """ + Returns the number of days from `start` to `end`. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + end : :class:`~pyspark.sql.Column` or str + to date column to work on. + start : :class:`~pyspark.sql.Column` or str + from date column to work on. + + Returns + ------- + :class:`~pyspark.sql.Column` + difference in days between two dates. + + Examples + -------- + >>> df = spark.createDataFrame([('2015-04-08','2015-05-10')], ['d1', 'd2']) + >>> df.select(datediff(df.d2, df.d1).alias('diff')).collect() + [Row(diff=32)] + """ + return _invoke_function_over_columns("datediff", end, start) + + +def add_months(start: "ColumnOrName", months: Union["ColumnOrName", int]) -> Column: + """ + Returns the date that is `months` months after `start`. If `months` is a negative value + then these amount of months will be deducted from the `start`. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + start : :class:`~pyspark.sql.Column` or str + date column to work on. + months : :class:`~pyspark.sql.Column` or str or int + how many months after the given date to calculate. + Accepts negative value as well to calculate backwards. + + Returns + ------- + :class:`~pyspark.sql.Column` + a date after/before given number of months. + + Examples + -------- + >>> df = spark.createDataFrame([('2015-04-08', 2)], ['dt', 'add']) + >>> df.select(add_months(df.dt, 1).alias('next_month')).collect() + [Row(next_month=datetime.date(2015, 5, 8))] + >>> df.select(add_months(df.dt, df.add.cast('integer')).alias('next_month')).collect() + [Row(next_month=datetime.date(2015, 6, 8))] + >>> df.select(add_months('dt', -2).alias('prev_month')).collect() + [Row(prev_month=datetime.date(2015, 2, 8))] + """ + months = lit(months) if isinstance(months, int) else months + return _invoke_function_over_columns("add_months", start, months) + + +def months_between(date1: "ColumnOrName", date2: "ColumnOrName", roundOff: bool = True) -> Column: + """ + Returns number of months between dates date1 and date2. + If date1 is later than date2, then the result is positive. + A whole number is returned if both inputs have the same day of month or both are the last day + of their respective months. Otherwise, the difference is calculated assuming 31 days per month. + The result is rounded off to 8 digits unless `roundOff` is set to `False`. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + date1 : :class:`~pyspark.sql.Column` or str + first date column. + date2 : :class:`~pyspark.sql.Column` or str + second date column. + roundOff : bool, optional + whether to round (to 8 digits) the final value or not (default: True). + + Returns + ------- + :class:`~pyspark.sql.Column` + number of months between two dates. + + Examples + -------- + >>> df = spark.createDataFrame([('1997-02-28 10:30:00', '1996-10-30')], ['date1', 'date2']) + >>> df.select(months_between(df.date1, df.date2).alias('months')).collect() + [Row(months=3.94959677)] + >>> df.select(months_between(df.date1, df.date2, False).alias('months')).collect() + [Row(months=3.9495967741935485)] + """ + return _invoke_function("months_between", _to_col(date1), _to_col(date2), lit(roundOff)) + + +def to_date(col: "ColumnOrName", format: Optional[str] = None) -> Column: + """Converts a :class:`~pyspark.sql.Column` into :class:`pyspark.sql.types.DateType` + using the optionally specified format. Specify formats according to `datetime pattern`_. + By default, it follows casting rules to :class:`pyspark.sql.types.DateType` if the format + is omitted. Equivalent to ``col.cast("date")``. + + .. _datetime pattern: https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + input column of values to convert. + format: str, optional + format to use to convert date values. + + Returns + ------- + :class:`~pyspark.sql.Column` + date value as :class:`pyspark.sql.types.DateType` type. + + Examples + -------- + >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t']) + >>> df.select(to_date(df.t).alias('date')).collect() + [Row(date=datetime.date(1997, 2, 28))] + + >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t']) + >>> df.select(to_date(df.t, 'yyyy-MM-dd HH:mm:ss').alias('date')).collect() + [Row(date=datetime.date(1997, 2, 28))] + """ + if format is None: + return _invoke_function_over_columns("to_date", col) + else: + return _invoke_function("to_date", _to_col(col), lit(format)) + + +@overload +def to_timestamp(col: "ColumnOrName") -> Column: + ... + + +@overload +def to_timestamp(col: "ColumnOrName", format: str) -> Column: + ... + + +def to_timestamp(col: "ColumnOrName", format: Optional[str] = None) -> Column: + """Converts a :class:`~pyspark.sql.Column` into :class:`pyspark.sql.types.TimestampType` + using the optionally specified format. Specify formats according to `datetime pattern`_. + By default, it follows casting rules to :class:`pyspark.sql.types.TimestampType` if the format + is omitted. Equivalent to ``col.cast("timestamp")``. + + .. _datetime pattern: https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + column values to convert. + format: str, optional + format to use to convert timestamp values. + + Returns + ------- + :class:`~pyspark.sql.Column` + timestamp value as :class:`pyspark.sql.types.TimestampType` type. + + Examples + -------- + >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t']) + >>> df.select(to_timestamp(df.t).alias('dt')).collect() + [Row(dt=datetime.datetime(1997, 2, 28, 10, 30))] + + >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t']) + >>> df.select(to_timestamp(df.t, 'yyyy-MM-dd HH:mm:ss').alias('dt')).collect() + [Row(dt=datetime.datetime(1997, 2, 28, 10, 30))] + """ + if format is None: + return _invoke_function_over_columns("to_timestamp", col) + else: + return _invoke_function("to_timestamp", _to_col(col), lit(format)) + + +def trunc(date: "ColumnOrName", format: str) -> Column: + """ + Returns date truncated to the unit specified by the format. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + date : :class:`~pyspark.sql.Column` or str + input column of values to truncate. + format : str + 'year', 'yyyy', 'yy' to truncate by year, + or 'month', 'mon', 'mm' to truncate by month + Other options are: 'week', 'quarter' + + Returns + ------- + :class:`~pyspark.sql.Column` + truncated date. + + Examples + -------- + >>> df = spark.createDataFrame([('1997-02-28',)], ['d']) + >>> df.select(trunc(df.d, 'year').alias('year')).collect() + [Row(year=datetime.date(1997, 1, 1))] + >>> df.select(trunc(df.d, 'mon').alias('month')).collect() + [Row(month=datetime.date(1997, 2, 1))] + """ + return _invoke_function("trunc", _to_col(date), lit(format)) + + +def date_trunc(format: str, timestamp: "ColumnOrName") -> Column: + """ + Returns timestamp truncated to the unit specified by the format. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + format : str + 'year', 'yyyy', 'yy' to truncate by year, + 'month', 'mon', 'mm' to truncate by month, + 'day', 'dd' to truncate by day, + Other options are: + 'microsecond', 'millisecond', 'second', 'minute', 'hour', 'week', 'quarter' + timestamp : :class:`~pyspark.sql.Column` or str + input column of values to truncate. + + Returns + ------- + :class:`~pyspark.sql.Column` + truncated timestamp. + + Examples + -------- + >>> df = spark.createDataFrame([('1997-02-28 05:02:11',)], ['t']) + >>> df.select(date_trunc('year', df.t).alias('year')).collect() + [Row(year=datetime.datetime(1997, 1, 1, 0, 0))] + >>> df.select(date_trunc('mon', df.t).alias('month')).collect() + [Row(month=datetime.datetime(1997, 2, 1, 0, 0))] + """ + return _invoke_function("date_trunc", lit(format), _to_col(timestamp)) + + +def next_day(date: "ColumnOrName", dayOfWeek: str) -> Column: + """ + Returns the first date which is later than the value of the date column + based on second `week day` argument. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + date : :class:`~pyspark.sql.Column` or str + target column to compute on. + dayOfWeek : str + day of the week, case-insensitive, accepts: + "Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun" + + Returns + ------- + :class:`~pyspark.sql.Column` + the column of computed results. + + Examples + -------- + >>> df = spark.createDataFrame([('2015-07-27',)], ['d']) + >>> df.select(next_day(df.d, 'Sun').alias('date')).collect() + [Row(date=datetime.date(2015, 8, 2))] + """ + return _invoke_function("next_day", _to_col(date), lit(dayOfWeek)) + + +def last_day(date: "ColumnOrName") -> Column: + """ + Returns the last day of the month which the given date belongs to. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + date : :class:`~pyspark.sql.Column` or str + target column to compute on. + + Returns + ------- + :class:`~pyspark.sql.Column` + last day of the month. + + Examples + -------- + >>> df = spark.createDataFrame([('1997-02-10',)], ['d']) + >>> df.select(last_day(df.d).alias('date')).collect() + [Row(date=datetime.date(1997, 2, 28))] + """ + return _invoke_function_over_columns("last_day", date) + + +def from_unixtime(timestamp: "ColumnOrName", format: str = "yyyy-MM-dd HH:mm:ss") -> Column: + """ + Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string + representing the timestamp of that moment in the current system time zone in the given + format. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + timestamp : :class:`~pyspark.sql.Column` or str + column of unix time values. + format : str, optional + format to use to convert to (default: yyyy-MM-dd HH:mm:ss) + + Returns + ------- + :class:`~pyspark.sql.Column` + formatted timestamp as string. + + Examples + -------- + >>> spark.conf.set("spark.sql.session.timeZone", "America/Los_Angeles") + >>> time_df = spark.createDataFrame([(1428476400,)], ['unix_time']) + >>> time_df.select(from_unixtime('unix_time').alias('ts')).collect() + [Row(ts='2015-04-08 00:00:00')] + >>> spark.conf.unset("spark.sql.session.timeZone") + """ + return _invoke_function("from_unixtime", _to_col(timestamp), lit(format)) + + +@overload +def unix_timestamp(timestamp: "ColumnOrName", format: str = ...) -> Column: + ... + + +@overload +def unix_timestamp() -> Column: + ... + + +def unix_timestamp( + timestamp: Optional["ColumnOrName"] = None, format: str = "yyyy-MM-dd HH:mm:ss" +) -> Column: + """ + Convert time string with given pattern ('yyyy-MM-dd HH:mm:ss', by default) + to Unix time stamp (in seconds), using the default timezone and the default + locale, returns null if failed. + + if `timestamp` is None, then it returns current timestamp. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + timestamp : :class:`~pyspark.sql.Column` or str, optional + timestamps of string values. + format : str, optional + alternative format to use for converting (default: yyyy-MM-dd HH:mm:ss). + + Returns + ------- + :class:`~pyspark.sql.Column` + unix time as long integer. + + Examples + -------- + >>> spark.conf.set("spark.sql.session.timeZone", "America/Los_Angeles") + >>> time_df = spark.createDataFrame([('2015-04-08',)], ['dt']) + >>> time_df.select(unix_timestamp('dt', 'yyyy-MM-dd').alias('unix_time')).collect() + [Row(unix_time=1428476400)] + >>> spark.conf.unset("spark.sql.session.timeZone") + """ + if timestamp is None: + return _invoke_function("unix_timestamp") + return _invoke_function("unix_timestamp", _to_col(timestamp), lit(format)) + + +def from_utc_timestamp(timestamp: "ColumnOrName", tz: "ColumnOrName") -> Column: + """ + This is a common function for databases supporting TIMESTAMP WITHOUT TIMEZONE. This function + takes a timestamp which is timezone-agnostic, and interprets it as a timestamp in UTC, and + renders that timestamp as a timestamp in the given time zone. + + However, timestamp in Spark represents number of microseconds from the Unix epoch, which is not + timezone-agnostic. So in Spark this function just shift the timestamp value from UTC timezone to + the given timezone. + + This function may return confusing result if the input is a string with timezone, e.g. + '2018-03-13T06:18:23+00:00'. The reason is that, Spark firstly cast the string to timestamp + according to the timezone in the string, and finally display the result by converting the + timestamp to string according to the session local timezone. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + timestamp : :class:`~pyspark.sql.Column` or str + the column that contains timestamps + tz : :class:`~pyspark.sql.Column` or str + A string detailing the time zone ID that the input should be adjusted to. It should + be in the format of either region-based zone IDs or zone offsets. Region IDs must + have the form 'area/city', such as 'America/Los_Angeles'. Zone offsets must be in + the format '(+|-)HH:mm', for example '-08:00' or '+01:00'. Also 'UTC' and 'Z' are + supported as aliases of '+00:00'. Other short names are not recommended to use + because they can be ambiguous. + `tz` can also take a :class:`~pyspark.sql.Column` containing timezone ID strings. + + Returns + ------- + :class:`~pyspark.sql.Column` + timestamp value represented in given timezone. + + Examples + -------- + >>> df = spark.createDataFrame([('1997-02-28 10:30:00', 'JST')], ['ts', 'tz']) + >>> df.select(from_utc_timestamp(df.ts, "PST").alias('local_time')).collect() + [Row(local_time=datetime.datetime(1997, 2, 28, 2, 30))] + >>> df.select(from_utc_timestamp(df.ts, df.tz).alias('local_time')).collect() + [Row(local_time=datetime.datetime(1997, 2, 28, 19, 30))] + """ + if isinstance(tz, str): + tz = lit(tz) + return _invoke_function_over_columns("from_utc_timestamp", timestamp, tz) + + +def to_utc_timestamp(timestamp: "ColumnOrName", tz: "ColumnOrName") -> Column: + """ + This is a common function for databases supporting TIMESTAMP WITHOUT TIMEZONE. This function + takes a timestamp which is timezone-agnostic, and interprets it as a timestamp in the given + timezone, and renders that timestamp as a timestamp in UTC. + + However, timestamp in Spark represents number of microseconds from the Unix epoch, which is not + timezone-agnostic. So in Spark this function just shift the timestamp value from the given + timezone to UTC timezone. + + This function may return confusing result if the input is a string with timezone, e.g. + '2018-03-13T06:18:23+00:00'. The reason is that, Spark firstly cast the string to timestamp + according to the timezone in the string, and finally display the result by converting the + timestamp to string according to the session local timezone. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + timestamp : :class:`~pyspark.sql.Column` or str + the column that contains timestamps + tz : :class:`~pyspark.sql.Column` or str + A string detailing the time zone ID that the input should be adjusted to. It should + be in the format of either region-based zone IDs or zone offsets. Region IDs must + have the form 'area/city', such as 'America/Los_Angeles'. Zone offsets must be in + the format '(+|-)HH:mm', for example '-08:00' or '+01:00'. Also 'UTC' and 'Z' are + supported as aliases of '+00:00'. Other short names are not recommended to use + because they can be ambiguous. + `tz` can also take a :class:`~pyspark.sql.Column` containing timezone ID strings. + + Returns + ------- + :class:`~pyspark.sql.Column` + timestamp value represented in UTC timezone. + + Examples + -------- + >>> df = spark.createDataFrame([('1997-02-28 10:30:00', 'JST')], ['ts', 'tz']) + >>> df.select(to_utc_timestamp(df.ts, "PST").alias('utc_time')).collect() + [Row(utc_time=datetime.datetime(1997, 2, 28, 18, 30))] + >>> df.select(to_utc_timestamp(df.ts, df.tz).alias('utc_time')).collect() + [Row(utc_time=datetime.datetime(1997, 2, 28, 1, 30))] + """ + if isinstance(tz, str): + tz = lit(tz) + return _invoke_function_over_columns("to_utc_timestamp", timestamp, tz) + + +def timestamp_seconds(col: "ColumnOrName") -> Column: + """ + Converts the number of seconds from the Unix epoch (1970-01-01T00:00:00Z) + to a timestamp. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + unix time values. + + Returns + ------- + :class:`~pyspark.sql.Column` + converted timestamp value. + + Examples + -------- + >>> from pyspark.sql.functions import timestamp_seconds + >>> spark.conf.set("spark.sql.session.timeZone", "UTC") + >>> time_df = spark.createDataFrame([(1230219000,)], ['unix_time']) + >>> time_df.select(timestamp_seconds(time_df.unix_time).alias('ts')).show() + +-------------------+ + | ts| + +-------------------+ + |2008-12-25 15:30:00| + +-------------------+ + >>> time_df.select(timestamp_seconds('unix_time').alias('ts')).printSchema() + root + |-- ts: timestamp (nullable = true) + >>> spark.conf.unset("spark.sql.session.timeZone") + """ + + return _invoke_function_over_columns("timestamp_seconds", col) + + +# Misc Functions + + +def assert_true(col: "ColumnOrName", errMsg: Optional[Union[Column, str]] = None) -> Column: + """ + Returns `null` if the input column is `true`; throws an exception + with the provided error message otherwise. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + column name or column that represents the input column to test + errMsg : :class:`~pyspark.sql.Column` or str, optional + A Python string literal or column containing the error message + + Returns + ------- + :class:`~pyspark.sql.Column` + `null` if the input column is `true` otherwise throws an error with specified message. + + Examples + -------- + >>> df = spark.createDataFrame([(0,1)], ['a', 'b']) + >>> df.select(assert_true(df.a < df.b).alias('r')).collect() + [Row(r=None)] + >>> df.select(assert_true(df.a < df.b, df.a).alias('r')).collect() + [Row(r=None)] + >>> df.select(assert_true(df.a < df.b, 'error').alias('r')).collect() + [Row(r=None)] + >>> df.select(assert_true(df.a > df.b, 'My error msg').alias('r')).collect() # doctest: +SKIP + ... + java.lang.RuntimeException: My error msg + ... + """ + if errMsg is None: + return _invoke_function_over_columns("assert_true", col) + if not isinstance(errMsg, (str, Column)): + raise TypeError("errMsg should be a Column or a str, got {}".format(type(errMsg))) + + _err_msg = lit(errMsg) if isinstance(errMsg, str) else _to_col(errMsg) + + return _invoke_function("assert_true", _to_col(col), _err_msg) + + +def raise_error(errMsg: Union[Column, str]) -> Column: + """ + Throws an exception with the provided error message. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + errMsg : :class:`~pyspark.sql.Column` or str + A Python string literal or column containing the error message + + Returns + ------- + :class:`~pyspark.sql.Column` + throws an error with specified message. + + Examples + -------- + >>> df = spark.range(1) + >>> df.select(raise_error("My error message")).show() # doctest: +SKIP + ... + java.lang.RuntimeException: My error message + ... + """ + if not isinstance(errMsg, (str, Column)): + raise TypeError("errMsg should be a Column or a str, got {}".format(type(errMsg))) + + _err_msg = lit(errMsg) if isinstance(errMsg, str) else _to_col(errMsg) + + return _invoke_function("raise_error", _err_msg) + + +def crc32(col: "ColumnOrName") -> Column: + """ + Calculates the cyclic redundancy check value (CRC32) of a binary column and + returns the value as a bigint. + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + target column to compute on. + + Returns + ------- + :class:`~pyspark.sql.Column` + the column for computed results. + + .. versionadded:: 3.4.0 + + Examples + -------- + >>> spark.createDataFrame([('ABC',)], ['a']).select(crc32('a').alias('crc32')).collect() + [Row(crc32=2743272264)] + """ + return _invoke_function_over_columns("crc32", col) + + +def hash(*cols: "ColumnOrName") -> Column: + """Calculates the hash code of given columns, and returns the result as an int column. + + .. versionadded:: 2.0.0 + + Parameters + ---------- + cols : :class:`~pyspark.sql.Column` or str + one or more columns to compute on. + + Returns + ------- + :class:`~pyspark.sql.Column` + hash value as int column. + + Examples + -------- + >>> df = spark.createDataFrame([('ABC', 'DEF')], ['c1', 'c2']) + + Hash for one column + + >>> df.select(hash('c1').alias('hash')).show() + +----------+ + | hash| + +----------+ + |-757602832| + +----------+ + + Two or more columns + + >>> df.select(hash('c1', 'c2').alias('hash')).show() + +---------+ + | hash| + +---------+ + |599895104| + +---------+ + """ + return _invoke_function_over_columns("hash", *cols) + + +def xxhash64(*cols: "ColumnOrName") -> Column: + """Calculates the hash code of given columns using the 64-bit variant of the xxHash algorithm, + and returns the result as a long column. The hash computation uses an initial seed of 42. + + .. versionadded:: 3.0.0 + + Parameters + ---------- + cols : :class:`~pyspark.sql.Column` or str + one or more columns to compute on. + + Returns + ------- + :class:`~pyspark.sql.Column` + hash value as long column. + + Examples + -------- + >>> df = spark.createDataFrame([('ABC', 'DEF')], ['c1', 'c2']) + + Hash for one column + + >>> df.select(xxhash64('c1').alias('hash')).show() + +-------------------+ + | hash| + +-------------------+ + |4105715581806190027| + +-------------------+ + + Two or more columns + + >>> df.select(xxhash64('c1', 'c2').alias('hash')).show() + +-------------------+ + | hash| + +-------------------+ + |3233247871021311208| + +-------------------+ + """ + return _invoke_function_over_columns("xxhash64", *cols) + + +def md5(col: "ColumnOrName") -> Column: + """Calculates the MD5 digest and returns the value as a 32 character hex string. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + target column to compute on. + + Returns + ------- + :class:`~pyspark.sql.Column` + the column for computed results. + + Examples + -------- + >>> spark.createDataFrame([('ABC',)], ['a']).select(md5('a').alias('hash')).collect() + [Row(hash='902fbdd2b1df0c4f70b4a5d23525e932')] + """ + return _invoke_function_over_columns("md5", col) + + +def sha1(col: "ColumnOrName") -> Column: + """Returns the hex string result of SHA-1. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + target column to compute on. + + Returns + ------- + :class:`~pyspark.sql.Column` + the column for computed results. + + Examples + -------- + >>> spark.createDataFrame([('ABC',)], ['a']).select(sha1('a').alias('hash')).collect() + [Row(hash='3c01bdbb26f358bab27f267924aa2c9a03fcfdb8')] + """ + return _invoke_function_over_columns("sha1", col) + + +def sha2(col: "ColumnOrName", numBits: int) -> Column: + """Returns the hex string result of SHA-2 family of hash functions (SHA-224, SHA-256, SHA-384, + and SHA-512). The numBits indicates the desired bit length of the result, which must have a + value of 224, 256, 384, 512, or 0 (which is equivalent to 256). + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + target column to compute on. + numBits : int + the desired bit length of the result, which must have a + value of 224, 256, 384, 512, or 0 (which is equivalent to 256). + + Returns + ------- + :class:`~pyspark.sql.Column` + the column for computed results. + + Examples + -------- + >>> df = spark.createDataFrame([["Alice"], ["Bob"]], ["name"]) + >>> df.withColumn("sha2", sha2(df.name, 256)).show(truncate=False) + +-----+----------------------------------------------------------------+ + |name |sha2 | + +-----+----------------------------------------------------------------+ + |Alice|3bc51062973c458d5a6f2d8d64a023246354ad7e064b1e4e009ec8a0699a3043| + |Bob |cd9fb1e148ccd8442e5aa74904cc73bf6fb54d1d54d333bd596aa9bb4bb4e961| + +-----+----------------------------------------------------------------+ + """ + return _invoke_function("sha2", _to_col(col), lit(numBits)) diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 2de0dbb40c..e8b6d79943 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -17,11 +17,13 @@ from typing import Any, List, Optional, Sequence, Union, cast, TYPE_CHECKING, Mapping, Dict import functools -import pandas import pyarrow as pa + +from pyspark.sql.types import DataType + import pyspark.sql.connect.proto as proto from pyspark.sql.connect.column import Column, SortOrder, ColumnReference - +from pyspark.sql.connect.types import pyspark_types_to_proto_types if TYPE_CHECKING: from pyspark.sql.connect._typing import ColumnOrName @@ -167,21 +169,34 @@ def _repr_html_(self) -> str: class LocalRelation(LogicalPlan): - """Creates a LocalRelation plan object based on a Pandas DataFrame.""" + """Creates a LocalRelation plan object based on a PyArrow Table.""" - def __init__(self, pdf: "pandas.DataFrame") -> None: + def __init__( + self, + table: "pa.Table", + schema: Optional[Union[DataType, str]] = None, + ) -> None: super().__init__(None) - self._pdf = pdf + assert table is not None and isinstance(table, pa.Table) + self._table = table + + if schema is not None: + assert isinstance(schema, (DataType, str)) + self._schema = schema def plan(self, session: "SparkConnectClient") -> proto.Relation: sink = pa.BufferOutputStream() - table = pa.Table.from_pandas(self._pdf) - with pa.ipc.new_stream(sink, table.schema) as writer: - for b in table.to_batches(): + with pa.ipc.new_stream(sink, self._table.schema) as writer: + for b in self._table.to_batches(): writer.write_batch(b) plan = proto.Relation() plan.local_relation.data = sink.getvalue().to_pybytes() + if self._schema is not None: + if isinstance(self._schema, DataType): + plan.local_relation.datatype.CopyFrom(pyspark_types_to_proto_types(self._schema)) + elif isinstance(self._schema, str): + plan.local_relation.datatype_str = self._schema return plan def print(self, indent: int = 0) -> str: @@ -984,6 +999,65 @@ def _repr_html_(self) -> str: """ +class Unpivot(LogicalPlan): + """Logical plan object for a unpivot operation.""" + + def __init__( + self, + child: Optional["LogicalPlan"], + ids: List["ColumnOrName"], + values: List["ColumnOrName"], + variable_column_name: str, + value_column_name: str, + ) -> None: + super().__init__(child) + self.ids = ids + self.values = values + self.variable_column_name = variable_column_name + self.value_column_name = value_column_name + + def col_to_expr(self, col: "ColumnOrName", session: "SparkConnectClient") -> proto.Expression: + if isinstance(col, Column): + return col.to_plan(session) + else: + return self.unresolved_attr(col) + + def plan(self, session: "SparkConnectClient") -> proto.Relation: + assert self._child is not None + + plan = proto.Relation() + plan.unpivot.input.CopyFrom(self._child.plan(session)) + plan.unpivot.ids.extend([self.col_to_expr(x, session) for x in self.ids]) + plan.unpivot.values.extend([self.col_to_expr(x, session) for x in self.values]) + plan.unpivot.variable_column_name = self.variable_column_name + plan.unpivot.value_column_name = self.value_column_name + return plan + + def print(self, indent: int = 0) -> str: + c_buf = self._child.print(indent + LogicalPlan.INDENT) if self._child else "" + return ( + f"{' ' * indent}" + f"" + f"\n{c_buf}" + ) + + def _repr_html_(self) -> str: + return f""" +
    +
  • + Unpivot
    + ids: {self.ids} + values: {self.values} + variable_column_name: {self.variable_column_name} + value_column_name: {self.value_column_name} + {self._child._repr_html_() if self._child is not None else ""} +
  • +
+ """ + + class NAFill(LogicalPlan): def __init__( self, child: Optional["LogicalPlan"], cols: Optional[List[str]], values: List[Any] diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py b/python/pyspark/sql/connect/proto/expressions_pb2.py index 8510216324..91c57a9ef2 100644 --- a/python/pyspark/sql/connect/proto/expressions_pb2.py +++ b/python/pyspark/sql/connect/proto/expressions_pb2.py @@ -33,7 +33,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19spark/connect/types.proto"\xd2\x14\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunction\x12Y\n\x11\x65xpression_string\x18\x04 \x01(\x0b\x32*.spark.connect.Expression.ExpressionStringH\x00R\x10\x65xpressionString\x12S\n\x0funresolved_star\x18\x05 \x01(\x0b\x32(.spark.connect.Expression.UnresolvedStarH\x00R\x0eunresolvedStar\x12\x37\n\x05\x61lias\x18\x06 \x01(\x0b\x32\x1f.spark.connect.Expression.AliasH\x00R\x05\x61lias\x12\x34\n\x04\x63\x61st\x18\x07 \x01(\x0b\x32\x1e.spark.connect.Expression.CastH\x00R\x04\x63\x61st\x1ap\n\x04\x43\x61st\x12-\n\x04\x65xpr\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x04\x65xpr\x12\x39\n\x0c\x63\x61st_to_type\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\ncastToType\x1a\xb2\x0b\n\x07Literal\x12\x14\n\x04null\x18\x01 \x01(\x08H\x00R\x04null\x12\x18\n\x06\x62inary\x18\x02 \x01(\x0cH\x00R\x06\x62inary\x12\x1a\n\x07\x62oolean\x18\x03 \x01(\x08H\x00R\x07\x62oolean\x12\x14\n\x04\x62yte\x18\x04 \x01(\x05H\x00R\x04\x62yte\x12\x16\n\x05short\x18\x05 \x01(\x05H\x00R\x05short\x12\x1a\n\x07integer\x18\x06 \x01(\x05H\x00R\x07integer\x12\x14\n\x04long\x18\x07 \x01(\x03H\x00R\x04long\x12\x16\n\x05\x66loat\x18\n \x01(\x02H\x00R\x05\x66loat\x12\x18\n\x06\x64ouble\x18\x0b \x01(\x01H\x00R\x06\x64ouble\x12\x45\n\x07\x64\x65\x63imal\x18\x0c \x01(\x0b\x32).spark.connect.Expression.Literal.DecimalH\x00R\x07\x64\x65\x63imal\x12\x18\n\x06string\x18\r \x01(\tH\x00R\x06string\x12\x14\n\x04\x64\x61te\x18\x10 \x01(\x05H\x00R\x04\x64\x61te\x12\x1e\n\ttimestamp\x18\x11 \x01(\x03H\x00R\ttimestamp\x12%\n\rtimestamp_ntz\x18\x12 \x01(\x03H\x00R\x0ctimestampNtz\x12\x61\n\x11\x63\x61lendar_interval\x18\x13 \x01(\x0b\x32\x32.spark.connect.Expression.Literal.CalendarIntervalH\x00R\x10\x63\x61lendarInterval\x12\x30\n\x13year_month_interval\x18\x14 \x01(\x05H\x00R\x11yearMonthInterval\x12,\n\x11\x64\x61y_time_interval\x18\x15 \x01(\x03H\x00R\x0f\x64\x61yTimeInterval\x12?\n\x05\x61rray\x18\x16 \x01(\x0b\x32\'.spark.connect.Expression.Literal.ArrayH\x00R\x05\x61rray\x12\x42\n\x06struct\x18\x17 \x01(\x0b\x32(.spark.connect.Expression.Literal.StructH\x00R\x06struct\x12\x39\n\x03map\x18\x18 \x01(\x0b\x32%.spark.connect.Expression.Literal.MapH\x00R\x03map\x12\x1a\n\x08nullable\x18\x32 \x01(\x08R\x08nullable\x12\x38\n\x18type_variation_reference\x18\x33 \x01(\rR\x16typeVariationReference\x1au\n\x07\x44\x65\x63imal\x12\x14\n\x05value\x18\x01 \x01(\tR\x05value\x12!\n\tprecision\x18\x02 \x01(\x05H\x00R\tprecision\x88\x01\x01\x12\x19\n\x05scale\x18\x03 \x01(\x05H\x01R\x05scale\x88\x01\x01\x42\x0c\n\n_precisionB\x08\n\x06_scale\x1a\x62\n\x10\x43\x61lendarInterval\x12\x16\n\x06months\x18\x01 \x01(\x05R\x06months\x12\x12\n\x04\x64\x61ys\x18\x02 \x01(\x05R\x04\x64\x61ys\x12"\n\x0cmicroseconds\x18\x03 \x01(\x03R\x0cmicroseconds\x1a\x43\n\x06Struct\x12\x39\n\x06\x66ields\x18\x01 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06\x66ields\x1a\x42\n\x05\x41rray\x12\x39\n\x06values\x18\x01 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values\x1a\xbd\x01\n\x03Map\x12@\n\x05pairs\x18\x01 \x03(\x0b\x32*.spark.connect.Expression.Literal.Map.PairR\x05pairs\x1at\n\x04Pair\x12\x33\n\x03key\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x03key\x12\x37\n\x05value\x18\x02 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x05valueB\x0e\n\x0cliteral_type\x1a\x46\n\x13UnresolvedAttribute\x12/\n\x13unparsed_identifier\x18\x01 \x01(\tR\x12unparsedIdentifier\x1a\xcc\x01\n\x12UnresolvedFunction\x12#\n\rfunction_name\x18\x01 \x01(\tR\x0c\x66unctionName\x12\x37\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments\x12\x1f\n\x0bis_distinct\x18\x03 \x01(\x08R\nisDistinct\x12\x37\n\x18is_user_defined_function\x18\x04 \x01(\x08R\x15isUserDefinedFunction\x1a\x32\n\x10\x45xpressionString\x12\x1e\n\nexpression\x18\x01 \x01(\tR\nexpression\x1a(\n\x0eUnresolvedStar\x12\x16\n\x06target\x18\x01 \x03(\tR\x06target\x1ax\n\x05\x41lias\x12-\n\x04\x65xpr\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x04\x65xpr\x12\x12\n\x04name\x18\x02 \x03(\tR\x04name\x12\x1f\n\x08metadata\x18\x03 \x01(\tH\x00R\x08metadata\x88\x01\x01\x42\x0b\n\t_metadataB\x0b\n\texpr_typeB"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' + b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19spark/connect/types.proto"\xf4\x14\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunction\x12Y\n\x11\x65xpression_string\x18\x04 \x01(\x0b\x32*.spark.connect.Expression.ExpressionStringH\x00R\x10\x65xpressionString\x12S\n\x0funresolved_star\x18\x05 \x01(\x0b\x32(.spark.connect.Expression.UnresolvedStarH\x00R\x0eunresolvedStar\x12\x37\n\x05\x61lias\x18\x06 \x01(\x0b\x32\x1f.spark.connect.Expression.AliasH\x00R\x05\x61lias\x12\x34\n\x04\x63\x61st\x18\x07 \x01(\x0b\x32\x1e.spark.connect.Expression.CastH\x00R\x04\x63\x61st\x1a\x91\x01\n\x04\x43\x61st\x12-\n\x04\x65xpr\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x04\x65xpr\x12-\n\x04type\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\x04type\x12\x1b\n\x08type_str\x18\x03 \x01(\tH\x00R\x07typeStrB\x0e\n\x0c\x63\x61st_to_type\x1a\xb2\x0b\n\x07Literal\x12\x14\n\x04null\x18\x01 \x01(\x08H\x00R\x04null\x12\x18\n\x06\x62inary\x18\x02 \x01(\x0cH\x00R\x06\x62inary\x12\x1a\n\x07\x62oolean\x18\x03 \x01(\x08H\x00R\x07\x62oolean\x12\x14\n\x04\x62yte\x18\x04 \x01(\x05H\x00R\x04\x62yte\x12\x16\n\x05short\x18\x05 \x01(\x05H\x00R\x05short\x12\x1a\n\x07integer\x18\x06 \x01(\x05H\x00R\x07integer\x12\x14\n\x04long\x18\x07 \x01(\x03H\x00R\x04long\x12\x16\n\x05\x66loat\x18\n \x01(\x02H\x00R\x05\x66loat\x12\x18\n\x06\x64ouble\x18\x0b \x01(\x01H\x00R\x06\x64ouble\x12\x45\n\x07\x64\x65\x63imal\x18\x0c \x01(\x0b\x32).spark.connect.Expression.Literal.DecimalH\x00R\x07\x64\x65\x63imal\x12\x18\n\x06string\x18\r \x01(\tH\x00R\x06string\x12\x14\n\x04\x64\x61te\x18\x10 \x01(\x05H\x00R\x04\x64\x61te\x12\x1e\n\ttimestamp\x18\x11 \x01(\x03H\x00R\ttimestamp\x12%\n\rtimestamp_ntz\x18\x12 \x01(\x03H\x00R\x0ctimestampNtz\x12\x61\n\x11\x63\x61lendar_interval\x18\x13 \x01(\x0b\x32\x32.spark.connect.Expression.Literal.CalendarIntervalH\x00R\x10\x63\x61lendarInterval\x12\x30\n\x13year_month_interval\x18\x14 \x01(\x05H\x00R\x11yearMonthInterval\x12,\n\x11\x64\x61y_time_interval\x18\x15 \x01(\x03H\x00R\x0f\x64\x61yTimeInterval\x12?\n\x05\x61rray\x18\x16 \x01(\x0b\x32\'.spark.connect.Expression.Literal.ArrayH\x00R\x05\x61rray\x12\x42\n\x06struct\x18\x17 \x01(\x0b\x32(.spark.connect.Expression.Literal.StructH\x00R\x06struct\x12\x39\n\x03map\x18\x18 \x01(\x0b\x32%.spark.connect.Expression.Literal.MapH\x00R\x03map\x12\x1a\n\x08nullable\x18\x32 \x01(\x08R\x08nullable\x12\x38\n\x18type_variation_reference\x18\x33 \x01(\rR\x16typeVariationReference\x1au\n\x07\x44\x65\x63imal\x12\x14\n\x05value\x18\x01 \x01(\tR\x05value\x12!\n\tprecision\x18\x02 \x01(\x05H\x00R\tprecision\x88\x01\x01\x12\x19\n\x05scale\x18\x03 \x01(\x05H\x01R\x05scale\x88\x01\x01\x42\x0c\n\n_precisionB\x08\n\x06_scale\x1a\x62\n\x10\x43\x61lendarInterval\x12\x16\n\x06months\x18\x01 \x01(\x05R\x06months\x12\x12\n\x04\x64\x61ys\x18\x02 \x01(\x05R\x04\x64\x61ys\x12"\n\x0cmicroseconds\x18\x03 \x01(\x03R\x0cmicroseconds\x1a\x43\n\x06Struct\x12\x39\n\x06\x66ields\x18\x01 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06\x66ields\x1a\x42\n\x05\x41rray\x12\x39\n\x06values\x18\x01 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values\x1a\xbd\x01\n\x03Map\x12@\n\x05pairs\x18\x01 \x03(\x0b\x32*.spark.connect.Expression.Literal.Map.PairR\x05pairs\x1at\n\x04Pair\x12\x33\n\x03key\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x03key\x12\x37\n\x05value\x18\x02 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x05valueB\x0e\n\x0cliteral_type\x1a\x46\n\x13UnresolvedAttribute\x12/\n\x13unparsed_identifier\x18\x01 \x01(\tR\x12unparsedIdentifier\x1a\xcc\x01\n\x12UnresolvedFunction\x12#\n\rfunction_name\x18\x01 \x01(\tR\x0c\x66unctionName\x12\x37\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments\x12\x1f\n\x0bis_distinct\x18\x03 \x01(\x08R\nisDistinct\x12\x37\n\x18is_user_defined_function\x18\x04 \x01(\x08R\x15isUserDefinedFunction\x1a\x32\n\x10\x45xpressionString\x12\x1e\n\nexpression\x18\x01 \x01(\tR\nexpression\x1a(\n\x0eUnresolvedStar\x12\x16\n\x06target\x18\x01 \x03(\tR\x06target\x1ax\n\x05\x41lias\x12-\n\x04\x65xpr\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x04\x65xpr\x12\x12\n\x04name\x18\x02 \x03(\tR\x04name\x12\x1f\n\x08metadata\x18\x03 \x01(\tH\x00R\x08metadata\x88\x01\x01\x42\x0b\n\t_metadataB\x0b\n\texpr_typeB"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' ) @@ -197,31 +197,31 @@ DESCRIPTOR._options = None DESCRIPTOR._serialized_options = b"\n\036org.apache.spark.connect.protoP\001" _EXPRESSION._serialized_start = 78 - _EXPRESSION._serialized_end = 2720 - _EXPRESSION_CAST._serialized_start = 639 - _EXPRESSION_CAST._serialized_end = 751 - _EXPRESSION_LITERAL._serialized_start = 754 - _EXPRESSION_LITERAL._serialized_end = 2212 - _EXPRESSION_LITERAL_DECIMAL._serialized_start = 1650 - _EXPRESSION_LITERAL_DECIMAL._serialized_end = 1767 - _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 1769 - _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 1867 - _EXPRESSION_LITERAL_STRUCT._serialized_start = 1869 - _EXPRESSION_LITERAL_STRUCT._serialized_end = 1936 - _EXPRESSION_LITERAL_ARRAY._serialized_start = 1938 - _EXPRESSION_LITERAL_ARRAY._serialized_end = 2004 - _EXPRESSION_LITERAL_MAP._serialized_start = 2007 - _EXPRESSION_LITERAL_MAP._serialized_end = 2196 - _EXPRESSION_LITERAL_MAP_PAIR._serialized_start = 2080 - _EXPRESSION_LITERAL_MAP_PAIR._serialized_end = 2196 - _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 2214 - _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 2284 - _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 2287 - _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 2491 - _EXPRESSION_EXPRESSIONSTRING._serialized_start = 2493 - _EXPRESSION_EXPRESSIONSTRING._serialized_end = 2543 - _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 2545 - _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 2585 - _EXPRESSION_ALIAS._serialized_start = 2587 - _EXPRESSION_ALIAS._serialized_end = 2707 + _EXPRESSION._serialized_end = 2754 + _EXPRESSION_CAST._serialized_start = 640 + _EXPRESSION_CAST._serialized_end = 785 + _EXPRESSION_LITERAL._serialized_start = 788 + _EXPRESSION_LITERAL._serialized_end = 2246 + _EXPRESSION_LITERAL_DECIMAL._serialized_start = 1684 + _EXPRESSION_LITERAL_DECIMAL._serialized_end = 1801 + _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 1803 + _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 1901 + _EXPRESSION_LITERAL_STRUCT._serialized_start = 1903 + _EXPRESSION_LITERAL_STRUCT._serialized_end = 1970 + _EXPRESSION_LITERAL_ARRAY._serialized_start = 1972 + _EXPRESSION_LITERAL_ARRAY._serialized_end = 2038 + _EXPRESSION_LITERAL_MAP._serialized_start = 2041 + _EXPRESSION_LITERAL_MAP._serialized_end = 2230 + _EXPRESSION_LITERAL_MAP_PAIR._serialized_start = 2114 + _EXPRESSION_LITERAL_MAP_PAIR._serialized_end = 2230 + _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 2248 + _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 2318 + _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 2321 + _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 2525 + _EXPRESSION_EXPRESSIONSTRING._serialized_start = 2527 + _EXPRESSION_EXPRESSIONSTRING._serialized_end = 2577 + _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 2579 + _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 2619 + _EXPRESSION_ALIAS._serialized_start = 2621 + _EXPRESSION_ALIAS._serialized_end = 2741 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi b/python/pyspark/sql/connect/proto/expressions_pb2.pyi index c1034a8636..2c486f62a9 100644 --- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi +++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi @@ -60,27 +60,51 @@ class Expression(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor EXPR_FIELD_NUMBER: builtins.int - CAST_TO_TYPE_FIELD_NUMBER: builtins.int + TYPE_FIELD_NUMBER: builtins.int + TYPE_STR_FIELD_NUMBER: builtins.int @property def expr(self) -> global___Expression: """(Required) the expression to be casted.""" @property - def cast_to_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType: - """(Required) the data type that the expr to be casted to.""" + def type(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ... + type_str: builtins.str + """If this is set, Server will use Catalyst parser to parse this string to DataType.""" def __init__( self, *, expr: global___Expression | None = ..., - cast_to_type: pyspark.sql.connect.proto.types_pb2.DataType | None = ..., + type: pyspark.sql.connect.proto.types_pb2.DataType | None = ..., + type_str: builtins.str = ..., ) -> None: ... def HasField( self, - field_name: typing_extensions.Literal["cast_to_type", b"cast_to_type", "expr", b"expr"], + field_name: typing_extensions.Literal[ + "cast_to_type", + b"cast_to_type", + "expr", + b"expr", + "type", + b"type", + "type_str", + b"type_str", + ], ) -> builtins.bool: ... def ClearField( self, - field_name: typing_extensions.Literal["cast_to_type", b"cast_to_type", "expr", b"expr"], + field_name: typing_extensions.Literal[ + "cast_to_type", + b"cast_to_type", + "expr", + b"expr", + "type", + b"type", + "type_str", + b"type_str", + ], ) -> None: ... + def WhichOneof( + self, oneof_group: typing_extensions.Literal["cast_to_type", b"cast_to_type"] + ) -> typing_extensions.Literal["type", "type_str"] | None: ... class Literal(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py index 06cf18417d..d1651d0b72 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.py +++ b/python/pyspark/sql/connect/proto/relations_pb2.py @@ -30,10 +30,11 @@ from pyspark.sql.connect.proto import expressions_pb2 as spark_dot_connect_dot_expressions__pb2 +from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__pb2 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\x8b\x0e\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0b\x32\x13.spark.connect.JoinH\x00R\x04join\x12\x34\n\x06set_op\x18\x06 \x01(\x0b\x32\x1b.spark.connect.SetOperationH\x00R\x05setOp\x12)\n\x04sort\x18\x07 \x01(\x0b\x32\x13.spark.connect.SortH\x00R\x04sort\x12,\n\x05limit\x18\x08 \x01(\x0b\x32\x14.spark.connect.LimitH\x00R\x05limit\x12\x38\n\taggregate\x18\t \x01(\x0b\x32\x18.spark.connect.AggregateH\x00R\taggregate\x12&\n\x03sql\x18\n \x01(\x0b\x32\x12.spark.connect.SQLH\x00R\x03sql\x12\x45\n\x0elocal_relation\x18\x0b \x01(\x0b\x32\x1c.spark.connect.LocalRelationH\x00R\rlocalRelation\x12/\n\x06sample\x18\x0c \x01(\x0b\x32\x15.spark.connect.SampleH\x00R\x06sample\x12/\n\x06offset\x18\r \x01(\x0b\x32\x15.spark.connect.OffsetH\x00R\x06offset\x12>\n\x0b\x64\x65\x64uplicate\x18\x0e \x01(\x0b\x32\x1a.spark.connect.DeduplicateH\x00R\x0b\x64\x65\x64uplicate\x12,\n\x05range\x18\x0f \x01(\x0b\x32\x14.spark.connect.RangeH\x00R\x05range\x12\x45\n\x0esubquery_alias\x18\x10 \x01(\x0b\x32\x1c.spark.connect.SubqueryAliasH\x00R\rsubqueryAlias\x12>\n\x0brepartition\x18\x11 \x01(\x0b\x32\x1a.spark.connect.RepartitionH\x00R\x0brepartition\x12|\n#rename_columns_by_same_length_names\x18\x12 \x01(\x0b\x32-.spark.connect.RenameColumnsBySameLengthNamesH\x00R\x1erenameColumnsBySameLengthNames\x12w\n"rename_columns_by_name_to_name_map\x18\x13 \x01(\x0b\x32+.spark.connect.RenameColumnsByNameToNameMapH\x00R\x1crenameColumnsByNameToNameMap\x12<\n\x0bshow_string\x18\x14 \x01(\x0b\x32\x19.spark.connect.ShowStringH\x00R\nshowString\x12)\n\x04\x64rop\x18\x15 \x01(\x0b\x32\x13.spark.connect.DropH\x00R\x04\x64rop\x12)\n\x04tail\x18\x16 \x01(\x0b\x32\x13.spark.connect.TailH\x00R\x04tail\x12?\n\x0cwith_columns\x18\x17 \x01(\x0b\x32\x1a.spark.connect.WithColumnsH\x00R\x0bwithColumns\x12)\n\x04hint\x18\x18 \x01(\x0b\x32\x13.spark.connect.HintH\x00R\x04hint\x12\x30\n\x07\x66ill_na\x18Z \x01(\x0b\x32\x15.spark.connect.NAFillH\x00R\x06\x66illNa\x12\x30\n\x07\x64rop_na\x18[ \x01(\x0b\x32\x15.spark.connect.NADropH\x00R\x06\x64ropNa\x12\x34\n\x07replace\x18\\ \x01(\x0b\x32\x18.spark.connect.NAReplaceH\x00R\x07replace\x12\x36\n\x07summary\x18\x64 \x01(\x0b\x32\x1a.spark.connect.StatSummaryH\x00R\x07summary\x12\x39\n\x08\x63rosstab\x18\x65 \x01(\x0b\x32\x1b.spark.connect.StatCrosstabH\x00R\x08\x63rosstab\x12\x39\n\x08\x64\x65scribe\x18\x66 \x01(\x0b\x32\x1b.spark.connect.StatDescribeH\x00R\x08\x64\x65scribe\x12\x33\n\x07unknown\x18\xe7\x07 \x01(\x0b\x32\x16.spark.connect.UnknownH\x00R\x07unknownB\n\n\x08rel_type"\t\n\x07Unknown"1\n\x0eRelationCommon\x12\x1f\n\x0bsource_info\x18\x01 \x01(\tR\nsourceInfo"\x1b\n\x03SQL\x12\x14\n\x05query\x18\x01 \x01(\tR\x05query"\xaa\x03\n\x04Read\x12\x41\n\x0bnamed_table\x18\x01 \x01(\x0b\x32\x1e.spark.connect.Read.NamedTableH\x00R\nnamedTable\x12\x41\n\x0b\x64\x61ta_source\x18\x02 \x01(\x0b\x32\x1e.spark.connect.Read.DataSourceH\x00R\ndataSource\x1a=\n\nNamedTable\x12/\n\x13unparsed_identifier\x18\x01 \x01(\tR\x12unparsedIdentifier\x1a\xcf\x01\n\nDataSource\x12\x16\n\x06\x66ormat\x18\x01 \x01(\tR\x06\x66ormat\x12\x1b\n\x06schema\x18\x02 \x01(\tH\x00R\x06schema\x88\x01\x01\x12\x45\n\x07options\x18\x03 \x03(\x0b\x32+.spark.connect.Read.DataSource.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\t\n\x07_schemaB\x0b\n\tread_type"u\n\x07Project\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12;\n\x0b\x65xpressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0b\x65xpressions"p\n\x06\x46ilter\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x37\n\tcondition\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\tcondition"\xd7\x03\n\x04Join\x12+\n\x04left\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x04left\x12-\n\x05right\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x05right\x12@\n\x0ejoin_condition\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\rjoinCondition\x12\x39\n\tjoin_type\x18\x04 \x01(\x0e\x32\x1c.spark.connect.Join.JoinTypeR\x08joinType\x12#\n\rusing_columns\x18\x05 \x03(\tR\x0cusingColumns"\xd0\x01\n\x08JoinType\x12\x19\n\x15JOIN_TYPE_UNSPECIFIED\x10\x00\x12\x13\n\x0fJOIN_TYPE_INNER\x10\x01\x12\x18\n\x14JOIN_TYPE_FULL_OUTER\x10\x02\x12\x18\n\x14JOIN_TYPE_LEFT_OUTER\x10\x03\x12\x19\n\x15JOIN_TYPE_RIGHT_OUTER\x10\x04\x12\x17\n\x13JOIN_TYPE_LEFT_ANTI\x10\x05\x12\x17\n\x13JOIN_TYPE_LEFT_SEMI\x10\x06\x12\x13\n\x0fJOIN_TYPE_CROSS\x10\x07"\x8c\x03\n\x0cSetOperation\x12\x36\n\nleft_input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\tleftInput\x12\x38\n\x0bright_input\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\nrightInput\x12\x45\n\x0bset_op_type\x18\x03 \x01(\x0e\x32%.spark.connect.SetOperation.SetOpTypeR\tsetOpType\x12\x1a\n\x06is_all\x18\x04 \x01(\x08H\x00R\x05isAll\x88\x01\x01\x12\x1c\n\x07\x62y_name\x18\x05 \x01(\x08H\x01R\x06\x62yName\x88\x01\x01"r\n\tSetOpType\x12\x1b\n\x17SET_OP_TYPE_UNSPECIFIED\x10\x00\x12\x19\n\x15SET_OP_TYPE_INTERSECT\x10\x01\x12\x15\n\x11SET_OP_TYPE_UNION\x10\x02\x12\x16\n\x12SET_OP_TYPE_EXCEPT\x10\x03\x42\t\n\x07_is_allB\n\n\x08_by_name"L\n\x05Limit\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05limit\x18\x02 \x01(\x05R\x05limit"O\n\x06Offset\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x16\n\x06offset\x18\x02 \x01(\x05R\x06offset"K\n\x04Tail\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05limit\x18\x02 \x01(\x05R\x05limit"\xd2\x01\n\tAggregate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12L\n\x14grouping_expressions\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13groupingExpressions\x12H\n\x12result_expressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x11resultExpressions"\xa6\x04\n\x04Sort\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12>\n\x0bsort_fields\x18\x02 \x03(\x0b\x32\x1d.spark.connect.Sort.SortFieldR\nsortFields\x12 \n\tis_global\x18\x03 \x01(\x08H\x00R\x08isGlobal\x88\x01\x01\x1a\xbc\x01\n\tSortField\x12\x39\n\nexpression\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\nexpression\x12?\n\tdirection\x18\x02 \x01(\x0e\x32!.spark.connect.Sort.SortDirectionR\tdirection\x12\x33\n\x05nulls\x18\x03 \x01(\x0e\x32\x1d.spark.connect.Sort.SortNullsR\x05nulls"l\n\rSortDirection\x12\x1e\n\x1aSORT_DIRECTION_UNSPECIFIED\x10\x00\x12\x1c\n\x18SORT_DIRECTION_ASCENDING\x10\x01\x12\x1d\n\x19SORT_DIRECTION_DESCENDING\x10\x02"R\n\tSortNulls\x12\x1a\n\x16SORT_NULLS_UNSPECIFIED\x10\x00\x12\x14\n\x10SORT_NULLS_FIRST\x10\x01\x12\x13\n\x0fSORT_NULLS_LAST\x10\x02\x42\x0c\n\n_is_global"d\n\x04\x44rop\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12-\n\x04\x63ols\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x04\x63ols"\xab\x01\n\x0b\x44\x65\x64uplicate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12!\n\x0c\x63olumn_names\x18\x02 \x03(\tR\x0b\x63olumnNames\x12\x32\n\x13\x61ll_columns_as_keys\x18\x03 \x01(\x08H\x00R\x10\x61llColumnsAsKeys\x88\x01\x01\x42\x16\n\x14_all_columns_as_keys"#\n\rLocalRelation\x12\x12\n\x04\x64\x61ta\x18\x01 \x01(\x0cR\x04\x64\x61ta"\xe0\x01\n\x06Sample\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1f\n\x0blower_bound\x18\x02 \x01(\x01R\nlowerBound\x12\x1f\n\x0bupper_bound\x18\x03 \x01(\x01R\nupperBound\x12.\n\x10with_replacement\x18\x04 \x01(\x08H\x00R\x0fwithReplacement\x88\x01\x01\x12\x17\n\x04seed\x18\x05 \x01(\x03H\x01R\x04seed\x88\x01\x01\x42\x13\n\x11_with_replacementB\x07\n\x05_seed"\x91\x01\n\x05Range\x12\x19\n\x05start\x18\x01 \x01(\x03H\x00R\x05start\x88\x01\x01\x12\x10\n\x03\x65nd\x18\x02 \x01(\x03R\x03\x65nd\x12\x12\n\x04step\x18\x03 \x01(\x03R\x04step\x12*\n\x0enum_partitions\x18\x04 \x01(\x05H\x01R\rnumPartitions\x88\x01\x01\x42\x08\n\x06_startB\x11\n\x0f_num_partitions"r\n\rSubqueryAlias\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05\x61lias\x18\x02 \x01(\tR\x05\x61lias\x12\x1c\n\tqualifier\x18\x03 \x03(\tR\tqualifier"\x8e\x01\n\x0bRepartition\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12%\n\x0enum_partitions\x18\x02 \x01(\x05R\rnumPartitions\x12\x1d\n\x07shuffle\x18\x03 \x01(\x08H\x00R\x07shuffle\x88\x01\x01\x42\n\n\x08_shuffle"\x8d\x01\n\nShowString\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x18\n\x07numRows\x18\x02 \x01(\x05R\x07numRows\x12\x1a\n\x08truncate\x18\x03 \x01(\x05R\x08truncate\x12\x1a\n\x08vertical\x18\x04 \x01(\x08R\x08vertical"\\\n\x0bStatSummary\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1e\n\nstatistics\x18\x02 \x03(\tR\nstatistics"Q\n\x0cStatDescribe\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols"e\n\x0cStatCrosstab\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ol1\x18\x02 \x01(\tR\x04\x63ol1\x12\x12\n\x04\x63ol2\x18\x03 \x01(\tR\x04\x63ol2"\x86\x01\n\x06NAFill\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12\x39\n\x06values\x18\x03 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values"\x86\x01\n\x06NADrop\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12\'\n\rmin_non_nulls\x18\x03 \x01(\x05H\x00R\x0bminNonNulls\x88\x01\x01\x42\x10\n\x0e_min_non_nulls"\xa8\x02\n\tNAReplace\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12H\n\x0creplacements\x18\x03 \x03(\x0b\x32$.spark.connect.NAReplace.ReplacementR\x0creplacements\x1a\x8d\x01\n\x0bReplacement\x12>\n\told_value\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x08oldValue\x12>\n\tnew_value\x18\x02 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x08newValue"r\n\x1eRenameColumnsBySameLengthNames\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12!\n\x0c\x63olumn_names\x18\x02 \x03(\tR\x0b\x63olumnNames"\x83\x02\n\x1cRenameColumnsByNameToNameMap\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12o\n\x12rename_columns_map\x18\x02 \x03(\x0b\x32\x41.spark.connect.RenameColumnsByNameToNameMap.RenameColumnsMapEntryR\x10renameColumnsMap\x1a\x43\n\x15RenameColumnsMapEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01"\x83\x01\n\x0bWithColumns\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x45\n\x0ename_expr_list\x18\x02 \x03(\x0b\x32\x1f.spark.connect.Expression.AliasR\x0cnameExprList"\x8c\x01\n\x04Hint\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12\x41\n\nparameters\x18\x03 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\nparametersB"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' + b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto"\xbf\x0e\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0b\x32\x13.spark.connect.JoinH\x00R\x04join\x12\x34\n\x06set_op\x18\x06 \x01(\x0b\x32\x1b.spark.connect.SetOperationH\x00R\x05setOp\x12)\n\x04sort\x18\x07 \x01(\x0b\x32\x13.spark.connect.SortH\x00R\x04sort\x12,\n\x05limit\x18\x08 \x01(\x0b\x32\x14.spark.connect.LimitH\x00R\x05limit\x12\x38\n\taggregate\x18\t \x01(\x0b\x32\x18.spark.connect.AggregateH\x00R\taggregate\x12&\n\x03sql\x18\n \x01(\x0b\x32\x12.spark.connect.SQLH\x00R\x03sql\x12\x45\n\x0elocal_relation\x18\x0b \x01(\x0b\x32\x1c.spark.connect.LocalRelationH\x00R\rlocalRelation\x12/\n\x06sample\x18\x0c \x01(\x0b\x32\x15.spark.connect.SampleH\x00R\x06sample\x12/\n\x06offset\x18\r \x01(\x0b\x32\x15.spark.connect.OffsetH\x00R\x06offset\x12>\n\x0b\x64\x65\x64uplicate\x18\x0e \x01(\x0b\x32\x1a.spark.connect.DeduplicateH\x00R\x0b\x64\x65\x64uplicate\x12,\n\x05range\x18\x0f \x01(\x0b\x32\x14.spark.connect.RangeH\x00R\x05range\x12\x45\n\x0esubquery_alias\x18\x10 \x01(\x0b\x32\x1c.spark.connect.SubqueryAliasH\x00R\rsubqueryAlias\x12>\n\x0brepartition\x18\x11 \x01(\x0b\x32\x1a.spark.connect.RepartitionH\x00R\x0brepartition\x12|\n#rename_columns_by_same_length_names\x18\x12 \x01(\x0b\x32-.spark.connect.RenameColumnsBySameLengthNamesH\x00R\x1erenameColumnsBySameLengthNames\x12w\n"rename_columns_by_name_to_name_map\x18\x13 \x01(\x0b\x32+.spark.connect.RenameColumnsByNameToNameMapH\x00R\x1crenameColumnsByNameToNameMap\x12<\n\x0bshow_string\x18\x14 \x01(\x0b\x32\x19.spark.connect.ShowStringH\x00R\nshowString\x12)\n\x04\x64rop\x18\x15 \x01(\x0b\x32\x13.spark.connect.DropH\x00R\x04\x64rop\x12)\n\x04tail\x18\x16 \x01(\x0b\x32\x13.spark.connect.TailH\x00R\x04tail\x12?\n\x0cwith_columns\x18\x17 \x01(\x0b\x32\x1a.spark.connect.WithColumnsH\x00R\x0bwithColumns\x12)\n\x04hint\x18\x18 \x01(\x0b\x32\x13.spark.connect.HintH\x00R\x04hint\x12\x32\n\x07unpivot\x18\x19 \x01(\x0b\x32\x16.spark.connect.UnpivotH\x00R\x07unpivot\x12\x30\n\x07\x66ill_na\x18Z \x01(\x0b\x32\x15.spark.connect.NAFillH\x00R\x06\x66illNa\x12\x30\n\x07\x64rop_na\x18[ \x01(\x0b\x32\x15.spark.connect.NADropH\x00R\x06\x64ropNa\x12\x34\n\x07replace\x18\\ \x01(\x0b\x32\x18.spark.connect.NAReplaceH\x00R\x07replace\x12\x36\n\x07summary\x18\x64 \x01(\x0b\x32\x1a.spark.connect.StatSummaryH\x00R\x07summary\x12\x39\n\x08\x63rosstab\x18\x65 \x01(\x0b\x32\x1b.spark.connect.StatCrosstabH\x00R\x08\x63rosstab\x12\x39\n\x08\x64\x65scribe\x18\x66 \x01(\x0b\x32\x1b.spark.connect.StatDescribeH\x00R\x08\x64\x65scribe\x12\x33\n\x07unknown\x18\xe7\x07 \x01(\x0b\x32\x16.spark.connect.UnknownH\x00R\x07unknownB\n\n\x08rel_type"\t\n\x07Unknown"1\n\x0eRelationCommon\x12\x1f\n\x0bsource_info\x18\x01 \x01(\tR\nsourceInfo"\x1b\n\x03SQL\x12\x14\n\x05query\x18\x01 \x01(\tR\x05query"\xaa\x03\n\x04Read\x12\x41\n\x0bnamed_table\x18\x01 \x01(\x0b\x32\x1e.spark.connect.Read.NamedTableH\x00R\nnamedTable\x12\x41\n\x0b\x64\x61ta_source\x18\x02 \x01(\x0b\x32\x1e.spark.connect.Read.DataSourceH\x00R\ndataSource\x1a=\n\nNamedTable\x12/\n\x13unparsed_identifier\x18\x01 \x01(\tR\x12unparsedIdentifier\x1a\xcf\x01\n\nDataSource\x12\x16\n\x06\x66ormat\x18\x01 \x01(\tR\x06\x66ormat\x12\x1b\n\x06schema\x18\x02 \x01(\tH\x00R\x06schema\x88\x01\x01\x12\x45\n\x07options\x18\x03 \x03(\x0b\x32+.spark.connect.Read.DataSource.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\t\n\x07_schemaB\x0b\n\tread_type"u\n\x07Project\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12;\n\x0b\x65xpressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0b\x65xpressions"p\n\x06\x46ilter\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x37\n\tcondition\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\tcondition"\xd7\x03\n\x04Join\x12+\n\x04left\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x04left\x12-\n\x05right\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x05right\x12@\n\x0ejoin_condition\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\rjoinCondition\x12\x39\n\tjoin_type\x18\x04 \x01(\x0e\x32\x1c.spark.connect.Join.JoinTypeR\x08joinType\x12#\n\rusing_columns\x18\x05 \x03(\tR\x0cusingColumns"\xd0\x01\n\x08JoinType\x12\x19\n\x15JOIN_TYPE_UNSPECIFIED\x10\x00\x12\x13\n\x0fJOIN_TYPE_INNER\x10\x01\x12\x18\n\x14JOIN_TYPE_FULL_OUTER\x10\x02\x12\x18\n\x14JOIN_TYPE_LEFT_OUTER\x10\x03\x12\x19\n\x15JOIN_TYPE_RIGHT_OUTER\x10\x04\x12\x17\n\x13JOIN_TYPE_LEFT_ANTI\x10\x05\x12\x17\n\x13JOIN_TYPE_LEFT_SEMI\x10\x06\x12\x13\n\x0fJOIN_TYPE_CROSS\x10\x07"\x8c\x03\n\x0cSetOperation\x12\x36\n\nleft_input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\tleftInput\x12\x38\n\x0bright_input\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\nrightInput\x12\x45\n\x0bset_op_type\x18\x03 \x01(\x0e\x32%.spark.connect.SetOperation.SetOpTypeR\tsetOpType\x12\x1a\n\x06is_all\x18\x04 \x01(\x08H\x00R\x05isAll\x88\x01\x01\x12\x1c\n\x07\x62y_name\x18\x05 \x01(\x08H\x01R\x06\x62yName\x88\x01\x01"r\n\tSetOpType\x12\x1b\n\x17SET_OP_TYPE_UNSPECIFIED\x10\x00\x12\x19\n\x15SET_OP_TYPE_INTERSECT\x10\x01\x12\x15\n\x11SET_OP_TYPE_UNION\x10\x02\x12\x16\n\x12SET_OP_TYPE_EXCEPT\x10\x03\x42\t\n\x07_is_allB\n\n\x08_by_name"L\n\x05Limit\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05limit\x18\x02 \x01(\x05R\x05limit"O\n\x06Offset\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x16\n\x06offset\x18\x02 \x01(\x05R\x06offset"K\n\x04Tail\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05limit\x18\x02 \x01(\x05R\x05limit"\xd2\x01\n\tAggregate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12L\n\x14grouping_expressions\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13groupingExpressions\x12H\n\x12result_expressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x11resultExpressions"\xa6\x04\n\x04Sort\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12>\n\x0bsort_fields\x18\x02 \x03(\x0b\x32\x1d.spark.connect.Sort.SortFieldR\nsortFields\x12 \n\tis_global\x18\x03 \x01(\x08H\x00R\x08isGlobal\x88\x01\x01\x1a\xbc\x01\n\tSortField\x12\x39\n\nexpression\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\nexpression\x12?\n\tdirection\x18\x02 \x01(\x0e\x32!.spark.connect.Sort.SortDirectionR\tdirection\x12\x33\n\x05nulls\x18\x03 \x01(\x0e\x32\x1d.spark.connect.Sort.SortNullsR\x05nulls"l\n\rSortDirection\x12\x1e\n\x1aSORT_DIRECTION_UNSPECIFIED\x10\x00\x12\x1c\n\x18SORT_DIRECTION_ASCENDING\x10\x01\x12\x1d\n\x19SORT_DIRECTION_DESCENDING\x10\x02"R\n\tSortNulls\x12\x1a\n\x16SORT_NULLS_UNSPECIFIED\x10\x00\x12\x14\n\x10SORT_NULLS_FIRST\x10\x01\x12\x13\n\x0fSORT_NULLS_LAST\x10\x02\x42\x0c\n\n_is_global"d\n\x04\x44rop\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12-\n\x04\x63ols\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x04\x63ols"\xab\x01\n\x0b\x44\x65\x64uplicate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12!\n\x0c\x63olumn_names\x18\x02 \x03(\tR\x0b\x63olumnNames\x12\x32\n\x13\x61ll_columns_as_keys\x18\x03 \x01(\x08H\x00R\x10\x61llColumnsAsKeys\x88\x01\x01\x42\x16\n\x14_all_columns_as_keys"\x89\x01\n\rLocalRelation\x12\x12\n\x04\x64\x61ta\x18\x01 \x01(\x0cR\x04\x64\x61ta\x12\x35\n\x08\x64\x61tatype\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\x08\x64\x61tatype\x12#\n\x0c\x64\x61tatype_str\x18\x03 \x01(\tH\x00R\x0b\x64\x61tatypeStrB\x08\n\x06schema"\xe0\x01\n\x06Sample\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1f\n\x0blower_bound\x18\x02 \x01(\x01R\nlowerBound\x12\x1f\n\x0bupper_bound\x18\x03 \x01(\x01R\nupperBound\x12.\n\x10with_replacement\x18\x04 \x01(\x08H\x00R\x0fwithReplacement\x88\x01\x01\x12\x17\n\x04seed\x18\x05 \x01(\x03H\x01R\x04seed\x88\x01\x01\x42\x13\n\x11_with_replacementB\x07\n\x05_seed"\x91\x01\n\x05Range\x12\x19\n\x05start\x18\x01 \x01(\x03H\x00R\x05start\x88\x01\x01\x12\x10\n\x03\x65nd\x18\x02 \x01(\x03R\x03\x65nd\x12\x12\n\x04step\x18\x03 \x01(\x03R\x04step\x12*\n\x0enum_partitions\x18\x04 \x01(\x05H\x01R\rnumPartitions\x88\x01\x01\x42\x08\n\x06_startB\x11\n\x0f_num_partitions"r\n\rSubqueryAlias\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05\x61lias\x18\x02 \x01(\tR\x05\x61lias\x12\x1c\n\tqualifier\x18\x03 \x03(\tR\tqualifier"\x8e\x01\n\x0bRepartition\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12%\n\x0enum_partitions\x18\x02 \x01(\x05R\rnumPartitions\x12\x1d\n\x07shuffle\x18\x03 \x01(\x08H\x00R\x07shuffle\x88\x01\x01\x42\n\n\x08_shuffle"\x8d\x01\n\nShowString\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x18\n\x07numRows\x18\x02 \x01(\x05R\x07numRows\x12\x1a\n\x08truncate\x18\x03 \x01(\x05R\x08truncate\x12\x1a\n\x08vertical\x18\x04 \x01(\x08R\x08vertical"\\\n\x0bStatSummary\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1e\n\nstatistics\x18\x02 \x03(\tR\nstatistics"Q\n\x0cStatDescribe\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols"e\n\x0cStatCrosstab\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ol1\x18\x02 \x01(\tR\x04\x63ol1\x12\x12\n\x04\x63ol2\x18\x03 \x01(\tR\x04\x63ol2"\x86\x01\n\x06NAFill\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12\x39\n\x06values\x18\x03 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values"\x86\x01\n\x06NADrop\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12\'\n\rmin_non_nulls\x18\x03 \x01(\x05H\x00R\x0bminNonNulls\x88\x01\x01\x42\x10\n\x0e_min_non_nulls"\xa8\x02\n\tNAReplace\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12H\n\x0creplacements\x18\x03 \x03(\x0b\x32$.spark.connect.NAReplace.ReplacementR\x0creplacements\x1a\x8d\x01\n\x0bReplacement\x12>\n\told_value\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x08oldValue\x12>\n\tnew_value\x18\x02 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x08newValue"r\n\x1eRenameColumnsBySameLengthNames\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12!\n\x0c\x63olumn_names\x18\x02 \x03(\tR\x0b\x63olumnNames"\x83\x02\n\x1cRenameColumnsByNameToNameMap\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12o\n\x12rename_columns_map\x18\x02 \x03(\x0b\x32\x41.spark.connect.RenameColumnsByNameToNameMap.RenameColumnsMapEntryR\x10renameColumnsMap\x1a\x43\n\x15RenameColumnsMapEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01"\x83\x01\n\x0bWithColumns\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x45\n\x0ename_expr_list\x18\x02 \x03(\x0b\x32\x1f.spark.connect.Expression.AliasR\x0cnameExprList"\x8c\x01\n\x04Hint\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12\x41\n\nparameters\x18\x03 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\nparameters"\xf6\x01\n\x07Unpivot\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12+\n\x03ids\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x03ids\x12\x31\n\x06values\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x06values\x12\x30\n\x14variable_column_name\x18\x04 \x01(\tR\x12variableColumnName\x12*\n\x11value_column_name\x18\x05 \x01(\tR\x0fvalueColumnNameB"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' ) @@ -77,6 +78,7 @@ ) _WITHCOLUMNS = DESCRIPTOR.message_types_by_name["WithColumns"] _HINT = DESCRIPTOR.message_types_by_name["Hint"] +_UNPIVOT = DESCRIPTOR.message_types_by_name["Unpivot"] _JOIN_JOINTYPE = _JOIN.enum_types_by_name["JoinType"] _SETOPERATION_SETOPTYPE = _SETOPERATION.enum_types_by_name["SetOpType"] _SORT_SORTDIRECTION = _SORT.enum_types_by_name["SortDirection"] @@ -493,6 +495,17 @@ ) _sym_db.RegisterMessage(Hint) +Unpivot = _reflection.GeneratedProtocolMessageType( + "Unpivot", + (_message.Message,), + { + "DESCRIPTOR": _UNPIVOT, + "__module__": "spark.connect.relations_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.Unpivot) + }, +) +_sym_db.RegisterMessage(Unpivot) + if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None @@ -501,88 +514,90 @@ _READ_DATASOURCE_OPTIONSENTRY._serialized_options = b"8\001" _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._options = None _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_options = b"8\001" - _RELATION._serialized_start = 82 - _RELATION._serialized_end = 1885 - _UNKNOWN._serialized_start = 1887 - _UNKNOWN._serialized_end = 1896 - _RELATIONCOMMON._serialized_start = 1898 - _RELATIONCOMMON._serialized_end = 1947 - _SQL._serialized_start = 1949 - _SQL._serialized_end = 1976 - _READ._serialized_start = 1979 - _READ._serialized_end = 2405 - _READ_NAMEDTABLE._serialized_start = 2121 - _READ_NAMEDTABLE._serialized_end = 2182 - _READ_DATASOURCE._serialized_start = 2185 - _READ_DATASOURCE._serialized_end = 2392 - _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 2323 - _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 2381 - _PROJECT._serialized_start = 2407 - _PROJECT._serialized_end = 2524 - _FILTER._serialized_start = 2526 - _FILTER._serialized_end = 2638 - _JOIN._serialized_start = 2641 - _JOIN._serialized_end = 3112 - _JOIN_JOINTYPE._serialized_start = 2904 - _JOIN_JOINTYPE._serialized_end = 3112 - _SETOPERATION._serialized_start = 3115 - _SETOPERATION._serialized_end = 3511 - _SETOPERATION_SETOPTYPE._serialized_start = 3374 - _SETOPERATION_SETOPTYPE._serialized_end = 3488 - _LIMIT._serialized_start = 3513 - _LIMIT._serialized_end = 3589 - _OFFSET._serialized_start = 3591 - _OFFSET._serialized_end = 3670 - _TAIL._serialized_start = 3672 - _TAIL._serialized_end = 3747 - _AGGREGATE._serialized_start = 3750 - _AGGREGATE._serialized_end = 3960 - _SORT._serialized_start = 3963 - _SORT._serialized_end = 4513 - _SORT_SORTFIELD._serialized_start = 4117 - _SORT_SORTFIELD._serialized_end = 4305 - _SORT_SORTDIRECTION._serialized_start = 4307 - _SORT_SORTDIRECTION._serialized_end = 4415 - _SORT_SORTNULLS._serialized_start = 4417 - _SORT_SORTNULLS._serialized_end = 4499 - _DROP._serialized_start = 4515 - _DROP._serialized_end = 4615 - _DEDUPLICATE._serialized_start = 4618 - _DEDUPLICATE._serialized_end = 4789 - _LOCALRELATION._serialized_start = 4791 - _LOCALRELATION._serialized_end = 4826 - _SAMPLE._serialized_start = 4829 - _SAMPLE._serialized_end = 5053 - _RANGE._serialized_start = 5056 - _RANGE._serialized_end = 5201 - _SUBQUERYALIAS._serialized_start = 5203 - _SUBQUERYALIAS._serialized_end = 5317 - _REPARTITION._serialized_start = 5320 - _REPARTITION._serialized_end = 5462 - _SHOWSTRING._serialized_start = 5465 - _SHOWSTRING._serialized_end = 5606 - _STATSUMMARY._serialized_start = 5608 - _STATSUMMARY._serialized_end = 5700 - _STATDESCRIBE._serialized_start = 5702 - _STATDESCRIBE._serialized_end = 5783 - _STATCROSSTAB._serialized_start = 5785 - _STATCROSSTAB._serialized_end = 5886 - _NAFILL._serialized_start = 5889 - _NAFILL._serialized_end = 6023 - _NADROP._serialized_start = 6026 - _NADROP._serialized_end = 6160 - _NAREPLACE._serialized_start = 6163 - _NAREPLACE._serialized_end = 6459 - _NAREPLACE_REPLACEMENT._serialized_start = 6318 - _NAREPLACE_REPLACEMENT._serialized_end = 6459 - _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 6461 - _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 6575 - _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 6578 - _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 6837 - _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start = 6770 - _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 6837 - _WITHCOLUMNS._serialized_start = 6840 - _WITHCOLUMNS._serialized_end = 6971 - _HINT._serialized_start = 6974 - _HINT._serialized_end = 7114 + _RELATION._serialized_start = 109 + _RELATION._serialized_end = 1964 + _UNKNOWN._serialized_start = 1966 + _UNKNOWN._serialized_end = 1975 + _RELATIONCOMMON._serialized_start = 1977 + _RELATIONCOMMON._serialized_end = 2026 + _SQL._serialized_start = 2028 + _SQL._serialized_end = 2055 + _READ._serialized_start = 2058 + _READ._serialized_end = 2484 + _READ_NAMEDTABLE._serialized_start = 2200 + _READ_NAMEDTABLE._serialized_end = 2261 + _READ_DATASOURCE._serialized_start = 2264 + _READ_DATASOURCE._serialized_end = 2471 + _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 2402 + _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 2460 + _PROJECT._serialized_start = 2486 + _PROJECT._serialized_end = 2603 + _FILTER._serialized_start = 2605 + _FILTER._serialized_end = 2717 + _JOIN._serialized_start = 2720 + _JOIN._serialized_end = 3191 + _JOIN_JOINTYPE._serialized_start = 2983 + _JOIN_JOINTYPE._serialized_end = 3191 + _SETOPERATION._serialized_start = 3194 + _SETOPERATION._serialized_end = 3590 + _SETOPERATION_SETOPTYPE._serialized_start = 3453 + _SETOPERATION_SETOPTYPE._serialized_end = 3567 + _LIMIT._serialized_start = 3592 + _LIMIT._serialized_end = 3668 + _OFFSET._serialized_start = 3670 + _OFFSET._serialized_end = 3749 + _TAIL._serialized_start = 3751 + _TAIL._serialized_end = 3826 + _AGGREGATE._serialized_start = 3829 + _AGGREGATE._serialized_end = 4039 + _SORT._serialized_start = 4042 + _SORT._serialized_end = 4592 + _SORT_SORTFIELD._serialized_start = 4196 + _SORT_SORTFIELD._serialized_end = 4384 + _SORT_SORTDIRECTION._serialized_start = 4386 + _SORT_SORTDIRECTION._serialized_end = 4494 + _SORT_SORTNULLS._serialized_start = 4496 + _SORT_SORTNULLS._serialized_end = 4578 + _DROP._serialized_start = 4594 + _DROP._serialized_end = 4694 + _DEDUPLICATE._serialized_start = 4697 + _DEDUPLICATE._serialized_end = 4868 + _LOCALRELATION._serialized_start = 4871 + _LOCALRELATION._serialized_end = 5008 + _SAMPLE._serialized_start = 5011 + _SAMPLE._serialized_end = 5235 + _RANGE._serialized_start = 5238 + _RANGE._serialized_end = 5383 + _SUBQUERYALIAS._serialized_start = 5385 + _SUBQUERYALIAS._serialized_end = 5499 + _REPARTITION._serialized_start = 5502 + _REPARTITION._serialized_end = 5644 + _SHOWSTRING._serialized_start = 5647 + _SHOWSTRING._serialized_end = 5788 + _STATSUMMARY._serialized_start = 5790 + _STATSUMMARY._serialized_end = 5882 + _STATDESCRIBE._serialized_start = 5884 + _STATDESCRIBE._serialized_end = 5965 + _STATCROSSTAB._serialized_start = 5967 + _STATCROSSTAB._serialized_end = 6068 + _NAFILL._serialized_start = 6071 + _NAFILL._serialized_end = 6205 + _NADROP._serialized_start = 6208 + _NADROP._serialized_end = 6342 + _NAREPLACE._serialized_start = 6345 + _NAREPLACE._serialized_end = 6641 + _NAREPLACE_REPLACEMENT._serialized_start = 6500 + _NAREPLACE_REPLACEMENT._serialized_end = 6641 + _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 6643 + _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 6757 + _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 6760 + _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 7019 + _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start = 6952 + _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 7019 + _WITHCOLUMNS._serialized_start = 7022 + _WITHCOLUMNS._serialized_end = 7153 + _HINT._serialized_start = 7156 + _HINT._serialized_end = 7296 + _UNPIVOT._serialized_start = 7299 + _UNPIVOT._serialized_end = 7545 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi index f133661368..e942a63629 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.pyi +++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi @@ -40,6 +40,7 @@ import google.protobuf.internal.containers import google.protobuf.internal.enum_type_wrapper import google.protobuf.message import pyspark.sql.connect.proto.expressions_pb2 +import pyspark.sql.connect.proto.types_pb2 import sys import typing @@ -83,6 +84,7 @@ class Relation(google.protobuf.message.Message): TAIL_FIELD_NUMBER: builtins.int WITH_COLUMNS_FIELD_NUMBER: builtins.int HINT_FIELD_NUMBER: builtins.int + UNPIVOT_FIELD_NUMBER: builtins.int FILL_NA_FIELD_NUMBER: builtins.int DROP_NA_FIELD_NUMBER: builtins.int REPLACE_FIELD_NUMBER: builtins.int @@ -139,6 +141,8 @@ class Relation(google.protobuf.message.Message): @property def hint(self) -> global___Hint: ... @property + def unpivot(self) -> global___Unpivot: ... + @property def fill_na(self) -> global___NAFill: """NA functions""" @property @@ -181,6 +185,7 @@ class Relation(google.protobuf.message.Message): tail: global___Tail | None = ..., with_columns: global___WithColumns | None = ..., hint: global___Hint | None = ..., + unpivot: global___Unpivot | None = ..., fill_na: global___NAFill | None = ..., drop_na: global___NADrop | None = ..., replace: global___NAReplace | None = ..., @@ -254,6 +259,8 @@ class Relation(google.protobuf.message.Message): b"tail", "unknown", b"unknown", + "unpivot", + b"unpivot", "with_columns", b"with_columns", ], @@ -323,6 +330,8 @@ class Relation(google.protobuf.message.Message): b"tail", "unknown", b"unknown", + "unpivot", + b"unpivot", "with_columns", b"with_columns", ], @@ -353,6 +362,7 @@ class Relation(google.protobuf.message.Message): "tail", "with_columns", "hint", + "unpivot", "fill_na", "drop_na", "replace", @@ -1159,16 +1169,45 @@ class LocalRelation(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor DATA_FIELD_NUMBER: builtins.int + DATATYPE_FIELD_NUMBER: builtins.int + DATATYPE_STR_FIELD_NUMBER: builtins.int data: builtins.bytes """Local collection data serialized into Arrow IPC streaming format which contains the schema of the data. """ + @property + def datatype(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ... + datatype_str: builtins.str + """Server will use Catalyst parser to parse this string to DataType.""" def __init__( self, *, data: builtins.bytes = ..., + datatype: pyspark.sql.connect.proto.types_pb2.DataType | None = ..., + datatype_str: builtins.str = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["data", b"data"]) -> None: ... + def HasField( + self, + field_name: typing_extensions.Literal[ + "datatype", b"datatype", "datatype_str", b"datatype_str", "schema", b"schema" + ], + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "data", + b"data", + "datatype", + b"datatype", + "datatype_str", + b"datatype_str", + "schema", + b"schema", + ], + ) -> None: ... + def WhichOneof( + self, oneof_group: typing_extensions.Literal["schema", b"schema"] + ) -> typing_extensions.Literal["datatype", "datatype_str"] | None: ... global___LocalRelation = LocalRelation @@ -1963,3 +2002,66 @@ class Hint(google.protobuf.message.Message): ) -> None: ... global___Hint = Hint + +class Unpivot(google.protobuf.message.Message): + """Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns set.""" + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + INPUT_FIELD_NUMBER: builtins.int + IDS_FIELD_NUMBER: builtins.int + VALUES_FIELD_NUMBER: builtins.int + VARIABLE_COLUMN_NAME_FIELD_NUMBER: builtins.int + VALUE_COLUMN_NAME_FIELD_NUMBER: builtins.int + @property + def input(self) -> global___Relation: + """(Required) The input relation.""" + @property + def ids( + self, + ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[ + pyspark.sql.connect.proto.expressions_pb2.Expression + ]: + """(Required) Id columns.""" + @property + def values( + self, + ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[ + pyspark.sql.connect.proto.expressions_pb2.Expression + ]: + """(Optional) Value columns to unpivot.""" + variable_column_name: builtins.str + """(Required) Name of the variable column.""" + value_column_name: builtins.str + """(Required) Name of the value column.""" + def __init__( + self, + *, + input: global___Relation | None = ..., + ids: collections.abc.Iterable[pyspark.sql.connect.proto.expressions_pb2.Expression] + | None = ..., + values: collections.abc.Iterable[pyspark.sql.connect.proto.expressions_pb2.Expression] + | None = ..., + variable_column_name: builtins.str = ..., + value_column_name: builtins.str = ..., + ) -> None: ... + def HasField( + self, field_name: typing_extensions.Literal["input", b"input"] + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "ids", + b"ids", + "input", + b"input", + "value_column_name", + b"value_column_name", + "values", + b"values", + "variable_column_name", + b"variable_column_name", + ], + ) -> None: ... + +global___Unpivot = Unpivot diff --git a/python/pyspark/sql/connect/readwriter.py b/python/pyspark/sql/connect/readwriter.py index 45239a2fa2..778509bcf7 100644 --- a/python/pyspark/sql/connect/readwriter.py +++ b/python/pyspark/sql/connect/readwriter.py @@ -50,7 +50,7 @@ def _set_opts( self.option(k, v) # type: ignore[attr-defined] -class DataFrameReader: +class DataFrameReader(OptionUtils): """ TODO(SPARK-40539) Achieve parity with PySpark. """ @@ -164,7 +164,6 @@ def load( return self._df(plan) def _df(self, plan: LogicalPlan) -> "DataFrame": - # The import is needed here to avoid circular import issues. from pyspark.sql.connect.dataframe import DataFrame return DataFrame.withPlan(plan, self._client) @@ -172,6 +171,164 @@ def _df(self, plan: LogicalPlan) -> "DataFrame": def table(self, tableName: str) -> "DataFrame": return self._df(Read(tableName)) + def json( + self, + path: str, + schema: Optional[str] = None, + primitivesAsString: Optional[Union[bool, str]] = None, + prefersDecimal: Optional[Union[bool, str]] = None, + allowComments: Optional[Union[bool, str]] = None, + allowUnquotedFieldNames: Optional[Union[bool, str]] = None, + allowSingleQuotes: Optional[Union[bool, str]] = None, + allowNumericLeadingZero: Optional[Union[bool, str]] = None, + allowBackslashEscapingAnyCharacter: Optional[Union[bool, str]] = None, + mode: Optional[str] = None, + columnNameOfCorruptRecord: Optional[str] = None, + dateFormat: Optional[str] = None, + timestampFormat: Optional[str] = None, + multiLine: Optional[Union[bool, str]] = None, + allowUnquotedControlChars: Optional[Union[bool, str]] = None, + lineSep: Optional[str] = None, + samplingRatio: Optional[Union[float, str]] = None, + dropFieldIfAllNull: Optional[Union[bool, str]] = None, + encoding: Optional[str] = None, + locale: Optional[str] = None, + pathGlobFilter: Optional[Union[bool, str]] = None, + recursiveFileLookup: Optional[Union[bool, str]] = None, + modifiedBefore: Optional[Union[bool, str]] = None, + modifiedAfter: Optional[Union[bool, str]] = None, + allowNonNumericNumbers: Optional[Union[bool, str]] = None, + ) -> "DataFrame": + """ + Loads JSON files and returns the results as a :class:`DataFrame`. + + `JSON Lines `_ (newline-delimited JSON) is supported by default. + For JSON (one record per file), set the ``multiLine`` parameter to ``true``. + + If the ``schema`` parameter is not specified, this function goes + through the input once to determine the input schema. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + path : str + string represents path to the JSON dataset + schema : str, optional + a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``). + + Other Parameters + ---------------- + Extra options + For the extra options, refer to + `Data Source Option `_ + for the version you use. + + .. # noqa + + Examples + -------- + Write a DataFrame into a JSON file and read it back. + + >>> import tempfile + >>> with tempfile.TemporaryDirectory() as d: + ... # Write a DataFrame into a JSON file + ... spark.createDataFrame( + ... [{"age": 100, "name": "Hyukjin Kwon"}] + ... ).write.mode("overwrite").format("json").save(d) + ... + ... # Read the JSON file as a DataFrame. + ... spark.read.json(d).show() + +---+------------+ + |age| name| + +---+------------+ + |100|Hyukjin Kwon| + +---+------------+ + """ + self._set_opts( + primitivesAsString=primitivesAsString, + prefersDecimal=prefersDecimal, + allowComments=allowComments, + allowUnquotedFieldNames=allowUnquotedFieldNames, + allowSingleQuotes=allowSingleQuotes, + allowNumericLeadingZero=allowNumericLeadingZero, + allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter, + mode=mode, + columnNameOfCorruptRecord=columnNameOfCorruptRecord, + dateFormat=dateFormat, + timestampFormat=timestampFormat, + multiLine=multiLine, + allowUnquotedControlChars=allowUnquotedControlChars, + lineSep=lineSep, + samplingRatio=samplingRatio, + dropFieldIfAllNull=dropFieldIfAllNull, + encoding=encoding, + locale=locale, + pathGlobFilter=pathGlobFilter, + recursiveFileLookup=recursiveFileLookup, + modifiedBefore=modifiedBefore, + modifiedAfter=modifiedAfter, + allowNonNumericNumbers=allowNonNumericNumbers, + ) + return self.load(path=path, format="json", schema=schema) + + def parquet(self, path: str, **options: "OptionalPrimitiveType") -> "DataFrame": + """ + Loads Parquet files, returning the result as a :class:`DataFrame`. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + path : str + + Other Parameters + ---------------- + **options + For the extra options, refer to + `Data Source Option `_ + for the version you use. + + .. # noqa + + Examples + -------- + Write a DataFrame into a Parquet file and read it back. + + >>> import tempfile + >>> with tempfile.TemporaryDirectory() as d: + ... # Write a DataFrame into a Parquet file + ... spark.createDataFrame( + ... [{"age": 100, "name": "Hyukjin Kwon"}] + ... ).write.mode("overwrite").format("parquet").save(d) + ... + ... # Read the Parquet file as a DataFrame. + ... spark.read.parquet(d).show() + +---+------------+ + |age| name| + +---+------------+ + |100|Hyukjin Kwon| + +---+------------+ + """ + mergeSchema = options.get("mergeSchema", None) + pathGlobFilter = options.get("pathGlobFilter", None) + modifiedBefore = options.get("modifiedBefore", None) + modifiedAfter = options.get("modifiedAfter", None) + recursiveFileLookup = options.get("recursiveFileLookup", None) + datetimeRebaseMode = options.get("datetimeRebaseMode", None) + int96RebaseMode = options.get("int96RebaseMode", None) + self._set_opts( + mergeSchema=mergeSchema, + pathGlobFilter=pathGlobFilter, + recursiveFileLookup=recursiveFileLookup, + modifiedBefore=modifiedBefore, + modifiedAfter=modifiedAfter, + datetimeRebaseMode=datetimeRebaseMode, + int96RebaseMode=int96RebaseMode, + ) + + return self.load(path=path, format="parquet") + class DataFrameWriter(OptionUtils): """ diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index 28aebbdecb..0a3d03110f 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -16,17 +16,35 @@ # from threading import RLock -from typing import Optional, Any, Union, Dict, cast, overload +from collections.abc import Sized + +import numpy as np import pandas as pd +import pyarrow as pa + +from pyspark.sql.types import DataType, StructType -import pyspark.sql.types from pyspark.sql.connect.client import SparkConnectClient from pyspark.sql.connect.dataframe import DataFrame -from pyspark.sql.connect.plan import SQL, Range +from pyspark.sql.connect.plan import SQL, Range, LocalRelation from pyspark.sql.connect.readwriter import DataFrameReader from pyspark.sql.utils import to_str -from . import plan -from ._typing import OptionalPrimitiveType + +from typing import ( + Optional, + Any, + Union, + Dict, + List, + Tuple, + cast, + overload, + Iterable, + TYPE_CHECKING, +) + +if TYPE_CHECKING: + from pyspark.sql.connect._typing import OptionalPrimitiveType # TODO(SPARK-38912): This method can be dropped once support for Python 3.8 is dropped @@ -240,7 +258,11 @@ def read(self) -> "DataFrameReader": """ return DataFrameReader(self) - def createDataFrame(self, data: "pd.DataFrame") -> "DataFrame": + def createDataFrame( + self, + data: Union["pd.DataFrame", "np.ndarray", Iterable[Any]], + schema: Optional[Union[StructType, str, List[str], Tuple[str, ...]]] = None, + ) -> "DataFrame": """ Creates a :class:`DataFrame` from a :class:`pandas.DataFrame`. @@ -249,7 +271,15 @@ def createDataFrame(self, data: "pd.DataFrame") -> "DataFrame": Parameters ---------- - data : :class:`pandas.DataFrame` + data : :class:`pandas.DataFrame` or :class:`list`, or :class:`numpy.ndarray`. + schema : :class:`pyspark.sql.types.DataType`, str or list, optional + + When ``schema`` is :class:`pyspark.sql.types.DataType` or a datatype string, it must + match the real data, or an exception will be thrown at runtime. If the given schema is + not :class:`pyspark.sql.types.StructType`, it will be wrapped into a + :class:`pyspark.sql.types.StructType` as its only field, and the field name will be + "value". Each record will also be wrapped into a tuple, which can be converted to row + later. Returns ------- @@ -264,9 +294,71 @@ def createDataFrame(self, data: "pd.DataFrame") -> "DataFrame": """ assert data is not None - if len(data) == 0: + if isinstance(data, DataFrame): + raise TypeError("data is already a DataFrame") + if isinstance(data, Sized) and len(data) == 0: raise ValueError("Input data cannot be empty") - return DataFrame.withPlan(plan.LocalRelation(data), self) + + _schema: Optional[StructType] = None + _schema_str: Optional[str] = None + _cols: Optional[List[str]] = None + + if isinstance(schema, StructType): + _schema = schema + + elif isinstance(schema, str): + _schema_str = schema + + elif isinstance(schema, (list, tuple)): + # Must re-encode any unicode strings to be consistent with StructField names + _cols = [x.encode("utf-8") if not isinstance(x, str) else x for x in schema] + + # Create the Pandas DataFrame + if isinstance(data, pd.DataFrame): + pdf = data + + elif isinstance(data, np.ndarray): + # `data` of numpy.ndarray type will be converted to a pandas DataFrame, + if data.ndim not in [1, 2]: + raise ValueError("NumPy array input should be of 1 or 2 dimensions.") + + pdf = pd.DataFrame(data) + + if _cols is None: + if data.ndim == 1 or data.shape[1] == 1: + _cols = ["value"] + else: + _cols = ["_%s" % i for i in range(1, data.shape[1] + 1)] + + else: + pdf = pd.DataFrame(list(data)) + + if _cols is None: + _cols = ["_%s" % i for i in range(1, pdf.shape[1] + 1)] + + # Validate number of columns + num_cols = pdf.shape[1] + if _schema is not None and len(_schema.fields) != num_cols: + raise ValueError( + f"Length mismatch: Expected axis has {num_cols} elements, " + f"new values have {len(_schema.fields)} elements" + ) + elif _cols is not None and len(_cols) != num_cols: + raise ValueError( + f"Length mismatch: Expected axis has {num_cols} elements, " + f"new values have {len(_cols)} elements" + ) + + table = pa.Table.from_pandas(pdf) + + if _schema is not None: + return DataFrame.withPlan(LocalRelation(table, schema=_schema), self) + elif _schema_str is not None: + return DataFrame.withPlan(LocalRelation(table, schema=_schema_str), self) + elif _cols is not None and len(_cols) > 0: + return DataFrame.withPlan(LocalRelation(table), self).toDF(*_cols) + else: + return DataFrame.withPlan(LocalRelation(table), self) @property def client(self) -> "SparkConnectClient": @@ -279,9 +371,7 @@ def client(self) -> "SparkConnectClient": """ return self._client - def register_udf( - self, function: Any, return_type: Union[str, pyspark.sql.types.DataType] - ) -> str: + def register_udf(self, function: Any, return_type: Union[str, DataType]) -> str: return self._client.register_udf(function, return_type) def sql(self, sql_string: str) -> "DataFrame": diff --git a/python/pyspark/sql/connect/types.py b/python/pyspark/sql/connect/types.py new file mode 100644 index 0000000000..55f5953660 --- /dev/null +++ b/python/pyspark/sql/connect/types.py @@ -0,0 +1,143 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Optional + +import pyspark.sql.connect.proto as pb2 +from pyspark.sql.types import ( + DataType, + ByteType, + ShortType, + IntegerType, + FloatType, + DateType, + TimestampType, + DayTimeIntervalType, + MapType, + StringType, + CharType, + VarcharType, + StructType, + StructField, + ArrayType, + DoubleType, + LongType, + DecimalType, + BinaryType, + BooleanType, + NullType, +) + + +def pyspark_types_to_proto_types(data_type: DataType) -> pb2.DataType: + ret = pb2.DataType() + if isinstance(data_type, StringType): + ret.string.CopyFrom(pb2.DataType.String()) + elif isinstance(data_type, BooleanType): + ret.boolean.CopyFrom(pb2.DataType.Boolean()) + elif isinstance(data_type, BinaryType): + ret.binary.CopyFrom(pb2.DataType.Binary()) + elif isinstance(data_type, ByteType): + ret.byte.CopyFrom(pb2.DataType.Byte()) + elif isinstance(data_type, ShortType): + ret.short.CopyFrom(pb2.DataType.Short()) + elif isinstance(data_type, IntegerType): + ret.integer.CopyFrom(pb2.DataType.Integer()) + elif isinstance(data_type, LongType): + ret.long.CopyFrom(pb2.DataType.Long()) + elif isinstance(data_type, FloatType): + ret.float.CopyFrom(pb2.DataType.Float()) + elif isinstance(data_type, DoubleType): + ret.double.CopyFrom(pb2.DataType.Double()) + elif isinstance(data_type, DecimalType): + ret.decimal.CopyFrom(pb2.DataType.Decimal()) + elif isinstance(data_type, DayTimeIntervalType): + ret.day_time_interval.start_field = data_type.startField + ret.day_time_interval.end_field = data_type.endField + else: + raise Exception(f"Unsupported data type {data_type}") + return ret + + +def proto_schema_to_pyspark_data_type(schema: pb2.DataType) -> DataType: + if schema.HasField("null"): + return NullType() + elif schema.HasField("boolean"): + return BooleanType() + elif schema.HasField("binary"): + return BinaryType() + elif schema.HasField("byte"): + return ByteType() + elif schema.HasField("short"): + return ShortType() + elif schema.HasField("integer"): + return IntegerType() + elif schema.HasField("long"): + return LongType() + elif schema.HasField("float"): + return FloatType() + elif schema.HasField("double"): + return DoubleType() + elif schema.HasField("decimal"): + p = schema.decimal.precision if schema.decimal.HasField("precision") else 10 + s = schema.decimal.scale if schema.decimal.HasField("scale") else 0 + return DecimalType(precision=p, scale=s) + elif schema.HasField("string"): + return StringType() + elif schema.HasField("char"): + return CharType(schema.char.length) + elif schema.HasField("var_char"): + return VarcharType(schema.var_char.length) + elif schema.HasField("date"): + return DateType() + elif schema.HasField("timestamp"): + return TimestampType() + elif schema.HasField("day_time_interval"): + start: Optional[int] = ( + schema.day_time_interval.start_field + if schema.day_time_interval.HasField("start_field") + else None + ) + end: Optional[int] = ( + schema.day_time_interval.end_field + if schema.day_time_interval.HasField("end_field") + else None + ) + return DayTimeIntervalType(startField=start, endField=end) + elif schema.HasField("array"): + return ArrayType( + proto_schema_to_pyspark_data_type(schema.array.element_type), + schema.array.contains_null, + ) + elif schema.HasField("struct"): + fields = [ + StructField( + f.name, + proto_schema_to_pyspark_data_type(f.data_type), + f.nullable, + ) + for f in schema.struct.fields + ] + return StructType(fields) + elif schema.HasField("map"): + return MapType( + proto_schema_to_pyspark_data_type(schema.map.key_type), + proto_schema_to_pyspark_data_type(schema.map.value_type), + schema.map.value_contains_null, + ) + else: + raise Exception(f"Unsupported data type {schema}") diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 9746196dc9..de540c6249 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -5560,11 +5560,11 @@ def decode(col: "ColumnOrName", charset: str) -> Column: -------- >>> df = spark.createDataFrame([('abcd',)], ['a']) >>> df.select(decode("a", "UTF-8")).show() - +----------------------+ - |stringdecode(a, UTF-8)| - +----------------------+ - | abcd| - +----------------------+ + +----------------+ + |decode(a, UTF-8)| + +----------------+ + | abcd| + +----------------+ """ return _invoke_function("decode", _to_java_column(col), charset) @@ -8036,7 +8036,7 @@ def sequence( def from_csv( col: "ColumnOrName", - schema: Union[StructType, Column, str], + schema: Union[Column, str], options: Optional[Dict[str, str]] = None, ) -> Column: """ diff --git a/python/pyspark/sql/pandas/utils.py b/python/pyspark/sql/pandas/utils.py index c51a90ca57..ab62b955b6 100644 --- a/python/pyspark/sql/pandas/utils.py +++ b/python/pyspark/sql/pandas/utils.py @@ -73,6 +73,25 @@ def require_minimum_pyarrow_version() -> None: ) +def require_minimum_grpc_version() -> None: + """Raise ImportError if minimum version of grpc is not installed""" + minimum_pandas_version = "1.48.1" + + from distutils.version import LooseVersion + + try: + import grpc + except ImportError as error: + raise ImportError( + "grpc >= %s must be installed; however, " "it was not found." % minimum_pandas_version + ) from error + if LooseVersion(grpc.__version__) < LooseVersion(minimum_pandas_version): + raise ImportError( + "Pandas >= %s must be installed; however, " + "your version was %s." % (minimum_pandas_version, grpc.__version__) + ) + + def pyarrow_version_less_than_minimum(minimum_pyarrow_version: str) -> bool: """Return False if the installed pyarrow version is less than minimum_pyarrow_version or if pyarrow is not installed.""" diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 98150731c2..6dabbaedff 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -14,38 +14,28 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import Any import unittest import shutil import tempfile -import grpc # type: ignore - -from pyspark.sql.connect.column import Column -from pyspark.testing.sqlutils import have_pandas, SQLTestUtils - -if have_pandas: - import pandas - +from pyspark.testing.sqlutils import SQLTestUtils from pyspark.sql import SparkSession, Row from pyspark.sql.types import StructType, StructField, LongType, StringType +import pyspark.sql.functions +from pyspark.testing.utils import ReusedPySparkTestCase +from pyspark.testing.connectutils import should_test_connect, connect_requirement_message +from pyspark.testing.pandasutils import PandasOnSparkTestCase -if have_pandas: +if should_test_connect: + import grpc + import pandas as pd + import numpy as np from pyspark.sql.connect.session import SparkSession as RemoteSparkSession from pyspark.sql.connect.client import ChannelBuilder + from pyspark.sql.connect.column import Column from pyspark.sql.connect.dataframe import DataFrame as CDataFrame from pyspark.sql.connect.function_builder import udf from pyspark.sql.connect.functions import lit, col - from pyspark.testing.pandasutils import PandasOnSparkTestCase -else: - from pyspark.testing.sqlutils import ReusedSQLTestCase as PandasOnSparkTestCase # type: ignore -from pyspark.sql.dataframe import DataFrame -import pyspark.sql.functions -from pyspark.testing.connectutils import should_test_connect, connect_requirement_message -from pyspark.testing.utils import ReusedPySparkTestCase - - -import tempfile @unittest.skipIf(not should_test_connect, connect_requirement_message) @@ -53,15 +43,8 @@ class SparkConnectSQLTestCase(PandasOnSparkTestCase, ReusedPySparkTestCase, SQLT """Parent test fixture class for all Spark Connect related test cases.""" - if have_pandas: - connect: RemoteSparkSession - tbl_name: str - tbl_name_empty: str - df_text: "DataFrame" - spark: SparkSession - @classmethod - def setUpClass(cls: Any): + def setUpClass(cls): ReusedPySparkTestCase.setUpClass() cls.tempdir = tempfile.NamedTemporaryFile(delete=False) cls.hive_available = True @@ -82,12 +65,12 @@ def setUpClass(cls: Any): cls.spark_connect_load_test_data() @classmethod - def tearDownClass(cls: Any) -> None: + def tearDownClass(cls): cls.spark_connect_clean_up_test_data() ReusedPySparkTestCase.tearDownClass() @classmethod - def spark_connect_load_test_data(cls: Any): + def spark_connect_load_test_data(cls): # Setup Remote Spark Session cls.connect = RemoteSparkSession.builder.remote().getOrCreate() df = cls.spark.createDataFrame([(x, f"{x}") for x in range(100)], ["id", "name"]) @@ -108,7 +91,7 @@ def spark_connect_load_test_data(cls: Any): empty_df.write.saveAsTable(cls.tbl_name_empty) @classmethod - def spark_connect_clean_up_test_data(cls: Any) -> None: + def spark_connect_clean_up_test_data(cls): cls.spark.sql("DROP TABLE IF EXISTS {}".format(cls.tbl_name)) cls.spark.sql("DROP TABLE IF EXISTS {}".format(cls.tbl_name2)) cls.spark.sql("DROP TABLE IF EXISTS {}".format(cls.tbl_name_empty)) @@ -121,6 +104,35 @@ def test_simple_read(self): # Check that the limit is applied self.assertEqual(len(data.index), 10) + def test_json(self): + with tempfile.TemporaryDirectory() as d: + # Write a DataFrame into a JSON file + self.spark.createDataFrame([{"age": 100, "name": "Hyukjin Kwon"}]).write.mode( + "overwrite" + ).format("json").save(d) + # Read the JSON file as a DataFrame. + self.assert_eq(self.connect.read.json(d).toPandas(), self.spark.read.json(d).toPandas()) + self.assert_eq( + self.connect.read.json(path=d, schema="age INT, name STRING").toPandas(), + self.spark.read.json(path=d, schema="age INT, name STRING").toPandas(), + ) + self.assert_eq( + self.connect.read.json(path=d, primitivesAsString=True).toPandas(), + self.spark.read.json(path=d, primitivesAsString=True).toPandas(), + ) + + def test_paruqet(self): + # SPARK-41445: Implement DataFrameReader.paruqet + with tempfile.TemporaryDirectory() as d: + # Write a DataFrame into a JSON file + self.spark.createDataFrame([{"age": 100, "name": "Hyukjin Kwon"}]).write.mode( + "overwrite" + ).format("parquet").save(d) + # Read the Parquet file as a DataFrame. + self.assert_eq( + self.connect.read.parquet(d).toPandas(), self.spark.read.parquet(d).toPandas() + ) + def test_join_condition_column_list_columns(self): left_connect_df = self.connect.read.table(self.tbl_name) right_connect_df = self.connect.read.table(self.tbl_name2) @@ -183,7 +195,7 @@ def conv_udf(x) -> str: def test_with_local_data(self): """SPARK-41114: Test creating a dataframe using local data""" - pdf = pandas.DataFrame({"a": [1, 2, 3], "b": ["a", "b", "c"]}) + pdf = pd.DataFrame({"a": [1, 2, 3], "b": ["a", "b", "c"]}) df = self.connect.createDataFrame(pdf) rows = df.filter(df.a == lit(3)).collect() self.assertTrue(len(rows) == 1) @@ -191,10 +203,94 @@ def test_with_local_data(self): self.assertEqual(rows[0][1], "c") # Check correct behavior for empty DataFrame - pdf = pandas.DataFrame({"a": []}) + pdf = pd.DataFrame({"a": []}) with self.assertRaises(ValueError): self.connect.createDataFrame(pdf) + def test_with_local_ndarray(self): + """SPARK-41446: Test creating a dataframe using local list""" + data = np.array([[1, 2, 3, 4], [5, 6, 7, 8]]) + + sdf = self.spark.createDataFrame(data) + cdf = self.connect.createDataFrame(data) + self.assertEqual(sdf.schema, cdf.schema) + self.assert_eq(sdf.toPandas(), cdf.toPandas()) + + # TODO: add cases for StructType after 'pyspark_types_to_proto_types' support StructType + for schema in [ + "struct", + "col1 int, col2 int, col3 int, col4 int", + "col1 int, col2 long, col3 string, col4 long", + "col1 int, col2 string, col3 short, col4 long", + ["a", "b", "c", "d"], + ("x1", "x2", "x3", "x4"), + ]: + sdf = self.spark.createDataFrame(data, schema=schema) + cdf = self.connect.createDataFrame(data, schema=schema) + + self.assertEqual(sdf.schema, cdf.schema) + self.assert_eq(sdf.toPandas(), cdf.toPandas()) + + with self.assertRaisesRegex( + ValueError, + "Length mismatch: Expected axis has 4 elements, new values have 5 elements", + ): + self.connect.createDataFrame(data, ["a", "b", "c", "d", "e"]) + + with self.assertRaises(grpc.RpcError): + self.connect.createDataFrame( + data, "col1 magic_type, col2 int, col3 int, col4 int" + ).show() + + with self.assertRaises(grpc.RpcError): + self.connect.createDataFrame(data, "col1 int, col2 int, col3 int").show() + + def test_with_local_list(self): + """SPARK-41446: Test creating a dataframe using local list""" + data = [[1, 2, 3, 4]] + + sdf = self.spark.createDataFrame(data) + cdf = self.connect.createDataFrame(data) + self.assertEqual(sdf.schema, cdf.schema) + self.assert_eq(sdf.toPandas(), cdf.toPandas()) + + for schema in [ + "struct", + "col1 int, col2 int, col3 int, col4 int", + "col1 int, col2 long, col3 string, col4 long", + "col1 int, col2 string, col3 short, col4 long", + ["a", "b", "c", "d"], + ("x1", "x2", "x3", "x4"), + ]: + sdf = self.spark.createDataFrame(data, schema=schema) + cdf = self.connect.createDataFrame(data, schema=schema) + + self.assertEqual(sdf.schema, cdf.schema) + self.assert_eq(sdf.toPandas(), cdf.toPandas()) + + with self.assertRaisesRegex( + ValueError, + "Length mismatch: Expected axis has 4 elements, new values have 5 elements", + ): + self.connect.createDataFrame(data, ["a", "b", "c", "d", "e"]) + + with self.assertRaises(grpc.RpcError): + self.connect.createDataFrame( + data, "col1 magic_type, col2 int, col3 int, col4 int" + ).show() + + with self.assertRaises(grpc.RpcError): + self.connect.createDataFrame(data, "col1 int, col2 int, col3 int").show() + + def test_with_atom_type(self): + for data in [[(1), (2), (3)], [1, 2, 3]]: + for schema in ["long", "int", "short"]: + sdf = self.spark.createDataFrame(data, schema=schema) + cdf = self.connect.createDataFrame(data, schema=schema) + + self.assertEqual(sdf.schema, cdf.schema) + self.assert_eq(sdf.toPandas(), cdf.toPandas()) + def test_simple_explain_string(self): df = self.connect.read.table(self.tbl_name).limit(10) result = df._explain_string() @@ -687,6 +783,29 @@ def test_replace(self): """Cannot resolve column name "x" among (a, b, c)""", str(context.exception) ) + def test_unpivot(self): + self.assert_eq( + self.connect.read.table(self.tbl_name) + .filter("id > 3") + .unpivot(["id"], ["name"], "variable", "value") + .toPandas(), + self.spark.read.table(self.tbl_name) + .filter("id > 3") + .unpivot(["id"], ["name"], "variable", "value") + .toPandas(), + ) + + self.assert_eq( + self.connect.read.table(self.tbl_name) + .filter("id > 3") + .unpivot("id", None, "variable", "value") + .toPandas(), + self.spark.read.table(self.tbl_name) + .filter("id > 3") + .unpivot("id", None, "variable", "value") + .toPandas(), + ) + def test_with_columns(self): # SPARK-41256: test withColumn(s). self.assert_eq( @@ -955,7 +1074,7 @@ def test_metadata(self): from pyspark.sql.tests.connect.test_connect_basic import * # noqa: F401 try: - import xmlrunner # type: ignore + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/sql/tests/connect/test_connect_column.py b/python/pyspark/sql/tests/connect/test_connect_column.py index 106ab609bf..e670123199 100644 --- a/python/pyspark/sql/tests/connect/test_connect_column.py +++ b/python/pyspark/sql/tests/connect/test_connect_column.py @@ -14,13 +14,26 @@ # See the License for the specific language governing permissions and # limitations under the License. # - from pyspark.sql.tests.connect.test_connect_basic import SparkConnectSQLTestCase -from pyspark.testing.sqlutils import have_pandas +from pyspark.sql.types import StringType +from pyspark.sql.types import ( + ByteType, + ShortType, + IntegerType, + FloatType, + DayTimeIntervalType, + StringType, + DoubleType, + LongType, + DecimalType, + BinaryType, + BooleanType, +) +from pyspark.testing.connectutils import should_test_connect -if have_pandas: +if should_test_connect: + import pandas as pd from pyspark.sql.connect.functions import lit - import pandas class SparkConnectTests(SparkConnectSQLTestCase): @@ -74,11 +87,116 @@ def test_columns(self): def test_simple_binary_expressions(self): """Test complex expression""" df = self.connect.read.table(self.tbl_name) - pd = df.select(df.id).where(df.id % lit(30) == lit(0)).sort(df.id.asc()).toPandas() - self.assertEqual(len(pd.index), 4) + pdf = df.select(df.id).where(df.id % lit(30) == lit(0)).sort(df.id.asc()).toPandas() + self.assertEqual(len(pdf.index), 4) + + res = pd.DataFrame(data={"id": [0, 30, 60, 90]}) + self.assert_(pdf.equals(res), f"{pdf.to_string()} != {res.to_string()}") + + def test_literal_integers(self): + cdf = self.connect.range(0, 1) + sdf = self.spark.range(0, 1) + + from pyspark.sql import functions as SF + from pyspark.sql.connect import functions as CF + from pyspark.sql.connect.column import JVM_INT_MIN, JVM_INT_MAX, JVM_LONG_MIN, JVM_LONG_MAX + + cdf1 = cdf.select( + CF.lit(0), + CF.lit(1), + CF.lit(-1), + CF.lit(JVM_INT_MAX), + CF.lit(JVM_INT_MIN), + CF.lit(JVM_INT_MAX + 1), + CF.lit(JVM_INT_MIN - 1), + CF.lit(JVM_LONG_MAX), + CF.lit(JVM_LONG_MIN), + CF.lit(JVM_LONG_MAX - 1), + CF.lit(JVM_LONG_MIN + 1), + ) + + sdf1 = sdf.select( + SF.lit(0), + SF.lit(1), + SF.lit(-1), + SF.lit(JVM_INT_MAX), + SF.lit(JVM_INT_MIN), + SF.lit(JVM_INT_MAX + 1), + SF.lit(JVM_INT_MIN - 1), + SF.lit(JVM_LONG_MAX), + SF.lit(JVM_LONG_MIN), + SF.lit(JVM_LONG_MAX - 1), + SF.lit(JVM_LONG_MIN + 1), + ) + + self.assertEqual(cdf1.schema, sdf1.schema) + self.assert_eq(cdf1.toPandas(), sdf1.toPandas()) + + with self.assertRaisesRegex( + ValueError, + "integer 9223372036854775808 out of bounds", + ): + cdf.select(CF.lit(JVM_LONG_MAX + 1)).show() + + with self.assertRaisesRegex( + ValueError, + "integer -9223372036854775809 out of bounds", + ): + cdf.select(CF.lit(JVM_LONG_MIN - 1)).show() + + def test_cast(self): + # SPARK-41412: test basic Column.cast + df = self.connect.read.table(self.tbl_name) + df2 = self.spark.read.table(self.tbl_name) + + self.assert_eq( + df.select(df.id.cast("string")).toPandas(), df2.select(df2.id.cast("string")).toPandas() + ) + + # Test if the arguments can be passed properly. + # Do not need to check individual behaviour for the ANSI mode thoroughly. + with self.sql_conf({"spark.sql.ansi.enabled": False}): + for x in [ + StringType(), + BinaryType(), + ShortType(), + IntegerType(), + LongType(), + FloatType(), + DoubleType(), + ByteType(), + DecimalType(10, 2), + BooleanType(), + DayTimeIntervalType(), + ]: + self.assert_eq( + df.select(df.id.cast(x)).toPandas(), df2.select(df2.id.cast(x)).toPandas() + ) + + def test_unsupported_functions(self): + # SPARK-41225: Disable unsupported functions. + c = self.connect.range(1).id + for f in ( + "otherwise", + "over", + "isin", + "when", + "getItem", + "astype", + "between", + "getField", + "withField", + "dropFields", + ): + with self.assertRaises(NotImplementedError): + getattr(c, f)() + + with self.assertRaises(NotImplementedError): + c["a"] - res = pandas.DataFrame(data={"id": [0, 30, 60, 90]}) - self.assert_(pd.equals(res), f"{pd.to_string()} != {res.to_string()}") + with self.assertRaises(TypeError): + for x in c: + pass if __name__ == "__main__": @@ -86,7 +204,7 @@ def test_simple_binary_expressions(self): from pyspark.sql.tests.connect.test_connect_column import * # noqa: F401 try: - import xmlrunner # type: ignore + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py b/python/pyspark/sql/tests/connect/test_connect_column_expressions.py index 09e47657eb..d74473e725 100644 --- a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py +++ b/python/pyspark/sql/tests/connect/test_connect_column_expressions.py @@ -15,22 +15,24 @@ # limitations under the License. # import uuid -from typing import cast import unittest import decimal import datetime -from pyspark.testing.connectutils import PlanOnlyTestFixture -from pyspark.testing.sqlutils import have_pandas, pandas_requirement_message +from pyspark.testing.connectutils import ( + PlanOnlyTestFixture, + should_test_connect, + connect_requirement_message, +) -if have_pandas: +if should_test_connect: from pyspark.sql.connect.proto import Expression as ProtoExpression import pyspark.sql.connect.plan as p from pyspark.sql.connect.column import Column import pyspark.sql.connect.functions as fun -@unittest.skipIf(not have_pandas, cast(str, pandas_requirement_message)) +@unittest.skipIf(not should_test_connect, connect_requirement_message) class SparkConnectColumnExpressionSuite(PlanOnlyTestFixture): def test_simple_column_expressions(self): df = self.connect.with_plan(p.Read("table")) @@ -68,7 +70,7 @@ def test_map_literal(self): map_lit_p = map_lit.to_plan(None) self.assertEqual(2, len(map_lit_p.literal.map.pairs)) self.assertEqual("this", map_lit_p.literal.map.pairs[0].key.string) - self.assertEqual(12, map_lit_p.literal.map.pairs[1].key.long) + self.assertEqual(12, map_lit_p.literal.map.pairs[1].key.integer) val = {"this": fun.lit("is"), 12: [12, 32, 43]} map_lit = fun.lit(val) @@ -89,7 +91,10 @@ def test_column_literals(self): self.assertIsNotNone(fun.lit(10).to_plan(None)) plan = fun.lit(10).to_plan(None) - self.assertIs(plan.literal.long, 10) + self.assertIs(plan.literal.integer, 10) + + plan = fun.lit(1 << 33).to_plan(None) + self.assertEqual(plan.literal.long, 1 << 33) def test_numeric_literal_types(self): int_lit = fun.lit(10) @@ -167,13 +172,13 @@ def test_tuple_to_literal(self): p2 = fun.lit(t2).to_plan(None) self.assertIsNotNone(p2) self.assertTrue(p2.literal.HasField("struct")) - self.assertEqual(p2.literal.struct.fields[0].long, 1) + self.assertEqual(p2.literal.struct.fields[0].integer, 1) self.assertEqual(p2.literal.struct.fields[1].string, "xyz") p3 = fun.lit(t3).to_plan(None) self.assertIsNotNone(p3) self.assertTrue(p3.literal.HasField("struct")) - self.assertEqual(p3.literal.struct.fields[0].long, 1) + self.assertEqual(p3.literal.struct.fields[0].integer, 1) self.assertEqual(p3.literal.struct.fields[1].string, "abc") self.assertEqual(p3.literal.struct.fields[2].struct.fields[0].double, 3.5) self.assertEqual(p3.literal.struct.fields[2].struct.fields[1].boolean, True) @@ -207,7 +212,7 @@ def test_column_expressions(self): lit_fun = expr_plan.unresolved_function.arguments[1] self.assertIsInstance(lit_fun, ProtoExpression) self.assertIsInstance(lit_fun.literal, ProtoExpression.Literal) - self.assertEqual(lit_fun.literal.long, 10) + self.assertEqual(lit_fun.literal.integer, 10) mod_fun = expr_plan.unresolved_function.arguments[0] self.assertIsInstance(mod_fun, ProtoExpression) @@ -228,7 +233,7 @@ def test_column_expressions(self): from pyspark.sql.tests.connect.test_connect_column_expressions import * # noqa: F401 try: - import xmlrunner # type: ignore + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/sql/tests/connect/test_connect_function.py b/python/pyspark/sql/tests/connect/test_connect_function.py index ee3a927708..ee5d2d49d9 100644 --- a/python/pyspark/sql/tests/connect/test_connect_function.py +++ b/python/pyspark/sql/tests/connect/test_connect_function.py @@ -14,22 +14,18 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import Any import unittest import tempfile -from pyspark.testing.sqlutils import have_pandas, SQLTestUtils - from pyspark.sql import SparkSession - -if have_pandas: - from pyspark.sql.connect.session import SparkSession as RemoteSparkSession - from pyspark.testing.pandasutils import PandasOnSparkTestCase -else: - from pyspark.testing.sqlutils import ReusedSQLTestCase as PandasOnSparkTestCase # type: ignore -from pyspark.sql.dataframe import DataFrame +from pyspark.testing.pandasutils import PandasOnSparkTestCase from pyspark.testing.connectutils import should_test_connect, connect_requirement_message from pyspark.testing.utils import ReusedPySparkTestCase +from pyspark.testing.sqlutils import SQLTestUtils + +if should_test_connect: + import grpc + from pyspark.sql.connect.session import SparkSession as RemoteSparkSession @unittest.skipIf(not should_test_connect, connect_requirement_message) @@ -37,15 +33,8 @@ class SparkConnectFuncTestCase(PandasOnSparkTestCase, ReusedPySparkTestCase, SQL """Parent test fixture class for all Spark Connect related test cases.""" - if have_pandas: - connect: RemoteSparkSession - tbl_name: str - tbl_name_empty: str - df_text: "DataFrame" - spark: SparkSession - @classmethod - def setUpClass(cls: Any): + def setUpClass(cls): ReusedPySparkTestCase.setUpClass() cls.tempdir = tempfile.NamedTemporaryFile(delete=False) cls.hive_available = True @@ -55,7 +44,7 @@ def setUpClass(cls: Any): cls.connect = RemoteSparkSession.builder.remote().getOrCreate() @classmethod - def tearDownClass(cls: Any) -> None: + def tearDownClass(cls): ReusedPySparkTestCase.tearDownClass() @@ -63,6 +52,24 @@ class SparkConnectFunctionTests(SparkConnectFuncTestCase): """These test cases exercise the interface to the proto plan generation but do not call Spark.""" + def compare_by_show(self, df1, df2): + from pyspark.sql.dataframe import DataFrame as SDF + from pyspark.sql.connect.dataframe import DataFrame as CDF + + assert isinstance(df1, (SDF, CDF)) + if isinstance(df1, SDF): + str1 = df1._jdf.showString(20, 20, False) + else: + str1 = df1._show_string(20, 20, False) + + assert isinstance(df2, (SDF, CDF)) + if isinstance(df2, SDF): + str2 = df2._jdf.showString(20, 20, False) + else: + str2 = df2._show_string(20, 20, False) + + self.assertEqual(str1, str2) + def test_normal_functions(self): from pyspark.sql import functions as SF from pyspark.sql.connect import functions as CF @@ -428,6 +435,513 @@ def test_aggregation_functions(self): .toPandas(), ) + def test_collection_functions(self): + from pyspark.sql import functions as SF + from pyspark.sql.connect import functions as CF + + query = """ + SELECT * FROM VALUES + (ARRAY('a', 'ab'), ARRAY(1, 2, 3), ARRAY(1, NULL, 3), 1, 2, 'a'), + (ARRAY('x', NULL), NULL, ARRAY(1, 3), 3, 4, 'x'), + (NULL, ARRAY(-1, -2, -3), Array(), 5, 6, NULL) + AS tab(a, b, c, d, e, f) + """ + # +---------+------------+------------+---+---+----+ + # | a| b| c| d| e| f| + # +---------+------------+------------+---+---+----+ + # | [a, ab]| [1, 2, 3]|[1, null, 3]| 1| 2| a| + # |[x, null]| null| [1, 3]| 3| 4| x| + # | null|[-1, -2, -3]| []| 5| 6|null| + # +---------+------------+------------+---+---+----+ + + cdf = self.connect.sql(query) + sdf = self.spark.sql(query) + + for cfunc, sfunc in [ + (CF.array_distinct, SF.array_distinct), + (CF.array_max, SF.array_max), + (CF.array_min, SF.array_min), + (CF.reverse, SF.reverse), + (CF.size, SF.size), + ]: + self.assert_eq( + cdf.select(cfunc("a"), cfunc(cdf.b)).toPandas(), + sdf.select(sfunc("a"), sfunc(sdf.b)).toPandas(), + ) + + for cfunc, sfunc in [ + (CF.array_except, SF.array_except), + (CF.array_intersect, SF.array_intersect), + (CF.array_union, SF.array_union), + (CF.arrays_overlap, SF.arrays_overlap), + ]: + self.assert_eq( + cdf.select(cfunc("b", cdf.c)).toPandas(), + sdf.select(sfunc("b", sdf.c)).toPandas(), + ) + + for cfunc, sfunc in [ + (CF.array_position, SF.array_position), + (CF.array_remove, SF.array_remove), + ]: + self.assert_eq( + cdf.select(cfunc(cdf.a, "ab")).toPandas(), + sdf.select(sfunc(sdf.a, "ab")).toPandas(), + ) + + # test array + self.assert_eq( + cdf.select(CF.array(cdf.d, "e")).toPandas(), + sdf.select(SF.array(sdf.d, "e")).toPandas(), + ) + self.assert_eq( + cdf.select(CF.array(cdf.d, "e", CF.lit(99))).toPandas(), + sdf.select(SF.array(sdf.d, "e", SF.lit(99))).toPandas(), + ) + + # test array_contains + self.assert_eq( + cdf.select(CF.array_contains(cdf.a, "ab")).toPandas(), + sdf.select(SF.array_contains(sdf.a, "ab")).toPandas(), + ) + self.assert_eq( + cdf.select(CF.array_contains(cdf.a, cdf.f)).toPandas(), + sdf.select(SF.array_contains(sdf.a, sdf.f)).toPandas(), + ) + + # test array_join + self.assert_eq( + cdf.select( + CF.array_join(cdf.a, ","), CF.array_join("b", ":"), CF.array_join("c", "~") + ).toPandas(), + sdf.select( + SF.array_join(sdf.a, ","), SF.array_join("b", ":"), SF.array_join("c", "~") + ).toPandas(), + ) + self.assert_eq( + cdf.select( + CF.array_join(cdf.a, ",", "_null_"), + CF.array_join("b", ":", ".null."), + CF.array_join("c", "~", "NULL"), + ).toPandas(), + sdf.select( + SF.array_join(sdf.a, ",", "_null_"), + SF.array_join("b", ":", ".null."), + SF.array_join("c", "~", "NULL"), + ).toPandas(), + ) + + # test array_repeat + self.assert_eq( + cdf.select(CF.array_repeat(cdf.f, "d")).toPandas(), + sdf.select(SF.array_repeat(sdf.f, "d")).toPandas(), + ) + self.assert_eq( + cdf.select(CF.array_repeat("f", cdf.d)).toPandas(), + sdf.select(SF.array_repeat("f", sdf.d)).toPandas(), + ) + # TODO: Make Literal contains DataType + # Cannot resolve "array_repeat(f, 3)" due to data type mismatch: + # Parameter 2 requires the "INT" type, however "3" has the type "BIGINT". + # self.assert_eq( + # cdf.select(CF.array_repeat("f", 3)).toPandas(), + # sdf.select(SF.array_repeat("f", 3)).toPandas(), + # ) + + # test arrays_zip + # TODO: Make toPandas support complex nested types like Array + # DataFrame.iloc[:, 0] (column name="arrays_zip(b, c)") values are different (66.66667 %) + # [index]: [0, 1, 2] + # [left]: [[{'b': 1, 'c': 1.0}, {'b': 2, 'c': None}, {'b': 3, 'c': 3.0}], None, + # [{'b': -1, 'c': None}, {'b': -2, 'c': None}, {'b': -3, 'c': None}]] + # [right]: [[(1, 1), (2, None), (3, 3)], None, [(-1, None), (-2, None), (-3, None)]] + self.compare_by_show( + cdf.select(CF.arrays_zip(cdf.b, "c")), + sdf.select(SF.arrays_zip(sdf.b, "c")), + ) + + # test concat + self.assert_eq( + cdf.select(CF.concat("d", cdf.e, CF.lit(-1))).toPandas(), + sdf.select(SF.concat("d", sdf.e, SF.lit(-1))).toPandas(), + ) + + # test create_map + self.compare_by_show( + cdf.select(CF.create_map(cdf.d, cdf.e)), sdf.select(SF.create_map(sdf.d, sdf.e)) + ) + self.compare_by_show( + cdf.select(CF.create_map(cdf.d, "e", "e", CF.lit(1))), + sdf.select(SF.create_map(sdf.d, "e", "e", SF.lit(1))), + ) + + # test element_at + self.assert_eq( + cdf.select(CF.element_at("a", 1)).toPandas(), + sdf.select(SF.element_at("a", 1)).toPandas(), + ) + self.assert_eq( + cdf.select(CF.element_at(cdf.a, 1)).toPandas(), + sdf.select(SF.element_at(sdf.a, 1)).toPandas(), + ) + + # test get + self.assert_eq( + cdf.select(CF.get("a", 1)).toPandas(), + sdf.select(SF.get("a", 1)).toPandas(), + ) + self.assert_eq( + cdf.select(CF.get(cdf.a, 1)).toPandas(), + sdf.select(SF.get(sdf.a, 1)).toPandas(), + ) + + # test shuffle + # Can not compare the values due to the random permutation + self.assertEqual( + cdf.select(CF.shuffle(cdf.a), CF.shuffle("b")).count(), + sdf.select(SF.shuffle(sdf.a), SF.shuffle("b")).count(), + ) + + # test slice + self.assert_eq( + cdf.select(CF.slice(cdf.a, 1, 2), CF.slice("c", 2, 3)).toPandas(), + sdf.select(SF.slice(sdf.a, 1, 2), SF.slice("c", 2, 3)).toPandas(), + ) + + # test sort_array + self.assert_eq( + cdf.select(CF.sort_array(cdf.a, True), CF.sort_array("c", False)).toPandas(), + sdf.select(SF.sort_array(sdf.a, True), SF.sort_array("c", False)).toPandas(), + ) + + # test struct + self.compare_by_show( + cdf.select(CF.struct(cdf.a, "d", "e", cdf.f)), + sdf.select(SF.struct(sdf.a, "d", "e", sdf.f)), + ) + + def test_map_collection_functions(self): + from pyspark.sql import functions as SF + from pyspark.sql.connect import functions as CF + + query = """ + SELECT * FROM VALUES + (MAP('a', 'ab'), MAP('x', 'ab'), MAP(1, 2, 3, 4), 1, 'a', ARRAY(1, 2), ARRAY('X', 'Y')), + (MAP('x', 'yz'), MAP('c', NULL), NULL, 2, 'x', ARRAY(3, 4), ARRAY('A', 'B')), + (MAP('c', 'de'), NULL, MAP(-1, NULL, -3, -4), -3, 'c', NULL, ARRAY('Z')) + AS tab(a, b, c, e, f, g, h) + """ + # +---------+-----------+----------------------+---+---+------+------+ + # | a| b| c| e| f| g| h| + # +---------+-----------+----------------------+---+---+------+------+ + # |{a -> ab}| {x -> ab}| {1 -> 2, 3 -> 4}| 1| a|[1, 2]|[X, Y]| + # |{x -> yz}|{c -> null}| null| 2| x|[3, 4]|[A, B]| + # |{c -> de}| null|{-1 -> null, -3 -> -4}| -3| c| null| [Z]| + # +---------+-----------+----------------------+---+---+------+------+ + + cdf = self.connect.sql(query) + sdf = self.spark.sql(query) + + # test map_concat + self.compare_by_show( + cdf.select(CF.map_concat(cdf.a, "b")), + sdf.select(SF.map_concat(sdf.a, "b")), + ) + + # test map_contains_key + self.compare_by_show( + cdf.select(CF.map_contains_key(cdf.a, "a"), CF.map_contains_key("c", 3)), + sdf.select(SF.map_contains_key(sdf.a, "a"), SF.map_contains_key("c", 3)), + ) + + # test map_entries + self.compare_by_show( + cdf.select(CF.map_entries(cdf.a), CF.map_entries("b")), + sdf.select(SF.map_entries(sdf.a), SF.map_entries("b")), + ) + + # test map_from_arrays + self.compare_by_show( + cdf.select(CF.map_from_arrays(cdf.g, "h")), + sdf.select(SF.map_from_arrays(sdf.g, "h")), + ) + + # test map_keys and map_values + self.compare_by_show( + cdf.select(CF.map_keys(cdf.a), CF.map_values("b")), + sdf.select(SF.map_keys(sdf.a), SF.map_values("b")), + ) + + # test size + self.assert_eq( + cdf.select(CF.size(cdf.a), CF.size("c")).toPandas(), + sdf.select(SF.size(sdf.a), SF.size("c")).toPandas(), + ) + + def test_generator_functions(self): + from pyspark.sql import functions as SF + from pyspark.sql.connect import functions as CF + + query = """ + SELECT * FROM VALUES + (ARRAY('a', 'ab'), ARRAY(1, 2, 3), ARRAY(1, NULL, 3), + MAP(1, 2, 3, 4), 1, FLOAT(2.0), 3), + (ARRAY('x', NULL), NULL, ARRAY(1, 3), + NULL, 3, FLOAT(4.0), 5), + (NULL, ARRAY(-1, -2, -3), Array(), + MAP(-1, NULL, -3, -4), 7, FLOAT('NAN'), 9) + AS tab(a, b, c, d, e, f, g) + """ + # +---------+------------+------------+----------------------+---+---+---+ + # | a| b| c| d| e| f| g| + # +---------+------------+------------+----------------------+---+---+---+ + # | [a, ab]| [1, 2, 3]|[1, null, 3]| {1 -> 2, 3 -> 4}| 1|2.0| 3| + # |[x, null]| null| [1, 3]| null| 3|4.0| 5| + # | null|[-1, -2, -3]| []|{-1 -> null, -3 -> -4}| 7|NaN| 9| + # +---------+------------+------------+----------------------+---+---+---+ + + cdf = self.connect.sql(query) + sdf = self.spark.sql(query) + + # test explode with arrays + self.assert_eq( + cdf.select(CF.explode(cdf.a), CF.col("b")).toPandas(), + sdf.select(SF.explode(sdf.a), SF.col("b")).toPandas(), + ) + self.assert_eq( + cdf.select(CF.explode("a"), "b").toPandas(), + sdf.select(SF.explode("a"), "b").toPandas(), + ) + # test explode with maps + self.assert_eq( + cdf.select(CF.explode(cdf.d), CF.col("c")).toPandas(), + sdf.select(SF.explode(sdf.d), SF.col("c")).toPandas(), + ) + self.assert_eq( + cdf.select(CF.explode("d"), "c").toPandas(), + sdf.select(SF.explode("d"), "c").toPandas(), + ) + + # test explode_outer with arrays + self.assert_eq( + cdf.select(CF.explode_outer(cdf.a), CF.col("b")).toPandas(), + sdf.select(SF.explode_outer(sdf.a), SF.col("b")).toPandas(), + ) + self.assert_eq( + cdf.select(CF.explode_outer("a"), "b").toPandas(), + sdf.select(SF.explode_outer("a"), "b").toPandas(), + ) + # test explode_outer with maps + self.assert_eq( + cdf.select(CF.explode_outer(cdf.d), CF.col("c")).toPandas(), + sdf.select(SF.explode_outer(sdf.d), SF.col("c")).toPandas(), + ) + self.assert_eq( + cdf.select(CF.explode_outer("d"), "c").toPandas(), + sdf.select(SF.explode_outer("d"), "c").toPandas(), + ) + + # test flatten + self.assert_eq( + cdf.select(CF.flatten(CF.array("b", cdf.c)), CF.col("b")).toPandas(), + sdf.select(SF.flatten(SF.array("b", sdf.c)), SF.col("b")).toPandas(), + ) + + # test inline + self.assert_eq( + cdf.select(CF.expr("ARRAY(STRUCT(e, f), STRUCT(g AS e, f))").alias("X")) + .select(CF.inline("X")) + .toPandas(), + sdf.select(SF.expr("ARRAY(STRUCT(e, f), STRUCT(g AS e, f))").alias("X")) + .select(SF.inline("X")) + .toPandas(), + ) + + # test inline_outer + self.assert_eq( + cdf.select(CF.expr("ARRAY(STRUCT(e, f), STRUCT(g AS e, f))").alias("X")) + .select(CF.inline_outer("X")) + .toPandas(), + sdf.select(SF.expr("ARRAY(STRUCT(e, f), STRUCT(g AS e, f))").alias("X")) + .select(SF.inline_outer("X")) + .toPandas(), + ) + + # test posexplode with arrays + self.assert_eq( + cdf.select(CF.posexplode(cdf.a), CF.col("b")).toPandas(), + sdf.select(SF.posexplode(sdf.a), SF.col("b")).toPandas(), + ) + self.assert_eq( + cdf.select(CF.posexplode("a"), "b").toPandas(), + sdf.select(SF.posexplode("a"), "b").toPandas(), + ) + # test posexplode with maps + self.assert_eq( + cdf.select(CF.posexplode(cdf.d), CF.col("c")).toPandas(), + sdf.select(SF.posexplode(sdf.d), SF.col("c")).toPandas(), + ) + self.assert_eq( + cdf.select(CF.posexplode("d"), "c").toPandas(), + sdf.select(SF.posexplode("d"), "c").toPandas(), + ) + + # test posexplode_outer with arrays + self.assert_eq( + cdf.select(CF.posexplode_outer(cdf.a), CF.col("b")).toPandas(), + sdf.select(SF.posexplode_outer(sdf.a), SF.col("b")).toPandas(), + ) + self.assert_eq( + cdf.select(CF.posexplode_outer("a"), "b").toPandas(), + sdf.select(SF.posexplode_outer("a"), "b").toPandas(), + ) + # test posexplode_outer with maps + self.assert_eq( + cdf.select(CF.posexplode_outer(cdf.d), CF.col("c")).toPandas(), + sdf.select(SF.posexplode_outer(sdf.d), SF.col("c")).toPandas(), + ) + self.assert_eq( + cdf.select(CF.posexplode_outer("d"), "c").toPandas(), + sdf.select(SF.posexplode_outer("d"), "c").toPandas(), + ) + + def test_csv_functions(self): + from pyspark.sql import functions as SF + from pyspark.sql.connect import functions as CF + + query = """ + SELECT * FROM VALUES + ('1,2,3', 'a,b,5.0'), + ('3,4,5', 'x,y,6.0') + AS tab(a, b) + """ + # +-----+-------+ + # | a| b| + # +-----+-------+ + # |1,2,3|a,b,5.0| + # |3,4,5|x,y,6.0| + # +-----+-------+ + + cdf = self.connect.sql(query) + sdf = self.spark.sql(query) + + # test from_csv + self.compare_by_show( + cdf.select( + CF.from_csv(cdf.a, "a INT, b INT, c INT"), + CF.from_csv("b", "x STRING, y STRING, z DOUBLE"), + ), + sdf.select( + SF.from_csv(sdf.a, "a INT, b INT, c INT"), + SF.from_csv("b", "x STRING, y STRING, z DOUBLE"), + ), + ) + self.compare_by_show( + cdf.select( + CF.from_csv(cdf.a, CF.lit("a INT, b INT, c INT")), + CF.from_csv("b", CF.lit("x STRING, y STRING, z DOUBLE")), + ), + sdf.select( + SF.from_csv(sdf.a, SF.lit("a INT, b INT, c INT")), + SF.from_csv("b", SF.lit("x STRING, y STRING, z DOUBLE")), + ), + ) + + # test schema_of_csv + self.assert_eq( + cdf.select(CF.schema_of_csv(CF.lit('{"a": 0}'))).toPandas(), + sdf.select(SF.schema_of_csv(SF.lit('{"a": 0}'))).toPandas(), + ) + + # test to_csv + self.compare_by_show( + cdf.select(CF.to_csv(CF.struct(CF.lit("a"), CF.lit("b")))), + sdf.select(SF.to_csv(SF.struct(SF.lit("a"), SF.lit("b")))), + ) + + def test_json_functions(self): + from pyspark.sql import functions as SF + from pyspark.sql.connect import functions as CF + + query = """ + SELECT * FROM VALUES + ('{"a": 1}', '[1, 2, 3]', '{"f1": "value1", "f2": "value2"}'), + ('{"a": 0}', '[4, 5, 6]', '{"f1": "value12"}') + AS tab(a, b, c) + """ + # +--------+---------+--------------------------------+ + # | a| b| c| + # +--------+---------+--------------------------------+ + # |{"a": 1}|[1, 2, 3]|{"f1": "value1", "f2": "value2"}| + # |{"a": 0}|[4, 5, 6]| {"f1": "value12"}| + # +--------+---------+--------------------------------+ + + cdf = self.connect.sql(query) + sdf = self.spark.sql(query) + + # test from_json + for schema in [ + "a INT", + "MAP", + # StructType([StructField("a", IntegerType())]), + # ArrayType(StructType([StructField("a", IntegerType())])), + ]: + self.compare_by_show( + cdf.select(CF.from_json(cdf.a, schema)), + sdf.select(SF.from_json(sdf.a, schema)), + ) + self.compare_by_show( + cdf.select(CF.from_json("a", schema)), + sdf.select(SF.from_json("a", schema)), + ) + + for schema in [ + "ARRAY", + # ArrayType(IntegerType()), + ]: + self.compare_by_show( + cdf.select(CF.from_json(cdf.b, schema)), + sdf.select(SF.from_json(sdf.b, schema)), + ) + self.compare_by_show( + cdf.select(CF.from_json("b", schema)), + sdf.select(SF.from_json("b", schema)), + ) + + # test get_json_object + self.assert_eq( + cdf.select( + CF.get_json_object("c", "$.f1"), + CF.get_json_object(cdf.c, "$.f2"), + ).toPandas(), + sdf.select( + SF.get_json_object("c", "$.f1"), + SF.get_json_object(sdf.c, "$.f2"), + ).toPandas(), + ) + + # test json_tuple + self.assert_eq( + cdf.select(CF.json_tuple("c", "f1", "f2")).toPandas(), + sdf.select(SF.json_tuple("c", "f1", "f2")).toPandas(), + ) + self.assert_eq( + cdf.select(CF.json_tuple(cdf.c, "f1", "f2")).toPandas(), + sdf.select(SF.json_tuple(sdf.c, "f1", "f2")).toPandas(), + ) + + # test schema_of_json + self.assert_eq( + cdf.select(CF.schema_of_json(CF.lit('{"a": 0}'))).toPandas(), + sdf.select(SF.schema_of_json(SF.lit('{"a": 0}'))).toPandas(), + ) + + # test to_json + self.compare_by_show( + cdf.select(CF.to_json(CF.struct(CF.lit("a"), CF.lit("b")))), + sdf.select(SF.to_json(SF.struct(SF.lit("a"), SF.lit("b")))), + ) + def test_string_functions(self): from pyspark.sql import functions as SF from pyspark.sql.connect import functions as CF @@ -456,6 +970,7 @@ def test_string_functions(self): (CF.ltrim, SF.ltrim), (CF.rtrim, SF.rtrim), (CF.trim, SF.trim), + (CF.reverse, SF.reverse), ]: self.assert_eq( cdf.select(cfunc("a"), cfunc(cdf.b)).toPandas(), @@ -467,28 +982,229 @@ def test_string_functions(self): sdf.select(SF.concat_ws("-", sdf.a, "c")).toPandas(), ) - # Disable the test for "decode" because of inconsistent column names, - # as shown below - # - # >>> sdf.select(SF.decode("c", "UTF-8")).toPandas() - # stringdecode(c, UTF-8) - # 0 None - # 1 ab - # >>> cdf.select(CF.decode("c", "UTF-8")).toPandas() - # decode(c, UTF-8) - # 0 None - # 1 ab - # - # self.assert_eq( - # cdf.select(CF.decode("c", "UTF-8")).toPandas(), - # sdf.select(SF.decode("c", "UTF-8")).toPandas(), - # ) + self.assert_eq( + cdf.select(CF.decode("c", "UTF-8")).toPandas(), + sdf.select(SF.decode("c", "UTF-8")).toPandas(), + ) self.assert_eq( cdf.select(CF.encode("c", "UTF-8")).toPandas(), sdf.select(SF.encode("c", "UTF-8")).toPandas(), ) + # TODO(SPARK-41283): To compare toPandas for test cases with dtypes marked + def test_date_ts_functions(self): + from pyspark.sql import functions as SF + from pyspark.sql.connect import functions as CF + + query = """ + SELECT * FROM VALUES + ('1997/02/28 10:30:00', '2023/03/01 06:00:00', 'JST', 1428476400, 2020, 12, 6), + ('2000/01/01 04:30:05', '2020/05/01 12:15:00', 'PST', 1403892395, 2022, 12, 6) + AS tab(ts1, ts2, tz, seconds, Y, M, D) + """ + # +-------------------+-------------------+---+----------+----+---+---+ + # | ts1| ts2| tz| seconds| Y| M| D| + # +-------------------+-------------------+---+----------+----+---+---+ + # |1997/02/28 10:30:00|2023/03/01 06:00:00|JST|1428476400|2020| 12| 6| + # |2000/01/01 04:30:05|2020/05/01 12:15:00|PST|1403892395|2022| 12| 6| + # +-------------------+-------------------+---+----------+----+---+---+ + + cdf = self.connect.sql(query) + sdf = self.spark.sql(query) + + # With no parameters + for cfunc, sfunc in [ + (CF.current_date, SF.current_date), + ]: + self.assert_eq( + cdf.select(cfunc()).toPandas(), + sdf.select(sfunc()).toPandas(), + ) + + # current_timestamp + # [left]: datetime64[ns, America/Los_Angeles] + # [right]: datetime64[ns] + # TODO: compare the return values after resolving dtypes difference + self.assertEqual( + cdf.select(CF.current_timestamp()).count(), + sdf.select(SF.current_timestamp()).count(), + ) + + # localtimestamp + s_pdf0 = sdf.select(SF.localtimestamp()).toPandas() + c_pdf = cdf.select(CF.localtimestamp()).toPandas() + s_pdf1 = sdf.select(SF.localtimestamp()).toPandas() + self.assert_eq(s_pdf0 < c_pdf, c_pdf < s_pdf1) + + # With only column parameter + for cfunc, sfunc in [ + (CF.year, SF.year), + (CF.quarter, SF.quarter), + (CF.month, SF.month), + (CF.dayofweek, SF.dayofweek), + (CF.dayofmonth, SF.dayofmonth), + (CF.dayofyear, SF.dayofyear), + (CF.hour, SF.hour), + (CF.minute, SF.minute), + (CF.second, SF.second), + (CF.weekofyear, SF.weekofyear), + (CF.last_day, SF.last_day), + (CF.unix_timestamp, SF.unix_timestamp), + ]: + self.assert_eq( + cdf.select(cfunc(cdf.ts1)).toPandas(), + sdf.select(sfunc(sdf.ts1)).toPandas(), + ) + + # With format parameter + for cfunc, sfunc in [ + (CF.date_format, SF.date_format), + (CF.to_date, SF.to_date), + ]: + self.assert_eq( + cdf.select(cfunc(cdf.ts1, format="yyyy-MM-dd")).toPandas(), + sdf.select(sfunc(sdf.ts1, format="yyyy-MM-dd")).toPandas(), + ) + self.compare_by_show( + # [left]: datetime64[ns, America/Los_Angeles] + # [right]: datetime64[ns] + cdf.select(CF.to_timestamp(cdf.ts1, format="yyyy-MM-dd")), + sdf.select(SF.to_timestamp(sdf.ts1, format="yyyy-MM-dd")), + ) + + # With tz parameter + for cfunc, sfunc in [ + (CF.from_utc_timestamp, SF.from_utc_timestamp), + (CF.to_utc_timestamp, SF.to_utc_timestamp), + # [left]: datetime64[ns, America/Los_Angeles] + # [right]: datetime64[ns] + ]: + self.compare_by_show( + cdf.select(cfunc(cdf.ts1, tz=cdf.tz)), + sdf.select(sfunc(sdf.ts1, tz=sdf.tz)), + ) + + # With numeric parameter + for cfunc, sfunc in [ + (CF.date_add, SF.date_add), + (CF.date_sub, SF.date_sub), + (CF.add_months, SF.add_months), + ]: + self.assert_eq( + cdf.select(cfunc(cdf.ts1, cdf.D)).toPandas(), + sdf.select(sfunc(sdf.ts1, sdf.D)).toPandas(), + ) + + # With another timestamp as parameter + for cfunc, sfunc in [ + (CF.datediff, SF.datediff), + (CF.months_between, SF.months_between), + ]: + self.assert_eq( + cdf.select(cfunc(cdf.ts1, cdf.ts2)).toPandas(), + sdf.select(sfunc(sdf.ts1, sdf.ts2)).toPandas(), + ) + + # With seconds parameter + self.compare_by_show( + # [left]: datetime64[ns, America/Los_Angeles] + # [right]: datetime64[ns] + cdf.select(CF.timestamp_seconds(cdf.seconds)), + sdf.select(SF.timestamp_seconds(sdf.seconds)), + ) + + # make_date + self.assert_eq( + cdf.select(CF.make_date(cdf.Y, cdf.M, cdf.D)).toPandas(), + sdf.select(SF.make_date(sdf.Y, sdf.M, sdf.D)).toPandas(), + ) + + # date_trunc + self.compare_by_show( + # [left]: datetime64[ns, America/Los_Angeles] + # [right]: datetime64[ns] + cdf.select(CF.date_trunc("day", cdf.ts1)), + sdf.select(SF.date_trunc("day", sdf.ts1)), + ) + + # trunc + self.assert_eq( + cdf.select(CF.trunc(cdf.ts1, "year")).toPandas(), + sdf.select(SF.trunc(sdf.ts1, "year")).toPandas(), + ) + + # next_day + self.assert_eq( + cdf.select(CF.next_day(cdf.ts1, "Mon")).toPandas(), + sdf.select(SF.next_day(sdf.ts1, "Mon")).toPandas(), + ) + + def test_misc_functions(self): + from pyspark.sql import functions as SF + from pyspark.sql.connect import functions as CF + + query = """ + SELECT a, b, c, BINARY(c) as d FROM VALUES + (0, float("NAN"), 'x'), (1, NULL, 'y'), (1, 2.1, 'z'), (0, 0.5, NULL) + AS tab(a, b, c) + """ + # +---+----+----+----+ + # | a| b| c| d| + # +---+----+----+----+ + # | 0| NaN| x|[78]| + # | 1|null| y|[79]| + # | 1| 2.1| z|[7A]| + # | 0| 0.5|null|null| + # +---+----+----+----+ + + cdf = self.connect.sql(query) + sdf = self.spark.sql(query) + + # test assert_true + with self.assertRaises(grpc.RpcError): + cdf.select(CF.assert_true(cdf.a > 0, "a should be positive!")).show() + + # test raise_error + with self.assertRaises(grpc.RpcError): + cdf.select(CF.raise_error("a should be positive!")).show() + + # test crc32 + self.assert_eq( + cdf.select(CF.crc32(cdf.d)).toPandas(), + sdf.select(SF.crc32(sdf.d)).toPandas(), + ) + + # test hash + self.assert_eq( + cdf.select(CF.hash(cdf.a, "b", cdf.c)).toPandas(), + sdf.select(SF.hash(sdf.a, "b", sdf.c)).toPandas(), + ) + + # test xxhash64 + self.assert_eq( + cdf.select(CF.xxhash64(cdf.a, "b", cdf.c)).toPandas(), + sdf.select(SF.xxhash64(sdf.a, "b", sdf.c)).toPandas(), + ) + + # test md5 + self.assert_eq( + cdf.select(CF.md5(cdf.d), CF.md5("c")).toPandas(), + sdf.select(SF.md5(sdf.d), SF.md5("c")).toPandas(), + ) + + # test sha1 + self.assert_eq( + cdf.select(CF.sha1(cdf.d), CF.sha1("c")).toPandas(), + sdf.select(SF.sha1(sdf.d), SF.sha1("c")).toPandas(), + ) + + # test sha2 + self.assert_eq( + cdf.select(CF.sha2(cdf.c, 256), CF.sha2("d", 512)).toPandas(), + sdf.select(SF.sha2(sdf.c, 256), SF.sha2("d", 512)).toPandas(), + ) + if __name__ == "__main__": from pyspark.sql.tests.connect.test_connect_function import * # noqa: F401 diff --git a/python/pyspark/sql/tests/connect/test_connect_plan_only.py b/python/pyspark/sql/tests/connect/test_connect_plan_only.py index b9695eea78..e0cd54195f 100644 --- a/python/pyspark/sql/tests/connect/test_connect_plan_only.py +++ b/python/pyspark/sql/tests/connect/test_connect_plan_only.py @@ -14,13 +14,15 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import cast import unittest -from pyspark.testing.connectutils import PlanOnlyTestFixture -from pyspark.testing.sqlutils import have_pandas, pandas_requirement_message +from pyspark.testing.connectutils import ( + PlanOnlyTestFixture, + should_test_connect, + connect_requirement_message, +) -if have_pandas: +if should_test_connect: import pyspark.sql.connect.proto as proto from pyspark.sql.connect.column import Column from pyspark.sql.connect.plan import WriteOperation @@ -29,7 +31,7 @@ from pyspark.sql.types import StringType -@unittest.skipIf(not have_pandas, cast(str, pandas_requirement_message)) +@unittest.skipIf(not should_test_connect, connect_requirement_message) class SparkConnectTestsPlanOnly(PlanOnlyTestFixture): """These test cases exercise the interface to the proto plan generation but do not call Spark.""" @@ -168,6 +170,64 @@ def test_replace(self): self.assertEqual(plan.root.replace.replacements[1].old_value.string, "Bob") self.assertEqual(plan.root.replace.replacements[1].new_value.string, "B") + def test_unpivot(self): + df = self.connect.readTable(table_name=self.tbl_name) + + plan = ( + df.filter(df.col_name > 3) + .unpivot(["id"], ["name"], "variable", "value") + ._plan.to_proto(self.connect) + ) + self.assertTrue(all(isinstance(c, proto.Expression) for c in plan.root.unpivot.ids)) + self.assertEqual(plan.root.unpivot.ids[0].unresolved_attribute.unparsed_identifier, "id") + self.assertTrue(all(isinstance(c, proto.Expression) for c in plan.root.unpivot.values)) + self.assertEqual( + plan.root.unpivot.values[0].unresolved_attribute.unparsed_identifier, "name" + ) + self.assertEqual(plan.root.unpivot.variable_column_name, "variable") + self.assertEqual(plan.root.unpivot.value_column_name, "value") + + plan = ( + df.filter(df.col_name > 3) + .unpivot(["id"], None, "variable", "value") + ._plan.to_proto(self.connect) + ) + self.assertTrue(len(plan.root.unpivot.ids) == 1) + self.assertTrue(all(isinstance(c, proto.Expression) for c in plan.root.unpivot.ids)) + self.assertEqual(plan.root.unpivot.ids[0].unresolved_attribute.unparsed_identifier, "id") + self.assertTrue(len(plan.root.unpivot.values) == 0) + self.assertEqual(plan.root.unpivot.variable_column_name, "variable") + self.assertEqual(plan.root.unpivot.value_column_name, "value") + + def test_melt(self): + df = self.connect.readTable(table_name=self.tbl_name) + + plan = ( + df.filter(df.col_name > 3) + .melt(["id"], ["name"], "variable", "value") + ._plan.to_proto(self.connect) + ) + self.assertTrue(all(isinstance(c, proto.Expression) for c in plan.root.unpivot.ids)) + self.assertEqual(plan.root.unpivot.ids[0].unresolved_attribute.unparsed_identifier, "id") + self.assertTrue(all(isinstance(c, proto.Expression) for c in plan.root.unpivot.values)) + self.assertEqual( + plan.root.unpivot.values[0].unresolved_attribute.unparsed_identifier, "name" + ) + self.assertEqual(plan.root.unpivot.variable_column_name, "variable") + self.assertEqual(plan.root.unpivot.value_column_name, "value") + + plan = ( + df.filter(df.col_name > 3) + .melt(["id"], [], "variable", "value") + ._plan.to_proto(self.connect) + ) + self.assertTrue(len(plan.root.unpivot.ids) == 1) + self.assertTrue(all(isinstance(c, proto.Expression) for c in plan.root.unpivot.ids)) + self.assertEqual(plan.root.unpivot.ids[0].unresolved_attribute.unparsed_identifier, "id") + self.assertTrue(len(plan.root.unpivot.values) == 0) + self.assertEqual(plan.root.unpivot.variable_column_name, "variable") + self.assertEqual(plan.root.unpivot.value_column_name, "value") + def test_summary(self): df = self.connect.readTable(table_name=self.tbl_name) plan = df.filter(df.col_name > 3).summary()._plan.to_proto(self.connect) @@ -379,7 +439,7 @@ def test_simple_udf(self): self.assertIsNotNone(u) expr = u("ThisCol", "ThatCol", "OtherCol") self.assertTrue(isinstance(expr, Column)) - self.assertTrue(isinstance(cast(Column, expr)._expr, UserDefinedFunction)) + self.assertTrue(isinstance(expr._expr, UserDefinedFunction)) u_plan = expr.to_plan(self.connect) self.assertIsNotNone(u_plan) diff --git a/python/pyspark/sql/tests/connect/test_connect_select_ops.py b/python/pyspark/sql/tests/connect/test_connect_select_ops.py index 01d1819fdc..7f8153f7fc 100644 --- a/python/pyspark/sql/tests/connect/test_connect_select_ops.py +++ b/python/pyspark/sql/tests/connect/test_connect_select_ops.py @@ -14,19 +14,21 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import cast import unittest -from pyspark.testing.connectutils import PlanOnlyTestFixture -from pyspark.testing.sqlutils import have_pandas, pandas_requirement_message +from pyspark.testing.connectutils import ( + PlanOnlyTestFixture, + should_test_connect, + connect_requirement_message, +) -if have_pandas: +if should_test_connect: from pyspark.sql.connect.functions import col from pyspark.sql.connect.plan import Read import pyspark.sql.connect.proto as proto -@unittest.skipIf(not have_pandas, cast(str, pandas_requirement_message)) +@unittest.skipIf(not should_test_connect, connect_requirement_message) class SparkConnectToProtoSuite(PlanOnlyTestFixture): def test_select_with_columns_and_strings(self): df = self.connect.with_plan(Read("table")) diff --git a/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py b/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py index b3f4c7331d..c3557f4eb5 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py @@ -407,7 +407,7 @@ def merge_pandas(lft, rgt): from pyspark.sql.tests.pandas.test_pandas_cogrouped_map import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py index 7f27671cfe..0044ae3c72 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py @@ -744,7 +744,7 @@ def my_pandas_udf(pdf): from pyspark.sql.tests.pandas.test_pandas_grouped_map import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py index e75148e524..655f0bf151 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py @@ -240,7 +240,7 @@ def assert_test(): from pyspark.sql.tests.pandas.test_pandas_grouped_map_with_state import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/sql/tests/pandas/test_pandas_map.py b/python/pyspark/sql/tests/pandas/test_pandas_map.py index 7f996ca55a..243cc36c67 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_map.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_map.py @@ -207,7 +207,7 @@ def func(iterator): from pyspark.sql.tests.pandas.test_pandas_map import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf.py b/python/pyspark/sql/tests/pandas/test_pandas_udf.py index 077db2971e..d6d861edb3 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_udf.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_udf.py @@ -296,7 +296,7 @@ def noop(s: pd.Series) -> pd.Series: from pyspark.sql.tests.pandas.test_pandas_udf import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py b/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py index aa844fc5fd..155695f497 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py @@ -551,7 +551,7 @@ def mean(x): from pyspark.sql.tests.pandas.test_pandas_udf_grouped_agg import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py b/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py index 6580f839a8..a5b3bfc164 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py @@ -1333,7 +1333,7 @@ def udf(x): from pyspark.sql.tests.pandas.test_pandas_udf_scalar import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_typehints.py b/python/pyspark/sql/tests/pandas/test_pandas_udf_typehints.py index 8c77ed4b77..3cdf83e2d0 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_udf_typehints.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_typehints.py @@ -238,7 +238,7 @@ def test_scalar_udf_type_hint(self): df = self.spark.range(10).selectExpr("id", "id as v") def plus_one(v: Union[pd.Series, pd.DataFrame]) -> pd.Series: - return v + 1 # type: ignore[return-value] + return v + 1 plus_one = pandas_udf("long")(plus_one) actual = df.select(plus_one(df.v).alias("plus_one")) @@ -360,7 +360,7 @@ def func(col: "Union[pd.Series, pd.DataFrame]", *, col2: "pd.DataFrame") -> "pd. from pyspark.sql.tests.pandas.test_pandas_udf_typehints import * # noqa: #401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_typehints_with_future_annotations.py b/python/pyspark/sql/tests/pandas/test_pandas_udf_typehints_with_future_annotations.py index a6d3bd608d..9b6751564c 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_udf_typehints_with_future_annotations.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_typehints_with_future_annotations.py @@ -241,7 +241,7 @@ def test_scalar_udf_type_hint(self): df = self.spark.range(10).selectExpr("id", "id as v") def plus_one(v: Union[pd.Series, pd.DataFrame]) -> pd.Series: - return v + 1 # type: ignore[return-value] + return v + 1 plus_one = pandas_udf("long")(plus_one) actual = df.select(plus_one(df.v).alias("plus_one")) @@ -367,7 +367,7 @@ def func(col: "Union[pd.Series, pd.DataFrame]", *, col2: "pd.DataFrame") -> "pd. from pyspark.sql.tests.pandas.test_pandas_udf_typehints_with_future_annotations import * # noqa: #401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py b/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py index 07e10a58d2..596742a23b 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py @@ -398,7 +398,7 @@ def test_bounded_mixed(self): from pyspark.sql.tests.pandas.test_pandas_udf_window import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/sql/tests/streaming/test_streaming.py b/python/pyspark/sql/tests/streaming/test_streaming.py index f170787ff7..a67e493a7c 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming.py +++ b/python/pyspark/sql/tests/streaming/test_streaming.py @@ -653,7 +653,7 @@ def test_streaming_write_to_table(self): from pyspark.sql.tests.streaming.test_streaming import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/sql/tests/streaming/test_streaming_listener.py b/python/pyspark/sql/tests/streaming/test_streaming_listener.py index de34565254..c6667e2517 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming_listener.py +++ b/python/pyspark/sql/tests/streaming/test_streaming_listener.py @@ -299,7 +299,7 @@ def onQueryTerminated(self, event): from pyspark.sql.tests.streaming.test_streaming_listener import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/sql/tests/test_arrow_map.py b/python/pyspark/sql/tests/test_arrow_map.py index a4c948fea3..6166cc5dcc 100644 --- a/python/pyspark/sql/tests/test_arrow_map.py +++ b/python/pyspark/sql/tests/test_arrow_map.py @@ -35,7 +35,7 @@ @unittest.skipIf( not have_pandas or not have_pyarrow, - pandas_requirement_message or pyarrow_requirement_message, # type: ignore[arg-type] + pandas_requirement_message or pyarrow_requirement_message, ) class MapInArrowTests(ReusedSQLTestCase): @classmethod @@ -130,7 +130,7 @@ def test_self_join(self): from pyspark.sql.tests.test_arrow_map import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/sql/tests/test_catalog.py b/python/pyspark/sql/tests/test_catalog.py index 24cd67251a..2eccfab72f 100644 --- a/python/pyspark/sql/tests/test_catalog.py +++ b/python/pyspark/sql/tests/test_catalog.py @@ -398,7 +398,7 @@ def test_refresh_table(self): from pyspark.sql.tests.test_catalog import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/sql/tests/test_column.py b/python/pyspark/sql/tests/test_column.py index 2c4730fd81..236fb1b539 100644 --- a/python/pyspark/sql/tests/test_column.py +++ b/python/pyspark/sql/tests/test_column.py @@ -192,7 +192,7 @@ def test_drop_fields(self): from pyspark.sql.tests.test_column import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/sql/tests/test_conf.py b/python/pyspark/sql/tests/test_conf.py index 4ea160818d..a8fa59c036 100644 --- a/python/pyspark/sql/tests/test_conf.py +++ b/python/pyspark/sql/tests/test_conf.py @@ -48,7 +48,7 @@ def test_conf(self): from pyspark.sql.tests.test_conf import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/sql/tests/test_context.py b/python/pyspark/sql/tests/test_context.py index 508a829975..b381833314 100644 --- a/python/pyspark/sql/tests/test_context.py +++ b/python/pyspark/sql/tests/test_context.py @@ -193,7 +193,7 @@ def test_get_or_create(self): from pyspark.sql.tests.test_context import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/sql/tests/test_datasources.py b/python/pyspark/sql/tests/test_datasources.py index 30c1855622..80ab8a3316 100644 --- a/python/pyspark/sql/tests/test_datasources.py +++ b/python/pyspark/sql/tests/test_datasources.py @@ -198,7 +198,7 @@ def test_ignore_column_of_all_nulls(self): from pyspark.sql.tests.test_datasources import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index 55ef012b6d..94cb3c4f1e 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -1156,7 +1156,7 @@ def test_map_functions(self): from pyspark.sql.tests.test_functions import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/sql/tests/test_group.py b/python/pyspark/sql/tests/test_group.py index 19f1a0148b..19e1228d25 100644 --- a/python/pyspark/sql/tests/test_group.py +++ b/python/pyspark/sql/tests/test_group.py @@ -41,7 +41,7 @@ def test_aggregator(self): from pyspark.sql.tests.test_group import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/sql/tests/test_pandas_sqlmetrics.py b/python/pyspark/sql/tests/test_pandas_sqlmetrics.py index d182bafd8b..22a0e92e81 100644 --- a/python/pyspark/sql/tests/test_pandas_sqlmetrics.py +++ b/python/pyspark/sql/tests/test_pandas_sqlmetrics.py @@ -60,7 +60,7 @@ def test_pandas(col1): from pyspark.sql.tests.test_pandas_sqlmetrics import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/sql/tests/test_readwriter.py b/python/pyspark/sql/tests/test_readwriter.py index 2e1bdb4424..4aa24fc2be 100644 --- a/python/pyspark/sql/tests/test_readwriter.py +++ b/python/pyspark/sql/tests/test_readwriter.py @@ -217,7 +217,7 @@ def test_partitioning_functions(self): from pyspark.sql.tests.test_readwriter import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/sql/tests/test_serde.py b/python/pyspark/sql/tests/test_serde.py index 4e9d347da9..e8017cfd38 100644 --- a/python/pyspark/sql/tests/test_serde.py +++ b/python/pyspark/sql/tests/test_serde.py @@ -145,7 +145,7 @@ def test_bytes_as_binary_type(self): from pyspark.sql.tests.test_serde import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/sql/tests/test_session.py b/python/pyspark/sql/tests/test_session.py index 80c05e1a3c..dacaff4d2d 100644 --- a/python/pyspark/sql/tests/test_session.py +++ b/python/pyspark/sql/tests/test_session.py @@ -404,7 +404,7 @@ def test_use_custom_class_for_extensions(self): from pyspark.sql.tests.test_session import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index b1d2eccea4..bc7aafe5f0 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -1405,7 +1405,7 @@ def test_row_without_field_sorting(self): from pyspark.sql.tests.test_types import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/sql/tests/test_utils.py b/python/pyspark/sql/tests/test_utils.py index 954fe9f24a..080d88788b 100644 --- a/python/pyspark/sql/tests/test_utils.py +++ b/python/pyspark/sql/tests/test_utils.py @@ -81,7 +81,7 @@ def test_get_error_class_state(self): from pyspark.sql.tests.test_utils import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/streaming/tests/test_context.py b/python/pyspark/streaming/tests/test_context.py index 1e2c153176..1afcc90b9e 100644 --- a/python/pyspark/streaming/tests/test_context.py +++ b/python/pyspark/streaming/tests/test_context.py @@ -176,7 +176,7 @@ def test_await_termination_or_timeout(self): from pyspark.streaming.tests.test_context import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/streaming/tests/test_dstream.py b/python/pyspark/streaming/tests/test_dstream.py index a52d08a1b1..d37e64affb 100644 --- a/python/pyspark/streaming/tests/test_dstream.py +++ b/python/pyspark/streaming/tests/test_dstream.py @@ -698,7 +698,7 @@ def check_output(n): from pyspark.streaming.tests.test_dstream import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/streaming/tests/test_kinesis.py b/python/pyspark/streaming/tests/test_kinesis.py index 7b09f5b8f5..7efd7a7d0c 100644 --- a/python/pyspark/streaming/tests/test_kinesis.py +++ b/python/pyspark/streaming/tests/test_kinesis.py @@ -110,7 +110,7 @@ def get_output(_, rdd): from pyspark.streaming.tests.test_kinesis import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/streaming/tests/test_listener.py b/python/pyspark/streaming/tests/test_listener.py index f881b2d201..aeec278b38 100644 --- a/python/pyspark/streaming/tests/test_listener.py +++ b/python/pyspark/streaming/tests/test_listener.py @@ -152,7 +152,7 @@ def func(dstream): from pyspark.streaming.tests.test_listener import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/testing/connectutils.py b/python/pyspark/testing/connectutils.py index 1979b6eb72..efc118b572 100644 --- a/python/pyspark/testing/connectutils.py +++ b/python/pyspark/testing/connectutils.py @@ -14,52 +14,66 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import typing import os -from typing import Any, Dict, Optional import functools import unittest -from pyspark.testing.sqlutils import have_pandas +from pyspark.testing.sqlutils import ( + have_pandas, + have_pyarrow, + pandas_requirement_message, + pyarrow_requirement_message, +) -if have_pandas: + +grpc_requirement_message = None +try: + import grpc +except ImportError as e: + grpc_requirement_message = str(e) +have_grpc = grpc_requirement_message is None + +connect_not_compiled_message = None +if have_pandas and have_pyarrow and have_grpc: from pyspark.sql.connect import DataFrame from pyspark.sql.connect.plan import Read, Range, SQL from pyspark.testing.utils import search_jar - from pyspark.sql.connect.plan import LogicalPlan from pyspark.sql.connect.session import SparkSession connect_jar = search_jar("connector/connect/server", "spark-connect-assembly-", "spark-connect") + existing_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell") + jars_args = "--jars %s" % connect_jar + plugin_args = "--conf spark.plugins=org.apache.spark.sql.connect.SparkConnectPlugin" + os.environ["PYSPARK_SUBMIT_ARGS"] = " ".join([jars_args, plugin_args, existing_args]) else: - connect_jar = None - - -if connect_jar is None: - connect_requirement_message = ( + connect_not_compiled_message = ( "Skipping all Spark Connect Python tests as the optional Spark Connect project was " "not compiled into a JAR. To run these tests, you need to build Spark with " "'build/sbt package' or 'build/mvn package' before running this test." ) -else: - existing_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell") - jars_args = "--jars %s" % connect_jar - plugin_args = "--conf spark.plugins=org.apache.spark.sql.connect.SparkConnectPlugin" - os.environ["PYSPARK_SUBMIT_ARGS"] = " ".join([jars_args, plugin_args, existing_args]) - connect_requirement_message = None # type: ignore -should_test_connect = connect_requirement_message is None and have_pandas + +connect_requirement_message = ( + pandas_requirement_message + or pyarrow_requirement_message + or grpc_requirement_message + or connect_not_compiled_message +) +should_test_connect: str = typing.cast(str, connect_requirement_message is None) class MockRemoteSession: - def __init__(self) -> None: - self.hooks: Dict[str, Any] = {} + def __init__(self): + self.hooks = {} - def set_hook(self, name: str, hook: Any) -> None: + def set_hook(self, name, hook): self.hooks[name] = hook - def drop_hook(self, name: str) -> None: + def drop_hook(self, name): self.hooks.pop(name) - def __getattr__(self, item: str) -> Any: + def __getattr__(self, item): if item not in self.hooks: raise LookupError(f"{item} is not defined as a method hook in MockRemoteSession") return functools.partial(self.hooks[item]) @@ -67,43 +81,36 @@ def __getattr__(self, item: str) -> Any: @unittest.skipIf(not should_test_connect, connect_requirement_message) class PlanOnlyTestFixture(unittest.TestCase): - - connect: "MockRemoteSession" - if have_pandas: - session: SparkSession - @classmethod - def _read_table(cls, table_name: str) -> "DataFrame": - return DataFrame.withPlan(Read(table_name), cls.connect) # type: ignore + def _read_table(cls, table_name): + return DataFrame.withPlan(Read(table_name), cls.connect) @classmethod - def _udf_mock(cls, *args, **kwargs) -> str: + def _udf_mock(cls, *args, **kwargs): return "internal_name" @classmethod def _session_range( cls, - start: int, - end: int, - step: int = 1, - num_partitions: Optional[int] = None, - ) -> "DataFrame": - return DataFrame.withPlan( - Range(start, end, step, num_partitions), cls.connect # type: ignore - ) + start, + end, + step=1, + num_partitions=None, + ): + return DataFrame.withPlan(Range(start, end, step, num_partitions), cls.connect) @classmethod - def _session_sql(cls, query: str) -> "DataFrame": - return DataFrame.withPlan(SQL(query), cls.connect) # type: ignore + def _session_sql(cls, query): + return DataFrame.withPlan(SQL(query), cls.connect) if have_pandas: @classmethod - def _with_plan(cls, plan: LogicalPlan) -> "DataFrame": - return DataFrame.withPlan(plan, cls.connect) # type: ignore + def _with_plan(cls, plan): + return DataFrame.withPlan(plan, cls.connect) @classmethod - def setUpClass(cls: Any) -> None: + def setUpClass(cls): cls.connect = MockRemoteSession() cls.session = SparkSession.builder.remote().getOrCreate() cls.tbl_name = "test_connect_plan_only_table_1" @@ -115,7 +122,7 @@ def setUpClass(cls: Any) -> None: cls.connect.set_hook("with_plan", cls._with_plan) @classmethod - def tearDownClass(cls: Any) -> None: + def tearDownClass(cls): cls.connect.drop_hook("register_udf") cls.connect.drop_hook("readTable") cls.connect.drop_hook("range") diff --git a/python/pyspark/testing/pandasutils.py b/python/pyspark/testing/pandasutils.py index ad2f74e8af..6a828f1002 100644 --- a/python/pyspark/testing/pandasutils.py +++ b/python/pyspark/testing/pandasutils.py @@ -22,11 +22,6 @@ from contextlib import contextmanager from distutils.version import LooseVersion -import pandas as pd -from pandas.api.types import is_list_like # type: ignore[attr-defined] -from pandas.core.dtypes.common import is_numeric_dtype -from pandas.testing import assert_frame_equal, assert_index_equal, assert_series_equal - from pyspark import pandas as ps from pyspark.pandas.frame import DataFrame from pyspark.pandas.indexes import Index @@ -36,7 +31,7 @@ tabulate_requirement_message = None try: - from tabulate import tabulate # noqa: F401 + from tabulate import tabulate except ImportError as e: # If tabulate requirement is not satisfied, skip related tests. tabulate_requirement_message = str(e) @@ -44,7 +39,7 @@ matplotlib_requirement_message = None try: - import matplotlib # noqa: F401 + import matplotlib except ImportError as e: # If matplotlib requirement is not satisfied, skip related tests. matplotlib_requirement_message = str(e) @@ -52,7 +47,7 @@ plotly_requirement_message = None try: - import plotly # noqa: F401 + import plotly except ImportError as e: # If plotly requirement is not satisfied, skip related tests. plotly_requirement_message = str(e) @@ -72,6 +67,10 @@ def convert_str_to_lambda(self, func): return lambda x: getattr(x, func)() def assertPandasEqual(self, left, right, check_exact=True): + import pandas as pd + from pandas.core.dtypes.common import is_numeric_dtype + from pandas.testing import assert_frame_equal, assert_index_equal, assert_series_equal + if isinstance(left, pd.DataFrame) and isinstance(right, pd.DataFrame): try: if LooseVersion(pd.__version__) >= LooseVersion("1.1"): @@ -157,6 +156,8 @@ def assertPandasAlmostEqual(self, left, right): - Compare floats rounding to the number of decimal places, 7 after dropping missing values (NaN, NaT, None) """ + import pandas as pd + if isinstance(left, pd.DataFrame) and isinstance(right, pd.DataFrame): msg = ( "DataFrames are not almost equal: " @@ -217,6 +218,9 @@ def assert_eq(self, left, right, check_exact=True, almost=False): :param almost: if this is enabled, the comparison is delegated to `unittest`'s `assertAlmostEqual`. See its documentation for more details. """ + import pandas as pd + from pandas.api.types import is_list_like + lobj = self._to_pandas(left) robj = self._to_pandas(right) if isinstance(lobj, (pd.DataFrame, pd.Series, pd.Index)): diff --git a/python/pyspark/tests/test_appsubmit.py b/python/pyspark/tests/test_appsubmit.py index ff9f3e4f16..79b6b4fa91 100644 --- a/python/pyspark/tests/test_appsubmit.py +++ b/python/pyspark/tests/test_appsubmit.py @@ -298,7 +298,7 @@ def test_user_configuration(self): from pyspark.tests.test_appsubmit import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/tests/test_broadcast.py b/python/pyspark/tests/test_broadcast.py index bc4587ffa6..90d3caa736 100644 --- a/python/pyspark/tests/test_broadcast.py +++ b/python/pyspark/tests/test_broadcast.py @@ -188,7 +188,7 @@ def random_bytes(n): from pyspark.tests.test_broadcast import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/tests/test_conf.py b/python/pyspark/tests/test_conf.py index 6a7c7a05a9..cc9ff82909 100644 --- a/python/pyspark/tests/test_conf.py +++ b/python/pyspark/tests/test_conf.py @@ -36,7 +36,7 @@ def test_memory_conf(self): from pyspark.tests.test_conf import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/tests/test_context.py b/python/pyspark/tests/test_context.py index 1b63869562..d819656f3b 100644 --- a/python/pyspark/tests/test_context.py +++ b/python/pyspark/tests/test_context.py @@ -97,7 +97,7 @@ def test_add_py_file(self): # this job fails due to `userlibrary` not being on the Python path: # disable logging in log4j temporarily def func(x): - from userlibrary import UserClass # type: ignore + from userlibrary import UserClass return UserClass().hello() @@ -145,7 +145,7 @@ def test_add_egg_file_locally(self): # To ensure that we're actually testing addPyFile's effects, check that # this fails due to `userlibrary` not being on the Python path: def func(): - from userlib import UserClass # type: ignore[import] + from userlib import UserClass UserClass() @@ -159,7 +159,7 @@ def func(): def test_overwrite_system_module(self): self.sc.addPyFile(os.path.join(SPARK_HOME, "python/test_support/SimpleHTTPServer.py")) - import SimpleHTTPServer # type: ignore[import] + import SimpleHTTPServer self.assertEqual("My Server", SimpleHTTPServer.__name__) @@ -338,7 +338,7 @@ def tearDown(self): from pyspark.tests.test_context import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/tests/test_daemon.py b/python/pyspark/tests/test_daemon.py index d4cb90c4e8..22196e5369 100644 --- a/python/pyspark/tests/test_daemon.py +++ b/python/pyspark/tests/test_daemon.py @@ -81,7 +81,7 @@ def test_termination_sigterm(self): from pyspark.tests.test_daemon import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/tests/test_install_spark.py b/python/pyspark/tests/test_install_spark.py index cd1c424a85..6f39a09ae1 100644 --- a/python/pyspark/tests/test_install_spark.py +++ b/python/pyspark/tests/test_install_spark.py @@ -142,7 +142,7 @@ def test_checked_versions(self): from pyspark.tests.test_install_spark import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/tests/test_join.py b/python/pyspark/tests/test_join.py index ce4c6e5dfe..de1c260696 100644 --- a/python/pyspark/tests/test_join.py +++ b/python/pyspark/tests/test_join.py @@ -61,7 +61,7 @@ def test_narrow_dependency_in_join(self): from pyspark.tests.test_join import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/tests/test_memory_profiler.py b/python/pyspark/tests/test_memory_profiler.py index cdb75e5b6a..7bd7debe6e 100644 --- a/python/pyspark/tests/test_memory_profiler.py +++ b/python/pyspark/tests/test_memory_profiler.py @@ -33,7 +33,7 @@ @unittest.skipIf(not has_memory_profiler, "Must have memory-profiler installed.") -@unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore +@unittest.skipIf(not have_pandas, pandas_requirement_message) class MemoryProfilerTests(PySparkTestCase): def setUp(self): self._old_sys_path = list(sys.path) @@ -156,7 +156,7 @@ def map(pdfs: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]: from pyspark.tests.test_memory_profiler import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/tests/test_pin_thread.py b/python/pyspark/tests/test_pin_thread.py index 2874e09853..dd291b8a0c 100644 --- a/python/pyspark/tests/test_pin_thread.py +++ b/python/pyspark/tests/test_pin_thread.py @@ -171,7 +171,7 @@ def get_outer_local_prop(): from pyspark.tests.test_pin_thread import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/tests/test_profiler.py b/python/pyspark/tests/test_profiler.py index 8a078d36b4..1db33b59b8 100644 --- a/python/pyspark/tests/test_profiler.py +++ b/python/pyspark/tests/test_profiler.py @@ -155,7 +155,7 @@ def plus_one(v): from pyspark.tests.test_profiler import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/tests/test_rdd.py b/python/pyspark/tests/test_rdd.py index 23e41d6c03..752b5d5599 100644 --- a/python/pyspark/tests/test_rdd.py +++ b/python/pyspark/tests/test_rdd.py @@ -931,7 +931,7 @@ def run_job(job_group, index): from pyspark.tests.test_rdd import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/tests/test_rddbarrier.py b/python/pyspark/tests/test_rddbarrier.py index 18d618e3e1..dd3d2d6b36 100644 --- a/python/pyspark/tests/test_rddbarrier.py +++ b/python/pyspark/tests/test_rddbarrier.py @@ -44,7 +44,7 @@ def f(index, iterator): from pyspark.tests.test_rddbarrier import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/tests/test_rddsampler.py b/python/pyspark/tests/test_rddsampler.py index b504c4ab98..b98f2668cd 100644 --- a/python/pyspark/tests/test_rddsampler.py +++ b/python/pyspark/tests/test_rddsampler.py @@ -58,7 +58,7 @@ def test_rdd_stratified_sampler_func(self): from pyspark.tests.test_rddsampler import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/tests/test_readwrite.py b/python/pyspark/tests/test_readwrite.py index d7086c4bce..73f1025635 100644 --- a/python/pyspark/tests/test_readwrite.py +++ b/python/pyspark/tests/test_readwrite.py @@ -360,7 +360,7 @@ def test_malformed_RDD(self): from pyspark.tests.test_readwrite import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/tests/test_serializers.py b/python/pyspark/tests/test_serializers.py index 0a89861a26..230723e105 100644 --- a/python/pyspark/tests/test_serializers.py +++ b/python/pyspark/tests/test_serializers.py @@ -108,7 +108,7 @@ def __getattr__(self, item): def test_pickling_file_handles(self): # to be corrected with SPARK-11160 try: - import xmlrunner # type: ignore[import] # noqa: F401 + import xmlrunner # noqa: F401 except ImportError: ser = CloudPickleSerializer() out1 = sys.stderr diff --git a/python/pyspark/tests/test_shuffle.py b/python/pyspark/tests/test_shuffle.py index fb11a84f8a..4fb73607a2 100644 --- a/python/pyspark/tests/test_shuffle.py +++ b/python/pyspark/tests/test_shuffle.py @@ -259,7 +259,7 @@ def test_external_sort_in_rdd(self): from pyspark.tests.test_shuffle import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/tests/test_statcounter.py b/python/pyspark/tests/test_statcounter.py index b10fe7cd91..747f42e67b 100644 --- a/python/pyspark/tests/test_statcounter.py +++ b/python/pyspark/tests/test_statcounter.py @@ -122,7 +122,7 @@ def test_merge_stats_with_self(self): from pyspark.tests.test_statcounter import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/tests/test_taskcontext.py b/python/pyspark/tests/test_taskcontext.py index b90a788ae2..5d410aa57e 100644 --- a/python/pyspark/tests/test_taskcontext.py +++ b/python/pyspark/tests/test_taskcontext.py @@ -342,7 +342,7 @@ def tearDown(self): from pyspark.tests.test_taskcontext import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/tests/test_util.py b/python/pyspark/tests/test_util.py index 0ba9a5852e..77f06721b1 100644 --- a/python/pyspark/tests/test_util.py +++ b/python/pyspark/tests/test_util.py @@ -89,7 +89,7 @@ def test_find_spark_home(self): from pyspark.tests.test_util import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/pyspark/tests/test_worker.py b/python/pyspark/tests/test_worker.py index 06ada8f81d..703690bf7f 100644 --- a/python/pyspark/tests/test_worker.py +++ b/python/pyspark/tests/test_worker.py @@ -263,7 +263,7 @@ def conf(cls): from pyspark.tests.test_worker import * # noqa: F401 try: - import xmlrunner # type: ignore[import] + import xmlrunner testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) except ImportError: diff --git a/python/setup.py b/python/setup.py index af102f2308..65db3912ef 100644 --- a/python/setup.py +++ b/python/setup.py @@ -282,6 +282,7 @@ def run(self): 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', + 'Programming Language :: Python :: 3.11', 'Programming Language :: Python :: Implementation :: CPython', 'Programming Language :: Python :: Implementation :: PyPy', 'Typing :: Typed'], diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/R/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/R/Dockerfile index e4d62cf45f..3a5b96dc12 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/R/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/R/Dockerfile @@ -25,7 +25,7 @@ USER 0 RUN mkdir ${SPARK_HOME}/R -# Install R 4.0.4 (http://cloud.r-project.org/bin/linux/debian/) +# Install R 4.1.2 (http://cloud.r-project.org/bin/linux/debian/) RUN \ apt-get update && \ apt install -y r-base r-base-dev && \ diff --git a/resource-managers/kubernetes/integration-tests/README.md b/resource-managers/kubernetes/integration-tests/README.md index 17c1d117f5..82e4de9bad 100644 --- a/resource-managers/kubernetes/integration-tests/README.md +++ b/resource-managers/kubernetes/integration-tests/README.md @@ -27,7 +27,7 @@ To run tests with Hadoop 2.x instead of Hadoop 3.x, use `--hadoop-profile`. ./dev/dev-run-integration-tests.sh --hadoop-profile hadoop-2 -The minimum tested version of Minikube is 1.18.0. The kube-dns addon must be enabled. Minikube should +The minimum tested version of Minikube is 1.28.0. The kube-dns addon must be enabled. Minikube should run with a minimum of 4 CPUs and 6G of memory: minikube start --cpus 4 --memory 6144 @@ -46,7 +46,7 @@ default this is set to `minikube`, the available backends are their prerequisite ### `minikube` -Uses the local `minikube` cluster, this requires that `minikube` 1.18.0 or greater be installed and that it be allocated +Uses the local `minikube` cluster, this requires that `minikube` 1.28.0 or greater be installed and that it be allocated at least 4 CPUs and 6GB memory (some users have reported success with as few as 3 CPUs and 4GB memory). The tests will check if `minikube` is started and abort early if it isn't currently running. diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/minikube/Minikube.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/minikube/Minikube.scala index 755feb9aca..70a849c37e 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/minikube/Minikube.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/backend/minikube/Minikube.scala @@ -48,9 +48,9 @@ private[spark] object Minikube extends Logging { versionArrayOpt match { case Some(Array(x, y, z)) => - if (Ordering.Tuple3[Int, Int, Int].lt((x, y, z), (1, 18, 0))) { + if (Ordering.Tuple3[Int, Int, Int].lt((x, y, z), (1, 28, 0))) { assert(false, s"Unsupported Minikube version is detected: $minikubeVersionString." + - "For integration testing Minikube version 1.18.0 or greater is expected.") + "For integration testing Minikube version 1.28.0 or greater is expected.") } case _ => assert(false, s"Unexpected version format detected in `$minikubeVersionString`." + diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 69dd72720a..9815fa6df8 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -822,6 +822,7 @@ private[spark] class ApplicationMaster( case Shutdown(code) => exitCode = code shutdown = true + allocator.setShutdown(true) } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index ee1d10c204..4980d7e184 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -199,6 +199,8 @@ private[yarn] class YarnAllocator( } } + @volatile private var shutdown = false + // The default profile is always present so we need to initialize the datastructures keyed by // ResourceProfile id to ensure its present if things start running before a request for // executors could add it. This approach is easier then going and special casing everywhere. @@ -215,6 +217,8 @@ private[yarn] class YarnAllocator( initDefaultProfile() + def setShutdown(shutdown: Boolean): Unit = this.shutdown = shutdown + def getNumExecutorsRunning: Int = synchronized { runningExecutorsPerResourceProfileId.values.map(_.size).sum } @@ -835,6 +839,8 @@ private[yarn] class YarnAllocator( // now I think its ok as none of the containers are expected to exit. val exitStatus = completedContainer.getExitStatus val (exitCausedByApp, containerExitReason) = exitStatus match { + case _ if shutdown => + (false, s"Executor for container $containerId exited after Application shutdown.") case ContainerExitStatus.SUCCESS => (false, s"Executor for container $containerId exited because of a YARN event (e.g., " + "preemption) and not because of an error in the running job.") diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index 6e6d840604..717c620f5c 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -162,7 +162,7 @@ private[spark] class YarnClientSchedulerBackend( */ override def stop(exitCode: Int): Unit = { assert(client != null, "Attempted to stop this scheduler before starting it!") - yarnSchedulerEndpoint.handleClientModeDriverStop(exitCode) + yarnSchedulerEndpoint.signalDriverStop(exitCode) if (monitorThread != null) { monitorThread.stopMonitor() } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala index e70a78d3c4..3728c33228 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala @@ -35,6 +35,11 @@ private[spark] class YarnClusterSchedulerBackend( startBindings() } + override def stop(exitCode: Int): Unit = { + yarnSchedulerEndpoint.signalDriverStop(exitCode) + super.stop() + } + override def getDriverLogUrls: Option[Map[String, String]] = { YarnContainerInfoHelper.getLogUrls(sc.hadoopConfiguration, container = None) } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index 572c16d9e9..34848a7f3d 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -319,7 +319,7 @@ private[spark] abstract class YarnSchedulerBackend( removeExecutorMessage.foreach { message => driverEndpoint.send(message) } } - private[cluster] def handleClientModeDriverStop(exitCode: Int): Unit = { + private[cluster] def signalDriverStop(exitCode: Int): Unit = { amEndpoint match { case Some(am) => am.send(Shutdown(exitCode)) diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index 5a80aa9c61..a5ca382fb4 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -693,6 +693,28 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers { .updateBlacklist(hosts.slice(10, 11).asJava, Collections.emptyList()) } + test("SPARK-39601 YarnAllocator should not count executor failure after shutdown") { + val (handler, _) = createAllocator() + handler.updateResourceRequests() + handler.getNumExecutorsFailed should be(0) + + val failedBeforeShutdown = createContainer("host1") + val failedAfterShutdown = createContainer("host2") + handler.handleAllocatedContainers(Seq(failedBeforeShutdown, failedAfterShutdown)) + + val failedBeforeShutdownStatus = ContainerStatus.newInstance( + failedBeforeShutdown.getId, ContainerState.COMPLETE, "Failed", -1) + val failedAfterShutdownStatus = ContainerStatus.newInstance( + failedAfterShutdown.getId, ContainerState.COMPLETE, "Failed", -1) + + handler.processCompletedContainers(Seq(failedBeforeShutdownStatus)) + handler.getNumExecutorsFailed should be(1) + + handler.setShutdown(true) + handler.processCompletedContainers(Seq(failedAfterShutdownStatus)) + handler.getNumExecutorsFailed should be(1) + } + test("SPARK-28577#YarnAllocator.resource.memory should include offHeapSize " + "when offHeapEnabled is true.") { val originalOffHeapEnabled = sparkConf.get(MEMORY_OFFHEAP_ENABLED) diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 4e2da27569..f34b5d55e4 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -135,9 +135,9 @@ This file is divided into 3 sections: - + - ^FunSuite[A-Za-z]*$ + ^AnyFunSuite[A-Za-z]*$ Tests must extend org.apache.spark.SparkFunSuite instead. diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 index 41adbda7b1..38f52901aa 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 @@ -18,7 +18,7 @@ lexer grammar SqlBaseLexer; @members { /** - * When true, parser should throw ParseExcetion for unclosed bracketed comment. + * When true, parser should throw ParseException for unclosed bracketed comment. */ public boolean has_unclosed_bracketed_comment = false; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/Statistics.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/Statistics.java index 270b750259..5afc869d68 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/Statistics.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/Statistics.java @@ -17,8 +17,8 @@ package org.apache.spark.sql.connector.read; +import java.util.HashMap; import java.util.Map; -import java.util.Optional; import java.util.OptionalLong; import org.apache.spark.annotation.Evolving; @@ -35,7 +35,7 @@ public interface Statistics { OptionalLong sizeInBytes(); OptionalLong numRows(); - default Optional> columnStats() { - return Optional.empty(); + default Map columnStats() { + return new HashMap(); } } diff --git a/sql/catalyst/src/main/scala-2.12/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala b/sql/catalyst/src/main/scala-2.12/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala index c55c542d95..504b65e3db 100644 --- a/sql/catalyst/src/main/scala-2.12/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala +++ b/sql/catalyst/src/main/scala-2.12/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala @@ -49,7 +49,8 @@ class AttributeMap[A](val baseMap: Map[ExprId, (Attribute, A)]) override def contains(k: Attribute): Boolean = get(k).isDefined - override def + [B1 >: A](kv: (Attribute, B1)): Map[Attribute, B1] = baseMap.values.toMap + kv + override def + [B1 >: A](kv: (Attribute, B1)): AttributeMap[B1] = + AttributeMap(baseMap.values.toMap + kv) override def iterator: Iterator[(Attribute, A)] = baseMap.valuesIterator diff --git a/sql/catalyst/src/main/scala-2.13/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala b/sql/catalyst/src/main/scala-2.13/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala index 3d5d6471d2..ac6149f3ac 100644 --- a/sql/catalyst/src/main/scala-2.13/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala +++ b/sql/catalyst/src/main/scala-2.13/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala @@ -49,6 +49,9 @@ class AttributeMap[A](val baseMap: Map[ExprId, (Attribute, A)]) override def contains(k: Attribute): Boolean = get(k).isDefined + override def + [B1 >: A](kv: (Attribute, B1)): AttributeMap[B1] = + AttributeMap(baseMap.values.toMap + kv) + override def updated[B1 >: A](key: Attribute, value: B1): Map[Attribute, B1] = baseMap.values.toMap + (key -> value) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 3f806137ba..0423420444 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 import org.apache.spark.sql.catalyst.trees.{AlwaysProcess, CurrentOrigin} import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin import org.apache.spark.sql.catalyst.trees.TreePattern._ -import org.apache.spark.sql.catalyst.util.{toPrettySQL, CharVarcharUtils, StringUtils} +import org.apache.spark.sql.catalyst.util.{toPrettySQL, CaseInsensitiveMap, CharVarcharUtils, StringUtils} import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._ import org.apache.spark.sql.connector.catalog.{View => _, _} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ @@ -288,6 +288,8 @@ class Analyzer(override val catalogManager: CatalogManager) AddMetadataColumns :: DeduplicateRelations :: ResolveReferences :: + WrapLateralColumnAliasReference :: + ResolveLateralColumnAliasReference :: ResolveExpressionsWithNamePlaceholders :: ResolveDeserializer :: ResolveNewInstance :: @@ -1672,7 +1674,7 @@ class Analyzer(override val catalogManager: CatalogManager) // Only Project and Aggregate can host star expressions. case u @ (_: Project | _: Aggregate) => Try(s.expand(u.children.head, resolver)) match { - case Success(expanded) => expanded.map(wrapOuterReference) + case Success(expanded) => expanded.map(wrapOuterReference(_)) case Failure(_) => throw e } // Do not use the outer plan to resolve the star expression @@ -1761,6 +1763,117 @@ class Analyzer(override val catalogManager: CatalogManager) } } + /** + * The first phase to resolve lateral column alias. See comments in + * [[ResolveLateralColumnAliasReference]] for more detailed explanation. + */ + object WrapLateralColumnAliasReference extends Rule[LogicalPlan] { + import ResolveLateralColumnAliasReference.AliasEntry + + private def insertIntoAliasMap( + a: Alias, + idx: Int, + aliasMap: CaseInsensitiveMap[Seq[AliasEntry]]): CaseInsensitiveMap[Seq[AliasEntry]] = { + val prevAliases = aliasMap.getOrElse(a.name, Seq.empty[AliasEntry]) + aliasMap + (a.name -> (prevAliases :+ AliasEntry(a, idx))) + } + + /** + * Use the given lateral alias to resolve the unresolved attribute with the name parts. + * + * Construct a dummy plan with the given lateral alias as project list, use the output of the + * plan to resolve. + * @return The resolved [[LateralColumnAliasReference]] if succeeds. None if fails to resolve. + */ + private def resolveByLateralAlias( + nameParts: Seq[String], lateralAlias: Alias): Option[LateralColumnAliasReference] = { + val resolvedAttr = resolveExpressionByPlanOutput( + expr = UnresolvedAttribute(nameParts), + plan = LocalRelation(Seq(lateralAlias.toAttribute)), + throws = false + ).asInstanceOf[NamedExpression] + if (resolvedAttr.resolved) { + Some(LateralColumnAliasReference(resolvedAttr, nameParts, lateralAlias.toAttribute)) + } else { + None + } + } + + /** + * Recognize all the attributes in the given expression that reference lateral column aliases + * by looking up the alias map. Resolve these attributes and replace by wrapping with + * [[LateralColumnAliasReference]]. + * + * @param currentPlan Because lateral alias has lower resolution priority than table columns, + * the current plan is needed to first try resolving the attribute by its + * children + */ + private def wrapLCARef( + e: NamedExpression, + currentPlan: LogicalPlan, + aliasMap: CaseInsensitiveMap[Seq[AliasEntry]]): NamedExpression = { + e.transformWithPruning(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE)) { + case u: UnresolvedAttribute if aliasMap.contains(u.nameParts.head) && + resolveExpressionByPlanChildren(u, currentPlan).isInstanceOf[UnresolvedAttribute] => + val aliases = aliasMap.get(u.nameParts.head).get + aliases.size match { + case n if n > 1 => + throw QueryCompilationErrors.ambiguousLateralColumnAlias(u.name, n) + case n if n == 1 && aliases.head.alias.resolved => + // Only resolved alias can be the lateral column alias + // The lateral alias can be a struct and have nested field, need to construct + // a dummy plan to resolve the expression + resolveByLateralAlias(u.nameParts, aliases.head.alias).getOrElse(u) + case _ => u + } + case o: OuterReference + if aliasMap.contains( + o.getTagValue(ResolveLateralColumnAliasReference.NAME_PARTS_FROM_UNRESOLVED_ATTR) + .map(_.head) + .getOrElse(o.name)) => + // handle OuterReference exactly same as UnresolvedAttribute + val nameParts = o + .getTagValue(ResolveLateralColumnAliasReference.NAME_PARTS_FROM_UNRESOLVED_ATTR) + .getOrElse(Seq(o.name)) + val aliases = aliasMap.get(nameParts.head).get + aliases.size match { + case n if n > 1 => + throw QueryCompilationErrors.ambiguousLateralColumnAlias(nameParts, n) + case n if n == 1 && aliases.head.alias.resolved => + resolveByLateralAlias(nameParts, aliases.head.alias).getOrElse(o) + case _ => o + } + }.asInstanceOf[NamedExpression] + } + + override def apply(plan: LogicalPlan): LogicalPlan = { + if (!conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED)) { + plan + } else { + plan.resolveOperatorsUpWithPruning( + _.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE), ruleId) { + case p @ Project(projectList, _) if p.childrenResolved + && !ResolveReferences.containsStar(projectList) + && projectList.exists(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE)) => + var aliasMap = CaseInsensitiveMap(Map[String, Seq[AliasEntry]]()) + val newProjectList = projectList.zipWithIndex.map { + case (a: Alias, idx) => + val lcaWrapped = wrapLCARef(a, p, aliasMap).asInstanceOf[Alias] + // Insert the LCA-resolved alias instead of the unresolved one into map. If it is + // resolved, it can be referenced as LCA by later expressions (chaining). + // Unresolved Alias is also added to the map to perform ambiguous name check, but + // only resolved alias can be LCA. + aliasMap = insertIntoAliasMap(lcaWrapped, idx, aliasMap) + lcaWrapped + case (e, _) => + wrapLCARef(e, p, aliasMap) + } + p.copy(projectList = newProjectList) + } + } + } + } + private def containsDeserializer(exprs: Seq[Expression]): Boolean = { exprs.exists(_.exists(_.isInstanceOf[UnresolvedDeserializer])) } @@ -2143,7 +2256,7 @@ class Analyzer(override val catalogManager: CatalogManager) case u @ UnresolvedAttribute(nameParts) => withPosition(u) { try { AnalysisContext.get.outerPlan.get.resolveChildren(nameParts, resolver) match { - case Some(resolved) => wrapOuterReference(resolved) + case Some(resolved) => wrapOuterReference(resolved, Some(nameParts)) case None => u } } catch { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 12dac5c632..e7e153a319 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.optimizer.{BooleanSimplification, Decorrela import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.trees.TreeNodeTag -import org.apache.spark.sql.catalyst.trees.TreePattern.UNRESOLVED_WINDOW_EXPRESSION +import org.apache.spark.sql.catalyst.trees.TreePattern.{LATERAL_COLUMN_ALIAS_REFERENCE, UNRESOLVED_WINDOW_EXPRESSION} import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, StringUtils, TypeUtils} import org.apache.spark.sql.connector.catalog.{LookupCatalog, SupportsPartitionManagement} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase} @@ -552,7 +552,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB errorClass = "NUM_COLUMNS_MISMATCH", messageParameters = Map( "operator" -> toSQLStmt(operator.nodeName), - "refNumColumns" -> ref.length.toString, + "firstNumColumns" -> ref.length.toString, "invalidOrdinalNum" -> ordinalNumber(ti + 1), "invalidNumColumns" -> child.output.length.toString)) } @@ -565,7 +565,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB e.failAnalysis( errorClass = "_LEGACY_ERROR_TEMP_2430", messageParameters = Map( - "operator" -> operator.nodeName, + "operator" -> toSQLStmt(operator.nodeName), "ci" -> ordinalNumber(ci), "ti" -> ordinalNumber(ti + 1), "dt1" -> dt1.catalogString, @@ -638,6 +638,16 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB case UnresolvedWindowExpression(_, windowSpec) => throw QueryCompilationErrors.windowSpecificationNotDefinedError(windowSpec.name) }) + // This should not happen, resolved Project or Aggregate should restore or resolve + // all lateral column alias references. Add check for extra safe. + projectList.foreach(_.transformDownWithPruning( + _.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) { + case lcaRef: LateralColumnAliasReference if p.resolved => + throw SparkException.internalError("Resolved Project should not contain " + + s"any LateralColumnAliasReference.\nDebugging information: plan: $p", + context = lcaRef.origin.getQueryContext, + summary = lcaRef.origin.context.summary) + }) case j: Join if !j.duplicateResolved => val conflictingAttributes = j.left.outputSet.intersect(j.right.outputSet) @@ -714,6 +724,19 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB "operator" -> other.nodeName, "invalidExprSqls" -> invalidExprSqls.mkString(", "))) + // This should not happen, resolved Project or Aggregate should restore or resolve + // all lateral column alias references. Add check for extra safe. + case agg @ Aggregate(_, aggList, _) + if aggList.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) && agg.resolved => + aggList.foreach(_.transformDownWithPruning( + _.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) { + case lcaRef: LateralColumnAliasReference => + throw SparkException.internalError("Resolved Aggregate should not contain " + + s"any LateralColumnAliasReference.\nDebugging information: plan: $agg", + context = lcaRef.origin.getQueryContext, + summary = lcaRef.origin.context.summary) + }) + case _ => // Analysis successful! } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala new file mode 100644 index 0000000000..2ca187b95f --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeMap, LateralColumnAliasReference, NamedExpression} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreeNodeTag +import org.apache.spark.sql.catalyst.trees.TreePattern.LATERAL_COLUMN_ALIAS_REFERENCE +import org.apache.spark.sql.internal.SQLConf + +/** + * This rule is the second phase to resolve lateral column alias. + * + * Resolve lateral column alias, which references the alias defined previously in the SELECT list. + * Plan-wise, it handles two types of operators: Project and Aggregate. + * - in Project, pushing down the referenced lateral alias into a newly created Project, resolve + * the attributes referencing these aliases + * - in Aggregate TODO. + * + * The whole process is generally divided into two phases: + * 1) recognize resolved lateral alias, wrap the attributes referencing them with + * [[LateralColumnAliasReference]] + * 2) when the whole operator is resolved, unwrap [[LateralColumnAliasReference]]. + * For Project, it further resolves the attributes and push down the referenced lateral aliases. + * For Aggregate, TODO + * + * Example for Project: + * Before rewrite: + * Project [age AS a, 'a + 1] + * +- Child + * + * After phase 1: + * Project [age AS a, lateralalias(a) + 1] + * +- Child + * + * After phase 2: + * Project [a, a + 1] + * +- Project [child output, age AS a] + * +- Child + * + * Example for Aggregate TODO + * + * + * The name resolution priority: + * local table column > local lateral column alias > outer reference + * + * Because lateral column alias has higher resolution priority than outer reference, it will try + * to resolve an [[OuterReference]] using lateral column alias in phase 1, similar as an + * [[UnresolvedAttribute]]. If success, it strips [[OuterReference]] and also wraps it with + * [[LateralColumnAliasReference]]. + */ +object ResolveLateralColumnAliasReference extends Rule[LogicalPlan] { + case class AliasEntry(alias: Alias, index: Int) + + /** + * A tag to store the nameParts from the original unresolved attribute. + * It is set for [[OuterReference]], used in the current rule to convert [[OuterReference]] back + * to [[LateralColumnAliasReference]]. + */ + val NAME_PARTS_FROM_UNRESOLVED_ATTR = TreeNodeTag[Seq[String]]("name_parts_from_unresolved_attr") + + override def apply(plan: LogicalPlan): LogicalPlan = { + if (!conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED)) { + plan + } else { + // phase 2: unwrap + plan.resolveOperatorsUpWithPruning( + _.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE), ruleId) { + case p @ Project(projectList, child) if p.resolved + && projectList.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) => + var aliasMap = AttributeMap.empty[AliasEntry] + val referencedAliases = collection.mutable.Set.empty[AliasEntry] + def unwrapLCAReference(e: NamedExpression): NamedExpression = { + e.transformWithPruning(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) { + case lcaRef: LateralColumnAliasReference if aliasMap.contains(lcaRef.a) => + val aliasEntry = aliasMap.get(lcaRef.a).get + // If there is no chaining of lateral column alias reference, push down the alias + // and unwrap the LateralColumnAliasReference to the NamedExpression inside + // If there is chaining, don't resolve and save to future rounds + if (!aliasEntry.alias.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) { + referencedAliases += aliasEntry + lcaRef.ne + } else { + lcaRef + } + case lcaRef: LateralColumnAliasReference if !aliasMap.contains(lcaRef.a) => + // It shouldn't happen, but restore to unresolved attribute to be safe. + UnresolvedAttribute(lcaRef.nameParts) + }.asInstanceOf[NamedExpression] + } + val newProjectList = projectList.zipWithIndex.map { + case (a: Alias, idx) => + val lcaResolved = unwrapLCAReference(a) + // Insert the original alias instead of rewritten one to detect chained LCA + aliasMap += (a.toAttribute -> AliasEntry(a, idx)) + lcaResolved + case (e, _) => + unwrapLCAReference(e) + } + + if (referencedAliases.isEmpty) { + p + } else { + val outerProjectList = collection.mutable.Seq(newProjectList: _*) + val innerProjectList = + collection.mutable.ArrayBuffer(child.output.map(_.asInstanceOf[NamedExpression]): _*) + referencedAliases.foreach { case AliasEntry(alias: Alias, idx) => + outerProjectList.update(idx, alias.toAttribute) + innerProjectList += alias + } + p.copy( + projectList = outerProjectList.toSeq, + child = Project(innerProjectList.toSeq, child) + ) + } + } + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 23152adc0c..d7a3952e63 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -294,6 +294,9 @@ object Cast extends QueryErrorsBase { case (from: DecimalType, to: NumericType) if from.isTighterThan(to) => true case (f, t) if legalNumericPrecedence(f, t) => true case (DateType, TimestampType) => true + case (DateType, TimestampNTZType) => true + case (TimestampNTZType, TimestampType) => true + case (TimestampType, TimestampNTZType) => true case (_: AtomicType, StringType) => true case (_: CalendarIntervalType, StringType) => true case (NullType, _) => true @@ -507,7 +510,7 @@ case class Cast( final override def nodePatternsInternal(): Seq[TreePattern] = Seq(CAST) def ansiEnabled: Boolean = { - evalMode == EvalMode.ANSI || evalMode == EvalMode.TRY + evalMode == EvalMode.ANSI || (evalMode == EvalMode.TRY && !canUseLegacyCastForTryCast) } // Whether this expression is used for `try_cast()`. @@ -1267,7 +1270,7 @@ case class Cast( } private def cast(from: DataType, to: DataType): Any => Any = { - if (!isTryCast) { + if (!isTryCast || canUseLegacyCastForTryCast) { castInternal(from, to) } else { (input: Any) => @@ -1280,6 +1283,20 @@ case class Cast( } } + // Whether Spark SQL can evaluation the try_cast as the legacy cast, so that no `try...catch` + // is needed and the performance can be faster. + private lazy val canUseLegacyCastForTryCast: Boolean = { + if (!child.resolved) { + false + } else { + (child.dataType, dataType) match { + case (StringType, _: FractionalType) => true + case (StringType, _: DatetimeType) => true + case _ => false + } + } + } + protected[this] lazy val cast: Any => Any = cast(child.dataType, dataType) protected override def nullSafeEval(input: Any): Any = cast(input) @@ -1345,7 +1362,7 @@ case class Cast( protected[this] def castCode(ctx: CodegenContext, input: ExprValue, inputIsNull: ExprValue, result: ExprValue, resultIsNull: ExprValue, resultType: DataType, cast: CastFunction): Block = { val javaType = JavaCode.javaType(resultType) - val castCodeWithTryCatchIfNeeded = if (!isTryCast) { + val castCodeWithTryCatchIfNeeded = if (!isTryCast || canUseLegacyCastForTryCast) { s"${cast(input, result, resultIsNull)}" } else { s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index 2bbde304c2..330d66a21b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -142,6 +142,21 @@ class EquivalentExpressions { case _ => Nil } + private def supportedExpression(e: Expression) = { + !e.exists { + // `LambdaVariable` is usually used as a loop variable and `NamedLambdaVariable` is used in + // higher-order functions, which can't be evaluated ahead of the execution. + case _: LambdaVariable => true + case _: NamedLambdaVariable => true + + // `PlanExpression` wraps query plan. To compare query plans of `PlanExpression` on executor, + // can cause error like NPE. + case _: PlanExpression[_] => Utils.isInRunningSparkTask + + case _ => false + } + } + /** * Adds the expression to this data structure recursively. Stops if a matching expression * is found. That is, if `expr` has already been added, its children are not added. @@ -149,21 +164,16 @@ class EquivalentExpressions { def addExprTree( expr: Expression, map: mutable.HashMap[ExpressionEquals, ExpressionStats] = equivalenceMap): Unit = { - updateExprTree(expr, map) + if (supportedExpression(expr)) { + updateExprTree(expr, map) + } } private def updateExprTree( expr: Expression, map: mutable.HashMap[ExpressionEquals, ExpressionStats] = equivalenceMap, useCount: Int = 1): Unit = { - val skip = useCount == 0 || - expr.isInstanceOf[LeafExpression] || - // `LambdaVariable` is usually used as a loop variable, which can't be evaluated ahead of the - // loop. So we can't evaluate sub-expressions containing `LambdaVariable` at the beginning. - expr.exists(_.isInstanceOf[LambdaVariable]) || - // `PlanExpression` wraps query plan. To compare query plans of `PlanExpression` on executor, - // can cause error like NPE. - (expr.exists(_.isInstanceOf[PlanExpression[_]]) && Utils.isInRunningSparkTask) + val skip = useCount == 0 || expr.isInstanceOf[LeafExpression] if (!skip && !updateExprInMap(expr, map, useCount)) { val uc = useCount.signum @@ -177,7 +187,11 @@ class EquivalentExpressions { * equivalent expressions. */ def getExprState(e: Expression): Option[ExpressionStats] = { - equivalenceMap.get(ExpressionEquals(e)) + if (supportedExpression(e)) { + equivalenceMap.get(ExpressionEquals(e)) + } else { + None + } } // Exposed for testing. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala index 91c9457af7..4e129e96d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.DecimalType /** @@ -72,7 +73,7 @@ class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mutable private[this] val fieldWriters: Array[Any => Unit] = validExprs.map { case (e, i) => val writer = InternalRow.getWriter(i, e.dataType) - if (!e.nullable) { + if (!e.nullable || e.dataType.isInstanceOf[DecimalType]) { (v: Any) => writer(mutableRow, v) } else { (v: Any) => { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index e8bad46e84..3e89dfe39c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -171,8 +171,8 @@ object CurDateExpressionBuilder extends ExpressionBuilder { if (expressions.isEmpty) { CurrentDate() } else { - throw QueryCompilationErrors.invalidFunctionArgumentNumberError( - Seq.empty, funcName, expressions.length) + throw QueryCompilationErrors.invalidFunctionArgumentsError( + funcName, "0", expressions.length) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 8dd28e9aaa..0f5239be6c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -428,6 +428,39 @@ case class OuterReference(e: NamedExpression) final override val nodePatterns: Seq[TreePattern] = Seq(OUTER_REFERENCE) } +/** + * A placeholder used to hold a [[NamedExpression]] that has been temporarily resolved as the + * reference to a lateral column alias. + * + * This is created and removed by Analyzer rule [[ResolveLateralColumnAlias]]. + * There should be no [[LateralColumnAliasReference]] beyond analyzer: if the plan passes all + * analysis check, then all [[LateralColumnAliasReference]] should already be removed. + * + * @param ne the resolved [[NamedExpression]] by lateral column alias + * @param nameParts the named parts of the original [[UnresolvedAttribute]]. Used to restore back + * to [[UnresolvedAttribute]] when needed + * @param a the attribute of referenced lateral column alias. Used to match alias when unwrapping + * and resolving LateralColumnAliasReference + */ +case class LateralColumnAliasReference(ne: NamedExpression, nameParts: Seq[String], a: Attribute) + extends LeafExpression with NamedExpression with Unevaluable { + assert(ne.resolved) + override def name: String = + nameParts.map(n => if (n.contains(".")) s"`$n`" else n).mkString(".") + override def exprId: ExprId = ne.exprId + override def qualifier: Seq[String] = ne.qualifier + override def toAttribute: Attribute = ne.toAttribute + override def newInstance(): NamedExpression = + LateralColumnAliasReference(ne.newInstance(), nameParts, a) + + override def nullable: Boolean = ne.nullable + override def dataType: DataType = ne.dataType + override def prettyName: String = "lateralAliasReference" + override def sql: String = s"$prettyName($name)" + + final override val nodePatterns: Seq[TreePattern] = Seq(LATERAL_COLUMN_ALIAS_REFERENCE) +} + object VirtualColumn { // The attribute name used by Hive, which has different result than Spark, deprecated. val hiveGroupingIdName: String = "grouping__id" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala index f5f86bfac1..2d4f0438db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala @@ -224,17 +224,21 @@ case class TryToNumber(left: Expression, right: Expression) group = "string_funcs") case class ToCharacter(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { - private lazy val numberFormat = right.eval().toString.toUpperCase(Locale.ROOT) - private lazy val numberFormatter = new ToNumberParser(numberFormat, true) + private lazy val numberFormatter = { + val value = right.eval() + if (value != null) { + new ToNumberParser(value.toString.toUpperCase(Locale.ROOT), true) + } else { + null + } + } override def dataType: DataType = StringType override def inputTypes: Seq[AbstractDataType] = Seq(DecimalType, StringType) override def checkInputDataTypes(): TypeCheckResult = { val inputTypeCheck = super.checkInputDataTypes() if (inputTypeCheck.isSuccess) { - if (right.foldable) { - numberFormatter.checkInputDataTypes() - } else { + if (!right.foldable) { DataTypeMismatch( errorSubClass = "NON_FOLDABLE_INPUT", messageParameters = Map( @@ -243,6 +247,10 @@ case class ToCharacter(left: Expression, right: Expression) "inputExpr" -> toSQLExpr(right) ) ) + } else if (numberFormatter == null) { + TypeCheckResult.TypeCheckSuccess + } else { + numberFormatter.checkInputDataTypes() } } else { inputTypeCheck @@ -260,7 +268,7 @@ case class ToCharacter(left: Expression, right: Expression) val result = code""" |${eval.code} - |boolean ${ev.isNull} = ${eval.isNull}; + |boolean ${ev.isNull} = ${eval.isNull} || ($builder == null); |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |if (!${ev.isNull}) { | ${ev.value} = $builder.format(${eval.value}); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 3a1db2ce1b..56010219b0 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -2552,6 +2552,8 @@ case class StringDecode(bin: Expression, charset: Expression) override protected def withNewChildrenInternal( newLeft: Expression, newRight: Expression): StringDecode = copy(bin = newLeft, charset = newRight) + + override def prettyName: String = "decode" } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index e7384dac2d..b510893f37 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.mutable.ArrayBuffer +import org.apache.spark.sql.catalyst.analysis.ResolveLateralColumnAliasReference.NAME_PARTS_FROM_UNRESOLVED_ATTR import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{Filter, HintInfo, LogicalPlan} @@ -158,8 +159,12 @@ object SubExprUtils extends PredicateHelper { /** * Wrap attributes in the expression with [[OuterReference]]s. */ - def wrapOuterReference[E <: Expression](e: E): E = { - e.transform { case a: Attribute => OuterReference(a) }.asInstanceOf[E] + def wrapOuterReference[E <: Expression](e: E, nameParts: Option[Seq[String]] = None): E = { + e.transform { case a: Attribute => + val o = OuterReference(a) + nameParts.map(o.setTagValue(NAME_PARTS_FROM_UNRESOLVED_ATTR, _)) + o + }.asInstanceOf[E] } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 0ad185bef1..da25702e1f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -2417,7 +2417,7 @@ class AstBuilder extends SqlBaseParserBaseVisitor[AnyRef] with SQLConfHelper wit IntervalUtils.stringToInterval(UTF8String.fromString(value)) } catch { case e: IllegalArgumentException => - val ex = QueryParsingErrors.cannotParseIntervalValueError(value, ctx) + val ex = QueryParsingErrors.cannotParseValueTypeError(valueType, value, ctx) ex.setStackTrace(e.getStackTrace) throw ex } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 60586e4166..878ad91c08 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -342,7 +342,7 @@ case class Intersect( right: LogicalPlan, isAll: Boolean) extends SetOperation(left, right) { - override def nodeName: String = getClass.getSimpleName + ( if ( isAll ) "All" else "" ) + override def nodeName: String = getClass.getSimpleName + ( if ( isAll ) " All" else "" ) final override val nodePatterns: Seq[TreePattern] = Seq(INTERSECT) @@ -372,7 +372,7 @@ case class Except( left: LogicalPlan, right: LogicalPlan, isAll: Boolean) extends SetOperation(left, right) { - override def nodeName: String = getClass.getSimpleName + ( if ( isAll ) "All" else "" ) + override def nodeName: String = getClass.getSimpleName + ( if ( isAll ) " All" else "" ) /** We don't use right.output because those rows get excluded from the set. */ override def output: Seq[Attribute] = left.output diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala index f6bef88ab8..efafd3cfbc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala @@ -77,6 +77,7 @@ object RuleIdCollection { "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveWindowFrame" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveWindowOrder" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$WindowsSubstitution" :: + "org.apache.spark.sql.catalyst.analysis.Analyzer$WrapLateralColumnAliasReference" :: "org.apache.spark.sql.catalyst.analysis.AnsiTypeCoercion$AnsiCombinedTypeCoercionRule" :: "org.apache.spark.sql.catalyst.analysis.ApplyCharTypePadding" :: "org.apache.spark.sql.catalyst.analysis.DeduplicateRelations" :: @@ -88,6 +89,7 @@ object RuleIdCollection { "org.apache.spark.sql.catalyst.analysis.ResolveHints$ResolveJoinStrategyHints" :: "org.apache.spark.sql.catalyst.analysis.ResolveInlineTables" :: "org.apache.spark.sql.catalyst.analysis.ResolveLambdaVariables" :: + "org.apache.spark.sql.catalyst.analysis.ResolveLateralColumnAliasReference" :: "org.apache.spark.sql.catalyst.analysis.ResolveTimeZone" :: "org.apache.spark.sql.catalyst.analysis.ResolveUnion" :: "org.apache.spark.sql.catalyst.analysis.SubstituteUnresolvedOrdinals" :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala index 8fca9ec60c..1a8ad7c7d6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala @@ -58,6 +58,7 @@ object TreePattern extends Enumeration { val JSON_TO_STRUCT: Value = Value val LAMBDA_FUNCTION: Value = Value val LAMBDA_VARIABLE: Value = Value + val LATERAL_COLUMN_ALIAS_REFERENCE: Value = Value val LATERAL_SUBQUERY: Value = Value val LIKE_FAMLIY: Value = Value val LIST_SUBQUERY: Value = Value diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index ed08e33829..b329f6689d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -795,12 +795,6 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase { "column" -> quoted)) } - def columnDoesNotExistError(colName: String): Throwable = { - new AnalysisException( - errorClass = "_LEGACY_ERROR_TEMP_1061", - messageParameters = Map("colName" -> colName)) - } - def renameTempViewToExistingViewError(newName: String): Throwable = { new TableAlreadyExistsException(newName) } @@ -2281,14 +2275,22 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase { messageParameters = Map("columnName" -> toSQLId(columnName))) } + def columnNotFoundError(colName: String): Throwable = { + new AnalysisException( + errorClass = "COLUMN_NOT_FOUND", + messageParameters = Map( + "colName" -> toSQLId(colName), + "caseSensitiveConfig" -> toSQLConf(SQLConf.CASE_SENSITIVE.key))) + } + def noSuchTableError(db: String, table: String): Throwable = { new NoSuchTableException(db = db, table = table) } def tempViewNotCachedForAnalyzingColumnsError(tableIdent: TableIdentifier): Throwable = { new AnalysisException( - errorClass = "_LEGACY_ERROR_TEMP_1234", - messageParameters = Map("tableIdent" -> tableIdent.toString)) + errorClass = "UNSUPPORTED_FEATURE.ANALYZE_UNCACHED_TEMP_VIEW", + messageParameters = Map("viewName" -> toSQLId(tableIdent.toString))) } def columnTypeNotSupportStatisticsCollectionError( @@ -3397,4 +3399,23 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase { cause = Option(other)) } } + + def ambiguousLateralColumnAlias(name: String, numOfMatches: Int): Throwable = { + new AnalysisException( + errorClass = "AMBIGUOUS_LATERAL_COLUMN_ALIAS", + messageParameters = Map( + "name" -> toSQLId(name), + "n" -> numOfMatches.toString + ) + ) + } + def ambiguousLateralColumnAlias(nameParts: Seq[String], numOfMatches: Int): Throwable = { + new AnalysisException( + errorClass = "AMBIGUOUS_LATERAL_COLUMN_ALIAS", + messageParameters = Map( + "name" -> toSQLId(nameParts), + "n" -> numOfMatches.toString + ) + ) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala index 018e9a12e0..aef95a538a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala @@ -211,15 +211,11 @@ private[sql] object QueryParsingErrors extends QueryErrorsBase { def cannotParseValueTypeError( valueType: String, value: String, ctx: TypeConstructorContext): Throwable = { new ParseException( - errorClass = "_LEGACY_ERROR_TEMP_0019", - messageParameters = Map("valueType" -> valueType, "value" -> value), - ctx) - } - - def cannotParseIntervalValueError(value: String, ctx: TypeConstructorContext): Throwable = { - new ParseException( - errorClass = "_LEGACY_ERROR_TEMP_0020", - messageParameters = Map("value" -> value), + errorClass = "INVALID_TYPED_LITERAL", + messageParameters = Map( + "valueType" -> toSQLType(valueType), + "value" -> toSQLValue(value, StringType) + ), ctx) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 0d7ce1388d..3db2ec6b8d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -223,8 +223,8 @@ object DataSourceV2Relation { } var colStats: Seq[(Attribute, ColumnStat)] = Seq.empty[(Attribute, ColumnStat)] - if (v2Statistics.columnStats().isPresent) { - val v2ColumnStat = v2Statistics.columnStats().get() + if (!v2Statistics.columnStats().isEmpty) { + val v2ColumnStat = v2Statistics.columnStats() val keys = v2ColumnStat.keySet() keys.forEach(key => { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 84d78f365a..575775a0f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -4027,6 +4027,17 @@ object SQLConf { .checkValues(ErrorMessageFormat.values.map(_.toString)) .createWithDefault(ErrorMessageFormat.PRETTY.toString) + val LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED = + buildConf("spark.sql.lateralColumnAlias.enableImplicitResolution") + .internal() + .doc("Enable resolving implicit lateral column alias defined in the same SELECT list. For " + + "example, with this conf turned on, for query `SELECT 1 AS a, a + 1` the `a` in `a + 1` " + + "can be resolved as the previously defined `1 AS a`. But note that table column has " + + "higher resolution priority than the lateral column alias.") + .version("3.4.0") + .booleanConf + .createWithDefault(false) + /** * Holds information about keys that have been deprecated. * diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala index 68b3d5c844..bad85ca417 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala @@ -720,6 +720,15 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { assert(Cast.canUpCast(from, to)) } } + + { + assert(Cast.canUpCast(DateType, TimestampType)) + assert(Cast.canUpCast(DateType, TimestampNTZType)) + assert(Cast.canUpCast(TimestampType, TimestampNTZType)) + assert(Cast.canUpCast(TimestampNTZType, TimestampType)) + assert(!Cast.canUpCast(TimestampType, DateType)) + assert(!Cast.canUpCast(TimestampNTZType, DateType)) + } } test("SPARK-40389: canUpCast: return false if casting decimal to integral types can cause" + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala index 0f01bfbb89..e3f1128381 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala @@ -65,6 +65,68 @@ class MutableProjectionSuite extends SparkFunSuite with ExpressionEvalHelper { assert(SafeProjection.create(fixedLengthTypes)(projUnsafeRow) === inputRow) } + def testRows( + bufferSchema: StructType, + buffer: InternalRow, + scalaRows: Seq[Seq[Any]]): Unit = { + val bufferTypes = bufferSchema.map(_.dataType).toArray + val proj = createMutableProjection(bufferTypes) + + scalaRows.foreach { scalaRow => + val inputRow = InternalRow.fromSeq(scalaRow.zip(bufferTypes).map { + case (v, dataType) => CatalystTypeConverters.createToCatalystConverter(dataType)(v) + }) + val projRow = proj.target(buffer)(inputRow) + assert(SafeProjection.create(bufferTypes)(projRow) === inputRow) + } + } + + testBothCodegenAndInterpreted("SPARK-41395: unsafe buffer with null decimal (high precision)") { + val bufferSchema = StructType(Array( + StructField("dec1", DecimalType(27, 2), nullable = true), + StructField("dec2", DecimalType(27, 2), nullable = true))) + val buffer = UnsafeProjection.create(bufferSchema) + .apply(new GenericInternalRow(bufferSchema.length)) + val scalaRows = Seq( + Seq(null, null), + Seq(BigDecimal(77.77), BigDecimal(245.00))) + testRows(bufferSchema, buffer, scalaRows) + } + + testBothCodegenAndInterpreted("SPARK-41395: unsafe buffer with null decimal (low precision)") { + val bufferSchema = StructType(Array( + StructField("dec1", DecimalType(10, 2), nullable = true), + StructField("dec2", DecimalType(10, 2), nullable = true))) + val buffer = UnsafeProjection.create(bufferSchema) + .apply(new GenericInternalRow(bufferSchema.length)) + val scalaRows = Seq( + Seq(null, null), + Seq(BigDecimal(77.77), BigDecimal(245.00))) + testRows(bufferSchema, buffer, scalaRows) + } + + testBothCodegenAndInterpreted("SPARK-41395: generic buffer with null decimal (high precision)") { + val bufferSchema = StructType(Array( + StructField("dec1", DecimalType(27, 2), nullable = true), + StructField("dec2", DecimalType(27, 2), nullable = true))) + val buffer = new GenericInternalRow(bufferSchema.length) + val scalaRows = Seq( + Seq(null, null), + Seq(BigDecimal(77.77), BigDecimal(245.00))) + testRows(bufferSchema, buffer, scalaRows) + } + + testBothCodegenAndInterpreted("SPARK-41395: generic buffer with null decimal (low precision)") { + val bufferSchema = StructType(Array( + StructField("dec1", DecimalType(10, 2), nullable = true), + StructField("dec2", DecimalType(10, 2), nullable = true))) + val buffer = new GenericInternalRow(bufferSchema.length) + val scalaRows = Seq( + Seq(null, null), + Seq(BigDecimal(77.77), BigDecimal(245.00))) + testRows(bufferSchema, buffer, scalaRows) + } + testBothCodegenAndInterpreted("variable-length types") { val proj = createMutableProjection(variableLengthTypes) val scalaValues = Seq("abc", BigDecimal(10), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index f0b320db3a..8be732a52c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -1256,6 +1256,13 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { ) } + test("SPARK-41452: ToCharacter: null format string") { + // if null format, to_number should return null + val toCharacterExpr = ToCharacter(Literal(Decimal(454)), Literal(null, StringType)) + assert(toCharacterExpr.checkInputDataTypes() == TypeCheckResult.TypeCheckSuccess) + checkEvaluation(toCharacterExpr, null) + } + test("ToBinary: fails analysis if fmt is not foldable") { val wrongFmt = AttributeReference("invalidFormat", StringType)() val toBinaryExpr = ToBinary(Literal("abc"), Some(wrongFmt)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index 884e782736..760b8630f6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -521,8 +521,12 @@ class ExpressionParserSuite extends AnalysisTest { Literal(Timestamp.valueOf("2016-03-11 20:54:00.000"))) checkError( exception = parseException("timestamP_LTZ '2016-33-11 20:54:00.000'"), - errorClass = "_LEGACY_ERROR_TEMP_0019", - parameters = Map("valueType" -> "TIMESTAMP_LTZ", "value" -> "2016-33-11 20:54:00.000"), + errorClass = "INVALID_TYPED_LITERAL", + sqlState = "42000", + parameters = Map( + "valueType" -> "\"TIMESTAMP_LTZ\"", + "value" -> "'2016-33-11 20:54:00.000'" + ), context = ExpectedContext( fragment = "timestamP_LTZ '2016-33-11 20:54:00.000'", start = 0, @@ -533,8 +537,12 @@ class ExpressionParserSuite extends AnalysisTest { Literal(LocalDateTime.parse("2016-03-11T20:54:00.000"))) checkError( exception = parseException("tImEstAmp_Ntz '2016-33-11 20:54:00.000'"), - errorClass = "_LEGACY_ERROR_TEMP_0019", - parameters = Map("valueType" -> "TIMESTAMP_NTZ", "value" -> "2016-33-11 20:54:00.000"), + errorClass = "INVALID_TYPED_LITERAL", + sqlState = "42000", + parameters = Map( + "valueType" -> "\"TIMESTAMP_NTZ\"", + "value" -> "'2016-33-11 20:54:00.000'" + ), context = ExpectedContext( fragment = "tImEstAmp_Ntz '2016-33-11 20:54:00.000'", start = 0, @@ -545,8 +553,9 @@ class ExpressionParserSuite extends AnalysisTest { assertEqual("dAte '2016-03-11'", Literal(Date.valueOf("2016-03-11"))) checkError( exception = parseException("DAtE 'mar 11 2016'"), - errorClass = "_LEGACY_ERROR_TEMP_0019", - parameters = Map("valueType" -> "DATE", "value" -> "mar 11 2016"), + errorClass = "INVALID_TYPED_LITERAL", + sqlState = "42000", + parameters = Map("valueType" -> "\"DATE\"", "value" -> "'mar 11 2016'"), context = ExpectedContext( fragment = "DAtE 'mar 11 2016'", start = 0, @@ -557,8 +566,9 @@ class ExpressionParserSuite extends AnalysisTest { Literal(Timestamp.valueOf("2016-03-11 20:54:00.000"))) checkError( exception = parseException("timestamP '2016-33-11 20:54:00.000'"), - errorClass = "_LEGACY_ERROR_TEMP_0019", - parameters = Map("valueType" -> "TIMESTAMP", "value" -> "2016-33-11 20:54:00.000"), + errorClass = "INVALID_TYPED_LITERAL", + sqlState = "42000", + parameters = Map("valueType" -> "\"TIMESTAMP\"", "value" -> "'2016-33-11 20:54:00.000'"), context = ExpectedContext( fragment = "timestamP '2016-33-11 20:54:00.000'", start = 0, @@ -571,8 +581,9 @@ class ExpressionParserSuite extends AnalysisTest { checkError( exception = parseException("timestamP '2016-33-11 20:54:00.000'"), - errorClass = "_LEGACY_ERROR_TEMP_0019", - parameters = Map("valueType" -> "TIMESTAMP", "value" -> "2016-33-11 20:54:00.000"), + errorClass = "INVALID_TYPED_LITERAL", + sqlState = "42000", + parameters = Map("valueType" -> "\"TIMESTAMP\"", "value" -> "'2016-33-11 20:54:00.000'"), context = ExpectedContext( fragment = "timestamP '2016-33-11 20:54:00.000'", start = 0, @@ -591,8 +602,11 @@ class ExpressionParserSuite extends AnalysisTest { assertEqual("INTERVAL '1 year 2 month'", ymIntervalLiteral) checkError( exception = parseException("Interval 'interval 1 yearsss 2 monthsss'"), - errorClass = "_LEGACY_ERROR_TEMP_0020", - parameters = Map("value" -> "interval 1 yearsss 2 monthsss"), + errorClass = "INVALID_TYPED_LITERAL", + parameters = Map( + "valueType" -> "\"INTERVAL\"", + "value" -> "'interval 1 yearsss 2 monthsss'" + ), context = ExpectedContext( fragment = "Interval 'interval 1 yearsss 2 monthsss'", start = 0, @@ -605,8 +619,11 @@ class ExpressionParserSuite extends AnalysisTest { assertEqual("INTERVAL '1 day 2 hour 3 minute 4.005006 second'", dtIntervalLiteral) checkError( exception = parseException("Interval 'interval 1 daysss 2 hoursss'"), - errorClass = "_LEGACY_ERROR_TEMP_0020", - parameters = Map("value" -> "interval 1 daysss 2 hoursss"), + errorClass = "INVALID_TYPED_LITERAL", + parameters = Map( + "valueType" -> "\"INTERVAL\"", + "value" -> "'interval 1 daysss 2 hoursss'" + ), context = ExpectedContext( fragment = "Interval 'interval 1 daysss 2 hoursss'", start = 0, @@ -628,8 +645,11 @@ class ExpressionParserSuite extends AnalysisTest { assertEqual("INTERVAL '3 month 1 hour'", intervalLiteral) checkError( exception = parseException("Interval 'interval 3 monthsss 1 hoursss'"), - errorClass = "_LEGACY_ERROR_TEMP_0020", - parameters = Map("value" -> "interval 3 monthsss 1 hoursss"), + errorClass = "INVALID_TYPED_LITERAL", + parameters = Map( + "valueType" -> "\"INTERVAL\"", + "value" -> "'interval 3 monthsss 1 hoursss'" + ), context = ExpectedContext( fragment = "Interval 'interval 3 monthsss 1 hoursss'", start = 0, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala index dd255290f3..5874480468 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.connector.catalog import java.time.{Instant, ZoneId} import java.time.temporal.ChronoUnit import java.util -import java.util.{Optional, OptionalLong} +import java.util.OptionalLong import scala.collection.mutable @@ -277,7 +277,7 @@ abstract class InMemoryBaseTable( case class InMemoryStats( sizeInBytes: OptionalLong, numRows: OptionalLong, - override val columnStats: Optional[util.Map[NamedReference, ColumnStatistics]]) + override val columnStats: util.Map[NamedReference, ColumnStatistics]) extends Statistics case class InMemoryColumnStats( @@ -298,7 +298,7 @@ abstract class InMemoryBaseTable( override def estimateStatistics(): Statistics = { if (data.isEmpty) { - return InMemoryStats(OptionalLong.of(0L), OptionalLong.of(0L), Optional.empty()) + return InMemoryStats(OptionalLong.of(0L), OptionalLong.of(0L), new util.HashMap()) } val inputPartitions = data.map(_.asInstanceOf[BufferedRows]) @@ -331,7 +331,7 @@ abstract class InMemoryBaseTable( val colNames = tableSchema.fields.map(_.name) var i = 0 for (col <- colNames) { - val fieldReference = FieldReference(col) + val fieldReference = FieldReference.column(col) val colStats = InMemoryColumnStats( OptionalLong.of(colValueSets(i).size()), OptionalLong.of(numOfNulls(i))) @@ -339,7 +339,7 @@ abstract class InMemoryBaseTable( i = i + 1 } - InMemoryStats(OptionalLong.of(sizeInBytes), OptionalLong.of(numRows), Optional.of(map)) + InMemoryStats(OptionalLong.of(sizeInBytes), OptionalLong.of(numRows), map) } override def outputPartitioning(): Partitioning = { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java index 4efa06a781..f89c10155a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java @@ -20,6 +20,8 @@ import java.math.BigInteger; import java.nio.charset.StandardCharsets; import java.sql.Date; +import java.sql.Timestamp; +import java.time.LocalDateTime; import java.util.HashMap; import java.util.Iterator; import java.util.List; @@ -146,6 +148,9 @@ private static void appendValue(WritableColumnVector dst, DataType t, Object o) } else if (t == DataTypes.StringType) { byte[] b =((String)o).getBytes(StandardCharsets.UTF_8); dst.appendByteArray(b, 0, b.length); + } else if (t == DataTypes.BinaryType) { + byte[] b = (byte[]) o; + dst.appendByteArray(b, 0, b.length); } else if (t instanceof DecimalType) { DecimalType dt = (DecimalType) t; Decimal d = Decimal.apply((BigDecimal) o, dt.precision(), dt.scale()); @@ -165,7 +170,11 @@ private static void appendValue(WritableColumnVector dst, DataType t, Object o) dst.getChild(1).appendInt(c.days); dst.getChild(2).appendLong(c.microseconds); } else if (t instanceof DateType) { - dst.appendInt(DateTimeUtils.fromJavaDate((Date)o)); + dst.appendInt(DateTimeUtils.fromJavaDate((Date) o)); + } else if (t instanceof TimestampType) { + dst.appendLong(DateTimeUtils.fromJavaTimestamp((Timestamp) o)); + } else if (t instanceof TimestampNTZType) { + dst.appendLong(DateTimeUtils.localDateTimeToMicros((LocalDateTime) o)); } else { throw new UnsupportedOperationException("Type " + t); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala index 4afcf5b751..7b2d501584 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala @@ -155,7 +155,7 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) case DescribeColumn(ResolvedV1TableIdentifier(ident), column, isExtended, output) => column match { case u: UnresolvedAttribute => - throw QueryCompilationErrors.columnDoesNotExistError(u.name) + throw QueryCompilationErrors.columnNotFoundError(u.name) case a: Attribute => DescribeColumnCommand(ident, a.qualifier :+ a.name, isExtended, output) case Alias(child, _) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index 88bba7f5ec..d821b127e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -86,7 +86,7 @@ case class AnalyzeColumnCommand( } else { columnNames.get.map { col => val exprOption = relation.output.find(attr => conf.resolver(attr.name, col)) - exprOption.getOrElse(throw QueryCompilationErrors.columnDoesNotExistError(col)) + exprOption.getOrElse(throw QueryCompilationErrors.columnNotFoundError(col)) } } // Make sure the column types are supported for stats gathering. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 6b089a13d4..5e733ad9e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -764,7 +764,7 @@ case class DescribeColumnCommand( val colName = UnresolvedAttribute(colNameParts).name val field = { relation.resolve(colNameParts, resolver).getOrElse { - throw QueryCompilationErrors.columnDoesNotExistError(colName) + throw QueryCompilationErrors.columnNotFoundError(colName) } } if (!field.isInstanceOf[Attribute]) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 43b18a3b2d..ce721cd522 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -201,14 +201,14 @@ object FileFormatWriter extends Logging { rdd } - val jobIdInstant = new Date().getTime + val jobTrackerID = SparkHadoopWriterUtils.createJobTrackerID(new Date()) val ret = new Array[WriteTaskResult](rddWithNonEmptyPartitions.partitions.length) sparkSession.sparkContext.runJob( rddWithNonEmptyPartitions, (taskContext: TaskContext, iter: Iterator[InternalRow]) => { executeTask( description = description, - jobIdInstant = jobIdInstant, + jobTrackerID = jobTrackerID, sparkStageId = taskContext.stageId(), sparkPartitionId = taskContext.partitionId(), sparkAttemptNumber = taskContext.taskAttemptId().toInt & Integer.MAX_VALUE, @@ -244,7 +244,7 @@ object FileFormatWriter extends Logging { /** Writes data out in a single Spark task. */ private def executeTask( description: WriteJobDescription, - jobIdInstant: Long, + jobTrackerID: String, sparkStageId: Int, sparkPartitionId: Int, sparkAttemptNumber: Int, @@ -252,7 +252,7 @@ object FileFormatWriter extends Logging { iterator: Iterator[InternalRow], concurrentOutputWriterSpec: Option[ConcurrentOutputWriterSpec]): WriteTaskResult = { - val jobId = SparkHadoopWriterUtils.createJobID(new Date(jobIdInstant), sparkStageId) + val jobId = SparkHadoopWriterUtils.createJobID(jobTrackerID, sparkStageId) val taskId = new TaskID(jobId, TaskType.MAP, sparkPartitionId) val taskAttemptId = new TaskAttemptID(taskId, sparkAttemptNumber) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriterFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriterFactory.scala index d827e83623..ea13e2deac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriterFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriterFactory.scala @@ -29,6 +29,9 @@ import org.apache.spark.sql.execution.datasources.{DynamicPartitionDataSingleWri case class FileWriterFactory ( description: WriteJobDescription, committer: FileCommitProtocol) extends DataWriterFactory { + + private val jobId = SparkHadoopWriterUtils.createJobID(new Date, 0) + override def createWriter(partitionId: Int, realTaskId: Long): DataWriter[InternalRow] = { val taskAttemptContext = createTaskAttemptContext(partitionId) committer.setupTask(taskAttemptContext) @@ -40,7 +43,6 @@ case class FileWriterFactory ( } private def createTaskAttemptContext(partitionId: Int): TaskAttemptContextImpl = { - val jobId = SparkHadoopWriterUtils.createJobID(new Date, 0) val taskId = new TaskID(jobId, TaskType.MAP, partitionId) val taskAttemptId = new TaskAttemptID(taskId, 0) // Set up the configuration object diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala index d81223b48a..734f8165af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala @@ -91,10 +91,6 @@ case class CSVScan( override def hashCode(): Int = super.hashCode() - override def description(): String = { - super.description() + ", PushedFilters: " + pushedFilters.mkString("[", ", ", "]") - } - override def getMetaData(): Map[String, String] = { super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScan.scala index 9ab367136f..c9a3a6f5e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScan.scala @@ -91,7 +91,7 @@ case class JsonScan( override def hashCode(): Int = super.hashCode() - override def description(): String = { - super.description() + ", PushedFilters: " + pushedFilters.mkString("[", ", ", "]") + override def getMetaData(): Map[String, String] = { + super.getMetaData() ++ Map("PushedFilters" -> pushedFilters.mkString("[", ", ", "]")) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala index ccb9ca9c6b..072ab26774 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala @@ -92,12 +92,6 @@ case class OrcScan( ("[]", "[]") } - override def description(): String = { - super.description() + ", PushedFilters: " + seqToString(pushedFilters) + - ", PushedAggregation: " + pushedAggregationsStr + - ", PushedGroupBy: " + pushedGroupByStr - } - override def getMetaData(): Map[String, String] = { super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters)) ++ Map("PushedAggregation" -> pushedAggregationsStr) ++ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala index ff0b38880f..619a8fe66e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala @@ -127,12 +127,6 @@ case class ParquetScan( ("[]", "[]") } - override def description(): String = { - super.description() + ", PushedFilters: " + seqToString(pushedFilters) + - ", PushedAggregation: " + pushedAggregationsStr + - ", PushedGroupBy: " + pushedGroupByStr - } - override def getMetaData(): Map[String, String] = { super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters)) ++ Map("PushedAggregation" -> pushedAggregationsStr) ++ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index a613a39b2b..b31d0b9989 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -55,8 +55,10 @@ class SQLMetric(val metricType: String, initValue: Long = 0L) extends Accumulato override def merge(other: AccumulatorV2[Long, Long]): Unit = other match { case o: SQLMetric => - if (_value < 0) _value = 0 - if (o.value > 0) _value += o.value + if (o.value > 0) { + if (_value < 0) _value = 0 + _value += o.value + } case _ => throw QueryExecutionErrors.cannotMergeClassWithOtherClassError( this.getClass.getName, other.getClass.getName) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index 5acd20f49d..425fc02e31 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -133,7 +133,8 @@ class RocksDB( if (conf.resetStatsOnLoad) { nativeStats.reset } - // reset resources to prevent side-effects from previous loaded version + // reset resources to prevent side-effects from previous loaded version if it was not cleaned + // up correctly closePrefixScanIterators() resetWriteBatch() logInfo(s"Loaded $version") @@ -319,6 +320,10 @@ class RocksDB( } finally { db.continueBackgroundWork() silentDeleteRecursively(checkpointDir, s"committing $newVersion") + // reset resources as either 1) we already pushed the changes and it has been committed or + // 2) commit has failed and the current version is "invalidated". + closePrefixScanIterators() + resetWriteBatch() release() } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/date.sql b/sql/core/src/test/resources/sql-tests/inputs/date.sql index ab57c7c754..163855069f 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/date.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/date.sql @@ -19,6 +19,7 @@ select date'2021-4294967297-11'; select current_date = current_date; -- under ANSI mode, `current_date` can't be a function name. select current_date() = current_date(); +select curdate(1); -- conversions between date and unix_date (number of days from epoch) select DATE_FROM_UNIX_DATE(0), DATE_FROM_UNIX_DATE(1000), DATE_FROM_UNIX_DATE(null); diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/date.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/date.sql.out index 9ddbaec4f9..a292e3f0ef 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/date.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/date.sql.out @@ -22,10 +22,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0019", + "errorClass" : "INVALID_TYPED_LITERAL", + "sqlState" : "42000", "messageParameters" : { - "value" : "2020-01-01中文", - "valueType" : "DATE" + "value" : "'2020-01-01中文'", + "valueType" : "\"DATE\"" }, "queryContext" : [ { "objectType" : "", @@ -82,10 +83,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0019", + "errorClass" : "INVALID_TYPED_LITERAL", + "sqlState" : "42000", "messageParameters" : { - "value" : "015", - "valueType" : "DATE" + "value" : "'015'", + "valueType" : "\"DATE\"" }, "queryContext" : [ { "objectType" : "", @@ -104,10 +106,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0019", + "errorClass" : "INVALID_TYPED_LITERAL", + "sqlState" : "42000", "messageParameters" : { - "value" : "2021-4294967297-11", - "valueType" : "DATE" + "value" : "'2021-4294967297-11'", + "valueType" : "\"DATE\"" }, "queryContext" : [ { "objectType" : "", @@ -135,6 +138,29 @@ struct<(current_date() = current_date()):boolean> true +-- !query +select curdate(1) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "WRONG_NUM_ARGS", + "messageParameters" : { + "actualNum" : "1", + "expectedNum" : "0", + "functionName" : "`curdate`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 17, + "fragment" : "curdate(1)" + } ] +} + + -- !query select DATE_FROM_UNIX_DATE(0), DATE_FROM_UNIX_DATE(1000), DATE_FROM_UNIX_DATE(null) -- !query schema diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out index 9d298fe350..493d8769ad 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out @@ -2398,9 +2398,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0020", + "errorClass" : "INVALID_TYPED_LITERAL", + "sqlState" : "42000", "messageParameters" : { - "value" : "中文 interval 1 day" + "value" : "'中文 interval 1 day'", + "valueType" : "\"INTERVAL\"" }, "queryContext" : [ { "objectType" : "", @@ -2419,9 +2421,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0020", + "errorClass" : "INVALID_TYPED_LITERAL", + "sqlState" : "42000", "messageParameters" : { - "value" : "interval中文 1 day" + "value" : "'interval中文 1 day'", + "valueType" : "\"INTERVAL\"" }, "queryContext" : [ { "objectType" : "", @@ -2440,9 +2444,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0020", + "errorClass" : "INVALID_TYPED_LITERAL", + "sqlState" : "42000", "messageParameters" : { - "value" : "interval 1中文day" + "value" : "'interval 1中文day'", + "valueType" : "\"INTERVAL\"" }, "queryContext" : [ { "objectType" : "", @@ -2579,9 +2585,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0020", + "errorClass" : "INVALID_TYPED_LITERAL", + "sqlState" : "42000", "messageParameters" : { - "value" : "+" + "value" : "'+'", + "valueType" : "\"INTERVAL\"" }, "queryContext" : [ { "objectType" : "", @@ -2600,9 +2608,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0020", + "errorClass" : "INVALID_TYPED_LITERAL", + "sqlState" : "42000", "messageParameters" : { - "value" : "+." + "value" : "'+.'", + "valueType" : "\"INTERVAL\"" }, "queryContext" : [ { "objectType" : "", @@ -2621,9 +2631,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0020", + "errorClass" : "INVALID_TYPED_LITERAL", + "sqlState" : "42000", "messageParameters" : { - "value" : "1" + "value" : "'1'", + "valueType" : "\"INTERVAL\"" }, "queryContext" : [ { "objectType" : "", @@ -2642,9 +2654,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0020", + "errorClass" : "INVALID_TYPED_LITERAL", + "sqlState" : "42000", "messageParameters" : { - "value" : "1.2" + "value" : "'1.2'", + "valueType" : "\"INTERVAL\"" }, "queryContext" : [ { "objectType" : "", @@ -2663,9 +2677,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0020", + "errorClass" : "INVALID_TYPED_LITERAL", + "sqlState" : "42000", "messageParameters" : { - "value" : "- 2" + "value" : "'- 2'", + "valueType" : "\"INTERVAL\"" }, "queryContext" : [ { "objectType" : "", @@ -2684,9 +2700,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0020", + "errorClass" : "INVALID_TYPED_LITERAL", + "sqlState" : "42000", "messageParameters" : { - "value" : "1 day -" + "value" : "'1 day -'", + "valueType" : "\"INTERVAL\"" }, "queryContext" : [ { "objectType" : "", @@ -2705,9 +2723,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0020", + "errorClass" : "INVALID_TYPED_LITERAL", + "sqlState" : "42000", "messageParameters" : { - "value" : "1 day 1" + "value" : "'1 day 1'", + "valueType" : "\"INTERVAL\"" }, "queryContext" : [ { "objectType" : "", diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/literals.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/literals.sql.out index 108cfd19de..f878d350e9 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/literals.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/literals.sql.out @@ -390,10 +390,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0019", + "errorClass" : "INVALID_TYPED_LITERAL", + "sqlState" : "42000", "messageParameters" : { - "value" : "mar 11 2016", - "valueType" : "DATE" + "value" : "'mar 11 2016'", + "valueType" : "\"DATE\"" }, "queryContext" : [ { "objectType" : "", @@ -420,10 +421,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0019", + "errorClass" : "INVALID_TYPED_LITERAL", + "sqlState" : "42000", "messageParameters" : { - "value" : "2016-33-11 20:54:00.000", - "valueType" : "TIMESTAMP" + "value" : "'2016-33-11 20:54:00.000'", + "valueType" : "\"TIMESTAMP\"" }, "queryContext" : [ { "objectType" : "", diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/timestamp.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/timestamp.sql.out index de7f9c753e..b371e8a224 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/timestamp.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/timestamp.sql.out @@ -14,10 +14,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0019", + "errorClass" : "INVALID_TYPED_LITERAL", + "sqlState" : "42000", "messageParameters" : { - "value" : "2019-01-01中文", - "valueType" : "TIMESTAMP" + "value" : "'2019-01-01中文'", + "valueType" : "\"TIMESTAMP\"" }, "queryContext" : [ { "objectType" : "", @@ -36,10 +37,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0019", + "errorClass" : "INVALID_TYPED_LITERAL", + "sqlState" : "42000", "messageParameters" : { - "value" : "4294967297", - "valueType" : "TIMESTAMP" + "value" : "'4294967297'", + "valueType" : "\"TIMESTAMP\"" }, "queryContext" : [ { "objectType" : "", @@ -58,10 +60,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0019", + "errorClass" : "INVALID_TYPED_LITERAL", + "sqlState" : "42000", "messageParameters" : { - "value" : "2021-01-01T12:30:4294967297.123456", - "valueType" : "TIMESTAMP" + "value" : "'2021-01-01T12:30:4294967297.123456'", + "valueType" : "\"TIMESTAMP\"" }, "queryContext" : [ { "objectType" : "", diff --git a/sql/core/src/test/resources/sql-tests/results/date.sql.out b/sql/core/src/test/resources/sql-tests/results/date.sql.out index 9e427adb05..2cc1b70be7 100644 --- a/sql/core/src/test/resources/sql-tests/results/date.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/date.sql.out @@ -22,10 +22,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0019", + "errorClass" : "INVALID_TYPED_LITERAL", + "sqlState" : "42000", "messageParameters" : { - "value" : "2020-01-01中文", - "valueType" : "DATE" + "value" : "'2020-01-01中文'", + "valueType" : "\"DATE\"" }, "queryContext" : [ { "objectType" : "", @@ -68,10 +69,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0019", + "errorClass" : "INVALID_TYPED_LITERAL", + "sqlState" : "42000", "messageParameters" : { - "value" : "015", - "valueType" : "DATE" + "value" : "'015'", + "valueType" : "\"DATE\"" }, "queryContext" : [ { "objectType" : "", @@ -90,10 +92,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0019", + "errorClass" : "INVALID_TYPED_LITERAL", + "sqlState" : "42000", "messageParameters" : { - "value" : "2021-4294967297-11", - "valueType" : "DATE" + "value" : "'2021-4294967297-11'", + "valueType" : "\"DATE\"" }, "queryContext" : [ { "objectType" : "", @@ -121,6 +124,29 @@ struct<(current_date() = current_date()):boolean> true +-- !query +select curdate(1) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "WRONG_NUM_ARGS", + "messageParameters" : { + "actualNum" : "1", + "expectedNum" : "0", + "functionName" : "`curdate`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 17, + "fragment" : "curdate(1)" + } ] +} + + -- !query select DATE_FROM_UNIX_DATE(0), DATE_FROM_UNIX_DATE(1000), DATE_FROM_UNIX_DATE(null) -- !query schema diff --git a/sql/core/src/test/resources/sql-tests/results/datetime-legacy.sql.out b/sql/core/src/test/resources/sql-tests/results/datetime-legacy.sql.out index 3c3a70acd1..e057b22df0 100644 --- a/sql/core/src/test/resources/sql-tests/results/datetime-legacy.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/datetime-legacy.sql.out @@ -22,10 +22,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0019", + "errorClass" : "INVALID_TYPED_LITERAL", + "sqlState" : "42000", "messageParameters" : { - "value" : "2020-01-01中文", - "valueType" : "DATE" + "value" : "'2020-01-01中文'", + "valueType" : "\"DATE\"" }, "queryContext" : [ { "objectType" : "", @@ -68,10 +69,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0019", + "errorClass" : "INVALID_TYPED_LITERAL", + "sqlState" : "42000", "messageParameters" : { - "value" : "015", - "valueType" : "DATE" + "value" : "'015'", + "valueType" : "\"DATE\"" }, "queryContext" : [ { "objectType" : "", @@ -90,10 +92,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0019", + "errorClass" : "INVALID_TYPED_LITERAL", + "sqlState" : "42000", "messageParameters" : { - "value" : "2021-4294967297-11", - "valueType" : "DATE" + "value" : "'2021-4294967297-11'", + "valueType" : "\"DATE\"" }, "queryContext" : [ { "objectType" : "", @@ -121,6 +124,29 @@ struct<(current_date() = current_date()):boolean> true +-- !query +select curdate(1) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "WRONG_NUM_ARGS", + "messageParameters" : { + "actualNum" : "1", + "expectedNum" : "0", + "functionName" : "`curdate`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 17, + "fragment" : "curdate(1)" + } ] +} + + -- !query select DATE_FROM_UNIX_DATE(0), DATE_FROM_UNIX_DATE(1000), DATE_FROM_UNIX_DATE(null) -- !query schema @@ -1074,10 +1100,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0019", + "errorClass" : "INVALID_TYPED_LITERAL", + "sqlState" : "42000", "messageParameters" : { - "value" : "2019-01-01中文", - "valueType" : "TIMESTAMP" + "value" : "'2019-01-01中文'", + "valueType" : "\"TIMESTAMP\"" }, "queryContext" : [ { "objectType" : "", @@ -1096,10 +1123,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0019", + "errorClass" : "INVALID_TYPED_LITERAL", + "sqlState" : "42000", "messageParameters" : { - "value" : "4294967297", - "valueType" : "TIMESTAMP" + "value" : "'4294967297'", + "valueType" : "\"TIMESTAMP\"" }, "queryContext" : [ { "objectType" : "", @@ -1118,10 +1146,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0019", + "errorClass" : "INVALID_TYPED_LITERAL", + "sqlState" : "42000", "messageParameters" : { - "value" : "2021-01-01T12:30:4294967297.123456", - "valueType" : "TIMESTAMP" + "value" : "'2021-01-01T12:30:4294967297.123456'", + "valueType" : "\"TIMESTAMP\"" }, "queryContext" : [ { "objectType" : "", diff --git a/sql/core/src/test/resources/sql-tests/results/except-all.sql.out b/sql/core/src/test/resources/sql-tests/results/except-all.sql.out index b21dc8c62b..cd71644204 100644 --- a/sql/core/src/test/resources/sql-tests/results/except-all.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/except-all.sql.out @@ -145,7 +145,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "array", "dt2" : "int", "hint" : "", - "operator" : "ExceptAll", + "operator" : "EXCEPT ALL", "ti" : "second" }, "queryContext" : [ { @@ -230,10 +230,10 @@ org.apache.spark.sql.AnalysisException { "errorClass" : "NUM_COLUMNS_MISMATCH", "messageParameters" : { + "firstNumColumns" : "1", "invalidNumColumns" : "2", "invalidOrdinalNum" : "second", - "operator" : "EXCEPTALL", - "refNumColumns" : "1" + "operator" : "EXCEPT ALL" }, "queryContext" : [ { "objectType" : "", diff --git a/sql/core/src/test/resources/sql-tests/results/intersect-all.sql.out b/sql/core/src/test/resources/sql-tests/results/intersect-all.sql.out index cb3541e569..48c4ce4583 100644 --- a/sql/core/src/test/resources/sql-tests/results/intersect-all.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/intersect-all.sql.out @@ -102,7 +102,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "array", "dt2" : "int", "hint" : "", - "operator" : "IntersectAll", + "operator" : "INTERSECT ALL", "ti" : "second" }, "queryContext" : [ { @@ -126,10 +126,10 @@ org.apache.spark.sql.AnalysisException { "errorClass" : "NUM_COLUMNS_MISMATCH", "messageParameters" : { + "firstNumColumns" : "1", "invalidNumColumns" : "2", "invalidOrdinalNum" : "second", - "operator" : "INTERSECTALL", - "refNumColumns" : "1" + "operator" : "INTERSECT ALL" }, "queryContext" : [ { "objectType" : "", diff --git a/sql/core/src/test/resources/sql-tests/results/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/interval.sql.out index 716ea9335c..690f3da0f9 100644 --- a/sql/core/src/test/resources/sql-tests/results/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/interval.sql.out @@ -2211,9 +2211,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0020", + "errorClass" : "INVALID_TYPED_LITERAL", + "sqlState" : "42000", "messageParameters" : { - "value" : "中文 interval 1 day" + "value" : "'中文 interval 1 day'", + "valueType" : "\"INTERVAL\"" }, "queryContext" : [ { "objectType" : "", @@ -2232,9 +2234,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0020", + "errorClass" : "INVALID_TYPED_LITERAL", + "sqlState" : "42000", "messageParameters" : { - "value" : "interval中文 1 day" + "value" : "'interval中文 1 day'", + "valueType" : "\"INTERVAL\"" }, "queryContext" : [ { "objectType" : "", @@ -2253,9 +2257,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0020", + "errorClass" : "INVALID_TYPED_LITERAL", + "sqlState" : "42000", "messageParameters" : { - "value" : "interval 1中文day" + "value" : "'interval 1中文day'", + "valueType" : "\"INTERVAL\"" }, "queryContext" : [ { "objectType" : "", @@ -2392,9 +2398,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0020", + "errorClass" : "INVALID_TYPED_LITERAL", + "sqlState" : "42000", "messageParameters" : { - "value" : "+" + "value" : "'+'", + "valueType" : "\"INTERVAL\"" }, "queryContext" : [ { "objectType" : "", @@ -2413,9 +2421,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0020", + "errorClass" : "INVALID_TYPED_LITERAL", + "sqlState" : "42000", "messageParameters" : { - "value" : "+." + "value" : "'+.'", + "valueType" : "\"INTERVAL\"" }, "queryContext" : [ { "objectType" : "", @@ -2434,9 +2444,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0020", + "errorClass" : "INVALID_TYPED_LITERAL", + "sqlState" : "42000", "messageParameters" : { - "value" : "1" + "value" : "'1'", + "valueType" : "\"INTERVAL\"" }, "queryContext" : [ { "objectType" : "", @@ -2455,9 +2467,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0020", + "errorClass" : "INVALID_TYPED_LITERAL", + "sqlState" : "42000", "messageParameters" : { - "value" : "1.2" + "value" : "'1.2'", + "valueType" : "\"INTERVAL\"" }, "queryContext" : [ { "objectType" : "", @@ -2476,9 +2490,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0020", + "errorClass" : "INVALID_TYPED_LITERAL", + "sqlState" : "42000", "messageParameters" : { - "value" : "- 2" + "value" : "'- 2'", + "valueType" : "\"INTERVAL\"" }, "queryContext" : [ { "objectType" : "", @@ -2497,9 +2513,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0020", + "errorClass" : "INVALID_TYPED_LITERAL", + "sqlState" : "42000", "messageParameters" : { - "value" : "1 day -" + "value" : "'1 day -'", + "valueType" : "\"INTERVAL\"" }, "queryContext" : [ { "objectType" : "", @@ -2518,9 +2536,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0020", + "errorClass" : "INVALID_TYPED_LITERAL", + "sqlState" : "42000", "messageParameters" : { - "value" : "1 day 1" + "value" : "'1 day 1'", + "valueType" : "\"INTERVAL\"" }, "queryContext" : [ { "objectType" : "", diff --git a/sql/core/src/test/resources/sql-tests/results/literals.sql.out b/sql/core/src/test/resources/sql-tests/results/literals.sql.out index 108cfd19de..f878d350e9 100644 --- a/sql/core/src/test/resources/sql-tests/results/literals.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/literals.sql.out @@ -390,10 +390,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0019", + "errorClass" : "INVALID_TYPED_LITERAL", + "sqlState" : "42000", "messageParameters" : { - "value" : "mar 11 2016", - "valueType" : "DATE" + "value" : "'mar 11 2016'", + "valueType" : "\"DATE\"" }, "queryContext" : [ { "objectType" : "", @@ -420,10 +421,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0019", + "errorClass" : "INVALID_TYPED_LITERAL", + "sqlState" : "42000", "messageParameters" : { - "value" : "2016-33-11 20:54:00.000", - "valueType" : "TIMESTAMP" + "value" : "'2016-33-11 20:54:00.000'", + "valueType" : "\"TIMESTAMP\"" }, "queryContext" : [ { "objectType" : "", diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/date.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/date.sql.out index 1103aff05d..f0718f1a64 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/date.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/date.sql.out @@ -199,10 +199,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0019", + "errorClass" : "INVALID_TYPED_LITERAL", + "sqlState" : "42000", "messageParameters" : { - "value" : "1999 Jan 08", - "valueType" : "DATE" + "value" : "'1999 Jan 08'", + "valueType" : "\"DATE\"" }, "queryContext" : [ { "objectType" : "", @@ -221,10 +222,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0019", + "errorClass" : "INVALID_TYPED_LITERAL", + "sqlState" : "42000", "messageParameters" : { - "value" : "1999 08 Jan", - "valueType" : "DATE" + "value" : "'1999 08 Jan'", + "valueType" : "\"DATE\"" }, "queryContext" : [ { "objectType" : "", @@ -259,10 +261,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0019", + "errorClass" : "INVALID_TYPED_LITERAL", + "sqlState" : "42000", "messageParameters" : { - "value" : "1999 01 08", - "valueType" : "DATE" + "value" : "'1999 01 08'", + "valueType" : "\"DATE\"" }, "queryContext" : [ { "objectType" : "", @@ -281,10 +284,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0019", + "errorClass" : "INVALID_TYPED_LITERAL", + "sqlState" : "42000", "messageParameters" : { - "value" : "1999 08 01", - "valueType" : "DATE" + "value" : "'1999 08 01'", + "valueType" : "\"DATE\"" }, "queryContext" : [ { "objectType" : "", @@ -311,10 +315,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0019", + "errorClass" : "INVALID_TYPED_LITERAL", + "sqlState" : "42000", "messageParameters" : { - "value" : "1999 Jan 08", - "valueType" : "DATE" + "value" : "'1999 Jan 08'", + "valueType" : "\"DATE\"" }, "queryContext" : [ { "objectType" : "", @@ -333,10 +338,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0019", + "errorClass" : "INVALID_TYPED_LITERAL", + "sqlState" : "42000", "messageParameters" : { - "value" : "1999 08 Jan", - "valueType" : "DATE" + "value" : "'1999 08 Jan'", + "valueType" : "\"DATE\"" }, "queryContext" : [ { "objectType" : "", @@ -371,10 +377,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0019", + "errorClass" : "INVALID_TYPED_LITERAL", + "sqlState" : "42000", "messageParameters" : { - "value" : "1999 01 08", - "valueType" : "DATE" + "value" : "'1999 01 08'", + "valueType" : "\"DATE\"" }, "queryContext" : [ { "objectType" : "", @@ -393,10 +400,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0019", + "errorClass" : "INVALID_TYPED_LITERAL", + "sqlState" : "42000", "messageParameters" : { - "value" : "1999 08 01", - "valueType" : "DATE" + "value" : "'1999 08 01'", + "valueType" : "\"DATE\"" }, "queryContext" : [ { "objectType" : "", @@ -431,10 +439,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0019", + "errorClass" : "INVALID_TYPED_LITERAL", + "sqlState" : "42000", "messageParameters" : { - "value" : "1999 Jan 08", - "valueType" : "DATE" + "value" : "'1999 Jan 08'", + "valueType" : "\"DATE\"" }, "queryContext" : [ { "objectType" : "", @@ -453,10 +462,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0019", + "errorClass" : "INVALID_TYPED_LITERAL", + "sqlState" : "42000", "messageParameters" : { - "value" : "1999 08 Jan", - "valueType" : "DATE" + "value" : "'1999 08 Jan'", + "valueType" : "\"DATE\"" }, "queryContext" : [ { "objectType" : "", @@ -491,10 +501,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0019", + "errorClass" : "INVALID_TYPED_LITERAL", + "sqlState" : "42000", "messageParameters" : { - "value" : "1999 01 08", - "valueType" : "DATE" + "value" : "'1999 01 08'", + "valueType" : "\"DATE\"" }, "queryContext" : [ { "objectType" : "", @@ -513,10 +524,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0019", + "errorClass" : "INVALID_TYPED_LITERAL", + "sqlState" : "42000", "messageParameters" : { - "value" : "1999 08 01", - "valueType" : "DATE" + "value" : "'1999 08 01'", + "valueType" : "\"DATE\"" }, "queryContext" : [ { "objectType" : "", diff --git a/sql/core/src/test/resources/sql-tests/results/timestamp.sql.out b/sql/core/src/test/resources/sql-tests/results/timestamp.sql.out index b17dcdf323..affe34d545 100644 --- a/sql/core/src/test/resources/sql-tests/results/timestamp.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/timestamp.sql.out @@ -14,10 +14,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0019", + "errorClass" : "INVALID_TYPED_LITERAL", + "sqlState" : "42000", "messageParameters" : { - "value" : "2019-01-01中文", - "valueType" : "TIMESTAMP" + "value" : "'2019-01-01中文'", + "valueType" : "\"TIMESTAMP\"" }, "queryContext" : [ { "objectType" : "", @@ -36,10 +37,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0019", + "errorClass" : "INVALID_TYPED_LITERAL", + "sqlState" : "42000", "messageParameters" : { - "value" : "4294967297", - "valueType" : "TIMESTAMP" + "value" : "'4294967297'", + "valueType" : "\"TIMESTAMP\"" }, "queryContext" : [ { "objectType" : "", @@ -58,10 +60,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0019", + "errorClass" : "INVALID_TYPED_LITERAL", + "sqlState" : "42000", "messageParameters" : { - "value" : "2021-01-01T12:30:4294967297.123456", - "valueType" : "TIMESTAMP" + "value" : "'2021-01-01T12:30:4294967297.123456'", + "valueType" : "\"TIMESTAMP\"" }, "queryContext" : [ { "objectType" : "", diff --git a/sql/core/src/test/resources/sql-tests/results/timestampNTZ/timestamp-ansi.sql.out b/sql/core/src/test/resources/sql-tests/results/timestampNTZ/timestamp-ansi.sql.out index a326e009af..e4583e0d7e 100644 --- a/sql/core/src/test/resources/sql-tests/results/timestampNTZ/timestamp-ansi.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/timestampNTZ/timestamp-ansi.sql.out @@ -14,10 +14,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0019", + "errorClass" : "INVALID_TYPED_LITERAL", + "sqlState" : "42000", "messageParameters" : { - "value" : "2019-01-01中文", - "valueType" : "TIMESTAMP" + "value" : "'2019-01-01中文'", + "valueType" : "\"TIMESTAMP\"" }, "queryContext" : [ { "objectType" : "", @@ -36,10 +37,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0019", + "errorClass" : "INVALID_TYPED_LITERAL", + "sqlState" : "42000", "messageParameters" : { - "value" : "4294967297", - "valueType" : "TIMESTAMP" + "value" : "'4294967297'", + "valueType" : "\"TIMESTAMP\"" }, "queryContext" : [ { "objectType" : "", @@ -58,10 +60,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0019", + "errorClass" : "INVALID_TYPED_LITERAL", + "sqlState" : "42000", "messageParameters" : { - "value" : "2021-01-01T12:30:4294967297.123456", - "valueType" : "TIMESTAMP" + "value" : "'2021-01-01T12:30:4294967297.123456'", + "valueType" : "\"TIMESTAMP\"" }, "queryContext" : [ { "objectType" : "", diff --git a/sql/core/src/test/resources/sql-tests/results/timestampNTZ/timestamp.sql.out b/sql/core/src/test/resources/sql-tests/results/timestampNTZ/timestamp.sql.out index 2427356000..c174ff4853 100644 --- a/sql/core/src/test/resources/sql-tests/results/timestampNTZ/timestamp.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/timestampNTZ/timestamp.sql.out @@ -14,10 +14,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0019", + "errorClass" : "INVALID_TYPED_LITERAL", + "sqlState" : "42000", "messageParameters" : { - "value" : "2019-01-01中文", - "valueType" : "TIMESTAMP" + "value" : "'2019-01-01中文'", + "valueType" : "\"TIMESTAMP\"" }, "queryContext" : [ { "objectType" : "", @@ -36,10 +37,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0019", + "errorClass" : "INVALID_TYPED_LITERAL", + "sqlState" : "42000", "messageParameters" : { - "value" : "4294967297", - "valueType" : "TIMESTAMP" + "value" : "'4294967297'", + "valueType" : "\"TIMESTAMP\"" }, "queryContext" : [ { "objectType" : "", @@ -58,10 +60,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0019", + "errorClass" : "INVALID_TYPED_LITERAL", + "sqlState" : "42000", "messageParameters" : { - "value" : "2021-01-01T12:30:4294967297.123456", - "valueType" : "TIMESTAMP" + "value" : "'2021-01-01T12:30:4294967297.123456'", + "valueType" : "\"TIMESTAMP\"" }, "queryContext" : [ { "objectType" : "", diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/widenSetOperationTypes.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/widenSetOperationTypes.sql.out index 27abeaf859..3f40a18181 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/widenSetOperationTypes.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/widenSetOperationTypes.sql.out @@ -92,7 +92,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "binary", "dt2" : "tinyint", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -118,7 +118,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "boolean", "dt2" : "tinyint", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -144,7 +144,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "timestamp", "dt2" : "tinyint", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -170,7 +170,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "date", "dt2" : "tinyint", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -268,7 +268,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "binary", "dt2" : "smallint", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -294,7 +294,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "boolean", "dt2" : "smallint", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -320,7 +320,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "timestamp", "dt2" : "smallint", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -346,7 +346,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "date", "dt2" : "smallint", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -444,7 +444,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "binary", "dt2" : "int", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -470,7 +470,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "boolean", "dt2" : "int", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -496,7 +496,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "timestamp", "dt2" : "int", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -522,7 +522,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "date", "dt2" : "int", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -620,7 +620,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "binary", "dt2" : "bigint", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -646,7 +646,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "boolean", "dt2" : "bigint", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -672,7 +672,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "timestamp", "dt2" : "bigint", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -698,7 +698,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "date", "dt2" : "bigint", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -796,7 +796,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "binary", "dt2" : "float", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -822,7 +822,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "boolean", "dt2" : "float", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -848,7 +848,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "timestamp", "dt2" : "float", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -874,7 +874,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "date", "dt2" : "float", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -972,7 +972,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "binary", "dt2" : "double", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -998,7 +998,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "boolean", "dt2" : "double", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -1024,7 +1024,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "timestamp", "dt2" : "double", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -1050,7 +1050,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "date", "dt2" : "double", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -1148,7 +1148,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "binary", "dt2" : "decimal(10,0)", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -1174,7 +1174,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "boolean", "dt2" : "decimal(10,0)", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -1200,7 +1200,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "timestamp", "dt2" : "decimal(10,0)", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -1226,7 +1226,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "date", "dt2" : "decimal(10,0)", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -1324,7 +1324,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "binary", "dt2" : "string", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -1350,7 +1350,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "boolean", "dt2" : "string", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -1394,7 +1394,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "tinyint", "dt2" : "binary", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -1420,7 +1420,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "smallint", "dt2" : "binary", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -1446,7 +1446,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "int", "dt2" : "binary", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -1472,7 +1472,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "bigint", "dt2" : "binary", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -1498,7 +1498,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "float", "dt2" : "binary", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -1524,7 +1524,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "double", "dt2" : "binary", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -1550,7 +1550,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "decimal(10,0)", "dt2" : "binary", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -1576,7 +1576,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "string", "dt2" : "binary", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -1611,7 +1611,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "boolean", "dt2" : "binary", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -1637,7 +1637,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "timestamp", "dt2" : "binary", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -1663,7 +1663,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "date", "dt2" : "binary", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -1689,7 +1689,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "tinyint", "dt2" : "boolean", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -1715,7 +1715,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "smallint", "dt2" : "boolean", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -1741,7 +1741,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "int", "dt2" : "boolean", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -1767,7 +1767,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "bigint", "dt2" : "boolean", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -1793,7 +1793,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "float", "dt2" : "boolean", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -1819,7 +1819,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "double", "dt2" : "boolean", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -1845,7 +1845,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "decimal(10,0)", "dt2" : "boolean", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -1871,7 +1871,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "string", "dt2" : "boolean", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -1897,7 +1897,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "binary", "dt2" : "boolean", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -1931,7 +1931,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "timestamp", "dt2" : "boolean", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -1957,7 +1957,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "date", "dt2" : "boolean", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -1983,7 +1983,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "tinyint", "dt2" : "timestamp", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -2009,7 +2009,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "smallint", "dt2" : "timestamp", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -2035,7 +2035,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "int", "dt2" : "timestamp", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -2061,7 +2061,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "bigint", "dt2" : "timestamp", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -2087,7 +2087,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "float", "dt2" : "timestamp", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -2113,7 +2113,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "double", "dt2" : "timestamp", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -2139,7 +2139,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "decimal(10,0)", "dt2" : "timestamp", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -2174,7 +2174,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "binary", "dt2" : "timestamp", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -2200,7 +2200,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "boolean", "dt2" : "timestamp", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -2244,7 +2244,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "tinyint", "dt2" : "date", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -2270,7 +2270,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "smallint", "dt2" : "date", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -2296,7 +2296,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "int", "dt2" : "date", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -2322,7 +2322,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "bigint", "dt2" : "date", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -2348,7 +2348,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "float", "dt2" : "date", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -2374,7 +2374,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "double", "dt2" : "date", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -2400,7 +2400,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "decimal(10,0)", "dt2" : "date", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -2435,7 +2435,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "binary", "dt2" : "date", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { @@ -2461,7 +2461,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "boolean", "dt2" : "date", "hint" : "", - "operator" : "Union", + "operator" : "UNION", "ti" : "second" }, "queryContext" : [ { diff --git a/sql/core/src/test/resources/sql-tests/results/udf/udf-except-all.sql.out b/sql/core/src/test/resources/sql-tests/results/udf/udf-except-all.sql.out index 5b2754944b..ac1b1ac417 100644 --- a/sql/core/src/test/resources/sql-tests/results/udf/udf-except-all.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/udf/udf-except-all.sql.out @@ -145,7 +145,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "array", "dt2" : "int", "hint" : "", - "operator" : "ExceptAll", + "operator" : "EXCEPT ALL", "ti" : "second" }, "queryContext" : [ { @@ -230,10 +230,10 @@ org.apache.spark.sql.AnalysisException { "errorClass" : "NUM_COLUMNS_MISMATCH", "messageParameters" : { + "firstNumColumns" : "1", "invalidNumColumns" : "2", "invalidOrdinalNum" : "second", - "operator" : "EXCEPTALL", - "refNumColumns" : "1" + "operator" : "EXCEPT ALL" }, "queryContext" : [ { "objectType" : "", diff --git a/sql/core/src/test/resources/sql-tests/results/udf/udf-intersect-all.sql.out b/sql/core/src/test/resources/sql-tests/results/udf/udf-intersect-all.sql.out index 11bcc0f6bd..d359698366 100644 --- a/sql/core/src/test/resources/sql-tests/results/udf/udf-intersect-all.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/udf/udf-intersect-all.sql.out @@ -102,7 +102,7 @@ org.apache.spark.sql.AnalysisException "dt1" : "array", "dt2" : "int", "hint" : "", - "operator" : "IntersectAll", + "operator" : "INTERSECT ALL", "ti" : "second" }, "queryContext" : [ { @@ -126,10 +126,10 @@ org.apache.spark.sql.AnalysisException { "errorClass" : "NUM_COLUMNS_MISMATCH", "messageParameters" : { + "firstNumColumns" : "1", "invalidNumColumns" : "2", "invalidOrdinalNum" : "second", - "operator" : "INTERSECTALL", - "refNumColumns" : "1" + "operator" : "INTERSECT ALL" }, "queryContext" : [ { "objectType" : "", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala index dda9390f4b..994dfb4d11 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala @@ -530,9 +530,10 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession { errorClass = "NUM_COLUMNS_MISMATCH", parameters = Map( "operator" -> "UNION", - "refNumColumns" -> "2", + "firstNumColumns" -> "2", "invalidOrdinalNum" -> "second", - "invalidNumColumns" -> "3")) + "invalidNumColumns" -> "3") + ) df1 = Seq((1, 2, 3)).toDF("a", "b", "c") df2 = Seq((4, 5, 6)).toDF("a", "c", "d") @@ -1011,7 +1012,7 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession { val errMsg = intercept[AnalysisException] { df1.unionByName(df2) }.getMessage - assert(errMsg.contains("Union can only be performed on tables with" + + assert(errMsg.contains("UNION can only be performed on tables with" + " compatible column types." + " The third column of the second table is struct>" + " type which is not compatible with struct> at the same" + @@ -1095,7 +1096,7 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession { val err = intercept[AnalysisException](df7.union(df8).collect()) assert(err.message - .contains("Union can only be performed on tables with compatible column types")) + .contains("UNION can only be performed on tables with compatible column types")) } test("SPARK-36546: Add unionByName support to arrays of structs") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index 0740e1b2bd..028ab8ea14 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -49,6 +49,19 @@ class DateFunctionsSuite extends QueryTest with SharedSparkSession { sql("""SELECT CURDATE()""").collect().head.getDate(0)) val d4 = DateTimeUtils.currentDate(ZoneId.systemDefault()) assert(d0 <= d1 && d1 <= d2 && d2 <= d3 && d3 <= d4 && d4 - d0 <= 1) + + checkError( + exception = intercept[AnalysisException] { + sql("SELECT CURDATE(1)") + }, + errorClass = "WRONG_NUM_ARGS", + parameters = Map( + "functionName" -> "`curdate`", + "expectedNum" -> "0", + "actualNum" -> "1" + ), + context = ExpectedContext("", "", 7, 16, "CURDATE(1)") + ) } test("function current_timestamp and now") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala index 50c3b8fbf4..b5353455dc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala @@ -462,17 +462,8 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite withTempDir { dir => Seq("parquet", "orc", "csv", "json").foreach { fmt => val basePath = dir.getCanonicalPath + "/" + fmt - val pushFilterMaps = Map ( - "parquet" -> - "|PushedFilters: \\[IsNotNull\\(value\\), GreaterThan\\(value,2\\)\\]", - "orc" -> - "|PushedFilters: \\[IsNotNull\\(value\\), GreaterThan\\(value,2\\)\\]", - "csv" -> - "|PushedFilters: \\[IsNotNull\\(value\\), GreaterThan\\(value,2\\)\\]", - "json" -> - "|remove_marker" - ) - val expected_plan_fragment1 = + + val expectedPlanFragment = s""" |\\(1\\) BatchScan $fmt file:$basePath |Output \\[2\\]: \\[value#x, id#x\\] @@ -480,9 +471,9 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite |Format: $fmt |Location: InMemoryFileIndex\\([0-9]+ paths\\)\\[.*\\] |PartitionFilters: \\[isnotnull\\(id#x\\), \\(id#x > 1\\)\\] - ${pushFilterMaps.get(fmt).get} + |PushedFilters: \\[IsNotNull\\(value\\), GreaterThan\\(value,2\\)\\] |ReadSchema: struct\\ - |""".stripMargin.replaceAll("\nremove_marker", "").trim + |""".stripMargin.trim spark.range(10) .select(col("id"), col("id").as("value")) @@ -500,7 +491,7 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite .format(fmt) .load(basePath).where($"id" > 1 && $"value" > 2) val normalizedOutput = getNormalizedExplain(df, FormattedMode) - assert(expected_plan_fragment1.r.findAllMatchIn(normalizedOutput).length == 1) + assert(expectedPlanFragment.r.findAllMatchIn(normalizedOutput).length == 1) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala new file mode 100644 index 0000000000..abeb3bb784 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala @@ -0,0 +1,327 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.scalactic.source.Position +import org.scalatest.Tag + +import org.apache.spark.sql.catalyst.trees.TreePattern.OUTER_REFERENCE +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession + +class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { + protected val testTable: String = "employee" + + override def beforeAll(): Unit = { + super.beforeAll() + sql( + s""" + |CREATE TABLE $testTable ( + | dept INTEGER, + | name String, + | salary INTEGER, + | bonus INTEGER, + | properties STRUCT) + |USING orc + |""".stripMargin) + sql( + s""" + |INSERT INTO $testTable VALUES + | (1, 'amy', 10000, 1000, named_struct('joinYear', 2019, 'mostRecentEmployer', 'A')), + | (2, 'alex', 12000, 1200, named_struct('joinYear', 2017, 'mostRecentEmployer', 'A')), + | (1, 'cathy', 9000, 1200, named_struct('joinYear', 2020, 'mostRecentEmployer', 'B')), + | (2, 'david', 10000, 1300, named_struct('joinYear', 2019, 'mostRecentEmployer', 'C')), + | (6, 'jen', 12000, 1200, named_struct('joinYear', 2018, 'mostRecentEmployer', 'D')) + |""".stripMargin) + } + + override def afterAll(): Unit = { + try { + sql(s"DROP TABLE IF EXISTS $testTable") + } finally { + super.afterAll() + } + } + + val lcaEnabled: Boolean = true + // by default the tests in this suites run with LCA on + override protected def test(testName: String, testTags: Tag*)(testFun: => Any) + (implicit pos: Position): Unit = { + super.test(testName, testTags: _*) { + withSQLConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED.key -> lcaEnabled.toString) { + testFun + } + } + } + // mark special testcases test both LCA on and off + protected def testOnAndOff(testName: String, testTags: Tag*)(testFun: => Any) + (implicit pos: Position): Unit = { + super.test(testName, testTags: _*)(testFun) + } + + private def withLCAOff(f: => Unit): Unit = { + withSQLConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED.key -> "false") { + f + } + } + private def withLCAOn(f: => Unit): Unit = { + withSQLConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED.key -> "true") { + f + } + } + + testOnAndOff("Lateral alias basics - Project") { + def checkAnswerWhenOnAndExceptionWhenOff(query: String, expectedAnswerLCAOn: Row): Unit = { + withLCAOn { checkAnswer(sql(query), expectedAnswerLCAOn) } + withLCAOff { + assert(intercept[AnalysisException]{ sql(query) } + .getErrorClass == "UNRESOLVED_COLUMN.WITH_SUGGESTION") + } + } + + checkAnswerWhenOnAndExceptionWhenOff( + s"select dept as d, d + 1 as e from $testTable where name = 'amy'", + Row(1, 2)) + + checkAnswerWhenOnAndExceptionWhenOff( + s"select salary * 2 as new_salary, new_salary + bonus from $testTable where name = 'amy'", + Row(20000, 21000)) + checkAnswerWhenOnAndExceptionWhenOff( + s"select salary * 2 as new_salary, new_salary + bonus * 2 as new_income from $testTable" + + s" where name = 'amy'", + Row(20000, 22000)) + + checkAnswerWhenOnAndExceptionWhenOff( + "select salary * 2 as new_salary, (new_salary + bonus) * 3 - new_salary * 2 as " + + s"new_income from $testTable where name = 'amy'", + Row(20000, 23000)) + + // should referring to the previously defined LCA + checkAnswerWhenOnAndExceptionWhenOff( + s"SELECT salary * 1.5 AS d, d, 10000 AS d FROM $testTable WHERE name = 'jen'", + Row(18000, 18000, 10000) + ) + } + + test("Duplicated lateral alias names - Project") { + def checkDuplicatedAliasErrorHelper(query: String, parameters: Map[String, String]): Unit = { + checkError( + exception = intercept[AnalysisException] {sql(query)}, + errorClass = "AMBIGUOUS_LATERAL_COLUMN_ALIAS", + sqlState = "42000", + parameters = parameters + ) + } + + // Has duplicated names but not referenced is fine + checkAnswer( + sql(s"SELECT salary AS d, bonus AS d FROM $testTable WHERE name = 'jen'"), + Row(12000, 1200) + ) + checkAnswer( + sql(s"SELECT salary AS d, d, 10000 AS d FROM $testTable WHERE name = 'jen'"), + Row(12000, 12000, 10000) + ) + checkAnswer( + sql(s"SELECT salary * 1.5 AS d, d, 10000 AS d FROM $testTable WHERE name = 'jen'"), + Row(18000, 18000, 10000) + ) + checkAnswer( + sql(s"SELECT salary + 1000 AS new_salary, new_salary * 1.0 AS new_salary " + + s"FROM $testTable WHERE name = 'jen'"), + Row(13000, 13000.0)) + + // Referencing duplicated names raises error + checkDuplicatedAliasErrorHelper( + s"SELECT salary * 1.5 AS d, d, 10000 AS d, d + 1 FROM $testTable", + parameters = Map("name" -> "`d`", "n" -> "2") + ) + checkDuplicatedAliasErrorHelper( + s"SELECT 10000 AS d, d * 1.0, salary * 1.5 AS d, d FROM $testTable", + parameters = Map("name" -> "`d`", "n" -> "2") + ) + checkDuplicatedAliasErrorHelper( + s"SELECT salary AS d, d + 1 AS d, d + 1 AS d FROM $testTable", + parameters = Map("name" -> "`d`", "n" -> "2") + ) + checkDuplicatedAliasErrorHelper( + s"SELECT salary * 1.5 AS d, d, bonus * 1.5 AS d, d + d FROM $testTable", + parameters = Map("name" -> "`d`", "n" -> "2") + ) + + checkAnswer( + sql( + s""" + |SELECT salary * 1.5 AS salary, salary, 10000 AS salary, salary + |FROM $testTable + |WHERE name = 'jen' + |""".stripMargin), + Row(18000, 12000, 10000, 12000) + ) + } + + test("Lateral alias conflicts with table column - Project") { + checkAnswer( + sql( + "select salary * 2 as salary, salary * 2 + bonus as " + + s"new_income from $testTable where name = 'amy'"), + Row(20000, 21000)) + + checkAnswer( + sql( + "select salary * 2 as salary, (salary + bonus) * 3 - (salary + bonus) as " + + s"new_income from $testTable where name = 'amy'"), + Row(20000, 22000)) + + checkAnswer( + sql( + "select salary * 2 as salary, (salary + bonus) * 2 as bonus, " + + s"salary + bonus as prev_income, prev_income + bonus + salary from $testTable" + + " where name = 'amy'"), + Row(20000, 22000, 11000, 22000)) + + checkAnswer( + sql(s"SELECT named_struct('joinYear', 2022) AS properties, properties.joinYear " + + s"FROM $testTable WHERE name = 'amy'"), + Row(Row(2022), 2019)) + + checkAnswer( + sql(s"SELECT named_struct('name', 'someone') AS $testTable, $testTable.name " + + s"FROM $testTable WHERE name = 'amy'"), + Row(Row("someone"), "amy")) + } + + testOnAndOff("Lateral alias conflicts with OuterReference - Project") { + // an attribute can both be resolved as LCA and OuterReference + val query1 = + s""" + |SELECT * + |FROM range(1, 7) + |WHERE ( + | SELECT id2 + | FROM (SELECT 1 AS id, id + 1 AS id2)) > 5 + |ORDER BY id + |""".stripMargin + withLCAOff { checkAnswer(sql(query1), Row(5) :: Row(6) :: Nil) } + withLCAOn { checkAnswer(sql(query1), Seq.empty) } + + // an attribute can only be resolved as LCA + val query2 = + s""" + |SELECT * + |FROM range(1, 7) + |WHERE ( + | SELECT id2 + | FROM (SELECT 1 AS id1, id1 + 1 AS id2)) > 5 + |""".stripMargin + withLCAOff { + assert(intercept[AnalysisException] { sql(query2) } + .getErrorClass == "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION") + } + withLCAOn { checkAnswer(sql(query2), Seq.empty) } + + // an attribute should only be resolved as OuterReference + val query3 = + s""" + |SELECT * + |FROM range(1, 7) outer_table + |WHERE ( + | SELECT id2 + | FROM (SELECT 1 AS id, outer_table.id + 1 AS id2)) > 5 + |""".stripMargin + withLCAOff { checkAnswer(sql(query3), Row(5) :: Row(6) :: Nil) } + withLCAOn { checkAnswer(sql(query3), Row(5) :: Row(6) :: Nil) } + + // a bit complex subquery that the id + 1 is first wrapped with OuterReference + // test if lca rule strips the OuterReference and resolves to lateral alias + val query4 = + s""" + |SELECT * + |FROM range(1, 7) + |WHERE ( + | SELECT id2 + | FROM (SELECT dept * 2.0 AS id, id + 1 AS id2 FROM $testTable)) > 5 + |ORDER BY id + |""".stripMargin + withLCAOff { intercept[AnalysisException] { sql(query4) } } // surprisingly can't run .. + withLCAOn { + val analyzedPlan = sql(query4).queryExecution.analyzed + assert(!analyzedPlan.containsPattern(OUTER_REFERENCE)) + // but running it triggers exception + // checkAnswer(sql(query4), Range(1, 7).map(Row(_))) + } + } + // TODO: more tests on LCA in subquery + + test("Lateral alias of a complex type - Project") { + checkAnswer( + sql("SELECT named_struct('a', 1) AS foo, foo.a + 1 AS bar, bar + 1"), + Row(Row(1), 2, 3)) + + checkAnswer( + sql("SELECT named_struct('a', named_struct('b', 1)) AS foo, foo.a.b + 1 AS bar"), + Row(Row(Row(1)), 2) + ) + + checkAnswer( + sql("SELECT array(1, 2, 3) AS foo, foo[1] AS bar, bar + 1"), + Row(Seq(1, 2, 3), 2, 3) + ) + checkAnswer( + sql("SELECT array(array(1, 2), array(1, 2, 3), array(100)) AS foo, foo[2][0] + 1 AS bar"), + Row(Seq(Seq(1, 2), Seq(1, 2, 3), Seq(100)), 101) + ) + checkAnswer( + sql("SELECT array(named_struct('a', 1), named_struct('a', 2)) AS foo, foo[0].a + 1 AS bar"), + Row(Seq(Row(1), Row(2)), 2) + ) + + checkAnswer( + sql("SELECT map('a', 1, 'b', 2) AS foo, foo['b'] AS bar, bar + 1"), + Row(Map("a" -> 1, "b" -> 2), 2, 3) + ) + } + + test("Lateral alias reference attribute further be used by upper plan - Project") { + // this is out of the scope of lateral alias project functionality requirements, but naturally + // supported by the current design + checkAnswer( + sql(s"SELECT properties AS new_properties, new_properties.joinYear AS new_join_year " + + s"FROM $testTable WHERE dept = 1 ORDER BY new_join_year DESC"), + Row(Row(2020, "B"), 2020) :: Row(Row(2019, "A"), 2019) :: Nil + ) + } + + test("Lateral alias chaining - Project") { + checkAnswer( + sql( + s""" + |SELECT bonus * 1.1 AS new_bonus, salary + new_bonus AS new_base, + | new_base * 1.1 AS new_total, new_total - new_base AS r, + | new_total - r + |FROM $testTable WHERE name = 'cathy' + |""".stripMargin), + Row(1320, 10320, 11352, 1032, 10320) + ) + + checkAnswer( + sql("SELECT 1 AS a, a + 1 AS b, b - 1, b + 1 AS c, c + 1 AS d, d - a AS e, e + 1"), + Row(1, 2, 1, 3, 4, 3, 4) + ) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index d1e6a5df16..efcb501d95 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2690,10 +2690,24 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark checkAnswer(sql("SELECT struct(1 a) UNION ALL (SELECT struct(2 A))"), Row(Row(1)) :: Row(Row(2)) :: Nil) - val m2 = intercept[AnalysisException] { - sql("SELECT struct(1 a) EXCEPT (SELECT struct(2 A))") - }.message - assert(m2.contains("Except can only be performed on tables with compatible column types")) + checkError( + exception = intercept[AnalysisException] { + sql("SELECT struct(1 a) EXCEPT (SELECT struct(2 A))") + }, + errorClass = "_LEGACY_ERROR_TEMP_2430", + parameters = Map( + "operator" -> "EXCEPT", + "dt1" -> "struct", + "dt2" -> "struct", + "hint" -> "", + "ci" -> "first", + "ti" -> "second" + ), + context = ExpectedContext( + fragment = "SELECT struct(1 a) EXCEPT (SELECT struct(2 A))", + start = 0, + stop = 45) + ) withTable("t", "S") { sql("CREATE TABLE t(c struct) USING parquet") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index d4077274d5..95d9245c57 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -124,16 +124,29 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared Seq(ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3)))).toDF().write.saveAsTable(tableName) // Test unsupported data types - val err1 = intercept[AnalysisException] { - sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS data") - } - assert(err1.message.contains("does not support statistics collection")) + checkError( + exception = intercept[AnalysisException] { + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS data") + }, + errorClass = "_LEGACY_ERROR_TEMP_1235", + parameters = Map( + "name" -> "data", + "tableIdent" -> "`spark_catalog`.`default`.`column_stats_test1`", + "dataType" -> "ArrayType(IntegerType,true)" + ) + ) // Test invalid columns - val err2 = intercept[AnalysisException] { - sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS some_random_column") - } - assert(err2.message.contains("does not exist")) + checkError( + exception = intercept[AnalysisException] { + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS some_random_column") + }, + errorClass = "COLUMN_NOT_FOUND", + parameters = Map( + "colName" -> "`some_random_column`", + "caseSensitiveConfig" -> "\"spark.sql.caseSensitive\"" + ) + ) } } @@ -581,10 +594,13 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared withTempView("tempView") { // Analyzes in a temporary view sql("CREATE TEMPORARY VIEW tempView AS SELECT 1 id") - val errMsg = intercept[AnalysisException] { - sql("ANALYZE TABLE tempView COMPUTE STATISTICS FOR COLUMNS id") - }.getMessage - assert(errMsg.contains("Temporary view `tempView` is not cached for analyzing columns")) + checkError( + exception = intercept[AnalysisException] { + sql("ANALYZE TABLE tempView COMPUTE STATISTICS FOR COLUMNS id") + }, + errorClass = "UNSUPPORTED_FEATURE.ANALYZE_UNCACHED_TEMP_VIEW", + parameters = Map("viewName" -> "`tempView`") + ) // Cache the view then analyze it sql("CACHE TABLE tempView") @@ -604,11 +620,13 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared ExpectedContext(s"$globalTempDB.gTempView", 14, 13 + s"$globalTempDB.gTempView".length)) // Analyzes in a global temporary view sql("CREATE GLOBAL TEMP VIEW gTempView AS SELECT 1 id") - val errMsg2 = intercept[AnalysisException] { - sql(s"ANALYZE TABLE $globalTempDB.gTempView COMPUTE STATISTICS FOR COLUMNS id") - }.getMessage - assert(errMsg2.contains( - s"Temporary view `$globalTempDB`.`gTempView` is not cached for analyzing columns")) + checkError( + exception = intercept[AnalysisException] { + sql(s"ANALYZE TABLE $globalTempDB.gTempView COMPUTE STATISTICS FOR COLUMNS id") + }, + errorClass = "UNSUPPORTED_FEATURE.ANALYZE_UNCACHED_TEMP_VIEW", + parameters = Map("viewName" -> "`global_temp`.`gTempView`") + ) // Cache the view then analyze it sql(s"CACHE TABLE $globalTempDB.gTempView") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index d99c170fae..03b42a760e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -29,6 +29,8 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchDatabaseException, NoSuchNamespaceException, TableAlreadyExistsException} import org.apache.spark.sql.catalyst.parser.ParseException +import org.apache.spark.sql.catalyst.plans.logical.ColumnStat +import org.apache.spark.sql.catalyst.statsEstimation.StatsEstimationTestBase import org.apache.spark.sql.catalyst.util.{DateTimeUtils, ResolveDefaultColumns} import org.apache.spark.sql.connector.catalog._ import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME @@ -47,7 +49,7 @@ import org.apache.spark.unsafe.types.UTF8String abstract class DataSourceV2SQLSuite extends InsertIntoTests(supportsDynamicOverwrite = true, includeSQLOnlyTests = true) - with DeleteFromTests with DatasourceV2SQLBase { + with DeleteFromTests with DatasourceV2SQLBase with StatsEstimationTestBase { protected val v2Source = classOf[FakeV2Provider].getName override protected val v2Format = v2Source @@ -2779,17 +2781,16 @@ class DataSourceV2SQLSuiteV1Filter " (4, null), (5, 'test5')") val df = spark.sql("select * from testcat.test") + val expectedColumnStats = Seq( + "id" -> ColumnStat(Some(5), None, None, Some(0), None, None, None, 2), + "data" -> ColumnStat(Some(3), None, None, Some(3), None, None, None, 2)) df.queryExecution.optimizedPlan.collect { case scan: DataSourceV2ScanRelation => val stats = scan.stats assert(stats.sizeInBytes == 200) assert(stats.rowCount.get == 5) - val colStats = stats.attributeStats.values.toArray - assert(colStats.length == 2) - assert(colStats(0).distinctCount.get == 3) - assert(colStats(0).nullCount.get == 3) - assert(colStats(1).distinctCount.get == 5) - assert(colStats(1).nullCount.get == 0) + assert(stats.attributeStats == + toAttributeMap(expectedColumnStats, df.queryExecution.optimizedPlan)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala index d3b43059d3..833d9c3c7b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala @@ -218,10 +218,13 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { }.getMessage assert(e4.contains( s"$viewName is a temp view. 'ANALYZE TABLE' expects a table or permanent view.")) - val e5 = intercept[AnalysisException] { - sql(s"ANALYZE TABLE $viewName COMPUTE STATISTICS FOR COLUMNS id") - }.getMessage - assert(e5.contains(s"Temporary view `$viewName` is not cached for analyzing columns.")) + checkError( + exception = intercept[AnalysisException] { + sql(s"ANALYZE TABLE $viewName COMPUTE STATISTICS FOR COLUMNS id") + }, + errorClass = "UNSUPPORTED_FEATURE.ANALYZE_UNCACHED_TEMP_VIEW", + parameters = Map("viewName" -> "`testView`") + ) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 81bce35a58..88baf76ba7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -2156,8 +2156,15 @@ class AdaptiveQueryExecSuite assert(aqeReads.length == 2) aqeReads.foreach { c => val stats = c.child.asInstanceOf[QueryStageExec].getRuntimeStatistics - assert(stats.sizeInBytes >= 0) + val rowCount = stats.rowCount.get assert(stats.rowCount.get >= 0) + if (rowCount == 0) { + // For empty relation, the query stage doesn't serialize any bytes. + // The SQLMetric keeps initial value. + assert(stats.sizeInBytes == -1) + } else { + assert(stats.sizeInBytes > 0) + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/DescribeTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/DescribeTableSuite.scala index 9bee8d38c0..84da38f509 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/DescribeTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/DescribeTableSuite.scala @@ -59,10 +59,16 @@ trait DescribeTableSuiteBase extends command.DescribeTableSuiteBase |CREATE TABLE $tbl |(key int COMMENT 'column_comment', col struct) |$defaultUsing""".stripMargin) - val errMsg = intercept[AnalysisException] { - sql(s"DESC $tbl key1").collect() - }.getMessage - assert(errMsg === "Column key1 does not exist.") + checkError( + exception = intercept[AnalysisException] { + sql(s"DESC $tbl key1").collect() + }, + errorClass = "COLUMN_NOT_FOUND", + parameters = Map( + "colName" -> "`key1`", + "caseSensitiveConfig" -> "\"spark.sql.caseSensitive\"" + ) + ) } } @@ -79,10 +85,16 @@ trait DescribeTableSuiteBase extends command.DescribeTableSuiteBase withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { withNamespaceAndTable("ns", "tbl") { tbl => sql(s"CREATE TABLE $tbl (key int COMMENT 'comment1') $defaultUsing") - val errMsg = intercept[AnalysisException] { - sql(s"DESC $tbl KEY").collect() - }.getMessage - assert(errMsg === "Column KEY does not exist.") + checkError( + exception = intercept[AnalysisException] { + sql(s"DESC $tbl KEY").collect() + }, + errorClass = "COLUMN_NOT_FOUND", + parameters = Map( + "colName" -> "`KEY`", + "caseSensitiveConfig" -> "\"spark.sql.caseSensitive\"" + ) + ) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/CommonFileDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/CommonFileDataSourceSuite.scala index b7d0a7fc30..739f4c440b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/CommonFileDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/CommonFileDataSourceSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources -import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.funsuite.AnyFunSuite // scalastyle:ignore funsuite import org.apache.spark.sql.{Dataset, Encoders, FakeFileSystemRequiringDSOption, SparkSession} import org.apache.spark.sql.catalyst.plans.SQLHelper @@ -27,7 +27,8 @@ import org.apache.spark.sql.catalyst.plans.SQLHelper * The tests that are not applicable to all file-based data sources should be placed to * [[org.apache.spark.sql.FileBasedDataSourceSuite]]. */ -trait CommonFileDataSourceSuite extends SQLHelper { self: AnyFunSuite => +trait CommonFileDataSourceSuite extends SQLHelper { + self: AnyFunSuite => // scalastyle:ignore funsuite protected def spark: SparkSession protected def dataSourceFormat: String diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala index a68d9b951b..e8fae210fa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala @@ -347,7 +347,8 @@ trait FileSourceAggregatePushDownSuite spark.read.format(format).load(file.getCanonicalPath).createOrReplaceTempView("test") Seq("false", "true").foreach { enableVectorizedReader => withSQLConf(aggPushDownEnabledKey -> "true", - vectorizedReaderEnabledKey -> enableVectorizedReader) { + vectorizedReaderEnabledKey -> enableVectorizedReader, + SQLConf.MAX_METADATA_STRING_LENGTH.key -> "1000") { val testMinWithAllTypes = sql("SELECT min(StringCol), min(BooleanCol), min(ByteCol), " + "min(BinaryCol), min(ShortCol), min(IntegerCol), min(LongCol), min(FloatCol), " + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 07b35713fe..1f20fb62d3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -785,7 +785,11 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils testMetricsInSparkPlanOperator(exchanges.head, Map("dataSize" -> 3200, "shuffleRecordsWritten" -> 100)) - testMetricsInSparkPlanOperator(exchanges(1), Map("dataSize" -> 0, "shuffleRecordsWritten" -> 0)) + // `testData2.filter($"b" === 0)` is an empty relation. + // The exchange doesn't serialize any bytes. + // The SQLMetric keeps initial value. + testMetricsInSparkPlanOperator(exchanges(1), + Map("dataSize" -> -1, "shuffleRecordsWritten" -> 0)) } test("Add numRows to metric of BroadcastExchangeExec") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala index 00f9c7b8c0..dd426b8e92 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala @@ -116,7 +116,11 @@ class RocksDBSuite extends SparkFunSuite { withDB(remoteDir, conf = conf) { db => // Generate versions without cleaning up for (version <- 1 to 50) { - db.put(version.toString, version.toString) // update "1" -> "1", "2" -> "2", ... + if (version > 1) { + // remove keys we wrote in previous iteration to ensure compaction happens + db.remove((version - 1).toString) + } + db.put(version.toString, version.toString) db.commit() } @@ -132,7 +136,7 @@ class RocksDBSuite extends SparkFunSuite { versionsPresent.foreach { version => db.load(version) val data = db.iterator().map(toStr).toSet - assert(data === (1L to version).map(_.toString).map(x => x -> x).toSet) + assert(data === Set((version.toString, version.toString))) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 0395798d9e..8ee6da10b5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -20,9 +20,9 @@ package org.apache.spark.sql.execution.vectorized import java.nio.ByteBuffer import java.nio.ByteOrder import java.nio.charset.StandardCharsets +import java.sql.{Date, Timestamp} import java.time.LocalDateTime import java.util -import java.util.NoSuchElementException import scala.collection.JavaConverters._ import scala.collection.mutable @@ -1379,12 +1379,26 @@ class ColumnarBatchSuite extends SparkFunSuite { "Seed = " + seed) case DoubleType => assert(doubleEquals(r1.getDouble(ordinal), r2.getDouble(ordinal)), "Seed = " + seed) + case DateType => + assert(r1.getInt(ordinal) == DateTimeUtils.fromJavaDate(r2.getDate(ordinal)), + "Seed = " + seed) + case TimestampType => + assert(r1.getLong(ordinal) == + DateTimeUtils.fromJavaTimestamp(r2.getTimestamp(ordinal)), + "Seed = " + seed) + case TimestampNTZType => + assert(r1.getLong(ordinal) == + DateTimeUtils.localDateTimeToMicros(r2.getAs[LocalDateTime](ordinal)), + "Seed = " + seed) case t: DecimalType => val d1 = r1.getDecimal(ordinal, t.precision, t.scale).toBigDecimal val d2 = r2.getDecimal(ordinal) assert(d1.compare(d2) == 0, "Seed = " + seed) case StringType => assert(r1.getString(ordinal) == r2.getString(ordinal), "Seed = " + seed) + case BinaryType => + assert(r1.getBinary(ordinal) sameElements r2.getAs[Array[Byte]](ordinal), + "Seed = " + seed) case CalendarIntervalType => assert(r1.getInterval(ordinal) === r2.get(ordinal).asInstanceOf[CalendarInterval]) case ArrayType(childType, n) => @@ -1406,6 +1420,50 @@ class ColumnarBatchSuite extends SparkFunSuite { "Seed = " + seed) i += 1 } + case StringType => + var i = 0 + while (i < a1.length) { + assert((a1(i) == null) == (a2(i) == null), "Seed = " + seed) + if (a1(i) != null) { + val s1 = a1(i).asInstanceOf[UTF8String].toString + val s2 = a2(i).asInstanceOf[String] + assert(s1 === s2, "Seed = " + seed) + } + i += 1 + } + case DateType => + var i = 0 + while (i < a1.length) { + assert((a1(i) == null) == (a2(i) == null), "Seed = " + seed) + if (a1(i) != null) { + val i1 = a1(i).asInstanceOf[Int] + val i2 = DateTimeUtils.fromJavaDate(a2(i).asInstanceOf[Date]) + assert(i1 === i2, "Seed = " + seed) + } + i += 1 + } + case TimestampType => + var i = 0 + while (i < a1.length) { + assert((a1(i) == null) == (a2(i) == null), "Seed = " + seed) + if (a1(i) != null) { + val i1 = a1(i).asInstanceOf[Long] + val i2 = DateTimeUtils.fromJavaTimestamp(a2(i).asInstanceOf[Timestamp]) + assert(i1 === i2, "Seed = " + seed) + } + i += 1 + } + case TimestampNTZType => + var i = 0 + while (i < a1.length) { + assert((a1(i) == null) == (a2(i) == null), "Seed = " + seed) + if (a1(i) != null) { + val i1 = a1(i).asInstanceOf[Long] + val i2 = DateTimeUtils.localDateTimeToMicros(a2(i).asInstanceOf[LocalDateTime]) + assert(i1 === i2, "Seed = " + seed) + } + i += 1 + } case t: DecimalType => var i = 0 while (i < a1.length) { @@ -1457,12 +1515,12 @@ class ColumnarBatchSuite extends SparkFunSuite { * results. */ def testRandomRows(flatSchema: Boolean, numFields: Int): Unit = { - // TODO: Figure out why StringType doesn't work on jenkins. val types = Array( BooleanType, ByteType, FloatType, DoubleType, IntegerType, LongType, ShortType, DecimalType.ShortDecimal, DecimalType.IntDecimal, DecimalType.ByteDecimal, DecimalType.FloatDecimal, DecimalType.LongDecimal, new DecimalType(5, 2), - new DecimalType(12, 2), new DecimalType(30, 10), CalendarIntervalType) + new DecimalType(12, 2), new DecimalType(30, 10), CalendarIntervalType, + DateType, StringType, BinaryType, TimestampType, TimestampNTZType) val seed = System.nanoTime() val NUM_ROWS = 200 val NUM_ITERS = 1000 diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/LogDivertAppender.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/LogDivertAppender.java index 8ee606be31..32cc42f008 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/LogDivertAppender.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/LogDivertAppender.java @@ -242,7 +242,7 @@ public boolean isStopped() { } /** This is where the log message will go to */ - private final CharArrayWriter writer = new CharArrayWriter(); + private final CharArrayWriter writer; private static StringLayout getLayout(boolean isVerbose, StringLayout lo) { if (isVerbose) { @@ -276,12 +276,19 @@ private static StringLayout initLayout(OperationLog.LoggingLevel loggingMode) { return getLayout(isVerbose, layout); } - public LogDivertAppender(OperationManager operationManager, + public static LogDivertAppender create(OperationManager operationManager, OperationLog.LoggingLevel loggingMode) { + CharArrayWriter writer = new CharArrayWriter(); + return new LogDivertAppender(operationManager, loggingMode, writer); + } + + private LogDivertAppender(OperationManager operationManager, + OperationLog.LoggingLevel loggingMode, CharArrayWriter writer) { super("LogDivertAppender", initLayout(loggingMode), null, false, true, Property.EMPTY_ARRAY, - new WriterManager(new CharArrayWriter(), "LogDivertAppender", + new WriterManager(writer, "LogDivertAppender", initLayout(loggingMode), true)); + this.writer = writer; this.isVerbose = (loggingMode == OperationLog.LoggingLevel.VERBOSE); this.operationManager = operationManager; addFilter(new NameFilter(loggingMode, operationManager)); @@ -301,7 +308,7 @@ public void append(LogEvent event) { isVerbose = isCurrModeVerbose; } } - + super.append(event); // That should've gone into our writer. Notify the LogContext. String logOutput = writer.toString(); diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/OperationManager.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/OperationManager.java index 6ee48186e7..bb68c84049 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/OperationManager.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/OperationManager.java @@ -82,7 +82,7 @@ public synchronized void stop() { private void initOperationLogCapture(String loggingMode) { // Register another Appender (with the same layout) that talks to us. - Appender ap = new LogDivertAppender(this, OperationLog.getLoggingLevel(loggingMode)); + Appender ap = LogDivertAppender.create(this, OperationLog.getLoggingLevel(loggingMode)); ((org.apache.logging.log4j.core.Logger)org.apache.logging.log4j.LogManager.getRootLogger()).addAppender(ap); ap.start(); } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala index 4d71ce0e49..9304074e86 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala @@ -108,7 +108,8 @@ package object client { "org.apache.calcite.avatica:avatica", "com.fasterxml.jackson.core:*", "org.apache.curator:*", - "org.pentaho:pentaho-aggdesigner-algorithm")) + "org.pentaho:pentaho-aggdesigner-algorithm", + "org.apache.hive:hive-vector-code-gen")) // Since Hive 3.0, HookUtils uses org.apache.logging.log4j.util.Strings // Since HIVE-14496, Hive.java uses calcite-core @@ -117,7 +118,8 @@ package object client { "org.apache.derby:derby:10.14.1.0"), exclusions = Seq("org.apache.calcite:calcite-druid", "org.apache.curator:*", - "org.pentaho:pentaho-aggdesigner-algorithm")) + "org.pentaho:pentaho-aggdesigner-algorithm", + "org.apache.hive:hive-vector-code-gen")) // Since Hive 3.0, HookUtils uses org.apache.logging.log4j.util.Strings // Since HIVE-14496, Hive.java uses calcite-core @@ -126,7 +128,8 @@ package object client { "org.apache.derby:derby:10.14.1.0"), exclusions = Seq("org.apache.calcite:calcite-druid", "org.apache.curator:*", - "org.pentaho:pentaho-aggdesigner-algorithm")) + "org.pentaho:pentaho-aggdesigner-algorithm", + "org.apache.hive:hive-vector-code-gen")) val allSupportedHiveVersions = Set(v12, v13, v14, v1_0, v1_1, v1_2, v2_0, v2_1, v2_2, v2_3, v3_0, v3_1) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 4b69a01834..a03120ca44 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -27,7 +27,7 @@ import scala.util.matching.Regex import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.spark.metrics.source.HiveCatalogMetrics -import org.apache.spark.sql._ +import org.apache.spark.sql.{AnalysisException, _} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.NoSuchPartitionException import org.apache.spark.sql.catalyst.catalog.{CatalogColumnStat, CatalogStatistics, HiveTableRelation} @@ -43,7 +43,6 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils - class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleton { test("size estimation for relations is based on row size * number of rows") { @@ -582,6 +581,24 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto } } + test("analyze not found column") { + val tableName = "analyzeTable" + withTable(tableName) { + sql(s"CREATE TABLE $tableName (key STRING, value STRING) PARTITIONED BY (ds STRING)") + + checkError( + exception = intercept[AnalysisException] { + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS fakeColumn") + }, + errorClass = "COLUMN_NOT_FOUND", + parameters = Map( + "colName" -> "`fakeColumn`", + "caseSensitiveConfig" -> "\"spark.sql.caseSensitive\"" + ) + ) + } + } + test("analyze non-existent partition") { def assertAnalysisException(analyzeCommand: String, errorMessage: String): Unit = {