-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-11269][SQL] Java API support & test cases #9358
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
0166e71
33de26b
0eea82c
f6a674c
fbf791e
d8d5a19
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,11 +17,11 @@ | |
|
|
||
| package org.apache.spark.sql.catalyst.encoders | ||
|
|
||
|
|
||
|
|
||
| import scala.reflect.ClassTag | ||
|
|
||
| import org.apache.spark.sql.types.StructType | ||
| import org.apache.spark.util.Utils | ||
| import org.apache.spark.sql.types.{DataType, ObjectType, StructField, StructType} | ||
| import org.apache.spark.sql.catalyst.expressions._ | ||
|
|
||
| /** | ||
| * Used to convert a JVM object of type `T` to and from the internal Spark SQL representation. | ||
|
|
@@ -37,3 +37,120 @@ trait Encoder[T] extends Serializable { | |
| /** A ClassTag that can be used to construct and Array to contain a collection of `T`. */ | ||
| def clsTag: ClassTag[T] | ||
| } | ||
|
|
||
| object Encoder { | ||
| import scala.reflect.runtime.universe._ | ||
|
|
||
| def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder(flat = true) | ||
| def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder(flat = true) | ||
| def SHORT: Encoder[java.lang.Short] = ExpressionEncoder(flat = true) | ||
| def INT: Encoder[java.lang.Integer] = ExpressionEncoder(flat = true) | ||
| def LONG: Encoder[java.lang.Long] = ExpressionEncoder(flat = true) | ||
| def FLOAT: Encoder[java.lang.Float] = ExpressionEncoder(flat = true) | ||
| def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder(flat = true) | ||
| def STRING: Encoder[java.lang.String] = ExpressionEncoder(flat = true) | ||
|
|
||
| def tuple[T1, T2](enc1: Encoder[T1], enc2: Encoder[T2]): Encoder[(T1, T2)] = { | ||
| tuple(Seq(enc1, enc2).map(_.asInstanceOf[ExpressionEncoder[_]])) | ||
| .asInstanceOf[ExpressionEncoder[(T1, T2)]] | ||
| } | ||
|
|
||
| def tuple[T1, T2, T3]( | ||
| enc1: Encoder[T1], | ||
| enc2: Encoder[T2], | ||
| enc3: Encoder[T3]): Encoder[(T1, T2, T3)] = { | ||
| tuple(Seq(enc1, enc2, enc3).map(_.asInstanceOf[ExpressionEncoder[_]])) | ||
| .asInstanceOf[ExpressionEncoder[(T1, T2, T3)]] | ||
| } | ||
|
|
||
| def tuple[T1, T2, T3, T4]( | ||
| enc1: Encoder[T1], | ||
| enc2: Encoder[T2], | ||
| enc3: Encoder[T3], | ||
| enc4: Encoder[T4]): Encoder[(T1, T2, T3, T4)] = { | ||
| tuple(Seq(enc1, enc2, enc3, enc4).map(_.asInstanceOf[ExpressionEncoder[_]])) | ||
| .asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4)]] | ||
| } | ||
|
|
||
| def tuple[T1, T2, T3, T4, T5]( | ||
| enc1: Encoder[T1], | ||
| enc2: Encoder[T2], | ||
| enc3: Encoder[T3], | ||
| enc4: Encoder[T4], | ||
| enc5: Encoder[T5]): Encoder[(T1, T2, T3, T4, T5)] = { | ||
| tuple(Seq(enc1, enc2, enc3, enc4, enc5).map(_.asInstanceOf[ExpressionEncoder[_]])) | ||
| .asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4, T5)]] | ||
| } | ||
|
|
||
| private def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = { | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @cloud-fan , does that mean the limit will be 22? Do you think we should at least add it up to Tuple22, which is the limit of Scala?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can hold it off until some use cases come out that need more than Tuple5.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you! |
||
| assert(encoders.length > 1) | ||
| // make sure all encoders are resolved, i.e. `Attribute` has been resolved to `BoundReference`. | ||
| assert(encoders.forall(_.constructExpression.find(_.isInstanceOf[Attribute]).isEmpty)) | ||
|
|
||
| val schema = StructType(encoders.zipWithIndex.map { | ||
| case (e, i) => StructField(s"_${i + 1}", if (e.flat) e.schema.head.dataType else e.schema) | ||
| }) | ||
|
|
||
| val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}") | ||
|
|
||
| val extractExpressions = encoders.map { | ||
| case e if e.flat => e.extractExpressions.head | ||
| case other => CreateStruct(other.extractExpressions) | ||
| }.zipWithIndex.map { case (expr, index) => | ||
| expr.transformUp { | ||
| case BoundReference(0, t: ObjectType, _) => | ||
| Invoke( | ||
| BoundReference(0, ObjectType(cls), true), | ||
| s"_${index + 1}", | ||
| t) | ||
| } | ||
| } | ||
|
|
||
| val constructExpressions = encoders.zipWithIndex.map { case (enc, index) => | ||
| if (enc.flat) { | ||
| enc.constructExpression.transform { | ||
| case b: BoundReference => b.copy(ordinal = index) | ||
| } | ||
| } else { | ||
| enc.constructExpression.transformUp { | ||
| case BoundReference(ordinal, dt, _) => | ||
| GetInternalRowField(BoundReference(index, enc.schema, true), ordinal, dt) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| val constructExpression = | ||
| NewInstance(cls, constructExpressions, false, ObjectType(cls)) | ||
|
|
||
| new ExpressionEncoder[Any]( | ||
| schema, | ||
| false, | ||
| extractExpressions, | ||
| constructExpression, | ||
| ClassTag.apply(cls)) | ||
| } | ||
|
|
||
|
|
||
| def typeTagOfTuple2[T1 : TypeTag, T2 : TypeTag]: TypeTag[(T1, T2)] = typeTag[(T1, T2)] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
|
||
| private def getTypeTag[T](c: Class[T]): TypeTag[T] = { | ||
| import scala.reflect.api | ||
|
|
||
| // val mirror = runtimeMirror(c.getClassLoader) | ||
| val mirror = rootMirror | ||
| val sym = mirror.staticClass(c.getName) | ||
| val tpe = sym.selfType | ||
| TypeTag(mirror, new api.TypeCreator { | ||
| def apply[U <: api.Universe with Singleton](m: api.Mirror[U]) = | ||
| if (m eq mirror) tpe.asInstanceOf[U # Type] | ||
| else throw new IllegalArgumentException( | ||
| s"Type tag defined in $mirror cannot be migrated to other mirrors.") | ||
| }) | ||
| } | ||
|
|
||
| def forTuple2[T1, T2](c1: Class[T1], c2: Class[T2]): Encoder[(T1, T2)] = { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about just
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @marmbrus Any reason why it is tuple 5, instead of tuple 22 which is the current limit of Scala? |
||
| implicit val typeTag1 = getTypeTag(c1) | ||
| implicit val typeTag2 = getTypeTag(c2) | ||
| ExpressionEncoder[(T1, T2)]() | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,9 +17,13 @@ | |
|
|
||
| package org.apache.spark.sql | ||
|
|
||
| import scala.collection.JavaConverters._ | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's use explicit conversions instead of implicit ones.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. JavaConverters is the explicit one ( |
||
|
|
||
| import org.apache.spark.annotation.Experimental | ||
| import org.apache.spark.rdd.RDD | ||
| import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias | ||
| import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, _} | ||
|
|
||
| import org.apache.spark.sql.catalyst.encoders._ | ||
| import org.apache.spark.sql.catalyst.expressions._ | ||
| import org.apache.spark.sql.catalyst.plans.Inner | ||
|
|
@@ -148,18 +152,37 @@ class Dataset[T] private( | |
| def transform[U](t: Dataset[T] => Dataset[U]): Dataset[U] = t(this) | ||
|
|
||
| /** | ||
| * (Scala-specific) | ||
| * Returns a new [[Dataset]] that only contains elements where `func` returns `true`. | ||
| * @since 1.6.0 | ||
| */ | ||
| def filter(func: T => Boolean): Dataset[T] = mapPartitions(_.filter(func)) | ||
|
|
||
| /** | ||
| * (Java-specific) | ||
| * Returns a new [[Dataset]] that only contains elements where `func` returns `true`. | ||
| * @since 1.6.0 | ||
| */ | ||
| def filter(func: JFunction[T, java.lang.Boolean]): Dataset[T] = | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After talking with @rxin we should probably create |
||
| filter(t => func.call(t).booleanValue()) | ||
|
|
||
| /** | ||
| * (Scala-specific) | ||
| * Returns a new [[Dataset]] that contains the result of applying `func` to each element. | ||
| * @since 1.6.0 | ||
| */ | ||
| def map[U : Encoder](func: T => U): Dataset[U] = mapPartitions(_.map(func)) | ||
|
|
||
| /** | ||
| * (Java-specific) | ||
| * Returns a new [[Dataset]] that contains the result of applying `func` to each element. | ||
| * @since 1.6.0 | ||
| */ | ||
| def map[U](func: JFunction[T, U], encoder: Encoder[U]): Dataset[U] = | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here, |
||
| map(t => func.call(t))(encoder) | ||
|
|
||
| /** | ||
| * (Scala-specific) | ||
| * Returns a new [[Dataset]] that contains the result of applying `func` to each element. | ||
| * @since 1.6.0 | ||
| */ | ||
|
|
@@ -174,37 +197,93 @@ class Dataset[T] private( | |
| logicalPlan)) | ||
| } | ||
|
|
||
| /** | ||
| * (Java-specific) | ||
| * Returns a new [[Dataset]] that contains the result of applying `func` to each element. | ||
| * @since 1.6.0 | ||
| */ | ||
| def mapPartitions[U]( | ||
| f: FlatMapFunction[java.util.Iterator[T], U], | ||
| encoder: Encoder[U]): Dataset[U] = { | ||
| val func: (Iterator[T]) => Iterator[U] = x => f.call(x.asJava).iterator().asScala | ||
| mapPartitions(func)(encoder) | ||
| } | ||
|
|
||
| /** | ||
| * (Scala-specific) | ||
| * Returns a new [[Dataset]] by first applying a function to all elements of this [[Dataset]], | ||
| * and then flattening the results. | ||
| * @since 1.6.0 | ||
| */ | ||
| def flatMap[U : Encoder](func: T => TraversableOnce[U]): Dataset[U] = | ||
| mapPartitions(_.flatMap(func)) | ||
|
|
||
| /** | ||
| * (Java-specific) | ||
| * Returns a new [[Dataset]] by first applying a function to all elements of this [[Dataset]], | ||
| * and then flattening the results. | ||
| * @since 1.6.0 | ||
| */ | ||
| def flatMap[U](f: FlatMapFunction[T, U], encoder: Encoder[U]): Dataset[U] = { | ||
| val func: (T) => Iterable[U] = x => f.call(x).asScala | ||
| flatMap(func)(encoder) | ||
| } | ||
|
|
||
| /* ************** * | ||
| * Side effects * | ||
| * ************** */ | ||
|
|
||
| /** | ||
| * (Scala-specific) | ||
| * Runs `func` on each element of this Dataset. | ||
| * @since 1.6.0 | ||
| */ | ||
| def foreach(func: T => Unit): Unit = rdd.foreach(func) | ||
|
|
||
| /** | ||
| * (Java-specific) | ||
| * Runs `func` on each element of this Dataset. | ||
| * @since 1.6.0 | ||
| */ | ||
| def foreach(func: VoidFunction[T]): Unit = foreach(func.call(_)) | ||
|
|
||
| /** | ||
| * (Scala-specific) | ||
| * Runs `func` on each partition of this Dataset. | ||
| * @since 1.6.0 | ||
| */ | ||
| def foreachPartition(func: Iterator[T] => Unit): Unit = rdd.foreachPartition(func) | ||
|
|
||
| /** | ||
| * (Java-specific) | ||
| * Runs `func` on each partition of this Dataset. | ||
| * @since 1.6.0 | ||
| */ | ||
| def foreachPartition(func: VoidFunction[java.util.Iterator[T]]): Unit = | ||
| foreachPartition(it => func.call(it.asJava)) | ||
|
|
||
| /* ************* * | ||
| * Aggregation * | ||
| * ************* */ | ||
|
|
||
| /** | ||
| * (Scala-specific) | ||
| * Reduces the elements of this Dataset using the specified binary function. The given function | ||
| * must be commutative and associative or the result may be non-deterministic. | ||
| * @since 1.6.0 | ||
| */ | ||
| def reduce(func: (T, T) => T): T = rdd.reduce(func) | ||
|
|
||
| /** | ||
| * (Java-specific) | ||
| * Reduces the elements of this Dataset using the specified binary function. The given function | ||
| * must be commutative and associative or the result may be non-deterministic. | ||
| * @since 1.6.0 | ||
| */ | ||
| def reduce(func: JFunction2[T, T, T]): T = reduce(func.call(_, _)) | ||
|
|
||
| /** | ||
| * (Scala-specific) | ||
| * Aggregates the elements of each partition, and then the results for all the partitions, using a | ||
| * given associative and commutative function and a neutral "zero value". | ||
| * | ||
|
|
@@ -218,6 +297,15 @@ class Dataset[T] private( | |
| def fold(zeroValue: T)(op: (T, T) => T): T = rdd.fold(zeroValue)(op) | ||
|
|
||
| /** | ||
| * (Java-specific) | ||
| * Aggregates the elements of each partition, and then the results for all the partitions, using a | ||
| * given associative and commutative function and a neutral "zero value". | ||
| * @since 1.6.0 | ||
| */ | ||
| def fold(zeroValue: T, func: JFunction2[T, T, T]): T = fold(zeroValue)(func.call(_, _)) | ||
|
|
||
| /** | ||
| * (Scala-specific) | ||
| * Returns a [[GroupedDataset]] where the data is grouped by the given key function. | ||
| * @since 1.6.0 | ||
| */ | ||
|
|
@@ -255,6 +343,14 @@ class Dataset[T] private( | |
| keyAttributes) | ||
| } | ||
|
|
||
| /** | ||
| * (Java-specific) | ||
| * Returns a [[GroupedDataset]] where the data is grouped by the given key function. | ||
| * @since 1.6.0 | ||
| */ | ||
| def groupBy[K](f: JFunction[T, K], encoder: Encoder[K]): GroupedDataset[K, T] = | ||
| groupBy(f.call(_))(encoder) | ||
|
|
||
| /* ****************** * | ||
| * Typed Relational * | ||
| * ****************** */ | ||
|
|
@@ -264,8 +360,7 @@ class Dataset[T] private( | |
| * {{{ | ||
| * df.select($"colA", $"colB" + 1) | ||
| * }}} | ||
| * @group dfops | ||
| * @since 1.3.0 | ||
| * @since 1.6.0 | ||
| */ | ||
| // Copied from Dataframe to make sure we don't have invalid overloads. | ||
| @scala.annotation.varargs | ||
|
|
@@ -276,7 +371,7 @@ class Dataset[T] private( | |
| * | ||
| * {{{ | ||
| * val ds = Seq(1, 2, 3).toDS() | ||
| * val newDS = ds.select(e[Int]("value + 1")) | ||
| * val newDS = ds.select(expr("value + 1").as[Int]) | ||
| * }}} | ||
| * @since 1.6.0 | ||
| */ | ||
|
|
@@ -402,6 +497,8 @@ class Dataset[T] private( | |
| * This type of join can be useful both for preserving type-safety with the original object | ||
| * types as well as working with relational data where either side of the join has column | ||
| * names in common. | ||
| * | ||
| * @since 1.6.0 | ||
| */ | ||
| def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = { | ||
| val left = this.logicalPlan | ||
|
|
@@ -435,12 +532,31 @@ class Dataset[T] private( | |
| * Gather to Driver Actions * | ||
| * ************************** */ | ||
|
|
||
| /** Returns the first element in this [[Dataset]]. */ | ||
| /** | ||
| * Returns the first element in this [[Dataset]]. | ||
| * @since 1.6.0 | ||
| */ | ||
| def first(): T = rdd.first() | ||
|
|
||
| /** Collects the elements to an Array. */ | ||
| /** | ||
| * Collects the elements to an Array. | ||
| * @since 1.6.0 | ||
| */ | ||
| def collect(): Array[T] = rdd.collect() | ||
|
|
||
| /** | ||
| * (Java-specific) | ||
| * Collects the elements to a Java list. | ||
| * | ||
| * Due to the incompatibility problem between Scala and Java, the return type of [[collect()]] at | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This just means that the RDD has the wrong classtag. We need to find a way to pass the classtag from the encoder before calling collect.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will the class tag do the trick? I tried to define a generic class with ClassTag: The return type of One possible solution is to define
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. RDD holds a class tag of the element type that it uses to construct the
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cc @rxin
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can construct right type of array while calling
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see, then we should have collect as list too.
|
||
| * Java side is `java.lang.Object`, which is not easy to use. Java user can use this method | ||
| * instead and keep the generic type for result. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| * | ||
| * @since 1.6.0 | ||
| */ | ||
| def collectAsList(): java.util.List[T] = | ||
| rdd.collect().toSeq.asJava | ||
|
|
||
| /** Returns the first `num` elements of this [[Dataset]] as an Array. */ | ||
| def take(num: Int): Array[T] = rdd.take(num) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@cloud-fan Could you share me your idea why we do not add the other basic types like DecimalType, DateType and TimestampType? Thank you!
DecimalType -> java.math.BigDecimal
DateType -> java.sql.Date
TimestampType -> java.sql.Timestamp
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should add these.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will do it soon. Thanks!