diff --git a/dataframe-jdbc/build.gradle.kts b/dataframe-jdbc/build.gradle.kts index bac7c997b9..0b2ab2d5e9 100644 --- a/dataframe-jdbc/build.gradle.kts +++ b/dataframe-jdbc/build.gradle.kts @@ -33,6 +33,7 @@ dependencies { testImplementation(libs.kotestAssertions) { exclude("org.jetbrains.kotlin", "kotlin-stdlib-jdk8") } + testImplementation(libs.hikaricp) } kotlinPublications { diff --git a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2/h2Test.kt b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2/h2Test.kt index a95ec1e3db..9d0d0c8279 100644 --- a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2/h2Test.kt +++ b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2/h2Test.kt @@ -1,9 +1,12 @@ package org.jetbrains.kotlinx.dataframe.io.h2 +import com.zaxxer.hikari.HikariConfig +import com.zaxxer.hikari.HikariDataSource import io.kotest.assertions.throwables.shouldThrow import io.kotest.assertions.throwables.shouldThrowExactly import io.kotest.matchers.shouldBe import org.intellij.lang.annotations.Language +import org.jetbrains.kotlinx.dataframe.AnyFrame import org.jetbrains.kotlinx.dataframe.DataFrame import org.jetbrains.kotlinx.dataframe.annotations.DataSchema import org.jetbrains.kotlinx.dataframe.api.add @@ -34,6 +37,8 @@ import kotlin.reflect.typeOf private const val URL = "jdbc:h2:mem:test5;DB_CLOSE_DELAY=-1;MODE=MySQL;DATABASE_TO_UPPER=false" +private const val MAXIMUM_POOL_SIZE = 5 + @DataSchema interface Customer { val id: Int? @@ -88,13 +93,30 @@ interface TestTableData { class JdbcTest { companion object { private lateinit var connection: Connection + private lateinit var dataSource: HikariDataSource @BeforeClass @JvmStatic fun setUpClass() { - connection = - DriverManager.getConnection(URL) + initializeConnection() + initializeDataSource() + createTablesAndData() + } + + private fun initializeConnection() { + connection = DriverManager.getConnection(URL) + } + + private fun initializeDataSource() { + val config = HikariConfig().apply { + jdbcUrl = URL + maximumPoolSize = MAXIMUM_POOL_SIZE + minimumIdle = 2 + } + dataSource = HikariDataSource(config) + } + private fun createTablesAndData() { // Create table Customer @Language("SQL") val createCustomerTableQuery = """ @@ -136,13 +158,105 @@ class JdbcTest { @JvmStatic fun tearDownClass() { try { + dataSource.close() connection.close() } catch (e: SQLException) { e.printStackTrace() } } + + // Helper assertion functions + private fun assertCustomerData(df: AnyFrame, expectedRows: Int = 4) { + val casted = df.cast() + casted.rowsCount() shouldBe expectedRows + val expectedOlderThan30 = when (expectedRows) { + 4 -> 2 + 2 -> 1 + else -> 1 // for 1 row or other small limits in tests + } + casted.filter { + it[Customer::age] != null && it[Customer::age]!! > 30 + }.rowsCount() shouldBe expectedOlderThan30 + casted[0][1] shouldBe "John" + } + + private fun assertCustomerSchema(schema: DataFrameSchema) { + schema.columns.size shouldBe 3 + schema.columns["name"]!!.type shouldBe typeOf() + } + + private fun assertCustomerSalesData(df: AnyFrame, expectedRows: Int = 2) { + val casted = df.cast() + casted.rowsCount() shouldBe expectedRows + // In current tests, regardless of limit (2 or 1), the count of totalSalesAmount > 100 is 1 + casted.filter { it[CustomerSales::totalSalesAmount]!! > 100 }.rowsCount() shouldBe 1 + casted[0][0] shouldBe "John" + } + + private fun assertCustomerSalesSchema(schema: DataFrameSchema) { + schema.columns.size shouldBe 2 + schema.columns["name"]!!.type shouldBe typeOf() + } + + private fun assertAllTablesData(dataFrameMap: Map) { + dataFrameMap.containsKey("Customer") shouldBe true + dataFrameMap.containsKey("Sale") shouldBe true + + val dataframes = dataFrameMap.values.toList() + + val customerDf = dataframes[0].cast() + customerDf.rowsCount() shouldBe 4 + customerDf.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 2 + customerDf[0][1] shouldBe "John" + + val saleDf = dataframes[1].cast() + saleDf.rowsCount() shouldBe 4 + saleDf.filter { it[Sale::amount] > 40 }.rowsCount() shouldBe 3 + (saleDf[0][2] as BigDecimal).compareTo(BigDecimal(100.50)) shouldBe 0 + } + + private fun assertAllTablesDataWithLimit(dataFrameMap: Map) { + val dataframes = dataFrameMap.values.toList() + + val customerDf = dataframes[0].cast() + customerDf.rowsCount() shouldBe 1 + customerDf.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 1 + customerDf[0][1] shouldBe "John" + + val saleDf = dataframes[1].cast() + saleDf.rowsCount() shouldBe 1 + saleDf.filter { it[Sale::amount] > 40 }.rowsCount() shouldBe 1 + (saleDf[0][2] as BigDecimal).compareTo(BigDecimal(100.50)) shouldBe 0 + } + + private fun assertAllTablesSchema(dataFrameSchemaMap: Map) { + dataFrameSchemaMap.containsKey("Customer") shouldBe true + dataFrameSchemaMap.containsKey("Sale") shouldBe true + + val dataSchemas = dataFrameSchemaMap.values.toList() + + val customerDataSchema = dataSchemas[0] + customerDataSchema.columns.size shouldBe 3 + customerDataSchema.columns["name"]!!.type shouldBe typeOf() + + val saleDataSchema = dataSchemas[1] + saleDataSchema.columns.size shouldBe 3 + saleDataSchema.columns["amount"]!!.type shouldBe typeOf() + } + + @Language("SQL") + private val CUSTOMER_SALES_QUERY = + """ + SELECT c.name as customerName, SUM(s.amount) as totalSalesAmount + FROM Sale s + INNER JOIN Customer c ON s.customerId = c.id + WHERE c.age > 35 + GROUP BY s.customerId, c.name + """.trimIndent() } + // ========== Connection API Tests ========== + @Test fun `read from empty table`() { @Language("SQL") @@ -305,95 +419,60 @@ class JdbcTest { @Test fun `read from table`() { val tableName = "Customer" - val df = DataFrame.readSqlTable(connection, tableName).cast() - - df.rowsCount() shouldBe 4 - df.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 2 - df[0][1] shouldBe "John" - - val df1 = DataFrame.readSqlTable(connection, tableName, 1).cast() + val df = DataFrame.readSqlTable(connection, tableName) + assertCustomerData(df) - df1.rowsCount() shouldBe 1 - df1.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 1 - df1[0][1] shouldBe "John" + val df1 = DataFrame.readSqlTable(connection, tableName, 1) + assertCustomerData(df1, 1) val dataSchema = DataFrameSchema.readSqlTable(connection, tableName) - dataSchema.columns.size shouldBe 3 - dataSchema.columns["name"]!!.type shouldBe typeOf() + assertCustomerSchema(dataSchema) val dbConfig = DbConnectionConfig(url = URL) - val df2 = DataFrame.readSqlTable(dbConfig, tableName).cast() + val df2 = DataFrame.readSqlTable(dbConfig, tableName) + assertCustomerData(df2) - df2.rowsCount() shouldBe 4 - df2.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 2 - df2[0][1] shouldBe "John" - - val df3 = DataFrame.readSqlTable(dbConfig, tableName, 1).cast() - - df3.rowsCount() shouldBe 1 - df3.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 1 - df3[0][1] shouldBe "John" + val df3 = DataFrame.readSqlTable(dbConfig, tableName, 1) + assertCustomerData(df3, 1) val dataSchema1 = DataFrameSchema.readSqlTable(dbConfig, tableName) - dataSchema1.columns.size shouldBe 3 - dataSchema1.columns["name"]!!.type shouldBe typeOf() + assertCustomerSchema(dataSchema1) } @Test fun `read from table with extension functions`() { val tableName = "Customer" - val df = connection.readDataFrame(tableName).cast() - - df.rowsCount() shouldBe 4 - df.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 2 - df[0][1] shouldBe "John" + val df = connection.readDataFrame(tableName) + assertCustomerData(df) - val df1 = connection.readDataFrame(tableName, 1).cast() - - df1.rowsCount() shouldBe 1 - df1.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 1 - df1[0][1] shouldBe "John" + val df1 = connection.readDataFrame(tableName, 1) + assertCustomerData(df1, 1) val dataSchema = connection.readDataFrameSchema(tableName) - dataSchema.columns.size shouldBe 3 - dataSchema.columns["name"]!!.type shouldBe typeOf() + assertCustomerSchema(dataSchema) val dbConfig = DbConnectionConfig(url = URL) - val df2 = dbConfig.readDataFrame(tableName).cast() - - df2.rowsCount() shouldBe 4 - df2.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 2 - df2[0][1] shouldBe "John" - - val df3 = dbConfig.readDataFrame(tableName, 1).cast() + val df2 = dbConfig.readDataFrame(tableName) + assertCustomerData(df2) - df3.rowsCount() shouldBe 1 - df3.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 1 - df3[0][1] shouldBe "John" + val df3 = dbConfig.readDataFrame(tableName, 1) + assertCustomerData(df3, 1) val dataSchema1 = dbConfig.readDataFrameSchema(tableName) - dataSchema1.columns.size shouldBe 3 - dataSchema1.columns["name"]!!.type shouldBe typeOf() + assertCustomerSchema(dataSchema1) } - // to cover a reported case from https://github.com/Kotlin/dataframe/issues/494 @Test fun `repeated read from table with limit`() { val tableName = "Customer" - for (i in 1..10) { - val df1 = DataFrame.readSqlTable(connection, tableName, 2).cast() - - df1.rowsCount() shouldBe 2 - df1.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 1 - df1[0][1] shouldBe "John" + repeat(10) { + val df1 = DataFrame.readSqlTable(connection, tableName, 2) + assertCustomerData(df1, 2) val dbConfig = DbConnectionConfig(url = URL) - val df2 = DataFrame.readSqlTable(dbConfig, tableName, 2).cast() - - df2.rowsCount() shouldBe 2 - df2.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 1 - df2[0][1] shouldBe "John" + val df2 = DataFrame.readSqlTable(dbConfig, tableName, 2) + assertCustomerData(df2, 2) } } @@ -404,47 +483,33 @@ class JdbcTest { val selectStatement = "SELECT * FROM Customer" st.executeQuery(selectStatement).use { rs -> - val df = DataFrame.readResultSet(rs, H2(MySql)).cast() - - df.rowsCount() shouldBe 4 - df.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 2 - df[0][1] shouldBe "John" + val df = DataFrame.readResultSet(rs, H2(MySql)) + assertCustomerData(df) rs.beforeFirst() - val df1 = DataFrame.readResultSet(rs, H2(MySql), 1).cast() - - df1.rowsCount() shouldBe 1 - df1.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 1 - df1[0][1] shouldBe "John" + val df1 = DataFrame.readResultSet(rs, H2(MySql), 1) + assertCustomerData(df1, 1) rs.beforeFirst() val dataSchema = DataFrameSchema.readResultSet(rs, H2(MySql)) - dataSchema.columns.size shouldBe 3 - dataSchema.columns["name"]!!.type shouldBe typeOf() + assertCustomerSchema(dataSchema) rs.beforeFirst() - val df2 = DataFrame.readResultSet(rs, connection).cast() - - df2.rowsCount() shouldBe 4 - df2.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 2 - df2[0][1] shouldBe "John" + val df2 = DataFrame.readResultSet(rs, connection) + assertCustomerData(df2) rs.beforeFirst() - val df3 = DataFrame.readResultSet(rs, connection, 1).cast() - - df3.rowsCount() shouldBe 1 - df3.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 1 - df3[0][1] shouldBe "John" + val df3 = DataFrame.readResultSet(rs, connection, 1) + assertCustomerData(df3, 1) rs.beforeFirst() val dataSchema1 = DataFrameSchema.readResultSet(rs, H2(MySql)) - dataSchema1.columns.size shouldBe 3 - dataSchema1.columns["name"]!!.type shouldBe typeOf() + assertCustomerSchema(dataSchema1) } } } @@ -456,47 +521,33 @@ class JdbcTest { val selectStatement = "SELECT * FROM Customer" st.executeQuery(selectStatement).use { rs -> - val df = rs.readDataFrame(H2(MySql)).cast() - - df.rowsCount() shouldBe 4 - df.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 2 - df[0][1] shouldBe "John" + val df = rs.readDataFrame(H2(MySql)) + assertCustomerData(df) rs.beforeFirst() - val df1 = rs.readDataFrame(H2(MySql), 1).cast() - - df1.rowsCount() shouldBe 1 - df1.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 1 - df1[0][1] shouldBe "John" + val df1 = rs.readDataFrame(H2(MySql), 1) + assertCustomerData(df1, 1) rs.beforeFirst() val dataSchema = rs.readDataFrameSchema(H2(MySql)) - dataSchema.columns.size shouldBe 3 - dataSchema.columns["name"]!!.type shouldBe typeOf() + assertCustomerSchema(dataSchema) rs.beforeFirst() - val df2 = rs.readDataFrame(connection).cast() - - df2.rowsCount() shouldBe 4 - df2.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 2 - df2[0][1] shouldBe "John" + val df2 = rs.readDataFrame(connection) + assertCustomerData(df2) rs.beforeFirst() - val df3 = rs.readDataFrame(connection, 1).cast() - - df3.rowsCount() shouldBe 1 - df3.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 1 - df3[0][1] shouldBe "John" + val df3 = rs.readDataFrame(connection, 1) + assertCustomerData(df3, 1) rs.beforeFirst() val dataSchema1 = rs.readDataFrameSchema(H2(MySql)) - dataSchema1.columns.size shouldBe 3 - dataSchema1.columns["name"]!!.type shouldBe typeOf() + assertCustomerSchema(dataSchema1) } } } @@ -509,22 +560,16 @@ class JdbcTest { val selectStatement = "SELECT * FROM Customer" st.executeQuery(selectStatement).use { rs -> - for (i in 1..10) { + repeat(10) { rs.beforeFirst() - val df1 = DataFrame.readResultSet(rs, H2(MySql), 2).cast() - - df1.rowsCount() shouldBe 2 - df1.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 1 - df1[0][1] shouldBe "John" + val df1 = DataFrame.readResultSet(rs, H2(MySql), 2) + assertCustomerData(df1, 2) rs.beforeFirst() - val df2 = DataFrame.readResultSet(rs, connection, 2).cast() - - df2.rowsCount() shouldBe 2 - df2.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 1 - df2[0][1] shouldBe "John" + val df2 = DataFrame.readResultSet(rs, connection, 2) + assertCustomerData(df2, 2) } } } @@ -819,94 +864,46 @@ class JdbcTest { @Test fun `read from sql query`() { - @Language("SQL") - val sqlQuery = - """ - SELECT c.name as customerName, SUM(s.amount) as totalSalesAmount - FROM Sale s - INNER JOIN Customer c ON s.customerId = c.id - WHERE c.age > 35 - GROUP BY s.customerId, c.name - """.trimIndent() - - val df = DataFrame.readSqlQuery(connection, sqlQuery).cast() - - df.rowsCount() shouldBe 2 - df.filter { it[CustomerSales::totalSalesAmount]!! > 100 }.rowsCount() shouldBe 1 - df[0][0] shouldBe "John" + val df = DataFrame.readSqlQuery(connection, CUSTOMER_SALES_QUERY) + assertCustomerSalesData(df) - val df1 = DataFrame.readSqlQuery(connection, sqlQuery, 1).cast() + val df1 = DataFrame.readSqlQuery(connection, CUSTOMER_SALES_QUERY, 1) + assertCustomerSalesData(df1, 1) - df1.rowsCount() shouldBe 1 - df1.filter { it[CustomerSales::totalSalesAmount]!! > 100 }.rowsCount() shouldBe 1 - df1[0][0] shouldBe "John" - - val dataSchema = DataFrameSchema.readSqlQuery(connection, sqlQuery) - dataSchema.columns.size shouldBe 2 - dataSchema.columns["name"]!!.type shouldBe typeOf() + val dataSchema = DataFrameSchema.readSqlQuery(connection, CUSTOMER_SALES_QUERY) + assertCustomerSalesSchema(dataSchema) val dbConfig = DbConnectionConfig(url = URL) - val df2 = DataFrame.readSqlQuery(dbConfig, sqlQuery).cast() - - df2.rowsCount() shouldBe 2 - df2.filter { it[CustomerSales::totalSalesAmount]!! > 100 }.rowsCount() shouldBe 1 - df2[0][0] shouldBe "John" + val df2 = DataFrame.readSqlQuery(dbConfig, CUSTOMER_SALES_QUERY) + assertCustomerSalesData(df2) - val df3 = DataFrame.readSqlQuery(dbConfig, sqlQuery, 1).cast() + val df3 = DataFrame.readSqlQuery(dbConfig, CUSTOMER_SALES_QUERY, 1) + assertCustomerSalesData(df3, 1) - df3.rowsCount() shouldBe 1 - df3.filter { it[CustomerSales::totalSalesAmount]!! > 100 }.rowsCount() shouldBe 1 - df3[0][0] shouldBe "John" - - val dataSchema1 = DataFrameSchema.readSqlQuery(dbConfig, sqlQuery) - dataSchema1.columns.size shouldBe 2 - dataSchema1.columns["name"]!!.type shouldBe typeOf() + val dataSchema1 = DataFrameSchema.readSqlQuery(dbConfig, CUSTOMER_SALES_QUERY) + assertCustomerSalesSchema(dataSchema1) } @Test fun `read from sql query with extension functions`() { - @Language("SQL") - val sqlQuery = - """ - SELECT c.name as customerName, SUM(s.amount) as totalSalesAmount - FROM Sale s - INNER JOIN Customer c ON s.customerId = c.id - WHERE c.age > 35 - GROUP BY s.customerId, c.name - """.trimIndent() - - val df = connection.readDataFrame(sqlQuery).cast() - - df.rowsCount() shouldBe 2 - df.filter { it[CustomerSales::totalSalesAmount]!! > 100 }.rowsCount() shouldBe 1 - df[0][0] shouldBe "John" + val df = connection.readDataFrame(CUSTOMER_SALES_QUERY) + assertCustomerSalesData(df) - val df1 = connection.readDataFrame(sqlQuery, 1).cast() + val df1 = connection.readDataFrame(CUSTOMER_SALES_QUERY, 1) + assertCustomerSalesData(df1, 1) - df1.rowsCount() shouldBe 1 - df1.filter { it[CustomerSales::totalSalesAmount]!! > 100 }.rowsCount() shouldBe 1 - df1[0][0] shouldBe "John" - - val dataSchema = connection.readDataFrameSchema(sqlQuery) - dataSchema.columns.size shouldBe 2 - dataSchema.columns["name"]!!.type shouldBe typeOf() + val dataSchema = connection.readDataFrameSchema(CUSTOMER_SALES_QUERY) + assertCustomerSalesSchema(dataSchema) val dbConfig = DbConnectionConfig(url = URL) - val df2 = dbConfig.readDataFrame(sqlQuery).cast() - - df2.rowsCount() shouldBe 2 - df2.filter { it[CustomerSales::totalSalesAmount]!! > 100 }.rowsCount() shouldBe 1 - df2[0][0] shouldBe "John" + val df2 = dbConfig.readDataFrame(CUSTOMER_SALES_QUERY) + assertCustomerSalesData(df2) - val df3 = dbConfig.readDataFrame(sqlQuery, 1).cast() + val df3 = dbConfig.readDataFrame(CUSTOMER_SALES_QUERY, 1) + assertCustomerSalesData(df3, 1) - df3.rowsCount() shouldBe 1 - df3.filter { it[CustomerSales::totalSalesAmount]!! > 100 }.rowsCount() shouldBe 1 - df3[0][0] shouldBe "John" - - val dataSchema1 = dbConfig.readDataFrameSchema(sqlQuery) - dataSchema1.columns.size shouldBe 2 - dataSchema1.columns["name"]!!.type shouldBe typeOf() + val dataSchema1 = dbConfig.readDataFrameSchema(CUSTOMER_SALES_QUERY) + assertCustomerSalesSchema(dataSchema1) } @Test @@ -945,90 +942,23 @@ class JdbcTest { @Test fun `read from all tables`() { val dataFrameMap = DataFrame.readAllSqlTables(connection) - dataFrameMap.containsKey("Customer") shouldBe true - dataFrameMap.containsKey("Sale") shouldBe true - - val dataframes = dataFrameMap.values.toList() - - val customerDf = dataframes[0].cast() - - customerDf.rowsCount() shouldBe 4 - customerDf.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 2 - customerDf[0][1] shouldBe "John" - - val saleDf = dataframes[1].cast() - - saleDf.rowsCount() shouldBe 4 - saleDf.filter { it[Sale::amount] > 40 }.rowsCount() shouldBe 3 - (saleDf[0][2] as BigDecimal).compareTo(BigDecimal(100.50)) shouldBe 0 + assertAllTablesData(dataFrameMap) - val dataframes1 = DataFrame.readAllSqlTables(connection, limit = 1).values.toList() - - val customerDf1 = dataframes1[0].cast() - - customerDf1.rowsCount() shouldBe 1 - customerDf1.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 1 - customerDf1[0][1] shouldBe "John" - - val saleDf1 = dataframes1[1].cast() - - saleDf1.rowsCount() shouldBe 1 - saleDf1.filter { it[Sale::amount] > 40 }.rowsCount() shouldBe 1 - (saleDf[0][2] as BigDecimal).compareTo(BigDecimal(100.50)) shouldBe 0 + val dataframes1 = DataFrame.readAllSqlTables(connection, limit = 1) + assertAllTablesDataWithLimit(dataframes1) val dataFrameSchemaMap = DataFrameSchema.readAllSqlTables(connection) - dataFrameSchemaMap.containsKey("Customer") shouldBe true - dataFrameSchemaMap.containsKey("Sale") shouldBe true - - val dataSchemas = dataFrameSchemaMap.values.toList() - - val customerDataSchema = dataSchemas[0] - customerDataSchema.columns.size shouldBe 3 - customerDataSchema.columns["name"]!!.type shouldBe typeOf() - - val saleDataSchema = dataSchemas[1] - saleDataSchema.columns.size shouldBe 3 - // TODO: fix nullability - saleDataSchema.columns["amount"]!!.type shouldBe typeOf() + assertAllTablesSchema(dataFrameSchemaMap) val dbConfig = DbConnectionConfig(url = URL) - val dataframes2 = DataFrame.readAllSqlTables(dbConfig).values.toList() - - val customerDf2 = dataframes2[0].cast() - - customerDf2.rowsCount() shouldBe 4 - customerDf2.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 2 - customerDf2[0][1] shouldBe "John" - - val saleDf2 = dataframes2[1].cast() - - saleDf2.rowsCount() shouldBe 4 - saleDf2.filter { it[Sale::amount] > 40 }.rowsCount() shouldBe 3 - (saleDf[0][2] as BigDecimal).compareTo(BigDecimal(100.50)) shouldBe 0 - - val dataframes3 = DataFrame.readAllSqlTables(dbConfig, limit = 1).values.toList() - - val customerDf3 = dataframes3[0].cast() + val dataframes2 = DataFrame.readAllSqlTables(dbConfig) + assertAllTablesData(dataframes2) - customerDf3.rowsCount() shouldBe 1 - customerDf3.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 1 - customerDf3[0][1] shouldBe "John" + val dataframes3 = DataFrame.readAllSqlTables(dbConfig, limit = 1) + assertAllTablesDataWithLimit(dataframes3) - val saleDf3 = dataframes3[1].cast() - - saleDf3.rowsCount() shouldBe 1 - saleDf3.filter { it[Sale::amount] > 40 }.rowsCount() shouldBe 1 - (saleDf[0][2] as BigDecimal).compareTo(BigDecimal(100.50)) shouldBe 0 - - val dataSchemas1 = DataFrameSchema.readAllSqlTables(dbConfig).values.toList() - - val customerDataSchema1 = dataSchemas1[0] - customerDataSchema1.columns.size shouldBe 3 - customerDataSchema1.columns["name"]!!.type shouldBe typeOf() - - val saleDataSchema1 = dataSchemas1[1] - saleDataSchema1.columns.size shouldBe 3 - saleDataSchema1.columns["amount"]!!.type shouldBe typeOf() + val dataSchemas1 = DataFrameSchema.readAllSqlTables(dbConfig) + assertAllTablesSchema(dataSchemas1) } @Test @@ -1050,135 +980,206 @@ class JdbcTest { @Test fun `read from table from custom database`() { val tableName = "Customer" - val df = DataFrame.readSqlTable(connection, tableName, dbType = CustomDB).cast() - - df.rowsCount() shouldBe 4 - df.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 2 - df[0][1] shouldBe "John" + val df = DataFrame.readSqlTable(connection, tableName, dbType = CustomDB) + assertCustomerData(df) val dataSchema = DataFrameSchema.readSqlTable(connection, tableName, dbType = CustomDB) - dataSchema.columns.size shouldBe 3 - dataSchema.columns["name"]!!.type shouldBe typeOf() + assertCustomerSchema(dataSchema) val dbConfig = DbConnectionConfig(url = URL) - val df2 = DataFrame.readSqlTable(dbConfig, tableName, dbType = CustomDB).cast() - - df2.rowsCount() shouldBe 4 - df2.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 2 - df2[0][1] shouldBe "John" + val df2 = DataFrame.readSqlTable(dbConfig, tableName, dbType = CustomDB) + assertCustomerData(df2) val dataSchema1 = DataFrameSchema.readSqlTable(dbConfig, tableName, dbType = CustomDB) - dataSchema1.columns.size shouldBe 3 - dataSchema1.columns["name"]!!.type shouldBe typeOf() + assertCustomerSchema(dataSchema1) } @Test fun `read from query from custom database`() { - @Language("SQL") - val sqlQuery = - """ - SELECT c.name as customerName, SUM(s.amount) as totalSalesAmount - FROM Sale s - INNER JOIN Customer c ON s.customerId = c.id - WHERE c.age > 35 - GROUP BY s.customerId, c.name - """.trimIndent() + val df = DataFrame.readSqlQuery(connection, CUSTOMER_SALES_QUERY, dbType = CustomDB) + assertCustomerSalesData(df) - val df = DataFrame.readSqlQuery(connection, sqlQuery, dbType = CustomDB).cast() + val dataSchema = DataFrameSchema.readSqlQuery(connection, CUSTOMER_SALES_QUERY, dbType = CustomDB) + assertCustomerSalesSchema(dataSchema) - df.rowsCount() shouldBe 2 - df.filter { it[CustomerSales::totalSalesAmount]!! > 100 }.rowsCount() shouldBe 1 - df[0][0] shouldBe "John" + val dbConfig = DbConnectionConfig(url = URL) + val df2 = DataFrame.readSqlQuery(dbConfig, CUSTOMER_SALES_QUERY, dbType = CustomDB) + assertCustomerSalesData(df2) - val dataSchema = DataFrameSchema.readSqlQuery(connection, sqlQuery, dbType = CustomDB) - dataSchema.columns.size shouldBe 2 - dataSchema.columns["name"]!!.type shouldBe typeOf() + val dataSchema1 = DataFrameSchema.readSqlQuery(dbConfig, CUSTOMER_SALES_QUERY, dbType = CustomDB) + assertCustomerSalesSchema(dataSchema1) + } + + @Test + fun `read from all tables from custom database`() { + val dataFrameMap = DataFrame.readAllSqlTables(connection, dbType = CustomDB) + assertAllTablesData(dataFrameMap) + + val dataFrameSchemaMap = DataFrameSchema.readAllSqlTables(connection, dbType = CustomDB) + assertAllTablesSchema(dataFrameSchemaMap) val dbConfig = DbConnectionConfig(url = URL) - val df2 = DataFrame.readSqlQuery(dbConfig, sqlQuery, dbType = CustomDB).cast() + val dataframes2 = DataFrame.readAllSqlTables(dbConfig, dbType = CustomDB) + assertAllTablesData(dataframes2) - df2.rowsCount() shouldBe 2 - df2.filter { it[CustomerSales::totalSalesAmount]!! > 100 }.rowsCount() shouldBe 1 - df2[0][0] shouldBe "John" + val dataSchemas1 = DataFrameSchema.readAllSqlTables(dbConfig, dbType = CustomDB) + assertAllTablesSchema(dataSchemas1) + } + + @Test + fun `withReadOnlyConnection sets readOnly and rolls back after execution`() { + val config = DbConnectionConfig("jdbc:h2:mem:test;MODE=MySQL;DB_CLOSE_DELAY=-1", readOnly = true) + + var wasExecuted = false + val result = withReadOnlyConnection(config) { conn -> + wasExecuted = true + conn.autoCommit shouldBe false + 42 + } - val dataSchema1 = DataFrameSchema.readSqlQuery(dbConfig, sqlQuery, dbType = CustomDB) - dataSchema1.columns.size shouldBe 2 - dataSchema1.columns["name"]!!.type shouldBe typeOf() + wasExecuted shouldBe true + result shouldBe 42 } + // ========== DataSource API Tests ========== + @Test - fun `read from all tables from custom database`() { - val dataFrameMap = DataFrame.readAllSqlTables(connection, dbType = CustomDB) - dataFrameMap.containsKey("Customer") shouldBe true - dataFrameMap.containsKey("Sale") shouldBe true + fun `read from table using DataSource`() { + val tableName = "Customer" + val df = DataFrame.readSqlTable(dataSource, tableName) + assertCustomerData(df) - val dataframes = dataFrameMap.values.toList() + val df1 = DataFrame.readSqlTable(dataSource, tableName, 1) + assertCustomerData(df1, 1) - val customerDf = dataframes[0].cast() + val dataSchema = DataFrameSchema.readSqlTable(dataSource, tableName) + assertCustomerSchema(dataSchema) + } - customerDf.rowsCount() shouldBe 4 - customerDf.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 2 - customerDf[0][1] shouldBe "John" + @Test + fun `read from table with extension functions using DataSource`() { + val tableName = "Customer" + val df = dataSource.readDataFrame(tableName) + assertCustomerData(df) - val saleDf = dataframes[1].cast() + val df1 = dataSource.readDataFrame(tableName, 1) + assertCustomerData(df1, 1) - saleDf.rowsCount() shouldBe 4 - saleDf.filter { it[Sale::amount] > 40 }.rowsCount() shouldBe 3 - (saleDf[0][2] as BigDecimal).compareTo(BigDecimal(100.50)) shouldBe 0 + val dataSchema = dataSource.readDataFrameSchema(tableName) + assertCustomerSchema(dataSchema) + } - val dataFrameSchemaMap = DataFrameSchema.readAllSqlTables(connection, dbType = CustomDB) - dataFrameSchemaMap.containsKey("Customer") shouldBe true - dataFrameSchemaMap.containsKey("Sale") shouldBe true + @Test + fun `read from sql query using DataSource`() { + val df = DataFrame.readSqlQuery(dataSource, CUSTOMER_SALES_QUERY) + assertCustomerSalesData(df) - val dataSchemas = dataFrameSchemaMap.values.toList() + val df1 = DataFrame.readSqlQuery(dataSource, CUSTOMER_SALES_QUERY, 1) + assertCustomerSalesData(df1, 1) - val customerDataSchema = dataSchemas[0] - customerDataSchema.columns.size shouldBe 3 - customerDataSchema.columns["name"]!!.type shouldBe typeOf() + val dataSchema = DataFrameSchema.readSqlQuery(dataSource, CUSTOMER_SALES_QUERY) + assertCustomerSalesSchema(dataSchema) + } - val saleDataSchema = dataSchemas[1] - saleDataSchema.columns.size shouldBe 3 - // TODO: fix nullability - saleDataSchema.columns["amount"]!!.type shouldBe typeOf() + @Test + fun `read from sql query with extension functions using DataSource`() { + val df = dataSource.readDataFrame(CUSTOMER_SALES_QUERY) + assertCustomerSalesData(df) - val dbConfig = DbConnectionConfig(url = URL) - val dataframes2 = DataFrame.readAllSqlTables(dbConfig, dbType = CustomDB).values.toList() + val df1 = dataSource.readDataFrame(CUSTOMER_SALES_QUERY, 1) + assertCustomerSalesData(df1, 1) - val customerDf2 = dataframes2[0].cast() + val dataSchema = dataSource.readDataFrameSchema(CUSTOMER_SALES_QUERY) + assertCustomerSalesSchema(dataSchema) + } + + @Test + fun `read from all tables using DataSource`() { + val dataFrameMap = DataFrame.readAllSqlTables(dataSource) + assertAllTablesData(dataFrameMap) - customerDf2.rowsCount() shouldBe 4 - customerDf2.filter { it[Customer::age] != null && it[Customer::age]!! > 30 }.rowsCount() shouldBe 2 - customerDf2[0][1] shouldBe "John" + val dataframes1 = DataFrame.readAllSqlTables(dataSource, limit = 1) + assertAllTablesDataWithLimit(dataframes1) - val saleDf2 = dataframes2[1].cast() + val dataFrameSchemaMap = DataFrameSchema.readAllSqlTables(dataSource) + assertAllTablesSchema(dataFrameSchemaMap) + } - saleDf2.rowsCount() shouldBe 4 - saleDf2.filter { it[Sale::amount] > 40 }.rowsCount() shouldBe 3 - (saleDf[0][2] as BigDecimal).compareTo(BigDecimal(100.50)) shouldBe 0 + @Test + fun `read from table from custom database using DataSource`() { + val tableName = "Customer" + val df = DataFrame.readSqlTable(dataSource, tableName, dbType = CustomDB) + assertCustomerData(df) - val dataSchemas1 = DataFrameSchema.readAllSqlTables(dbConfig, dbType = CustomDB).values.toList() + val dataSchema = DataFrameSchema.readSqlTable(dataSource, tableName, dbType = CustomDB) + assertCustomerSchema(dataSchema) + } - val customerDataSchema1 = dataSchemas1[0] - customerDataSchema1.columns.size shouldBe 3 - customerDataSchema1.columns["name"]!!.type shouldBe typeOf() + @Test + fun `read from query from custom database using DataSource`() { + val df = DataFrame.readSqlQuery(dataSource, CUSTOMER_SALES_QUERY, dbType = CustomDB) + assertCustomerSalesData(df) - val saleDataSchema1 = dataSchemas1[1] - saleDataSchema1.columns.size shouldBe 3 - saleDataSchema1.columns["amount"]!!.type shouldBe typeOf() + val dataSchema = DataFrameSchema.readSqlQuery(dataSource, CUSTOMER_SALES_QUERY, dbType = CustomDB) + assertCustomerSalesSchema(dataSchema) } @Test - fun `withReadOnlyConnection sets readOnly and rolls back after execution`() { - val config = DbConnectionConfig("jdbc:h2:mem:test;MODE=MySQL;DB_CLOSE_DELAY=-1", readOnly = true) + fun `read from all tables from custom database using DataSource`() { + val dataFrameMap = DataFrame.readAllSqlTables(dataSource, dbType = CustomDB) + assertAllTablesData(dataFrameMap) - var wasExecuted = false - val result = withReadOnlyConnection(config) { conn -> - wasExecuted = true - conn.autoCommit shouldBe false - 42 + val dataFrameSchemaMap = DataFrameSchema.readAllSqlTables(dataSource, dbType = CustomDB) + assertAllTablesSchema(dataFrameSchemaMap) + } + + // ========== Connection Pool Tests ========== + + @Test + fun `repeated read from table with limit using DataSource`() { + // Verify DataSource integration handles repeated sequential reads correctly. + // Covers issue #494 where repeated reads with limit produced incorrect results. + val tableName = "Customer" + repeat(MAXIMUM_POOL_SIZE * 2) { + val df = DataFrame.readSqlTable(dataSource, tableName, 2) + assertCustomerData(df, 2) + } + } + + @Test + fun `DataSource sequential reads return connections to pool`() { + // Verify connections are properly closed and returned to the pool after each read. + // Would fail on iteration 6 if connections leak (maximumPoolSize=5). + repeat(MAXIMUM_POOL_SIZE * 2) { + val df = DataFrame.readSqlTable(dataSource, "Customer", limit = 1) + df.rowsCount() shouldBe 1 + assertCustomerData(df, 1) } + } - wasExecuted shouldBe true - result shouldBe 42 + @Test + fun `DataSource sequential reads with alternating tables`() { + // Test connection reuse when sequentially reading from different tables. + // Ensures no state pollution when switching between table schemas. + repeat(MAXIMUM_POOL_SIZE * 2) { i -> + val tableName = if (i % 2 == 0) "Customer" else "Sale" + val df = DataFrame.readSqlTable(dataSource, tableName, limit = 1) + df.rowsCount() shouldBe 1 + } + } + + @Test + fun `DataSource sequential reads with mixed query and table operations`() { + // Verify both readSqlTable and readSqlQuery properly manage the connection lifecycle. + // Tests that different code paths can alternate sequentially without resource leaks. + repeat(MAXIMUM_POOL_SIZE * 2) { + val dfTable = DataFrame.readSqlTable(dataSource, "Customer", limit = 1) + dfTable.rowsCount() shouldBe 1 + assertCustomerData(dfTable, 1) + + val dfQuery = DataFrame.readSqlQuery(dataSource, CUSTOMER_SALES_QUERY, limit = 1) + dfQuery.rowsCount() shouldBe 1 + assertCustomerSalesData(dfQuery, 1) + } } } diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 8e1e3a550d..d7d527934a 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -54,6 +54,7 @@ shadow = "8.3.5" android-gradle-api = "7.3.1" # need to revise our tests to update ktor = "3.0.1" # needs jupyter compatibility with Kotlin 2.1 to update kotlin-compile-testing = "0.7.1" +hikari = "7.0.2" duckdb = "1.3.1.0" buildconfig = "5.6.7" benchmark = "0.4.12" @@ -71,7 +72,6 @@ kandy-stats-notebook = "0.5.0n" exposed = "1.0.0-beta-2" hibernate = "6.5.2.Final" -hikari = "5.1.0" # check the versions down in the [libraries] section too! kotlin-spark = "1.2.4" @@ -180,6 +180,7 @@ kotlin-jupyter-test-kit = { group = "org.jetbrains.kotlinx", name = "kotlin-jupy kotlinx-benchmark-runtime = { group = "org.jetbrains.kotlinx", name = "kotlinx-benchmark-runtime", version.ref = "benchmark" } dataframe-symbol-processor = { group = "org.jetbrains.kotlinx.dataframe", name = "symbol-processor-all" } +hikari = { group = "com.zaxxer", name = "HikariCP", version.ref = "hikari" } duckdb-jdbc = { group = "org.duckdb", name = "duckdb_jdbc", version.ref = "duckdb" } exposed-core = { group = "org.jetbrains.exposed", name = "exposed-core", version.ref = "exposed" }