diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/JdbcConfig.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/JdbcConfig.kt new file mode 100644 index 0000000000..513ab82a17 --- /dev/null +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/JdbcConfig.kt @@ -0,0 +1,68 @@ +package org.jetbrains.kotlinx.dataframe.io + +/** + * Internal configuration for JDBC module behavior. + * Controlled via system properties or environment variables. + * + * Example usage: + * ```kotlin + * // Enable progress with detailed statistics + * System.setProperty("dataframe.jdbc.progress", "true") + * System.setProperty("dataframe.jdbc.progress.detailed", "true") + * + * // Or via environment variables + * export DATAFRAME_JDBC_PROGRESS=true + * export DATAFRAME_JDBC_PROGRESS_DETAILED=true + * ``` + */ +internal object JdbcConfig { + /** + * Enable progress logging during data loading. + * + * When `true`, enables DEBUG-level progress messages and memory estimates. + * When `false` (default), only respects logger level (e.g., DEBUG). + * + * System property: `-Ddataframe.jdbc.progress=true` + * Environment variable: `DATAFRAME_JDBC_PROGRESS=true` + * Default: `false` + */ + @JvmStatic + val PROGRESS_ENABLED: Boolean by lazy { + System.getProperty("dataframe.jdbc.progress")?.toBoolean() + ?: System.getenv("DATAFRAME_JDBC_PROGRESS")?.toBoolean() + ?: false + } + + /** + * Enable detailed progress logging with statistics. + * + * When `true`: Shows row count, percentage, speed (rows/sec), memory usage + * When `false`: Shows only basic "Loaded X rows" messages + * + * Only takes effect when progress is enabled. + * + * System property: `-Ddataframe.jdbc.progress.detailed=true` + * Environment variable: `DATAFRAME_JDBC_PROGRESS_DETAILED=true` + * Default: `true` + */ + @JvmStatic + val PROGRESS_DETAILED: Boolean by lazy { + System.getProperty("dataframe.jdbc.progress.detailed")?.toBoolean() + ?: System.getenv("DATAFRAME_JDBC_PROGRESS_DETAILED")?.toBoolean() + ?: true + } + + /** + * Report progress every N rows. + * + * System property: `-Ddataframe.jdbc.progress.interval=5000` + * Environment variable: `DATAFRAME_JDBC_PROGRESS_INTERVAL=5000` + * Default: `1000` + */ + @JvmStatic + val PROGRESS_INTERVAL: Int by lazy { + System.getProperty("dataframe.jdbc.progress.interval")?.toIntOrNull() + ?: System.getenv("DATAFRAME_JDBC_PROGRESS_INTERVAL")?.toIntOrNull() + ?: 1000 + } +} diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/JdbcSafeLoading.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/JdbcSafeLoading.kt new file mode 100644 index 0000000000..f8174872d1 --- /dev/null +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/JdbcSafeLoading.kt @@ -0,0 +1,419 @@ +package org.jetbrains.kotlinx.dataframe.io + +import io.github.oshai.kotlinlogging.KotlinLogging +import org.jetbrains.kotlinx.dataframe.AnyFrame +import org.jetbrains.kotlinx.dataframe.DataFrame +import org.jetbrains.kotlinx.dataframe.io.db.DbType +import org.jetbrains.kotlinx.dataframe.io.db.extractDBTypeFromConnection +import java.sql.Connection +import java.sql.PreparedStatement +import javax.sql.DataSource + +private val logger = KotlinLogging.logger {} + +/** + * Safe JDBC data loading with automatic memory estimation and limit management. + * + * Uses the same method names as the main API but adds automatic memory checking: + * + * ```kotlin + * // Instead of: + * DataFrame.readSqlTable(config, "table") + * + * // Use: + * JdbcSafeDataLoading.load(maxMemoryGb = 1.0) { + * readSqlTable(config, "table") + * } + * ``` + * + * ## Features + * - Automatic memory estimation before loading (10 sample rows + COUNT(*)) + * - Automatic limit application when memory threshold exceeded + * - Configurable behavior: throw, warn, or auto-limit + * - Callbacks for estimates and limit events + * + * ## Examples + * + * Basic usage: + * ```kotlin + * val df = JdbcSafeDataLoading.load(maxMemoryGb = 1.0) { + * readSqlTable(config, "large_table") + * } + * ``` + * + * Advanced configuration: + * ```kotlin + * val df = JdbcSafeDataLoading.load { + * maxMemoryGb = 2.0 + * onExceed = ExceedAction.THROW + * + * onEstimate = { estimate -> + * println("Estimated: ${estimate.humanReadable}") + * } + * + * readSqlTable(config, "huge_table") + * } + * ``` + */ +public object JdbcSafeDataLoading { + + /** + * Loads JDBC data with automatic memory checking. + */ + public fun load( + maxMemoryGb: Double = 1.0, + block: SafeLoadContext.() -> T, + ): T { + val config = LoadConfig().apply { this.maxMemoryGb = maxMemoryGb } + return SafeLoadContext(config).block() + } + + /** + * Advanced version with full configuration. + */ + public fun load( + configure: LoadConfig.() -> Unit, + block: SafeLoadContext.() -> T, + ): T { + val config = LoadConfig().apply(configure) + return SafeLoadContext(config).block() + } + + /** + * For loading multiple results (like Map). + */ + public fun loadMultiple( + maxMemoryGb: Double = 1.0, + block: SafeLoadContext.() -> T, + ): T { + val config = LoadConfig().apply { this.maxMemoryGb = maxMemoryGb } + return SafeLoadContext(config).block() + } + + /** + * Advanced version for multiple results. + */ + public fun loadMultiple( + configure: LoadConfig.() -> Unit, + block: SafeLoadContext.() -> T, + ): T { + val config = LoadConfig().apply(configure) + return SafeLoadContext(config).block() + } + + /** + * Configuration for safe JDBC loading. + */ + public class LoadConfig { + /** Maximum allowed memory in bytes. */ + public var maxMemoryBytes: Long = 1024L * 1024L * 1024L + + /** Maximum allowed memory in gigabytes (convenience setter). */ + public var maxMemoryGb: Double + get() = maxMemoryBytes / (1024.0 * 1024.0 * 1024.0) + set(value) { maxMemoryBytes = (value * 1024 * 1024 * 1024).toLong() } + + /** What to do when memory limit is exceeded. */ + public var onExceed: ExceedAction = ExceedAction.APPLY_LIMIT + + /** Callback invoked when estimate is available (before loading). */ + public var onEstimate: ((MemoryEstimate) -> Unit)? = null + + /** Callback invoked when limit is automatically applied. */ + public var onLimitApplied: ((estimate: MemoryEstimate, appliedLimit: Int) -> Unit)? = null + + /** For loadAllTables: callback invoked for each table's estimate. */ + public var onTableEstimate: ((tableName: String, estimate: MemoryEstimate) -> Unit)? = null + } + + /** Action to take when estimated memory exceeds the limit. */ + public enum class ExceedAction { + /** Automatically apply limit to stay under threshold */ + APPLY_LIMIT, + + /** Throw MemoryLimitExceededException and don't load */ + THROW, + + /** Log warning but proceed with full load */ + WARN_AND_PROCEED, + } + + /** Exception thrown when memory limit exceeded and action is THROW. */ + public class MemoryLimitExceededException( + message: String, + public val estimate: MemoryEstimate, + ) : IllegalStateException(message) + + /** + * Context with the same method names as DataFrame.Companion. + */ + public class SafeLoadContext internal constructor( + private val config: LoadConfig, + ) { + // region readSqlTable + + public fun readSqlTable( + dbConfig: DbConnectionConfig, + tableName: String, + dbType: DbType? = null, + inferNullability: Boolean = true, + strictValidation: Boolean = true, + configureStatement: (PreparedStatement) -> Unit = {}, + ): AnyFrame = estimateAndLoad( + estimate = { estimateSqlTable(dbConfig, tableName, null, dbType) }, + load = { limit -> + DataFrame.readSqlTable(dbConfig, tableName, limit, inferNullability, dbType, strictValidation, configureStatement) + } + ) + + public fun readSqlTable( + connection: Connection, + tableName: String, + dbType: DbType? = null, + inferNullability: Boolean = true, + strictValidation: Boolean = true, + configureStatement: (PreparedStatement) -> Unit = {}, + ): AnyFrame = estimateAndLoad( + estimate = { estimateSqlTable(connection, tableName, null, dbType) }, + load = { limit -> + DataFrame.readSqlTable(connection, tableName, limit, inferNullability, dbType, strictValidation, configureStatement) + } + ) + + public fun readSqlTable( + dataSource: DataSource, + tableName: String, + dbType: DbType? = null, + inferNullability: Boolean = true, + strictValidation: Boolean = true, + configureStatement: (PreparedStatement) -> Unit = {}, + ): AnyFrame = dataSource.connection.use { conn -> + readSqlTable(conn, tableName, dbType, inferNullability, strictValidation, configureStatement) + } + + // endregion + + // region readSqlQuery + + public fun readSqlQuery( + dbConfig: DbConnectionConfig, + sqlQuery: String, + dbType: DbType? = null, + inferNullability: Boolean = true, + strictValidation: Boolean = true, + configureStatement: (PreparedStatement) -> Unit = {}, + ): AnyFrame = estimateAndLoad( + estimate = { estimateSqlQuery(dbConfig, sqlQuery, null, dbType) }, + load = { limit -> + DataFrame.readSqlQuery(dbConfig, sqlQuery, limit, inferNullability, dbType, strictValidation, configureStatement) + } + ) + + public fun readSqlQuery( + connection: Connection, + sqlQuery: String, + dbType: DbType? = null, + inferNullability: Boolean = true, + strictValidation: Boolean = true, + configureStatement: (PreparedStatement) -> Unit = {}, + ): AnyFrame = estimateAndLoad( + estimate = { estimateSqlQuery(connection, sqlQuery, null, dbType) }, + load = { limit -> + DataFrame.readSqlQuery(connection, sqlQuery, limit, inferNullability, dbType, strictValidation, configureStatement) + } + ) + + public fun readSqlQuery( + dataSource: DataSource, + sqlQuery: String, + dbType: DbType? = null, + inferNullability: Boolean = true, + strictValidation: Boolean = true, + configureStatement: (PreparedStatement) -> Unit = {}, + ): AnyFrame = dataSource.connection.use { conn -> + readSqlQuery(conn, sqlQuery, dbType, inferNullability, strictValidation, configureStatement) + } + + // endregion + + // region readAllSqlTables + + public fun readAllSqlTables( + dbConfig: DbConnectionConfig, + catalogue: String? = null, + dbType: DbType? = null, + inferNullability: Boolean = true, + configureStatement: (PreparedStatement) -> Unit = {}, + ): Map = withReadOnlyConnection(dbConfig, dbType) { connection -> + readAllSqlTables(connection, catalogue, dbType, inferNullability, configureStatement) + } + + public fun readAllSqlTables( + connection: Connection, + catalogue: String? = null, + dbType: DbType? = null, + inferNullability: Boolean = true, + configureStatement: (PreparedStatement) -> Unit = {}, + ): Map { + val determinedDbType = dbType ?: extractDBTypeFromConnection(connection) + val metaData = connection.metaData + val tablesResultSet = retrieveTableMetadata(metaData, catalogue, determinedDbType) + + return buildMap { + while (tablesResultSet.next()) { + val tableMetadata = determinedDbType.buildTableMetadata(tablesResultSet) + + if (determinedDbType.isSystemTable(tableMetadata)) continue + + val fullTableName = buildFullTableName(catalogue, tableMetadata.schemaName, tableMetadata.name) + + val dataFrame = estimateAndLoad( + estimate = { + val est = estimateSqlTable(connection, fullTableName, null, dbType) + config.onTableEstimate?.invoke(fullTableName, est) + est + }, + load = { limit -> + DataFrame.readSqlTable(connection, fullTableName, limit, inferNullability, determinedDbType, true, configureStatement) + } + ) + + put(fullTableName, dataFrame) + } + } + } + + public fun readAllSqlTables( + dataSource: DataSource, + catalogue: String? = null, + dbType: DbType? = null, + inferNullability: Boolean = true, + configureStatement: (PreparedStatement) -> Unit = {}, + ): Map = dataSource.connection.use { conn -> + readAllSqlTables(conn, catalogue, dbType, inferNullability, configureStatement) + } + + // endregion + + private fun estimateAndLoad( + estimate: () -> MemoryEstimate, + load: (limit: Int?) -> AnyFrame, + ): AnyFrame { + val memoryEstimate = estimate() + config.onEstimate?.invoke(memoryEstimate) + val finalLimit = handleMemoryLimit(memoryEstimate, null) + return load(finalLimit) + } + + private fun handleMemoryLimit(estimate: MemoryEstimate, originalLimit: Int?): Int? { + if (!estimate.exceeds(config.maxMemoryBytes)) return originalLimit + + return when (config.onExceed) { + ExceedAction.APPLY_LIMIT -> { + if (originalLimit != null) { + originalLimit + } else { + val recommendedLimit = estimate.recommendedLimit(config.maxMemoryBytes) + logger.warn { + "Estimated memory ${estimate.humanReadable} exceeds limit " + + "${formatBytes(config.maxMemoryBytes)}. Applying limit: $recommendedLimit rows" + } + config.onLimitApplied?.invoke(estimate, recommendedLimit) + recommendedLimit + } + } + + ExceedAction.THROW -> throw MemoryLimitExceededException( + "Estimated memory ${estimate.humanReadable} exceeds limit ${formatBytes(config.maxMemoryBytes)}", + estimate + ) + + ExceedAction.WARN_AND_PROCEED -> { + logger.warn { "Memory ${estimate.humanReadable} exceeds limit, but proceeding" } + originalLimit + } + } + } + } +} + +// Helper functions for estimation + +private fun estimateSqlTable( + dbConfig: DbConnectionConfig, + tableName: String, + limit: Int?, + dbType: DbType?, +): MemoryEstimate { + validateLimit(limit) + return withReadOnlyConnection(dbConfig, dbType) { conn -> + val determinedDbType = dbType ?: extractDBTypeFromConnection(conn) + MemoryEstimator.estimateTable(conn, tableName, determinedDbType, limit) + } +} + +private fun estimateSqlQuery( + dbConfig: DbConnectionConfig, + sqlQuery: String, + limit: Int?, + dbType: DbType?, +): MemoryEstimate { + validateLimit(limit) + require(isValidSqlQuery(sqlQuery)) { + "SQL query should start from SELECT and be a valid query" + } + + return withReadOnlyConnection(dbConfig, dbType) { conn -> + val determinedDbType = dbType ?: extractDBTypeFromConnection(conn) + MemoryEstimator.estimateQuery(conn, sqlQuery, determinedDbType, limit) + } +} + +private fun estimateSqlTable( + connection: Connection, + tableName: String, + limit: Int?, + dbType: DbType?, +): MemoryEstimate { + validateLimit(limit) + val determinedDbType = dbType ?: extractDBTypeFromConnection(connection) + return MemoryEstimator.estimateTable(connection, tableName, determinedDbType, limit) +} + +private fun estimateSqlQuery( + connection: Connection, + sqlQuery: String, + limit: Int?, + dbType: DbType?, +): MemoryEstimate { + validateLimit(limit) + require(isValidSqlQuery(sqlQuery)) { + "SQL query should start from SELECT and be a valid query" + } + val determinedDbType = dbType ?: extractDBTypeFromConnection(connection) + return MemoryEstimator.estimateQuery(connection, sqlQuery, determinedDbType, limit) +} + +private fun buildFullTableName(catalogue: String?, schemaName: String?, tableName: String): String { + return when { + catalogue != null && schemaName != null -> "$catalogue.$schemaName.$tableName" + catalogue != null -> "$catalogue.$tableName" + else -> tableName + } +} + +private fun retrieveTableMetadata( + metaData: java.sql.DatabaseMetaData, + catalogue: String?, + dbType: DbType, +): java.sql.ResultSet { + val tableTypes = dbType.tableTypes?.toTypedArray() + return metaData.getTables(catalogue, null, null, tableTypes) +} + +private fun formatBytes(bytes: Long): String = when { + bytes < 1024 -> "$bytes bytes" + bytes < 1024 * 1024 -> "%.1f KB".format(bytes / 1024.0) + bytes < 1024 * 1024 * 1024 -> "%.1f MB".format(bytes / (1024.0 * 1024.0)) + else -> "%.2f GB".format(bytes / (1024.0 * 1024.0 * 1024.0)) +} diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/MemoryEstimate.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/MemoryEstimate.kt new file mode 100644 index 0000000000..95262dcac1 --- /dev/null +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/MemoryEstimate.kt @@ -0,0 +1,52 @@ +package org.jetbrains.kotlinx.dataframe.io + +/** + * Result of memory estimation for a database query. + */ +public data class MemoryEstimate( + /** Estimated number of rows */ + val estimatedRows: Long, + + /** Estimated bytes per row (including DataFrame overhead) */ + val bytesPerRow: Long, + + /** Total estimated memory in bytes */ + val totalBytes: Long, + + /** Human-readable size */ + val humanReadable: String, + + /** Whether this is an exact count or estimate */ + val isExact: Boolean, +) { + /** Total estimated memory in megabytes */ + val megabytes: Double get() = totalBytes / (1024.0 * 1024.0) + + /** Total estimated memory in gigabytes */ + val gigabytes: Double get() = totalBytes / (1024.0 * 1024.0 * 1024.0) + + /** + * Returns true if estimated memory exceeds the given threshold. + */ + public fun exceeds(thresholdBytes: Long): Boolean = totalBytes > thresholdBytes + + /** + * Returns true if estimated memory exceeds the given threshold in gigabytes. + */ + public fun exceedsGb(thresholdGb: Double): Boolean = gigabytes > thresholdGb + + /** + * Calculates recommended limit to stay under the given memory threshold. + */ + public fun recommendedLimit(maxBytes: Long): Int { + if (bytesPerRow == 0L) return Int.MAX_VALUE + val recommendedRows = maxBytes / bytesPerRow + return recommendedRows.coerceAtMost(Int.MAX_VALUE.toLong()).toInt() + } + + override fun toString(): String = buildString { + append("Memory Estimate: $humanReadable") + append(" (~$estimatedRows rows × $bytesPerRow bytes/row)") + if (!isExact) append(" [approximate]") + } +} diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/MemoryEstimator.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/MemoryEstimator.kt new file mode 100644 index 0000000000..6ee67c91a4 --- /dev/null +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/MemoryEstimator.kt @@ -0,0 +1,208 @@ +package org.jetbrains.kotlinx.dataframe.io + +import io.github.oshai.kotlinlogging.KotlinLogging +import org.jetbrains.kotlinx.dataframe.io.db.DbType +import org.jetbrains.kotlinx.dataframe.io.db.TableColumnMetadata +import java.sql.Connection +import java.sql.ResultSet + +private val logger = KotlinLogging.logger {} + +/** + * Estimates memory usage for database queries WITHOUT loading all data. + * + * Strategy: + * 1. Load 10 sample rows to calculate average row size + * 2. Use COUNT(*) or database statistics for total row count + * 3. Multiply to get total memory estimate + */ +internal object MemoryEstimator { + + private const val SAMPLE_SIZE = 10 + private const val COUNT_TIMEOUT_SECONDS = 5 + + /** + * Estimates memory for a table. + */ + fun estimateTable( + connection: Connection, + tableName: String, + dbType: DbType, + limit: Int?, + ): MemoryEstimate { + val bytesPerRow = estimateAverageRowSize(connection, tableName, dbType) + val totalRows = limit?.toLong() ?: estimateTableRowCount(connection, tableName, dbType) + val totalBytes = bytesPerRow * totalRows + + return MemoryEstimate( + estimatedRows = totalRows, + bytesPerRow = bytesPerRow, + totalBytes = totalBytes, + humanReadable = formatBytes(totalBytes), + isExact = limit != null, + ) + } + + /** + * Estimates memory for a custom SQL query. + */ + fun estimateQuery( + connection: Connection, + sqlQuery: String, + dbType: DbType, + limit: Int?, + ): MemoryEstimate { + val sampleQuery = dbType.buildSqlQueryWithLimit(sqlQuery, SAMPLE_SIZE) + + var totalSize = 0L + var rowCount = 0 + + connection.prepareStatement(sampleQuery).use { stmt -> + stmt.executeQuery().use { rs -> + val tableColumns = getTableColumnsMetadata(rs) + + while (rs.next()) { + totalSize += calculateRowSizeFromResultSet(rs, tableColumns) + rowCount++ + } + + if (rowCount == 0) { + return MemoryEstimate(0, 0, 0, "0 bytes", true) + } + + val bytesPerRow = totalSize / rowCount + val totalRows = limit?.toLong() ?: estimateQueryRowCount(connection, sqlQuery) + val totalBytes = bytesPerRow * totalRows + + return MemoryEstimate( + estimatedRows = totalRows, + bytesPerRow = bytesPerRow, + totalBytes = totalBytes, + humanReadable = formatBytes(totalBytes), + isExact = limit != null, + ) + } + } + } + + /** + * Calculates average row size from sample rows. + */ + private fun estimateAverageRowSize( + connection: Connection, + tableName: String, + dbType: DbType, + ): Long { + val sampleQuery = dbType.buildSelectTableQueryWithLimit(tableName, SAMPLE_SIZE) + + var totalSize = 0L + var rowCount = 0 + + connection.prepareStatement(sampleQuery).use { stmt -> + stmt.executeQuery().use { rs -> + val tableColumns = getTableColumnsMetadata(rs) + + while (rs.next()) { + totalSize += calculateRowSizeFromResultSet(rs, tableColumns) + rowCount++ + } + } + } + + return if (rowCount > 0) totalSize / rowCount else 64L + } + + /** + * Calculates the size of a single row from ResultSet. + */ + private fun calculateRowSizeFromResultSet( + rs: ResultSet, + tableColumns: List, + ): Long { + var size = 16L // DataRow overhead + + for (i in tableColumns.indices) { + val value = rs.getObject(i + 1) + val metadata = tableColumns[i] + size += estimateValueSize(value, metadata) + } + + size += tableColumns.size * 64L // DataFrame column overhead + return size + } + + /** + * Estimates value size based on actual value or metadata. + */ + private fun estimateValueSize(value: Any?, metadata: TableColumnMetadata): Long { + if (value != null) { + return when (value) { + is Boolean, is Byte -> 16 + is Short -> 16 + is Int -> 16 + is Long, is Double -> 24 + is Float -> 16 + is String -> 40 + (value.length * 2L) + is ByteArray -> 16 + value.size + else -> 48 + } as Long + } + + // Fallback to metadata + return when (metadata.jdbcType) { + java.sql.Types.BIT, java.sql.Types.BOOLEAN, java.sql.Types.TINYINT -> 16 + java.sql.Types.SMALLINT -> 16 + java.sql.Types.INTEGER -> 16 + java.sql.Types.BIGINT, java.sql.Types.DOUBLE -> 24 + java.sql.Types.REAL, java.sql.Types.FLOAT -> 16 + java.sql.Types.CHAR, java.sql.Types.VARCHAR -> 40 + (metadata.size.coerceAtMost(100) * 2L) + else -> 48 + } + } + + /** + * Estimates total row count using COUNT(*) with timeout. + */ + private fun estimateTableRowCount( + connection: Connection, + tableName: String, + dbType: DbType, + ): Long { + return try { + connection.prepareStatement("SELECT COUNT(*) FROM $tableName").use { stmt -> + stmt.queryTimeout = COUNT_TIMEOUT_SECONDS + stmt.executeQuery().use { rs -> + if (rs.next()) rs.getLong(1) else 0L + } + } + } catch (e: Exception) { + logger.warn(e) { "COUNT(*) failed for $tableName, using fallback" } + 1000L // Conservative fallback + } + } + + /** + * Estimates query row count using COUNT(*) wrapper. + */ + private fun estimateQueryRowCount(connection: Connection, sqlQuery: String): Long { + return try { + val countQuery = "SELECT COUNT(*) FROM ($sqlQuery) AS temp_count" + connection.prepareStatement(countQuery).use { stmt -> + stmt.queryTimeout = COUNT_TIMEOUT_SECONDS + stmt.executeQuery().use { rs -> + if (rs.next()) rs.getLong(1) else 0L + } + } + } catch (e: Exception) { + logger.warn(e) { "Cannot wrap query with COUNT(*), using sample-based estimate" } + 1000L // Conservative fallback + } + } +} + +private fun formatBytes(bytes: Long): String = when { + bytes < 1024 -> "$bytes bytes" + bytes < 1024 * 1024 -> "%.1f KB".format(bytes / 1024.0) + bytes < 1024 * 1024 * 1024 -> "%.1f MB".format(bytes / (1024.0 * 1024.0)) + else -> "%.2f GB".format(bytes / (1024.0 * 1024.0 * 1024.0)) +} diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/progressTracker.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/progressTracker.kt new file mode 100644 index 0000000000..df7c0ee593 --- /dev/null +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/progressTracker.kt @@ -0,0 +1,301 @@ +package org.jetbrains.kotlinx.dataframe.io + +import java.sql.Date +import java.sql.Time +import java.sql.Types +import io.github.oshai.kotlinlogging.KLogger +import org.jetbrains.kotlinx.dataframe.io.db.TableColumnMetadata +import java.math.BigDecimal +import java.math.BigInteger +import java.sql.Blob +import java.sql.Clob +import java.sql.ResultSet +import java.sql.Timestamp +import java.time.Instant +import java.time.LocalDate +import java.time.LocalDateTime +import java.time.LocalTime +import java.time.OffsetDateTime +import java.time.ZonedDateTime +import java.util.UUID + +/** + * Strategy interface for tracking progress during data loading. + */ +internal interface ProgressTracker { + fun onStart() + fun onRowLoaded() + fun onComplete(rowCount: Int) + + companion object { + fun create(logger: KLogger): ProgressTracker { + return when { + JdbcConfig.PROGRESS_ENABLED -> { + if (JdbcConfig.PROGRESS_DETAILED) { + DetailedProgressTracker(logger) + } else { + SimpleProgressTracker(logger) + } + } + logger.isDebugEnabled -> { + DetailedProgressTracker(logger) + } + else -> NoOpProgressTracker + } + } + } +} + +/** + * No-op implementation - zero overhead when disabled. + */ +private object NoOpProgressTracker : ProgressTracker { + override fun onStart() {} + override fun onRowLoaded() {} + override fun onComplete(rowCount: Int) {} +} + +/** + * Simple progress tracker - shows only basic information. + */ +private class SimpleProgressTracker(private val logger: KLogger) : ProgressTracker { + private var rowsLoaded = 0 + + override fun onStart() {} + + override fun onRowLoaded() { + rowsLoaded++ + if (rowsLoaded % JdbcConfig.PROGRESS_INTERVAL == 0) { + logger.debug { "Loaded $rowsLoaded rows" } + } + } + + override fun onComplete(rowCount: Int) { + if (rowCount > 0) { + logger.debug { "Loading complete: $rowCount rows" } + } + } +} + +/** + * Detailed progress tracker with statistics and memory estimation. + */ +internal class DetailedProgressTracker(private val logger: KLogger) : ProgressTracker { + private val startTime = System.currentTimeMillis() + private var rowsLoaded = 0 + private var firstRowSize: Long? = null + private var estimatedTotalRows: Long? = null + private var estimatedMemoryBytes: Long? = null + private var memoryWarningShown = false + + override fun onStart() {} + + override fun onRowLoaded() { + rowsLoaded++ + + if (rowsLoaded % JdbcConfig.PROGRESS_INTERVAL == 0) { + logDetailedProgress() + } + } + + override fun onComplete(rowCount: Int) { + if (rowCount == 0) return + + val totalMillis = System.currentTimeMillis() - startTime + val rowsPerSecond = if (totalMillis > 0) { + (rowCount.toDouble() / totalMillis * 1000).toInt() + } else { + 0 + } + + val memoryInfo = estimatedMemoryBytes?.let { + val actualMemory = firstRowSize?.let { size -> size * rowCount } ?: it + " using ${formatBytes(actualMemory)}" + } ?: "" + + if (totalMillis > 0) { + logger.debug { + "Loading complete: $rowCount rows in ${totalMillis}ms (~$rowsPerSecond rows/sec)$memoryInfo" + } + } else { + logger.debug { "Loading complete: $rowCount rows$memoryInfo" } + } + } + + /** + * Estimates memory on first row and shows warnings. + */ + fun estimateMemoryOnFirstRow( + columnData: List>, + tableColumns: List, + rs: ResultSet, + ) { + if (memoryWarningShown) return + + try { + firstRowSize = calculateRowSize(columnData, tableColumns) + estimatedTotalRows = estimateTotalRows(rs) + + if (firstRowSize != null && estimatedTotalRows != null) { + estimatedMemoryBytes = firstRowSize!! * estimatedTotalRows!! + + if (estimatedMemoryBytes!! > 100 * 1024 * 1024) { + logger.debug { + "Estimated memory: ${formatBytes(estimatedMemoryBytes!!)} for ~$estimatedTotalRows rows" + } + + if (estimatedMemoryBytes!! > 1024L * 1024L * 1024L) { + logger.warn { + "Large dataset detected (${formatBytes(estimatedMemoryBytes!!)}). " + + "Consider using 'limit' parameter." + } + memoryWarningShown = true + } + } + } + } catch (e: Exception) { + logger.debug(e) { "Failed to estimate memory" } + } + } + + private fun logDetailedProgress() { + val elapsedMillis = System.currentTimeMillis() - startTime + if (elapsedMillis == 0L) return + + val rowsPerSecond = (rowsLoaded.toDouble() / elapsedMillis * 1000).toInt() + + val progressPercent = estimatedTotalRows?.let { + " (${(rowsLoaded * 100.0 / it).toInt()}%)" + } ?: "" + + logger.debug { + "Loaded $rowsLoaded rows$progressPercent in ${elapsedMillis}ms (~$rowsPerSecond rows/sec)" + } + } + + private fun calculateRowSize( + columnData: List>, + tableColumns: List, + ): Long { + var size = 16L // DataRow overhead + + columnData.forEachIndexed { index, values -> + val value = values.firstOrNull() + val columnMetadata = tableColumns.getOrNull(index) + size += estimateValueSize(value, columnMetadata) + } + + size += columnData.size * 64L // DataFrame column overhead + return size + } + + private fun estimateValueSize(value: Any?, metadata: TableColumnMetadata?): Long { + if (value != null) { + return estimateActualValueSize(value) + } + + if (metadata != null) { + return estimateValueSizeFromMetadata(metadata) + } + + return 8L + } + + private fun estimateActualValueSize(value: Any): Long { + return when (value) { + is Boolean, is Byte -> 16 + is Short -> 16 + is Int -> 16 + is Long, is Double -> 24 + is Float -> 16 + is Char -> 16 + is String -> 40 + (value.length * 2L) + is ByteArray -> 16 + value.size + is CharArray -> 16 + (value.size * 2L) + is BigInteger -> 32 + (value.bitLength() / 8) + is BigDecimal -> { + val unscaledValue = value.unscaledValue() + 48 + (unscaledValue.bitLength() / 8) + } + + is LocalDate -> 24 + is LocalTime -> 24 + is LocalDateTime -> 48 + is Instant -> 24 + is OffsetDateTime -> 56 + is ZonedDateTime -> 64 + is Timestamp -> 32 + is Date -> 32 + is Time -> 32 + is UUID -> 32 + is Blob -> 48 + estimateBlobSize(value) + is Clob -> 48 + estimateClobSize(value) + else -> 48 + } as Long + } + + private fun estimateValueSizeFromMetadata(metadata: TableColumnMetadata): Long { + return when (metadata.jdbcType) { + Types.BIT, Types.BOOLEAN -> 16 + Types.TINYINT -> 16 + Types.SMALLINT -> 16 + Types.INTEGER -> 16 + Types.BIGINT -> 24 + Types.REAL, Types.FLOAT -> 16 + Types.DOUBLE -> 24 + Types.NUMERIC, Types.DECIMAL -> 48 + (metadata.size / 4) + Types.CHAR, Types.VARCHAR, Types.LONGVARCHAR -> + 40 + (metadata.size.coerceAtMost(1000) * 2L) + + Types.DATE -> 24 + Types.TIME -> 24 + Types.TIMESTAMP -> 32 + Types.TIMESTAMP_WITH_TIMEZONE -> 56 + Types.BINARY, Types.VARBINARY, Types.LONGVARBINARY -> + 16 + metadata.size.coerceAtMost(1000) + + Types.BLOB -> 64 + Types.CLOB -> 64 + else -> 48 + } as Long + } + + private fun estimateBlobSize(blob: Blob): Long { + return try { + blob.length().coerceAtMost(10 * 1024 * 1024) + } catch (e: Exception) { + 1024L + } + } + + private fun estimateClobSize(clob: Clob): Long { + return try { + clob.length() * 2L + } catch (e: Exception) { + 2048L + } + } + + private fun estimateTotalRows(rs: ResultSet): Long? { + return try { + if (rs.type != ResultSet.TYPE_FORWARD_ONLY) { + val currentRow = rs.row + rs.last() + val total = rs.row.toLong() + rs.absolute(currentRow) + total + } else { + null + } + } catch (e: Exception) { + null + } + } +} + +private fun formatBytes(bytes: Long): String = when { + bytes < 1024 -> "$bytes bytes" + bytes < 1024 * 1024 -> "%.1f KB".format(bytes / 1024.0) + bytes < 1024 * 1024 * 1024 -> "%.1f MB".format(bytes / (1024.0 * 1024.0)) + else -> "%.2f GB".format(bytes / (1024.0 * 1024.0 * 1024.0)) +} diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt index a2594ce131..44dbd110b3 100644 --- a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt @@ -937,9 +937,16 @@ internal fun fetchAndConvertDataFromResultSet( inferNullability: Boolean, ): AnyFrame { val columnKTypes = buildColumnKTypes(tableColumns, dbType) - val columnData = readAllRowsFromResultSet(rs, tableColumns, columnKTypes, dbType, limit) + + val progressTracker = ProgressTracker.create(logger) + + progressTracker.onStart() + + val columnData = readAllRowsFromResultSet(rs, tableColumns, columnKTypes, dbType, limit, progressTracker) val dataFrame = buildDataFrameFromColumnData(columnData, tableColumns, columnKTypes, dbType, inferNullability) + progressTracker.onComplete(dataFrame.rowsCount()) + logger.debug { "DataFrame with ${dataFrame.rowsCount()} rows and ${dataFrame.columnsCount()} columns created as a result of SQL query." } @@ -965,6 +972,7 @@ private fun readAllRowsFromResultSet( columnKTypes: Map, dbType: DbType, limit: Int?, + progressTracker: ProgressTracker, ): List> { val columnsCount = tableColumns.size val columnData = List(columnsCount) { mutableListOf() } @@ -981,12 +989,29 @@ private fun readAllRowsFromResultSet( columnData[columnIndex].add(value) } rowsRead++ - // if (rowsRead % 1000 == 0) logger.debug { "Loaded $rowsRead rows." } // TODO: https://github.com/Kotlin/dataframe/issues/455 + // Progress tracking + trackProgress(rowsRead, columnData, tableColumns, rs, progressTracker) } return columnData } +private fun trackProgress( + rowsRead: Int, + columnData: List>, + tableColumns: List, + rs: ResultSet, + progressTracker: ProgressTracker, +) { + // Memory estimation on first row (only for DetailedProgressTracker) + if (rowsRead == 1 && progressTracker is DetailedProgressTracker) { + progressTracker.estimateMemoryOnFirstRow(columnData, tableColumns, rs) + } + + // Progress update + progressTracker.onRowLoaded() +} + /** * Builds DataFrame from column-oriented data. * Accepts mutable lists to enable efficient in-place transformations. diff --git a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2/tracking/JdbcProgressOutputTest.kt b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2/tracking/JdbcProgressOutputTest.kt new file mode 100644 index 0000000000..02303e6f2c --- /dev/null +++ b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2/tracking/JdbcProgressOutputTest.kt @@ -0,0 +1,257 @@ +package org.jetbrains.kotlinx.dataframe.io.h2.tracking + +import io.kotest.matchers.shouldBe +import org.jetbrains.kotlinx.dataframe.DataFrame +import org.jetbrains.kotlinx.dataframe.io.JdbcSafeDataLoading +import org.jetbrains.kotlinx.dataframe.io.readSqlTable +import org.junit.After +import org.junit.Before +import org.junit.Test +import java.io.ByteArrayOutputStream +import java.io.PrintStream +import java.sql.Connection +import java.sql.DriverManager + +/** + * Tests for verifying log output of progress tracking and safe loading. + * + * Note: These tests capture System.err output instead of using Logback-specific + * log appenders to remain compatible with any SLF4J implementation. + */ +class JdbcProgressOutputTest { + + private lateinit var connection: Connection + private lateinit var originalErr: PrintStream + private lateinit var capturedErr: ByteArrayOutputStream + + @Before + fun setUp() { + connection = DriverManager.getConnection("jdbc:h2:mem:test_output;DB_CLOSE_DELAY=-1;MODE=MySQL") + + // Setup System.err capture for SLF4J Simple logging + originalErr = System.err + capturedErr = ByteArrayOutputStream() + System.setErr(PrintStream(capturedErr)) + + // Create test table + connection.createStatement().use { stmt -> + stmt.executeUpdate( + """ + CREATE TABLE test_progress ( + id INT PRIMARY KEY, + data VARCHAR(100) + ) + """.trimIndent() + ) + + // Insert 3000 rows + connection.autoCommit = false + val insertStmt = connection.prepareStatement("INSERT INTO test_progress VALUES (?, ?)") + for (i in 1..3000) { + insertStmt.setInt(1, i) + insertStmt.setString(2, "Data_$i") + insertStmt.addBatch() + + if (i % 500 == 0) { + insertStmt.executeBatch() + } + } + insertStmt.executeBatch() + connection.commit() + connection.autoCommit = true + } + } + + @After + fun tearDown() { + System.setErr(originalErr) + connection.createStatement().use { it.executeUpdate("DROP TABLE IF EXISTS test_progress") } + connection.close() + System.clearProperty("dataframe.jdbc.progress") + System.clearProperty("dataframe.jdbc.progress.detailed") + System.clearProperty("dataframe.jdbc.progress.interval") + } + + private fun getLogOutput(): String = capturedErr.toString("UTF-8") + + private fun clearLogOutput() { + capturedErr.reset() + } + + @Test + fun `progress tracking logs debug messages`() { + System.setProperty("dataframe.jdbc.progress", "true") + clearLogOutput() + + DataFrame.readSqlTable(connection, "test_progress") + + val logOutput = getLogOutput() + + // Should have progress messages + val hasProgressMessages = logOutput.contains("Loaded") && logOutput.contains("rows") + hasProgressMessages shouldBe true + } + + @Test + fun `detailed progress shows statistics`() { + System.setProperty("dataframe.jdbc.progress", "true") + System.setProperty("dataframe.jdbc.progress.detailed", "true") + clearLogOutput() + + DataFrame.readSqlTable(connection, "test_progress") + + val logOutput = getLogOutput() + + // Should contain statistics + val hasStatistics = logOutput.contains("rows/sec") || logOutput.contains("ms") + hasStatistics shouldBe true + } + + @Test + fun `simple progress shows basic messages`() { + System.setProperty("dataframe.jdbc.progress", "true") + System.setProperty("dataframe.jdbc.progress.detailed", "false") + clearLogOutput() + + DataFrame.readSqlTable(connection, "test_progress") + + val logOutput = getLogOutput() + + // Should have messages but without detailed stats + val hasBasicMessages = logOutput.contains("Loaded") && logOutput.contains("rows") + hasBasicMessages shouldBe true + + // Should NOT contain detailed statistics + val hasDetailedStats = logOutput.contains("rows/sec") + hasDetailedStats shouldBe false + } + + @Test + fun `progress interval is respected`() { + System.setProperty("dataframe.jdbc.progress", "true") + System.setProperty("dataframe.jdbc.progress.interval", "1000") + clearLogOutput() + + DataFrame.readSqlTable(connection, "test_progress") + + val logOutput = getLogOutput() + + // Should have messages at 1000, 2000, 3000 + val has1000 = logOutput.contains("1000 rows") + val has2000 = logOutput.contains("2000 rows") + has1000 shouldBe true + has2000 shouldBe true + } + + @Test + fun `completion message is logged`() { + System.setProperty("dataframe.jdbc.progress", "true") + clearLogOutput() + + DataFrame.readSqlTable(connection, "test_progress") + + val logOutput = getLogOutput() + + // Should have completion message + val hasCompletion = logOutput.contains("Loading complete") || logOutput.contains("3000 rows") + hasCompletion shouldBe true + } + + @Test + fun `memory warning is logged for large datasets`() { + System.setProperty("dataframe.jdbc.progress", "true") + + // Create a larger table + connection.createStatement().use { stmt -> + stmt.executeUpdate( + """ + CREATE TABLE large_test ( + id INT PRIMARY KEY, + data VARCHAR(10000) + ) + """.trimIndent() + ) + + connection.autoCommit = false + val insertStmt = connection.prepareStatement("INSERT INTO large_test VALUES (?, ?)") + for (i in 1..1000) { + insertStmt.setInt(1, i) + insertStmt.setString(2, "x".repeat(10000)) // 10KB per row = 10MB total + insertStmt.addBatch() + + if (i % 100 == 0) { + insertStmt.executeBatch() + } + } + insertStmt.executeBatch() + connection.commit() + connection.autoCommit = true + } + + clearLogOutput() + + DataFrame.readSqlTable(connection, "large_test") + + val logOutput = getLogOutput() + + // Should have memory estimate message (if logged at DEBUG/INFO level) + // This is optional as it depends on logging configuration + + connection.createStatement().use { it.executeUpdate("DROP TABLE IF EXISTS large_test") } + } + + @Test + fun `safe loading logs warnings when limit applied`() { + clearLogOutput() + + JdbcSafeDataLoading.load( + configure = { + maxMemoryGb = 0.00001 // Very small limit to trigger warning + onExceed = JdbcSafeDataLoading.ExceedAction.APPLY_LIMIT + } + ) { + readSqlTable(connection, "test_progress") + } + + val logOutput = getLogOutput() + + // Should have warning about exceeding limit + val hasWarning = logOutput.contains("exceeds limit") || logOutput.contains("WARN") + hasWarning shouldBe true + } + + @Test + fun `safe loading logs warning when proceeding despite limit`() { + clearLogOutput() + + JdbcSafeDataLoading.load( + configure = { + maxMemoryGb = 0.00001 + onExceed = JdbcSafeDataLoading.ExceedAction.WARN_AND_PROCEED + } + ) { + readSqlTable(connection, "test_progress") + } + + val logOutput = getLogOutput() + + // Should have warning about proceeding + val hasWarning = (logOutput.contains("exceeds limit") && logOutput.contains("proceeding")) || + logOutput.contains("WARN") + hasWarning shouldBe true + } + + @Test + fun `no logs when progress disabled`() { + System.clearProperty("dataframe.jdbc.progress") + clearLogOutput() + + DataFrame.readSqlTable(connection, "test_progress") + + val logOutput = getLogOutput() + + // Should NOT have progress messages (only connection/query logs possibly) + val hasProgressMessages = logOutput.matches(Regex(".*Loaded \\d+ rows.*")) + hasProgressMessages shouldBe false + } +} diff --git a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2/tracking/JdbcProgressTest.kt b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2/tracking/JdbcProgressTest.kt new file mode 100644 index 0000000000..5c1e37782a --- /dev/null +++ b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2/tracking/JdbcProgressTest.kt @@ -0,0 +1,137 @@ + +package org.jetbrains.kotlinx.dataframe.io.h2.tracking + +import io.kotest.matchers.shouldBe +import org.jetbrains.kotlinx.dataframe.DataFrame +import org.jetbrains.kotlinx.dataframe.io.readSqlQuery +import org.jetbrains.kotlinx.dataframe.io.readSqlTable +import org.junit.After +import org.junit.Before +import org.junit.Test +import java.sql.Connection +import java.sql.DriverManager + +/** + * Tests for JDBC progress tracking functionality. + */ +class JdbcProgressTest { + + private lateinit var connection: Connection + + @Before + fun setUp() { + connection = DriverManager.getConnection("jdbc:h2:mem:test_progress;DB_CLOSE_DELAY=-1;MODE=MySQL") + + connection.createStatement().use { stmt -> + stmt.executeUpdate( + """ + CREATE TABLE test_table ( + id INT PRIMARY KEY, + name VARCHAR(100), + amount INT + ) + """.trimIndent() + ) + + // Insert 5000 rows for testing + connection.autoCommit = false + val insertStmt = connection.prepareStatement("INSERT INTO test_table VALUES (?, ?, ?)") + for (i in 1..5000) { + insertStmt.setInt(1, i) + insertStmt.setString(2, "Name_$i") + insertStmt.setInt(3, i * 10) + insertStmt.addBatch() + + if (i % 500 == 0) { + insertStmt.executeBatch() + } + } + insertStmt.executeBatch() + connection.commit() + connection.autoCommit = true + } + } + + @After + fun tearDown() { + connection.createStatement().use { it.executeUpdate("DROP TABLE IF EXISTS test_table") } + connection.close() + } + + @Test + fun `progress tracking disabled by default`() { + // Progress should be disabled by default + System.clearProperty("dataframe.jdbc.progress") + + val df = DataFrame.readSqlTable(connection, "test_table") + + df.rowsCount() shouldBe 5000 + } + + @Test + fun `progress tracking can be enabled via system property`() { + System.setProperty("dataframe.jdbc.progress", "true") + + try { + val df = DataFrame.readSqlTable(connection, "test_table") + df.rowsCount() shouldBe 5000 + } finally { + System.clearProperty("dataframe.jdbc.progress") + } + } + + @Test + fun `progress tracking with limit`() { + System.setProperty("dataframe.jdbc.progress", "true") + + try { + val df = DataFrame.readSqlTable(connection, "test_table", limit = 100) + df.rowsCount() shouldBe 100 + } finally { + System.clearProperty("dataframe.jdbc.progress") + } + } + + @Test + fun `detailed progress can be disabled`() { + System.setProperty("dataframe.jdbc.progress", "true") + System.setProperty("dataframe.jdbc.progress.detailed", "false") + + try { + val df = DataFrame.readSqlTable(connection, "test_table") + df.rowsCount() shouldBe 5000 + } finally { + System.clearProperty("dataframe.jdbc.progress") + System.clearProperty("dataframe.jdbc.progress.detailed") + } + } + + @Test + fun `progress interval can be configured`() { + System.setProperty("dataframe.jdbc.progress", "true") + System.setProperty("dataframe.jdbc.progress.interval", "500") + + try { + val df = DataFrame.readSqlTable(connection, "test_table") + df.rowsCount() shouldBe 5000 + } finally { + System.clearProperty("dataframe.jdbc.progress") + System.clearProperty("dataframe.jdbc.progress.interval") + } + } + + @Test + fun `progress works with query`() { + System.setProperty("dataframe.jdbc.progress", "true") + + try { + val df = DataFrame.readSqlQuery( + connection, + "SELECT * FROM test_table WHERE amount > 1000" + ) + df.rowsCount() shouldBe 4900 + } finally { + System.clearProperty("dataframe.jdbc.progress") + } + } +} diff --git a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2/tracking/JdbcSafeLoadingTest.kt b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2/tracking/JdbcSafeLoadingTest.kt new file mode 100644 index 0000000000..33fe97d4c4 --- /dev/null +++ b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2/tracking/JdbcSafeLoadingTest.kt @@ -0,0 +1,177 @@ +package org.jetbrains.kotlinx.dataframe.io.h2.tracking + +import io.kotest.assertions.throwables.shouldThrow +import io.kotest.matchers.ints.shouldBeLessThan +import io.kotest.matchers.shouldBe +import org.jetbrains.kotlinx.dataframe.io.DbConnectionConfig +import org.jetbrains.kotlinx.dataframe.io.JdbcSafeDataLoading +import org.junit.After +import org.junit.Before +import org.junit.Test +import java.sql.Connection +import java.sql.DriverManager + +/** + * Tests for JdbcSafeDataLoading functionality. + */ +class JdbcSafeLoadingTest { + + private lateinit var connection: Connection + private val dbConfig = DbConnectionConfig("jdbc:h2:mem:test_safe;DB_CLOSE_DELAY=-1;MODE=MySQL") + + @Before + fun setUp() { + connection = DriverManager.getConnection(dbConfig.url) + + connection.createStatement().use { stmt -> + stmt.executeUpdate( + """ + CREATE TABLE test_data ( + id INT PRIMARY KEY, + data VARCHAR(1000) + ) + """.trimIndent() + ) + + // Insert 10000 rows (~10KB each = ~100MB total) + connection.autoCommit = false + val insertStmt = connection.prepareStatement("INSERT INTO test_data VALUES (?, ?)") + for (i in 1..10000) { + insertStmt.setInt(1, i) + insertStmt.setString(2, "x".repeat(1000)) // 1KB per row + insertStmt.addBatch() + + if (i % 1000 == 0) { + insertStmt.executeBatch() + } + } + insertStmt.executeBatch() + connection.commit() + connection.autoCommit = true + } + } + + @After + fun tearDown() { + connection.createStatement().use { it.executeUpdate("DROP TABLE IF EXISTS test_data") } + connection.close() + } + + @Test + fun `safe load with sufficient memory limit`() { + val df = JdbcSafeDataLoading.load(maxMemoryGb = 1.0) { + readSqlTable(connection, "test_data") + } + + df.rowsCount() shouldBe 10000 + } + + @Test + fun `safe load applies automatic limit when memory exceeded`() { + var limitApplied = false + var appliedLimitValue = 0 + + val df = JdbcSafeDataLoading.load( + configure = { + maxMemoryGb = 0.001 // Very small limit + onExceed = JdbcSafeDataLoading.ExceedAction.APPLY_LIMIT + + onLimitApplied = { _, limit -> + limitApplied = true + appliedLimitValue = limit + } + } + ) { + readSqlTable(connection, "test_data") + } + + limitApplied shouldBe true + appliedLimitValue shouldBeLessThan 10000 + df.rowsCount() shouldBe appliedLimitValue + } + + @Test + fun `safe load throws when configured`() { + shouldThrow { + JdbcSafeDataLoading.load( + configure = { + maxMemoryGb = 0.001 + onExceed = JdbcSafeDataLoading.ExceedAction.THROW + } + ) { + readSqlTable(connection, "test_data") + } + } + } + + @Test + fun `safe load warns and proceeds when configured`() { + val df = JdbcSafeDataLoading.load( + configure = { + maxMemoryGb = 0.001 + onExceed = JdbcSafeDataLoading.ExceedAction.WARN_AND_PROCEED + } + ) { + readSqlTable(connection, "test_data") + } + + df.rowsCount() shouldBe 10000 + } + + @Test + fun `onEstimate callback is invoked`() { + var estimateCalled = false + var estimatedRows = 0L + + JdbcSafeDataLoading.load( + configure = { + maxMemoryGb = 1.0 + + onEstimate = { estimate -> + estimateCalled = true + estimatedRows = estimate.estimatedRows + } + } + ) { + readSqlTable(connection, "test_data") + } + + estimateCalled shouldBe true + estimatedRows shouldBe 10000L + } + + @Test + fun `safe load works with DbConnectionConfig`() { + val df = JdbcSafeDataLoading.load(maxMemoryGb = 1.0) { + readSqlTable(dbConfig, "test_data") + } + + df.rowsCount() shouldBe 10000 + } + + @Test + fun `safe load works with query`() { + val df = JdbcSafeDataLoading.load(maxMemoryGb = 1.0) { + readSqlQuery(connection, "SELECT * FROM test_data WHERE id <= 100") + } + + df.rowsCount() shouldBe 100 + } + + @Test + fun `loadMultiple works with readAllSqlTables`() { + // Create another table + connection.createStatement().use { stmt -> + stmt.executeUpdate("CREATE TABLE test_data2 (id INT PRIMARY KEY, amount INT)") + stmt.executeUpdate("INSERT INTO test_data2 VALUES (1, 10), (2, 20)") + } + + val tables = JdbcSafeDataLoading.loadMultiple(maxMemoryGb = 1.0) { + readAllSqlTables(connection) + } + + tables.keys shouldBe setOf("TEST_DATA", "TEST_DATA2") + + connection.createStatement().use { it.executeUpdate("DROP TABLE IF EXISTS test_data2") } + } +}