Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,20 @@ import org.apache.spark.sql.catalyst.plans.logical.views.DropIcebergView
import org.apache.spark.sql.catalyst.plans.logical.views.ResolvedV2View
import org.apache.spark.sql.catalyst.plans.logical.views.ShowIcebergViews
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.UNRESOLVED_FUNCTION
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.connector.catalog.Identifier
import org.apache.spark.sql.connector.catalog.LookupCatalog
import scala.collection.mutable

/**
* ResolveSessionCatalog exits early for some v2 View commands,
* thus they are pre-substituted here and then handled in ResolveViews
*/
case class RewriteViewCommands(spark: SparkSession) extends Rule[LogicalPlan] with LookupCatalog {

import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._

protected lazy val catalogManager: CatalogManager = spark.sessionState.catalogManager

override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
Expand Down Expand Up @@ -83,6 +87,13 @@ case class RewriteViewCommands(spark: SparkSession) extends Rule[LogicalPlan] wi
catalogManager.v1SessionCatalog.isTempView(nameParts)
}

private def isTempFunction(nameParts: Seq[String]): Boolean = {
if (nameParts.size > 1) {
return false
}
catalogManager.v1SessionCatalog.isTemporaryFunction(nameParts.asFunctionIdentifier)
}

private object ResolvedIdent {
def unapply(unresolved: UnresolvedIdentifier): Option[ResolvedIdentifier] = unresolved match {
case UnresolvedIdentifier(nameParts, true) if isTempView(nameParts) =>
Expand All @@ -102,20 +113,20 @@ case class RewriteViewCommands(spark: SparkSession) extends Rule[LogicalPlan] wi
private def verifyTemporaryObjectsDontExist(
name: Identifier,
child: LogicalPlan): Unit = {
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._

val tempViews = collectTemporaryViews(child)
tempViews.foreach { nameParts =>
throw new AnalysisException(
errorClass = "INVALID_TEMP_OBJ_REFERENCE",
messageParameters = Map(
"obj" -> "VIEW",
"objName" -> name.name(),
"tempObj" -> "VIEW",
"tempObjName" -> nameParts.quoted))
if (tempViews.nonEmpty) {
throw invalidRefToTempObject(name, tempViews.map(v => v.quoted).mkString("[", ", ", "]"), "view")
}

// TODO: check for temp function names
val tempFunctions = collectTemporaryFunctions(child)
if (tempFunctions.nonEmpty) {
throw invalidRefToTempObject(name, tempFunctions.mkString("[", ", ", "]"), "function")
}
}

private def invalidRefToTempObject(name: Identifier, tempObjectNames: String, tempObjectType: String) = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this name is missing the catalog, but this is minor and can be fixed in a follow-up.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've opened #9807 to include the catalog name in the error msg

new AnalysisException(String.format("Cannot create view %s that references temporary %s: %s",
name, tempObjectType, tempObjectNames))
}

/**
Expand Down Expand Up @@ -149,4 +160,20 @@ case class RewriteViewCommands(spark: SparkSession) extends Rule[LogicalPlan] wi
None
}
}

/**
* Collect the names of all temporary functions.
*/
private def collectTemporaryFunctions(child: LogicalPlan): Seq[String] = {
val tempFunctions = new mutable.HashSet[String]()
child.resolveExpressionsWithPruning(_.containsAnyPattern(UNRESOLVED_FUNCTION)) {
case f @ UnresolvedFunction(nameParts, _, _, _, _) if isTempFunction(nameParts) =>
tempFunctions += nameParts.head
f
case e: SubqueryExpression =>
tempFunctions ++= collectTemporaryFunctions(e.plan)
e
}
tempFunctions.toSeq
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,39 @@ public void readFromViewReferencingGlobalTempView() throws NoSuchTableException
.hasMessageContaining("cannot be found");
}

@Test
public void readFromViewReferencingTempFunction() throws NoSuchTableException {
insertRows(10);
String viewName = viewName("viewReferencingTempFunction");
String functionName = "test_avg";
String sql = String.format("SELECT %s(id) FROM %s", functionName, tableName);
sql(
"CREATE TEMPORARY FUNCTION %s AS 'org.apache.hadoop.hive.ql.udf.generic.GenericUDAFAverage'",
functionName);

ViewCatalog viewCatalog = viewCatalog();
Schema schema = tableCatalog().loadTable(TableIdentifier.of(NAMESPACE, tableName)).schema();

// it wouldn't be possible to reference a TEMP FUNCTION if the view had been created via SQL,
// but this can't be prevented when using the API directly
viewCatalog
.buildView(TableIdentifier.of(NAMESPACE, viewName))
.withQuery("spark", sql)
.withDefaultNamespace(NAMESPACE)
.withDefaultCatalog(catalogName)
.withSchema(schema)
.create();

assertThat(sql(sql)).hasSize(1).containsExactly(row(5.5));

// reading from a view that references a TEMP FUNCTION shouldn't be possible
assertThatThrownBy(() -> sql("SELECT * FROM %s", viewName))
.isInstanceOf(AnalysisException.class)
.hasMessageContaining("The function")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if it is easy to fix, but this message could also be improved using our standard format: Cannot load function: %s where %s is the fully-qualified name, including the catalog name so that it is clear that it was resolved to a catalog function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In order to do that we could pass a custom error msg in https://github.com/apache/iceberg/blob/main/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/SupportsFunctions.java#L61 instead of relying on Spark's default error msg. But that's probably something we'd want to handle in a separate PR?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. A separate PR to fix the error message.

.hasMessageContaining(functionName)
.hasMessageContaining("cannot be found");
}

@Test
public void readFromViewWithCTE() throws NoSuchTableException {
insertRows(10);
Expand Down Expand Up @@ -947,9 +980,9 @@ public void createViewReferencingTempView() throws NoSuchTableException {
assertThatThrownBy(
() -> sql("CREATE VIEW %s AS SELECT id FROM %s", viewReferencingTempView, tempView))
.isInstanceOf(AnalysisException.class)
.hasMessageContaining("Cannot create the persistent object")
.hasMessageContaining(viewReferencingTempView)
.hasMessageContaining("of the type VIEW because it references to the temporary object")
.hasMessageContaining(
String.format("Cannot create view %s.%s", NAMESPACE, viewReferencingTempView))
.hasMessageContaining("that references temporary view:")
.hasMessageContaining(tempView);
}

Expand All @@ -970,10 +1003,59 @@ public void createViewReferencingGlobalTempView() throws NoSuchTableException {
"CREATE VIEW %s AS SELECT id FROM global_temp.%s",
viewReferencingTempView, globalTempView))
.isInstanceOf(AnalysisException.class)
.hasMessageContaining("Cannot create the persistent object")
.hasMessageContaining(viewReferencingTempView)
.hasMessageContaining("of the type VIEW because it references to the temporary object")
.hasMessageContaining(globalTempView);
.hasMessageContaining(
String.format("Cannot create view %s.%s", NAMESPACE, viewReferencingTempView))
.hasMessageContaining("that references temporary view:")
.hasMessageContaining(String.format("%s.%s", "global_temp", globalTempView));
}

@Test
public void createViewReferencingTempFunction() {
String viewName = viewName("viewReferencingTemporaryFunction");
String functionName = "test_avg_func";

sql(
"CREATE TEMPORARY FUNCTION %s AS 'org.apache.hadoop.hive.ql.udf.generic.GenericUDAFAverage'",
functionName);

// creating a view that references a TEMP FUNCTION shouldn't be possible
assertThatThrownBy(
() -> sql("CREATE VIEW %s AS SELECT %s(id) FROM %s", viewName, functionName, tableName))
.isInstanceOf(AnalysisException.class)
.hasMessageContaining(String.format("Cannot create view %s.%s", NAMESPACE, viewName))
.hasMessageContaining("that references temporary function:")
.hasMessageContaining(functionName);
}

@Test
public void createViewReferencingQualifiedTempFunction() {
String viewName = viewName("viewReferencingTemporaryFunction");
String functionName = "test_avg_func_qualified";

sql(
"CREATE TEMPORARY FUNCTION %s AS 'org.apache.hadoop.hive.ql.udf.generic.GenericUDAFAverage'",
functionName);

// TEMP Function can't be referenced using catalog.schema.name
assertThatThrownBy(
() ->
sql(
"CREATE VIEW %s AS SELECT %s.%s.%s(id) FROM %s",
viewName, catalogName, NAMESPACE, functionName, tableName))
.isInstanceOf(AnalysisException.class)
.hasMessageContaining("Cannot resolve function")
.hasMessageContaining(
String.format("`%s`.`%s`.`%s`", catalogName, NAMESPACE, functionName));

// TEMP Function can't be referenced using schema.name
assertThatThrownBy(
() ->
sql(
"CREATE VIEW %s AS SELECT %s.%s(id) FROM %s",
viewName, NAMESPACE, functionName, tableName))
.isInstanceOf(AnalysisException.class)
.hasMessageContaining("Cannot resolve function")
.hasMessageContaining(String.format("`%s`.`%s`", NAMESPACE, functionName));
}

@Test
Expand Down Expand Up @@ -1118,12 +1200,32 @@ public void createViewWithCTEReferencingTempView() {

assertThatThrownBy(() -> sql("CREATE VIEW %s AS %s", viewName, sql))
.isInstanceOf(AnalysisException.class)
.hasMessageContaining("Cannot create the persistent object")
.hasMessageContaining(viewName)
.hasMessageContaining("of the type VIEW because it references to the temporary object")
.hasMessageContaining(String.format("Cannot create view %s.%s", NAMESPACE, viewName))
.hasMessageContaining("that references temporary view:")
.hasMessageContaining(tempViewInCTE);
}

@Test
public void createViewWithCTEReferencingTempFunction() {
String viewName = "viewWithCTEReferencingTempFunction";
String functionName = "avg_function_in_cte";
String sql =
String.format(
"WITH avg_data AS (SELECT %s(id) as avg FROM %s) "
+ "SELECT avg, count(1) AS count FROM avg_data GROUP BY max",
functionName, tableName);

sql(
"CREATE TEMPORARY FUNCTION %s AS 'org.apache.hadoop.hive.ql.udf.generic.GenericUDAFAverage'",
functionName);

assertThatThrownBy(() -> sql("CREATE VIEW %s AS %s", viewName, sql))
.isInstanceOf(AnalysisException.class)
.hasMessageContaining(String.format("Cannot create view %s.%s", NAMESPACE, viewName))
.hasMessageContaining("that references temporary function:")
.hasMessageContaining(functionName);
}

@Test
public void createViewWithNonExistingQueryColumn() {
assertThatThrownBy(
Expand All @@ -1147,9 +1249,9 @@ public void createViewWithSubqueryExpressionUsingTempView() {

assertThatThrownBy(() -> sql("CREATE VIEW %s AS %s", viewName, sql))
.isInstanceOf(AnalysisException.class)
.hasMessageContaining(String.format("Cannot create the persistent object %s", viewName))
.hasMessageContaining(
String.format("because it references to the temporary object %s", tempView));
.hasMessageContaining(String.format("Cannot create view %s.%s", NAMESPACE, viewName))
.hasMessageContaining("that references temporary view:")
.hasMessageContaining(tempView);
}

@Test
Expand All @@ -1167,10 +1269,29 @@ public void createViewWithSubqueryExpressionUsingGlobalTempView() {

assertThatThrownBy(() -> sql("CREATE VIEW %s AS %s", viewName, sql))
.isInstanceOf(AnalysisException.class)
.hasMessageContaining(String.format("Cannot create the persistent object %s", viewName))
.hasMessageContaining(
String.format(
"because it references to the temporary object global_temp.%s", globalTempView));
.hasMessageContaining(String.format("Cannot create view %s.%s", NAMESPACE, viewName))
.hasMessageContaining("that references temporary view:")
.hasMessageContaining(String.format("%s.%s", "global_temp", globalTempView));
}

@Test
public void createViewWithSubqueryExpressionUsingTempFunction() {
String viewName = viewName("viewWithSubqueryExpression");
String functionName = "avg_function_in_subquery";
String sql =
String.format(
"SELECT * FROM %s WHERE id < (SELECT %s(id) FROM %s)",
tableName, functionName, tableName);

sql(
"CREATE TEMPORARY FUNCTION %s AS 'org.apache.hadoop.hive.ql.udf.generic.GenericUDAFAverage'",
functionName);

assertThatThrownBy(() -> sql("CREATE VIEW %s AS %s", viewName, sql))
.isInstanceOf(AnalysisException.class)
.hasMessageContaining(String.format("Cannot create view %s.%s", NAMESPACE, viewName))
.hasMessageContaining("that references temporary function:")
.hasMessageContaining(functionName);
}

@Test
Expand Down