diff --git a/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala index 23c2315400ff..ff02cd73dcc6 100644 --- a/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala +++ b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala @@ -76,6 +76,24 @@ class SparkConnectResultSet( } } + private def getColumnValue[T](columnIndex: Int, defaultVal: T)(getter: Int => T): T = { + checkOpen() + // the passed index value is 1-indexed, but the underlying array is 0-indexed + val index = columnIndex - 1 + if (index < 0 || index >= currentRow.length) { + throw new SQLException(s"The column index is out of range: $columnIndex, " + + s"number of columns: ${currentRow.length}.") + } + + if (currentRow.isNullAt(index)) { + _wasNull = true + defaultVal + } else { + _wasNull = false + getter(index) + } + } + override def findColumn(columnLabel: String): Int = { sparkResult.schema.getFieldIndex(columnLabel) match { case Some(i) => i + 1 @@ -85,75 +103,35 @@ class SparkConnectResultSet( } override def getString(columnIndex: Int): String = { - if (currentRow.isNullAt(columnIndex - 1)) { - _wasNull = true - return null - } - _wasNull = false - String.valueOf(currentRow.get(columnIndex - 1)) + getColumnValue(columnIndex, null: String) { idx => String.valueOf(currentRow.get(idx)) } } override def getBoolean(columnIndex: Int): Boolean = { - if (currentRow.isNullAt(columnIndex - 1)) { - _wasNull = true - return false - } - _wasNull = false - currentRow.getBoolean(columnIndex - 1) + getColumnValue(columnIndex, false) { idx => currentRow.getBoolean(idx) } } override def getByte(columnIndex: Int): Byte = { - if (currentRow.isNullAt(columnIndex - 1)) { - _wasNull = true - return 0.toByte - } - _wasNull = false - currentRow.getByte(columnIndex - 1) + getColumnValue(columnIndex, 0.toByte) { idx => currentRow.getByte(idx) } } override def getShort(columnIndex: Int): Short = { - if (currentRow.isNullAt(columnIndex - 1)) { - _wasNull = true - return 0.toShort - } - _wasNull = false - currentRow.getShort(columnIndex - 1) + getColumnValue(columnIndex, 0.toShort) { idx => currentRow.getShort(idx) } } override def getInt(columnIndex: Int): Int = { - if (currentRow.isNullAt(columnIndex - 1)) { - _wasNull = true - return 0 - } - _wasNull = false - currentRow.getInt(columnIndex - 1) + getColumnValue(columnIndex, 0) { idx => currentRow.getInt(idx) } } override def getLong(columnIndex: Int): Long = { - if (currentRow.isNullAt(columnIndex - 1)) { - _wasNull = true - return 0L - } - _wasNull = false - currentRow.getLong(columnIndex - 1) + getColumnValue(columnIndex, 0.toLong) { idx => currentRow.getLong(idx) } } override def getFloat(columnIndex: Int): Float = { - if (currentRow.isNullAt(columnIndex - 1)) { - _wasNull = true - return 0.toFloat - } - _wasNull = false - currentRow.getFloat(columnIndex - 1) + getColumnValue(columnIndex, 0.toFloat) { idx => currentRow.getFloat(idx) } } override def getDouble(columnIndex: Int): Double = { - if (currentRow.isNullAt(columnIndex - 1)) { - _wasNull = true - return 0.toDouble - } - _wasNull = false - currentRow.getDouble(columnIndex - 1) + getColumnValue(columnIndex, 0.toDouble) { idx => currentRow.getDouble(idx) } } override def getBigDecimal(columnIndex: Int, scale: Int): java.math.BigDecimal = @@ -240,12 +218,9 @@ class SparkConnectResultSet( } override def getObject(columnIndex: Int): AnyRef = { - if (currentRow.isNullAt(columnIndex - 1)) { - _wasNull = true - return null + getColumnValue(columnIndex, null: AnyRef) { idx => + currentRow.get(idx).asInstanceOf[AnyRef] } - _wasNull = false - currentRow.get(columnIndex - 1).asInstanceOf[AnyRef] } override def getObject(columnLabel: String): AnyRef = @@ -258,12 +233,9 @@ class SparkConnectResultSet( throw new SQLFeatureNotSupportedException override def getBigDecimal(columnIndex: Int): java.math.BigDecimal = { - if (currentRow.isNullAt(columnIndex - 1)) { - _wasNull = true - return null + getColumnValue(columnIndex, null: java.math.BigDecimal) { idx => + currentRow.getDecimal(idx) } - _wasNull = false - currentRow.getDecimal(columnIndex - 1) } override def getBigDecimal(columnLabel: String): java.math.BigDecimal = diff --git a/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectJdbcDataTypeSuite.scala b/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectJdbcDataTypeSuite.scala index 089c1d7fdf0d..217142287b13 100644 --- a/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectJdbcDataTypeSuite.scala +++ b/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectJdbcDataTypeSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.connect.client.jdbc -import java.sql.Types +import java.sql.{ResultSet, SQLException, Types} import org.apache.spark.sql.connect.client.jdbc.test.JdbcHelper import org.apache.spark.sql.connect.test.{ConnectFunSuite, RemoteSparkSession} @@ -248,4 +248,57 @@ class SparkConnectJdbcDataTypeSuite extends ConnectFunSuite with RemoteSparkSess } } } + + test("getter functions column index out of bound") { + Seq( + ("'foo'", (rs: ResultSet) => rs.getString(999)), + ("true", (rs: ResultSet) => rs.getBoolean(999)), + ("cast(1 AS BYTE)", (rs: ResultSet) => rs.getByte(999)), + ("cast(1 AS SHORT)", (rs: ResultSet) => rs.getShort(999)), + ("cast(1 AS INT)", (rs: ResultSet) => rs.getInt(999)), + ("cast(1 AS BIGINT)", (rs: ResultSet) => rs.getLong(999)), + ("cast(1 AS FLOAT)", (rs: ResultSet) => rs.getFloat(999)), + ("cast(1 AS DOUBLE)", (rs: ResultSet) => rs.getDouble(999)), + ("cast(1 AS DECIMAL(10,5))", (rs: ResultSet) => rs.getBigDecimal(999)) + ).foreach { + case (query, getter) => + withExecuteQuery(s"SELECT $query") { rs => + assert(rs.next()) + val exception = intercept[SQLException] { + getter(rs) + } + assert(exception.getMessage() === + "The column index is out of range: 999, number of columns: 1.") + } + } + } + + test("getter functions called after statement closed") { + Seq( + ("'foo'", (rs: ResultSet) => rs.getString(1), "foo"), + ("true", (rs: ResultSet) => rs.getBoolean(1), true), + ("cast(1 AS BYTE)", (rs: ResultSet) => rs.getByte(1), 1.toByte), + ("cast(1 AS SHORT)", (rs: ResultSet) => rs.getShort(1), 1.toShort), + ("cast(1 AS INT)", (rs: ResultSet) => rs.getInt(1), 1.toInt), + ("cast(1 AS BIGINT)", (rs: ResultSet) => rs.getLong(1), 1.toLong), + ("cast(1 AS FLOAT)", (rs: ResultSet) => rs.getFloat(1), 1.toFloat), + ("cast(1 AS DOUBLE)", (rs: ResultSet) => rs.getDouble(1), 1.toDouble), + ("cast(1 AS DECIMAL(10,5))", (rs: ResultSet) => rs.getBigDecimal(1), + new java.math.BigDecimal("1.00000")) + ).foreach { + case (query, getter, expectedValue) => + var resultSet: Option[ResultSet] = None + withExecuteQuery(s"SELECT $query") { rs => + assert(rs.next()) + assert(getter(rs) === expectedValue) + assert(!rs.wasNull) + resultSet = Some(rs) + } + assert(resultSet.isDefined) + val exception = intercept[SQLException] { + getter(resultSet.get) + } + assert(exception.getMessage() === "JDBC Statement is closed.") + } + } }