Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-43321][Connect] Dataset#Joinwith #40997

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import java.util.{Collections, Locale}

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.reflect.ClassTag
import scala.util.control.NonFatal

import org.apache.spark.SparkException
Expand Down Expand Up @@ -568,7 +569,7 @@ class Dataset[T] private[sql] (
}
}

private def toJoinType(name: String): proto.Join.JoinType = {
private def toJoinType(name: String, skipSemiAnti: Boolean = false): proto.Join.JoinType = {
name.trim.toLowerCase(Locale.ROOT) match {
case "inner" =>
proto.Join.JoinType.JOIN_TYPE_INNER
Expand All @@ -580,12 +581,12 @@ class Dataset[T] private[sql] (
proto.Join.JoinType.JOIN_TYPE_LEFT_OUTER
case "right" | "rightouter" | "right_outer" =>
proto.Join.JoinType.JOIN_TYPE_RIGHT_OUTER
case "semi" | "leftsemi" | "left_semi" =>
case "semi" | "leftsemi" | "left_semi" if !skipSemiAnti =>
proto.Join.JoinType.JOIN_TYPE_LEFT_SEMI
case "anti" | "leftanti" | "left_anti" =>
case "anti" | "leftanti" | "left_anti" if !skipSemiAnti =>
proto.Join.JoinType.JOIN_TYPE_LEFT_ANTI
case _ =>
throw new IllegalArgumentException(s"Unsupported join type `joinType`.")
case e =>
throw new IllegalArgumentException(s"Unsupported join type '$e'.")
}
}

Expand Down Expand Up @@ -835,6 +836,80 @@ class Dataset[T] private[sql] (
}
}

/**
* Joins this Dataset returning a `Tuple2` for each pair where `condition` evaluates to true.
*
* This is similar to the relation `join` function with one important difference in the result
* schema. Since `joinWith` preserves objects present on either side of the join, the result
* schema is similarly nested into a tuple under the column names `_1` and `_2`.
*
* This type of join can be useful both for preserving type-safety with the original object
* types as well as working with relational data where either side of the join has column names
* in common.
*
* @param other
* Right side of the join.
* @param condition
* Join expression.
* @param joinType
* Type of join to perform. Default `inner`. Must be one of: `inner`, `cross`, `outer`,
* `full`, `fullouter`,`full_outer`, `left`, `leftouter`, `left_outer`, `right`, `rightouter`,
* `right_outer`.
*
* @group typedrel
* @since 3.5.0
*/
def joinWith[U](other: Dataset[U], condition: Column, joinType: String): Dataset[(T, U)] = {
val joinTypeValue = toJoinType(joinType, skipSemiAnti = true)
val (leftNullable, rightNullable) = joinTypeValue match {
case proto.Join.JoinType.JOIN_TYPE_INNER | proto.Join.JoinType.JOIN_TYPE_CROSS =>
(false, false)
case proto.Join.JoinType.JOIN_TYPE_FULL_OUTER =>
(true, true)
case proto.Join.JoinType.JOIN_TYPE_LEFT_OUTER =>
(false, true)
case proto.Join.JoinType.JOIN_TYPE_RIGHT_OUTER =>
(true, false)
case e =>
throw new IllegalArgumentException(s"Unsupported join type '$e'.")
}

val tupleEncoder =
ProductEncoder[(T, U)](
ClassTag(Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple2")),
Seq(
EncoderField(s"_1", this.encoder, leftNullable, Metadata.empty),
EncoderField(s"_2", other.encoder, rightNullable, Metadata.empty)))

sparkSession.newDataset(tupleEncoder) { builder =>
val joinBuilder = builder.getJoinBuilder
joinBuilder
.setLeft(plan.getRoot)
.setRight(other.plan.getRoot)
.setJoinType(joinTypeValue)
.setJoinCondition(condition.expr)
.setJoinDataType(joinBuilder.getJoinDataTypeBuilder
.setIsLeftFlattenableToRow(this.encoder.isFlattenable)
.setIsRightFlattenableToRow(other.encoder.isFlattenable))
}
}

/**
* Using inner equi-join to join this Dataset returning a `Tuple2` for each pair where
* `condition` evaluates to true.
*
* @param other
* Right side of the join.
* @param condition
* Join expression.
*
* @group typedrel
* @since 3.5.0
*/
def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = {
joinWith(other, condition, "inner")
}

/**
* Returns a new Dataset with each partition sorted by the given expressions.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ import org.apache.arrow.vector.ipc.ArrowStreamReader
import org.apache.spark.connect.proto
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.UnboundRowEncoder
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ProductEncoder, UnboundRowEncoder}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.Deserializer
import org.apache.spark.sql.connect.client.util.{AutoCloseables, Cleanable}
import org.apache.spark.sql.connect.common.DataTypeProtoConverter
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector}

Expand All @@ -50,15 +50,33 @@ private[sql] class SparkResult[T](
private val idxToBatches = mutable.Map.empty[Int, ColumnarBatch]

private def createEncoder(schema: StructType): ExpressionEncoder[T] = {
val agnosticEncoder = if (encoder == UnboundRowEncoder) {
// Create a row encoder based on the schema.
RowEncoder.encoderFor(schema).asInstanceOf[AgnosticEncoder[T]]
} else {
encoder
}
val agnosticEncoder = createEncoder(encoder, schema).asInstanceOf[AgnosticEncoder[T]]
ExpressionEncoder(agnosticEncoder)
}

/**
* Update RowEncoder and recursively update the fields of the ProductEncoder if found.
*/
private def createEncoder[_](
enc: AgnosticEncoder[_],
dataType: DataType): AgnosticEncoder[_] = {
enc match {
case UnboundRowEncoder =>
// Replace the row encoder with the encoder inferred from the schema.
RowEncoder.encoderFor(dataType.asInstanceOf[StructType])
case ProductEncoder(clsTag, fields) if ProductEncoder.isTuple(clsTag) =>
// Recursively continue updating the tuple product encoder
val schema = dataType.asInstanceOf[StructType]
assert(fields.length <= schema.fields.length)
val updatedFields = fields.zipWithIndex.map { case (f, id) =>
f.copy(enc = createEncoder(f.enc, schema.fields(id).dataType))
}
ProductEncoder(clsTag, updatedFields)
case _ =>
enc
}
}

private def processResponses(stopOnFirstNonEmptyResponse: Boolean): Boolean = {
while (responses.hasNext) {
val response = responses.next()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM
}

test("read and write") {
assume(IntegrationTestUtils.isSparkHiveJarAvailable)
val testDataPath = java.nio.file.Paths
.get(
IntegrationTestUtils.sparkHome,
Expand Down Expand Up @@ -158,6 +159,7 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM
}

test("textFile") {
assume(IntegrationTestUtils.isSparkHiveJarAvailable)
val testDataPath = java.nio.file.Paths
.get(
IntegrationTestUtils.sparkHome,
Expand All @@ -178,6 +180,7 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM
}

test("write table") {
assume(IntegrationTestUtils.isSparkHiveJarAvailable)
withTable("myTable") {
val df = spark.range(10).limit(3)
df.write.mode(SaveMode.Overwrite).saveAsTable("myTable")
Expand Down Expand Up @@ -221,6 +224,7 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM
}

test("write without table or path") {
assume(IntegrationTestUtils.isSparkHiveJarAvailable)
// Should receive no error to write noop
spark.range(10).write.format("noop").mode("append").save()
}
Expand Down Expand Up @@ -970,8 +974,197 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM
val result2 = spark.sql("select :c0 limit :l0", Map("l0" -> 1, "c0" -> "abc")).collect()
assert(result2.length == 1 && result2(0).getString(0) === "abc")
}

test("joinWith, flat schema") {
val session: SparkSession = spark
import session.implicits._
val ds1 = Seq(1, 2, 3).toDS().as("a")
val ds2 = Seq(1, 2).toDS().as("b")

val joined = ds1.joinWith(ds2, $"a.value" === $"b.value", "inner")

val expectedSchema = StructType(
Seq(
StructField("_1", IntegerType, nullable = false),
StructField("_2", IntegerType, nullable = false)))

assert(joined.schema === expectedSchema)

val expected = Seq((1, 1), (2, 2))
checkSameResult(expected, joined)
}

test("joinWith tuple with primitive, expression") {
val session: SparkSession = spark
import session.implicits._
val ds1 = Seq(1, 1, 2).toDS()
val ds2 = Seq(("a", 1), ("b", 2)).toDS()

val joined = ds1.joinWith(ds2, $"value" === $"_2")

// This is an inner join, so both outputs fields are non-nullable
val expectedSchema = StructType(
Seq(
StructField("_1", IntegerType, nullable = false),
StructField(
"_2",
StructType(
Seq(StructField("_1", StringType), StructField("_2", IntegerType, nullable = false))),
nullable = false)))
assert(joined.schema === expectedSchema)

checkSameResult(Seq((1, ("a", 1)), (1, ("a", 1)), (2, ("b", 2))), joined)
}

test("joinWith tuple with primitive, rows") {
val session: SparkSession = spark
import session.implicits._
val ds1 = Seq(1, 1, 2).toDF()
val ds2 = Seq(("a", 1), ("b", 2)).toDF()

val joined = ds1.joinWith(ds2, $"value" === $"_2")

checkSameResult(
Seq((Row(1), Row("a", 1)), (Row(1), Row("a", 1)), (Row(2), Row("b", 2))),
joined)
}

test("joinWith class with primitive, toDF") {
val session: SparkSession = spark
import session.implicits._
val ds1 = Seq(1, 1, 2).toDS()
val ds2 = Seq(ClassData("a", 1), ClassData("b", 2)).toDS()

val df = ds1
.joinWith(ds2, $"value" === $"b")
.toDF()
.select($"_1", $"_2.a", $"_2.b")
checkSameResult(Seq(Row(1, "a", 1), Row(1, "a", 1), Row(2, "b", 2)), df)
}

test("multi-level joinWith") {
val session: SparkSession = spark
import session.implicits._
val ds1 = Seq(("a", 1), ("b", 2)).toDS().as("a")
val ds2 = Seq(("a", 1), ("b", 2)).toDS().as("b")
val ds3 = Seq(("a", 1), ("b", 2)).toDS().as("c")

val joined = ds1
.joinWith(ds2, $"a._2" === $"b._2")
.as("ab")
.joinWith(ds3, $"ab._1._2" === $"c._2")

checkSameResult(
Seq(((("a", 1), ("a", 1)), ("a", 1)), ((("b", 2), ("b", 2)), ("b", 2))),
joined)
}

test("multi-level joinWith, rows") {
val session: SparkSession = spark
import session.implicits._
val ds1 = Seq(("a", 1), ("b", 2)).toDF().as("a")
val ds2 = Seq(("a", 1), ("b", 2)).toDF().as("b")
val ds3 = Seq(("a", 1), ("b", 2)).toDF().as("c")

val joined = ds1
.joinWith(ds2, $"a._2" === $"b._2")
.as("ab")
.joinWith(ds3, $"ab._1._2" === $"c._2")

checkSameResult(
Seq(((Row("a", 1), Row("a", 1)), Row("a", 1)), ((Row("b", 2), Row("b", 2)), Row("b", 2))),
joined)
}

test("self join") {
val session: SparkSession = spark
import session.implicits._
val ds = Seq("1", "2").toDS().as("a")
val joined = ds.joinWith(ds, lit(true), "cross")
checkSameResult(Seq(("1", "1"), ("1", "2"), ("2", "1"), ("2", "2")), joined)
}

test("SPARK-11894: Incorrect results are returned when using null") {
val session: SparkSession = spark
import session.implicits._
val nullInt = null.asInstanceOf[java.lang.Integer]
val ds1 = Seq((nullInt, "1"), (java.lang.Integer.valueOf(22), "2")).toDS()
val ds2 = Seq((nullInt, "1"), (java.lang.Integer.valueOf(22), "2")).toDS()

checkSameResult(
Seq(
((nullInt, "1"), (nullInt, "1")),
((nullInt, "1"), (java.lang.Integer.valueOf(22), "2")),
((java.lang.Integer.valueOf(22), "2"), (nullInt, "1")),
((java.lang.Integer.valueOf(22), "2"), (java.lang.Integer.valueOf(22), "2"))),
ds1.joinWith(ds2, lit(true), "cross"))
}

test("SPARK-15441: Dataset outer join") {
val session: SparkSession = spark
import session.implicits._
val left = Seq(ClassData("a", 1), ClassData("b", 2)).toDS().as("left")
val right = Seq(ClassData("x", 2), ClassData("y", 3)).toDS().as("right")
val joined = left.joinWith(right, $"left.b" === $"right.b", "left")

val expectedSchema = StructType(
Seq(
StructField(
"_1",
StructType(
Seq(StructField("a", StringType), StructField("b", IntegerType, nullable = false))),
nullable = false),
// This is a left join, so the right output is nullable:
StructField(
"_2",
StructType(
Seq(StructField("a", StringType), StructField("b", IntegerType, nullable = false))))))
assert(joined.schema === expectedSchema)

val result = joined.collect().toSet
assert(result == Set(ClassData("a", 1) -> null, ClassData("b", 2) -> ClassData("x", 2)))
}

test("SPARK-37829: DataFrame outer join") {
// Same as "SPARK-15441: Dataset outer join" but using DataFrames instead of Datasets
val session: SparkSession = spark
import session.implicits._
val left = Seq(ClassData("a", 1), ClassData("b", 2)).toDF().as("left")
val right = Seq(ClassData("x", 2), ClassData("y", 3)).toDF().as("right")
val joined = left.joinWith(right, $"left.b" === $"right.b", "left")

val leftFieldSchema = StructType(
Seq(StructField("a", StringType), StructField("b", IntegerType, nullable = false)))
val rightFieldSchema = StructType(
Seq(StructField("a", StringType), StructField("b", IntegerType, nullable = false)))
val expectedSchema = StructType(
Seq(
StructField("_1", leftFieldSchema, nullable = false),
// This is a left join, so the right output is nullable:
StructField("_2", rightFieldSchema)))
assert(joined.schema === expectedSchema)

val result = joined.collect().toSet
val expected = Set(
new GenericRowWithSchema(Array("a", 1), leftFieldSchema) ->
null,
new GenericRowWithSchema(Array("b", 2), leftFieldSchema) ->
new GenericRowWithSchema(Array("x", 2), rightFieldSchema))
assert(result == expected)
}

test("SPARK-24762: joinWith on Option[Product]") {
val session: SparkSession = spark
import session.implicits._
val ds1 = Seq(Some((1, 2)), Some((2, 3)), None).toDS().as("a")
val ds2 = Seq(Some((1, 2)), Some((2, 3)), None).toDS().as("b")
val joined = ds1.joinWith(ds2, $"a.value._1" === $"b.value._2", "inner")
checkSameResult(Seq((Some((2, 3)), Some((1, 2)))), joined)
}
}

private[sql] case class ClassData(a: String, b: Int)

private[sql] case class MyType(id: Long, a: Double, b: Double)
private[sql] case class KV(key: String, value: Int)
private[sql] class SimpleBean {
Expand Down
Loading