From 73cb01b8cb95dd9e505967e843a5a18567f9c749 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Tue, 24 Dec 2024 19:27:50 +0100 Subject: [PATCH 01/59] first commit --- .../analysis/ColumnResolutionHelper.scala | 34 +++++++--- .../catalyst/analysis/ResolveCatalogs.scala | 1 + .../analysis/ResolveSetVariable.scala | 4 ++ .../catalyst/analysis/v2ResolutionPlans.scala | 1 + .../catalog/TempVariableManager.scala | 27 ++++++-- .../connector/catalog/CatalogManager.scala | 7 ++- .../command/v2/CreateVariableExec.scala | 13 +++- .../scripting/ScriptingVariableManager.scala | 63 +++++++++++++++++++ .../sql/scripting/SqlScriptingExecution.scala | 3 + .../SqlScriptingExecutionContext.scala | 7 ++- 10 files changed, 139 insertions(+), 21 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/scripting/ScriptingVariableManager.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala index 36fd4d02f8da1..8c5568f6713b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala @@ -265,22 +265,36 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { } } - if (maybeTempVariableName(nameParts)) { - val variableName = if (conf.caseSensitiveAnalysis) { - nameParts.last - } else { - nameParts.last.toLowerCase(Locale.ROOT) - } - catalogManager.tempVariableManager.get(variableName).map { varDef => + // todo LOCALVARS: refactor to be more functional with getorelse or something + + val variableName = if (conf.caseSensitiveAnalysis) { + nameParts.last + } else { + nameParts.last.toLowerCase(Locale.ROOT) + } + + catalogManager.scriptingLocalVariableManager + .flatMap(_.get(variableName)) + .map { varDef => VariableReference( nameParts, FakeSystemCatalog, Identifier.of(Array(CatalogManager.SESSION_NAMESPACE), variableName), varDef) } - } else { - None - } + .orElse( + if (maybeTempVariableName(nameParts)) { + catalogManager.tempVariableManager.get(variableName).map { varDef => + VariableReference( + nameParts, + FakeSystemCatalog, + Identifier.of(Array(CatalogManager.SESSION_NAMESPACE), variableName), + varDef) + } + } else { + None + } + ) } // Resolves `UnresolvedAttribute` to its value. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala index 664b68008080d..642fbe7b00287 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala @@ -75,6 +75,7 @@ class ResolveCatalogs(val catalogManager: CatalogManager) private def resolveVariableName(nameParts: Seq[String]): ResolvedIdentifier = { def ident: Identifier = Identifier.of(Array(CatalogManager.SESSION_NAMESPACE), nameParts.last) + // todo LOCALVARS: update to support local vars if (nameParts.length == 1) { ResolvedIdentifier(FakeSystemCatalog, ident) } else if (nameParts.length == 2) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSetVariable.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSetVariable.scala index bd0204ba06fd8..52104d412ae22 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSetVariable.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSetVariable.scala @@ -42,6 +42,7 @@ class ResolveSetVariable(val catalogManager: CatalogManager) extends Rule[Logica case u: UnresolvedAttribute => lookupVariable(u.nameParts) match { case Some(variable) => variable.copy(canFold = false) + // todo LOCALVARS: change system session to proper message case _ => throw unresolvedVariableError(u.nameParts, Seq("SYSTEM", "SESSION")) } @@ -53,6 +54,9 @@ class ResolveSetVariable(val catalogManager: CatalogManager) extends Rule[Logica // Names are normalized when the variables are created. // No need for case insensitive comparison here. // TODO: we need to group by the qualified variable name once other catalogs support it. + + // todo LOCALVARS: the todo above, although possibly not neceesary because it might work. + // research further val dups = resolvedVars.groupBy(_.identifier.name).filter(kv => kv._2.length > 1) if (dups.nonEmpty) { throw new AnalysisException( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala index dee78b8f03af4..cac41c3a50608 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala @@ -259,5 +259,6 @@ case class ResolvedIdentifier( // A fake v2 catalog to hold temp views. object FakeSystemCatalog extends CatalogPlugin { override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = {} + // todo LOCALVARS: why is this here override def name(): String = "system" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/TempVariableManager.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/TempVariableManager.scala index 2c262da1f4449..b21d7724efddc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/TempVariableManager.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/TempVariableManager.scala @@ -26,6 +26,22 @@ import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.connector.catalog.CatalogManager.{SESSION_NAMESPACE, SYSTEM_CATALOG_NAME} import org.apache.spark.sql.errors.DataTypeErrorsBase +trait VariableManager { + def create( + name: String, + defaultValueSQL: String, + initValue: Literal, + overrideIfExists: Boolean): Unit + + def get(name: String): Option[VariableDefinition] + + def remove(name: String): Boolean + + def clear(): Unit + + def isEmpty: Boolean +} + /** * A thread-safe manager for temporary SQL variables (that live in the schema `SYSTEM.SESSION`), * providing atomic operations to manage them, e.g. create, get, remove, etc. @@ -33,7 +49,7 @@ import org.apache.spark.sql.errors.DataTypeErrorsBase * Note that, the variable name is always case-sensitive here, callers are responsible to format the * variable name w.r.t. case-sensitive config. */ -class TempVariableManager extends DataTypeErrorsBase { +class TempVariableManager extends VariableManager with DataTypeErrorsBase { @GuardedBy("this") private val variables = new mutable.HashMap[String, VariableDefinition] @@ -52,19 +68,20 @@ class TempVariableManager extends DataTypeErrorsBase { variables.put(name, VariableDefinition(defaultValueSQL, initValue)) } - def get(name: String): Option[VariableDefinition] = synchronized { + override def get(name: String): Option[VariableDefinition] = synchronized { variables.get(name) } - def remove(name: String): Boolean = synchronized { + override def remove(name: String): Boolean = synchronized { variables.remove(name).isDefined } - def clear(): Unit = synchronized { + override def clear(): Unit = synchronized { variables.clear() } - def isEmpty: Boolean = synchronized { + // todo LOCALVARS: check what this is for with Vladimir + override def isEmpty: Boolean = synchronized { variables.isEmpty } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala index db94659b1033b..3bf07c636d9f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala @@ -18,10 +18,9 @@ package org.apache.spark.sql.connector.catalog import scala.collection.mutable - import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.SQLConfHelper -import org.apache.spark.sql.catalyst.catalog.{SessionCatalog, TempVariableManager} +import org.apache.spark.sql.catalyst.catalog.{SessionCatalog, TempVariableManager, VariableManager} import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf @@ -49,6 +48,9 @@ class CatalogManager( // TODO: create a real SYSTEM catalog to host `TempVariableManager` under the SESSION namespace. val tempVariableManager: TempVariableManager = new TempVariableManager + // todo LOCALVARS: should this be thread local + var scriptingLocalVariableManager: Option[VariableManager] = None + def catalog(name: String): CatalogPlugin = synchronized { if (name.equalsIgnoreCase(SESSION_CATALOG_NAME)) { v2SessionCatalog @@ -159,6 +161,7 @@ class CatalogManager( private[sql] object CatalogManager { val SESSION_CATALOG_NAME: String = "spark_catalog" + // todo LOCALVARS: whats this val SYSTEM_CATALOG_NAME = "system" val SESSION_NAMESPACE = "session" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/CreateVariableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/CreateVariableExec.scala index 0ed1c104edb92..113d506f40bd3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/CreateVariableExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/CreateVariableExec.scala @@ -31,7 +31,9 @@ case class CreateVariableExec(name: String, defaultExpr: DefaultValueExpression, extends LeafV2CommandExec with ExpressionsEvaluator { override protected def run(): Seq[InternalRow] = { - val variableManager = session.sessionState.catalogManager.tempVariableManager + val scriptingVariableManager = session.sessionState.catalogManager.scriptingLocalVariableManager + + val tempVariableManager = session.sessionState.catalogManager.tempVariableManager val exprs = prepareExpressions(Seq(defaultExpr.child), subExprEliminationEnabled = false) initializeExprs(exprs, 0) val initValue = Literal(exprs.head.eval(), defaultExpr.dataType) @@ -40,8 +42,13 @@ case class CreateVariableExec(name: String, defaultExpr: DefaultValueExpression, } else { name.toLowerCase(Locale.ROOT) } - variableManager.create( - normalizedName, defaultExpr.originalSQL, initValue, replace) + + // create local variable if we are in a script, otherwise create session variable + scriptingVariableManager.getOrElse(tempVariableManager) + .create(normalizedName, defaultExpr.originalSQL, initValue, replace) + +// tempVariableManager.create( +// normalizedName, defaultExpr.originalSQL, initValue, replace) Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/ScriptingVariableManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/ScriptingVariableManager.scala new file mode 100644 index 0000000000000..2d300fef5a28c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/ScriptingVariableManager.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.scripting + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.catalog.{VariableDefinition, VariableManager} +import org.apache.spark.sql.catalyst.expressions.Literal + + +// todo LOCALVARS: should this be thread safe / synchronized, also should probably be one per frame +class ScriptingVariableManager(context: SqlScriptingExecutionContext) extends VariableManager{ + + // map from scope label to map from variable name to variable definition + private val variables = { + val map = new mutable.HashMap[String, mutable.HashMap[String, VariableDefinition]] + // probably unnecessary as when variable manager is initialized there are no scopes yet + context.currentFrame.scopes.foreach(scope => + map.put(scope.label, new mutable.HashMap[String, VariableDefinition])) + map + } + + override def create( + name: String, + defaultValueSQL: String, + initValue: Literal, + overrideIfExists: Boolean): Unit = { + // todo LOCALVARS: throw meaningful error, qualified name + variables + .getOrElse(context.currentScope.label, throw Exception) + .put(name, VariableDefinition(defaultValueSQL, initValue)) + } + + override def get(name: String): Option[VariableDefinition] = { + // todo LOCALVAR: add support for qualified name + context.currentFrame.scopes + .findLast(scope => variables(scope.label).contains(name)) + .map(scope => variables(scope.label)(name)) + } + + override def remove(name: String): Boolean = { + true + } + + override def clear(): Unit = variables.clear() + + override def isEmpty: Boolean = variables.values.forall(_.isEmpty) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala index 71b44cbbd0704..070fd34ecb0d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala @@ -50,6 +50,9 @@ class SqlScriptingExecution( ctx } + private val variableManager = new ScriptingVariableManager(context) + session.sessionState.catalogManager.scriptingLocalVariableManager = Some(variableManager) + private var current: Option[DataFrame] = getNextResult override def hasNext: Boolean = current.isDefined diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionContext.scala index 5a2ef62e3bb7d..c167befc0eda1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionContext.scala @@ -41,6 +41,9 @@ class SqlScriptingExecutionContext { } frames.last.exitScope(label) } + + def currentFrame: SqlScriptingExecutionFrame = frames.last + def currentScope: SqlScriptingExecutionScope = currentFrame.currentScope } /** @@ -53,7 +56,7 @@ class SqlScriptingExecutionFrame( executionPlan: Iterator[CompoundStatementExec]) extends Iterator[CompoundStatementExec] { // List of scopes that are currently active. - private val scopes: ListBuffer[SqlScriptingExecutionScope] = ListBuffer.empty + val scopes: ListBuffer[SqlScriptingExecutionScope] = ListBuffer.empty override def hasNext: Boolean = executionPlan.hasNext @@ -80,6 +83,8 @@ class SqlScriptingExecutionFrame( scopes.remove(scopes.length - 1) } } + + def currentScope: SqlScriptingExecutionScope = scopes.last } /** From 813d28247cb61afb5e71a0ec8f033c2c32bef962 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Tue, 24 Dec 2024 20:40:10 +0100 Subject: [PATCH 02/59] POC works --- .../analysis/ColumnResolutionHelper.scala | 3 ++- .../scripting/ScriptingVariableManager.scala | 3 ++- .../SqlScriptingInterpreterSuite.scala | 19 +++++++++++++++++++ 3 files changed, 23 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala index 8c5568f6713b1..cae24efacd46e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala @@ -273,11 +273,12 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { nameParts.last.toLowerCase(Locale.ROOT) } - catalogManager.scriptingLocalVariableManager + catalogManager.scriptingLocalVariableManager .flatMap(_.get(variableName)) .map { varDef => VariableReference( nameParts, + // todo LOCALVARS: deal with this fakesystemcatalog situation FakeSystemCatalog, Identifier.of(Array(CatalogManager.SESSION_NAMESPACE), variableName), varDef) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/ScriptingVariableManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/ScriptingVariableManager.scala index 2d300fef5a28c..15096be4cad3d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/ScriptingVariableManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/ScriptingVariableManager.scala @@ -42,7 +42,8 @@ class ScriptingVariableManager(context: SqlScriptingExecutionContext) extends Va overrideIfExists: Boolean): Unit = { // todo LOCALVARS: throw meaningful error, qualified name variables - .getOrElse(context.currentScope.label, throw Exception) +// .getOrElse(context.currentScope.label, throw Exception) + .get(context.currentScope.label).get .put(name, VariableDefinition(defaultValueSQL, initValue)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index 20997504b15eb..50560602b6190 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -69,6 +69,25 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { result.zip(expected).foreach { case (df, expectedAnswer) => checkAnswer(df, expectedAnswer) } } + test("testtest") { + val sqlScript = + """ + |BEGIN + |DECLARE var = 1; + |SELECT var + var * 2; + |END + |""".stripMargin + + val r = spark.sql(sqlScript).collect() + + val expected = Seq( + Seq.empty[Row], // declare var + Seq(Row(1)), // select + Seq.empty[Row] // drop var + ) +// verifySqlScriptResult(sqlScript, expected) + } + // Tests test("multi statement - simple") { withTable("t") { From 1c08f57eaf934a315b607b3487fa2e3669920e8d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Wed, 25 Dec 2024 12:19:55 +0100 Subject: [PATCH 03/59] make column res helper more functional --- .../analysis/ColumnResolutionHelper.scala | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala index cae24efacd46e..dfa9974f09217 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala @@ -265,8 +265,6 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { } } - // todo LOCALVARS: refactor to be more functional with getorelse or something - val variableName = if (conf.caseSensitiveAnalysis) { nameParts.last } else { @@ -278,24 +276,22 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { .map { varDef => VariableReference( nameParts, - // todo LOCALVARS: deal with this fakesystemcatalog situation + // todo LOCALVARS: deal with this fakesystemcatalog / session_namespace situation FakeSystemCatalog, Identifier.of(Array(CatalogManager.SESSION_NAMESPACE), variableName), varDef) } - .orElse( - if (maybeTempVariableName(nameParts)) { - catalogManager.tempVariableManager.get(variableName).map { varDef => + .orElse(Option.when(maybeTempVariableName(nameParts)) { + catalogManager.tempVariableManager + .get(variableName) + .map { varDef => VariableReference( nameParts, FakeSystemCatalog, Identifier.of(Array(CatalogManager.SESSION_NAMESPACE), variableName), varDef) } - } else { - None - } - ) + }) } // Resolves `UnresolvedAttribute` to its value. From 18da02f77c9a5207d7aecc8f65c9abff66ccb3a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Wed, 25 Dec 2024 14:44:47 +0100 Subject: [PATCH 04/59] move variables map to SqlScriptingScope --- .../analysis/ColumnResolutionHelper.scala | 22 ++++++------- .../connector/catalog/CatalogManager.scala | 2 +- .../parser/SqlScriptingParserSuite.scala | 12 +++++++ .../scripting/ScriptingVariableManager.scala | 33 ++++++------------- .../SqlScriptingExecutionContext.scala | 7 +++- .../SqlScriptingInterpreterSuite.scala | 19 +++++++++++ 6 files changed, 59 insertions(+), 36 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala index dfa9974f09217..ea684d1299dbd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala @@ -281,17 +281,17 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { Identifier.of(Array(CatalogManager.SESSION_NAMESPACE), variableName), varDef) } - .orElse(Option.when(maybeTempVariableName(nameParts)) { - catalogManager.tempVariableManager - .get(variableName) - .map { varDef => - VariableReference( - nameParts, - FakeSystemCatalog, - Identifier.of(Array(CatalogManager.SESSION_NAMESPACE), variableName), - varDef) - } - }) +// .orElse(Option.when(maybeTempVariableName(nameParts)) { +// catalogManager.tempVariableManager +// .get(variableName) +// .map { varDef => +// VariableReference( +// nameParts, +// FakeSystemCatalog, +// Identifier.of(Array(CatalogManager.SESSION_NAMESPACE), variableName), +// varDef) +// } +// }) } // Resolves `UnresolvedAttribute` to its value. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala index 3bf07c636d9f1..90ad72fce30e3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala @@ -48,7 +48,7 @@ class CatalogManager( // TODO: create a real SYSTEM catalog to host `TempVariableManager` under the SESSION namespace. val tempVariableManager: TempVariableManager = new TempVariableManager - // todo LOCALVARS: should this be thread local + // todo LOCALVARS: should this be thread local (probably) var scriptingLocalVariableManager: Option[VariableManager] = None def catalog(name: String): CatalogPlugin = synchronized { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala index c9e2f42e164f9..e4cc67c37dd83 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala @@ -46,6 +46,18 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { assert(!statement.isInstanceOf[CompoundBody]) } + test("testtest") { + val sqlScriptText = + """ + |BEGIN + |DECLARE `my.var.i.ab.le` = 1; + |SELECT `my.var.i.ab.le` + `my.var.i.ab.le` * 2; + |END + |""".stripMargin + val statement = parsePlan(sqlScriptText) + assert(!statement.isInstanceOf[CompoundBody]) + } + test("multi select without ; - should fail") { val sqlScriptText = "SELECT 1 SELECT 1" val e = intercept[ParseException] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/ScriptingVariableManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/ScriptingVariableManager.scala index 15096be4cad3d..e13600cdcb04b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/ScriptingVariableManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/ScriptingVariableManager.scala @@ -17,48 +17,35 @@ package org.apache.spark.sql.scripting -import scala.collection.mutable - import org.apache.spark.sql.catalyst.catalog.{VariableDefinition, VariableManager} import org.apache.spark.sql.catalyst.expressions.Literal -// todo LOCALVARS: should this be thread safe / synchronized, also should probably be one per frame -class ScriptingVariableManager(context: SqlScriptingExecutionContext) extends VariableManager{ - - // map from scope label to map from variable name to variable definition - private val variables = { - val map = new mutable.HashMap[String, mutable.HashMap[String, VariableDefinition]] - // probably unnecessary as when variable manager is initialized there are no scopes yet - context.currentFrame.scopes.foreach(scope => - map.put(scope.label, new mutable.HashMap[String, VariableDefinition])) - map - } +// todo LOCALVARS: should this be thread safe / synchronized (probably not since its one per script) +class ScriptingVariableManager(context: SqlScriptingExecutionContext) extends VariableManager { override def create( name: String, defaultValueSQL: String, initValue: Literal, overrideIfExists: Boolean): Unit = { - // todo LOCALVARS: throw meaningful error, qualified name - variables -// .getOrElse(context.currentScope.label, throw Exception) - .get(context.currentScope.label).get - .put(name, VariableDefinition(defaultValueSQL, initValue)) + // todo LOCALVARS: qualified name + context.currentScope.variables.put(name, VariableDefinition(defaultValueSQL, initValue)) } override def get(name: String): Option[VariableDefinition] = { - // todo LOCALVAR: add support for qualified name + // todo LOCALVARS: add support for qualified name context.currentFrame.scopes - .findLast(scope => variables(scope.label).contains(name)) - .map(scope => variables(scope.label)(name)) + .findLast(_.variables.contains(name)) + .map(_.variables(name)) } override def remove(name: String): Boolean = { true } - override def clear(): Unit = variables.clear() + // todo LOCALVARS: do we need this + override def clear(): Unit = () - override def isEmpty: Boolean = variables.values.forall(_.isEmpty) + override def isEmpty: Boolean = context.currentFrame.scopes.forall(_.variables.isEmpty) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionContext.scala index c167befc0eda1..503881cdf1fb3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionContext.scala @@ -17,9 +17,12 @@ package org.apache.spark.sql.scripting +import scala.collection.mutable import scala.collection.mutable.ListBuffer import org.apache.spark.SparkException +import org.apache.spark.sql.catalyst.catalog.VariableDefinition + /** * SQL scripting execution context - keeps track of the current execution state. @@ -93,4 +96,6 @@ class SqlScriptingExecutionFrame( * @param label * Label of the scope. */ -class SqlScriptingExecutionScope(val label: String) +class SqlScriptingExecutionScope(val label: String) { + val variables = new mutable.HashMap[String, VariableDefinition] +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index 50560602b6190..8f3bdb60fceaf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -88,6 +88,25 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { // verifySqlScriptResult(sqlScript, expected) } + test("testtest2") { + val sqlScript = + """ + |BEGIN + |DECLARE `my.var.i.ab.le` = 1; + |SELECT `my.var.i.ab.le` + `my.var.i.ab.le` * 2; + |END + |""".stripMargin + + val r = spark.sql(sqlScript).collect() + + val expected = Seq( + Seq.empty[Row], // declare var + Seq(Row(1)), // select + Seq.empty[Row] // drop var + ) +// verifySqlScriptResult(sqlScript, expected) + } + // Tests test("multi statement - simple") { withTable("t") { From cee5f1a665784a98cf13d58777d43f76cdbe4f3e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Fri, 27 Dec 2024 19:38:15 +0100 Subject: [PATCH 05/59] implement proper namespace (scope label name) for local variables --- .../analysis/ColumnResolutionHelper.scala | 25 ++++++++++--------- .../catalog/TempVariableManager.scala | 17 +++++++++---- .../sql/catalyst/analysis/AnalysisSuite.scala | 6 ++--- .../sql/catalyst/analysis/AnalysisTest.scala | 7 ++++-- .../command/v2/CreateVariableExec.scala | 13 +++++----- .../command/v2/SetVariableExec.scala | 3 ++- .../command/v2/V2CommandStrategy.scala | 2 +- .../scripting/ScriptingVariableManager.scala | 15 ++++++++--- .../SqlScriptingInterpreterSuite.scala | 2 ++ 9 files changed, 56 insertions(+), 34 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala index ea684d1299dbd..af01ee9470beb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala @@ -265,6 +265,7 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { } } + // todo LOCALVARS: should we do this for namespace as well val variableName = if (conf.caseSensitiveAnalysis) { nameParts.last } else { @@ -278,20 +279,20 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { nameParts, // todo LOCALVARS: deal with this fakesystemcatalog / session_namespace situation FakeSystemCatalog, - Identifier.of(Array(CatalogManager.SESSION_NAMESPACE), variableName), + Identifier.of(Array(varDef.identifier.namespace().last), variableName), varDef) } -// .orElse(Option.when(maybeTempVariableName(nameParts)) { -// catalogManager.tempVariableManager -// .get(variableName) -// .map { varDef => -// VariableReference( -// nameParts, -// FakeSystemCatalog, -// Identifier.of(Array(CatalogManager.SESSION_NAMESPACE), variableName), -// varDef) -// } -// }) + .orElse(Option.when(maybeTempVariableName(nameParts)) { + catalogManager.tempVariableManager + .get(variableName) + .map { varDef => + VariableReference( + nameParts, + FakeSystemCatalog, + Identifier.of(Array(CatalogManager.SESSION_NAMESPACE), variableName), + varDef) + } + }) } // Resolves `UnresolvedAttribute` to its value. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/TempVariableManager.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/TempVariableManager.scala index b21d7724efddc..4185a43a17a80 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/TempVariableManager.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/TempVariableManager.scala @@ -24,14 +24,17 @@ import scala.collection.mutable import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.connector.catalog.CatalogManager.{SESSION_NAMESPACE, SYSTEM_CATALOG_NAME} +import org.apache.spark.sql.connector.catalog.Identifier import org.apache.spark.sql.errors.DataTypeErrorsBase +// todo LOCALVARS: move this to separate file or rename this file trait VariableManager { def create( name: String, defaultValueSQL: String, initValue: Literal, - overrideIfExists: Boolean): Unit + overrideIfExists: Boolean, + identifier: Identifier): Unit def get(name: String): Option[VariableDefinition] @@ -42,6 +45,11 @@ trait VariableManager { def isEmpty: Boolean } +case class VariableDefinition( + identifier: Identifier, + defaultValueSQL: String, + currentValue: Literal) + /** * A thread-safe manager for temporary SQL variables (that live in the schema `SYSTEM.SESSION`), * providing atomic operations to manage them, e.g. create, get, remove, etc. @@ -58,14 +66,15 @@ class TempVariableManager extends VariableManager with DataTypeErrorsBase { name: String, defaultValueSQL: String, initValue: Literal, - overrideIfExists: Boolean): Unit = synchronized { + overrideIfExists: Boolean, + identifier: Identifier): Unit = synchronized { if (!overrideIfExists && variables.contains(name)) { throw new AnalysisException( errorClass = "VARIABLE_ALREADY_EXISTS", messageParameters = Map( "variableName" -> toSQLId(Seq(SYSTEM_CATALOG_NAME, SESSION_NAMESPACE, name)))) } - variables.put(name, VariableDefinition(defaultValueSQL, initValue)) + variables.put(name, VariableDefinition(identifier, defaultValueSQL, initValue)) } override def get(name: String): Option[VariableDefinition] = synchronized { @@ -85,5 +94,3 @@ class TempVariableManager extends VariableManager with DataTypeErrorsBase { variables.isEmpty } } - -case class VariableDefinition(defaultValueSQL: String, currentValue: Literal) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index ae27985a3ba64..876bbd83687a9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -42,7 +42,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning} import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.connector.catalog.InMemoryTable +import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTable} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.internal.SQLConf @@ -1483,9 +1483,9 @@ class AnalysisSuite extends AnalysisTest with Matchers { test("Execute Immediate plan transformation") { try { SimpleAnalyzer.catalogManager.tempVariableManager.create( - "res", "1", Literal(1), overrideIfExists = true) + "res", "1", Literal(1), overrideIfExists = true, Identifier.of(Array("res"), "res")) SimpleAnalyzer.catalogManager.tempVariableManager.create( - "res2", "1", Literal(1), overrideIfExists = true) + "res2", "1", Literal(1), overrideIfExists = true, Identifier.of(Array("res"), "res")) val actual1 = parsePlan("EXECUTE IMMEDIATE 'SELECT 42 WHERE ? = 1' USING 2").analyze val expected1 = parsePlan("SELECT 42 where 2 = 1").analyze comparePlans(actual1, expected1) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index 71744f4d15105..678bc072c1662 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.connector.catalog.Identifier import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} import org.apache.spark.sql.types.{StringType, StructType} @@ -87,9 +88,11 @@ trait AnalysisTest extends PlanTest { overrideIfExists = true) new Analyzer(catalog) { catalogManager.tempVariableManager.create( - "testVarA", "1", Literal(1), overrideIfExists = true) + "testVarA", "1", Literal(1), + overrideIfExists = true, Identifier.of(Array("testA"), "testVarA")) catalogManager.tempVariableManager.create( - "testVarNull", null, Literal(null, StringType), overrideIfExists = true) + "testVarNull", null, Literal(null, StringType), + overrideIfExists = true, Identifier.of(Array("testA"), "testVarA")) override val extendedResolutionRules = extendedAnalysisRules } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/CreateVariableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/CreateVariableExec.scala index 113d506f40bd3..02251120abd61 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/CreateVariableExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/CreateVariableExec.scala @@ -22,12 +22,15 @@ import java.util.Locale import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, ExpressionsEvaluator, Literal} import org.apache.spark.sql.catalyst.plans.logical.DefaultValueExpression +import org.apache.spark.sql.connector.catalog.Identifier import org.apache.spark.sql.execution.datasources.v2.LeafV2CommandExec /** * Physical plan node for creating a variable. */ -case class CreateVariableExec(name: String, defaultExpr: DefaultValueExpression, replace: Boolean) +case class CreateVariableExec( + identifier: Identifier, + defaultExpr: DefaultValueExpression, replace: Boolean) extends LeafV2CommandExec with ExpressionsEvaluator { override protected def run(): Seq[InternalRow] = { @@ -38,17 +41,15 @@ case class CreateVariableExec(name: String, defaultExpr: DefaultValueExpression, initializeExprs(exprs, 0) val initValue = Literal(exprs.head.eval(), defaultExpr.dataType) val normalizedName = if (session.sessionState.conf.caseSensitiveAnalysis) { - name + identifier.name() } else { - name.toLowerCase(Locale.ROOT) + identifier.name().toLowerCase(Locale.ROOT) } // create local variable if we are in a script, otherwise create session variable scriptingVariableManager.getOrElse(tempVariableManager) - .create(normalizedName, defaultExpr.originalSQL, initValue, replace) + .create(normalizedName, defaultExpr.originalSQL, initValue, replace, identifier) -// tempVariableManager.create( -// normalizedName, defaultExpr.originalSQL, initValue, replace) Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/SetVariableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/SetVariableExec.scala index a5d90b4d154ce..a893b6fdf4976 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/SetVariableExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/SetVariableExec.scala @@ -61,7 +61,8 @@ case class SetVariableExec(variables: Seq[VariableReference], query: SparkPlan) variable.identifier.name, variable.varDef.defaultValueSQL, Literal(value, variable.dataType), - overrideIfExists = true) + overrideIfExists = true, + variable.identifier) } override def output: Seq[Attribute] = Seq.empty diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/V2CommandStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/V2CommandStrategy.scala index ebc2e83e9c5fc..e75c30539e81f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/V2CommandStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/V2CommandStrategy.scala @@ -28,7 +28,7 @@ object V2CommandStrategy extends Strategy { // TODO: move v2 commands to here which are not data source v2 related. override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case CreateVariable(ident: ResolvedIdentifier, defaultExpr, replace) => - CreateVariableExec(ident.identifier.name, defaultExpr, replace) :: Nil + CreateVariableExec(ident.identifier, defaultExpr, replace) :: Nil case DropVariable(ident: ResolvedIdentifier, ifExists) => DropVariableExec(ident.identifier.name, ifExists) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/ScriptingVariableManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/ScriptingVariableManager.scala index e13600cdcb04b..7382b45da3158 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/ScriptingVariableManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/ScriptingVariableManager.scala @@ -19,18 +19,24 @@ package org.apache.spark.sql.scripting import org.apache.spark.sql.catalyst.catalog.{VariableDefinition, VariableManager} import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.connector.catalog.Identifier - -// todo LOCALVARS: should this be thread safe / synchronized (probably not since its one per script) class ScriptingVariableManager(context: SqlScriptingExecutionContext) extends VariableManager { override def create( name: String, defaultValueSQL: String, initValue: Literal, - overrideIfExists: Boolean): Unit = { + overrideIfExists: Boolean, + identifier: Identifier): Unit = { // todo LOCALVARS: qualified name - context.currentScope.variables.put(name, VariableDefinition(defaultValueSQL, initValue)) + context.currentScope.variables.put( + name, + VariableDefinition( + Identifier.of(Array(context.currentScope.label), name), + defaultValueSQL, + initValue + )) } override def get(name: String): Option[VariableDefinition] = { @@ -41,6 +47,7 @@ class ScriptingVariableManager(context: SqlScriptingExecutionContext) extends Va } override def remove(name: String): Boolean = { + // probably throw error true } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index 8f3bdb60fceaf..209af5a4ad4d7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -73,8 +73,10 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { val sqlScript = """ |BEGIN + | lbl: BEGIN |DECLARE var = 1; |SELECT var + var * 2; + | END; |END |""".stripMargin From 47934ab73ab9288ac395086618d24b70354f7de5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Mon, 30 Dec 2024 17:16:27 +0100 Subject: [PATCH 06/59] qualified names --- .../analysis/ColumnResolutionHelper.scala | 41 +++++++++++-------- .../catalog/TempVariableManager.scala | 6 +-- .../scripting/ScriptingVariableManager.scala | 10 +++-- .../SqlScriptingInterpreterSuite.scala | 13 ++++-- 4 files changed, 43 insertions(+), 27 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala index af01ee9470beb..e0caca6062cf8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala @@ -265,34 +265,41 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { } } - // todo LOCALVARS: should we do this for namespace as well - val variableName = if (conf.caseSensitiveAnalysis) { - nameParts.last +// val variableName = if (conf.caseSensitiveAnalysis) { +// nameParts.last +// } else { +// nameParts.last.toLowerCase(Locale.ROOT) +// } + + // todo LOCALVARS: check if we have 1 or 2 nameparts (here and maybe createvarexec) + // todo LOCALVARS: should we do this for label or only for var name + val namePartsCaseAdjusted = if (conf.caseSensitiveAnalysis) { + nameParts } else { - nameParts.last.toLowerCase(Locale.ROOT) + nameParts.map(_.toLowerCase(Locale.ROOT)) } catalogManager.scriptingLocalVariableManager - .flatMap(_.get(variableName)) + .flatMap(_.get(namePartsCaseAdjusted)) .map { varDef => VariableReference( nameParts, // todo LOCALVARS: deal with this fakesystemcatalog / session_namespace situation FakeSystemCatalog, - Identifier.of(Array(varDef.identifier.namespace().last), variableName), + Identifier.of(Array(varDef.identifier.namespace().last), namePartsCaseAdjusted.last), varDef) } - .orElse(Option.when(maybeTempVariableName(nameParts)) { - catalogManager.tempVariableManager - .get(variableName) - .map { varDef => - VariableReference( - nameParts, - FakeSystemCatalog, - Identifier.of(Array(CatalogManager.SESSION_NAMESPACE), variableName), - varDef) - } - }) +// .orElse(Option.when(maybeTempVariableName(nameParts)) { +// catalogManager.tempVariableManager +// .get(variableName) +// .map { varDef => +// VariableReference( +// nameParts, +// FakeSystemCatalog, +// Identifier.of(Array(CatalogManager.SESSION_NAMESPACE), variableName), +// varDef) +// } +// }) } // Resolves `UnresolvedAttribute` to its value. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/TempVariableManager.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/TempVariableManager.scala index 4185a43a17a80..328829815f460 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/TempVariableManager.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/TempVariableManager.scala @@ -36,7 +36,7 @@ trait VariableManager { overrideIfExists: Boolean, identifier: Identifier): Unit - def get(name: String): Option[VariableDefinition] + def get(nameParts: Seq[String]): Option[VariableDefinition] def remove(name: String): Boolean @@ -77,8 +77,8 @@ class TempVariableManager extends VariableManager with DataTypeErrorsBase { variables.put(name, VariableDefinition(identifier, defaultValueSQL, initValue)) } - override def get(name: String): Option[VariableDefinition] = synchronized { - variables.get(name) + override def get(nameParts: Seq[String]): Option[VariableDefinition] = synchronized { + variables.get(nameParts.last) } override def remove(name: String): Boolean = synchronized { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/ScriptingVariableManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/ScriptingVariableManager.scala index 7382b45da3158..8b57db336ef4b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/ScriptingVariableManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/ScriptingVariableManager.scala @@ -39,11 +39,15 @@ class ScriptingVariableManager(context: SqlScriptingExecutionContext) extends Va )) } - override def get(name: String): Option[VariableDefinition] = { - // todo LOCALVARS: add support for qualified name - context.currentFrame.scopes + override def get(nameParts: Seq[String]): Option[VariableDefinition] = nameParts match { + case Seq(name) => + context.currentFrame.scopes .findLast(_.variables.contains(name)) .map(_.variables(name)) + case Seq(label, name) => + context.currentFrame.scopes + .findLast(_.label == label) + .map(_.variables(name)) } override def remove(name: String): Boolean = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index 209af5a4ad4d7..faa09c6ac13c3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -74,8 +74,8 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { """ |BEGIN | lbl: BEGIN - |DECLARE var = 1; - |SELECT var + var * 2; + | DECLARE var = 1; + | SELECT lbl.var; | END; |END |""".stripMargin @@ -94,8 +94,13 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { val sqlScript = """ |BEGIN - |DECLARE `my.var.i.ab.le` = 1; - |SELECT `my.var.i.ab.le` + `my.var.i.ab.le` * 2; + | lbl: BEGIN + | DECLARE var = 1; + | lbl2: BEGIN + | DECLARE var = 2; + | SELECT lbl.var; + | END; + | END; |END |""".stripMargin From 399d4e86b09046fe5be4a738f4fc40c9692111e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Fri, 3 Jan 2025 13:24:23 +0100 Subject: [PATCH 07/59] update todos --- .../analysis/ColumnResolutionHelper.scala | 5 ++-- .../analysis/ResolveSetVariable.scala | 3 +-- .../analysis/resolver/ResolverGuard.scala | 1 + .../catalyst/analysis/v2ResolutionPlans.scala | 1 - .../catalog/TempVariableManager.scala | 1 - .../connector/catalog/CatalogManager.scala | 2 -- .../command/v2/CreateVariableExec.scala | 2 ++ .../scripting/ScriptingVariableManager.scala | 6 ++++- .../SqlScriptingInterpreterSuite.scala | 25 +++++++++++++++++++ 9 files changed, 36 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala index e0caca6062cf8..de131ca237b08 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala @@ -271,20 +271,19 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { // nameParts.last.toLowerCase(Locale.ROOT) // } - // todo LOCALVARS: check if we have 1 or 2 nameparts (here and maybe createvarexec) - // todo LOCALVARS: should we do this for label or only for var name + // todo LOCALVARS: should we do this for all nameParts or only name val namePartsCaseAdjusted = if (conf.caseSensitiveAnalysis) { nameParts } else { nameParts.map(_.toLowerCase(Locale.ROOT)) } + // todo LOCALVARS: if system.session.var check only tempVariableManager catalogManager.scriptingLocalVariableManager .flatMap(_.get(namePartsCaseAdjusted)) .map { varDef => VariableReference( nameParts, - // todo LOCALVARS: deal with this fakesystemcatalog / session_namespace situation FakeSystemCatalog, Identifier.of(Array(varDef.identifier.namespace().last), namePartsCaseAdjusted.last), varDef) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSetVariable.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSetVariable.scala index 52104d412ae22..4b8e414a80e0b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSetVariable.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSetVariable.scala @@ -55,8 +55,7 @@ class ResolveSetVariable(val catalogManager: CatalogManager) extends Rule[Logica // No need for case insensitive comparison here. // TODO: we need to group by the qualified variable name once other catalogs support it. - // todo LOCALVARS: the todo above, although possibly not neceesary because it might work. - // research further + // todo LOCALVARS: the todo above val dups = resolvedVars.groupBy(_.identifier.name).filter(kv => kv._2.length > 1) if (dups.nonEmpty) { throw new AnalysisException( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverGuard.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverGuard.scala index b3b3d4def602d..b515f4d64cb93 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverGuard.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverGuard.scala @@ -209,6 +209,7 @@ class ResolverGuard(catalogManager: CatalogManager) extends SQLConfHelper { // Case-sensitive inference is not supported for Hive table schema. conf.caseSensitiveInferenceMode == HiveCaseSensitiveInferenceMode.NEVER_INFER + // todo LOCALVARS: add check that no local vars exist here private def checkVariables() = catalogManager.tempVariableManager.isEmpty } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala index cac41c3a50608..dee78b8f03af4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala @@ -259,6 +259,5 @@ case class ResolvedIdentifier( // A fake v2 catalog to hold temp views. object FakeSystemCatalog extends CatalogPlugin { override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = {} - // todo LOCALVARS: why is this here override def name(): String = "system" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/TempVariableManager.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/TempVariableManager.scala index 328829815f460..67deeb7aebe05 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/TempVariableManager.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/TempVariableManager.scala @@ -89,7 +89,6 @@ class TempVariableManager extends VariableManager with DataTypeErrorsBase { variables.clear() } - // todo LOCALVARS: check what this is for with Vladimir override def isEmpty: Boolean = synchronized { variables.isEmpty } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala index 90ad72fce30e3..70764644e3fb0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala @@ -48,7 +48,6 @@ class CatalogManager( // TODO: create a real SYSTEM catalog to host `TempVariableManager` under the SESSION namespace. val tempVariableManager: TempVariableManager = new TempVariableManager - // todo LOCALVARS: should this be thread local (probably) var scriptingLocalVariableManager: Option[VariableManager] = None def catalog(name: String): CatalogPlugin = synchronized { @@ -161,7 +160,6 @@ class CatalogManager( private[sql] object CatalogManager { val SESSION_CATALOG_NAME: String = "spark_catalog" - // todo LOCALVARS: whats this val SYSTEM_CATALOG_NAME = "system" val SESSION_NAMESPACE = "session" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/CreateVariableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/CreateVariableExec.scala index 02251120abd61..fba0a241ef1ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/CreateVariableExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/CreateVariableExec.scala @@ -40,6 +40,8 @@ case class CreateVariableExec( val exprs = prepareExpressions(Seq(defaultExpr.child), subExprEliminationEnabled = false) initializeExprs(exprs, 0) val initValue = Literal(exprs.head.eval(), defaultExpr.dataType) + + // todo LOCALVARS: should we do this for the entire identifier or only name val normalizedName = if (session.sessionState.conf.caseSensitiveAnalysis) { identifier.name() } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/ScriptingVariableManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/ScriptingVariableManager.scala index 8b57db336ef4b..eb8614e81756d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/ScriptingVariableManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/ScriptingVariableManager.scala @@ -29,10 +29,12 @@ class ScriptingVariableManager(context: SqlScriptingExecutionContext) extends Va initValue: Literal, overrideIfExists: Boolean, identifier: Identifier): Unit = { - // todo LOCALVARS: qualified name + // todo LOCALVARS: check for duplicate (somewhere) context.currentScope.variables.put( name, VariableDefinition( + // we use the label name of current scope as namespace for local variables + // e.g. ("scopeName.varName") Identifier.of(Array(context.currentScope.label), name), defaultValueSQL, initValue @@ -48,6 +50,8 @@ class ScriptingVariableManager(context: SqlScriptingExecutionContext) extends Va context.currentFrame.scopes .findLast(_.label == label) .map(_.variables(name)) + // todo LOCALVARS: add error if not 1 or 2 nameparts + // case _ => throw error } override def remove(name: String): Boolean = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index faa09c6ac13c3..7ef93db465430 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -114,6 +114,31 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { // verifySqlScriptResult(sqlScript, expected) } + test("testtest3") { + val sqlScript = + """ + |BEGIN + | lbl: BEGIN + | DECLARE var = 1; + | lbl2: BEGIN + | DECLARE var = 2; + | SELECT lbl.var; + | SET (var, lbl.var) = (select 1, 2); + | END; + | END; + |END + |""".stripMargin + + val r = spark.sql(sqlScript).collect() + + val expected = Seq( + Seq.empty[Row], // declare var + Seq(Row(1)), // select + Seq.empty[Row] // drop var + ) +// verifySqlScriptResult(sqlScript, expected) + } + // Tests test("multi statement - simple") { withTable("t") { From 6efe764c81102c8a9ba6e4bca2ea01cd9f14683d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Fri, 3 Jan 2025 15:44:34 +0100 Subject: [PATCH 08/59] resolve catalogs + check for duplicates --- .../catalyst/analysis/ResolveCatalogs.scala | 42 ++++++++++-------- .../analysis/resolver/ResolverGuard.scala | 4 +- .../catalog/TempVariableManager.scala | 7 ++- .../command/v2/CreateVariableExec.scala | 2 +- .../scripting/ScriptingVariableManager.scala | 27 ++++++++---- .../SqlScriptingInterpreterSuite.scala | 43 +++++++++++++++++++ 6 files changed, 94 insertions(+), 31 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala index 642fbe7b00287..2d5e495ba83f6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala @@ -74,28 +74,32 @@ class ResolveCatalogs(val catalogManager: CatalogManager) } private def resolveVariableName(nameParts: Seq[String]): ResolvedIdentifier = { - def ident: Identifier = Identifier.of(Array(CatalogManager.SESSION_NAMESPACE), nameParts.last) - // todo LOCALVARS: update to support local vars - if (nameParts.length == 1) { - ResolvedIdentifier(FakeSystemCatalog, ident) - } else if (nameParts.length == 2) { - if (nameParts.head.equalsIgnoreCase(CatalogManager.SESSION_NAMESPACE)) { + if (catalogManager.scriptingLocalVariableManager.isDefined && nameParts.length != 1) { + // todo LOCALVARS: create errorclass for this + throw new Exception("must be unqualified") + } + + val ident: Identifier = catalogManager.scriptingLocalVariableManager + .getOrElse(catalogManager.tempVariableManager) + .createIdentifier(nameParts.last) + + nameParts.length match { + case 1 => ResolvedIdentifier(FakeSystemCatalog, ident) - } else { - throw QueryCompilationErrors.unresolvedVariableError( - nameParts, Seq(CatalogManager.SYSTEM_CATALOG_NAME, CatalogManager.SESSION_NAMESPACE)) - } - } else if (nameParts.length == 3) { - if (nameParts(0).equalsIgnoreCase(CatalogManager.SYSTEM_CATALOG_NAME) && - nameParts(1).equalsIgnoreCase(CatalogManager.SESSION_NAMESPACE)) { + + case 2 if nameParts.head.equalsIgnoreCase(CatalogManager.SESSION_NAMESPACE) => ResolvedIdentifier(FakeSystemCatalog, ident) - } else { + + case 3 if nameParts(0).equalsIgnoreCase(CatalogManager.SYSTEM_CATALOG_NAME) && + nameParts(1).equalsIgnoreCase(CatalogManager.SESSION_NAMESPACE) => + ResolvedIdentifier(FakeSystemCatalog, ident) + + case _ => + // todo LOCALVARS: update this error throw QueryCompilationErrors.unresolvedVariableError( - nameParts, Seq(CatalogManager.SYSTEM_CATALOG_NAME, CatalogManager.SESSION_NAMESPACE)) - } - } else { - throw QueryCompilationErrors.unresolvedVariableError( - nameParts, Seq(CatalogManager.SYSTEM_CATALOG_NAME, CatalogManager.SESSION_NAMESPACE)) + nameParts, Seq(CatalogManager.SYSTEM_CATALOG_NAME, CatalogManager.SESSION_NAMESPACE) + ) } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverGuard.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverGuard.scala index b515f4d64cb93..26192c3eb2716 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverGuard.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverGuard.scala @@ -209,8 +209,8 @@ class ResolverGuard(catalogManager: CatalogManager) extends SQLConfHelper { // Case-sensitive inference is not supported for Hive table schema. conf.caseSensitiveInferenceMode == HiveCaseSensitiveInferenceMode.NEVER_INFER - // todo LOCALVARS: add check that no local vars exist here - private def checkVariables() = catalogManager.tempVariableManager.isEmpty + private def checkVariables() = catalogManager.tempVariableManager.isEmpty && + catalogManager.scriptingLocalVariableManager.forall(_.isEmpty) } object ResolverGuard { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/TempVariableManager.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/TempVariableManager.scala index 67deeb7aebe05..ea23c1cec658b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/TempVariableManager.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/TempVariableManager.scala @@ -23,8 +23,8 @@ import scala.collection.mutable import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.connector.catalog.{CatalogManager, Identifier} import org.apache.spark.sql.connector.catalog.CatalogManager.{SESSION_NAMESPACE, SYSTEM_CATALOG_NAME} -import org.apache.spark.sql.connector.catalog.Identifier import org.apache.spark.sql.errors.DataTypeErrorsBase // todo LOCALVARS: move this to separate file or rename this file @@ -40,6 +40,8 @@ trait VariableManager { def remove(name: String): Boolean + def createIdentifier(name: String): Identifier + def clear(): Unit def isEmpty: Boolean @@ -85,6 +87,9 @@ class TempVariableManager extends VariableManager with DataTypeErrorsBase { variables.remove(name).isDefined } + override def createIdentifier(name: String): Identifier = + Identifier.of(Array(CatalogManager.SESSION_NAMESPACE), name) + override def clear(): Unit = synchronized { variables.clear() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/CreateVariableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/CreateVariableExec.scala index fba0a241ef1ae..a36c50568a5aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/CreateVariableExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/CreateVariableExec.scala @@ -35,8 +35,8 @@ case class CreateVariableExec( override protected def run(): Seq[InternalRow] = { val scriptingVariableManager = session.sessionState.catalogManager.scriptingLocalVariableManager - val tempVariableManager = session.sessionState.catalogManager.tempVariableManager + val exprs = prepareExpressions(Seq(defaultExpr.child), subExprEliminationEnabled = false) initializeExprs(exprs, 0) val initValue = Literal(exprs.head.eval(), defaultExpr.dataType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/ScriptingVariableManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/ScriptingVariableManager.scala index eb8614e81756d..a9ea35d49aa5d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/ScriptingVariableManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/ScriptingVariableManager.scala @@ -17,11 +17,15 @@ package org.apache.spark.sql.scripting +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.catalog.{VariableDefinition, VariableManager} import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.connector.catalog.CatalogManager.SYSTEM_CATALOG_NAME import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.errors.DataTypeErrorsBase -class ScriptingVariableManager(context: SqlScriptingExecutionContext) extends VariableManager { +class ScriptingVariableManager(context: SqlScriptingExecutionContext) + extends VariableManager with DataTypeErrorsBase { override def create( name: String, @@ -29,13 +33,16 @@ class ScriptingVariableManager(context: SqlScriptingExecutionContext) extends Va initValue: Literal, overrideIfExists: Boolean, identifier: Identifier): Unit = { - // todo LOCALVARS: check for duplicate (somewhere) + if (!overrideIfExists && context.currentScope.variables.contains(name)) { + throw new AnalysisException( + errorClass = "VARIABLE_ALREADY_EXISTS", + messageParameters = Map( + "variableName" -> toSQLId(Seq(SYSTEM_CATALOG_NAME, context.currentScope.label, name)))) + } context.currentScope.variables.put( name, VariableDefinition( - // we use the label name of current scope as namespace for local variables - // e.g. ("scopeName.varName") - Identifier.of(Array(context.currentScope.label), name), + identifier, defaultValueSQL, initValue )) @@ -54,13 +61,17 @@ class ScriptingVariableManager(context: SqlScriptingExecutionContext) extends Va // case _ => throw error } + override def createIdentifier(name: String): Identifier = + Identifier.of(Array(context.currentScope.label), name) + override def remove(name: String): Boolean = { // probably throw error true } - // todo LOCALVARS: do we need this - override def clear(): Unit = () + // todo LOCALVARS: create errorclass for this + override def clear(): Unit = throw new Exception("cant clear() scripting manager") - override def isEmpty: Boolean = context.currentFrame.scopes.forall(_.variables.isEmpty) + // Empty if all scopes of all frames in the script context contain no variables. + override def isEmpty: Boolean = context.frames.forall(_.scopes.forall(_.variables.isEmpty)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index 7ef93db465430..bf42b398d873f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -139,6 +139,49 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { // verifySqlScriptResult(sqlScript, expected) } + test("testtest4") { + val sqlScript = + """ + |BEGIN + | lbl: BEGIN + | DECLARE var = 1; + | SELECT lbl.var; + | END; + |END + |""".stripMargin + + val r = spark.sql(sqlScript).collect() + + val expected = Seq( + Seq.empty[Row], // declare var + Seq(Row(1)), // select + Seq.empty[Row] // drop var + ) + // verifySqlScriptResult(sqlScript, expected) + } + + test("testtest5") { + val sqlScript = + """ + |BEGIN + | lbl: BEGIN + | DECLARE var = 1; + | DECLARE var = 2; + | SELECT lbl.var; + | END; + |END + |""".stripMargin + + val r = spark.sql(sqlScript).collect() + + val expected = Seq( + Seq.empty[Row], // declare var + Seq(Row(1)), // select + Seq.empty[Row] // drop var + ) + // verifySqlScriptResult(sqlScript, expected) + } + // Tests test("multi statement - simple") { withTable("t") { From 769607d3eca7eb39a9d60129caf5c52f62b403e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Fri, 3 Jan 2025 19:55:48 +0100 Subject: [PATCH 09/59] set variable and normalized identifiers --- .../analysis/ColumnResolutionHelper.scala | 7 ------ .../command/v2/CreateVariableExec.scala | 16 +++++++++----- .../command/v2/SetVariableExec.scala | 13 ++++++----- .../scripting/ScriptingVariableManager.scala | 2 +- .../SqlScriptingInterpreterSuite.scala | 22 +++++++++++++++++++ 5 files changed, 42 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala index de131ca237b08..80c63d3720e19 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala @@ -265,13 +265,6 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { } } -// val variableName = if (conf.caseSensitiveAnalysis) { -// nameParts.last -// } else { -// nameParts.last.toLowerCase(Locale.ROOT) -// } - - // todo LOCALVARS: should we do this for all nameParts or only name val namePartsCaseAdjusted = if (conf.caseSensitiveAnalysis) { nameParts } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/CreateVariableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/CreateVariableExec.scala index a36c50568a5aa..49447cf05aa99 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/CreateVariableExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/CreateVariableExec.scala @@ -41,16 +41,22 @@ case class CreateVariableExec( initializeExprs(exprs, 0) val initValue = Literal(exprs.head.eval(), defaultExpr.dataType) - // todo LOCALVARS: should we do this for the entire identifier or only name - val normalizedName = if (session.sessionState.conf.caseSensitiveAnalysis) { - identifier.name() + val normalizedIdentifier = if (session.sessionState.conf.caseSensitiveAnalysis) { + identifier } else { - identifier.name().toLowerCase(Locale.ROOT) + Identifier.of( + identifier.namespace().map(_.toLowerCase(Locale.ROOT)), + identifier.name().toLowerCase(Locale.ROOT)) } // create local variable if we are in a script, otherwise create session variable scriptingVariableManager.getOrElse(tempVariableManager) - .create(normalizedName, defaultExpr.originalSQL, initValue, replace, identifier) + .create( + normalizedIdentifier.name(), + defaultExpr.originalSQL, + initValue, + replace, + normalizedIdentifier) Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/SetVariableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/SetVariableExec.scala index a893b6fdf4976..f98cac39ff52a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/SetVariableExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/v2/SetVariableExec.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.command.v2 import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.catalog.TempVariableManager +import org.apache.spark.sql.catalyst.catalog.VariableManager import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal, VariableReference} import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.execution.SparkPlan @@ -32,11 +32,14 @@ case class SetVariableExec(variables: Seq[VariableReference], query: SparkPlan) extends V2CommandExec with UnaryLike[SparkPlan] { override protected def run(): Seq[InternalRow] = { - val variableManager = session.sessionState.catalogManager.tempVariableManager + val tempVariableManager = session.sessionState.catalogManager.tempVariableManager + val scriptingVariableManager = session.sessionState.catalogManager.scriptingLocalVariableManager + val manager = scriptingVariableManager.getOrElse(tempVariableManager) + val values = query.executeCollect() if (values.length == 0) { variables.foreach { v => - createVariable(variableManager, v, null) + createVariable(manager, v, null) } } else if (values.length > 1) { throw new SparkException( @@ -47,14 +50,14 @@ case class SetVariableExec(variables: Seq[VariableReference], query: SparkPlan) val row = values(0) variables.zipWithIndex.foreach { case (v, index) => val value = row.get(index, v.dataType) - createVariable(variableManager, v, value) + createVariable(manager, v, value) } } Seq.empty } private def createVariable( - variableManager: TempVariableManager, + variableManager: VariableManager, variable: VariableReference, value: Any): Unit = { variableManager.create( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/ScriptingVariableManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/ScriptingVariableManager.scala index a9ea35d49aa5d..4ed62e92b6089 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/ScriptingVariableManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/ScriptingVariableManager.scala @@ -66,7 +66,7 @@ class ScriptingVariableManager(context: SqlScriptingExecutionContext) override def remove(name: String): Boolean = { // probably throw error - true + throw new Exception("cant remove local var") } // todo LOCALVARS: create errorclass for this diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index bf42b398d873f..0d0e0d6133b28 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -182,6 +182,28 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { // verifySqlScriptResult(sqlScript, expected) } + test("testtest6") { + val sqlScript = + """ + |BEGIN + | lbl: BEGIN + | DECLARE var = 1; + | SET lbl.var = 5; + | SELECT lbl.var; + | END; + |END + |""".stripMargin + + val r = spark.sql(sqlScript).collect() + + val expected = Seq( + Seq.empty[Row], // declare var + Seq(Row(1)), // select + Seq.empty[Row] // drop var + ) + // verifySqlScriptResult(sqlScript, expected) + } + // Tests test("multi statement - simple") { withTable("t") { From 622595690cc9052a3e00b7f0bca8a16b68dcd1b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Mon, 6 Jan 2025 13:42:49 +0100 Subject: [PATCH 10/59] resolve fully qualified session vars in tempvarManager only and update ResolveCatalogs to support local vars, also to throw error when creating qualified local vars --- .../analysis/ColumnResolutionHelper.scala | 4 +++- .../sql/catalyst/analysis/ResolveCatalogs.scala | 16 ++++++++-------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala index 80c63d3720e19..1b66766df0c50 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala @@ -271,8 +271,10 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { nameParts.map(_.toLowerCase(Locale.ROOT)) } - // todo LOCALVARS: if system.session.var check only tempVariableManager + // todo LOCALVARS: check if system.session.var uses only tempVariableManager catalogManager.scriptingLocalVariableManager + // if variable name is qualified with system.session. treat it as a session variable + .filterNot(_ => namePartsCaseAdjusted.take(2) == Seq("system", "session")) .flatMap(_.get(namePartsCaseAdjusted)) .map { varDef => VariableReference( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala index 2d5e495ba83f6..c5b9fe70349ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala @@ -35,7 +35,7 @@ class ResolveCatalogs(val catalogManager: CatalogManager) // We only support temp variables for now and the system catalog is not properly implemented // yet. We need to resolve `UnresolvedIdentifier` for variable commands specially. case c @ CreateVariable(UnresolvedIdentifier(nameParts, _), _, _) => - val resolved = resolveVariableName(nameParts) + val resolved = resolveCreateVariableName(nameParts) c.copy(name = resolved) case d @ DropVariable(UnresolvedIdentifier(nameParts, _), _) => val resolved = resolveVariableName(nameParts) @@ -73,22 +73,22 @@ class ResolveCatalogs(val catalogManager: CatalogManager) } } - private def resolveVariableName(nameParts: Seq[String]): ResolvedIdentifier = { + private def resolveCreateVariableName(nameParts: Seq[String]): ResolvedIdentifier = { + // from scripts we can only create local variables, which must be unqualified if (catalogManager.scriptingLocalVariableManager.isDefined && nameParts.length != 1) { - // todo LOCALVARS: create errorclass for this + // todo LOCALVARS: add error class throw new Exception("must be unqualified") } + resolveVariableName(nameParts) + } + private def resolveVariableName(nameParts: Seq[String]): ResolvedIdentifier = { val ident: Identifier = catalogManager.scriptingLocalVariableManager .getOrElse(catalogManager.tempVariableManager) .createIdentifier(nameParts.last) nameParts.length match { - case 1 => - ResolvedIdentifier(FakeSystemCatalog, ident) - - case 2 if nameParts.head.equalsIgnoreCase(CatalogManager.SESSION_NAMESPACE) => - ResolvedIdentifier(FakeSystemCatalog, ident) + case 1 | 2 => ResolvedIdentifier(FakeSystemCatalog, ident) case 3 if nameParts(0).equalsIgnoreCase(CatalogManager.SYSTEM_CATALOG_NAME) && nameParts(1).equalsIgnoreCase(CatalogManager.SESSION_NAMESPACE) => From 241fc0521830d821a0582f195240c064e8308af9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Mon, 6 Jan 2025 19:18:46 +0100 Subject: [PATCH 11/59] tests first batch --- .../analysis/ColumnResolutionHelper.scala | 25 +-- .../SqlScriptingExecutionSuite.scala | 173 ++++++++++++++++++ .../apache/spark/sql/test/SQLTestUtils.scala | 11 ++ 3 files changed, 197 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala index 1b66766df0c50..a4e2ffb9cb595 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala @@ -271,7 +271,6 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { nameParts.map(_.toLowerCase(Locale.ROOT)) } - // todo LOCALVARS: check if system.session.var uses only tempVariableManager catalogManager.scriptingLocalVariableManager // if variable name is qualified with system.session. treat it as a session variable .filterNot(_ => namePartsCaseAdjusted.take(2) == Seq("system", "session")) @@ -283,17 +282,19 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { Identifier.of(Array(varDef.identifier.namespace().last), namePartsCaseAdjusted.last), varDef) } -// .orElse(Option.when(maybeTempVariableName(nameParts)) { -// catalogManager.tempVariableManager -// .get(variableName) -// .map { varDef => -// VariableReference( -// nameParts, -// FakeSystemCatalog, -// Identifier.of(Array(CatalogManager.SESSION_NAMESPACE), variableName), -// varDef) -// } -// }) + .orElse( + Option.when(maybeTempVariableName(nameParts)) { + catalogManager.tempVariableManager + .get(namePartsCaseAdjusted) + .map { varDef => + VariableReference( + nameParts, + FakeSystemCatalog, + Identifier.of(Array(CatalogManager.SESSION_NAMESPACE), namePartsCaseAdjusted.last), + varDef) + } + }.flatten + ) } // Resolves `UnresolvedAttribute` to its value. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala index bbeae942f9fe7..763a2522a1b01 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala @@ -1056,4 +1056,177 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { ) verifySqlScriptResult(sqlScriptText, expected) } + + test("local variable") { + val sqlScript = + """ + |BEGIN + | lbl: BEGIN + | DECLARE localVar = 1; + | SELECT lbl.localVar; + | END; + |END + |""".stripMargin + + val expected = Seq( + Seq(Row(1)) // select lbl.localVar + ) + verifySqlScriptResult(sqlScript, expected) + } + + test("local variable - nested compounds") { + val sqlScript = + """ + |BEGIN + | lbl1: BEGIN + | DECLARE localVar = 1; + | lbl2: BEGIN + | DECLARE localVar = 2; + | SELECT var; + | SELECT lbl1.localVar; + | SELECT lbl2.localVar; + | END; + | END; + |END + |""".stripMargin + + val expected = Seq( + Seq(Row(2)), // select localVar + Seq(Row(1)), // select lbl1.localVar + Seq(Row(2)) // select lbl2.localVar + ) + verifySqlScriptResult(sqlScript, expected) + } + + // todo: fix this case + +// test("testtest3") { +// val sqlScript = +// """ +// |BEGIN +// | lbl: BEGIN +// | DECLARE var = 1; +// | lbl2: BEGIN +// | DECLARE var = 2; +// | SELECT lbl.var; +// | SET (var, lbl.var) = (select 1, 2); +// | END; +// | END; +// |END +// |""".stripMargin +// +// val r = spark.sql(sqlScript).collect() +// +// val expected = Seq( +// Seq.empty[Row], // declare var +// Seq(Row(1)), // select +// Seq.empty[Row] // drop var +// ) +// // verifySqlScriptResult(sqlScript, expected) +// } + + // todo: check error when it's added + +// test("testtest5") { +// val sqlScript = +// """ +// |BEGIN +// | lbl: BEGIN +// | DECLARE var = 1; +// | DECLARE var = 2; +// | SELECT lbl.var; +// | END; +// |END +// |""".stripMargin +// +// val r = spark.sql(sqlScript).collect() +// +// val expected = Seq( +// Seq.empty[Row], // declare var +// Seq(Row(1)), // select +// Seq.empty[Row] // drop var +// ) +// // verifySqlScriptResult(sqlScript, expected) +// } + + test("local variable - set qualified") { + val sqlScript = + """ + |BEGIN + | lbl: BEGIN + | DECLARE localVar = 1; + | SELECT lbl.localVar; + | SET lbl.localVar = 5; + | SELECT lbl.localVar; + | END; + |END + |""".stripMargin + + val expected = Seq( + Seq(Row(1)), // select lbl.localVar + Seq(Row(5)) // select lbl.localVar + ) + verifySqlScriptResult(sqlScript, expected) + } + + test("local variable - set unqualified") { + val sqlScript = + """ + |BEGIN + | lbl: BEGIN + | DECLARE localVar = 1; + | SELECT localVar; + | SET localVar = 5; + | SELECT localVar; + | END; + |END + |""".stripMargin + + val expected = Seq( + Seq(Row(1)), // select localVar + Seq(Row(5)) // select localVar + ) + verifySqlScriptResult(sqlScript, expected) + } + + test("local variable - set unqualified select qualified") { + val sqlScript = + """ + |BEGIN + | lbl: BEGIN + | DECLARE localVar = 1; + | SELECT lbl.localVar; + | SET localVar = 5; + | SELECT lbl.localVar; + | END; + |END + |""".stripMargin + + val expected = Seq( + Seq(Row(1)), // select lbl.localVar + Seq(Row(5)) // select lbl.localVar + ) + verifySqlScriptResult(sqlScript, expected) + } + +// test("local variable - resolved over session variable") { +// withSessionVariable("localVar") { +// spark.sql("DECLARE VARIABLE localVar = 1") +// +// val sqlScript = +// """ +// |BEGIN +// | lbl: BEGIN +// | DECLARE localVar = 5; +// | SELECT localVar; +// | END; +// |END +// |""".stripMargin +// +// val expected = Seq( +// Seq(Row(5)) // select lbl.localVar +// ) +// verifySqlScriptResult(sqlScript, expected) +// } +// } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index c93f17701c620..5ad046a94c10d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -423,6 +423,17 @@ private[sql] trait SQLTestUtilsBase } } + /** + * Drops temporary variable `variableName` after calling `f`. + */ + protected def withSessionVariable(variableNames: String*)(f: => Unit): Unit = { + Utils.tryWithSafeFinally(f) { + variableNames.foreach { name => + spark.sql(s"DROP TEMPORARY VARIABLE IF EXISTS $name") + } + } + } + /** * Activates database `db` before executing `f`, then switches back to `default` database after * `f` returns. From 068e1eca4a51598a6e7b68219f2e8bd40f0818b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Wed, 8 Jan 2025 12:59:00 +0100 Subject: [PATCH 12/59] add more tests --- .../sql/scripting/SqlScriptingExecution.scala | 3 + .../SqlScriptingExecutionSuite.scala | 161 +++++++++++++++--- 2 files changed, 143 insertions(+), 21 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala index 070fd34ecb0d2..0f10e7d7b9722 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala @@ -74,6 +74,9 @@ class SqlScriptingExecution( if (context.frames.nonEmpty) { return Some(context.frames.last.next()) } + // cleanup variable manager after script is completed + // todo LOCALVARS: figure out a better way to do this, also cleanup when script fails + session.sessionState.catalogManager.scriptingLocalVariableManager = None None } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala index 763a2522a1b01..e81695b945d7e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.scripting import org.apache.spark.SparkConf -import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.CompoundBody import org.apache.spark.sql.internal.SQLConf @@ -1209,24 +1209,143 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { verifySqlScriptResult(sqlScript, expected) } -// test("local variable - resolved over session variable") { -// withSessionVariable("localVar") { -// spark.sql("DECLARE VARIABLE localVar = 1") -// -// val sqlScript = -// """ -// |BEGIN -// | lbl: BEGIN -// | DECLARE localVar = 5; -// | SELECT localVar; -// | END; -// |END -// |""".stripMargin -// -// val expected = Seq( -// Seq(Row(5)) // select lbl.localVar -// ) -// verifySqlScriptResult(sqlScript, expected) -// } -// } + test("local variable - resolved over session variable") { + withSessionVariable("localVar") { + spark.sql("DECLARE VARIABLE localVar = 1") + + val sqlScript = + """ + |BEGIN + | lbl: BEGIN + | DECLARE localVar = 5; + | SELECT localVar; + | END; + |END + |""".stripMargin + + val expected = Seq( + Seq(Row(5)) // select localVar + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("local variable - resolved over session variable nested") { + withSessionVariable("localVar") { + spark.sql("DECLARE VARIABLE localVar = 1") + + val sqlScript = + """ + |BEGIN + | SELECT localVar; + | lbl: BEGIN + | DECLARE localVar = 5; + | SELECT localVar; + | END; + |END + |""".stripMargin + + val expected = Seq( + Seq(Row(1)), // select localVar + Seq(Row(5)) // select localVar + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("local variable - case insensitive name") { + val sqlScript = + """ + |BEGIN + | lbl: BEGIN + | DECLARE localVar = 1; + | SELECT LOCALVAR; + | END; + |END + |""".stripMargin + + val expected = Seq( + Seq(Row(1)) // select lbl.localVar + ) + verifySqlScriptResult(sqlScript, expected) + } + + test("local variable - case sensitive name") { + val e = intercept[AnalysisException] { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> true.toString) { + val sqlScript = + """ + |BEGIN + | lbl: BEGIN + | DECLARE localVar = 1; + | SELECT LOCALVAR; + | END; + |END + |""".stripMargin + + val expected = Seq( + Seq(Row(1)) // select lbl.localVar + ) + verifySqlScriptResult(sqlScript, expected) + } + } + checkError( + exception = e, + condition = "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", + sqlState = "42703", + parameters = Map("objectName" -> "`LOCALVAR`"), + context = ExpectedContext( + fragment = "LOCALVAR", + start = 52, + stop = 59) + ) + } + + test("local variable - case insensitive label") { + val sqlScript = + """ + |BEGIN + | lbl: BEGIN + | DECLARE localVar = 1; + | SELECT LBL.localVar; + | END; + |END + |""".stripMargin + + val expected = Seq( + Seq(Row(1)) // select lbl.localVar + ) + verifySqlScriptResult(sqlScript, expected) + } + + test("local variable - case sensitive label") { + val e = intercept[AnalysisException] { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> true.toString) { + val sqlScript = + """ + |BEGIN + | lbl: BEGIN + | DECLARE localVar = 1; + | SELECT LBL.localVar; + | END; + |END + |""".stripMargin + + val expected = Seq( + Seq(Row(1)) // select lbl.localVar + ) + verifySqlScriptResult(sqlScript, expected) + } + } + checkError( + exception = e, + condition = "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", + sqlState = "42703", + parameters = Map("objectName" -> "`LBL`.`localVar`"), + context = ExpectedContext( + fragment = "LBL.localVar", + start = 52, + stop = 63) + ) + } } From 60335db578b81039d4a68ae58f37278e94d14e40 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Wed, 8 Jan 2025 15:53:19 +0100 Subject: [PATCH 13/59] add error messages, more tests and some comments --- .../resources/error/error-conditions.json | 5 ++ .../catalyst/analysis/ResolveCatalogs.scala | 9 ++- .../scripting/ScriptingVariableManager.scala | 13 ++-- .../SqlScriptingExecutionSuite.scala | 74 +++++++++++++++---- 4 files changed, 78 insertions(+), 23 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index deb62866f072e..42e2281c568d0 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -3325,6 +3325,11 @@ "message" : [ "Variable can only be declared at the beginning of the compound." ] + }, + "QUALIFIED_LOCAL_VARIABLE" : { + "message" : [ + "Variable must be declared without a qualifier." + ] } }, "sqlState" : "42K0M" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala index c5b9fe70349ce..5ae61be7995a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import scala.jdk.CollectionConverters._ +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogPlugin, Identifier, LookupCatalog, SupportsNamespaces} @@ -76,8 +77,9 @@ class ResolveCatalogs(val catalogManager: CatalogManager) private def resolveCreateVariableName(nameParts: Seq[String]): ResolvedIdentifier = { // from scripts we can only create local variables, which must be unqualified if (catalogManager.scriptingLocalVariableManager.isDefined && nameParts.length != 1) { - // todo LOCALVARS: add error class - throw new Exception("must be unqualified") + throw new AnalysisException( + "INVALID_VARIABLE_DECLARATION.QUALIFIED_LOCAL_VARIABLE", + Map("varName" -> nameParts.mkString("."))) } resolveVariableName(nameParts) } @@ -90,6 +92,8 @@ class ResolveCatalogs(val catalogManager: CatalogManager) nameParts.length match { case 1 | 2 => ResolvedIdentifier(FakeSystemCatalog, ident) + // When there are 3 nameParts the variable must be a fully qualified session variable + // i.e. "system.session." case 3 if nameParts(0).equalsIgnoreCase(CatalogManager.SYSTEM_CATALOG_NAME) && nameParts(1).equalsIgnoreCase(CatalogManager.SESSION_NAMESPACE) => ResolvedIdentifier(FakeSystemCatalog, ident) @@ -100,6 +104,5 @@ class ResolveCatalogs(val catalogManager: CatalogManager) nameParts, Seq(CatalogManager.SYSTEM_CATALOG_NAME, CatalogManager.SESSION_NAMESPACE) ) } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/ScriptingVariableManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/ScriptingVariableManager.scala index 4ed62e92b6089..6a86fb486c518 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/ScriptingVariableManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/ScriptingVariableManager.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.scripting +import org.apache.spark.SparkException import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.catalog.{VariableDefinition, VariableManager} import org.apache.spark.sql.catalyst.expressions.Literal @@ -57,20 +58,20 @@ class ScriptingVariableManager(context: SqlScriptingExecutionContext) context.currentFrame.scopes .findLast(_.label == label) .map(_.variables(name)) - // todo LOCALVARS: add error if not 1 or 2 nameparts - // case _ => throw error + case _ => + throw SparkException.internalError("ScriptingVariableManager.get expects 1 or 2 nameParts.") } override def createIdentifier(name: String): Identifier = Identifier.of(Array(context.currentScope.label), name) override def remove(name: String): Boolean = { - // probably throw error - throw new Exception("cant remove local var") + throw SparkException.internalError( + "ScriptingVariableManager.remove should never be called as local variables cannot be dropped." + ) } - // todo LOCALVARS: create errorclass for this - override def clear(): Unit = throw new Exception("cant clear() scripting manager") + override def clear(): Unit = context.frames.clear() // Empty if all scopes of all frames in the script context contain no variables. override def isEmpty: Boolean = context.frames.forall(_.scopes.forall(_.variables.isEmpty)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala index e81695b945d7e..457ef47575129 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala @@ -1063,12 +1063,14 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { |BEGIN | lbl: BEGIN | DECLARE localVar = 1; + | SELECT localVar; | SELECT lbl.localVar; | END; |END |""".stripMargin val expected = Seq( + Seq(Row(1)), // select localVar Seq(Row(1)) // select lbl.localVar ) verifySqlScriptResult(sqlScript, expected) @@ -1253,6 +1255,29 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { } } + test("local variable - session variable resolved over local if qualified") { + withSessionVariable("localVar") { + spark.sql("DECLARE VARIABLE localVar = 1") + + val sqlScript = + """ + |BEGIN + | lbl: BEGIN + | DECLARE localVar = 5; + | SELECT system.session.localVar; + | SELECT localVar; + | END; + |END + |""".stripMargin + + val expected = Seq( + Seq(Row(1)), // select system.session.localVar + Seq(Row(5)) // select localVar + ) + verifySqlScriptResult(sqlScript, expected) + } + } + test("local variable - case insensitive name") { val sqlScript = """ @@ -1319,24 +1344,22 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { } test("local variable - case sensitive label") { + val sqlScript = + """ + |BEGIN + | lbl: BEGIN + | DECLARE localVar = 1; + | SELECT LBL.localVar; + | END; + |END + |""".stripMargin + val e = intercept[AnalysisException] { withSQLConf(SQLConf.CASE_SENSITIVE.key -> true.toString) { - val sqlScript = - """ - |BEGIN - | lbl: BEGIN - | DECLARE localVar = 1; - | SELECT LBL.localVar; - | END; - |END - |""".stripMargin - - val expected = Seq( - Seq(Row(1)) // select lbl.localVar - ) - verifySqlScriptResult(sqlScript, expected) + verifySqlScriptResult(sqlScript, Seq.empty[Seq[Row]]) } } + checkError( exception = e, condition = "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", @@ -1348,4 +1371,27 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { stop = 63) ) } + + test("local variable - qualified declare") { + val sqlScript = + """ + |BEGIN + | lbl: BEGIN + | DECLARE lbl.localVar = 1; + | SELECT lbl.localVar; + | END; + |END + |""".stripMargin + + val e = intercept[AnalysisException] { + verifySqlScriptResult(sqlScript, Seq.empty[Seq[Row]]) + } + + checkError( + exception = e, + condition = "INVALID_VARIABLE_DECLARATION.QUALIFIED_LOCAL_VARIABLE", + sqlState = "42K0M", + parameters = Map("varName" -> "lbl.localVar") + ) + } } From 65b69d391d76ecad1c08c6bc19aeb6e6207f2658 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Wed, 8 Jan 2025 16:39:45 +0100 Subject: [PATCH 14/59] rename TempVariableManager.scala and add more tests --- .../catalyst/analysis/ResolveCatalogs.scala | 3 +- ...bleManager.scala => VariableManager.scala} | 1 - .../scripting/ScriptingVariableManager.scala | 3 +- .../SqlScriptingExecutionSuite.scala | 152 ++++++++++++++---- .../SqlScriptingInterpreterSuite.scala | 22 --- 5 files changed, 128 insertions(+), 53 deletions(-) rename sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/{TempVariableManager.scala => VariableManager.scala} (98%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala index 5ae61be7995a9..78479debbd02d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogPlugin, Identifier, LookupCatalog, SupportsNamespaces} +import org.apache.spark.sql.errors.DataTypeErrors.toSQLId import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.util.ArrayImplicits._ @@ -79,7 +80,7 @@ class ResolveCatalogs(val catalogManager: CatalogManager) if (catalogManager.scriptingLocalVariableManager.isDefined && nameParts.length != 1) { throw new AnalysisException( "INVALID_VARIABLE_DECLARATION.QUALIFIED_LOCAL_VARIABLE", - Map("varName" -> nameParts.mkString("."))) + Map("varName" -> toSQLId(nameParts))) } resolveVariableName(nameParts) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/TempVariableManager.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/VariableManager.scala similarity index 98% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/TempVariableManager.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/VariableManager.scala index ea23c1cec658b..f16859a74bdb5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/TempVariableManager.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/VariableManager.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql.connector.catalog.{CatalogManager, Identifier} import org.apache.spark.sql.connector.catalog.CatalogManager.{SESSION_NAMESPACE, SYSTEM_CATALOG_NAME} import org.apache.spark.sql.errors.DataTypeErrorsBase -// todo LOCALVARS: move this to separate file or rename this file trait VariableManager { def create( name: String, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/ScriptingVariableManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/ScriptingVariableManager.scala index 6a86fb486c518..04f8bf108a02b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/ScriptingVariableManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/ScriptingVariableManager.scala @@ -21,7 +21,6 @@ import org.apache.spark.SparkException import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.catalog.{VariableDefinition, VariableManager} import org.apache.spark.sql.catalyst.expressions.Literal -import org.apache.spark.sql.connector.catalog.CatalogManager.SYSTEM_CATALOG_NAME import org.apache.spark.sql.connector.catalog.Identifier import org.apache.spark.sql.errors.DataTypeErrorsBase @@ -38,7 +37,7 @@ class ScriptingVariableManager(context: SqlScriptingExecutionContext) throw new AnalysisException( errorClass = "VARIABLE_ALREADY_EXISTS", messageParameters = Map( - "variableName" -> toSQLId(Seq(SYSTEM_CATALOG_NAME, context.currentScope.label, name)))) + "variableName" -> toSQLId(Seq(context.currentScope.label, name)))) } context.currentScope.variables.put( name, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala index 457ef47575129..13d3ec262558e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.SparkConf import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.CompoundBody +import org.apache.spark.sql.errors.DataTypeErrors.toSQLId import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -1125,30 +1126,6 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { // Seq.empty[Row] // drop var // ) // // verifySqlScriptResult(sqlScript, expected) -// } - - // todo: check error when it's added - -// test("testtest5") { -// val sqlScript = -// """ -// |BEGIN -// | lbl: BEGIN -// | DECLARE var = 1; -// | DECLARE var = 2; -// | SELECT lbl.var; -// | END; -// |END -// |""".stripMargin -// -// val r = spark.sql(sqlScript).collect() -// -// val expected = Seq( -// Seq.empty[Row], // declare var -// Seq(Row(1)), // select -// Seq.empty[Row] // drop var -// ) -// // verifySqlScriptResult(sqlScript, expected) // } test("local variable - set qualified") { @@ -1318,7 +1295,7 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { exception = e, condition = "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", sqlState = "42703", - parameters = Map("objectName" -> "`LOCALVAR`"), + parameters = Map("objectName" -> toSQLId("LOCALVAR")), context = ExpectedContext( fragment = "LOCALVAR", start = 52, @@ -1364,7 +1341,7 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { exception = e, condition = "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", sqlState = "42703", - parameters = Map("objectName" -> "`LBL`.`localVar`"), + parameters = Map("objectName" -> toSQLId("LBL.localVar")), context = ExpectedContext( fragment = "LBL.localVar", start = 52, @@ -1391,7 +1368,128 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { exception = e, condition = "INVALID_VARIABLE_DECLARATION.QUALIFIED_LOCAL_VARIABLE", sqlState = "42K0M", - parameters = Map("varName" -> "lbl.localVar") + parameters = Map("varName" -> toSQLId("lbl.localVar")) + ) + } + + test("local variable - declare var duplicate names") { + val sqlScript = + """ + |BEGIN + | lbl: BEGIN + | DECLARE localVar = 1; + | DECLARE localVar = 2; + | SELECT lbl.localVar; + | END; + |END + |""".stripMargin + + val e = intercept[AnalysisException] { + verifySqlScriptResult(sqlScript, Seq.empty[Seq[Row]]) + } + + checkError( + exception = e, + condition = "VARIABLE_ALREADY_EXISTS", + sqlState = "42723", + parameters = Map("variableName" -> toSQLId("lbl.localVar")) ) } + + test("local variable - leaves scope unqualified") { + val sqlScript = + """ + |BEGIN + | lbl: BEGIN + | DECLARE localVar = 1; + | SELECT localVar; + | END; + | SELECT localVar; + |END + |""".stripMargin + + val e = intercept[AnalysisException] { + verifySqlScriptResult(sqlScript, Seq.empty[Seq[Row]]) + } + + checkError( + exception = e, + condition = "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", + sqlState = "42703", + parameters = Map("objectName" -> toSQLId("localVar")), + context = ExpectedContext(fragment = "localVar", start = 76, stop = 83) + ) + } + + test("local variable - leaves scope qualified") { + val sqlScript = + """ + |BEGIN + | lbl: BEGIN + | DECLARE localVar = 1; + | SELECT lbl.localVar; + | END; + | SELECT lbl.localVar; + |END + |""".stripMargin + + val e = intercept[AnalysisException] { + verifySqlScriptResult(sqlScript, Seq.empty[Seq[Row]]) + } + + checkError( + exception = e, + condition = "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", + sqlState = "42703", + parameters = Map("objectName" -> toSQLId("lbl.localVar")), + context = ExpectedContext(fragment = "lbl.localVar", start = 80, stop = 91) + ) + } + + test("local variable - leaves inner scope") { + val sqlScript = + """ + |BEGIN + | DECLARE localVar = 1; + | lbl: BEGIN + | DECLARE localVar = 2; + | SELECT localVar; + | END; + | SELECT localVar; + |END + |""".stripMargin + + val expected = Seq( + Seq(Row(2)), // select localVar + Seq(Row(1)) // select localVar + ) + verifySqlScriptResult(sqlScript, expected) + } + +// test("local variable - inner inner scope -> inner scope -> session var") { +// withSessionVariable("localVar") { +// spark.sql("DECLARE VARIABLE localVar = 0") +// val sqlScript = +// """ +// |BEGIN +// | lbl1: BEGIN +// | DECLARE localVar = 1; +// | lbl: BEGIN +// | DECLARE localVar = 2; +// | SELECT localVar; +// | END; +// | SELECT localVar; +// | END; +// | SELECT localVar; +// |END +// |""".stripMargin +// +// val expected = Seq( +// Seq(Row(2)), // select localVar +// Seq(Row(1)), // select localVar +// Seq(Row(0)) // select localVar +// ) +// verifySqlScriptResult(sqlScript, expected) +// } +// } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index 0d0e0d6133b28..2ffc91c3be194 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -160,28 +160,6 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { // verifySqlScriptResult(sqlScript, expected) } - test("testtest5") { - val sqlScript = - """ - |BEGIN - | lbl: BEGIN - | DECLARE var = 1; - | DECLARE var = 2; - | SELECT lbl.var; - | END; - |END - |""".stripMargin - - val r = spark.sql(sqlScript).collect() - - val expected = Seq( - Seq.empty[Row], // declare var - Seq(Row(1)), // select - Seq.empty[Row] // drop var - ) - // verifySqlScriptResult(sqlScript, expected) - } - test("testtest6") { val sqlScript = """ From fe5dc7bad924d2d3ecf8d1ccd52fb6984bf4ff3f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Wed, 8 Jan 2025 18:02:33 +0100 Subject: [PATCH 15/59] remove old logic for dropping variables, update tests and add more tests --- .../scripting/SqlScriptingExecutionNode.scala | 4 +- .../scripting/SqlScriptingInterpreter.scala | 30 +-- .../SqlScriptingExecutionSuite.scala | 136 ++++++---- .../SqlScriptingInterpreterSuite.scala | 236 +++--------------- 4 files changed, 137 insertions(+), 269 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 2d50d37e2cb83..3d8ab6796e1db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -36,7 +36,6 @@ sealed trait CompoundStatementExec extends Logging { /** * Whether the statement originates from the SQL script or is created during the interpretation. - * Example: DropVariable statements are automatically created at the end of each compound. */ val isInternal: Boolean = false @@ -113,8 +112,7 @@ trait NonLeafStatementExec extends CompoundStatementExec { * A map of parameter names to SQL literal expressions. * @param isInternal * Whether the statement originates from the SQL script or it is created during the - * interpretation. Example: DropVariable statements are automatically created at the end of each - * compound. + * interpretation. * @param context * SqlScriptingExecutionContext keeps the execution state of current script. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala index 7d00bbb3538df..79201843029e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala @@ -18,10 +18,8 @@ package org.apache.spark.sql.scripting import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.analysis.UnresolvedIdentifier import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.plans.logical.{CaseStatement, CompoundBody, CompoundPlanStatement, CreateVariable, DropVariable, ForStatement, IfElseStatement, IterateStatement, LeaveStatement, LogicalPlan, LoopStatement, RepeatStatement, SingleStatement, WhileStatement} -import org.apache.spark.sql.catalyst.trees.Origin +import org.apache.spark.sql.catalyst.plans.logical.{CaseStatement, CompoundBody, CompoundPlanStatement, ForStatement, IfElseStatement, IterateStatement, LeaveStatement, LoopStatement, RepeatStatement, SingleStatement, WhileStatement} /** * SQL scripting interpreter - builds SQL script execution plan. @@ -50,19 +48,6 @@ case class SqlScriptingInterpreter(session: SparkSession) { .asInstanceOf[CompoundBodyExec] } - /** - * Fetch the name of the Create Variable plan. - * @param plan - * Plan to fetch the name from. - * @return - * Name of the variable. - */ - private def getDeclareVarNameFromPlan(plan: LogicalPlan): Option[UnresolvedIdentifier] = - plan match { - case CreateVariable(name: UnresolvedIdentifier, _, _) => Some(name) - case _ => None - } - /** * Transform the parsed tree to the executable node. * @@ -79,22 +64,11 @@ case class SqlScriptingInterpreter(session: SparkSession) { context: SqlScriptingExecutionContext): CompoundStatementExec = node match { case CompoundBody(collection, label, isScope) => - // TODO [SPARK-48530]: Current logic doesn't support scoped variables and shadowing. - val variables = collection.flatMap { - case st: SingleStatement => getDeclareVarNameFromPlan(st.parsedPlan) - case _ => None - } - val dropVariables = variables - .map(varName => DropVariable(varName, ifExists = true)) - .map(new SingleStatementExec(_, Origin(), args, isInternal = true, context)) - .reverse - val statements = collection - .map(st => transformTreeIntoExecutable(st, args, context)) ++ dropVariables match { + .map(st => transformTreeIntoExecutable(st, args, context)) match { case Nil => Seq(new NoOpStatementExec) case s => s } - new CompoundBodyExec( statements, label, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala index 13d3ec262558e..a7143263ddf5a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala @@ -149,20 +149,6 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { verifySqlScriptResult(sqlScript, expected) } - test("session vars - drop var statement") { - val sqlScript = - """ - |BEGIN - |DECLARE var = 1; - |SET VAR var = var + 1; - |SELECT var; - |DROP TEMPORARY VARIABLE var; - |END - |""".stripMargin - val expected = Seq(Seq(Row(2))) - verifySqlScriptResult(sqlScript, expected) - } - test("if") { val commands = """ @@ -1085,7 +1071,7 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { | DECLARE localVar = 1; | lbl2: BEGIN | DECLARE localVar = 2; - | SELECT var; + | SELECT localVar; | SELECT lbl1.localVar; | SELECT lbl2.localVar; | END; @@ -1392,7 +1378,7 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { exception = e, condition = "VARIABLE_ALREADY_EXISTS", sqlState = "42723", - parameters = Map("variableName" -> toSQLId("lbl.localVar")) + parameters = Map("variableName" -> toSQLId("lbl.localvar")) ) } @@ -1466,30 +1452,96 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { verifySqlScriptResult(sqlScript, expected) } -// test("local variable - inner inner scope -> inner scope -> session var") { -// withSessionVariable("localVar") { -// spark.sql("DECLARE VARIABLE localVar = 0") -// val sqlScript = -// """ -// |BEGIN -// | lbl1: BEGIN -// | DECLARE localVar = 1; -// | lbl: BEGIN -// | DECLARE localVar = 2; -// | SELECT localVar; -// | END; -// | SELECT localVar; -// | END; -// | SELECT localVar; -// |END -// |""".stripMargin -// -// val expected = Seq( -// Seq(Row(2)), // select localVar -// Seq(Row(1)), // select localVar -// Seq(Row(0)) // select localVar -// ) -// verifySqlScriptResult(sqlScript, expected) -// } -// } + test("local variable - inner inner scope -> inner scope -> session var") { + withSessionVariable("localVar") { + spark.sql("DECLARE VARIABLE localVar = 0") + val sqlScript = + """ + |BEGIN + | lbl1: BEGIN + | DECLARE localVar = 1; + | lbl: BEGIN + | DECLARE localVar = 2; + | SELECT localVar; + | END; + | SELECT localVar; + | END; + | SELECT localVar; + |END + |""".stripMargin + + val expected = Seq( + Seq(Row(2)), // select localVar + Seq(Row(1)), // select localVar + Seq(Row(0)) // select localVar + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + // Local variables cannot be dropped. + test("local variable - drop") { + val sqlScript = + """ + |BEGIN + |DECLARE localVar = 1; + |SELECT localVar; + |DROP TEMPORARY VARIABLE localVar; + |END + |""".stripMargin + val e = intercept[AnalysisException] { + verifySqlScriptResult(sqlScript, Seq.empty[Seq[Row]]) + } + checkError( + exception = e, + condition = "VARIABLE_NOT_FOUND", + parameters = Map("variableName" -> toSQLId("system.session.localVar")) + ) + } + + test("local variable - drop session variable") { + val sqlScript = + """ + |BEGIN + |DECLARE localVar = 1; + |SELECT localVar; + |DROP TEMPORARY VARIABLE localVar; + |END + |""".stripMargin + withSessionVariable("localVar") { + spark.sql("DECLARE VARIABLE localVar = 0") + val expected = Seq( + Seq(Row(1)) + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("local variable - drop session variable successfully") { + val sqlScript = + """ + |BEGIN + |DECLARE localVar = 1; + |SELECT system.session.localVar; + |DROP TEMPORARY VARIABLE localVar; + |SELECT system.session.localVar; + |END + |""".stripMargin + withSessionVariable("localVar") { + spark.sql("DECLARE VARIABLE localVar = 0") + val e = intercept[AnalysisException] { + verifySqlScriptResult(sqlScript, Seq.empty[Seq[Row]]) + } + checkError( + exception = e, + condition = "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", + sqlState = "42703", + parameters = Map("objectName" -> toSQLId("system.session.localVar")), + context = ExpectedContext( + fragment = "system.session.localVar", + start = 102, + stop = 124) + ) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index 2ffc91c3be194..b166b92149d28 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -51,16 +51,22 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { val executionPlan = interpreter.buildExecutionPlan(compoundBody, args, context) context.frames.addOne(new SqlScriptingExecutionFrame(executionPlan.getTreeIterator)) executionPlan.enterScope() - - executionPlan.getTreeIterator.flatMap { - case statement: SingleStatementExec => - if (statement.isExecuted) { - None - } else { - Some(Dataset.ofRows(spark, statement.parsedPlan, new QueryPlanningTracker)) - } - case _ => None - }.toArray + spark.sessionState.catalogManager.scriptingLocalVariableManager = + Some(new ScriptingVariableManager(context)) + + try { + executionPlan.getTreeIterator.flatMap { + case statement: SingleStatementExec => + if (statement.isExecuted) { + None + } else { + Some(Dataset.ofRows(spark, statement.parsedPlan, new QueryPlanningTracker)) + } + case _ => None + }.toArray + } finally { + spark.sessionState.catalogManager.scriptingLocalVariableManager = None + } } private def verifySqlScriptResult(sqlText: String, expected: Seq[Seq[Row]]): Unit = { @@ -69,119 +75,6 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { result.zip(expected).foreach { case (df, expectedAnswer) => checkAnswer(df, expectedAnswer) } } - test("testtest") { - val sqlScript = - """ - |BEGIN - | lbl: BEGIN - | DECLARE var = 1; - | SELECT lbl.var; - | END; - |END - |""".stripMargin - - val r = spark.sql(sqlScript).collect() - - val expected = Seq( - Seq.empty[Row], // declare var - Seq(Row(1)), // select - Seq.empty[Row] // drop var - ) -// verifySqlScriptResult(sqlScript, expected) - } - - test("testtest2") { - val sqlScript = - """ - |BEGIN - | lbl: BEGIN - | DECLARE var = 1; - | lbl2: BEGIN - | DECLARE var = 2; - | SELECT lbl.var; - | END; - | END; - |END - |""".stripMargin - - val r = spark.sql(sqlScript).collect() - - val expected = Seq( - Seq.empty[Row], // declare var - Seq(Row(1)), // select - Seq.empty[Row] // drop var - ) -// verifySqlScriptResult(sqlScript, expected) - } - - test("testtest3") { - val sqlScript = - """ - |BEGIN - | lbl: BEGIN - | DECLARE var = 1; - | lbl2: BEGIN - | DECLARE var = 2; - | SELECT lbl.var; - | SET (var, lbl.var) = (select 1, 2); - | END; - | END; - |END - |""".stripMargin - - val r = spark.sql(sqlScript).collect() - - val expected = Seq( - Seq.empty[Row], // declare var - Seq(Row(1)), // select - Seq.empty[Row] // drop var - ) -// verifySqlScriptResult(sqlScript, expected) - } - - test("testtest4") { - val sqlScript = - """ - |BEGIN - | lbl: BEGIN - | DECLARE var = 1; - | SELECT lbl.var; - | END; - |END - |""".stripMargin - - val r = spark.sql(sqlScript).collect() - - val expected = Seq( - Seq.empty[Row], // declare var - Seq(Row(1)), // select - Seq.empty[Row] // drop var - ) - // verifySqlScriptResult(sqlScript, expected) - } - - test("testtest6") { - val sqlScript = - """ - |BEGIN - | lbl: BEGIN - | DECLARE var = 1; - | SET lbl.var = 5; - | SELECT lbl.var; - | END; - |END - |""".stripMargin - - val r = spark.sql(sqlScript).collect() - - val expected = Seq( - Seq.empty[Row], // declare var - Seq(Row(1)), // select - Seq.empty[Row] // drop var - ) - // verifySqlScriptResult(sqlScript, expected) - } - // Tests test("multi statement - simple") { withTable("t") { @@ -296,8 +189,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { val expected = Seq( Seq.empty[Row], // declare var Seq.empty[Row], // set var - Seq(Row(2)), // select - Seq.empty[Row] // drop var + Seq(Row(2)) // select ) verifySqlScriptResult(sqlScript, expected) } @@ -314,8 +206,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { val expected = Seq( Seq.empty[Row], // declare var Seq.empty[Row], // set var - Seq(Row(2)), // select - Seq.empty[Row] // drop var + Seq(Row(2)) // select ) verifySqlScriptResult(sqlScript, expected) } @@ -342,14 +233,11 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { val expected = Seq( Seq.empty[Row], // declare var Seq(Row(1)), // select - Seq.empty[Row], // drop var Seq.empty[Row], // declare var Seq(Row(2)), // select - Seq.empty[Row], // drop var Seq.empty[Row], // declare var Seq.empty[Row], // set var - Seq(Row(4)), // select - Seq.empty[Row] // drop var + Seq(Row(4)) // select ) verifySqlScriptResult(sqlScript, expected) } @@ -381,26 +269,6 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { ) } - test("session vars - drop var statement") { - val sqlScript = - """ - |BEGIN - |DECLARE var = 1; - |SET VAR var = var + 1; - |SELECT var; - |DROP TEMPORARY VARIABLE var; - |END - |""".stripMargin - val expected = Seq( - Seq.empty[Row], // declare var - Seq.empty[Row], // set var - Seq(Row(2)), // select - Seq.empty[Row], // drop var - explicit - Seq.empty[Row] // drop var - implicit - ) - verifySqlScriptResult(sqlScript, expected) - } - test("if") { val commands = """ @@ -1128,8 +996,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq(Row(1)), // select i Seq.empty[Row], // set i Seq(Row(2)), // select i - Seq.empty[Row], // set i - Seq.empty[Row] // drop var + Seq.empty[Row] // set i ) verifySqlScriptResult(commands, expected) } @@ -1147,8 +1014,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { |""".stripMargin val expected = Seq( - Seq.empty[Row], // declare i - Seq.empty[Row] // drop i + Seq.empty[Row] // declare i ) verifySqlScriptResult(commands, expected) } @@ -1184,9 +1050,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq.empty[Row], // increase j Seq(Row(1, 1)), // select i, j Seq.empty[Row], // increase j - Seq.empty[Row], // increase i - Seq.empty[Row], // drop j - Seq.empty[Row] // drop i + Seq.empty[Row] // increase i ) verifySqlScriptResult(commands, expected) } @@ -1236,8 +1100,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq(Row(1)), // select i Seq.empty[Row], // set i Seq(Row(2)), // select i - Seq.empty[Row], // set i - Seq.empty[Row] // drop var + Seq.empty[Row] // set i ) verifySqlScriptResult(commands, expected) } @@ -1259,8 +1122,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { val expected = Seq( Seq.empty[Row], // declare i Seq(Row(3)), // select i - Seq.empty[Row], // set i - Seq.empty[Row] // drop i + Seq.empty[Row] // set i ) verifySqlScriptResult(commands, expected) } @@ -1334,9 +1196,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq.empty[Row], // increase j Seq(Row(1, 1)), // select i, j Seq.empty[Row], // increase j - Seq.empty[Row], // increase i - Seq.empty[Row], // drop j - Seq.empty[Row] // drop i + Seq.empty[Row] // increase i ) verifySqlScriptResult(commands, expected) } @@ -1522,8 +1382,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq.empty[Row], // set x = 0 Seq.empty[Row], // set x = 1 Seq.empty[Row], // set x = 2 - Seq(Row(2)), // select - Seq.empty[Row] // drop + Seq(Row(2)) // select ) verifySqlScriptResult(sqlScriptText, expected) } @@ -1547,8 +1406,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq.empty[Row], // set x = 0 Seq.empty[Row], // set x = 1 Seq.empty[Row], // set x = 2 - Seq(Row(2)), // select x - Seq.empty[Row] // drop + Seq(Row(2)) // select x ) verifySqlScriptResult(sqlScriptText, expected) } @@ -1645,8 +1503,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq(Row(1)), // select 1 Seq.empty[Row], // set x = 2 Seq(Row(1)), // select 1 - Seq(Row(2)), // select x - Seq.empty[Row] // drop + Seq(Row(2)) // select x ) verifySqlScriptResult(sqlScriptText, expected) } @@ -1679,8 +1536,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq.empty[Row], // set x = 2 Seq(Row(1)), // select 1 Seq(Row(2)), // select 2 - Seq(Row(2)), // select x - Seq.empty[Row] // drop + Seq(Row(2)) // select x ) verifySqlScriptResult(sqlScriptText, expected) } @@ -1709,8 +1565,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq(Row(1)), // select 1 Seq.empty[Row], // set x = 2 Seq(Row(1)), // select 1 - Seq(Row(2)), // select x - Seq.empty[Row] // drop + Seq(Row(2)) // select x ) verifySqlScriptResult(sqlScriptText, expected) } @@ -1740,8 +1595,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq(Row(2)), // select x Seq.empty[Row], // set x = 3 Seq(Row(3)), // select x - Seq(Row(3)), // select x - Seq.empty[Row] // drop + Seq(Row(3)) // select x ) verifySqlScriptResult(sqlScriptText, expected) } @@ -1783,9 +1637,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq.empty[Row], // increase y Seq(Row(1, 1)), // select x, y Seq.empty[Row], // increase y - Seq.empty[Row], // increase x - Seq.empty[Row], // drop y - Seq.empty[Row] // drop x + Seq.empty[Row] // increase x ) verifySqlScriptResult(commands, expected) } @@ -1811,8 +1663,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq.empty[Row], // set x = 0 Seq.empty[Row], // set x = 1 Seq.empty[Row], // set x = 2 - Seq(Row(2)), // select x - Seq.empty[Row] // drop + Seq(Row(2)) // select x ) verifySqlScriptResult(sqlScriptText, expected) } @@ -1861,8 +1712,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq.empty[Row], // set x = 2 Seq(Row(1)), // select 1 Seq.empty[Row], // set x = 3 - Seq(Row(3)), // select x - Seq.empty[Row] // drop + Seq(Row(3)) // select x ) verifySqlScriptResult(sqlScriptText, expected) } @@ -1962,8 +1812,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq.empty[Row], // set sumOfCols Seq.empty[Row], // drop local var Seq.empty[Row], // drop local var - Seq(Row(10)), // select sumOfCols - Seq.empty[Row] // drop sumOfCols + Seq(Row(10)) // select sumOfCols ) verifySqlScriptResult(sqlScript, expected) } @@ -2285,8 +2134,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq(Row(1)), // select intCol Seq.empty[Row], // insert Seq.empty[Row], // drop local var - Seq.empty[Row], // drop local var - Seq.empty[Row] // drop cnt + Seq.empty[Row] // drop local var ) verifySqlScriptResult(sqlScript, expected) } @@ -2374,8 +2222,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq(Row(1)), // select intCol Seq.empty[Row], // insert Seq.empty[Row], // drop local var - Seq.empty[Row], // drop local var - Seq.empty[Row] // drop cnt + Seq.empty[Row] // drop local var ) verifySqlScriptResult(sqlScript, expected) } @@ -2572,8 +2419,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq.empty[Row], // set sumOfCols Seq.empty[Row], // set sumOfCols Seq.empty[Row], // drop local var - Seq(Row(10)), // select sumOfCols - Seq.empty[Row] // drop sumOfCols + Seq(Row(10)) // select sumOfCols ) verifySqlScriptResult(sqlScript, expected) } @@ -2804,8 +2650,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq(Row(0)), // select intCol Seq(Row(1)), // select intCol Seq.empty[Row], // insert - Seq.empty[Row], // drop local var - Seq.empty[Row] // drop cnt + Seq.empty[Row] // drop local var ) verifySqlScriptResult(sqlScript, expected) } @@ -2878,8 +2723,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq(Row(0)), // select intCol Seq(Row(1)), // select intCol Seq.empty[Row], // insert - Seq.empty[Row], // drop local var - Seq.empty[Row] // drop cnt + Seq.empty[Row] // drop local var ) verifySqlScriptResult(sqlScript, expected) } From 4f8d2c1d91abf27bfb8ec11423848b7d4737b146 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Thu, 9 Jan 2025 10:48:08 +0100 Subject: [PATCH 16/59] add cleanup for scripting execution, separate drop and create variable paths in resolvecatalogs --- .../catalyst/analysis/ResolveCatalogs.scala | 44 +++++++++++-------- .../org/apache/spark/sql/SparkSession.scala | 34 +++++++------- .../sql/scripting/SqlScriptingExecution.scala | 8 ++-- .../SqlScriptingExecutionSuite.scala | 29 +++++++++++- 4 files changed, 77 insertions(+), 38 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala index 78479debbd02d..5082aacf6572f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala @@ -40,7 +40,7 @@ class ResolveCatalogs(val catalogManager: CatalogManager) val resolved = resolveCreateVariableName(nameParts) c.copy(name = resolved) case d @ DropVariable(UnresolvedIdentifier(nameParts, _), _) => - val resolved = resolveVariableName(nameParts) + val resolved = resolveDropVariableName(nameParts) d.copy(name = resolved) case UnresolvedIdentifier(nameParts, allowTemp) => @@ -82,28 +82,34 @@ class ResolveCatalogs(val catalogManager: CatalogManager) "INVALID_VARIABLE_DECLARATION.QUALIFIED_LOCAL_VARIABLE", Map("varName" -> toSQLId(nameParts))) } - resolveVariableName(nameParts) - } - - private def resolveVariableName(nameParts: Seq[String]): ResolvedIdentifier = { - val ident: Identifier = catalogManager.scriptingLocalVariableManager + val ident = catalogManager.scriptingLocalVariableManager .getOrElse(catalogManager.tempVariableManager) .createIdentifier(nameParts.last) - nameParts.length match { - case 1 | 2 => ResolvedIdentifier(FakeSystemCatalog, ident) + resolveVariableName(nameParts, ident) + } - // When there are 3 nameParts the variable must be a fully qualified session variable - // i.e. "system.session." - case 3 if nameParts(0).equalsIgnoreCase(CatalogManager.SYSTEM_CATALOG_NAME) && - nameParts(1).equalsIgnoreCase(CatalogManager.SESSION_NAMESPACE) => - ResolvedIdentifier(FakeSystemCatalog, ident) + private def resolveDropVariableName(nameParts: Seq[String]): ResolvedIdentifier = { + // Only session variables can be dropped, so catalogManager.scriptingLocalVariableManager + // is not checked in the case of DropVariable. + val ident = catalogManager.tempVariableManager.createIdentifier(nameParts.last) + resolveVariableName(nameParts, ident) + } - case _ => - // todo LOCALVARS: update this error - throw QueryCompilationErrors.unresolvedVariableError( - nameParts, Seq(CatalogManager.SYSTEM_CATALOG_NAME, CatalogManager.SESSION_NAMESPACE) - ) - } + private def resolveVariableName( + nameParts: Seq[String], + ident: Identifier): ResolvedIdentifier = nameParts.length match { + case 1 | 2 => ResolvedIdentifier(FakeSystemCatalog, ident) + + // When there are 3 nameParts the variable must be a fully qualified session variable + // i.e. "system.session." + case 3 if nameParts(0).equalsIgnoreCase(CatalogManager.SYSTEM_CATALOG_NAME) && + nameParts(1).equalsIgnoreCase(CatalogManager.SESSION_NAMESPACE) => + ResolvedIdentifier(FakeSystemCatalog, ident) + + case _ => + throw QueryCompilationErrors.unresolvedVariableError( + nameParts, Seq(CatalogManager.SYSTEM_CATALOG_NAME, ident.namespace().head) + ) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 878fdc8e267a5..13f1dda963d27 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -448,24 +448,28 @@ class SparkSession private( val sse = new SqlScriptingExecution(script, this, args) var result: Option[Seq[Row]] = None - while (sse.hasNext) { - sse.withErrorHandling { - val df = sse.next() - if (sse.hasNext) { - df.write.format("noop").mode("overwrite").save() - } else { - // Collect results from the last DataFrame. - result = Some(df.collect().toSeq) + try { + while (sse.hasNext) { + sse.withErrorHandling { + val df = sse.next() + if (sse.hasNext) { + df.write.format("noop").mode("overwrite").save() + } else { + // Collect results from the last DataFrame. + result = Some(df.collect().toSeq) + } } } - } - if (result.isEmpty) { - emptyDataFrame - } else { - val attributes = DataTypeUtils.toAttributes(result.get.head.schema) - Dataset.ofRows( - self, LocalRelation.fromExternalRows(attributes, result.get)) + if (result.isEmpty) { + emptyDataFrame + } else { + val attributes = DataTypeUtils.toAttributes(result.get.head.schema) + Dataset.ofRows( + self, LocalRelation.fromExternalRows(attributes, result.get)) + } + } finally { + sse.cleanup() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala index 0f10e7d7b9722..50895a83187c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala @@ -74,9 +74,6 @@ class SqlScriptingExecution( if (context.frames.nonEmpty) { return Some(context.frames.last.next()) } - // cleanup variable manager after script is completed - // todo LOCALVARS: figure out a better way to do this, also cleanup when script fails - session.sessionState.catalogManager.scriptingLocalVariableManager = None None } @@ -115,4 +112,9 @@ class SqlScriptingExecution( handleException(e) } } + + /** Cleans up resources associated with the execution. */ + def cleanup(): Unit = { + session.sessionState.catalogManager.scriptingLocalVariableManager = None + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala index a7143263ddf5a..1041e6cf8086e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala @@ -44,7 +44,11 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { args: Map[String, Expression] = Map.empty): Seq[Array[Row]] = { val compoundBody = spark.sessionState.sqlParser.parsePlan(sqlText).asInstanceOf[CompoundBody] val sse = new SqlScriptingExecution(compoundBody, spark, args) - sse.map { df => df.collect() }.toList + try { + sse.map { df => df.collect() }.toList + } finally { + sse.cleanup() + } } private def verifySqlScriptResult(sqlText: String, expected: Seq[Seq[Row]]): Unit = { @@ -1499,6 +1503,29 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { ) } + test("drop variable - drop too many nameparts") { + val sqlScript = + """ + |BEGIN + | lbl: BEGIN + | DECLARE localVar = 1; + | SELECT localVar; + | DROP TEMPORARY VARIABLE a.b.c.d; + | END; + |END + |""".stripMargin + val e = intercept[AnalysisException] { + verifySqlScriptResult(sqlScript, Seq.empty[Seq[Row]]) + } + checkError( + exception = e, + condition = "UNRESOLVED_VARIABLE", + parameters = Map( + "variableName" -> toSQLId("a.b.c.d"), + "searchPath" -> toSQLId("system.session")) + ) + } + test("local variable - drop session variable") { val sqlScript = """ From ba5b8d263cfa2a54e015a02de8d7b25ecdf507b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Thu, 9 Jan 2025 11:59:51 +0100 Subject: [PATCH 17/59] fix resolvecatalogs and add more tests --- .../catalyst/analysis/ResolveCatalogs.scala | 5 +- .../connector/catalog/CatalogManager.scala | 2 + .../parser/SqlScriptingParserSuite.scala | 12 - .../SqlScriptingExecutionSuite.scala | 272 +++++++++--------- 4 files changed, 142 insertions(+), 149 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala index 5082aacf6572f..d3a0d62f4c534 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala @@ -99,7 +99,10 @@ class ResolveCatalogs(val catalogManager: CatalogManager) private def resolveVariableName( nameParts: Seq[String], ident: Identifier): ResolvedIdentifier = nameParts.length match { - case 1 | 2 => ResolvedIdentifier(FakeSystemCatalog, ident) + case 1 => ResolvedIdentifier(FakeSystemCatalog, ident) + + case 2 if nameParts.head.equalsIgnoreCase(CatalogManager.SESSION_NAMESPACE) => + ResolvedIdentifier(FakeSystemCatalog, ident) // When there are 3 nameParts the variable must be a fully qualified session variable // i.e. "system.session." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala index 70764644e3fb0..81db7166a6337 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.connector.catalog import scala.collection.mutable + import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.catalog.{SessionCatalog, TempVariableManager, VariableManager} @@ -48,6 +49,7 @@ class CatalogManager( // TODO: create a real SYSTEM catalog to host `TempVariableManager` under the SESSION namespace. val tempVariableManager: TempVariableManager = new TempVariableManager + // This field will be populated and cleaned up by SqlScriptingExecution. var scriptingLocalVariableManager: Option[VariableManager] = None def catalog(name: String): CatalogPlugin = synchronized { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala index e4cc67c37dd83..c9e2f42e164f9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala @@ -46,18 +46,6 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { assert(!statement.isInstanceOf[CompoundBody]) } - test("testtest") { - val sqlScriptText = - """ - |BEGIN - |DECLARE `my.var.i.ab.le` = 1; - |SELECT `my.var.i.ab.le` + `my.var.i.ab.le` * 2; - |END - |""".stripMargin - val statement = parsePlan(sqlScriptText) - assert(!statement.isInstanceOf[CompoundBody]) - } - test("multi select without ; - should fail") { val sqlScriptText = "SELECT 1 SELECT 1" val e = intercept[ParseException] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala index 1041e6cf8086e..5912d7ba3b8c0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala @@ -1091,93 +1091,6 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { verifySqlScriptResult(sqlScript, expected) } - // todo: fix this case - -// test("testtest3") { -// val sqlScript = -// """ -// |BEGIN -// | lbl: BEGIN -// | DECLARE var = 1; -// | lbl2: BEGIN -// | DECLARE var = 2; -// | SELECT lbl.var; -// | SET (var, lbl.var) = (select 1, 2); -// | END; -// | END; -// |END -// |""".stripMargin -// -// val r = spark.sql(sqlScript).collect() -// -// val expected = Seq( -// Seq.empty[Row], // declare var -// Seq(Row(1)), // select -// Seq.empty[Row] // drop var -// ) -// // verifySqlScriptResult(sqlScript, expected) -// } - - test("local variable - set qualified") { - val sqlScript = - """ - |BEGIN - | lbl: BEGIN - | DECLARE localVar = 1; - | SELECT lbl.localVar; - | SET lbl.localVar = 5; - | SELECT lbl.localVar; - | END; - |END - |""".stripMargin - - val expected = Seq( - Seq(Row(1)), // select lbl.localVar - Seq(Row(5)) // select lbl.localVar - ) - verifySqlScriptResult(sqlScript, expected) - } - - test("local variable - set unqualified") { - val sqlScript = - """ - |BEGIN - | lbl: BEGIN - | DECLARE localVar = 1; - | SELECT localVar; - | SET localVar = 5; - | SELECT localVar; - | END; - |END - |""".stripMargin - - val expected = Seq( - Seq(Row(1)), // select localVar - Seq(Row(5)) // select localVar - ) - verifySqlScriptResult(sqlScript, expected) - } - - test("local variable - set unqualified select qualified") { - val sqlScript = - """ - |BEGIN - | lbl: BEGIN - | DECLARE localVar = 1; - | SELECT lbl.localVar; - | SET localVar = 5; - | SELECT lbl.localVar; - | END; - |END - |""".stripMargin - - val expected = Seq( - Seq(Row(1)), // select lbl.localVar - Seq(Row(5)) // select lbl.localVar - ) - verifySqlScriptResult(sqlScript, expected) - } - test("local variable - resolved over session variable") { withSessionVariable("localVar") { spark.sql("DECLARE VARIABLE localVar = 1") @@ -1339,53 +1252,6 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { ) } - test("local variable - qualified declare") { - val sqlScript = - """ - |BEGIN - | lbl: BEGIN - | DECLARE lbl.localVar = 1; - | SELECT lbl.localVar; - | END; - |END - |""".stripMargin - - val e = intercept[AnalysisException] { - verifySqlScriptResult(sqlScript, Seq.empty[Seq[Row]]) - } - - checkError( - exception = e, - condition = "INVALID_VARIABLE_DECLARATION.QUALIFIED_LOCAL_VARIABLE", - sqlState = "42K0M", - parameters = Map("varName" -> toSQLId("lbl.localVar")) - ) - } - - test("local variable - declare var duplicate names") { - val sqlScript = - """ - |BEGIN - | lbl: BEGIN - | DECLARE localVar = 1; - | DECLARE localVar = 2; - | SELECT lbl.localVar; - | END; - |END - |""".stripMargin - - val e = intercept[AnalysisException] { - verifySqlScriptResult(sqlScript, Seq.empty[Seq[Row]]) - } - - checkError( - exception = e, - condition = "VARIABLE_ALREADY_EXISTS", - sqlState = "42723", - parameters = Map("variableName" -> toSQLId("lbl.localvar")) - ) - } - test("local variable - leaves scope unqualified") { val sqlScript = """ @@ -1483,6 +1349,53 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { } } + test("local variable - declare - qualified") { + val sqlScript = + """ + |BEGIN + | lbl: BEGIN + | DECLARE lbl.localVar = 1; + | SELECT lbl.localVar; + | END; + |END + |""".stripMargin + + val e = intercept[AnalysisException] { + verifySqlScriptResult(sqlScript, Seq.empty[Seq[Row]]) + } + + checkError( + exception = e, + condition = "INVALID_VARIABLE_DECLARATION.QUALIFIED_LOCAL_VARIABLE", + sqlState = "42K0M", + parameters = Map("varName" -> toSQLId("lbl.localVar")) + ) + } + + test("local variable - declare - duplicate names") { + val sqlScript = + """ + |BEGIN + | lbl: BEGIN + | DECLARE localVar = 1; + | DECLARE localVar = 2; + | SELECT lbl.localVar; + | END; + |END + |""".stripMargin + + val e = intercept[AnalysisException] { + verifySqlScriptResult(sqlScript, Seq.empty[Seq[Row]]) + } + + checkError( + exception = e, + condition = "VARIABLE_ALREADY_EXISTS", + sqlState = "42723", + parameters = Map("variableName" -> toSQLId("lbl.localvar")) + ) + } + // Local variables cannot be dropped. test("local variable - drop") { val sqlScript = @@ -1503,7 +1416,7 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { ) } - test("drop variable - drop too many nameparts") { + test("drop variable - drop - too many nameparts") { val sqlScript = """ |BEGIN @@ -1526,7 +1439,7 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { ) } - test("local variable - drop session variable") { + test("local variable - drop - session variable") { val sqlScript = """ |BEGIN @@ -1571,4 +1484,91 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { ) } } + + test("local variable - set - qualified") { + val sqlScript = + """ + |BEGIN + | lbl: BEGIN + | DECLARE localVar = 1; + | SELECT lbl.localVar; + | SET lbl.localVar = 5; + | SELECT lbl.localVar; + | END; + |END + |""".stripMargin + + val expected = Seq( + Seq(Row(1)), // select lbl.localVar + Seq(Row(5)) // select lbl.localVar + ) + verifySqlScriptResult(sqlScript, expected) + } + + test("local variable - set - unqualified") { + val sqlScript = + """ + |BEGIN + | lbl: BEGIN + | DECLARE localVar = 1; + | SELECT localVar; + | SET localVar = 5; + | SELECT localVar; + | END; + |END + |""".stripMargin + + val expected = Seq( + Seq(Row(1)), // select localVar + Seq(Row(5)) // select localVar + ) + verifySqlScriptResult(sqlScript, expected) + } + + test("local variable - set - set unqualified select qualified") { + val sqlScript = + """ + |BEGIN + | lbl: BEGIN + | DECLARE localVar = 1; + | SELECT lbl.localVar; + | SET localVar = 5; + | SELECT lbl.localVar; + | END; + |END + |""".stripMargin + + val expected = Seq( + Seq(Row(1)), // select lbl.localVar + Seq(Row(5)) // select lbl.localVar + ) + verifySqlScriptResult(sqlScript, expected) + } + + // todo: fix this case + + // test("testtest3") { + // val sqlScript = + // """ + // |BEGIN + // | lbl: BEGIN + // | DECLARE var = 1; + // | lbl2: BEGIN + // | DECLARE var = 2; + // | SELECT lbl.var; + // | SET (var, lbl.var) = (select 1, 2); + // | END; + // | END; + // |END + // |""".stripMargin + // + // val r = spark.sql(sqlScript).collect() + // + // val expected = Seq( + // Seq.empty[Row], // declare var + // Seq(Row(1)), // select + // Seq.empty[Row] // drop var + // ) + // // verifySqlScriptResult(sqlScript, expected) + // } } From 33f0aac516df0b1b7ff6c8bba3d648bbb0091a11 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Thu, 9 Jan 2025 17:51:31 +0100 Subject: [PATCH 18/59] refactor to support properly setting variables --- .../analysis/ColumnResolutionHelper.scala | 4 +- .../analysis/ResolveSetVariable.scala | 1 - .../catalyst/catalog/VariableManager.scala | 31 +++++- .../sql/catalyst/analysis/AnalysisSuite.scala | 4 +- .../sql/catalyst/analysis/AnalysisTest.scala | 8 +- .../command/v2/CreateVariableExec.scala | 5 +- .../command/v2/SetVariableExec.scala | 44 ++++++-- .../scripting/ScriptingVariableManager.scala | 41 ++++++- .../SqlScriptingExecutionSuite.scala | 103 +++++++++++++++++- 9 files changed, 206 insertions(+), 35 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala index a4e2ffb9cb595..185627397cd1d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala @@ -272,8 +272,10 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { } catalogManager.scriptingLocalVariableManager - // if variable name is qualified with system.session. treat it as a session variable + // If variable name is qualified with system.session. treat it as a session variable. .filterNot(_ => namePartsCaseAdjusted.take(2) == Seq("system", "session")) + // Local variable must be in format or