diff --git a/.github/workflows/pr_build_linux.yml b/.github/workflows/pr_build_linux.yml index 899fa6139a..6811d6c2b4 100644 --- a/.github/workflows/pr_build_linux.yml +++ b/.github/workflows/pr_build_linux.yml @@ -338,6 +338,7 @@ jobs: org.apache.comet.CometCsvExpressionSuite org.apache.comet.CometJsonExpressionSuite org.apache.comet.CometDateTimeUtilsSuite + org.apache.comet.SparkErrorConverterSuite org.apache.comet.expressions.conditional.CometIfSuite org.apache.comet.expressions.conditional.CometCoalesceSuite org.apache.comet.expressions.conditional.CometCaseWhenSuite diff --git a/.github/workflows/pr_build_macos.yml b/.github/workflows/pr_build_macos.yml index 53001b04e6..8362a6cfba 100644 --- a/.github/workflows/pr_build_macos.yml +++ b/.github/workflows/pr_build_macos.yml @@ -213,6 +213,7 @@ jobs: org.apache.comet.CometJsonExpressionSuite org.apache.comet.CometCsvExpressionSuite org.apache.comet.CometDateTimeUtilsSuite + org.apache.comet.SparkErrorConverterSuite org.apache.comet.expressions.conditional.CometIfSuite org.apache.comet.expressions.conditional.CometCoalesceSuite org.apache.comet.expressions.conditional.CometCaseWhenSuite diff --git a/spark/src/main/spark-3.4/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala b/spark/src/main/spark-3.4/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala index 6eee3f5bc0..46aed78b70 100644 --- a/spark/src/main/spark-3.4/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala +++ b/spark/src/main/spark-3.4/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala @@ -44,6 +44,24 @@ trait ShimSparkErrorConverter { private def sqlCtx(context: Array[QueryContext]): SQLQueryContext = context.headOption.map(_.asInstanceOf[SQLQueryContext]).getOrElse(null) + private def parseFloatLiteral(value: String): Float = { + value.toLowerCase match { + case "inf" | "+inf" | "infinity" | "+infinity" => Float.PositiveInfinity + case "-inf" | "-infinity" => Float.NegativeInfinity + case "nan" | "+nan" | "-nan" => Float.NaN + case _ => value.toFloat + } + } + + private def parseDoubleLiteral(value: String): Double = { + value.toLowerCase match { + case "inf" | "+inf" | "infinity" | "+infinity" | "infd" => Double.PositiveInfinity + case "-inf" | "-infinity" | "-infd" => Double.NegativeInfinity + case "nan" | "+nan" | "-nan" | "nand" | "-nand" => Double.NaN + case _ => value.toDouble + } + } + def convertErrorType( errorType: String, errorClass: String, @@ -207,8 +225,8 @@ trait ShimSparkErrorConverter { case LongType => val cleanStr = if (valueStr.endsWith("L")) valueStr.dropRight(1) else valueStr cleanStr.toLong - case FloatType => valueStr.toFloat - case DoubleType => valueStr.toDouble + case FloatType => parseFloatLiteral(valueStr) + case DoubleType => parseDoubleLiteral(valueStr) case StringType => UTF8String.fromString(valueStr) case _ => valueStr } diff --git a/spark/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala b/spark/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala index 75316c51ef..ba30aa6924 100644 --- a/spark/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala +++ b/spark/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala @@ -44,6 +44,24 @@ trait ShimSparkErrorConverter { private def sqlCtx(context: Array[QueryContext]): SQLQueryContext = context.headOption.map(_.asInstanceOf[SQLQueryContext]).getOrElse(null) + private def parseFloatLiteral(value: String): Float = { + value.toLowerCase match { + case "inf" | "+inf" | "infinity" | "+infinity" => Float.PositiveInfinity + case "-inf" | "-infinity" => Float.NegativeInfinity + case "nan" | "+nan" | "-nan" => Float.NaN + case _ => value.toFloat + } + } + + private def parseDoubleLiteral(value: String): Double = { + value.toLowerCase match { + case "inf" | "+inf" | "infinity" | "+infinity" | "infd" => Double.PositiveInfinity + case "-inf" | "-infinity" | "-infd" => Double.NegativeInfinity + case "nan" | "+nan" | "-nan" | "nand" | "-nand" => Double.NaN + case _ => value.toDouble + } + } + def convertErrorType( errorType: String, errorClass: String, @@ -205,8 +223,8 @@ trait ShimSparkErrorConverter { case LongType => val cleanStr = if (valueStr.endsWith("L")) valueStr.dropRight(1) else valueStr cleanStr.toLong - case FloatType => valueStr.toFloat - case DoubleType => valueStr.toDouble + case FloatType => parseFloatLiteral(valueStr) + case DoubleType => parseDoubleLiteral(valueStr) case StringType => UTF8String.fromString(valueStr) case _ => valueStr } diff --git a/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala b/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala index fc13a58a41..2ed07d0f3e 100644 --- a/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala +++ b/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala @@ -37,6 +37,24 @@ object ShimSparkErrorConverter { */ trait ShimSparkErrorConverter { + private def parseFloatLiteral(value: String): Float = { + value.toLowerCase match { + case "inf" | "+inf" | "infinity" | "+infinity" => Float.PositiveInfinity + case "-inf" | "-infinity" => Float.NegativeInfinity + case "nan" | "+nan" | "-nan" => Float.NaN + case _ => value.toFloat + } + } + + private def parseDoubleLiteral(value: String): Double = { + value.toLowerCase match { + case "inf" | "+inf" | "infinity" | "+infinity" | "infd" => Double.PositiveInfinity + case "-inf" | "-infinity" | "-infd" => Double.NegativeInfinity + case "nan" | "+nan" | "-nan" | "nand" | "-nand" => Double.NaN + case _ => value.toDouble + } + } + /** * Convert error type string and parameters to appropriate Spark exception. Version-specific * implementations call the correct QueryExecutionErrors.* methods. @@ -213,8 +231,8 @@ trait ShimSparkErrorConverter { // Strip "L" suffix for BIGINT literals val cleanStr = if (valueStr.endsWith("L")) valueStr.dropRight(1) else valueStr cleanStr.toLong - case FloatType => valueStr.toFloat - case DoubleType => valueStr.toDouble + case FloatType => parseFloatLiteral(valueStr) + case DoubleType => parseDoubleLiteral(valueStr) case StringType => UTF8String.fromString(valueStr) case _ => valueStr // Fallback to string } diff --git a/spark/src/test/scala/org/apache/comet/SparkErrorConverterSuite.scala b/spark/src/test/scala/org/apache/comet/SparkErrorConverterSuite.scala new file mode 100644 index 0000000000..6530c5da7a --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/SparkErrorConverterSuite.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet + +import org.scalatest.funsuite.AnyFunSuite + +class SparkErrorConverterSuite extends AnyFunSuite { + private def castOverflowError(fromType: String, value: String): Throwable = { + SparkErrorConverter + .convertErrorType( + "CastOverFlow", + "CAST_OVERFLOW", + Map("fromType" -> fromType, "toType" -> "INT", "value" -> value), + Array.empty, + null) + .getOrElse(fail("Expected CastOverFlow to be converted to a Spark exception")) + } + + test("CastOverFlow conversion handles float inf") { + val err = castOverflowError("FLOAT", "inf") + assert(!err.isInstanceOf[NumberFormatException]) + assert(err.getMessage.contains("Infinity")) + } + + test("CastOverFlow conversion handles float -inf") { + val err = castOverflowError("FLOAT", "-inf") + assert(!err.isInstanceOf[NumberFormatException]) + assert(err.getMessage.contains("-Infinity")) + } + + test("CastOverFlow conversion handles double inf") { + val err = castOverflowError("DOUBLE", "inf") + assert(!err.isInstanceOf[NumberFormatException]) + assert(err.getMessage.contains("Infinity")) + } + + test("CastOverFlow conversion handles double -inf") { + val err = castOverflowError("DOUBLE", "-inf") + assert(!err.isInstanceOf[NumberFormatException]) + assert(err.getMessage.contains("-Infinity")) + } + + test("CastOverFlow conversion handles float nan") { + val err = castOverflowError("FLOAT", "nan") + assert(!err.isInstanceOf[NumberFormatException]) + assert(err.getMessage.toLowerCase.contains("nan")) + } + + test("CastOverFlow conversion handles double nan") { + val err = castOverflowError("DOUBLE", "nan") + assert(!err.isInstanceOf[NumberFormatException]) + assert(err.getMessage.toLowerCase.contains("nan")) + } +}