Skip to content

Commit

Permalink
[SPARK-22538][ML] SQLTransformer should not unpersist possibly cached…
Browse files Browse the repository at this point in the history
… input dataset

## What changes were proposed in this pull request?

`SQLTransformer.transform` unpersists input dataset when dropping temporary view. We should not change input dataset's cache status.

## How was this patch tested?

Added test.

Author: Liang-Chi Hsieh <viirya@gmail.com>

Closes #19772 from viirya/SPARK-22538.
  • Loading branch information
viirya authored and cloud-fan committed Nov 17, 2017
1 parent 7d039e0 commit fccb337
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ class SQLTransformer @Since("1.6.0") (@Since("1.6.0") override val uid: String)
dataset.createOrReplaceTempView(tableName)
val realStatement = $(statement).replace(tableIdentifier, tableName)
val result = dataset.sparkSession.sql(realStatement)
dataset.sparkSession.catalog.dropTempView(tableName)
// Call SessionCatalog.dropTempView to avoid unpersisting the possibly cached dataset.
dataset.sparkSession.sessionState.catalog.dropTempView(tableName)
result
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.types.{LongType, StructField, StructType}
import org.apache.spark.storage.StorageLevel

class SQLTransformerSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
Expand Down Expand Up @@ -60,4 +61,15 @@ class SQLTransformerSuite
val expected = StructType(Seq(StructField("id1", LongType, nullable = false)))
assert(outputSchema === expected)
}

test("SPARK-22538: SQLTransformer should not unpersist given dataset") {
val df = spark.range(10)
df.cache()
df.count()
assert(df.storageLevel != StorageLevel.NONE)
new SQLTransformer()
.setStatement("SELECT id + 1 AS id1 FROM __THIS__")
.transform(df)
assert(df.storageLevel != StorageLevel.NONE)
}
}

0 comments on commit fccb337

Please sign in to comment.