Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.RewriteViewCommands
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.parser.ParameterContext
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.parser.extensions.IcebergSqlExtensionsParser.NonReservedContext
import org.apache.spark.sql.catalyst.parser.extensions.IcebergSqlExtensionsParser.QuotedIdentifierContext
Expand Down Expand Up @@ -125,6 +126,24 @@ class IcebergSparkSqlExtensionsParser(delegate: ParserInterface)
}
}

/**
* Parse a string to a LogicalPlan, binding the given parameters.
*/
override def parsePlanWithParameters(
sqlText: String,
parameterContext: ParameterContext): LogicalPlan = {
val sqlTextAfterSubstitution = substitutor.substitute(sqlText)
if (isIcebergCommand(sqlTextAfterSubstitution)) {
Comment thread
j1wonpark marked this conversation as resolved.
// Iceberg DDL grammars do not accept parameter markers (`?` / `:name`), so the
// parameterContext is intentionally not propagated on this path.
parse(sqlTextAfterSubstitution) { parser => astBuilder.visit(parser.singleStatement()) }
.asInstanceOf[LogicalPlan]
} else {
RewriteViewCommands(SparkSession.active)
.apply(delegate.parsePlanWithParameters(sqlText, parameterContext))
}
}

private def isIcebergCommand(sqlText: String): Boolean = {
val normalized = sqlText
.toLowerCase(Locale.ROOT)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,28 @@

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import java.lang.reflect.Field;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.apache.iceberg.NullOrder;
import org.apache.iceberg.SortDirection;
import org.apache.iceberg.expressions.Term;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.catalyst.parser.AbstractSqlParser;
import org.apache.spark.sql.catalyst.parser.AstBuilder;
import org.apache.spark.sql.catalyst.parser.ParameterContext;
import org.apache.spark.sql.catalyst.parser.ParserInterface;
import org.apache.spark.sql.catalyst.parser.extensions.IcebergSparkSqlExtensionsParser;
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
Expand Down Expand Up @@ -184,6 +191,47 @@ public void testParseSortOrderFindsExtendedParserInParentClassField() throws Exc
verify(icebergParser).parseSortOrder("id ASC NULLS FIRST");
}

/** Tests that non-Iceberg SQL delegates with the parameter context intact. */
@Test
public void testParsePlanWithParametersDelegatesForNonIcebergSql() throws Exception {
ParserInterface delegate = mock(ParserInterface.class);
ParameterContext context = mock(ParameterContext.class);
LogicalPlan plan = new OneRowRelation();
when(delegate.parsePlanWithParameters(anyString(), any(ParameterContext.class)))
.thenReturn(plan);

IcebergSparkSqlExtensionsParser parser = new IcebergSparkSqlExtensionsParser(delegate);

parser.parsePlanWithParameters("SELECT 1 WHERE 1 = ?", context);

verify(delegate).parsePlanWithParameters("SELECT 1 WHERE 1 = ?", context);
}

/** Tests that a positional parameter binds through a real Iceberg-extended parser. */
@Test
public void testParsePlanWithParametersBindsPositionalParameter() throws Exception {
IcebergSparkSqlExtensionsParser parser = new IcebergSparkSqlExtensionsParser(originalParser);
setSessionStateParser(spark.sessionState(), parser);

List<Row> rows = spark.sql("SELECT ? AS id", new Object[] {42}).collectAsList();

assertThat(rows).hasSize(1);
assertThat(rows.get(0).get(0)).isEqualTo(42);
}

/** Tests that a named parameter binds through a real Iceberg-extended parser. */
@Test
public void testParsePlanWithParametersBindsNamedParameter() throws Exception {
IcebergSparkSqlExtensionsParser parser = new IcebergSparkSqlExtensionsParser(originalParser);
setSessionStateParser(spark.sessionState(), parser);

Map<String, Object> args = Collections.singletonMap("id", 42);
List<Row> rows = spark.sql("SELECT :id AS id", args).collectAsList();

assertThat(rows).hasSize(1);
assertThat(rows.get(0).get(0)).isEqualTo(42);
}

private static void setSessionStateParser(Object sessionState, ParserInterface parser)
throws Exception {
Class<?> clazz = sessionState.getClass();
Expand Down