Skip to content

Commit

Permalink
Add nullability inference support to dataframe-jdbc (#672)
Browse files Browse the repository at this point in the history
* Add nullability inference support to dataframe-jdbc

This update adds a new parameter `inferNullability` for various dataframe reading functions including `readSqlTable`, `readSqlQuery`, and `readResultSet`. It allows better control over how column nullability should be inferred. The `h2Test` file has been adjusted to test this new feature.

* Refactor SQL query requirement message for readability

The code format of the requirement message for SQL query validation in the readJdbc.kt file has been improved. The change enhances readability by wrapping the requirement message into its own block, splitting the long string into two separate lines rather than extending it across one long line.

* Update inferNullability from Infer to Boolean

This commit changes the inferNullability parameter from 'Infer' to a Boolean type in functions of readJdbc.kt and adjusts related function calls in h2Test.kt. Now, inferNullability takes a Boolean value with 'true' indicating Inference and 'false' meaning no inference, making it more intuitive and easier to use.

* Remove unnecessary inferNulls call in readJdbc

The Infer.Nulls call was redundant and has been removed from the readJdbc.kt file. This simplifies the code of reading JDBC in the DataFrame-JDBC module, without altering functionality.
  • Loading branch information
zaleslaw committed Apr 29, 2024
1 parent 7de6022 commit ae1692d
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import io.github.oshai.kotlinlogging.KotlinLogging
import org.jetbrains.kotlinx.dataframe.AnyFrame
import org.jetbrains.kotlinx.dataframe.DataColumn
import org.jetbrains.kotlinx.dataframe.DataFrame
import org.jetbrains.kotlinx.dataframe.api.Infer
import org.jetbrains.kotlinx.dataframe.api.toDataFrame
import org.jetbrains.kotlinx.dataframe.impl.schema.DataFrameSchemaImpl
import org.jetbrains.kotlinx.dataframe.io.db.DbType
Expand Down Expand Up @@ -105,15 +106,17 @@ public data class DatabaseConfiguration(val url: String, val user: String = "",
* @param [dbConfig] the configuration for the database, including URL, user, and password.
* @param [tableName] the name of the table to read data from.
* @param [limit] the maximum number of rows to retrieve from the table.
* @param [inferNullability] indicates how the column nullability should be inferred.
* @return the DataFrame containing the data from the SQL table.
*/
public fun DataFrame.Companion.readSqlTable(
dbConfig: DatabaseConfiguration,
tableName: String,
limit: Int = DEFAULT_LIMIT
limit: Int = DEFAULT_LIMIT,
inferNullability: Boolean = true,
): AnyFrame {
DriverManager.getConnection(dbConfig.url, dbConfig.user, dbConfig.password).use { connection ->
return readSqlTable(connection, tableName, limit)
return readSqlTable(connection, tableName, limit, inferNullability)
}
}

Expand All @@ -123,14 +126,16 @@ public fun DataFrame.Companion.readSqlTable(
* @param [connection] the database connection to read tables from.
* @param [tableName] the name of the table to read data from.
* @param [limit] the maximum number of rows to retrieve from the table.
* @param [inferNullability] indicates how the column nullability should be inferred.
* @return the DataFrame containing the data from the SQL table.
*
* @see DriverManager.getConnection
*/
public fun DataFrame.Companion.readSqlTable(
connection: Connection,
tableName: String,
limit: Int = DEFAULT_LIMIT
limit: Int = DEFAULT_LIMIT,
inferNullability: Boolean = true,
): AnyFrame {
var preparedQuery = "SELECT * FROM $tableName"
if (limit > 0) preparedQuery += " LIMIT $limit"
Expand All @@ -145,7 +150,7 @@ public fun DataFrame.Companion.readSqlTable(
preparedQuery
).use { rs ->
val tableColumns = getTableColumnsMetadata(rs)
return fetchAndConvertDataFromResultSet(tableColumns, rs, dbType, limit)
return fetchAndConvertDataFromResultSet(tableColumns, rs, dbType, limit, inferNullability)
}
}
}
Expand All @@ -159,15 +164,17 @@ public fun DataFrame.Companion.readSqlTable(
* @param [dbConfig] the database configuration to connect to the database, including URL, user, and password.
* @param [sqlQuery] the SQL query to execute.
* @param [limit] the maximum number of rows to retrieve from the result of the SQL query execution.
* @param [inferNullability] indicates how the column nullability should be inferred.
* @return the DataFrame containing the result of the SQL query.
*/
public fun DataFrame.Companion.readSqlQuery(
dbConfig: DatabaseConfiguration,
sqlQuery: String,
limit: Int = DEFAULT_LIMIT
limit: Int = DEFAULT_LIMIT,
inferNullability: Boolean = true,
): AnyFrame {
DriverManager.getConnection(dbConfig.url, dbConfig.user, dbConfig.password).use { connection ->
return readSqlQuery(connection, sqlQuery, limit)
return readSqlQuery(connection, sqlQuery, limit, inferNullability)
}
}

Expand All @@ -180,16 +187,21 @@ public fun DataFrame.Companion.readSqlQuery(
* @param [connection] the database connection to execute the SQL query.
* @param [sqlQuery] the SQL query to execute.
* @param [limit] the maximum number of rows to retrieve from the result of the SQL query execution.
* @param [inferNullability] indicates how the column nullability should be inferred.
* @return the DataFrame containing the result of the SQL query.
*
* @see DriverManager.getConnection
*/
public fun DataFrame.Companion.readSqlQuery(
connection: Connection,
sqlQuery: String,
limit: Int = DEFAULT_LIMIT
limit: Int = DEFAULT_LIMIT,
inferNullability: Boolean = true,
): AnyFrame {
require(isValid(sqlQuery)) { "SQL query should start from SELECT and contain one query for reading data without any manipulation. " }
require(isValid(sqlQuery)) {
"SQL query should start from SELECT and contain one query for reading data without any manipulation. " +
"Also it should not contain any separators like `;`."
}

val url = connection.metaData.url
val dbType = extractDBTypeFromUrl(url)
Expand All @@ -202,12 +214,12 @@ public fun DataFrame.Companion.readSqlQuery(
connection.createStatement().use { st ->
st.executeQuery(internalSqlQuery).use { rs ->
val tableColumns = getTableColumnsMetadata(rs)
return fetchAndConvertDataFromResultSet(tableColumns, rs, dbType, DEFAULT_LIMIT)
return fetchAndConvertDataFromResultSet(tableColumns, rs, dbType, limit, inferNullability)
}
}
}

/** SQL-query is accepted only if it starts from SELECT */
/** SQL query is accepted only if it starts from SELECT */
private fun isValid(sqlQuery: String): Boolean {
val normalizedSqlQuery = sqlQuery.trim().uppercase()

Expand All @@ -221,15 +233,17 @@ private fun isValid(sqlQuery: String): Boolean {
* @param [resultSet] the [ResultSet] containing the data to read.
* @param [dbType] the type of database that the [ResultSet] belongs to.
* @param [limit] the maximum number of rows to read from the [ResultSet].
* @param [inferNullability] indicates how the column nullability should be inferred.
* @return the DataFrame generated from the [ResultSet] data.
*/
public fun DataFrame.Companion.readResultSet(
resultSet: ResultSet,
dbType: DbType,
limit: Int = DEFAULT_LIMIT
limit: Int = DEFAULT_LIMIT,
inferNullability: Boolean = true,
): AnyFrame {
val tableColumns = getTableColumnsMetadata(resultSet)
return fetchAndConvertDataFromResultSet(tableColumns, resultSet, dbType, limit)
return fetchAndConvertDataFromResultSet(tableColumns, resultSet, dbType, limit, inferNullability)
}

/**
Expand All @@ -238,33 +252,38 @@ public fun DataFrame.Companion.readResultSet(
* @param [resultSet] the [ResultSet] containing the data to read.
* @param [connection] the connection to the database (it's required to extract the database type).
* @param [limit] the maximum number of rows to read from the [ResultSet].
* @param [inferNullability] indicates how the column nullability should be inferred.
* @return the DataFrame generated from the [ResultSet] data.
*/
public fun DataFrame.Companion.readResultSet(
resultSet: ResultSet,
connection: Connection,
limit: Int = DEFAULT_LIMIT
limit: Int = DEFAULT_LIMIT,
inferNullability: Boolean = true,
): AnyFrame {
val url = connection.metaData.url
val dbType = extractDBTypeFromUrl(url)

return readResultSet(resultSet, dbType, limit)
return readResultSet(resultSet, dbType, limit, inferNullability)
}

/**
* Reads all tables from the given database using the provided database configuration and limit.
*
* @param [dbConfig] the database configuration to connect to the database, including URL, user, and password.
* @param [limit] the maximum number of rows to read from each table.
* @param [catalogue] a name of the catalog from which tables will be retrieved. A null value retrieves tables from all catalogs.
* @param [inferNullability] indicates how the column nullability should be inferred.
* @return a list of [AnyFrame] objects representing the non-system tables from the database.
*/
public fun DataFrame.Companion.readAllSqlTables(
dbConfig: DatabaseConfiguration,
catalogue: String? = null,
limit: Int = DEFAULT_LIMIT
limit: Int = DEFAULT_LIMIT,
inferNullability: Boolean = true,
): List<AnyFrame> {
DriverManager.getConnection(dbConfig.url, dbConfig.user, dbConfig.password).use { connection ->
return readAllSqlTables(connection, catalogue, limit)
return readAllSqlTables(connection, catalogue, limit, inferNullability)
}
}

Expand All @@ -273,14 +292,17 @@ public fun DataFrame.Companion.readAllSqlTables(
*
* @param [connection] the database connection to read tables from.
* @param [limit] the maximum number of rows to read from each table.
* @param [catalogue] a name of the catalog from which tables will be retrieved. A null value retrieves tables from all catalogs.
* @param [inferNullability] indicates how the column nullability should be inferred.
* @return a list of [AnyFrame] objects representing the non-system tables from the database.
*
* @see DriverManager.getConnection
*/
public fun DataFrame.Companion.readAllSqlTables(
connection: Connection,
catalogue: String? = null,
limit: Int = DEFAULT_LIMIT
limit: Int = DEFAULT_LIMIT,
inferNullability: Boolean = true,
): List<AnyFrame> {
val metaData = connection.metaData
val url = connection.metaData.url
Expand All @@ -304,7 +326,7 @@ public fun DataFrame.Companion.readAllSqlTables(
// could be Dialect/Database specific
logger.debug { "Reading table: $tableName" }

val dataFrame = readSqlTable(connection, tableName, limit)
val dataFrame = readSqlTable(connection, tableName, limit, inferNullability)
dataFrames += dataFrame
logger.debug { "Finished reading table: $tableName" }
}
Expand Down Expand Up @@ -450,7 +472,7 @@ public fun DataFrame.Companion.getSchemaForAllSqlTables(connection: Connection):
val dbType = extractDBTypeFromUrl(url)

val tableTypes = arrayOf("TABLE")
// exclude system and other tables without data
// exclude a system and other tables without data
val tables = metaData.getTables(null, null, null, tableTypes)

val dataFrameSchemas = mutableListOf<DataFrameSchema>()
Expand Down Expand Up @@ -561,13 +583,15 @@ private fun manageColumnNameDuplication(columnNameCounter: MutableMap<String, In
* @param [rs] the ResultSet object containing the data to be fetched and converted.
* @param [dbType] the type of the database.
* @param [limit] the maximum number of rows to fetch and convert.
* @param [inferNullability] indicates how the column nullability should be inferred.
* @return A mutable map containing the fetched and converted data.
*/
private fun fetchAndConvertDataFromResultSet(
tableColumns: MutableList<TableColumnMetadata>,
rs: ResultSet,
dbType: DbType,
limit: Int
limit: Int,
inferNullability: Boolean,
): AnyFrame {
val data = List(tableColumns.size) { mutableListOf<Any?>() }

Expand Down Expand Up @@ -596,6 +620,7 @@ private fun fetchAndConvertDataFromResultSet(
DataColumn.createValueColumn(
name = tableColumns[index].name,
values = values,
infer = convertNullabilityInference(inferNullability),
type = kotlinTypesForSqlColumns[index]!!
)
}.toDataFrame()
Expand All @@ -605,6 +630,8 @@ private fun fetchAndConvertDataFromResultSet(
return dataFrame
}

private fun convertNullabilityInference(inferNullability: Boolean) = if (inferNullability) Infer.Nulls else Infer.None

private fun extractNewRowFromResultSetAndAddToData(
tableColumns: MutableList<TableColumnMetadata>,
data: List<MutableList<Any?>>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,7 @@ import org.h2.jdbc.JdbcSQLSyntaxErrorException
import org.intellij.lang.annotations.Language
import org.jetbrains.kotlinx.dataframe.DataFrame
import org.jetbrains.kotlinx.dataframe.annotations.DataSchema
import org.jetbrains.kotlinx.dataframe.api.add
import org.jetbrains.kotlinx.dataframe.api.cast
import org.jetbrains.kotlinx.dataframe.api.filter
import org.jetbrains.kotlinx.dataframe.api.select
import org.jetbrains.kotlinx.dataframe.api.*
import org.jetbrains.kotlinx.dataframe.io.db.H2
import org.junit.AfterClass
import org.junit.BeforeClass
Expand Down Expand Up @@ -677,4 +674,115 @@ class JdbcTest {
saleDataSchema1.columns.size shouldBe 3
saleDataSchema1.columns["amount"]!!.type shouldBe typeOf<BigDecimal>()
}

@Test
fun `infer nullability`() {
// prepare tables and data
@Language("SQL")
val createTestTable1Query = """
CREATE TABLE TestTable1 (
id INT PRIMARY KEY,
name VARCHAR(50),
surname VARCHAR(50),
age INT NOT NULL
)
"""

connection.createStatement().execute(createTestTable1Query)

connection.createStatement().execute("INSERT INTO TestTable1 (id, name, surname, age) VALUES (1, 'John', 'Crawford', 40)")
connection.createStatement().execute("INSERT INTO TestTable1 (id, name, surname, age) VALUES (2, 'Alice', 'Smith', 25)")
connection.createStatement().execute("INSERT INTO TestTable1 (id, name, surname, age) VALUES (3, 'Bob', 'Johnson', 47)")
connection.createStatement().execute("INSERT INTO TestTable1 (id, name, surname, age) VALUES (4, 'Sam', NULL, 15)")

// start testing `readSqlTable` method

// with default inferNullability: Boolean = true
val tableName = "TestTable1"
val df = DataFrame.readSqlTable(connection, tableName)
df.schema().columns["id"]!!.type shouldBe typeOf<Int>()
df.schema().columns["name"]!!.type shouldBe typeOf<String>()
df.schema().columns["surname"]!!.type shouldBe typeOf<String?>()
df.schema().columns["age"]!!.type shouldBe typeOf<Int>()

val dataSchema = DataFrame.getSchemaForSqlTable(connection, tableName)
dataSchema.columns.size shouldBe 4
dataSchema.columns["id"]!!.type shouldBe typeOf<Int>()
dataSchema.columns["name"]!!.type shouldBe typeOf<String?>()
dataSchema.columns["surname"]!!.type shouldBe typeOf<String?>()
dataSchema.columns["age"]!!.type shouldBe typeOf<Int>()

// with inferNullability: Boolean = false
val df1 = DataFrame.readSqlTable(connection, tableName, inferNullability = false)
df1.schema().columns["id"]!!.type shouldBe typeOf<Int>()
df1.schema().columns["name"]!!.type shouldBe typeOf<String?>() // <=== this column changed a type because it doesn't contain nulls
df1.schema().columns["surname"]!!.type shouldBe typeOf<String?>()
df1.schema().columns["age"]!!.type shouldBe typeOf<Int>()

// end testing `readSqlTable` method

// start testing `readSQLQuery` method

// ith default inferNullability: Boolean = true
@Language("SQL")
val sqlQuery = """
SELECT name, surname, age FROM TestTable1
""".trimIndent()

val df2 = DataFrame.readSqlQuery(connection, sqlQuery)
df2.schema().columns["name"]!!.type shouldBe typeOf<String>()
df2.schema().columns["surname"]!!.type shouldBe typeOf<String?>()
df2.schema().columns["age"]!!.type shouldBe typeOf<Int>()

val dataSchema2 = DataFrame.getSchemaForSqlQuery(connection, sqlQuery)
dataSchema2.columns.size shouldBe 3
dataSchema2.columns["name"]!!.type shouldBe typeOf<String?>()
dataSchema2.columns["surname"]!!.type shouldBe typeOf<String?>()
dataSchema2.columns["age"]!!.type shouldBe typeOf<Int>()

// with inferNullability: Boolean = false
val df3 = DataFrame.readSqlQuery(connection, sqlQuery, inferNullability = false)
df3.schema().columns["name"]!!.type shouldBe typeOf<String?>() // <=== this column changed a type because it doesn't contain nulls
df3.schema().columns["surname"]!!.type shouldBe typeOf<String?>()
df3.schema().columns["age"]!!.type shouldBe typeOf<Int>()

// end testing `readSQLQuery` method

// start testing `readResultSet` method

connection.createStatement(ResultSet.TYPE_SCROLL_SENSITIVE, ResultSet.CONCUR_UPDATABLE).use { st ->
@Language("SQL")
val selectStatement = "SELECT * FROM TestTable1"

st.executeQuery(selectStatement).use { rs ->
// ith default inferNullability: Boolean = true
val df4 = DataFrame.readResultSet(rs, H2)
df4.schema().columns["id"]!!.type shouldBe typeOf<Int>()
df4.schema().columns["name"]!!.type shouldBe typeOf<String>()
df4.schema().columns["surname"]!!.type shouldBe typeOf<String?>()
df4.schema().columns["age"]!!.type shouldBe typeOf<Int>()

rs.beforeFirst()

val dataSchema3 = DataFrame.getSchemaForResultSet(rs, H2)
dataSchema3.columns.size shouldBe 4
dataSchema3.columns["id"]!!.type shouldBe typeOf<Int>()
dataSchema3.columns["name"]!!.type shouldBe typeOf<String?>()
dataSchema3.columns["surname"]!!.type shouldBe typeOf<String?>()
dataSchema3.columns["age"]!!.type shouldBe typeOf<Int>()

// with inferNullability: Boolean = false
rs.beforeFirst()

val df5 = DataFrame.readResultSet(rs, H2, inferNullability = false)
df5.schema().columns["id"]!!.type shouldBe typeOf<Int>()
df5.schema().columns["name"]!!.type shouldBe typeOf<String?>() // <=== this column changed a type because it doesn't contain nulls
df5.schema().columns["surname"]!!.type shouldBe typeOf<String?>()
df5.schema().columns["age"]!!.type shouldBe typeOf<Int>()
}
}
// end testing `readResultSet` method

connection.createStatement().execute("DROP TABLE TestTable1")
}
}

0 comments on commit ae1692d

Please sign in to comment.