Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add nullability inference support to dataframe-jdbc #672

Merged
merged 4 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import io.github.oshai.kotlinlogging.KotlinLogging
import org.jetbrains.kotlinx.dataframe.AnyFrame
import org.jetbrains.kotlinx.dataframe.DataColumn
import org.jetbrains.kotlinx.dataframe.DataFrame
import org.jetbrains.kotlinx.dataframe.api.Infer
import org.jetbrains.kotlinx.dataframe.api.toDataFrame
import org.jetbrains.kotlinx.dataframe.impl.schema.DataFrameSchemaImpl
import org.jetbrains.kotlinx.dataframe.io.db.DbType
Expand Down Expand Up @@ -105,15 +106,17 @@ public data class DatabaseConfiguration(val url: String, val user: String = "",
* @param [dbConfig] the configuration for the database, including URL, user, and password.
* @param [tableName] the name of the table to read data from.
* @param [limit] the maximum number of rows to retrieve from the table.
* @param [inferNullability] indicates how the column nullability should be inferred.
* @return the DataFrame containing the data from the SQL table.
*/
public fun DataFrame.Companion.readSqlTable(
dbConfig: DatabaseConfiguration,
tableName: String,
limit: Int = DEFAULT_LIMIT
limit: Int = DEFAULT_LIMIT,
inferNullability: Boolean = true,
): AnyFrame {
DriverManager.getConnection(dbConfig.url, dbConfig.user, dbConfig.password).use { connection ->
return readSqlTable(connection, tableName, limit)
return readSqlTable(connection, tableName, limit, inferNullability)
}
}

Expand All @@ -123,14 +126,16 @@ public fun DataFrame.Companion.readSqlTable(
* @param [connection] the database connection to read tables from.
* @param [tableName] the name of the table to read data from.
* @param [limit] the maximum number of rows to retrieve from the table.
* @param [inferNullability] indicates how the column nullability should be inferred.
* @return the DataFrame containing the data from the SQL table.
*
* @see DriverManager.getConnection
*/
public fun DataFrame.Companion.readSqlTable(
connection: Connection,
tableName: String,
limit: Int = DEFAULT_LIMIT
limit: Int = DEFAULT_LIMIT,
inferNullability: Boolean = true,
): AnyFrame {
var preparedQuery = "SELECT * FROM $tableName"
if (limit > 0) preparedQuery += " LIMIT $limit"
Expand All @@ -145,7 +150,7 @@ public fun DataFrame.Companion.readSqlTable(
preparedQuery
).use { rs ->
val tableColumns = getTableColumnsMetadata(rs)
return fetchAndConvertDataFromResultSet(tableColumns, rs, dbType, limit)
return fetchAndConvertDataFromResultSet(tableColumns, rs, dbType, limit, inferNullability)
}
}
}
Expand All @@ -159,15 +164,17 @@ public fun DataFrame.Companion.readSqlTable(
* @param [dbConfig] the database configuration to connect to the database, including URL, user, and password.
* @param [sqlQuery] the SQL query to execute.
* @param [limit] the maximum number of rows to retrieve from the result of the SQL query execution.
* @param [inferNullability] indicates how the column nullability should be inferred.
* @return the DataFrame containing the result of the SQL query.
*/
public fun DataFrame.Companion.readSqlQuery(
dbConfig: DatabaseConfiguration,
sqlQuery: String,
limit: Int = DEFAULT_LIMIT
limit: Int = DEFAULT_LIMIT,
inferNullability: Boolean = true,
): AnyFrame {
DriverManager.getConnection(dbConfig.url, dbConfig.user, dbConfig.password).use { connection ->
return readSqlQuery(connection, sqlQuery, limit)
return readSqlQuery(connection, sqlQuery, limit, inferNullability)
}
}

Expand All @@ -180,16 +187,21 @@ public fun DataFrame.Companion.readSqlQuery(
* @param [connection] the database connection to execute the SQL query.
* @param [sqlQuery] the SQL query to execute.
* @param [limit] the maximum number of rows to retrieve from the result of the SQL query execution.
* @param [inferNullability] indicates how the column nullability should be inferred.
* @return the DataFrame containing the result of the SQL query.
*
* @see DriverManager.getConnection
*/
public fun DataFrame.Companion.readSqlQuery(
connection: Connection,
sqlQuery: String,
limit: Int = DEFAULT_LIMIT
limit: Int = DEFAULT_LIMIT,
inferNullability: Boolean = true,
): AnyFrame {
require(isValid(sqlQuery)) { "SQL query should start from SELECT and contain one query for reading data without any manipulation. " }
require(isValid(sqlQuery)) {
"SQL query should start from SELECT and contain one query for reading data without any manipulation. " +
"Also it should not contain any separators like `;`."
}

val url = connection.metaData.url
val dbType = extractDBTypeFromUrl(url)
Expand All @@ -202,12 +214,12 @@ public fun DataFrame.Companion.readSqlQuery(
connection.createStatement().use { st ->
st.executeQuery(internalSqlQuery).use { rs ->
val tableColumns = getTableColumnsMetadata(rs)
return fetchAndConvertDataFromResultSet(tableColumns, rs, dbType, DEFAULT_LIMIT)
return fetchAndConvertDataFromResultSet(tableColumns, rs, dbType, limit, inferNullability)
}
}
}

/** SQL-query is accepted only if it starts from SELECT */
/** SQL query is accepted only if it starts from SELECT */
private fun isValid(sqlQuery: String): Boolean {
val normalizedSqlQuery = sqlQuery.trim().uppercase()

Expand All @@ -221,15 +233,17 @@ private fun isValid(sqlQuery: String): Boolean {
* @param [resultSet] the [ResultSet] containing the data to read.
* @param [dbType] the type of database that the [ResultSet] belongs to.
* @param [limit] the maximum number of rows to read from the [ResultSet].
* @param [inferNullability] indicates how the column nullability should be inferred.
* @return the DataFrame generated from the [ResultSet] data.
*/
public fun DataFrame.Companion.readResultSet(
resultSet: ResultSet,
dbType: DbType,
limit: Int = DEFAULT_LIMIT
limit: Int = DEFAULT_LIMIT,
inferNullability: Boolean = true,
): AnyFrame {
val tableColumns = getTableColumnsMetadata(resultSet)
return fetchAndConvertDataFromResultSet(tableColumns, resultSet, dbType, limit)
return fetchAndConvertDataFromResultSet(tableColumns, resultSet, dbType, limit, inferNullability)
}

/**
Expand All @@ -238,33 +252,38 @@ public fun DataFrame.Companion.readResultSet(
* @param [resultSet] the [ResultSet] containing the data to read.
* @param [connection] the connection to the database (it's required to extract the database type).
* @param [limit] the maximum number of rows to read from the [ResultSet].
* @param [inferNullability] indicates how the column nullability should be inferred.
* @return the DataFrame generated from the [ResultSet] data.
*/
public fun DataFrame.Companion.readResultSet(
resultSet: ResultSet,
connection: Connection,
limit: Int = DEFAULT_LIMIT
limit: Int = DEFAULT_LIMIT,
inferNullability: Boolean = true,
): AnyFrame {
val url = connection.metaData.url
val dbType = extractDBTypeFromUrl(url)

return readResultSet(resultSet, dbType, limit)
return readResultSet(resultSet, dbType, limit, inferNullability)
}

/**
* Reads all tables from the given database using the provided database configuration and limit.
*
* @param [dbConfig] the database configuration to connect to the database, including URL, user, and password.
* @param [limit] the maximum number of rows to read from each table.
* @param [catalogue] a name of the catalog from which tables will be retrieved. A null value retrieves tables from all catalogs.
* @param [inferNullability] indicates how the column nullability should be inferred.
* @return a list of [AnyFrame] objects representing the non-system tables from the database.
*/
public fun DataFrame.Companion.readAllSqlTables(
dbConfig: DatabaseConfiguration,
catalogue: String? = null,
limit: Int = DEFAULT_LIMIT
limit: Int = DEFAULT_LIMIT,
inferNullability: Boolean = true,
): List<AnyFrame> {
DriverManager.getConnection(dbConfig.url, dbConfig.user, dbConfig.password).use { connection ->
return readAllSqlTables(connection, catalogue, limit)
return readAllSqlTables(connection, catalogue, limit, inferNullability)
}
}

Expand All @@ -273,14 +292,17 @@ public fun DataFrame.Companion.readAllSqlTables(
*
* @param [connection] the database connection to read tables from.
* @param [limit] the maximum number of rows to read from each table.
* @param [catalogue] a name of the catalog from which tables will be retrieved. A null value retrieves tables from all catalogs.
* @param [inferNullability] indicates how the column nullability should be inferred.
* @return a list of [AnyFrame] objects representing the non-system tables from the database.
*
* @see DriverManager.getConnection
*/
public fun DataFrame.Companion.readAllSqlTables(
connection: Connection,
catalogue: String? = null,
limit: Int = DEFAULT_LIMIT
limit: Int = DEFAULT_LIMIT,
inferNullability: Boolean = true,
): List<AnyFrame> {
val metaData = connection.metaData
val url = connection.metaData.url
Expand All @@ -304,7 +326,7 @@ public fun DataFrame.Companion.readAllSqlTables(
// could be Dialect/Database specific
logger.debug { "Reading table: $tableName" }

val dataFrame = readSqlTable(connection, tableName, limit)
val dataFrame = readSqlTable(connection, tableName, limit, inferNullability)
dataFrames += dataFrame
logger.debug { "Finished reading table: $tableName" }
}
Expand Down Expand Up @@ -450,7 +472,7 @@ public fun DataFrame.Companion.getSchemaForAllSqlTables(connection: Connection):
val dbType = extractDBTypeFromUrl(url)

val tableTypes = arrayOf("TABLE")
// exclude system and other tables without data
// exclude a system and other tables without data
val tables = metaData.getTables(null, null, null, tableTypes)

val dataFrameSchemas = mutableListOf<DataFrameSchema>()
Expand Down Expand Up @@ -561,13 +583,15 @@ private fun manageColumnNameDuplication(columnNameCounter: MutableMap<String, In
* @param [rs] the ResultSet object containing the data to be fetched and converted.
* @param [dbType] the type of the database.
* @param [limit] the maximum number of rows to fetch and convert.
* @param [inferNullability] indicates how the column nullability should be inferred.
* @return A mutable map containing the fetched and converted data.
*/
private fun fetchAndConvertDataFromResultSet(
tableColumns: MutableList<TableColumnMetadata>,
rs: ResultSet,
dbType: DbType,
limit: Int
limit: Int,
inferNullability: Boolean,
): AnyFrame {
val data = List(tableColumns.size) { mutableListOf<Any?>() }

Expand Down Expand Up @@ -596,6 +620,7 @@ private fun fetchAndConvertDataFromResultSet(
DataColumn.createValueColumn(
name = tableColumns[index].name,
values = values,
infer = convertNullabilityInference(inferNullability),
type = kotlinTypesForSqlColumns[index]!!
)
}.toDataFrame()
Expand All @@ -605,6 +630,8 @@ private fun fetchAndConvertDataFromResultSet(
return dataFrame
}

private fun convertNullabilityInference(inferNullability: Boolean) = if (inferNullability) Infer.Nulls else Infer.None

private fun extractNewRowFromResultSetAndAddToData(
tableColumns: MutableList<TableColumnMetadata>,
data: List<MutableList<Any?>>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,7 @@ import org.h2.jdbc.JdbcSQLSyntaxErrorException
import org.intellij.lang.annotations.Language
import org.jetbrains.kotlinx.dataframe.DataFrame
import org.jetbrains.kotlinx.dataframe.annotations.DataSchema
import org.jetbrains.kotlinx.dataframe.api.add
import org.jetbrains.kotlinx.dataframe.api.cast
import org.jetbrains.kotlinx.dataframe.api.filter
import org.jetbrains.kotlinx.dataframe.api.select
import org.jetbrains.kotlinx.dataframe.api.*
import org.jetbrains.kotlinx.dataframe.io.db.H2
import org.junit.AfterClass
import org.junit.BeforeClass
Expand Down Expand Up @@ -677,4 +674,115 @@ class JdbcTest {
saleDataSchema1.columns.size shouldBe 3
saleDataSchema1.columns["amount"]!!.type shouldBe typeOf<BigDecimal>()
}

@Test
fun `infer nullability`() {
// prepare tables and data
@Language("SQL")
val createTestTable1Query = """
CREATE TABLE TestTable1 (
id INT PRIMARY KEY,
name VARCHAR(50),
surname VARCHAR(50),
age INT NOT NULL
)
"""

connection.createStatement().execute(createTestTable1Query)

connection.createStatement().execute("INSERT INTO TestTable1 (id, name, surname, age) VALUES (1, 'John', 'Crawford', 40)")
connection.createStatement().execute("INSERT INTO TestTable1 (id, name, surname, age) VALUES (2, 'Alice', 'Smith', 25)")
connection.createStatement().execute("INSERT INTO TestTable1 (id, name, surname, age) VALUES (3, 'Bob', 'Johnson', 47)")
connection.createStatement().execute("INSERT INTO TestTable1 (id, name, surname, age) VALUES (4, 'Sam', NULL, 15)")

// start testing `readSqlTable` method

// with default inferNullability: Boolean = true
val tableName = "TestTable1"
val df = DataFrame.readSqlTable(connection, tableName)
df.schema().columns["id"]!!.type shouldBe typeOf<Int>()
df.schema().columns["name"]!!.type shouldBe typeOf<String>()
df.schema().columns["surname"]!!.type shouldBe typeOf<String?>()
df.schema().columns["age"]!!.type shouldBe typeOf<Int>()

val dataSchema = DataFrame.getSchemaForSqlTable(connection, tableName)
dataSchema.columns.size shouldBe 4
dataSchema.columns["id"]!!.type shouldBe typeOf<Int>()
dataSchema.columns["name"]!!.type shouldBe typeOf<String?>()
dataSchema.columns["surname"]!!.type shouldBe typeOf<String?>()
dataSchema.columns["age"]!!.type shouldBe typeOf<Int>()

// with inferNullability: Boolean = false
val df1 = DataFrame.readSqlTable(connection, tableName, inferNullability = false)
df1.schema().columns["id"]!!.type shouldBe typeOf<Int>()
df1.schema().columns["name"]!!.type shouldBe typeOf<String?>() // <=== this column changed a type because it doesn't contain nulls
df1.schema().columns["surname"]!!.type shouldBe typeOf<String?>()
df1.schema().columns["age"]!!.type shouldBe typeOf<Int>()

// end testing `readSqlTable` method

// start testing `readSQLQuery` method

// ith default inferNullability: Boolean = true
@Language("SQL")
val sqlQuery = """
SELECT name, surname, age FROM TestTable1
""".trimIndent()

val df2 = DataFrame.readSqlQuery(connection, sqlQuery)
df2.schema().columns["name"]!!.type shouldBe typeOf<String>()
df2.schema().columns["surname"]!!.type shouldBe typeOf<String?>()
df2.schema().columns["age"]!!.type shouldBe typeOf<Int>()

val dataSchema2 = DataFrame.getSchemaForSqlQuery(connection, sqlQuery)
dataSchema2.columns.size shouldBe 3
dataSchema2.columns["name"]!!.type shouldBe typeOf<String?>()
dataSchema2.columns["surname"]!!.type shouldBe typeOf<String?>()
dataSchema2.columns["age"]!!.type shouldBe typeOf<Int>()

// with inferNullability: Boolean = false
val df3 = DataFrame.readSqlQuery(connection, sqlQuery, inferNullability = false)
df3.schema().columns["name"]!!.type shouldBe typeOf<String?>() // <=== this column changed a type because it doesn't contain nulls
df3.schema().columns["surname"]!!.type shouldBe typeOf<String?>()
df3.schema().columns["age"]!!.type shouldBe typeOf<Int>()

// end testing `readSQLQuery` method

// start testing `readResultSet` method

connection.createStatement(ResultSet.TYPE_SCROLL_SENSITIVE, ResultSet.CONCUR_UPDATABLE).use { st ->
@Language("SQL")
val selectStatement = "SELECT * FROM TestTable1"

st.executeQuery(selectStatement).use { rs ->
// ith default inferNullability: Boolean = true
val df4 = DataFrame.readResultSet(rs, H2)
df4.schema().columns["id"]!!.type shouldBe typeOf<Int>()
df4.schema().columns["name"]!!.type shouldBe typeOf<String>()
df4.schema().columns["surname"]!!.type shouldBe typeOf<String?>()
df4.schema().columns["age"]!!.type shouldBe typeOf<Int>()

rs.beforeFirst()

val dataSchema3 = DataFrame.getSchemaForResultSet(rs, H2)
dataSchema3.columns.size shouldBe 4
dataSchema3.columns["id"]!!.type shouldBe typeOf<Int>()
dataSchema3.columns["name"]!!.type shouldBe typeOf<String?>()
dataSchema3.columns["surname"]!!.type shouldBe typeOf<String?>()
dataSchema3.columns["age"]!!.type shouldBe typeOf<Int>()

// with inferNullability: Boolean = false
rs.beforeFirst()

val df5 = DataFrame.readResultSet(rs, H2, inferNullability = false)
df5.schema().columns["id"]!!.type shouldBe typeOf<Int>()
df5.schema().columns["name"]!!.type shouldBe typeOf<String?>() // <=== this column changed a type because it doesn't contain nulls
df5.schema().columns["surname"]!!.type shouldBe typeOf<String?>()
df5.schema().columns["age"]!!.type shouldBe typeOf<Int>()
}
}
// end testing `readResultSet` method

connection.createStatement().execute("DROP TABLE TestTable1")
}
}