diff --git a/dataframe-jdbc/api/dataframe-jdbc.api b/dataframe-jdbc/api/dataframe-jdbc.api index bd57ed2bcd..bb142cba42 100644 --- a/dataframe-jdbc/api/dataframe-jdbc.api +++ b/dataframe-jdbc/api/dataframe-jdbc.api @@ -135,19 +135,39 @@ public class org/jetbrains/kotlinx/dataframe/io/db/H2 : org/jetbrains/kotlinx/da public static final field MODE_POSTGRESQL Ljava/lang/String; public fun ()V public fun (Lorg/jetbrains/kotlinx/dataframe/io/db/DbType;)V - public synthetic fun (Lorg/jetbrains/kotlinx/dataframe/io/db/DbType;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public fun (Lorg/jetbrains/kotlinx/dataframe/io/db/H2$Mode;)V + public synthetic fun (Lorg/jetbrains/kotlinx/dataframe/io/db/H2$Mode;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public fun buildSqlQueryWithLimit (Ljava/lang/String;I)Ljava/lang/String; public fun buildTableMetadata (Ljava/sql/ResultSet;)Lorg/jetbrains/kotlinx/dataframe/io/db/TableMetadata; public fun convertSqlTypeToColumnSchemaValue (Lorg/jetbrains/kotlinx/dataframe/io/db/TableColumnMetadata;)Lorg/jetbrains/kotlinx/dataframe/schema/ColumnSchema; public fun convertSqlTypeToKType (Lorg/jetbrains/kotlinx/dataframe/io/db/TableColumnMetadata;)Lkotlin/reflect/KType; - public final fun getDialect ()Lorg/jetbrains/kotlinx/dataframe/io/db/DbType; public fun getDriverClassName ()Ljava/lang/String; + public final fun getMode ()Lorg/jetbrains/kotlinx/dataframe/io/db/H2$Mode; public fun isSystemTable (Lorg/jetbrains/kotlinx/dataframe/io/db/TableMetadata;)Z } public final class org/jetbrains/kotlinx/dataframe/io/db/H2$Companion { } +public final class org/jetbrains/kotlinx/dataframe/io/db/H2$Mode : java/lang/Enum { + public static final field Companion Lorg/jetbrains/kotlinx/dataframe/io/db/H2$Mode$Companion; + public static final field MariaDb Lorg/jetbrains/kotlinx/dataframe/io/db/H2$Mode; + public static final field MsSqlServer Lorg/jetbrains/kotlinx/dataframe/io/db/H2$Mode; + public static final field MySql Lorg/jetbrains/kotlinx/dataframe/io/db/H2$Mode; + public static final field PostgreSql Lorg/jetbrains/kotlinx/dataframe/io/db/H2$Mode; + public static final field Regular Lorg/jetbrains/kotlinx/dataframe/io/db/H2$Mode; + public static fun getEntries ()Lkotlin/enums/EnumEntries; + public final fun getValue ()Ljava/lang/String; + public final fun toDbType ()Lorg/jetbrains/kotlinx/dataframe/io/db/DbType; + public static fun valueOf (Ljava/lang/String;)Lorg/jetbrains/kotlinx/dataframe/io/db/H2$Mode; + public static fun values ()[Lorg/jetbrains/kotlinx/dataframe/io/db/H2$Mode; +} + +public final class org/jetbrains/kotlinx/dataframe/io/db/H2$Mode$Companion { + public final fun fromDbType (Lorg/jetbrains/kotlinx/dataframe/io/db/DbType;)Lorg/jetbrains/kotlinx/dataframe/io/db/H2$Mode; + public final fun fromValue (Ljava/lang/String;)Lorg/jetbrains/kotlinx/dataframe/io/db/H2$Mode; +} + public final class org/jetbrains/kotlinx/dataframe/io/db/MariaDb : org/jetbrains/kotlinx/dataframe/io/db/DbType { public static final field INSTANCE Lorg/jetbrains/kotlinx/dataframe/io/db/MariaDb; public fun buildTableMetadata (Ljava/sql/ResultSet;)Lorg/jetbrains/kotlinx/dataframe/io/db/TableMetadata; diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/H2.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/H2.kt index 4a7b47481d..96cea43724 100644 --- a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/H2.kt +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/H2.kt @@ -4,6 +4,10 @@ import org.jetbrains.kotlinx.dataframe.schema.ColumnSchema import java.sql.ResultSet import java.util.Locale import kotlin.reflect.KType +import org.jetbrains.kotlinx.dataframe.io.db.MariaDb as MariaDbType +import org.jetbrains.kotlinx.dataframe.io.db.MsSql as MsSqlType +import org.jetbrains.kotlinx.dataframe.io.db.MySql as MySqlType +import org.jetbrains.kotlinx.dataframe.io.db.PostgreSql as PostgreSqlType /** * Represents the H2 database type. @@ -13,9 +17,78 @@ import kotlin.reflect.KType * * NOTE: All date and timestamp-related types are converted to String to avoid java.sql.* types. */ -public open class H2(public val dialect: DbType = MySql) : DbType("h2") { - init { - require(dialect::class != H2::class) { "H2 database could not be specified with H2 dialect!" } + +public open class H2(public val mode: Mode = Mode.Regular) : DbType("h2") { + @Deprecated("Use H2(mode = Mode.XXX) instead", ReplaceWith("H2(H2.Mode.MySql)")) + public constructor(dialect: DbType) : this( + Mode.fromDbType(dialect) + ?: throw IllegalArgumentException("H2 database could not be specified with H2 dialect!"), + ) + + private val delegate: DbType? = mode.toDbType() + + /** + * Represents the compatibility modes supported by an H2 database. + * + * @property value The string value used in H2 JDBC URL and settings. + */ + public enum class Mode(public val value: String) { + /** Native H2 mode (no compatibility), our synthetic marker. */ + Regular("H2-Regular"), + MySql("MySQL"), + PostgreSql("PostgreSQL"), + MsSqlServer("MSSQLServer"), + MariaDb("MariaDB"), ; + + /** + * Converts this Mode to the corresponding DbType delegate. + * + * @return The DbType for this mode, or null for Regular mode. + */ + public fun toDbType(): DbType? = + when (this) { + Regular -> null + MySql -> MySqlType + PostgreSql -> PostgreSqlType + MsSqlServer -> MsSqlType + MariaDb -> MariaDbType + } + + public companion object { + /** + * Creates a Mode from the given DbType. + * + * @param dialect The DbType to convert. + * @return The corresponding Mode, or null if the dialect is H2. + */ + public fun fromDbType(dialect: DbType): Mode? = + when (dialect) { + is H2 -> null + MySqlType -> MySql + PostgreSqlType -> PostgreSql + MsSqlType -> MsSqlServer + MariaDbType -> MariaDb + else -> Regular + } + + /** + * Finds a Mode by its string value (case-insensitive). + * Handles both URL values (MySQL, PostgreSQL, etc.) and + * INFORMATION_SCHEMA values (Regular). + * + * @param value The string value to search for. + * @return The matching Mode, or null if not found. + */ + public fun fromValue(value: String): Mode? { + // "Regular" from INFORMATION_SCHEMA or "H2-Regular" from URL + if (value.equals("regular", ignoreCase = true) || + value.equals("h2-regular", ignoreCase = true) + ) { + return Regular + } + return entries.find { it.value.equals(value, ignoreCase = true) } + } + } } /** @@ -29,16 +102,17 @@ public open class H2(public val dialect: DbType = MySql) : DbType("h2") { * @see [createH2Instance] */ public companion object { - /** It represents the mode value "MySQL" for the H2 database. */ + + @Deprecated("Use Mode.MySql.value instead", ReplaceWith("Mode.MySql.value")) public const val MODE_MYSQL: String = "MySQL" - /** It represents the mode value "PostgreSQL" for the H2 database. */ + @Deprecated("Use Mode.PostgreSql.value instead", ReplaceWith("Mode.PostgreSql.value")) public const val MODE_POSTGRESQL: String = "PostgreSQL" - /** It represents the mode value "MSSQLServer" for the H2 database. */ + @Deprecated("Use Mode.MsSqlServer.value instead", ReplaceWith("Mode.MsSqlServer.value")) public const val MODE_MSSQLSERVER: String = "MSSQLServer" - /** It represents the mode value "MariaDB" for the H2 database. */ + @Deprecated("Use Mode.MariaDb.value instead", ReplaceWith("Mode.MariaDb.value")) public const val MODE_MARIADB: String = "MariaDB" } @@ -46,7 +120,7 @@ public open class H2(public val dialect: DbType = MySql) : DbType("h2") { get() = "org.h2.Driver" override fun convertSqlTypeToColumnSchemaValue(tableColumnMetadata: TableColumnMetadata): ColumnSchema? = - dialect.convertSqlTypeToColumnSchemaValue(tableColumnMetadata) + delegate?.convertSqlTypeToColumnSchemaValue(tableColumnMetadata) override fun isSystemTable(tableMetadata: TableMetadata): Boolean { val locale = Locale.getDefault() @@ -57,14 +131,24 @@ public open class H2(public val dialect: DbType = MySql) : DbType("h2") { // could be extended for other symptoms of the system tables for H2 val isH2SystemTable = schemaName.containsWithLowercase("information_schema") - return isH2SystemTable || dialect.isSystemTable(tableMetadata) + return if (delegate == null) { + isH2SystemTable + } else { + isH2SystemTable || delegate.isSystemTable(tableMetadata) + } } - override fun buildTableMetadata(tables: ResultSet): TableMetadata = dialect.buildTableMetadata(tables) + override fun buildTableMetadata(tables: ResultSet): TableMetadata = + delegate?.buildTableMetadata(tables) + ?: TableMetadata( + tables.getString("table_name"), + tables.getString("table_schem"), + tables.getString("table_cat"), + ) override fun convertSqlTypeToKType(tableColumnMetadata: TableColumnMetadata): KType? = - dialect.convertSqlTypeToKType(tableColumnMetadata) + delegate?.convertSqlTypeToKType(tableColumnMetadata) public override fun buildSqlQueryWithLimit(sqlQuery: String, limit: Int): String = - dialect.buildSqlQueryWithLimit(sqlQuery, limit) + delegate?.buildSqlQueryWithLimit(sqlQuery, limit) ?: super.buildSqlQueryWithLimit(sqlQuery, limit) } diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/util.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/util.kt index 1d6752cf90..bea3f53d77 100644 --- a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/util.kt +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/db/util.kt @@ -3,12 +3,20 @@ package org.jetbrains.kotlinx.dataframe.io.db import io.github.oshai.kotlinlogging.KotlinLogging import java.sql.Connection import java.sql.SQLException -import java.util.Locale private val logger = KotlinLogging.logger {} +private const val UNSUPPORTED_H2_MODE_MESSAGE = + "Unsupported H2 MODE: %s. Supported: MySQL, PostgreSQL, MSSQLServer, MariaDB, REGULAR/H2-Regular (or omit MODE)." + +private const val H2_MODE_QUERY = "SELECT SETTING_VALUE FROM INFORMATION_SCHEMA.SETTINGS WHERE SETTING_NAME = 'MODE'" + +private val H2_MODE_URL_PATTERN = "MODE=([^;:&]+)".toRegex(RegexOption.IGNORE_CASE) + /** * Extracts the database type from the given connection. + * For H2, fetches the actual MODE from the active connection settings. + * For other databases, extracts type from URL. * * @param [connection] the database connection. * @return the corresponding [DbType]. @@ -21,44 +29,56 @@ public fun extractDBTypeFromConnection(connection: Connection): DbType { ?: throw IllegalStateException("URL information is missing in connection meta data!") logger.info { "Processing DB type extraction for connection url: $url" } - return if (url.contains(H2().dbTypeInJdbcUrl)) { - // works only for H2 version 2 - val modeQuery = "SELECT SETTING_VALUE FROM INFORMATION_SCHEMA.SETTINGS WHERE SETTING_NAME = 'MODE'" - var mode = "" - connection.prepareStatement(modeQuery).use { st -> - st.executeQuery().use { rs -> - if (rs.next()) { - mode = rs.getString("SETTING_VALUE") - logger.debug { "Fetched H2 DB mode: $mode" } - } else { - throw IllegalStateException("The information about H2 mode is not found in the H2 meta-data!") - } - } - } - - // H2 doesn't support MariaDB and SQLite - when (mode.lowercase(Locale.getDefault())) { - H2.MODE_MYSQL.lowercase(Locale.getDefault()) -> H2(MySql) + // First, determine the base database type from URL + val baseDbType = extractDBTypeFromUrl(url) - H2.MODE_MSSQLSERVER.lowercase(Locale.getDefault()) -> H2(MsSql) - - H2.MODE_POSTGRESQL.lowercase(Locale.getDefault()) -> H2(PostgreSql) - - H2.MODE_MARIADB.lowercase(Locale.getDefault()) -> H2(MariaDb) - - else -> { - val message = "Unsupported database type in the url: $url. " + - "Only MySQL, MariaDB, MSSQL and PostgreSQL are supported!" - logger.error { message } + // For H2, refine the mode by querying the active connection settings + // This handles cases where MODE is not specified in URL, but H2 returns "Regular" from settings + return if (baseDbType is H2) { + val mode = fetchH2ModeFromConnection(connection) + parseH2ModeOrThrow(mode) + } else { + logger.info { "Identified DB type as $baseDbType from url: $url" } + baseDbType + } +} - throw IllegalArgumentException(message) +/** + * Fetches H2 database mode from an active connection. + * Works only for H2 version 2. + * + * @param [connection] the database connection. + * @return the mode string or null if not set. + */ +private fun fetchH2ModeFromConnection(connection: Connection): String? { + var mode: String? = null + connection.prepareStatement(H2_MODE_QUERY).use { st -> + st.executeQuery().use { rs -> + if (rs.next()) { + mode = rs.getString("SETTING_VALUE") + logger.debug { "Fetched H2 DB mode: $mode" } } } - } else { - val dbType = extractDBTypeFromUrl(url) - logger.info { "Identified DB type as $dbType from url: $url" } - dbType } + + return mode?.trim()?.takeIf { it.isNotEmpty() } +} + +/** + * Parses H2 mode string and returns the corresponding H2 DbType instance. + * + * @param [mode] the mode string (may be null or empty for Regular mode). + * @return H2 instance with the appropriate mode. + * @throws [IllegalArgumentException] if the mode is not supported. + */ +private fun parseH2ModeOrThrow(mode: String?): H2 { + if (mode.isNullOrEmpty()) { + return H2(H2.Mode.Regular) + } + return H2.Mode.fromValue(mode)?.let { H2(it) } + ?: throw IllegalArgumentException(UNSUPPORTED_H2_MODE_MESSAGE.format(mode)).also { + logger.error { it.message } + } } /** @@ -66,33 +86,31 @@ public fun extractDBTypeFromConnection(connection: Connection): DbType { * * @param [url] the JDBC URL. * @return the corresponding [DbType]. - * @throws [RuntimeException] if the url is null. + * @throws [SQLException] if the url is null. + * @throws [IllegalArgumentException] if the URL specifies an unsupported database type. */ public fun extractDBTypeFromUrl(url: String?): DbType { - if (url != null) { - val helperH2Instance = H2() - return when { - helperH2Instance.dbTypeInJdbcUrl in url -> createH2Instance(url) + url ?: throw SQLException("Database URL could not be null.") - MariaDb.dbTypeInJdbcUrl in url -> MariaDb + return when { + H2().dbTypeInJdbcUrl in url -> createH2Instance(url) - MySql.dbTypeInJdbcUrl in url -> MySql + MariaDb.dbTypeInJdbcUrl in url -> MariaDb - Sqlite.dbTypeInJdbcUrl in url -> Sqlite + MySql.dbTypeInJdbcUrl in url -> MySql - PostgreSql.dbTypeInJdbcUrl in url -> PostgreSql + Sqlite.dbTypeInJdbcUrl in url -> Sqlite - MsSql.dbTypeInJdbcUrl in url -> MsSql + PostgreSql.dbTypeInJdbcUrl in url -> PostgreSql - DuckDb.dbTypeInJdbcUrl in url -> DuckDb + MsSql.dbTypeInJdbcUrl in url -> MsSql - else -> throw IllegalArgumentException( - "Unsupported database type in the url: $url. " + - "Only H2, MariaDB, MySQL, MSSQL, SQLite, PostgreSQL, and DuckDB are supported!", - ) - } - } else { - throw SQLException("Database URL could not be null. The existing value is $url") + DuckDb.dbTypeInJdbcUrl in url -> DuckDb + + else -> throw IllegalArgumentException( + "Unsupported database type in the url: $url. " + + "Only H2, MariaDB, MySQL, MSSQL, SQLite, PostgreSQL, and DuckDB are supported!", + ) } } @@ -104,30 +122,8 @@ public fun extractDBTypeFromUrl(url: String?): DbType { * @throws [IllegalArgumentException] if the provided URL does not contain a valid mode. */ private fun createH2Instance(url: String): DbType { - val modePattern = "MODE=(.*?);".toRegex() - val matchResult = modePattern.find(url) - - val mode: String = if (matchResult != null && matchResult.groupValues.size == 2) { - matchResult.groupValues[1] - } else { - throw IllegalArgumentException("The provided URL `$url` does not contain a valid mode.") - } - - // H2 doesn't support MariaDB and SQLite - return when (mode.lowercase(Locale.getDefault())) { - H2.MODE_MYSQL.lowercase(Locale.getDefault()) -> H2(MySql) - - H2.MODE_MSSQLSERVER.lowercase(Locale.getDefault()) -> H2(MsSql) - - H2.MODE_POSTGRESQL.lowercase(Locale.getDefault()) -> H2(PostgreSql) - - H2.MODE_MARIADB.lowercase(Locale.getDefault()) -> H2(MariaDb) - - else -> throw IllegalArgumentException( - "Unsupported database mode: $mode. " + - "Only MySQL, MariaDB, MSSQL, PostgreSQL modes are supported!", - ) - } + val mode = H2_MODE_URL_PATTERN.find(url)?.groupValues?.getOrNull(1) + return parseH2ModeOrThrow(mode?.takeIf { it.isNotBlank() }) } /** diff --git a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2/h2Test.kt b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2/h2Test.kt index 9d0d0c8279..40c09fd04a 100644 --- a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2/h2Test.kt +++ b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2/h2Test.kt @@ -5,6 +5,7 @@ import com.zaxxer.hikari.HikariDataSource import io.kotest.assertions.throwables.shouldThrow import io.kotest.assertions.throwables.shouldThrowExactly import io.kotest.matchers.shouldBe +import io.kotest.matchers.shouldNotBe import org.intellij.lang.annotations.Language import org.jetbrains.kotlinx.dataframe.AnyFrame import org.jetbrains.kotlinx.dataframe.DataFrame @@ -15,7 +16,14 @@ import org.jetbrains.kotlinx.dataframe.api.filter import org.jetbrains.kotlinx.dataframe.api.select import org.jetbrains.kotlinx.dataframe.io.DbConnectionConfig import org.jetbrains.kotlinx.dataframe.io.db.H2 +import org.jetbrains.kotlinx.dataframe.io.db.H2.Mode import org.jetbrains.kotlinx.dataframe.io.db.MySql +import org.jetbrains.kotlinx.dataframe.io.db.PostgreSql +import org.jetbrains.kotlinx.dataframe.io.db.Sqlite +import org.jetbrains.kotlinx.dataframe.io.db.TableMetadata +import org.jetbrains.kotlinx.dataframe.io.db.driverClassNameFromUrl +import org.jetbrains.kotlinx.dataframe.io.db.extractDBTypeFromConnection +import org.jetbrains.kotlinx.dataframe.io.db.extractDBTypeFromUrl import org.jetbrains.kotlinx.dataframe.io.inferNullability import org.jetbrains.kotlinx.dataframe.io.readAllSqlTables import org.jetbrains.kotlinx.dataframe.io.readDataFrame @@ -39,6 +47,8 @@ private const val URL = "jdbc:h2:mem:test5;DB_CLOSE_DELAY=-1;MODE=MySQL;DATABASE private const val MAXIMUM_POOL_SIZE = 5 +private const val QUERY_SELECT_ONE = "SELECT 1" + @DataSchema interface Customer { val id: Int? @@ -478,22 +488,24 @@ class JdbcTest { @Test fun `read from ResultSet`() { + val dbType = H2(Mode.MySql) + connection.createStatement(ResultSet.TYPE_SCROLL_SENSITIVE, ResultSet.CONCUR_UPDATABLE).use { st -> @Language("SQL") val selectStatement = "SELECT * FROM Customer" st.executeQuery(selectStatement).use { rs -> - val df = DataFrame.readResultSet(rs, H2(MySql)) + val df = DataFrame.readResultSet(rs, dbType) assertCustomerData(df) rs.beforeFirst() - val df1 = DataFrame.readResultSet(rs, H2(MySql), 1) + val df1 = DataFrame.readResultSet(rs, dbType, 1) assertCustomerData(df1, 1) rs.beforeFirst() - val dataSchema = DataFrameSchema.readResultSet(rs, H2(MySql)) + val dataSchema = DataFrameSchema.readResultSet(rs, dbType) assertCustomerSchema(dataSchema) rs.beforeFirst() @@ -508,7 +520,7 @@ class JdbcTest { rs.beforeFirst() - val dataSchema1 = DataFrameSchema.readResultSet(rs, H2(MySql)) + val dataSchema1 = DataFrameSchema.readResultSet(rs, dbType) assertCustomerSchema(dataSchema1) } } @@ -516,22 +528,24 @@ class JdbcTest { @Test fun `read from extension function on ResultSet`() { + val dbType = H2(Mode.MySql) + connection.createStatement(ResultSet.TYPE_SCROLL_SENSITIVE, ResultSet.CONCUR_UPDATABLE).use { st -> @Language("SQL") val selectStatement = "SELECT * FROM Customer" st.executeQuery(selectStatement).use { rs -> - val df = rs.readDataFrame(H2(MySql)) + val df = rs.readDataFrame(dbType) assertCustomerData(df) rs.beforeFirst() - val df1 = rs.readDataFrame(H2(MySql), 1) + val df1 = rs.readDataFrame(dbType, 1) assertCustomerData(df1, 1) rs.beforeFirst() - val dataSchema = rs.readDataFrameSchema(H2(MySql)) + val dataSchema = rs.readDataFrameSchema(dbType) assertCustomerSchema(dataSchema) rs.beforeFirst() @@ -546,7 +560,7 @@ class JdbcTest { rs.beforeFirst() - val dataSchema1 = rs.readDataFrameSchema(H2(MySql)) + val dataSchema1 = rs.readDataFrameSchema(dbType) assertCustomerSchema(dataSchema1) } } @@ -563,7 +577,7 @@ class JdbcTest { repeat(10) { rs.beforeFirst() - val df1 = DataFrame.readResultSet(rs, H2(MySql), 2) + val df1 = DataFrame.readResultSet(rs, H2(Mode.MySql), 2) assertCustomerData(df1, 2) rs.beforeFirst() @@ -974,8 +988,190 @@ class JdbcTest { exception.message shouldBe "H2 database could not be specified with H2 dialect!" } + @Test + fun `regular mode for H2 with DbConnectionConfig`() { + val url = "jdbc:h2:mem:testDatabase" + + val dbConfig = DbConnectionConfig(url) + + val df = DataFrame.readSqlQuery(dbConfig, QUERY_SELECT_ONE) + df.rowsCount() shouldBe 1 + } + + @Test + fun `regular mode for H2 with Connection`() { + val url = "jdbc:h2:mem:testDatabase" + + DriverManager.getConnection(url).use { connection -> + val df = DataFrame.readSqlQuery(connection, QUERY_SELECT_ONE) + df.rowsCount() shouldBe 1 + } + } + + // ========== H2 Mode Tests ========== + + private fun testH2ModeWithDbConnectionConfig(modeUrl: String) { + val dbConfig = DbConnectionConfig(modeUrl) + val df = DataFrame.readSqlQuery(dbConfig, QUERY_SELECT_ONE) + df.rowsCount() shouldBe 1 + } + + private fun testH2ModeWithConnection(modeUrl: String) { + DriverManager.getConnection(modeUrl).use { connection -> + val df = DataFrame.readSqlQuery(connection, QUERY_SELECT_ONE) + df.rowsCount() shouldBe 1 + } + } + + @Test + fun `MySQL mode for H2 with DbConnectionConfig`() { + testH2ModeWithDbConnectionConfig("jdbc:h2:mem:testMySql;MODE=MySQL") + } + + @Test + fun `MySQL mode for H2 with Connection`() { + testH2ModeWithConnection("jdbc:h2:mem:testMySql;MODE=MySQL") + } + + @Test + fun `PostgreSQL mode for H2 with DbConnectionConfig`() { + testH2ModeWithDbConnectionConfig("jdbc:h2:mem:testPostgres;MODE=PostgreSQL") + } + + @Test + fun `PostgreSQL mode for H2 with Connection`() { + testH2ModeWithConnection("jdbc:h2:mem:testPostgres;MODE=PostgreSQL") + } + + @Test + fun `MSSQLServer mode for H2 with DbConnectionConfig`() { + testH2ModeWithDbConnectionConfig("jdbc:h2:mem:testMsSql;MODE=MSSQLServer") + } + + @Test + fun `MSSQLServer mode for H2 with Connection`() { + testH2ModeWithConnection("jdbc:h2:mem:testMsSql;MODE=MSSQLServer") + } + + @Test + fun `MariaDB mode for H2 with DbConnectionConfig`() { + testH2ModeWithDbConnectionConfig("jdbc:h2:mem:testMariaDb;MODE=MariaDB") + } + + @Test + fun `MariaDB mode for H2 with Connection`() { + testH2ModeWithConnection("jdbc:h2:mem:testMariaDb;MODE=MariaDB") + } + + @Test + fun `H2 with unsupported mode throws exception`() { + val url = "jdbc:h2:mem:testUnsupported;MODE=DB2" + + DriverManager.getConnection(url).use { connection -> + shouldThrow { + DataFrame.readSqlQuery(connection, QUERY_SELECT_ONE) + } + } + } + + @Test + fun `H2 with unsupported mode throws exception using DbConnectionConfig`() { + val url = "jdbc:h2:mem:testUnsupported;MODE=Oracle" + val dbConfig = DbConnectionConfig(url) + + shouldThrow { + DataFrame.readSqlQuery(dbConfig, QUERY_SELECT_ONE) + } + } + + @Test + fun `H2 Regular mode extraction and fallbacks`() { + // 1. Create a connection without explicit MODE in URL. + // H2 defaults to Regular mode. extractDBTypeFromConnection should detect this by querying settings. + DriverManager.getConnection("jdbc:h2:mem:testRegularFallback").use { conn -> + val dbType = extractDBTypeFromConnection(conn) + + (dbType is H2) shouldBe true + (dbType as H2).mode shouldBe Mode.Regular + + // 2. Verify fallback behaviors (when delegate is null) + + // buildSqlQueryWithLimit: Check fallback to super implementation (standard LIMIT syntax) + val query = "SELECT * FROM table" + dbType.buildSqlQueryWithLimit(query, 10) shouldBe "SELECT * FROM table LIMIT 10" + + // isSystemTable: Check fallback to H2-specific logic (INFORMATION_SCHEMA) + val systemTable = TableMetadata("SETTINGS", "INFORMATION_SCHEMA", "TEST_DB") + dbType.isSystemTable(systemTable) shouldBe true + + val userTable = TableMetadata("USERS", "PUBLIC", "TEST_DB") + dbType.isSystemTable(userTable) shouldBe false + + // buildTableMetadata: Check fallback to reading from ResultSet directly + conn.createStatement().use { st -> + st.execute("CREATE TABLE MY_FALLBACK_TABLE (ID INT)") + } + conn.metaData.getTables(null, null, "MY_FALLBACK_TABLE", null).use { rs -> + if (rs.next()) { + val metadata = dbType.buildTableMetadata(rs) + metadata.name shouldBe "MY_FALLBACK_TABLE" + metadata.schemaName shouldBe "PUBLIC" + metadata.catalogue shouldNotBe null + } else { + throw IllegalStateException("Could not find created table metadata") + } + } + } + } + + @Test + fun `database type extraction utils`() { + // 1. Test direct extraction from URL for various DBs + (extractDBTypeFromUrl("jdbc:mysql://localhost:3306/db") is MySql) shouldBe true + (extractDBTypeFromUrl("jdbc:postgresql://localhost:5432/db") is PostgreSql) shouldBe true + (extractDBTypeFromUrl("jdbc:sqlite:sample.db") is Sqlite) shouldBe true + + // Test driverClassNameFromUrl + driverClassNameFromUrl("jdbc:mysql://localhost:3306/db") shouldBe "com.mysql.jdbc.Driver" + driverClassNameFromUrl("jdbc:postgresql://localhost:5432/db") shouldBe "org.postgresql.Driver" + driverClassNameFromUrl("jdbc:h2:mem:test") shouldBe "org.h2.Driver" + + // 2. Test unsupported Database URL + shouldThrow { + extractDBTypeFromUrl("jdbc:oracle:thin:@localhost:1521:xe") + } + + // 3. Test null URL + shouldThrow { + extractDBTypeFromUrl(null) + } + + // 4. Test H2 specific mode extraction from Connection (End-to-End) + + // Case A: MySQL Mode via URL + DriverManager.getConnection("jdbc:h2:mem:testExtractMySql;MODE=MySQL").use { conn -> + val dbType = extractDBTypeFromConnection(conn) + (dbType is H2) shouldBe true + (dbType as H2).mode shouldBe H2.Mode.MySql + } + + // Case B: PostgreSQL Mode via URL + DriverManager.getConnection("jdbc:h2:mem:testExtractPostgres;MODE=PostgreSQL").use { conn -> + val dbType = extractDBTypeFromConnection(conn) + (dbType is H2) shouldBe true + (dbType as H2).mode shouldBe H2.Mode.PostgreSql + } + + // Case C: MSSQLServer Mode via URL + DriverManager.getConnection("jdbc:h2:mem:testExtractMsSql;MODE=MSSQLServer").use { conn -> + val dbType = extractDBTypeFromConnection(conn) + (dbType is H2) shouldBe true + (dbType as H2).mode shouldBe H2.Mode.MsSqlServer + } + } + // helper object created for API testing purposes - object CustomDB : H2(MySql) + object CustomDB : H2(Mode.MySql) @Test fun `read from table from custom database`() { diff --git a/docs/StardustDocs/topics/dataSources/sql/H2.md b/docs/StardustDocs/topics/dataSources/sql/H2.md index a288628fd7..bc57d64488 100644 --- a/docs/StardustDocs/topics/dataSources/sql/H2.md +++ b/docs/StardustDocs/topics/dataSources/sql/H2.md @@ -56,10 +56,18 @@ It is also possible to load all data from non-system tables, each into a separat See [](readSqlDatabases.md) for more details. +### H2 Compatibility Modes + +When working with H2 database, the library automatically detects the compatibility mode from the connection. +If no `MODE` is specified in the JDBC URL, the default `Regular` mode is used. +H2 supports the following compatibility modes: `MySQL`, `PostgreSQL`, `MSSQLServer`, `MariaDB`, and `Regular`. + ```kotlin import org.jetbrains.kotlinx.dataframe.io.DbConnectionConfig import org.jetbrains.kotlinx.dataframe.api.* +// Basic H2 connection (uses Regular mode by default) + val url = "jdbc:h2:mem:testDatabase" val username = "sa" val password = "" @@ -71,3 +79,20 @@ val tableName = "Customer" val df = DataFrame.readSqlTable(dbConfig, tableName) ``` +```kotlin +import org.jetbrains.kotlinx.dataframe.io.DbConnectionConfig +import org.jetbrains.kotlinx.dataframe.api.* + +// H2 with PostgreSQL compatibility mode + +val postgresUrl = "jdbc:h2:mem:testDatabase;MODE=PostgreSQL" +val username = "sa" +val password = "" + +val postgresConfig = DbConnectionConfig(postgresUrl, username, password) + +val tableName = "Customer" + +val dfPostgres = DataFrame.readSqlTable(postgresConfig, tableName) +``` +