From fb9951dd99d8d387d424a19f5ced001bcfe03cf6 Mon Sep 17 00:00:00 2001 From: cty123 Date: Mon, 10 Nov 2025 20:27:05 -0500 Subject: [PATCH 01/12] address the input validation of getter functions. --- .../client/jdbc/SparkConnectResultSet.scala | 87 ++++++------------- .../jdbc/SparkConnectJdbcDataTypeSuite.scala | 57 +++++++++++- 2 files changed, 83 insertions(+), 61 deletions(-) diff --git a/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala index 23c2315400ff..8dde19b25a62 100644 --- a/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala +++ b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala @@ -76,6 +76,23 @@ class SparkConnectResultSet( } } + private[jdbc] def getField[T](columnIndex: Int)(get: Int => T): Option[T] = { + checkOpen() + if (columnIndex < 0 || columnIndex >= currentRow.length) { + throw new SQLException(s"The column index is out of range: $columnIndex, " + + s"number of columns: ${currentRow.length}.") + } + + Option(currentRow.get(columnIndex)) match { + case Some(rawField) => + _wasNull = false + Some(get(columnIndex)) + case None => + _wasNull = true + None + } + } + override def findColumn(columnLabel: String): Int = { sparkResult.schema.getFieldIndex(columnLabel) match { case Some(i) => i + 1 @@ -85,75 +102,35 @@ class SparkConnectResultSet( } override def getString(columnIndex: Int): String = { - if (currentRow.isNullAt(columnIndex - 1)) { - _wasNull = true - return null - } - _wasNull = false - String.valueOf(currentRow.get(columnIndex - 1)) + getField(columnIndex - 1) { idx => String.valueOf(currentRow.get(idx)) }.orNull } override def getBoolean(columnIndex: Int): Boolean = { - if (currentRow.isNullAt(columnIndex - 1)) { - _wasNull = true - return false - } - _wasNull = false - currentRow.getBoolean(columnIndex - 1) + getField(columnIndex - 1) { idx => currentRow.getBoolean(idx) }.getOrElse(false) } override def getByte(columnIndex: Int): Byte = { - if (currentRow.isNullAt(columnIndex - 1)) { - _wasNull = true - return 0.toByte - } - _wasNull = false - currentRow.getByte(columnIndex - 1) + getField(columnIndex - 1) { idx => currentRow.getByte(idx) }.getOrElse(0) } override def getShort(columnIndex: Int): Short = { - if (currentRow.isNullAt(columnIndex - 1)) { - _wasNull = true - return 0.toShort - } - _wasNull = false - currentRow.getShort(columnIndex - 1) + getField(columnIndex - 1) { idx => currentRow.getShort(idx) }.getOrElse(0) } override def getInt(columnIndex: Int): Int = { - if (currentRow.isNullAt(columnIndex - 1)) { - _wasNull = true - return 0 - } - _wasNull = false - currentRow.getInt(columnIndex - 1) + getField(columnIndex - 1) { idx => currentRow.getInt(idx) }.getOrElse(0) } override def getLong(columnIndex: Int): Long = { - if (currentRow.isNullAt(columnIndex - 1)) { - _wasNull = true - return 0L - } - _wasNull = false - currentRow.getLong(columnIndex - 1) + getField(columnIndex - 1) { idx => currentRow.getLong(idx) }.getOrElse(0) } override def getFloat(columnIndex: Int): Float = { - if (currentRow.isNullAt(columnIndex - 1)) { - _wasNull = true - return 0.toFloat - } - _wasNull = false - currentRow.getFloat(columnIndex - 1) + getField(columnIndex - 1) { idx => currentRow.getFloat(idx) }.getOrElse(0) } override def getDouble(columnIndex: Int): Double = { - if (currentRow.isNullAt(columnIndex - 1)) { - _wasNull = true - return 0.toDouble - } - _wasNull = false - currentRow.getDouble(columnIndex - 1) + getField(columnIndex - 1) { idx => currentRow.getDouble(idx) }.getOrElse(0) } override def getBigDecimal(columnIndex: Int, scale: Int): java.math.BigDecimal = @@ -240,12 +217,7 @@ class SparkConnectResultSet( } override def getObject(columnIndex: Int): AnyRef = { - if (currentRow.isNullAt(columnIndex - 1)) { - _wasNull = true - return null - } - _wasNull = false - currentRow.get(columnIndex - 1).asInstanceOf[AnyRef] + getField(columnIndex - 1) { idx => currentRow.get(idx).asInstanceOf[AnyRef] }.orNull } override def getObject(columnLabel: String): AnyRef = @@ -258,12 +230,7 @@ class SparkConnectResultSet( throw new SQLFeatureNotSupportedException override def getBigDecimal(columnIndex: Int): java.math.BigDecimal = { - if (currentRow.isNullAt(columnIndex - 1)) { - _wasNull = true - return null - } - _wasNull = false - currentRow.getDecimal(columnIndex - 1) + getField(columnIndex - 1) { idx => currentRow.getDecimal(idx) }.orNull } override def getBigDecimal(columnLabel: String): java.math.BigDecimal = diff --git a/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectJdbcDataTypeSuite.scala b/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectJdbcDataTypeSuite.scala index 089c1d7fdf0d..4b463fa82699 100644 --- a/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectJdbcDataTypeSuite.scala +++ b/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectJdbcDataTypeSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.connect.client.jdbc -import java.sql.Types +import java.sql.{ResultSet, SQLException, Types} import org.apache.spark.sql.connect.client.jdbc.test.JdbcHelper import org.apache.spark.sql.connect.test.{ConnectFunSuite, RemoteSparkSession} @@ -248,4 +248,59 @@ class SparkConnectJdbcDataTypeSuite extends ConnectFunSuite with RemoteSparkSess } } } + + test("getter functions column index out of bound") { + Seq( + ("'foo'", (rs: ResultSet) => rs.getString(999)), + ("true", (rs: ResultSet) => rs.getBoolean(999)), + ("cast(1 as byte)", (rs: ResultSet) => rs.getByte(999)), + ("cast(1 as short)", (rs: ResultSet) => rs.getShort(999)), + ("cast(1 as int)", (rs: ResultSet) => rs.getInt(999)), + ("cast(1 as bigint)", (rs: ResultSet) => rs.getLong(999)), + ("cast(1 as float)", (rs: ResultSet) => rs.getFloat(999)), + ("cast(1 as double)", (rs: ResultSet) => rs.getDouble(999)), + ("cast(1 as DECIMAL(10,5))", (rs: ResultSet) => rs.getBigDecimal(999)) + ).foreach { + case (query, getter) => + withExecuteQuery(s"SELECT $query") { rs => + assert(rs.next()) + withClue("SQLException is not thrown when the result set index goes out of bound") { + intercept[SQLException] { + getter(rs) + } + } + } + } + } + + test("getter functions called after statement closed") { + Seq( + ("'foo'", (rs: ResultSet) => rs.getString(1), "foo"), + ("true", (rs: ResultSet) => rs.getBoolean(1), true), + ("cast(1 as byte)", (rs: ResultSet) => rs.getByte(1), 1.toByte), + ("cast(1 as short)", (rs: ResultSet) => rs.getShort(1), 1.toShort), + ("cast(1 as int)", (rs: ResultSet) => rs.getInt(1), 1.toInt), + ("cast(1 as bigint)", (rs: ResultSet) => rs.getLong(1), 1.toLong), + ("cast(1 as float)", (rs: ResultSet) => rs.getFloat(1), 1.toFloat), + ("cast(1 as double)", (rs: ResultSet) => rs.getDouble(1), 1.toDouble), + ("cast(1 as DECIMAL(10,5))", (rs: ResultSet) => rs.getBigDecimal(1), + new java.math.BigDecimal("1.00000")) + ).foreach { + case (query, getter, value) => + var resultSet: Option[ResultSet] = None + withExecuteQuery(s"SELECT $query") { rs => + assert(rs.next()) + assert(getter(rs) === value) + assert(!rs.wasNull) + resultSet = Some(rs) + } + assert(resultSet.isDefined) + withClue( + "SQLException is not thrown when result set is used after JDBC statement is closed") { + intercept[SQLException] { + getter(resultSet.get) + } + } + } + } } From 96926ff26553abac541c60e8b8e4074ed3d129a0 Mon Sep 17 00:00:00 2001 From: cty123 Date: Mon, 10 Nov 2025 20:43:36 -0500 Subject: [PATCH 02/12] reduce an unused variable --- .../spark/sql/connect/client/jdbc/SparkConnectResultSet.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala index 8dde19b25a62..6927aed92658 100644 --- a/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala +++ b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala @@ -84,7 +84,7 @@ class SparkConnectResultSet( } Option(currentRow.get(columnIndex)) match { - case Some(rawField) => + case Some(_) => _wasNull = false Some(get(columnIndex)) case None => From 758eadf32300aa8c020e2a14e3eb1dafc3867985 Mon Sep 17 00:00:00 2001 From: cty123 Date: Mon, 10 Nov 2025 22:00:38 -0500 Subject: [PATCH 03/12] use `isNullAt` to check for null. --- .../connect/client/jdbc/SparkConnectResultSet.scala | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala index 6927aed92658..7a4f48bf11ce 100644 --- a/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala +++ b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala @@ -83,13 +83,12 @@ class SparkConnectResultSet( s"number of columns: ${currentRow.length}.") } - Option(currentRow.get(columnIndex)) match { - case Some(_) => - _wasNull = false - Some(get(columnIndex)) - case None => - _wasNull = true - None + if (currentRow.isNullAt(columnIndex)) { + _wasNull = true + None + } else { + _wasNull = false + Some(get(columnIndex)) } } From f068c80707b1cfe7083128e47d06d342ca1ea94d Mon Sep 17 00:00:00 2001 From: cty Date: Mon, 10 Nov 2025 22:47:31 -0500 Subject: [PATCH 04/12] Update sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectJdbcDataTypeSuite.scala Co-authored-by: Cheng Pan --- .../sql/connect/client/jdbc/SparkConnectJdbcDataTypeSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectJdbcDataTypeSuite.scala b/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectJdbcDataTypeSuite.scala index 4b463fa82699..252c9603f262 100644 --- a/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectJdbcDataTypeSuite.scala +++ b/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectJdbcDataTypeSuite.scala @@ -286,7 +286,7 @@ class SparkConnectJdbcDataTypeSuite extends ConnectFunSuite with RemoteSparkSess ("cast(1 as DECIMAL(10,5))", (rs: ResultSet) => rs.getBigDecimal(1), new java.math.BigDecimal("1.00000")) ).foreach { - case (query, getter, value) => + case (query, getter, expectedValue) => var resultSet: Option[ResultSet] = None withExecuteQuery(s"SELECT $query") { rs => assert(rs.next()) From 05f59d3d4d6803bb53a94ffd0120c0f9d79994c1 Mon Sep 17 00:00:00 2001 From: cty123 Date: Mon, 10 Nov 2025 23:05:03 -0500 Subject: [PATCH 05/12] rename `getField` function to `getColumnValue` and small fixes for code style. --- .../client/jdbc/SparkConnectResultSet.scala | 22 ++++----- .../jdbc/SparkConnectJdbcDataTypeSuite.scala | 46 +++++++++---------- 2 files changed, 33 insertions(+), 35 deletions(-) diff --git a/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala index 7a4f48bf11ce..8617a28b9c7b 100644 --- a/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala +++ b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala @@ -76,7 +76,7 @@ class SparkConnectResultSet( } } - private[jdbc] def getField[T](columnIndex: Int)(get: Int => T): Option[T] = { + private[jdbc] def getColumnValue[T](columnIndex: Int)(get: Int => T): Option[T] = { checkOpen() if (columnIndex < 0 || columnIndex >= currentRow.length) { throw new SQLException(s"The column index is out of range: $columnIndex, " + @@ -101,35 +101,35 @@ class SparkConnectResultSet( } override def getString(columnIndex: Int): String = { - getField(columnIndex - 1) { idx => String.valueOf(currentRow.get(idx)) }.orNull + getColumnValue(columnIndex - 1) { idx => String.valueOf(currentRow.get(idx)) }.orNull } override def getBoolean(columnIndex: Int): Boolean = { - getField(columnIndex - 1) { idx => currentRow.getBoolean(idx) }.getOrElse(false) + getColumnValue(columnIndex - 1) { idx => currentRow.getBoolean(idx) }.getOrElse(false) } override def getByte(columnIndex: Int): Byte = { - getField(columnIndex - 1) { idx => currentRow.getByte(idx) }.getOrElse(0) + getColumnValue(columnIndex - 1) { idx => currentRow.getByte(idx) }.getOrElse(0) } override def getShort(columnIndex: Int): Short = { - getField(columnIndex - 1) { idx => currentRow.getShort(idx) }.getOrElse(0) + getColumnValue(columnIndex - 1) { idx => currentRow.getShort(idx) }.getOrElse(0) } override def getInt(columnIndex: Int): Int = { - getField(columnIndex - 1) { idx => currentRow.getInt(idx) }.getOrElse(0) + getColumnValue(columnIndex - 1) { idx => currentRow.getInt(idx) }.getOrElse(0) } override def getLong(columnIndex: Int): Long = { - getField(columnIndex - 1) { idx => currentRow.getLong(idx) }.getOrElse(0) + getColumnValue(columnIndex - 1) { idx => currentRow.getLong(idx) }.getOrElse(0) } override def getFloat(columnIndex: Int): Float = { - getField(columnIndex - 1) { idx => currentRow.getFloat(idx) }.getOrElse(0) + getColumnValue(columnIndex - 1) { idx => currentRow.getFloat(idx) }.getOrElse(0) } override def getDouble(columnIndex: Int): Double = { - getField(columnIndex - 1) { idx => currentRow.getDouble(idx) }.getOrElse(0) + getColumnValue(columnIndex - 1) { idx => currentRow.getDouble(idx) }.getOrElse(0) } override def getBigDecimal(columnIndex: Int, scale: Int): java.math.BigDecimal = @@ -216,7 +216,7 @@ class SparkConnectResultSet( } override def getObject(columnIndex: Int): AnyRef = { - getField(columnIndex - 1) { idx => currentRow.get(idx).asInstanceOf[AnyRef] }.orNull + getColumnValue(columnIndex - 1) { idx => currentRow.get(idx).asInstanceOf[AnyRef] }.orNull } override def getObject(columnLabel: String): AnyRef = @@ -229,7 +229,7 @@ class SparkConnectResultSet( throw new SQLFeatureNotSupportedException override def getBigDecimal(columnIndex: Int): java.math.BigDecimal = { - getField(columnIndex - 1) { idx => currentRow.getDecimal(idx) }.orNull + getColumnValue(columnIndex - 1) { idx => currentRow.getDecimal(idx) }.orNull } override def getBigDecimal(columnLabel: String): java.math.BigDecimal = diff --git a/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectJdbcDataTypeSuite.scala b/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectJdbcDataTypeSuite.scala index 252c9603f262..69b469830991 100644 --- a/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectJdbcDataTypeSuite.scala +++ b/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectJdbcDataTypeSuite.scala @@ -253,22 +253,22 @@ class SparkConnectJdbcDataTypeSuite extends ConnectFunSuite with RemoteSparkSess Seq( ("'foo'", (rs: ResultSet) => rs.getString(999)), ("true", (rs: ResultSet) => rs.getBoolean(999)), - ("cast(1 as byte)", (rs: ResultSet) => rs.getByte(999)), - ("cast(1 as short)", (rs: ResultSet) => rs.getShort(999)), - ("cast(1 as int)", (rs: ResultSet) => rs.getInt(999)), - ("cast(1 as bigint)", (rs: ResultSet) => rs.getLong(999)), - ("cast(1 as float)", (rs: ResultSet) => rs.getFloat(999)), - ("cast(1 as double)", (rs: ResultSet) => rs.getDouble(999)), - ("cast(1 as DECIMAL(10,5))", (rs: ResultSet) => rs.getBigDecimal(999)) + ("cast(1 AS BYTE)", (rs: ResultSet) => rs.getByte(999)), + ("cast(1 AS SHORT)", (rs: ResultSet) => rs.getShort(999)), + ("cast(1 AS INT)", (rs: ResultSet) => rs.getInt(999)), + ("cast(1 AS BIGINT)", (rs: ResultSet) => rs.getLong(999)), + ("cast(1 AS FLOAT)", (rs: ResultSet) => rs.getFloat(999)), + ("cast(1 AS DOUBLE)", (rs: ResultSet) => rs.getDouble(999)), + ("cast(1 AS DECIMAL(10,5))", (rs: ResultSet) => rs.getBigDecimal(999)) ).foreach { case (query, getter) => withExecuteQuery(s"SELECT $query") { rs => assert(rs.next()) - withClue("SQLException is not thrown when the result set index goes out of bound") { - intercept[SQLException] { - getter(rs) - } + val exception = intercept[SQLException] { + getter(rs) } + assert(exception.getMessage() === + "The column index is out of range: 998, number of columns: 1.") } } } @@ -277,30 +277,28 @@ class SparkConnectJdbcDataTypeSuite extends ConnectFunSuite with RemoteSparkSess Seq( ("'foo'", (rs: ResultSet) => rs.getString(1), "foo"), ("true", (rs: ResultSet) => rs.getBoolean(1), true), - ("cast(1 as byte)", (rs: ResultSet) => rs.getByte(1), 1.toByte), - ("cast(1 as short)", (rs: ResultSet) => rs.getShort(1), 1.toShort), - ("cast(1 as int)", (rs: ResultSet) => rs.getInt(1), 1.toInt), - ("cast(1 as bigint)", (rs: ResultSet) => rs.getLong(1), 1.toLong), - ("cast(1 as float)", (rs: ResultSet) => rs.getFloat(1), 1.toFloat), - ("cast(1 as double)", (rs: ResultSet) => rs.getDouble(1), 1.toDouble), - ("cast(1 as DECIMAL(10,5))", (rs: ResultSet) => rs.getBigDecimal(1), + ("cast(1 AS BYTE)", (rs: ResultSet) => rs.getByte(1), 1.toByte), + ("cast(1 AS SHORT)", (rs: ResultSet) => rs.getShort(1), 1.toShort), + ("cast(1 AS INT)", (rs: ResultSet) => rs.getInt(1), 1.toInt), + ("cast(1 AS BIGINT)", (rs: ResultSet) => rs.getLong(1), 1.toLong), + ("cast(1 AS FLOAT)", (rs: ResultSet) => rs.getFloat(1), 1.toFloat), + ("cast(1 AS DOUBLE)", (rs: ResultSet) => rs.getDouble(1), 1.toDouble), + ("cast(1 AS DECIMAL(10,5))", (rs: ResultSet) => rs.getBigDecimal(1), new java.math.BigDecimal("1.00000")) ).foreach { case (query, getter, expectedValue) => var resultSet: Option[ResultSet] = None withExecuteQuery(s"SELECT $query") { rs => assert(rs.next()) - assert(getter(rs) === value) + assert(getter(rs) === expectedValue) assert(!rs.wasNull) resultSet = Some(rs) } assert(resultSet.isDefined) - withClue( - "SQLException is not thrown when result set is used after JDBC statement is closed") { - intercept[SQLException] { - getter(resultSet.get) - } + val exception = intercept[SQLException] { + getter(resultSet.get) } + assert(exception.getMessage() === "JDBC Statement is closed.") } } } From 3b506c9e1c6f7d1d6a37676ad130949812340499 Mon Sep 17 00:00:00 2001 From: cty123 Date: Mon, 10 Nov 2025 23:15:56 -0500 Subject: [PATCH 06/12] optimize default value passing for `getColumnValue` function. --- .../client/jdbc/SparkConnectResultSet.scala | 28 ++++++++++--------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala index 8617a28b9c7b..6c42cc901337 100644 --- a/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala +++ b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala @@ -76,7 +76,7 @@ class SparkConnectResultSet( } } - private[jdbc] def getColumnValue[T](columnIndex: Int)(get: Int => T): Option[T] = { + private[jdbc] def getColumnValue[T](columnIndex: Int, defaultVal: T)(getter: Int => T): T = { checkOpen() if (columnIndex < 0 || columnIndex >= currentRow.length) { throw new SQLException(s"The column index is out of range: $columnIndex, " + @@ -85,10 +85,10 @@ class SparkConnectResultSet( if (currentRow.isNullAt(columnIndex)) { _wasNull = true - None + defaultVal } else { _wasNull = false - Some(get(columnIndex)) + getter(columnIndex) } } @@ -101,35 +101,35 @@ class SparkConnectResultSet( } override def getString(columnIndex: Int): String = { - getColumnValue(columnIndex - 1) { idx => String.valueOf(currentRow.get(idx)) }.orNull + getColumnValue(columnIndex - 1, null: String) { idx => String.valueOf(currentRow.get(idx)) } } override def getBoolean(columnIndex: Int): Boolean = { - getColumnValue(columnIndex - 1) { idx => currentRow.getBoolean(idx) }.getOrElse(false) + getColumnValue(columnIndex - 1, false) { idx => currentRow.getBoolean(idx) } } override def getByte(columnIndex: Int): Byte = { - getColumnValue(columnIndex - 1) { idx => currentRow.getByte(idx) }.getOrElse(0) + getColumnValue(columnIndex - 1, 0.toByte) { idx => currentRow.getByte(idx) } } override def getShort(columnIndex: Int): Short = { - getColumnValue(columnIndex - 1) { idx => currentRow.getShort(idx) }.getOrElse(0) + getColumnValue(columnIndex - 1, 0.toShort) { idx => currentRow.getShort(idx) } } override def getInt(columnIndex: Int): Int = { - getColumnValue(columnIndex - 1) { idx => currentRow.getInt(idx) }.getOrElse(0) + getColumnValue(columnIndex - 1, 0.toInt) { idx => currentRow.getInt(idx) } } override def getLong(columnIndex: Int): Long = { - getColumnValue(columnIndex - 1) { idx => currentRow.getLong(idx) }.getOrElse(0) + getColumnValue(columnIndex - 1, 0.toLong) { idx => currentRow.getLong(idx) } } override def getFloat(columnIndex: Int): Float = { - getColumnValue(columnIndex - 1) { idx => currentRow.getFloat(idx) }.getOrElse(0) + getColumnValue(columnIndex - 1, 0.toFloat) { idx => currentRow.getFloat(idx) } } override def getDouble(columnIndex: Int): Double = { - getColumnValue(columnIndex - 1) { idx => currentRow.getDouble(idx) }.getOrElse(0) + getColumnValue(columnIndex - 1, 0.toDouble) { idx => currentRow.getDouble(idx) } } override def getBigDecimal(columnIndex: Int, scale: Int): java.math.BigDecimal = @@ -216,7 +216,8 @@ class SparkConnectResultSet( } override def getObject(columnIndex: Int): AnyRef = { - getColumnValue(columnIndex - 1) { idx => currentRow.get(idx).asInstanceOf[AnyRef] }.orNull + getColumnValue(columnIndex - 1, null: AnyRef) { idx => + currentRow.get(idx).asInstanceOf[AnyRef] } } override def getObject(columnLabel: String): AnyRef = @@ -229,7 +230,8 @@ class SparkConnectResultSet( throw new SQLFeatureNotSupportedException override def getBigDecimal(columnIndex: Int): java.math.BigDecimal = { - getColumnValue(columnIndex - 1) { idx => currentRow.getDecimal(idx) }.orNull + getColumnValue(columnIndex - 1, null: java.math.BigDecimal) { idx => + currentRow.getDecimal(idx) } } override def getBigDecimal(columnLabel: String): java.math.BigDecimal = From 22edc57e6cde9e755ffc455e20fe5eb7717bf0d6 Mon Sep 17 00:00:00 2001 From: cty Date: Tue, 11 Nov 2025 00:17:03 -0500 Subject: [PATCH 07/12] Update sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala Co-authored-by: Cheng Pan --- .../spark/sql/connect/client/jdbc/SparkConnectResultSet.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala index 6c42cc901337..c7483a5a8099 100644 --- a/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala +++ b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala @@ -217,7 +217,8 @@ class SparkConnectResultSet( override def getObject(columnIndex: Int): AnyRef = { getColumnValue(columnIndex - 1, null: AnyRef) { idx => - currentRow.get(idx).asInstanceOf[AnyRef] } + currentRow.get(idx).asInstanceOf[AnyRef] + } } override def getObject(columnLabel: String): AnyRef = From f2e41c254ec2a6927da70e863d36dc8f88386a26 Mon Sep 17 00:00:00 2001 From: cty Date: Tue, 11 Nov 2025 00:17:09 -0500 Subject: [PATCH 08/12] Update sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala Co-authored-by: Cheng Pan --- .../spark/sql/connect/client/jdbc/SparkConnectResultSet.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala index c7483a5a8099..aca79a3ce00b 100644 --- a/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala +++ b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala @@ -232,7 +232,8 @@ class SparkConnectResultSet( override def getBigDecimal(columnIndex: Int): java.math.BigDecimal = { getColumnValue(columnIndex - 1, null: java.math.BigDecimal) { idx => - currentRow.getDecimal(idx) } + currentRow.getDecimal(idx) + } } override def getBigDecimal(columnLabel: String): java.math.BigDecimal = From bcfa8788f9c99ec336013dd7a8f4254e1a54f474 Mon Sep 17 00:00:00 2001 From: cty123 Date: Tue, 11 Nov 2025 00:27:10 -0500 Subject: [PATCH 09/12] address the index passed to the getColumnValue function. --- .../client/jdbc/SparkConnectResultSet.scala | 26 ++++++++++--------- .../jdbc/SparkConnectJdbcDataTypeSuite.scala | 2 +- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala index aca79a3ce00b..773371b66113 100644 --- a/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala +++ b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala @@ -76,10 +76,12 @@ class SparkConnectResultSet( } } - private[jdbc] def getColumnValue[T](columnIndex: Int, defaultVal: T)(getter: Int => T): T = { + private[jdbc] def getColumnValue[T](index: Int, defaultVal: T)(getter: Int => T): T = { checkOpen() + // the passed index value is 1-indexed, but the underlying array is 0-indexed + val columnIndex = index - 1 if (columnIndex < 0 || columnIndex >= currentRow.length) { - throw new SQLException(s"The column index is out of range: $columnIndex, " + + throw new SQLException(s"The column index is out of range: $index, " + s"number of columns: ${currentRow.length}.") } @@ -101,35 +103,35 @@ class SparkConnectResultSet( } override def getString(columnIndex: Int): String = { - getColumnValue(columnIndex - 1, null: String) { idx => String.valueOf(currentRow.get(idx)) } + getColumnValue(columnIndex, null: String) { idx => String.valueOf(currentRow.get(idx)) } } override def getBoolean(columnIndex: Int): Boolean = { - getColumnValue(columnIndex - 1, false) { idx => currentRow.getBoolean(idx) } + getColumnValue(columnIndex, false) { idx => currentRow.getBoolean(idx) } } override def getByte(columnIndex: Int): Byte = { - getColumnValue(columnIndex - 1, 0.toByte) { idx => currentRow.getByte(idx) } + getColumnValue(columnIndex, 0.toByte) { idx => currentRow.getByte(idx) } } override def getShort(columnIndex: Int): Short = { - getColumnValue(columnIndex - 1, 0.toShort) { idx => currentRow.getShort(idx) } + getColumnValue(columnIndex, 0.toShort) { idx => currentRow.getShort(idx) } } override def getInt(columnIndex: Int): Int = { - getColumnValue(columnIndex - 1, 0.toInt) { idx => currentRow.getInt(idx) } + getColumnValue(columnIndex, 0.toInt) { idx => currentRow.getInt(idx) } } override def getLong(columnIndex: Int): Long = { - getColumnValue(columnIndex - 1, 0.toLong) { idx => currentRow.getLong(idx) } + getColumnValue(columnIndex, 0.toLong) { idx => currentRow.getLong(idx) } } override def getFloat(columnIndex: Int): Float = { - getColumnValue(columnIndex - 1, 0.toFloat) { idx => currentRow.getFloat(idx) } + getColumnValue(columnIndex, 0.toFloat) { idx => currentRow.getFloat(idx) } } override def getDouble(columnIndex: Int): Double = { - getColumnValue(columnIndex - 1, 0.toDouble) { idx => currentRow.getDouble(idx) } + getColumnValue(columnIndex, 0.toDouble) { idx => currentRow.getDouble(idx) } } override def getBigDecimal(columnIndex: Int, scale: Int): java.math.BigDecimal = @@ -216,7 +218,7 @@ class SparkConnectResultSet( } override def getObject(columnIndex: Int): AnyRef = { - getColumnValue(columnIndex - 1, null: AnyRef) { idx => + getColumnValue(columnIndex, null: AnyRef) { idx => currentRow.get(idx).asInstanceOf[AnyRef] } } @@ -231,7 +233,7 @@ class SparkConnectResultSet( throw new SQLFeatureNotSupportedException override def getBigDecimal(columnIndex: Int): java.math.BigDecimal = { - getColumnValue(columnIndex - 1, null: java.math.BigDecimal) { idx => + getColumnValue(columnIndex, null: java.math.BigDecimal) { idx => currentRow.getDecimal(idx) } } diff --git a/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectJdbcDataTypeSuite.scala b/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectJdbcDataTypeSuite.scala index 69b469830991..217142287b13 100644 --- a/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectJdbcDataTypeSuite.scala +++ b/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectJdbcDataTypeSuite.scala @@ -268,7 +268,7 @@ class SparkConnectJdbcDataTypeSuite extends ConnectFunSuite with RemoteSparkSess getter(rs) } assert(exception.getMessage() === - "The column index is out of range: 998, number of columns: 1.") + "The column index is out of range: 999, number of columns: 1.") } } } From ed8f0c04f18c060d1fbbcd2acf3fe2cc49b4a3f2 Mon Sep 17 00:00:00 2001 From: cty123 Date: Tue, 11 Nov 2025 00:58:11 -0500 Subject: [PATCH 10/12] rename column index variable --- .../connect/client/jdbc/SparkConnectResultSet.scala | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala index 773371b66113..247edf1a2510 100644 --- a/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala +++ b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala @@ -76,21 +76,21 @@ class SparkConnectResultSet( } } - private[jdbc] def getColumnValue[T](index: Int, defaultVal: T)(getter: Int => T): T = { + private[jdbc] def getColumnValue[T](columnIndex: Int, defaultVal: T)(getter: Int => T): T = { checkOpen() // the passed index value is 1-indexed, but the underlying array is 0-indexed - val columnIndex = index - 1 - if (columnIndex < 0 || columnIndex >= currentRow.length) { - throw new SQLException(s"The column index is out of range: $index, " + + val index = columnIndex - 1 + if (index < 0 || index >= currentRow.length) { + throw new SQLException(s"The column index is out of range: $columnIndex, " + s"number of columns: ${currentRow.length}.") } - if (currentRow.isNullAt(columnIndex)) { + if (currentRow.isNullAt(index)) { _wasNull = true defaultVal } else { _wasNull = false - getter(columnIndex) + getter(index) } } From 3ba3f648ab58c242ae0faef39d48f634b6bfca50 Mon Sep 17 00:00:00 2001 From: cty Date: Tue, 11 Nov 2025 10:12:28 -0500 Subject: [PATCH 11/12] Update sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala Co-authored-by: YangJie --- .../spark/sql/connect/client/jdbc/SparkConnectResultSet.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala index 247edf1a2510..c1234be9e9b8 100644 --- a/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala +++ b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala @@ -119,7 +119,7 @@ class SparkConnectResultSet( } override def getInt(columnIndex: Int): Int = { - getColumnValue(columnIndex, 0.toInt) { idx => currentRow.getInt(idx) } + getColumnValue(columnIndex, 0) { idx => currentRow.getInt(idx) } } override def getLong(columnIndex: Int): Long = { From eecc9dfbb0bd0a5894e2040a226170aed6f025c8 Mon Sep 17 00:00:00 2001 From: cty123 Date: Tue, 11 Nov 2025 10:15:19 -0500 Subject: [PATCH 12/12] remove jdbc in private function --- .../spark/sql/connect/client/jdbc/SparkConnectResultSet.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala index c1234be9e9b8..ff02cd73dcc6 100644 --- a/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala +++ b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala @@ -76,7 +76,7 @@ class SparkConnectResultSet( } } - private[jdbc] def getColumnValue[T](columnIndex: Int, defaultVal: T)(getter: Int => T): T = { + private def getColumnValue[T](columnIndex: Int, defaultVal: T)(getter: Int => T): T = { checkOpen() // the passed index value is 1-indexed, but the underlying array is 0-indexed val index = columnIndex - 1