diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 9790c87147..aad46c98c4 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -533,8 +533,8 @@ class CometExecSuite extends CometTestBase { Seq("struct", "array").foreach { dataType => val df = sql(s"""SELECT f.date_id, f.product_id, f.units_sold, f.store_id FROM fact_stats f - |JOIN dim_stats s - |ON $dataType(f.store_id) = $dataType(s.store_id) WHERE s.country = 'DE' + |JOIN dim_stats s + |ON $dataType(f.store_id) = $dataType(s.store_id) WHERE s.country = 'DE' """.stripMargin) checkSparkAnswer(df) } @@ -553,8 +553,8 @@ class CometExecSuite extends CometTestBase { Seq("struct", "array").foreach { dataType => val df = sql(s"""SELECT f.date_id, f.product_id, f.units_sold, f.store_id FROM fact_stats f - |JOIN dim_stats s - |ON $dataType(f.store_id) = $dataType(s.store_id) WHERE s.country = 'DE' + |JOIN dim_stats s + |ON $dataType(f.store_id) = $dataType(s.store_id) WHERE s.country = 'DE' """.stripMargin) val (_, cometPlan) = checkSparkAnswer(df) @@ -570,6 +570,178 @@ class CometExecSuite extends CometTestBase { } } + test("non-AQE DPP with two separate broadcast joins") { + withTempDir { dir => + val path = s"${dir.getAbsolutePath}/data" + withSQLConf(CometConf.COMET_EXEC_ENABLED.key -> "false") { + spark + .range(100) + .selectExpr( + "cast(id % 5 as int) as store_id", + "cast(id % 3 as int) as region_id", + "cast(id as int) as amount") + .write + .partitionBy("store_id", "region_id") + .parquet(s"$path/fact") + spark + .range(5) + .selectExpr("cast(id as int) as store_id", "cast(id as string) as store_name") + .write + .parquet(s"$path/store_dim") + spark + .range(3) + .selectExpr("cast(id as int) as region_id", "cast(id as string) as region_name") + .write + .parquet(s"$path/region_dim") + } + + withSQLConf( + SQLConf.USE_V1_SOURCE_LIST.key -> "parquet", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { + spark.read.parquet(s"$path/fact").createOrReplaceTempView("fact_two_joins") + spark.read.parquet(s"$path/store_dim").createOrReplaceTempView("store_dim") + spark.read.parquet(s"$path/region_dim").createOrReplaceTempView("region_dim") + + val df = spark.sql("""SELECT f.amount, s.store_name, r.region_name + |FROM fact_two_joins f + |JOIN store_dim s ON f.store_id = s.store_id + |JOIN region_dim r ON f.region_id = r.region_id + |WHERE s.store_name = '1' AND r.region_name = '2'""".stripMargin) + val (_, cometPlan) = checkSparkAnswer(df) + + val nativeScans = cometPlan.collect { case s: CometNativeScanExec => s } + assert(nativeScans.nonEmpty, "Expected CometNativeScanExec in plan") + + val dppScans = + nativeScans.filter(_.partitionFilters.exists(_.isInstanceOf[DynamicPruningExpression])) + assert( + dppScans.nonEmpty, + "Expected at least one CometNativeScanExec with DynamicPruningExpression") + } + } + } + + test("non-AQE DPP fallback when broadcast exchange is not Comet") { + withTempDir { dir => + val path = s"${dir.getAbsolutePath}/data" + withSQLConf(CometConf.COMET_EXEC_ENABLED.key -> "false") { + spark + .range(100) + .selectExpr("cast(id % 10 as int) as store_id", "cast(id as int) as amount") + .write + .partitionBy("store_id") + .parquet(s"$path/fact") + spark + .range(10) + .selectExpr("cast(id as int) as store_id", "cast(id as string) as country") + .write + .parquet(s"$path/dim") + } + + // Disable Comet broadcast exchange so SubqueryBroadcastExec wraps a Spark + // BroadcastExchangeExec. convertSubqueryBroadcasts should skip it (child isn't + // CometNativeExec). Query should still produce correct results via Spark's standard path. + withSQLConf( + SQLConf.USE_V1_SOURCE_LIST.key -> "parquet", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false", + CometConf.COMET_EXEC_BROADCAST_EXCHANGE_ENABLED.key -> "false", + CometConf.COMET_EXEC_BROADCAST_HASH_JOIN_ENABLED.key -> "false") { + spark.read.parquet(s"$path/fact").createOrReplaceTempView("fact_fallback") + spark.read.parquet(s"$path/dim").createOrReplaceTempView("dim_fallback") + + val df = spark.sql("""SELECT f.amount, f.store_id + |FROM fact_fallback f JOIN dim_fallback d + |ON f.store_id = d.store_id + |WHERE d.country = 'DE'""".stripMargin) + checkSparkAnswer(df) + } + } + } + + test("non-AQE DPP with empty broadcast result") { + withTempDir { dir => + val path = s"${dir.getAbsolutePath}/data" + withSQLConf(CometConf.COMET_EXEC_ENABLED.key -> "false") { + spark + .range(100) + .selectExpr("cast(id % 10 as int) as store_id", "cast(id as int) as amount") + .write + .partitionBy("store_id") + .parquet(s"$path/fact") + spark + .range(10) + .selectExpr("cast(id as int) as store_id", "cast(id as string) as country") + .write + .parquet(s"$path/dim") + } + + withSQLConf( + SQLConf.USE_V1_SOURCE_LIST.key -> "parquet", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { + spark.read.parquet(s"$path/fact").createOrReplaceTempView("fact_empty") + spark.read.parquet(s"$path/dim").createOrReplaceTempView("dim_empty") + + // Filter on dim that matches nothing -- DPP prunes all partitions + val df = spark.sql("""SELECT f.amount, f.store_id + |FROM fact_empty f JOIN dim_empty d + |ON f.store_id = d.store_id + |WHERE d.country = 'NONEXISTENT'""".stripMargin) + val result = df.collect() + assert(result.isEmpty, s"Expected empty result but got ${result.length} rows") + checkSparkAnswer(df) + } + } + } + + test("non-AQE DPP resolves both outer and inner partition filters") { + // CometNativeScanExec.partitionFilters and CometScanExec.partitionFilters contain + // different InSubqueryExec instances. Both must be resolved for partition selection + // to work correctly. This test verifies correct results, which requires both sets + // of filters to be resolved. + withTempDir { dir => + val path = s"${dir.getAbsolutePath}/data" + withSQLConf(CometConf.COMET_EXEC_ENABLED.key -> "false") { + spark + .range(100) + .selectExpr( + "cast(id % 10 as int) as store_id", + "cast(id as int) as date_id", + "cast(id as int) as amount") + .write + .partitionBy("store_id") + .parquet(s"$path/fact") + spark + .range(10) + .selectExpr("cast(id as int) as store_id", "cast(id as string) as country") + .write + .parquet(s"$path/dim") + } + + withSQLConf( + SQLConf.USE_V1_SOURCE_LIST.key -> "parquet", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false", + SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { + spark.read.parquet(s"$path/fact").createOrReplaceTempView("fact_dual") + spark.read.parquet(s"$path/dim").createOrReplaceTempView("dim_dual") + + val df = spark.sql("""SELECT f.date_id, f.store_id + |FROM fact_dual f JOIN dim_dual d + |ON f.store_id = d.store_id + |WHERE d.country = 'DE'""".stripMargin) + val (_, cometPlan) = checkSparkAnswer(df) + + // Verify native scan is used + val nativeScans = cometPlan.collect { case s: CometNativeScanExec => s } + assert(nativeScans.nonEmpty, "Expected CometNativeScanExec in plan") + + // Verify DPP is present + val dppScans = + nativeScans.filter(_.partitionFilters.exists(_.isInstanceOf[DynamicPruningExpression])) + assert(dppScans.nonEmpty, "Expected DPP filter on native scan") + } + } + } + test("ShuffleQueryStageExec could be direct child node of CometBroadcastExchangeExec") { withSQLConf(CometConf.COMET_SHUFFLE_MODE.key -> "jvm") { val table = "src"