diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index c575e95485cea..097b371ea0722 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -172,12 +172,12 @@ object JDBCRDD extends Logging { * * @param sc - Your SparkContext. * @param schema - The Catalyst schema of the underlying database table. - * @param requiredColumns - The names of the columns to SELECT. + * @param requiredColumns - The names of the columns or aggregate columns to SELECT. * @param filters - The filters to include in all WHERE clauses. * @param parts - An array of JDBCPartitions specifying partition ids and * per-partition WHERE clauses. * @param options - JDBC options that contains url, table and other information. - * @param outputSchema - The schema of the columns to SELECT. + * @param outputSchema - The schema of the columns or aggregate columns to SELECT. * @param groupByColumns - The pushed down group by columns. * * @return An RDD representing "SELECT requiredColumns FROM fqTable". @@ -213,8 +213,8 @@ object JDBCRDD extends Logging { } /** - * An RDD representing a table in a database accessed via JDBC. Both the - * driver code and the workers must be able to access the database; the driver + * An RDD representing a query is related to a table in a database accessed via JDBC. + * Both the driver code and the workers must be able to access the database; the driver * needs to fetch the schema while the workers need to fetch the data. */ private[jdbc] class JDBCRDD( @@ -237,11 +237,7 @@ private[jdbc] class JDBCRDD( /** * `columns`, but as a String suitable for injection into a SQL query. */ - private val columnList: String = { - val sb = new StringBuilder() - columns.foreach(x => sb.append(",").append(x)) - if (sb.isEmpty) "1" else sb.substring(1) - } + private val columnList: String = if (columns.isEmpty) "1" else columns.mkString(",") /** * `filters`, but as a WHERE clause suitable for injection into a SQL query. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index 60d88b6690587..8098fa0b83a95 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -278,12 +278,18 @@ private[sql] case class JDBCRelation( } override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { + // When pushDownPredicate is false, all Filters that need to be pushed down should be ignored + val pushedFilters = if (jdbcOptions.pushDownPredicate) { + filters + } else { + Array.empty[Filter] + } // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row] JDBCRDD.scanTable( sparkSession.sparkContext, schema, requiredColumns, - filters, + pushedFilters, parts, jdbcOptions).asInstanceOf[RDD[Row]] } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 95a91616f80cc..8842db2a2aca4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -1723,6 +1723,40 @@ class JDBCSuite extends QueryTest Row("fred", 1) :: Nil) } + test( + "SPARK-36574: pushDownPredicate=false should prevent push down filters to JDBC data source") { + val df = spark.read.format("jdbc") + .option("Url", urlWithUserAndPass) + .option("dbTable", "test.people") + val df1 = df + .option("pushDownPredicate", false) + .load() + .filter("theid = 1") + .select("name", "theid") + val df2 = df + .option("pushDownPredicate", true) + .load() + .filter("theid = 1") + .select("name", "theid") + val df3 = df + .load() + .select("name", "theid") + + def getRowCount(df: DataFrame): Long = { + val queryExecution = df.queryExecution + val rawPlan = queryExecution.executedPlan.collect { + case p: DataSourceScanExec => p + } match { + case Seq(p) => p + case _ => fail(s"More than one PhysicalRDD found\n$queryExecution") + } + rawPlan.execute().count() + } + + assert(getRowCount(df1) == df3.count) + assert(getRowCount(df2) < df3.count) + } + test("SPARK-26383 throw IllegalArgumentException if wrong kind of driver to the given url") { val e = intercept[IllegalArgumentException] { val opts = Map(