From f6e412653f4ccfb0b345e5c2c38ed70226c873ac Mon Sep 17 00:00:00 2001 From: zaleslaw Date: Thu, 13 Nov 2025 16:11:40 +0100 Subject: [PATCH] Add progress tracking and memory estimation for JDBC module This commit introduces progress tracking functionality for the JDBC module, including both simple and detailed modes configurable via system properties. Added new tests, configuration class (`JdbcConfig`), and memory estimation utilities to ensure efficient handling of large datasets. --- .../kotlinx/dataframe/io/JdbcConfig.kt | 68 +++ .../kotlinx/dataframe/io/JdbcSafeLoading.kt | 419 ++++++++++++++++++ .../kotlinx/dataframe/io/MemoryEstimate.kt | 52 +++ .../kotlinx/dataframe/io/MemoryEstimator.kt | 208 +++++++++ .../kotlinx/dataframe/io/progressTracker.kt | 301 +++++++++++++ .../kotlinx/dataframe/io/readJdbc.kt | 29 +- .../io/h2/tracking/JdbcProgressOutputTest.kt | 257 +++++++++++ .../io/h2/tracking/JdbcProgressTest.kt | 137 ++++++ .../io/h2/tracking/JdbcSafeLoadingTest.kt | 177 ++++++++ 9 files changed, 1646 insertions(+), 2 deletions(-) create mode 100644 dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/JdbcConfig.kt create mode 100644 dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/JdbcSafeLoading.kt create mode 100644 dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/MemoryEstimate.kt create mode 100644 dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/MemoryEstimator.kt create mode 100644 dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/progressTracker.kt create mode 100644 dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2/tracking/JdbcProgressOutputTest.kt create mode 100644 dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2/tracking/JdbcProgressTest.kt create mode 100644 dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2/tracking/JdbcSafeLoadingTest.kt diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/JdbcConfig.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/JdbcConfig.kt new file mode 100644 index 0000000000..513ab82a17 --- /dev/null +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/JdbcConfig.kt @@ -0,0 +1,68 @@ +package org.jetbrains.kotlinx.dataframe.io + +/** + * Internal configuration for JDBC module behavior. + * Controlled via system properties or environment variables. + * + * Example usage: + * ```kotlin + * // Enable progress with detailed statistics + * System.setProperty("dataframe.jdbc.progress", "true") + * System.setProperty("dataframe.jdbc.progress.detailed", "true") + * + * // Or via environment variables + * export DATAFRAME_JDBC_PROGRESS=true + * export DATAFRAME_JDBC_PROGRESS_DETAILED=true + * ``` + */ +internal object JdbcConfig { + /** + * Enable progress logging during data loading. + * + * When `true`, enables DEBUG-level progress messages and memory estimates. + * When `false` (default), only respects logger level (e.g., DEBUG). + * + * System property: `-Ddataframe.jdbc.progress=true` + * Environment variable: `DATAFRAME_JDBC_PROGRESS=true` + * Default: `false` + */ + @JvmStatic + val PROGRESS_ENABLED: Boolean by lazy { + System.getProperty("dataframe.jdbc.progress")?.toBoolean() + ?: System.getenv("DATAFRAME_JDBC_PROGRESS")?.toBoolean() + ?: false + } + + /** + * Enable detailed progress logging with statistics. + * + * When `true`: Shows row count, percentage, speed (rows/sec), memory usage + * When `false`: Shows only basic "Loaded X rows" messages + * + * Only takes effect when progress is enabled. + * + * System property: `-Ddataframe.jdbc.progress.detailed=true` + * Environment variable: `DATAFRAME_JDBC_PROGRESS_DETAILED=true` + * Default: `true` + */ + @JvmStatic + val PROGRESS_DETAILED: Boolean by lazy { + System.getProperty("dataframe.jdbc.progress.detailed")?.toBoolean() + ?: System.getenv("DATAFRAME_JDBC_PROGRESS_DETAILED")?.toBoolean() + ?: true + } + + /** + * Report progress every N rows. + * + * System property: `-Ddataframe.jdbc.progress.interval=5000` + * Environment variable: `DATAFRAME_JDBC_PROGRESS_INTERVAL=5000` + * Default: `1000` + */ + @JvmStatic + val PROGRESS_INTERVAL: Int by lazy { + System.getProperty("dataframe.jdbc.progress.interval")?.toIntOrNull() + ?: System.getenv("DATAFRAME_JDBC_PROGRESS_INTERVAL")?.toIntOrNull() + ?: 1000 + } +} diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/JdbcSafeLoading.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/JdbcSafeLoading.kt new file mode 100644 index 0000000000..f8174872d1 --- /dev/null +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/JdbcSafeLoading.kt @@ -0,0 +1,419 @@ +package org.jetbrains.kotlinx.dataframe.io + +import io.github.oshai.kotlinlogging.KotlinLogging +import org.jetbrains.kotlinx.dataframe.AnyFrame +import org.jetbrains.kotlinx.dataframe.DataFrame +import org.jetbrains.kotlinx.dataframe.io.db.DbType +import org.jetbrains.kotlinx.dataframe.io.db.extractDBTypeFromConnection +import java.sql.Connection +import java.sql.PreparedStatement +import javax.sql.DataSource + +private val logger = KotlinLogging.logger {} + +/** + * Safe JDBC data loading with automatic memory estimation and limit management. + * + * Uses the same method names as the main API but adds automatic memory checking: + * + * ```kotlin + * // Instead of: + * DataFrame.readSqlTable(config, "table") + * + * // Use: + * JdbcSafeDataLoading.load(maxMemoryGb = 1.0) { + * readSqlTable(config, "table") + * } + * ``` + * + * ## Features + * - Automatic memory estimation before loading (10 sample rows + COUNT(*)) + * - Automatic limit application when memory threshold exceeded + * - Configurable behavior: throw, warn, or auto-limit + * - Callbacks for estimates and limit events + * + * ## Examples + * + * Basic usage: + * ```kotlin + * val df = JdbcSafeDataLoading.load(maxMemoryGb = 1.0) { + * readSqlTable(config, "large_table") + * } + * ``` + * + * Advanced configuration: + * ```kotlin + * val df = JdbcSafeDataLoading.load { + * maxMemoryGb = 2.0 + * onExceed = ExceedAction.THROW + * + * onEstimate = { estimate -> + * println("Estimated: ${estimate.humanReadable}") + * } + * + * readSqlTable(config, "huge_table") + * } + * ``` + */ +public object JdbcSafeDataLoading { + + /** + * Loads JDBC data with automatic memory checking. + */ + public fun load( + maxMemoryGb: Double = 1.0, + block: SafeLoadContext.() -> T, + ): T { + val config = LoadConfig().apply { this.maxMemoryGb = maxMemoryGb } + return SafeLoadContext(config).block() + } + + /** + * Advanced version with full configuration. + */ + public fun load( + configure: LoadConfig.() -> Unit, + block: SafeLoadContext.() -> T, + ): T { + val config = LoadConfig().apply(configure) + return SafeLoadContext(config).block() + } + + /** + * For loading multiple results (like Map). + */ + public fun loadMultiple( + maxMemoryGb: Double = 1.0, + block: SafeLoadContext.() -> T, + ): T { + val config = LoadConfig().apply { this.maxMemoryGb = maxMemoryGb } + return SafeLoadContext(config).block() + } + + /** + * Advanced version for multiple results. + */ + public fun loadMultiple( + configure: LoadConfig.() -> Unit, + block: SafeLoadContext.() -> T, + ): T { + val config = LoadConfig().apply(configure) + return SafeLoadContext(config).block() + } + + /** + * Configuration for safe JDBC loading. + */ + public class LoadConfig { + /** Maximum allowed memory in bytes. */ + public var maxMemoryBytes: Long = 1024L * 1024L * 1024L + + /** Maximum allowed memory in gigabytes (convenience setter). */ + public var maxMemoryGb: Double + get() = maxMemoryBytes / (1024.0 * 1024.0 * 1024.0) + set(value) { maxMemoryBytes = (value * 1024 * 1024 * 1024).toLong() } + + /** What to do when memory limit is exceeded. */ + public var onExceed: ExceedAction = ExceedAction.APPLY_LIMIT + + /** Callback invoked when estimate is available (before loading). */ + public var onEstimate: ((MemoryEstimate) -> Unit)? = null + + /** Callback invoked when limit is automatically applied. */ + public var onLimitApplied: ((estimate: MemoryEstimate, appliedLimit: Int) -> Unit)? = null + + /** For loadAllTables: callback invoked for each table's estimate. */ + public var onTableEstimate: ((tableName: String, estimate: MemoryEstimate) -> Unit)? = null + } + + /** Action to take when estimated memory exceeds the limit. */ + public enum class ExceedAction { + /** Automatically apply limit to stay under threshold */ + APPLY_LIMIT, + + /** Throw MemoryLimitExceededException and don't load */ + THROW, + + /** Log warning but proceed with full load */ + WARN_AND_PROCEED, + } + + /** Exception thrown when memory limit exceeded and action is THROW. */ + public class MemoryLimitExceededException( + message: String, + public val estimate: MemoryEstimate, + ) : IllegalStateException(message) + + /** + * Context with the same method names as DataFrame.Companion. + */ + public class SafeLoadContext internal constructor( + private val config: LoadConfig, + ) { + // region readSqlTable + + public fun readSqlTable( + dbConfig: DbConnectionConfig, + tableName: String, + dbType: DbType? = null, + inferNullability: Boolean = true, + strictValidation: Boolean = true, + configureStatement: (PreparedStatement) -> Unit = {}, + ): AnyFrame = estimateAndLoad( + estimate = { estimateSqlTable(dbConfig, tableName, null, dbType) }, + load = { limit -> + DataFrame.readSqlTable(dbConfig, tableName, limit, inferNullability, dbType, strictValidation, configureStatement) + } + ) + + public fun readSqlTable( + connection: Connection, + tableName: String, + dbType: DbType? = null, + inferNullability: Boolean = true, + strictValidation: Boolean = true, + configureStatement: (PreparedStatement) -> Unit = {}, + ): AnyFrame = estimateAndLoad( + estimate = { estimateSqlTable(connection, tableName, null, dbType) }, + load = { limit -> + DataFrame.readSqlTable(connection, tableName, limit, inferNullability, dbType, strictValidation, configureStatement) + } + ) + + public fun readSqlTable( + dataSource: DataSource, + tableName: String, + dbType: DbType? = null, + inferNullability: Boolean = true, + strictValidation: Boolean = true, + configureStatement: (PreparedStatement) -> Unit = {}, + ): AnyFrame = dataSource.connection.use { conn -> + readSqlTable(conn, tableName, dbType, inferNullability, strictValidation, configureStatement) + } + + // endregion + + // region readSqlQuery + + public fun readSqlQuery( + dbConfig: DbConnectionConfig, + sqlQuery: String, + dbType: DbType? = null, + inferNullability: Boolean = true, + strictValidation: Boolean = true, + configureStatement: (PreparedStatement) -> Unit = {}, + ): AnyFrame = estimateAndLoad( + estimate = { estimateSqlQuery(dbConfig, sqlQuery, null, dbType) }, + load = { limit -> + DataFrame.readSqlQuery(dbConfig, sqlQuery, limit, inferNullability, dbType, strictValidation, configureStatement) + } + ) + + public fun readSqlQuery( + connection: Connection, + sqlQuery: String, + dbType: DbType? = null, + inferNullability: Boolean = true, + strictValidation: Boolean = true, + configureStatement: (PreparedStatement) -> Unit = {}, + ): AnyFrame = estimateAndLoad( + estimate = { estimateSqlQuery(connection, sqlQuery, null, dbType) }, + load = { limit -> + DataFrame.readSqlQuery(connection, sqlQuery, limit, inferNullability, dbType, strictValidation, configureStatement) + } + ) + + public fun readSqlQuery( + dataSource: DataSource, + sqlQuery: String, + dbType: DbType? = null, + inferNullability: Boolean = true, + strictValidation: Boolean = true, + configureStatement: (PreparedStatement) -> Unit = {}, + ): AnyFrame = dataSource.connection.use { conn -> + readSqlQuery(conn, sqlQuery, dbType, inferNullability, strictValidation, configureStatement) + } + + // endregion + + // region readAllSqlTables + + public fun readAllSqlTables( + dbConfig: DbConnectionConfig, + catalogue: String? = null, + dbType: DbType? = null, + inferNullability: Boolean = true, + configureStatement: (PreparedStatement) -> Unit = {}, + ): Map = withReadOnlyConnection(dbConfig, dbType) { connection -> + readAllSqlTables(connection, catalogue, dbType, inferNullability, configureStatement) + } + + public fun readAllSqlTables( + connection: Connection, + catalogue: String? = null, + dbType: DbType? = null, + inferNullability: Boolean = true, + configureStatement: (PreparedStatement) -> Unit = {}, + ): Map { + val determinedDbType = dbType ?: extractDBTypeFromConnection(connection) + val metaData = connection.metaData + val tablesResultSet = retrieveTableMetadata(metaData, catalogue, determinedDbType) + + return buildMap { + while (tablesResultSet.next()) { + val tableMetadata = determinedDbType.buildTableMetadata(tablesResultSet) + + if (determinedDbType.isSystemTable(tableMetadata)) continue + + val fullTableName = buildFullTableName(catalogue, tableMetadata.schemaName, tableMetadata.name) + + val dataFrame = estimateAndLoad( + estimate = { + val est = estimateSqlTable(connection, fullTableName, null, dbType) + config.onTableEstimate?.invoke(fullTableName, est) + est + }, + load = { limit -> + DataFrame.readSqlTable(connection, fullTableName, limit, inferNullability, determinedDbType, true, configureStatement) + } + ) + + put(fullTableName, dataFrame) + } + } + } + + public fun readAllSqlTables( + dataSource: DataSource, + catalogue: String? = null, + dbType: DbType? = null, + inferNullability: Boolean = true, + configureStatement: (PreparedStatement) -> Unit = {}, + ): Map = dataSource.connection.use { conn -> + readAllSqlTables(conn, catalogue, dbType, inferNullability, configureStatement) + } + + // endregion + + private fun estimateAndLoad( + estimate: () -> MemoryEstimate, + load: (limit: Int?) -> AnyFrame, + ): AnyFrame { + val memoryEstimate = estimate() + config.onEstimate?.invoke(memoryEstimate) + val finalLimit = handleMemoryLimit(memoryEstimate, null) + return load(finalLimit) + } + + private fun handleMemoryLimit(estimate: MemoryEstimate, originalLimit: Int?): Int? { + if (!estimate.exceeds(config.maxMemoryBytes)) return originalLimit + + return when (config.onExceed) { + ExceedAction.APPLY_LIMIT -> { + if (originalLimit != null) { + originalLimit + } else { + val recommendedLimit = estimate.recommendedLimit(config.maxMemoryBytes) + logger.warn { + "Estimated memory ${estimate.humanReadable} exceeds limit " + + "${formatBytes(config.maxMemoryBytes)}. Applying limit: $recommendedLimit rows" + } + config.onLimitApplied?.invoke(estimate, recommendedLimit) + recommendedLimit + } + } + + ExceedAction.THROW -> throw MemoryLimitExceededException( + "Estimated memory ${estimate.humanReadable} exceeds limit ${formatBytes(config.maxMemoryBytes)}", + estimate + ) + + ExceedAction.WARN_AND_PROCEED -> { + logger.warn { "Memory ${estimate.humanReadable} exceeds limit, but proceeding" } + originalLimit + } + } + } + } +} + +// Helper functions for estimation + +private fun estimateSqlTable( + dbConfig: DbConnectionConfig, + tableName: String, + limit: Int?, + dbType: DbType?, +): MemoryEstimate { + validateLimit(limit) + return withReadOnlyConnection(dbConfig, dbType) { conn -> + val determinedDbType = dbType ?: extractDBTypeFromConnection(conn) + MemoryEstimator.estimateTable(conn, tableName, determinedDbType, limit) + } +} + +private fun estimateSqlQuery( + dbConfig: DbConnectionConfig, + sqlQuery: String, + limit: Int?, + dbType: DbType?, +): MemoryEstimate { + validateLimit(limit) + require(isValidSqlQuery(sqlQuery)) { + "SQL query should start from SELECT and be a valid query" + } + + return withReadOnlyConnection(dbConfig, dbType) { conn -> + val determinedDbType = dbType ?: extractDBTypeFromConnection(conn) + MemoryEstimator.estimateQuery(conn, sqlQuery, determinedDbType, limit) + } +} + +private fun estimateSqlTable( + connection: Connection, + tableName: String, + limit: Int?, + dbType: DbType?, +): MemoryEstimate { + validateLimit(limit) + val determinedDbType = dbType ?: extractDBTypeFromConnection(connection) + return MemoryEstimator.estimateTable(connection, tableName, determinedDbType, limit) +} + +private fun estimateSqlQuery( + connection: Connection, + sqlQuery: String, + limit: Int?, + dbType: DbType?, +): MemoryEstimate { + validateLimit(limit) + require(isValidSqlQuery(sqlQuery)) { + "SQL query should start from SELECT and be a valid query" + } + val determinedDbType = dbType ?: extractDBTypeFromConnection(connection) + return MemoryEstimator.estimateQuery(connection, sqlQuery, determinedDbType, limit) +} + +private fun buildFullTableName(catalogue: String?, schemaName: String?, tableName: String): String { + return when { + catalogue != null && schemaName != null -> "$catalogue.$schemaName.$tableName" + catalogue != null -> "$catalogue.$tableName" + else -> tableName + } +} + +private fun retrieveTableMetadata( + metaData: java.sql.DatabaseMetaData, + catalogue: String?, + dbType: DbType, +): java.sql.ResultSet { + val tableTypes = dbType.tableTypes?.toTypedArray() + return metaData.getTables(catalogue, null, null, tableTypes) +} + +private fun formatBytes(bytes: Long): String = when { + bytes < 1024 -> "$bytes bytes" + bytes < 1024 * 1024 -> "%.1f KB".format(bytes / 1024.0) + bytes < 1024 * 1024 * 1024 -> "%.1f MB".format(bytes / (1024.0 * 1024.0)) + else -> "%.2f GB".format(bytes / (1024.0 * 1024.0 * 1024.0)) +} diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/MemoryEstimate.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/MemoryEstimate.kt new file mode 100644 index 0000000000..95262dcac1 --- /dev/null +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/MemoryEstimate.kt @@ -0,0 +1,52 @@ +package org.jetbrains.kotlinx.dataframe.io + +/** + * Result of memory estimation for a database query. + */ +public data class MemoryEstimate( + /** Estimated number of rows */ + val estimatedRows: Long, + + /** Estimated bytes per row (including DataFrame overhead) */ + val bytesPerRow: Long, + + /** Total estimated memory in bytes */ + val totalBytes: Long, + + /** Human-readable size */ + val humanReadable: String, + + /** Whether this is an exact count or estimate */ + val isExact: Boolean, +) { + /** Total estimated memory in megabytes */ + val megabytes: Double get() = totalBytes / (1024.0 * 1024.0) + + /** Total estimated memory in gigabytes */ + val gigabytes: Double get() = totalBytes / (1024.0 * 1024.0 * 1024.0) + + /** + * Returns true if estimated memory exceeds the given threshold. + */ + public fun exceeds(thresholdBytes: Long): Boolean = totalBytes > thresholdBytes + + /** + * Returns true if estimated memory exceeds the given threshold in gigabytes. + */ + public fun exceedsGb(thresholdGb: Double): Boolean = gigabytes > thresholdGb + + /** + * Calculates recommended limit to stay under the given memory threshold. + */ + public fun recommendedLimit(maxBytes: Long): Int { + if (bytesPerRow == 0L) return Int.MAX_VALUE + val recommendedRows = maxBytes / bytesPerRow + return recommendedRows.coerceAtMost(Int.MAX_VALUE.toLong()).toInt() + } + + override fun toString(): String = buildString { + append("Memory Estimate: $humanReadable") + append(" (~$estimatedRows rows × $bytesPerRow bytes/row)") + if (!isExact) append(" [approximate]") + } +} diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/MemoryEstimator.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/MemoryEstimator.kt new file mode 100644 index 0000000000..6ee67c91a4 --- /dev/null +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/MemoryEstimator.kt @@ -0,0 +1,208 @@ +package org.jetbrains.kotlinx.dataframe.io + +import io.github.oshai.kotlinlogging.KotlinLogging +import org.jetbrains.kotlinx.dataframe.io.db.DbType +import org.jetbrains.kotlinx.dataframe.io.db.TableColumnMetadata +import java.sql.Connection +import java.sql.ResultSet + +private val logger = KotlinLogging.logger {} + +/** + * Estimates memory usage for database queries WITHOUT loading all data. + * + * Strategy: + * 1. Load 10 sample rows to calculate average row size + * 2. Use COUNT(*) or database statistics for total row count + * 3. Multiply to get total memory estimate + */ +internal object MemoryEstimator { + + private const val SAMPLE_SIZE = 10 + private const val COUNT_TIMEOUT_SECONDS = 5 + + /** + * Estimates memory for a table. + */ + fun estimateTable( + connection: Connection, + tableName: String, + dbType: DbType, + limit: Int?, + ): MemoryEstimate { + val bytesPerRow = estimateAverageRowSize(connection, tableName, dbType) + val totalRows = limit?.toLong() ?: estimateTableRowCount(connection, tableName, dbType) + val totalBytes = bytesPerRow * totalRows + + return MemoryEstimate( + estimatedRows = totalRows, + bytesPerRow = bytesPerRow, + totalBytes = totalBytes, + humanReadable = formatBytes(totalBytes), + isExact = limit != null, + ) + } + + /** + * Estimates memory for a custom SQL query. + */ + fun estimateQuery( + connection: Connection, + sqlQuery: String, + dbType: DbType, + limit: Int?, + ): MemoryEstimate { + val sampleQuery = dbType.buildSqlQueryWithLimit(sqlQuery, SAMPLE_SIZE) + + var totalSize = 0L + var rowCount = 0 + + connection.prepareStatement(sampleQuery).use { stmt -> + stmt.executeQuery().use { rs -> + val tableColumns = getTableColumnsMetadata(rs) + + while (rs.next()) { + totalSize += calculateRowSizeFromResultSet(rs, tableColumns) + rowCount++ + } + + if (rowCount == 0) { + return MemoryEstimate(0, 0, 0, "0 bytes", true) + } + + val bytesPerRow = totalSize / rowCount + val totalRows = limit?.toLong() ?: estimateQueryRowCount(connection, sqlQuery) + val totalBytes = bytesPerRow * totalRows + + return MemoryEstimate( + estimatedRows = totalRows, + bytesPerRow = bytesPerRow, + totalBytes = totalBytes, + humanReadable = formatBytes(totalBytes), + isExact = limit != null, + ) + } + } + } + + /** + * Calculates average row size from sample rows. + */ + private fun estimateAverageRowSize( + connection: Connection, + tableName: String, + dbType: DbType, + ): Long { + val sampleQuery = dbType.buildSelectTableQueryWithLimit(tableName, SAMPLE_SIZE) + + var totalSize = 0L + var rowCount = 0 + + connection.prepareStatement(sampleQuery).use { stmt -> + stmt.executeQuery().use { rs -> + val tableColumns = getTableColumnsMetadata(rs) + + while (rs.next()) { + totalSize += calculateRowSizeFromResultSet(rs, tableColumns) + rowCount++ + } + } + } + + return if (rowCount > 0) totalSize / rowCount else 64L + } + + /** + * Calculates the size of a single row from ResultSet. + */ + private fun calculateRowSizeFromResultSet( + rs: ResultSet, + tableColumns: List, + ): Long { + var size = 16L // DataRow overhead + + for (i in tableColumns.indices) { + val value = rs.getObject(i + 1) + val metadata = tableColumns[i] + size += estimateValueSize(value, metadata) + } + + size += tableColumns.size * 64L // DataFrame column overhead + return size + } + + /** + * Estimates value size based on actual value or metadata. + */ + private fun estimateValueSize(value: Any?, metadata: TableColumnMetadata): Long { + if (value != null) { + return when (value) { + is Boolean, is Byte -> 16 + is Short -> 16 + is Int -> 16 + is Long, is Double -> 24 + is Float -> 16 + is String -> 40 + (value.length * 2L) + is ByteArray -> 16 + value.size + else -> 48 + } as Long + } + + // Fallback to metadata + return when (metadata.jdbcType) { + java.sql.Types.BIT, java.sql.Types.BOOLEAN, java.sql.Types.TINYINT -> 16 + java.sql.Types.SMALLINT -> 16 + java.sql.Types.INTEGER -> 16 + java.sql.Types.BIGINT, java.sql.Types.DOUBLE -> 24 + java.sql.Types.REAL, java.sql.Types.FLOAT -> 16 + java.sql.Types.CHAR, java.sql.Types.VARCHAR -> 40 + (metadata.size.coerceAtMost(100) * 2L) + else -> 48 + } + } + + /** + * Estimates total row count using COUNT(*) with timeout. + */ + private fun estimateTableRowCount( + connection: Connection, + tableName: String, + dbType: DbType, + ): Long { + return try { + connection.prepareStatement("SELECT COUNT(*) FROM $tableName").use { stmt -> + stmt.queryTimeout = COUNT_TIMEOUT_SECONDS + stmt.executeQuery().use { rs -> + if (rs.next()) rs.getLong(1) else 0L + } + } + } catch (e: Exception) { + logger.warn(e) { "COUNT(*) failed for $tableName, using fallback" } + 1000L // Conservative fallback + } + } + + /** + * Estimates query row count using COUNT(*) wrapper. + */ + private fun estimateQueryRowCount(connection: Connection, sqlQuery: String): Long { + return try { + val countQuery = "SELECT COUNT(*) FROM ($sqlQuery) AS temp_count" + connection.prepareStatement(countQuery).use { stmt -> + stmt.queryTimeout = COUNT_TIMEOUT_SECONDS + stmt.executeQuery().use { rs -> + if (rs.next()) rs.getLong(1) else 0L + } + } + } catch (e: Exception) { + logger.warn(e) { "Cannot wrap query with COUNT(*), using sample-based estimate" } + 1000L // Conservative fallback + } + } +} + +private fun formatBytes(bytes: Long): String = when { + bytes < 1024 -> "$bytes bytes" + bytes < 1024 * 1024 -> "%.1f KB".format(bytes / 1024.0) + bytes < 1024 * 1024 * 1024 -> "%.1f MB".format(bytes / (1024.0 * 1024.0)) + else -> "%.2f GB".format(bytes / (1024.0 * 1024.0 * 1024.0)) +} diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/progressTracker.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/progressTracker.kt new file mode 100644 index 0000000000..df7c0ee593 --- /dev/null +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/progressTracker.kt @@ -0,0 +1,301 @@ +package org.jetbrains.kotlinx.dataframe.io + +import java.sql.Date +import java.sql.Time +import java.sql.Types +import io.github.oshai.kotlinlogging.KLogger +import org.jetbrains.kotlinx.dataframe.io.db.TableColumnMetadata +import java.math.BigDecimal +import java.math.BigInteger +import java.sql.Blob +import java.sql.Clob +import java.sql.ResultSet +import java.sql.Timestamp +import java.time.Instant +import java.time.LocalDate +import java.time.LocalDateTime +import java.time.LocalTime +import java.time.OffsetDateTime +import java.time.ZonedDateTime +import java.util.UUID + +/** + * Strategy interface for tracking progress during data loading. + */ +internal interface ProgressTracker { + fun onStart() + fun onRowLoaded() + fun onComplete(rowCount: Int) + + companion object { + fun create(logger: KLogger): ProgressTracker { + return when { + JdbcConfig.PROGRESS_ENABLED -> { + if (JdbcConfig.PROGRESS_DETAILED) { + DetailedProgressTracker(logger) + } else { + SimpleProgressTracker(logger) + } + } + logger.isDebugEnabled -> { + DetailedProgressTracker(logger) + } + else -> NoOpProgressTracker + } + } + } +} + +/** + * No-op implementation - zero overhead when disabled. + */ +private object NoOpProgressTracker : ProgressTracker { + override fun onStart() {} + override fun onRowLoaded() {} + override fun onComplete(rowCount: Int) {} +} + +/** + * Simple progress tracker - shows only basic information. + */ +private class SimpleProgressTracker(private val logger: KLogger) : ProgressTracker { + private var rowsLoaded = 0 + + override fun onStart() {} + + override fun onRowLoaded() { + rowsLoaded++ + if (rowsLoaded % JdbcConfig.PROGRESS_INTERVAL == 0) { + logger.debug { "Loaded $rowsLoaded rows" } + } + } + + override fun onComplete(rowCount: Int) { + if (rowCount > 0) { + logger.debug { "Loading complete: $rowCount rows" } + } + } +} + +/** + * Detailed progress tracker with statistics and memory estimation. + */ +internal class DetailedProgressTracker(private val logger: KLogger) : ProgressTracker { + private val startTime = System.currentTimeMillis() + private var rowsLoaded = 0 + private var firstRowSize: Long? = null + private var estimatedTotalRows: Long? = null + private var estimatedMemoryBytes: Long? = null + private var memoryWarningShown = false + + override fun onStart() {} + + override fun onRowLoaded() { + rowsLoaded++ + + if (rowsLoaded % JdbcConfig.PROGRESS_INTERVAL == 0) { + logDetailedProgress() + } + } + + override fun onComplete(rowCount: Int) { + if (rowCount == 0) return + + val totalMillis = System.currentTimeMillis() - startTime + val rowsPerSecond = if (totalMillis > 0) { + (rowCount.toDouble() / totalMillis * 1000).toInt() + } else { + 0 + } + + val memoryInfo = estimatedMemoryBytes?.let { + val actualMemory = firstRowSize?.let { size -> size * rowCount } ?: it + " using ${formatBytes(actualMemory)}" + } ?: "" + + if (totalMillis > 0) { + logger.debug { + "Loading complete: $rowCount rows in ${totalMillis}ms (~$rowsPerSecond rows/sec)$memoryInfo" + } + } else { + logger.debug { "Loading complete: $rowCount rows$memoryInfo" } + } + } + + /** + * Estimates memory on first row and shows warnings. + */ + fun estimateMemoryOnFirstRow( + columnData: List>, + tableColumns: List, + rs: ResultSet, + ) { + if (memoryWarningShown) return + + try { + firstRowSize = calculateRowSize(columnData, tableColumns) + estimatedTotalRows = estimateTotalRows(rs) + + if (firstRowSize != null && estimatedTotalRows != null) { + estimatedMemoryBytes = firstRowSize!! * estimatedTotalRows!! + + if (estimatedMemoryBytes!! > 100 * 1024 * 1024) { + logger.debug { + "Estimated memory: ${formatBytes(estimatedMemoryBytes!!)} for ~$estimatedTotalRows rows" + } + + if (estimatedMemoryBytes!! > 1024L * 1024L * 1024L) { + logger.warn { + "Large dataset detected (${formatBytes(estimatedMemoryBytes!!)}). " + + "Consider using 'limit' parameter." + } + memoryWarningShown = true + } + } + } + } catch (e: Exception) { + logger.debug(e) { "Failed to estimate memory" } + } + } + + private fun logDetailedProgress() { + val elapsedMillis = System.currentTimeMillis() - startTime + if (elapsedMillis == 0L) return + + val rowsPerSecond = (rowsLoaded.toDouble() / elapsedMillis * 1000).toInt() + + val progressPercent = estimatedTotalRows?.let { + " (${(rowsLoaded * 100.0 / it).toInt()}%)" + } ?: "" + + logger.debug { + "Loaded $rowsLoaded rows$progressPercent in ${elapsedMillis}ms (~$rowsPerSecond rows/sec)" + } + } + + private fun calculateRowSize( + columnData: List>, + tableColumns: List, + ): Long { + var size = 16L // DataRow overhead + + columnData.forEachIndexed { index, values -> + val value = values.firstOrNull() + val columnMetadata = tableColumns.getOrNull(index) + size += estimateValueSize(value, columnMetadata) + } + + size += columnData.size * 64L // DataFrame column overhead + return size + } + + private fun estimateValueSize(value: Any?, metadata: TableColumnMetadata?): Long { + if (value != null) { + return estimateActualValueSize(value) + } + + if (metadata != null) { + return estimateValueSizeFromMetadata(metadata) + } + + return 8L + } + + private fun estimateActualValueSize(value: Any): Long { + return when (value) { + is Boolean, is Byte -> 16 + is Short -> 16 + is Int -> 16 + is Long, is Double -> 24 + is Float -> 16 + is Char -> 16 + is String -> 40 + (value.length * 2L) + is ByteArray -> 16 + value.size + is CharArray -> 16 + (value.size * 2L) + is BigInteger -> 32 + (value.bitLength() / 8) + is BigDecimal -> { + val unscaledValue = value.unscaledValue() + 48 + (unscaledValue.bitLength() / 8) + } + + is LocalDate -> 24 + is LocalTime -> 24 + is LocalDateTime -> 48 + is Instant -> 24 + is OffsetDateTime -> 56 + is ZonedDateTime -> 64 + is Timestamp -> 32 + is Date -> 32 + is Time -> 32 + is UUID -> 32 + is Blob -> 48 + estimateBlobSize(value) + is Clob -> 48 + estimateClobSize(value) + else -> 48 + } as Long + } + + private fun estimateValueSizeFromMetadata(metadata: TableColumnMetadata): Long { + return when (metadata.jdbcType) { + Types.BIT, Types.BOOLEAN -> 16 + Types.TINYINT -> 16 + Types.SMALLINT -> 16 + Types.INTEGER -> 16 + Types.BIGINT -> 24 + Types.REAL, Types.FLOAT -> 16 + Types.DOUBLE -> 24 + Types.NUMERIC, Types.DECIMAL -> 48 + (metadata.size / 4) + Types.CHAR, Types.VARCHAR, Types.LONGVARCHAR -> + 40 + (metadata.size.coerceAtMost(1000) * 2L) + + Types.DATE -> 24 + Types.TIME -> 24 + Types.TIMESTAMP -> 32 + Types.TIMESTAMP_WITH_TIMEZONE -> 56 + Types.BINARY, Types.VARBINARY, Types.LONGVARBINARY -> + 16 + metadata.size.coerceAtMost(1000) + + Types.BLOB -> 64 + Types.CLOB -> 64 + else -> 48 + } as Long + } + + private fun estimateBlobSize(blob: Blob): Long { + return try { + blob.length().coerceAtMost(10 * 1024 * 1024) + } catch (e: Exception) { + 1024L + } + } + + private fun estimateClobSize(clob: Clob): Long { + return try { + clob.length() * 2L + } catch (e: Exception) { + 2048L + } + } + + private fun estimateTotalRows(rs: ResultSet): Long? { + return try { + if (rs.type != ResultSet.TYPE_FORWARD_ONLY) { + val currentRow = rs.row + rs.last() + val total = rs.row.toLong() + rs.absolute(currentRow) + total + } else { + null + } + } catch (e: Exception) { + null + } + } +} + +private fun formatBytes(bytes: Long): String = when { + bytes < 1024 -> "$bytes bytes" + bytes < 1024 * 1024 -> "%.1f KB".format(bytes / 1024.0) + bytes < 1024 * 1024 * 1024 -> "%.1f MB".format(bytes / (1024.0 * 1024.0)) + else -> "%.2f GB".format(bytes / (1024.0 * 1024.0 * 1024.0)) +} diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt index a2594ce131..44dbd110b3 100644 --- a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt @@ -937,9 +937,16 @@ internal fun fetchAndConvertDataFromResultSet( inferNullability: Boolean, ): AnyFrame { val columnKTypes = buildColumnKTypes(tableColumns, dbType) - val columnData = readAllRowsFromResultSet(rs, tableColumns, columnKTypes, dbType, limit) + + val progressTracker = ProgressTracker.create(logger) + + progressTracker.onStart() + + val columnData = readAllRowsFromResultSet(rs, tableColumns, columnKTypes, dbType, limit, progressTracker) val dataFrame = buildDataFrameFromColumnData(columnData, tableColumns, columnKTypes, dbType, inferNullability) + progressTracker.onComplete(dataFrame.rowsCount()) + logger.debug { "DataFrame with ${dataFrame.rowsCount()} rows and ${dataFrame.columnsCount()} columns created as a result of SQL query." } @@ -965,6 +972,7 @@ private fun readAllRowsFromResultSet( columnKTypes: Map, dbType: DbType, limit: Int?, + progressTracker: ProgressTracker, ): List> { val columnsCount = tableColumns.size val columnData = List(columnsCount) { mutableListOf() } @@ -981,12 +989,29 @@ private fun readAllRowsFromResultSet( columnData[columnIndex].add(value) } rowsRead++ - // if (rowsRead % 1000 == 0) logger.debug { "Loaded $rowsRead rows." } // TODO: https://github.com/Kotlin/dataframe/issues/455 + // Progress tracking + trackProgress(rowsRead, columnData, tableColumns, rs, progressTracker) } return columnData } +private fun trackProgress( + rowsRead: Int, + columnData: List>, + tableColumns: List, + rs: ResultSet, + progressTracker: ProgressTracker, +) { + // Memory estimation on first row (only for DetailedProgressTracker) + if (rowsRead == 1 && progressTracker is DetailedProgressTracker) { + progressTracker.estimateMemoryOnFirstRow(columnData, tableColumns, rs) + } + + // Progress update + progressTracker.onRowLoaded() +} + /** * Builds DataFrame from column-oriented data. * Accepts mutable lists to enable efficient in-place transformations. diff --git a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2/tracking/JdbcProgressOutputTest.kt b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2/tracking/JdbcProgressOutputTest.kt new file mode 100644 index 0000000000..02303e6f2c --- /dev/null +++ b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2/tracking/JdbcProgressOutputTest.kt @@ -0,0 +1,257 @@ +package org.jetbrains.kotlinx.dataframe.io.h2.tracking + +import io.kotest.matchers.shouldBe +import org.jetbrains.kotlinx.dataframe.DataFrame +import org.jetbrains.kotlinx.dataframe.io.JdbcSafeDataLoading +import org.jetbrains.kotlinx.dataframe.io.readSqlTable +import org.junit.After +import org.junit.Before +import org.junit.Test +import java.io.ByteArrayOutputStream +import java.io.PrintStream +import java.sql.Connection +import java.sql.DriverManager + +/** + * Tests for verifying log output of progress tracking and safe loading. + * + * Note: These tests capture System.err output instead of using Logback-specific + * log appenders to remain compatible with any SLF4J implementation. + */ +class JdbcProgressOutputTest { + + private lateinit var connection: Connection + private lateinit var originalErr: PrintStream + private lateinit var capturedErr: ByteArrayOutputStream + + @Before + fun setUp() { + connection = DriverManager.getConnection("jdbc:h2:mem:test_output;DB_CLOSE_DELAY=-1;MODE=MySQL") + + // Setup System.err capture for SLF4J Simple logging + originalErr = System.err + capturedErr = ByteArrayOutputStream() + System.setErr(PrintStream(capturedErr)) + + // Create test table + connection.createStatement().use { stmt -> + stmt.executeUpdate( + """ + CREATE TABLE test_progress ( + id INT PRIMARY KEY, + data VARCHAR(100) + ) + """.trimIndent() + ) + + // Insert 3000 rows + connection.autoCommit = false + val insertStmt = connection.prepareStatement("INSERT INTO test_progress VALUES (?, ?)") + for (i in 1..3000) { + insertStmt.setInt(1, i) + insertStmt.setString(2, "Data_$i") + insertStmt.addBatch() + + if (i % 500 == 0) { + insertStmt.executeBatch() + } + } + insertStmt.executeBatch() + connection.commit() + connection.autoCommit = true + } + } + + @After + fun tearDown() { + System.setErr(originalErr) + connection.createStatement().use { it.executeUpdate("DROP TABLE IF EXISTS test_progress") } + connection.close() + System.clearProperty("dataframe.jdbc.progress") + System.clearProperty("dataframe.jdbc.progress.detailed") + System.clearProperty("dataframe.jdbc.progress.interval") + } + + private fun getLogOutput(): String = capturedErr.toString("UTF-8") + + private fun clearLogOutput() { + capturedErr.reset() + } + + @Test + fun `progress tracking logs debug messages`() { + System.setProperty("dataframe.jdbc.progress", "true") + clearLogOutput() + + DataFrame.readSqlTable(connection, "test_progress") + + val logOutput = getLogOutput() + + // Should have progress messages + val hasProgressMessages = logOutput.contains("Loaded") && logOutput.contains("rows") + hasProgressMessages shouldBe true + } + + @Test + fun `detailed progress shows statistics`() { + System.setProperty("dataframe.jdbc.progress", "true") + System.setProperty("dataframe.jdbc.progress.detailed", "true") + clearLogOutput() + + DataFrame.readSqlTable(connection, "test_progress") + + val logOutput = getLogOutput() + + // Should contain statistics + val hasStatistics = logOutput.contains("rows/sec") || logOutput.contains("ms") + hasStatistics shouldBe true + } + + @Test + fun `simple progress shows basic messages`() { + System.setProperty("dataframe.jdbc.progress", "true") + System.setProperty("dataframe.jdbc.progress.detailed", "false") + clearLogOutput() + + DataFrame.readSqlTable(connection, "test_progress") + + val logOutput = getLogOutput() + + // Should have messages but without detailed stats + val hasBasicMessages = logOutput.contains("Loaded") && logOutput.contains("rows") + hasBasicMessages shouldBe true + + // Should NOT contain detailed statistics + val hasDetailedStats = logOutput.contains("rows/sec") + hasDetailedStats shouldBe false + } + + @Test + fun `progress interval is respected`() { + System.setProperty("dataframe.jdbc.progress", "true") + System.setProperty("dataframe.jdbc.progress.interval", "1000") + clearLogOutput() + + DataFrame.readSqlTable(connection, "test_progress") + + val logOutput = getLogOutput() + + // Should have messages at 1000, 2000, 3000 + val has1000 = logOutput.contains("1000 rows") + val has2000 = logOutput.contains("2000 rows") + has1000 shouldBe true + has2000 shouldBe true + } + + @Test + fun `completion message is logged`() { + System.setProperty("dataframe.jdbc.progress", "true") + clearLogOutput() + + DataFrame.readSqlTable(connection, "test_progress") + + val logOutput = getLogOutput() + + // Should have completion message + val hasCompletion = logOutput.contains("Loading complete") || logOutput.contains("3000 rows") + hasCompletion shouldBe true + } + + @Test + fun `memory warning is logged for large datasets`() { + System.setProperty("dataframe.jdbc.progress", "true") + + // Create a larger table + connection.createStatement().use { stmt -> + stmt.executeUpdate( + """ + CREATE TABLE large_test ( + id INT PRIMARY KEY, + data VARCHAR(10000) + ) + """.trimIndent() + ) + + connection.autoCommit = false + val insertStmt = connection.prepareStatement("INSERT INTO large_test VALUES (?, ?)") + for (i in 1..1000) { + insertStmt.setInt(1, i) + insertStmt.setString(2, "x".repeat(10000)) // 10KB per row = 10MB total + insertStmt.addBatch() + + if (i % 100 == 0) { + insertStmt.executeBatch() + } + } + insertStmt.executeBatch() + connection.commit() + connection.autoCommit = true + } + + clearLogOutput() + + DataFrame.readSqlTable(connection, "large_test") + + val logOutput = getLogOutput() + + // Should have memory estimate message (if logged at DEBUG/INFO level) + // This is optional as it depends on logging configuration + + connection.createStatement().use { it.executeUpdate("DROP TABLE IF EXISTS large_test") } + } + + @Test + fun `safe loading logs warnings when limit applied`() { + clearLogOutput() + + JdbcSafeDataLoading.load( + configure = { + maxMemoryGb = 0.00001 // Very small limit to trigger warning + onExceed = JdbcSafeDataLoading.ExceedAction.APPLY_LIMIT + } + ) { + readSqlTable(connection, "test_progress") + } + + val logOutput = getLogOutput() + + // Should have warning about exceeding limit + val hasWarning = logOutput.contains("exceeds limit") || logOutput.contains("WARN") + hasWarning shouldBe true + } + + @Test + fun `safe loading logs warning when proceeding despite limit`() { + clearLogOutput() + + JdbcSafeDataLoading.load( + configure = { + maxMemoryGb = 0.00001 + onExceed = JdbcSafeDataLoading.ExceedAction.WARN_AND_PROCEED + } + ) { + readSqlTable(connection, "test_progress") + } + + val logOutput = getLogOutput() + + // Should have warning about proceeding + val hasWarning = (logOutput.contains("exceeds limit") && logOutput.contains("proceeding")) || + logOutput.contains("WARN") + hasWarning shouldBe true + } + + @Test + fun `no logs when progress disabled`() { + System.clearProperty("dataframe.jdbc.progress") + clearLogOutput() + + DataFrame.readSqlTable(connection, "test_progress") + + val logOutput = getLogOutput() + + // Should NOT have progress messages (only connection/query logs possibly) + val hasProgressMessages = logOutput.matches(Regex(".*Loaded \\d+ rows.*")) + hasProgressMessages shouldBe false + } +} diff --git a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2/tracking/JdbcProgressTest.kt b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2/tracking/JdbcProgressTest.kt new file mode 100644 index 0000000000..5c1e37782a --- /dev/null +++ b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2/tracking/JdbcProgressTest.kt @@ -0,0 +1,137 @@ + +package org.jetbrains.kotlinx.dataframe.io.h2.tracking + +import io.kotest.matchers.shouldBe +import org.jetbrains.kotlinx.dataframe.DataFrame +import org.jetbrains.kotlinx.dataframe.io.readSqlQuery +import org.jetbrains.kotlinx.dataframe.io.readSqlTable +import org.junit.After +import org.junit.Before +import org.junit.Test +import java.sql.Connection +import java.sql.DriverManager + +/** + * Tests for JDBC progress tracking functionality. + */ +class JdbcProgressTest { + + private lateinit var connection: Connection + + @Before + fun setUp() { + connection = DriverManager.getConnection("jdbc:h2:mem:test_progress;DB_CLOSE_DELAY=-1;MODE=MySQL") + + connection.createStatement().use { stmt -> + stmt.executeUpdate( + """ + CREATE TABLE test_table ( + id INT PRIMARY KEY, + name VARCHAR(100), + amount INT + ) + """.trimIndent() + ) + + // Insert 5000 rows for testing + connection.autoCommit = false + val insertStmt = connection.prepareStatement("INSERT INTO test_table VALUES (?, ?, ?)") + for (i in 1..5000) { + insertStmt.setInt(1, i) + insertStmt.setString(2, "Name_$i") + insertStmt.setInt(3, i * 10) + insertStmt.addBatch() + + if (i % 500 == 0) { + insertStmt.executeBatch() + } + } + insertStmt.executeBatch() + connection.commit() + connection.autoCommit = true + } + } + + @After + fun tearDown() { + connection.createStatement().use { it.executeUpdate("DROP TABLE IF EXISTS test_table") } + connection.close() + } + + @Test + fun `progress tracking disabled by default`() { + // Progress should be disabled by default + System.clearProperty("dataframe.jdbc.progress") + + val df = DataFrame.readSqlTable(connection, "test_table") + + df.rowsCount() shouldBe 5000 + } + + @Test + fun `progress tracking can be enabled via system property`() { + System.setProperty("dataframe.jdbc.progress", "true") + + try { + val df = DataFrame.readSqlTable(connection, "test_table") + df.rowsCount() shouldBe 5000 + } finally { + System.clearProperty("dataframe.jdbc.progress") + } + } + + @Test + fun `progress tracking with limit`() { + System.setProperty("dataframe.jdbc.progress", "true") + + try { + val df = DataFrame.readSqlTable(connection, "test_table", limit = 100) + df.rowsCount() shouldBe 100 + } finally { + System.clearProperty("dataframe.jdbc.progress") + } + } + + @Test + fun `detailed progress can be disabled`() { + System.setProperty("dataframe.jdbc.progress", "true") + System.setProperty("dataframe.jdbc.progress.detailed", "false") + + try { + val df = DataFrame.readSqlTable(connection, "test_table") + df.rowsCount() shouldBe 5000 + } finally { + System.clearProperty("dataframe.jdbc.progress") + System.clearProperty("dataframe.jdbc.progress.detailed") + } + } + + @Test + fun `progress interval can be configured`() { + System.setProperty("dataframe.jdbc.progress", "true") + System.setProperty("dataframe.jdbc.progress.interval", "500") + + try { + val df = DataFrame.readSqlTable(connection, "test_table") + df.rowsCount() shouldBe 5000 + } finally { + System.clearProperty("dataframe.jdbc.progress") + System.clearProperty("dataframe.jdbc.progress.interval") + } + } + + @Test + fun `progress works with query`() { + System.setProperty("dataframe.jdbc.progress", "true") + + try { + val df = DataFrame.readSqlQuery( + connection, + "SELECT * FROM test_table WHERE amount > 1000" + ) + df.rowsCount() shouldBe 4900 + } finally { + System.clearProperty("dataframe.jdbc.progress") + } + } +} diff --git a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2/tracking/JdbcSafeLoadingTest.kt b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2/tracking/JdbcSafeLoadingTest.kt new file mode 100644 index 0000000000..33fe97d4c4 --- /dev/null +++ b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2/tracking/JdbcSafeLoadingTest.kt @@ -0,0 +1,177 @@ +package org.jetbrains.kotlinx.dataframe.io.h2.tracking + +import io.kotest.assertions.throwables.shouldThrow +import io.kotest.matchers.ints.shouldBeLessThan +import io.kotest.matchers.shouldBe +import org.jetbrains.kotlinx.dataframe.io.DbConnectionConfig +import org.jetbrains.kotlinx.dataframe.io.JdbcSafeDataLoading +import org.junit.After +import org.junit.Before +import org.junit.Test +import java.sql.Connection +import java.sql.DriverManager + +/** + * Tests for JdbcSafeDataLoading functionality. + */ +class JdbcSafeLoadingTest { + + private lateinit var connection: Connection + private val dbConfig = DbConnectionConfig("jdbc:h2:mem:test_safe;DB_CLOSE_DELAY=-1;MODE=MySQL") + + @Before + fun setUp() { + connection = DriverManager.getConnection(dbConfig.url) + + connection.createStatement().use { stmt -> + stmt.executeUpdate( + """ + CREATE TABLE test_data ( + id INT PRIMARY KEY, + data VARCHAR(1000) + ) + """.trimIndent() + ) + + // Insert 10000 rows (~10KB each = ~100MB total) + connection.autoCommit = false + val insertStmt = connection.prepareStatement("INSERT INTO test_data VALUES (?, ?)") + for (i in 1..10000) { + insertStmt.setInt(1, i) + insertStmt.setString(2, "x".repeat(1000)) // 1KB per row + insertStmt.addBatch() + + if (i % 1000 == 0) { + insertStmt.executeBatch() + } + } + insertStmt.executeBatch() + connection.commit() + connection.autoCommit = true + } + } + + @After + fun tearDown() { + connection.createStatement().use { it.executeUpdate("DROP TABLE IF EXISTS test_data") } + connection.close() + } + + @Test + fun `safe load with sufficient memory limit`() { + val df = JdbcSafeDataLoading.load(maxMemoryGb = 1.0) { + readSqlTable(connection, "test_data") + } + + df.rowsCount() shouldBe 10000 + } + + @Test + fun `safe load applies automatic limit when memory exceeded`() { + var limitApplied = false + var appliedLimitValue = 0 + + val df = JdbcSafeDataLoading.load( + configure = { + maxMemoryGb = 0.001 // Very small limit + onExceed = JdbcSafeDataLoading.ExceedAction.APPLY_LIMIT + + onLimitApplied = { _, limit -> + limitApplied = true + appliedLimitValue = limit + } + } + ) { + readSqlTable(connection, "test_data") + } + + limitApplied shouldBe true + appliedLimitValue shouldBeLessThan 10000 + df.rowsCount() shouldBe appliedLimitValue + } + + @Test + fun `safe load throws when configured`() { + shouldThrow { + JdbcSafeDataLoading.load( + configure = { + maxMemoryGb = 0.001 + onExceed = JdbcSafeDataLoading.ExceedAction.THROW + } + ) { + readSqlTable(connection, "test_data") + } + } + } + + @Test + fun `safe load warns and proceeds when configured`() { + val df = JdbcSafeDataLoading.load( + configure = { + maxMemoryGb = 0.001 + onExceed = JdbcSafeDataLoading.ExceedAction.WARN_AND_PROCEED + } + ) { + readSqlTable(connection, "test_data") + } + + df.rowsCount() shouldBe 10000 + } + + @Test + fun `onEstimate callback is invoked`() { + var estimateCalled = false + var estimatedRows = 0L + + JdbcSafeDataLoading.load( + configure = { + maxMemoryGb = 1.0 + + onEstimate = { estimate -> + estimateCalled = true + estimatedRows = estimate.estimatedRows + } + } + ) { + readSqlTable(connection, "test_data") + } + + estimateCalled shouldBe true + estimatedRows shouldBe 10000L + } + + @Test + fun `safe load works with DbConnectionConfig`() { + val df = JdbcSafeDataLoading.load(maxMemoryGb = 1.0) { + readSqlTable(dbConfig, "test_data") + } + + df.rowsCount() shouldBe 10000 + } + + @Test + fun `safe load works with query`() { + val df = JdbcSafeDataLoading.load(maxMemoryGb = 1.0) { + readSqlQuery(connection, "SELECT * FROM test_data WHERE id <= 100") + } + + df.rowsCount() shouldBe 100 + } + + @Test + fun `loadMultiple works with readAllSqlTables`() { + // Create another table + connection.createStatement().use { stmt -> + stmt.executeUpdate("CREATE TABLE test_data2 (id INT PRIMARY KEY, amount INT)") + stmt.executeUpdate("INSERT INTO test_data2 VALUES (1, 10), (2, 20)") + } + + val tables = JdbcSafeDataLoading.loadMultiple(maxMemoryGb = 1.0) { + readAllSqlTables(connection) + } + + tables.keys shouldBe setOf("TEST_DATA", "TEST_DATA2") + + connection.createStatement().use { it.executeUpdate("DROP TABLE IF EXISTS test_data2") } + } +}