diff --git a/spark/v4.1/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSparkSqlExtensionsParser.scala b/spark/v4.1/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSparkSqlExtensionsParser.scala index ac127f754a91..22c28dce1558 100644 --- a/spark/v4.1/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSparkSqlExtensionsParser.scala +++ b/spark/v4.1/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSparkSqlExtensionsParser.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.RewriteViewCommands import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.parser.ParameterContext import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.parser.extensions.IcebergSqlExtensionsParser.NonReservedContext import org.apache.spark.sql.catalyst.parser.extensions.IcebergSqlExtensionsParser.QuotedIdentifierContext @@ -125,6 +126,24 @@ class IcebergSparkSqlExtensionsParser(delegate: ParserInterface) } } + /** + * Parse a string to a LogicalPlan, binding the given parameters. + */ + override def parsePlanWithParameters( + sqlText: String, + parameterContext: ParameterContext): LogicalPlan = { + val sqlTextAfterSubstitution = substitutor.substitute(sqlText) + if (isIcebergCommand(sqlTextAfterSubstitution)) { + // Iceberg DDL grammars do not accept parameter markers (`?` / `:name`), so the + // parameterContext is intentionally not propagated on this path. + parse(sqlTextAfterSubstitution) { parser => astBuilder.visit(parser.singleStatement()) } + .asInstanceOf[LogicalPlan] + } else { + RewriteViewCommands(SparkSession.active) + .apply(delegate.parsePlanWithParameters(sqlText, parameterContext)) + } + } + private def isIcebergCommand(sqlText: String): Boolean = { val normalized = sqlText .toLowerCase(Locale.ROOT) diff --git a/spark/v4.1/spark-extensions/src/test/java/org/apache/iceberg/spark/TestExtendedParser.java b/spark/v4.1/spark-extensions/src/test/java/org/apache/iceberg/spark/TestExtendedParser.java index ef4f0090292c..36e7314473aa 100644 --- a/spark/v4.1/spark-extensions/src/test/java/org/apache/iceberg/spark/TestExtendedParser.java +++ b/spark/v4.1/spark-extensions/src/test/java/org/apache/iceberg/spark/TestExtendedParser.java @@ -20,6 +20,8 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -27,14 +29,19 @@ import java.lang.reflect.Field; import java.util.Collections; import java.util.List; +import java.util.Map; import org.apache.iceberg.NullOrder; import org.apache.iceberg.SortDirection; import org.apache.iceberg.expressions.Term; +import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.catalyst.parser.AbstractSqlParser; import org.apache.spark.sql.catalyst.parser.AstBuilder; +import org.apache.spark.sql.catalyst.parser.ParameterContext; import org.apache.spark.sql.catalyst.parser.ParserInterface; import org.apache.spark.sql.catalyst.parser.extensions.IcebergSparkSqlExtensionsParser; +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeAll; @@ -184,6 +191,47 @@ public void testParseSortOrderFindsExtendedParserInParentClassField() throws Exc verify(icebergParser).parseSortOrder("id ASC NULLS FIRST"); } + /** Tests that non-Iceberg SQL delegates with the parameter context intact. */ + @Test + public void testParsePlanWithParametersDelegatesForNonIcebergSql() throws Exception { + ParserInterface delegate = mock(ParserInterface.class); + ParameterContext context = mock(ParameterContext.class); + LogicalPlan plan = new OneRowRelation(); + when(delegate.parsePlanWithParameters(anyString(), any(ParameterContext.class))) + .thenReturn(plan); + + IcebergSparkSqlExtensionsParser parser = new IcebergSparkSqlExtensionsParser(delegate); + + parser.parsePlanWithParameters("SELECT 1 WHERE 1 = ?", context); + + verify(delegate).parsePlanWithParameters("SELECT 1 WHERE 1 = ?", context); + } + + /** Tests that a positional parameter binds through a real Iceberg-extended parser. */ + @Test + public void testParsePlanWithParametersBindsPositionalParameter() throws Exception { + IcebergSparkSqlExtensionsParser parser = new IcebergSparkSqlExtensionsParser(originalParser); + setSessionStateParser(spark.sessionState(), parser); + + List rows = spark.sql("SELECT ? AS id", new Object[] {42}).collectAsList(); + + assertThat(rows).hasSize(1); + assertThat(rows.get(0).get(0)).isEqualTo(42); + } + + /** Tests that a named parameter binds through a real Iceberg-extended parser. */ + @Test + public void testParsePlanWithParametersBindsNamedParameter() throws Exception { + IcebergSparkSqlExtensionsParser parser = new IcebergSparkSqlExtensionsParser(originalParser); + setSessionStateParser(spark.sessionState(), parser); + + Map args = Collections.singletonMap("id", 42); + List rows = spark.sql("SELECT :id AS id", args).collectAsList(); + + assertThat(rows).hasSize(1); + assertThat(rows.get(0).get(0)).isEqualTo(42); + } + private static void setSessionStateParser(Object sessionState, ParserInterface parser) throws Exception { Class clazz = sessionState.getClass();