Skip to content

Commit

Permalink
[SPARK-42519][CONNECT][TESTS] Add More WriteTo Tests In Spark Connect…
Browse files Browse the repository at this point in the history
… Client

### What changes were proposed in this pull request?
Add more WriteTo tests for Spark Connect Client

### Why are the changes needed?
Improve Test Case, remove same todo

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Add new tests

Closes #40564 from Hisoka-X/connec_test.

Authored-by: Hisoka <fanjiaeminem@qq.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
(cherry picked from commit 4199325)
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
Hisoka-X authored and HyukjinKwon committed Apr 3, 2023
1 parent 807abf9 commit beb8928
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -220,41 +220,109 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper {
}
}

test("writeTo with create") {
withTable("testcat.myTableV2") {

val rows = Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c"))

val schema = StructType(Array(StructField("id", LongType), StructField("data", StringType)))

spark.createDataFrame(rows.asJava, schema).writeTo("testcat.myTableV2").create()

val outputRows = spark.table("testcat.myTableV2").collect()
assert(outputRows.length == 3)
}
}

test("writeTo with create and using") {
// TODO (SPARK-42519): Add more test after we can set configs. See more WriteTo test cases
// in SparkConnectProtoSuite.
// e.g. spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName)
withTable("myTableV2") {
spark.range(3).writeTo("myTableV2").using("parquet").create()
val result = spark.sql("select * from myTableV2").sort("id").collect()
assert(result.length == 3)
assert(result(0).getLong(0) == 0)
assert(result(1).getLong(0) == 1)
assert(result(2).getLong(0) == 2)
withTable("testcat.myTableV2") {
val rows = Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c"))

val schema = StructType(Array(StructField("id", LongType), StructField("data", StringType)))

spark.createDataFrame(rows.asJava, schema).writeTo("testcat.myTableV2").create()
val outputRows = spark.table("testcat.myTableV2").collect()
assert(outputRows.length == 3)

val columns = spark.table("testcat.myTableV2").columns
assert(columns.length == 2)

val sqlOutputRows = spark.sql("select * from testcat.myTableV2").collect()
assert(outputRows.length == 3)
assert(sqlOutputRows(0).schema == schema)
assert(sqlOutputRows(1).getString(1) == "b")
}
}

// TODO (SPARK-42519): Revisit this test after we can set configs.
// e.g. spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName)
test("writeTo with create and append") {
withTable("myTableV2") {
spark.range(3).writeTo("myTableV2").using("parquet").create()
withTable("myTableV2") {
assertThrows[StatusRuntimeException] {
// Failed to append as Cannot write into v1 table: `spark_catalog`.`default`.`mytablev2`.
spark.range(3).writeTo("myTableV2").append()
}
withTable("testcat.myTableV2") {

val rows = Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c"))

val schema = StructType(Array(StructField("id", LongType), StructField("data", StringType)))

spark.sql("CREATE TABLE testcat.myTableV2 (id bigint, data string) USING foo")

assert(spark.table("testcat.myTableV2").collect().isEmpty)

spark.createDataFrame(rows.asJava, schema).writeTo("testcat.myTableV2").append()
val outputRows = spark.table("testcat.myTableV2").collect()
assert(outputRows.length == 3)
}
}

test("WriteTo with overwrite") {
withTable("testcat.myTableV2") {

val rows1 = (1L to 3L).map { i =>
Row(i, "" + (i - 1 + 'a'))
}
val rows2 = (4L to 7L).map { i =>
Row(i, "" + (i - 1 + 'a'))
}

val schema = StructType(Array(StructField("id", LongType), StructField("data", StringType)))

spark.sql(
"CREATE TABLE testcat.myTableV2 (id bigint, data string) USING foo PARTITIONED BY (id)")

assert(spark.table("testcat.myTableV2").collect().isEmpty)

spark.createDataFrame(rows1.asJava, schema).writeTo("testcat.myTableV2").append()
val outputRows = spark.table("testcat.myTableV2").collect()
assert(outputRows.length == 3)

spark
.createDataFrame(rows2.asJava, schema)
.writeTo("testcat.myTableV2")
.overwrite(functions.expr("true"))
val outputRows2 = spark.table("testcat.myTableV2").collect()
assert(outputRows2.length == 4)

}
}

// TODO (SPARK-42519): Revisit this test after we can set configs.
// e.g. spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName)
test("writeTo with create") {
assume(IntegrationTestUtils.isSparkHiveJarAvailable)
withTable("myTableV2") {
// Failed to create as Hive support is required.
spark.range(3).writeTo("myTableV2").create()
test("WriteTo with overwritePartitions") {
withTable("testcat.myTableV2") {

val rows = (4L to 7L).map { i =>
Row(i, "" + (i - 1 + 'a'))
}

val schema = StructType(Array(StructField("id", LongType), StructField("data", StringType)))

spark.sql(
"CREATE TABLE testcat.myTableV2 (id bigint, data string) USING foo PARTITIONED BY (id)")

assert(spark.table("testcat.myTableV2").collect().isEmpty)

spark
.createDataFrame(rows.asJava, schema)
.writeTo("testcat.myTableV2")
.overwritePartitions()
val outputRows = spark.table("testcat.myTableV2").collect()
assert(outputRows.length == 4)

}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,22 +65,27 @@ object IntegrationTestUtils {
* @return
* the jar
*/
private[sql] def findJar(path: String, sbtName: String, mvnName: String): File = {
private[sql] def findJar(
path: String,
sbtName: String,
mvnName: String,
test: Boolean = false): File = {
val targetDir = new File(new File(sparkHome, path), "target")
assert(
targetDir.exists(),
s"Fail to locate the target folder: '${targetDir.getCanonicalPath}'. " +
s"SPARK_HOME='${new File(sparkHome).getCanonicalPath}'. " +
"Make sure the spark project jars has been built (e.g. using build/sbt package)" +
"and the env variable `SPARK_HOME` is set correctly.")
val suffix = if (test) "-tests.jar" else ".jar"
val jars = recursiveListFiles(targetDir).filter { f =>
// SBT jar
(f.getParentFile.getName == scalaDir &&
f.getName.startsWith(sbtName) && f.getName.endsWith(".jar")) ||
f.getName.startsWith(sbtName) && f.getName.endsWith(suffix)) ||
// Maven Jar
(f.getParent.endsWith("target") &&
f.getName.startsWith(mvnName) &&
f.getName.endsWith(s"${org.apache.spark.SPARK_VERSION}.jar"))
f.getName.endsWith(s"${org.apache.spark.SPARK_VERSION}$suffix"))
}
// It is possible we found more than one: one built by maven, and another by SBT
assert(jars.nonEmpty, s"Failed to find the jar inside folder: ${targetDir.getCanonicalPath}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,12 @@ object SparkConnectServerUtils {

private lazy val sparkConnect: Process = {
debug("Starting the Spark Connect Server...")
val jar = findJar(
val connectJar = findJar(
"connector/connect/server",
"spark-connect-assembly",
"spark-connect").getCanonicalPath
val driverClassPath = connectJar + ":" +
findJar("sql/catalyst", "spark-catalyst", "spark-catalyst", test = true).getCanonicalPath
val catalogImplementation = if (IntegrationTestUtils.isSparkHiveJarAvailable) {
"hive"
} else {
Expand All @@ -78,16 +80,16 @@ object SparkConnectServerUtils {
Seq(
"bin/spark-submit",
"--driver-class-path",
jar,
driverClassPath,
"--conf",
s"spark.connect.grpc.binding.port=$port",
"--conf",
"spark.sql.catalog.testcat=org.apache.spark.sql.connect.catalog.InMemoryTableCatalog",
"spark.sql.catalog.testcat=org.apache.spark.sql.connector.catalog.InMemoryTableCatalog",
"--conf",
s"spark.sql.catalogImplementation=$catalogImplementation",
"--class",
"org.apache.spark.sql.connect.SimpleSparkConnectService",
jar),
connectJar),
new File(sparkHome))

val io = new ProcessIO(
Expand Down
1 change: 1 addition & 0 deletions project/SparkBuild.scala
Original file line number Diff line number Diff line change
Expand Up @@ -854,6 +854,7 @@ object SparkConnectClient {

buildTestDeps := {
(LocalProject("assembly") / Compile / Keys.`package`).value
(LocalProject("catalyst") / Test / Keys.`package`).value
},

// SPARK-42538: Make sure the `${SPARK_HOME}/assembly/target/scala-$SPARK_SCALA_VERSION/jars` is available for testing.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import java.util.OptionalLong
import scala.collection.mutable

import com.google.common.base.Objects
import org.scalatest.Assertions._

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, JoinedRow}
Expand Down Expand Up @@ -433,7 +432,9 @@ abstract class InMemoryBaseTable(
protected var streamingWriter: StreamingWrite = StreamingAppend

override def overwriteDynamicPartitions(): WriteBuilder = {
assert(writer == Append)
if (writer != Append) {
throw new IllegalArgumentException(s"Unsupported writer type: $writer")
}
writer = DynamicOverwrite
streamingWriter = new StreamingNotSupportedOperation("overwriteDynamicPartitions")
this
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ package org.apache.spark.sql.connector.catalog

import java.util

import org.scalatest.Assertions.assert

import org.apache.spark.sql.connector.distributions.{Distribution, Distributions}
import org.apache.spark.sql.connector.expressions.{SortOrder, Transform}
import org.apache.spark.sql.connector.write.{LogicalWriteInfo, SupportsOverwrite, WriteBuilder, WriterCommitMessage}
Expand Down Expand Up @@ -89,14 +87,18 @@ class InMemoryTable(
with SupportsOverwrite {

override def truncate(): WriteBuilder = {
assert(writer == Append)
if (writer != Append) {
throw new IllegalArgumentException(s"Unsupported writer type: $writer")
}
writer = TruncateAndAppend
streamingWriter = StreamingTruncateAndAppend
this
}

override def overwrite(filters: Array[Filter]): WriteBuilder = {
assert(writer == Append)
if (writer != Append) {
throw new IllegalArgumentException(s"Unsupported writer type: $writer")
}
writer = new Overwrite(filters)
streamingWriter = new StreamingNotSupportedOperation(
s"overwrite (${filters.mkString("filters(", ", ", ")")})")
Expand Down

0 comments on commit beb8928

Please sign in to comment.