diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala index 0f698d8aa6..ccf218cf6c 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala @@ -129,11 +129,16 @@ case class CometBroadcastExchangeExec( case AQEShuffleReadExec(s: ShuffleQueryStageExec, _) if s.plan.isInstanceOf[CometPlan] => CometExec.getByteArrayRdd(s.plan.asInstanceOf[CometPlan]).collect() + case s: ShuffleQueryStageExec if s.plan.isInstanceOf[CometPlan] => + CometExec.getByteArrayRdd(s.plan.asInstanceOf[CometPlan]).collect() case ReusedExchangeExec(_, plan) if plan.isInstanceOf[CometPlan] => CometExec.getByteArrayRdd(plan.asInstanceOf[CometPlan]).collect() case AQEShuffleReadExec(ShuffleQueryStageExec(_, ReusedExchangeExec(_, plan), _), _) if plan.isInstanceOf[CometPlan] => CometExec.getByteArrayRdd(plan.asInstanceOf[CometPlan]).collect() + case ShuffleQueryStageExec(_, ReusedExchangeExec(_, plan), _) + if plan.isInstanceOf[CometPlan] => + CometExec.getByteArrayRdd(plan.asInstanceOf[CometPlan]).collect() case AQEShuffleReadExec(s: ShuffleQueryStageExec, _) => throw new CometRuntimeException( "Child of CometBroadcastExchangeExec should be CometExec, " + diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index bfddd74d8e..88256d55a2 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -64,6 +64,22 @@ class CometExecSuite extends CometTestBase { } } + test("ShuffleQueryStageExec could be direct child node of CometBroadcastExchangeExec") { + val table = "src" + withTable(table) { + withView("lv_noalias") { + sql(s"CREATE TABLE $table (key INT, value STRING) USING PARQUET") + sql(s"INSERT INTO $table VALUES(238, 'val_238')") + + sql( + "CREATE VIEW lv_noalias AS SELECT myTab.* FROM src " + + "LATERAL VIEW explode(map('key1', 100, 'key2', 200)) myTab LIMIT 2") + val df = sql("SELECT * FROM lv_noalias a JOIN lv_noalias b ON a.key=b.key"); + checkSparkAnswer(df) + } + } + } + test("Sort on single struct should fallback to Spark") { withSQLConf( SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",