Skip to content

Commit

Permalink
[SPARK-47129][CONNECT][SQL] Make ResolveRelations cache connect pla…
Browse files Browse the repository at this point in the history
…n properly

### What changes were proposed in this pull request?
Make `ResolveRelations` handle plan id properly

### Why are the changes needed?
bug fix for Spark Connect, it won't affect classic Spark SQL

before this PR:
```
from pyspark.sql import functions as sf

spark.range(10).withColumn("value_1", sf.lit(1)).write.saveAsTable("test_table_1")
spark.range(10).withColumnRenamed("id", "index").withColumn("value_2", sf.lit(2)).write.saveAsTable("test_table_2")

df1 = spark.read.table("test_table_1")
df2 = spark.read.table("test_table_2")
df3 = spark.read.table("test_table_1")

join1 = df1.join(df2, on=df1.id==df2.index).select(df2.index, df2.value_2)
join2 = df3.join(join1, how="left", on=join1.index==df3.id)

join2.schema
```

fails with
```
AnalysisException: [CANNOT_RESOLVE_DATAFRAME_COLUMN] Cannot resolve dataframe column "id". It's probably because of illegal references like `df1.select(df2.col("a"))`. SQLSTATE: 42704
```

That is due to existing plan caching in `ResolveRelations` doesn't work with Spark Connect

```
=== Applying Rule org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveRelations ===
 '[#12]Join LeftOuter, '`==`('index, 'id)                     '[#12]Join LeftOuter, '`==`('index, 'id)
!:- '[#9]UnresolvedRelation [test_table_1], [], false         :- '[#9]SubqueryAlias spark_catalog.default.test_table_1
!+- '[#11]Project ['index, 'value_2]                          :  +- 'UnresolvedCatalogRelation `spark_catalog`.`default`.`test_table_1`, [], false
!   +- '[#10]Join Inner, '`==`('id, 'index)                   +- '[#11]Project ['index, 'value_2]
!      :- '[#7]UnresolvedRelation [test_table_1], [], false      +- '[#10]Join Inner, '`==`('id, 'index)
!      +- '[#8]UnresolvedRelation [test_table_2], [], false         :- '[#9]SubqueryAlias spark_catalog.default.test_table_1
!                                                                   :  +- 'UnresolvedCatalogRelation `spark_catalog`.`default`.`test_table_1`, [], false
!                                                                   +- '[#8]SubqueryAlias spark_catalog.default.test_table_2
!                                                                      +- 'UnresolvedCatalogRelation `spark_catalog`.`default`.`test_table_2`, [], false

Can not resolve 'id with plan 7
```

`[#7]UnresolvedRelation [test_table_1], [], false` was wrongly resolved to the cached one
```
:- '[#9]SubqueryAlias spark_catalog.default.test_table_1
   +- 'UnresolvedCatalogRelation `spark_catalog`.`default`.`test_table_1`, [], false
```

### Does this PR introduce _any_ user-facing change?
yes, bug fix

### How was this patch tested?
added ut

### Was this patch authored or co-authored using generative AI tooling?
ci

Closes apache#45214 from zhengruifeng/connect_fix_read_join.

Authored-by: Ruifeng Zheng <ruifengz@apache.org>
Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
  • Loading branch information
zhengruifeng authored and dongjoon-hyun committed Feb 23, 2024
1 parent 3baa60a commit 06c741a
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 8 deletions.
23 changes: 22 additions & 1 deletion python/pyspark/sql/tests/test_readwriter.py
Expand Up @@ -20,7 +20,7 @@
import tempfile

from pyspark.errors import AnalysisException
from pyspark.sql.functions import col
from pyspark.sql.functions import col, lit
from pyspark.sql.readwriter import DataFrameWriterV2
from pyspark.sql.types import StructType, StructField, StringType
from pyspark.testing.sqlutils import ReusedSQLTestCase
Expand Down Expand Up @@ -181,6 +181,27 @@ def test_insert_into(self):
df.write.mode("overwrite").insertInto("test_table", False)
self.assertEqual(6, self.spark.sql("select * from test_table").count())

def test_cached_table(self):
with self.table("test_cached_table_1"):
self.spark.range(10).withColumn(
"value_1",
lit(1),
).write.saveAsTable("test_cached_table_1")

with self.table("test_cached_table_2"):
self.spark.range(10).withColumnRenamed("id", "index").withColumn(
"value_2", lit(2)
).write.saveAsTable("test_cached_table_2")

df1 = self.spark.read.table("test_cached_table_1")
df2 = self.spark.read.table("test_cached_table_2")
df3 = self.spark.read.table("test_cached_table_1")

join1 = df1.join(df2, on=df1.id == df2.index).select(df2.index, df2.value_2)
join2 = df3.join(join1, how="left", on=join1.index == df3.id)

self.assertEqual(join2.columns, ["id", "value_1", "index", "value_2"])


class ReadwriterV2TestsMixin:
def test_api(self):
Expand Down
Expand Up @@ -1275,16 +1275,29 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
val key =
((catalog.name +: ident.namespace :+ ident.name).toImmutableArraySeq,
finalTimeTravelSpec)
AnalysisContext.get.relationCache.get(key).map(_.transform {
case multi: MultiInstanceRelation =>
val newRelation = multi.newInstance()
newRelation.copyTagsFrom(multi)
newRelation
}).orElse {
AnalysisContext.get.relationCache.get(key).map { cache =>
val cachedRelation = cache.transform {
case multi: MultiInstanceRelation =>
val newRelation = multi.newInstance()
newRelation.copyTagsFrom(multi)
newRelation
}
u.getTagValue(LogicalPlan.PLAN_ID_TAG).map { planId =>
val cachedConnectRelation = cachedRelation.clone()
cachedConnectRelation.setTagValue(LogicalPlan.PLAN_ID_TAG, planId)
cachedConnectRelation
}.getOrElse(cachedRelation)
}.orElse {
val table = CatalogV2Util.loadTable(catalog, ident, finalTimeTravelSpec)
val loaded = createRelation(catalog, ident, table, u.options, u.isStreaming)
loaded.foreach(AnalysisContext.get.relationCache.update(key, _))
loaded
u.getTagValue(LogicalPlan.PLAN_ID_TAG).map { planId =>
loaded.map { loadedRelation =>
val loadedConnectRelation = loadedRelation.clone()
loadedConnectRelation.setTagValue(LogicalPlan.PLAN_ID_TAG, planId)
loadedConnectRelation
}
}.getOrElse(loaded)
}
case _ => None
}
Expand Down

0 comments on commit 06c741a

Please sign in to comment.