Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ class JDBCRDD(
logInfo(log"Generated JDBC query to fetch data: ${MDC(SQL_TEXT, sqlText)}")
stmt = conn.prepareStatement(sqlText,
ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
stmt.setFetchSize(options.fetchSize)
stmt.setFetchSize(dialect.getFetchSize(options))
stmt.setQueryTimeout(options.queryTimeout)

rs = SQLMetrics.withTimingNs(queryExecutionTimeMetric) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@

package org.apache.spark.sql.jdbc

import java.sql.SQLException
import java.sql.{Connection, SQLException}

import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
import org.apache.spark.sql.types.{DataType, MetadataBuilder}

/**
Expand Down Expand Up @@ -83,4 +84,12 @@ private class AggregatedDialect(dialects: List[JdbcDialect])
cascade: Option[Boolean] = isCascadingTruncateTable()): String = {
dialects.head.getTruncateQuery(table, cascade)
}

override def getFetchSize(options: JDBCOptions): Int = {
dialects.head.getFetchSize(options)
}

override def beforeFetch(connection: Connection, options: JDBCOptions): Unit = {
dialects.head.beforeFetch(connection, options)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,15 @@ abstract class JdbcDialect extends Serializable with Logging {
s"TRUNCATE TABLE $table"
}

/**
* Returns the effective fetch size for reading from the JDBC source.
* By default, returns the user-specified fetchSize from [[JDBCOptions]].
* Dialects can override this to provide a sensible default when the user does not
* explicitly set the fetchSize option.
*/
@Since("4.2.0")
def getFetchSize(options: JDBCOptions): Int = options.fetchSize

/**
* Override connection specific properties to run before a select is made. This is in place to
* allow dialects that need special treatment to optimize behavior.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,16 +191,32 @@ private case class PostgresDialect()
}
}

override def beforeFetch(connection: Connection, properties: Map[String, String]): Unit = {
super.beforeFetch(connection, properties)
// PostgreSQL JDBC driver fetches all rows into memory by default (fetchSize=0),
// which can cause executor OOM. Override to use 1000 as a sensible default when
// the user does not explicitly set the fetchSize option.
private val POSTGRES_DEFAULT_FETCH_SIZE = 1000

override def getFetchSize(options: JDBCOptions): Int = {
options.parameters.get(JDBCOptions.JDBC_BATCH_FETCH_SIZE) match {
case Some(v) => v.toInt
case None =>
logInfo(s"No fetchSize option set for PostgreSQL JDBC read. " +
s"Defaulting to $POSTGRES_DEFAULT_FETCH_SIZE to avoid loading all rows into memory. " +
s"Set the 'fetchsize' option explicitly to override this behavior.")
POSTGRES_DEFAULT_FETCH_SIZE
}
}

override def beforeFetch(connection: Connection, options: JDBCOptions): Unit = {
super.beforeFetch(connection, options)

// According to the postgres jdbc documentation we need to be in autocommit=false if we actually
// want to have fetchsize be non 0 (all the rows). This allows us to not have to cache all the
// rows inside the driver when fetching.
//
// See: https://jdbc.postgresql.org/documentation/head/query.html#query-with-cursor
//
if (properties.getOrElse(JDBCOptions.JDBC_BATCH_FETCH_SIZE, "0").toInt > 0) {
if (getFetchSize(options) > 0) {
connection.setAutoCommit(false)
}
}
Expand Down
54 changes: 54 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,60 @@ class JDBCSuite extends QueryTest with SharedSparkSession {
}
}

test("SPARK-56251: Dialect getFetchSize is applied when user does not specify fetchsize") {
@volatile var capturedFetchSize: Int = -1

val testDialect = new JdbcDialect {
override def canHandle(url: String): Boolean = url.startsWith("jdbc:h2")
override def getFetchSize(options: JDBCOptions): Int = {
val result = options.parameters.get(JDBCOptions.JDBC_BATCH_FETCH_SIZE) match {
case Some(v) => v.toInt
case None => 100
}
capturedFetchSize = result
result
}
}

JdbcDialects.registerDialect(testDialect)
try {
val df = spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties())
assert(df.collect().length === 3)
assert(capturedFetchSize === 100,
s"Expected getFetchSize to return 100 (dialect default), got $capturedFetchSize")
} finally {
JdbcDialects.unregisterDialect(testDialect)
}
}

test("SPARK-56251: User-specified fetchsize takes precedence over dialect getFetchSize") {
@volatile var capturedFetchSize: Int = -1

val testDialect = new JdbcDialect {
override def canHandle(url: String): Boolean = url.startsWith("jdbc:h2")
override def getFetchSize(options: JDBCOptions): Int = {
val result = options.parameters.get(JDBCOptions.JDBC_BATCH_FETCH_SIZE) match {
case Some(v) => v.toInt
case None => 100
}
capturedFetchSize = result
result
}
}

JdbcDialects.registerDialect(testDialect)
try {
val properties = new Properties()
properties.setProperty(JDBCOptions.JDBC_BATCH_FETCH_SIZE, "42")
val df = spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", properties)
assert(df.collect().length === 3)
assert(capturedFetchSize === 42,
s"Expected getFetchSize to return 42 (user-specified), got $capturedFetchSize")
} finally {
JdbcDialects.unregisterDialect(testDialect)
}
}

test("Partitioning via JDBCPartitioningInfo API") {
val df = spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3, new Properties())
checkNumPartitions(df, expectedNumPartitions = 3)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions

class PostgresDialectSuite extends SparkFunSuite with MockitoSugar {

private val dialect = PostgresDialect()

private def createJDBCOptions(extraOptions: Map[String, String]): JDBCOptions = {
new JDBCOptions(Map(
"url" -> "jdbc:postgresql://localhost:5432/test",
Expand All @@ -37,29 +39,43 @@ class PostgresDialectSuite extends SparkFunSuite with MockitoSugar {

test("beforeFetch sets autoCommit=false with lowercase fetchsize") {
val conn = mock[Connection]
val dialect = PostgresDialect()
dialect.beforeFetch(conn, createJDBCOptions(Map("fetchsize" -> "100")))
verify(conn).setAutoCommit(false)
}

test("beforeFetch sets autoCommit=false with camelCase fetchSize") {
val conn = mock[Connection]
val dialect = PostgresDialect()
dialect.beforeFetch(conn, createJDBCOptions(Map("fetchSize" -> "100")))
verify(conn).setAutoCommit(false)
}

test("beforeFetch sets autoCommit=false with uppercase FETCHSIZE") {
val conn = mock[Connection]
val dialect = PostgresDialect()
dialect.beforeFetch(conn, createJDBCOptions(Map("FETCHSIZE" -> "100")))
verify(conn).setAutoCommit(false)
}

test("beforeFetch does not set autoCommit when fetchSize is 0") {
val conn = mock[Connection]
val dialect = PostgresDialect()
dialect.beforeFetch(conn, createJDBCOptions(Map("fetchsize" -> "0")))
verify(conn, never()).setAutoCommit(false)
}

test("SPARK-56251: getFetchSize: returns 1000 when not set (Postgres default)") {
assert(dialect.getFetchSize(createJDBCOptions(Map.empty)) === 1000)
}

test("SPARK-56251: getFetchSize: base dialect returns 0 when not set") {
val baseDialect = new JdbcDialect {
override def canHandle(url: String): Boolean = true
}
assert(baseDialect.getFetchSize(createJDBCOptions(Map.empty)) === 0)
}

test("SPARK-56251: beforeFetch sets autoCommit=false when using default fetchSize") {
val conn = mock[Connection]
// No explicit fetchsize - should use Postgres default (1000) and set autoCommit=false
dialect.beforeFetch(conn, createJDBCOptions(Map.empty))
verify(conn).setAutoCommit(false)
}
}