diff --git a/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala index 0745ddc09911..23c2315400ff 100644 --- a/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala +++ b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala @@ -257,11 +257,17 @@ class SparkConnectResultSet( override def getCharacterStream(columnLabel: String): Reader = throw new SQLFeatureNotSupportedException - override def getBigDecimal(columnIndex: Int): java.math.BigDecimal = - throw new SQLFeatureNotSupportedException + override def getBigDecimal(columnIndex: Int): java.math.BigDecimal = { + if (currentRow.isNullAt(columnIndex - 1)) { + _wasNull = true + return null + } + _wasNull = false + currentRow.getDecimal(columnIndex - 1) + } override def getBigDecimal(columnLabel: String): java.math.BigDecimal = - throw new SQLFeatureNotSupportedException + getBigDecimal(findColumn(columnLabel)) override def isBeforeFirst: Boolean = { checkOpen() diff --git a/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/util/JdbcTypeUtils.scala b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/util/JdbcTypeUtils.scala index 55e3d29c99a5..c2b27128caa7 100644 --- a/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/util/JdbcTypeUtils.scala +++ b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/util/JdbcTypeUtils.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.connect.client.jdbc.util import java.lang.{Boolean => JBoolean, Byte => JByte, Double => JDouble, Float => JFloat, Long => JLong, Short => JShort} +import java.math.{BigDecimal => JBigDecimal} import java.sql.{Array => _, _} import org.apache.spark.sql.types._ @@ -34,6 +35,7 @@ private[jdbc] object JdbcTypeUtils { case FloatType => Types.FLOAT case DoubleType => Types.DOUBLE case StringType => Types.VARCHAR + case _: DecimalType => Types.DECIMAL case other => throw new SQLFeatureNotSupportedException(s"DataType $other is not supported yet.") } @@ -48,12 +50,14 @@ private[jdbc] object JdbcTypeUtils { case FloatType => classOf[JFloat].getName case DoubleType => classOf[JDouble].getName case StringType => classOf[String].getName + case _: DecimalType => classOf[JBigDecimal].getName case other => throw new SQLFeatureNotSupportedException(s"DataType $other is not supported yet.") } def isSigned(field: StructField): Boolean = field.dataType match { - case ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => true + case ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | + _: DecimalType => true case NullType | BooleanType | StringType => false case other => throw new SQLFeatureNotSupportedException(s"DataType $other is not supported yet.") @@ -69,6 +73,7 @@ private[jdbc] object JdbcTypeUtils { case FloatType => 7 case DoubleType => 15 case StringType => 255 + case DecimalType.Fixed(p, _) => p case other => throw new SQLFeatureNotSupportedException(s"DataType $other is not supported yet.") } @@ -77,6 +82,7 @@ private[jdbc] object JdbcTypeUtils { case FloatType => 7 case DoubleType => 15 case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType | StringType => 0 + case DecimalType.Fixed(_, s) => s case other => throw new SQLFeatureNotSupportedException(s"DataType $other is not supported yet.") } @@ -90,6 +96,12 @@ private[jdbc] object JdbcTypeUtils { case DoubleType => 24 case StringType => getPrecision(field) + // precision + negative sign + leading zero + decimal point, like DECIMAL(5,5) = -0.12345 + case DecimalType.Fixed(p, s) if p == s => p + 3 + // precision + negative sign, like DECIMAL(5,0) = -12345 + case DecimalType.Fixed(p, s) if s == 0 => p + 1 + // precision + negative sign + decimal point, like DECIMAL(5,2) = -123.45 + case DecimalType.Fixed(p, _) => p + 2 case other => throw new SQLFeatureNotSupportedException(s"DataType $other is not supported yet.") } diff --git a/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectJdbcDataTypeSuite.scala b/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectJdbcDataTypeSuite.scala index 619b279310eb..089c1d7fdf0d 100644 --- a/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectJdbcDataTypeSuite.scala +++ b/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectJdbcDataTypeSuite.scala @@ -215,4 +215,37 @@ class SparkConnectJdbcDataTypeSuite extends ConnectFunSuite with RemoteSparkSess assert(metaData.getColumnDisplaySize(1) === 255) } } + + test("get decimal type") { + Seq( + ("123.45", 37, 2, 39), + ("-0.12345", 5, 5, 8), + ("-0.12345", 6, 5, 8), + ("-123.45", 5, 2, 7), + ("12345", 5, 0, 6), + ("-12345", 5, 0, 6) + ).foreach { + case (value, precision, scale, expectedColumnDisplaySize) => + val decimalType = s"DECIMAL($precision,$scale)" + withExecuteQuery(s"SELECT cast('$value' as $decimalType)") { rs => + assert(rs.next()) + assert(rs.getBigDecimal(1) === new java.math.BigDecimal(value)) + assert(!rs.wasNull) + assert(!rs.next()) + + val metaData = rs.getMetaData + assert(metaData.getColumnCount === 1) + assert(metaData.getColumnName(1) === s"CAST($value AS $decimalType)") + assert(metaData.getColumnLabel(1) === s"CAST($value AS $decimalType)") + assert(metaData.getColumnType(1) === Types.DECIMAL) + assert(metaData.getColumnTypeName(1) === decimalType) + assert(metaData.getColumnClassName(1) === "java.math.BigDecimal") + assert(metaData.isSigned(1) === true) + assert(metaData.getPrecision(1) === precision) + assert(metaData.getScale(1) === scale) + assert(metaData.getColumnDisplaySize(1) === expectedColumnDisplaySize) + assert(metaData.getColumnDisplaySize(1) >= value.size) + } + } + } }