From 560f222dfc23929a4f13b2b6a9b038907ac1cc58 Mon Sep 17 00:00:00 2001 From: Alex Rhee Date: Mon, 13 Apr 2026 10:48:45 -0700 Subject: [PATCH] [VL] Fix struct field binding after explode aliasing --- .../gluten/execution/MiscOperatorSuite.scala | 86 +++++++++++++++++++ .../expression/ExpressionConverter.scala | 55 ++++++++---- 2 files changed, 122 insertions(+), 19 deletions(-) diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/MiscOperatorSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/MiscOperatorSuite.scala index 9fbd99752ed4..d0af3aba991b 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/MiscOperatorSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/MiscOperatorSuite.scala @@ -1200,6 +1200,92 @@ class MiscOperatorSuite extends VeloxWholeStageTransformerSuite with AdaptiveSpa } } + test("struct field projection after LATERAL VIEW EXPLODE stays native") { + withTable("t_lv_struct") { + sql("""CREATE TABLE t_lv_struct ( + | name STRING, + | items ARRAY> + |) USING parquet""".stripMargin) + sql("""INSERT INTO t_lv_struct VALUES + |('alice', ARRAY(NAMED_STRUCT('score', 90, 'label', 'A'))), + |('bob', ARRAY(NAMED_STRUCT('score', 40, 'label', 'B'))) + |""".stripMargin) + + runQueryAndCompare("""SELECT name, item.score, item.label + |FROM t_lv_struct + |LATERAL VIEW EXPLODE(items) AS item + |""".stripMargin) { + df => + val executedPlan = getExecutedPlan(df) + assert( + executedPlan.exists( + plan => + plan.isInstanceOf[ProjectExecTransformer] && + plan.find(_.isInstanceOf[GenerateExecTransformer]).isDefined), + s"Expected ProjectExecTransformer on top of explode output in executed plan:\n" + + s"${executedPlan.last}" + ) + } + } + } + + test("struct field filter after LATERAL VIEW EXPLODE stays native") { + withTable("t_lv_filter") { + sql("""CREATE TABLE t_lv_filter ( + | name STRING, + | items ARRAY> + |) USING parquet""".stripMargin) + sql("""INSERT INTO t_lv_filter VALUES + |('alice', ARRAY(NAMED_STRUCT('score', 90, 'label', 'A'), + | NAMED_STRUCT('score', 30, 'label', 'B'))), + |('bob', ARRAY(NAMED_STRUCT('score', 40, 'label', 'C'))) + |""".stripMargin) + + runQueryAndCompare("""SELECT name, item.score, item.label + |FROM t_lv_filter + |LATERAL VIEW EXPLODE(items) AS item + |WHERE item.score > 50 + |ORDER BY name, item.score + |""".stripMargin) { + df => + val executedPlan = getExecutedPlan(df) + assert( + executedPlan.exists( + plan => + plan.isInstanceOf[FilterExecTransformer] && + plan.find(_.isInstanceOf[GenerateExecTransformer]).isDefined), + s"Expected FilterExecTransformer on explode output in executed plan:\n" + + s"${executedPlan.last}" + ) + } + } + } + + test("nested struct field rollup after explode_outer stays native") { + runQueryAndCompare("""SELECT elem.a.x, elem.a.y, count(*) + |FROM ( + | SELECT explode_outer(arr) AS elem + | FROM ( + | SELECT array(named_struct( + | 'a', named_struct('x', id * 1.5, 'y', 'foo'))) AS arr + | FROM range(1) + | ) + |) + |GROUP BY ROLLUP(elem.a.x, elem.a.y) + |""".stripMargin) { + df => + val executedPlan = getExecutedPlan(df) + assert( + executedPlan.exists( + plan => + plan.isInstanceOf[ProjectExecTransformer] && + plan.find(_.isInstanceOf[GenerateExecTransformer]).isDefined), + s"Expected ProjectExecTransformer on explode_outer output in executed plan:\n" + + s"${executedPlan.last}" + ) + } + } + test("test array functions") { withTable("t") { sql("CREATE TABLE t (c1 ARRAY, c2 ARRAY, c3 STRING) using parquet") diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala index 73f7b8699630..f8236475e15c 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala @@ -930,10 +930,25 @@ object ExpressionConverter extends SQLConfHelper with Logging { } } + private def resolveStructOrdinal(names: ArrayBuffer[String], dataType: DataType): Int = { + var level = names.size - 1 + var dtType = dataType + var ordinal = -1 + while (dtType.isInstanceOf[StructType] && level >= 1) { + val candidateFields = dtType.asInstanceOf[StructType].fields + level -= 1 + val idx = candidateFields.indexWhere(_.name == names(level)) + if (idx < 0) return -1 + dtType = candidateFields(idx).dataType + ordinal = idx + } + if (level > 0) return -1 + ordinal + } + private def bindGetStructField( structField: GetStructField, input: AttributeSeq): BoundReference = { - // get the new ordinal base input var newOrdinal: Int = -1 val names = new ArrayBuffer[String] var root: Expression = structField @@ -947,26 +962,28 @@ object ExpressionConverter extends SQLConfHelper with Logging { if (!root.isInstanceOf[AttributeReference]) { return BoundReference(structField.ordinal, structField.dataType, structField.nullable) } - names += root.asInstanceOf[AttributeReference].name - input.attrs.foreach( - attribute => { - var level = names.size - 1 - if (names(level) == attribute.name) { - var candidateFields: Array[StructField] = null - var dtType = attribute.dataType - while (dtType.isInstanceOf[StructType] && level >= 1) { - candidateFields = dtType.asInstanceOf[StructType].fields - level -= 1 - val curName = names(level) - for (i <- 0 until candidateFields.length) { - if (candidateFields(i).name == curName) { - dtType = candidateFields(i).dataType - newOrdinal = i - } - } + val ref = root.asInstanceOf[AttributeReference] + names += ref.name + input.attrs.foreach { + attribute => + if (newOrdinal == -1 && names.last == attribute.name) { + val ordinal = resolveStructOrdinal(names, attribute.dataType) + if (ordinal != -1) { + newOrdinal = ordinal } } - }) + } + // Keep name-based binding as the primary path. Some post-generate projections rename + // attributes while preserving exprId, so retry by exprId only after the name lookup fails. + input.attrs.foreach { + attribute => + if (newOrdinal == -1 && attribute.exprId == ref.exprId) { + val ordinal = resolveStructOrdinal(names, attribute.dataType) + if (ordinal != -1) { + newOrdinal = ordinal + } + } + } if (newOrdinal == -1) { throw new IllegalStateException( s"Couldn't find $structField in ${input.attrs.mkString("[", ",", "]")}")