From 06c741a0061bcf2c6e2c08212cab9f4e774cb70a Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Fri, 23 Feb 2024 09:26:13 -0800 Subject: [PATCH] [SPARK-47129][CONNECT][SQL] Make `ResolveRelations` cache connect plan properly ### What changes were proposed in this pull request? Make `ResolveRelations` handle plan id properly ### Why are the changes needed? bug fix for Spark Connect, it won't affect classic Spark SQL before this PR: ``` from pyspark.sql import functions as sf spark.range(10).withColumn("value_1", sf.lit(1)).write.saveAsTable("test_table_1") spark.range(10).withColumnRenamed("id", "index").withColumn("value_2", sf.lit(2)).write.saveAsTable("test_table_2") df1 = spark.read.table("test_table_1") df2 = spark.read.table("test_table_2") df3 = spark.read.table("test_table_1") join1 = df1.join(df2, on=df1.id==df2.index).select(df2.index, df2.value_2) join2 = df3.join(join1, how="left", on=join1.index==df3.id) join2.schema ``` fails with ``` AnalysisException: [CANNOT_RESOLVE_DATAFRAME_COLUMN] Cannot resolve dataframe column "id". It's probably because of illegal references like `df1.select(df2.col("a"))`. SQLSTATE: 42704 ``` That is due to existing plan caching in `ResolveRelations` doesn't work with Spark Connect ``` === Applying Rule org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveRelations === '[#12]Join LeftOuter, '`==`('index, 'id) '[#12]Join LeftOuter, '`==`('index, 'id) !:- '[#9]UnresolvedRelation [test_table_1], [], false :- '[#9]SubqueryAlias spark_catalog.default.test_table_1 !+- '[#11]Project ['index, 'value_2] : +- 'UnresolvedCatalogRelation `spark_catalog`.`default`.`test_table_1`, [], false ! +- '[#10]Join Inner, '`==`('id, 'index) +- '[#11]Project ['index, 'value_2] ! :- '[#7]UnresolvedRelation [test_table_1], [], false +- '[#10]Join Inner, '`==`('id, 'index) ! +- '[#8]UnresolvedRelation [test_table_2], [], false :- '[#9]SubqueryAlias spark_catalog.default.test_table_1 ! : +- 'UnresolvedCatalogRelation `spark_catalog`.`default`.`test_table_1`, [], false ! +- '[#8]SubqueryAlias spark_catalog.default.test_table_2 ! +- 'UnresolvedCatalogRelation `spark_catalog`.`default`.`test_table_2`, [], false Can not resolve 'id with plan 7 ``` `[#7]UnresolvedRelation [test_table_1], [], false` was wrongly resolved to the cached one ``` :- '[#9]SubqueryAlias spark_catalog.default.test_table_1 +- 'UnresolvedCatalogRelation `spark_catalog`.`default`.`test_table_1`, [], false ``` ### Does this PR introduce _any_ user-facing change? yes, bug fix ### How was this patch tested? added ut ### Was this patch authored or co-authored using generative AI tooling? ci Closes #45214 from zhengruifeng/connect_fix_read_join. Authored-by: Ruifeng Zheng Signed-off-by: Dongjoon Hyun --- python/pyspark/sql/tests/test_readwriter.py | 23 +++++++++++++++- .../sql/catalyst/analysis/Analyzer.scala | 27 ++++++++++++++----- 2 files changed, 42 insertions(+), 8 deletions(-) diff --git a/python/pyspark/sql/tests/test_readwriter.py b/python/pyspark/sql/tests/test_readwriter.py index 70a320fc53b69..85057f37a1817 100644 --- a/python/pyspark/sql/tests/test_readwriter.py +++ b/python/pyspark/sql/tests/test_readwriter.py @@ -20,7 +20,7 @@ import tempfile from pyspark.errors import AnalysisException -from pyspark.sql.functions import col +from pyspark.sql.functions import col, lit from pyspark.sql.readwriter import DataFrameWriterV2 from pyspark.sql.types import StructType, StructField, StringType from pyspark.testing.sqlutils import ReusedSQLTestCase @@ -181,6 +181,27 @@ def test_insert_into(self): df.write.mode("overwrite").insertInto("test_table", False) self.assertEqual(6, self.spark.sql("select * from test_table").count()) + def test_cached_table(self): + with self.table("test_cached_table_1"): + self.spark.range(10).withColumn( + "value_1", + lit(1), + ).write.saveAsTable("test_cached_table_1") + + with self.table("test_cached_table_2"): + self.spark.range(10).withColumnRenamed("id", "index").withColumn( + "value_2", lit(2) + ).write.saveAsTable("test_cached_table_2") + + df1 = self.spark.read.table("test_cached_table_1") + df2 = self.spark.read.table("test_cached_table_2") + df3 = self.spark.read.table("test_cached_table_1") + + join1 = df1.join(df2, on=df1.id == df2.index).select(df2.index, df2.value_2) + join2 = df3.join(join1, how="left", on=join1.index == df3.id) + + self.assertEqual(join2.columns, ["id", "value_1", "index", "value_2"]) + class ReadwriterV2TestsMixin: def test_api(self): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index d8127fe03da4e..1fb5d00bdf39a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1275,16 +1275,29 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor val key = ((catalog.name +: ident.namespace :+ ident.name).toImmutableArraySeq, finalTimeTravelSpec) - AnalysisContext.get.relationCache.get(key).map(_.transform { - case multi: MultiInstanceRelation => - val newRelation = multi.newInstance() - newRelation.copyTagsFrom(multi) - newRelation - }).orElse { + AnalysisContext.get.relationCache.get(key).map { cache => + val cachedRelation = cache.transform { + case multi: MultiInstanceRelation => + val newRelation = multi.newInstance() + newRelation.copyTagsFrom(multi) + newRelation + } + u.getTagValue(LogicalPlan.PLAN_ID_TAG).map { planId => + val cachedConnectRelation = cachedRelation.clone() + cachedConnectRelation.setTagValue(LogicalPlan.PLAN_ID_TAG, planId) + cachedConnectRelation + }.getOrElse(cachedRelation) + }.orElse { val table = CatalogV2Util.loadTable(catalog, ident, finalTimeTravelSpec) val loaded = createRelation(catalog, ident, table, u.options, u.isStreaming) loaded.foreach(AnalysisContext.get.relationCache.update(key, _)) - loaded + u.getTagValue(LogicalPlan.PLAN_ID_TAG).map { planId => + loaded.map { loadedRelation => + val loadedConnectRelation = loadedRelation.clone() + loadedConnectRelation.setTagValue(LogicalPlan.PLAN_ID_TAG, planId) + loadedConnectRelation + } + }.getOrElse(loaded) } case _ => None }