Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@

package org.apache.spark.sql.catalyst.encoders



import scala.reflect.ClassTag

import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils
import org.apache.spark.sql.types.{DataType, ObjectType, StructField, StructType}
import org.apache.spark.sql.catalyst.expressions._

/**
* Used to convert a JVM object of type `T` to and from the internal Spark SQL representation.
Expand All @@ -37,3 +37,120 @@ trait Encoder[T] extends Serializable {
/** A ClassTag that can be used to construct and Array to contain a collection of `T`. */
def clsTag: ClassTag[T]
}

object Encoder {
import scala.reflect.runtime.universe._

def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder(flat = true)
def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder(flat = true)
def SHORT: Encoder[java.lang.Short] = ExpressionEncoder(flat = true)
def INT: Encoder[java.lang.Integer] = ExpressionEncoder(flat = true)
def LONG: Encoder[java.lang.Long] = ExpressionEncoder(flat = true)
def FLOAT: Encoder[java.lang.Float] = ExpressionEncoder(flat = true)
def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder(flat = true)
def STRING: Encoder[java.lang.String] = ExpressionEncoder(flat = true)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan Could you share me your idea why we do not add the other basic types like DecimalType, DateType and TimestampType? Thank you!

DecimalType -> java.math.BigDecimal
DateType -> java.sql.Date
TimestampType -> java.sql.Timestamp

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add these.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do it soon. Thanks!


def tuple[T1, T2](enc1: Encoder[T1], enc2: Encoder[T2]): Encoder[(T1, T2)] = {
tuple(Seq(enc1, enc2).map(_.asInstanceOf[ExpressionEncoder[_]]))
.asInstanceOf[ExpressionEncoder[(T1, T2)]]
}

def tuple[T1, T2, T3](
enc1: Encoder[T1],
enc2: Encoder[T2],
enc3: Encoder[T3]): Encoder[(T1, T2, T3)] = {
tuple(Seq(enc1, enc2, enc3).map(_.asInstanceOf[ExpressionEncoder[_]]))
.asInstanceOf[ExpressionEncoder[(T1, T2, T3)]]
}

def tuple[T1, T2, T3, T4](
enc1: Encoder[T1],
enc2: Encoder[T2],
enc3: Encoder[T3],
enc4: Encoder[T4]): Encoder[(T1, T2, T3, T4)] = {
tuple(Seq(enc1, enc2, enc3, enc4).map(_.asInstanceOf[ExpressionEncoder[_]]))
.asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4)]]
}

def tuple[T1, T2, T3, T4, T5](
enc1: Encoder[T1],
enc2: Encoder[T2],
enc3: Encoder[T3],
enc4: Encoder[T4],
enc5: Encoder[T5]): Encoder[(T1, T2, T3, T4, T5)] = {
tuple(Seq(enc1, enc2, enc3, enc4, enc5).map(_.asInstanceOf[ExpressionEncoder[_]]))
.asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4, T5)]]
}

private def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I came up with a new approach to create encoder at java side, which support nested tuples as well. I also keep the old code below and you can judge which way is better.

cc @marmbrus @rxin

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan , does that mean the limit will be 22? Do you think we should at least add it up to Tuple22, which is the limit of Scala?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can hold it off until some use cases come out that need more than Tuple5.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

assert(encoders.length > 1)
// make sure all encoders are resolved, i.e. `Attribute` has been resolved to `BoundReference`.
assert(encoders.forall(_.constructExpression.find(_.isInstanceOf[Attribute]).isEmpty))

val schema = StructType(encoders.zipWithIndex.map {
case (e, i) => StructField(s"_${i + 1}", if (e.flat) e.schema.head.dataType else e.schema)
})

val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}")

val extractExpressions = encoders.map {
case e if e.flat => e.extractExpressions.head
case other => CreateStruct(other.extractExpressions)
}.zipWithIndex.map { case (expr, index) =>
expr.transformUp {
case BoundReference(0, t: ObjectType, _) =>
Invoke(
BoundReference(0, ObjectType(cls), true),
s"_${index + 1}",
t)
}
}

val constructExpressions = encoders.zipWithIndex.map { case (enc, index) =>
if (enc.flat) {
enc.constructExpression.transform {
case b: BoundReference => b.copy(ordinal = index)
}
} else {
enc.constructExpression.transformUp {
case BoundReference(ordinal, dt, _) =>
GetInternalRowField(BoundReference(index, enc.schema, true), ordinal, dt)
}
}
}

val constructExpression =
NewInstance(cls, constructExpressions, false, ObjectType(cls))

new ExpressionEncoder[Any](
schema,
false,
extractExpressions,
constructExpression,
ClassTag.apply(cls))
}


def typeTagOfTuple2[T1 : TypeTag, T2 : TypeTag]: TypeTag[(T1, T2)] = typeTag[(T1, T2)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

private


private def getTypeTag[T](c: Class[T]): TypeTag[T] = {
import scala.reflect.api

// val mirror = runtimeMirror(c.getClassLoader)
val mirror = rootMirror
val sym = mirror.staticClass(c.getName)
val tpe = sym.selfType
TypeTag(mirror, new api.TypeCreator {
def apply[U <: api.Universe with Singleton](m: api.Mirror[U]) =
if (m eq mirror) tpe.asInstanceOf[U # Type]
else throw new IllegalArgumentException(
s"Type tag defined in $mirror cannot be migrated to other mirrors.")
})
}

def forTuple2[T1, T2](c1: Class[T1], c2: Class[T2]): Encoder[(T1, T2)] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about just forTuple, the type of tuple returned is obvious from the number of arguments. We should also add at least up to tuple 5.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@marmbrus Any reason why it is tuple 5, instead of tuple 22 which is the current limit of Scala?

implicit val typeTag1 = getTypeTag(c1)
implicit val typeTag2 = getTypeTag(c2)
ExpressionEncoder[(T1, T2)]()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -491,3 +491,24 @@ case class CreateExternalRow(children: Seq[Expression]) extends Expression {
s"final ${classOf[Row].getName} ${ev.value} = new $rowClass($values);"
}
}

case class GetInternalRowField(child: Expression, ordinal: Int, dataType: DataType)
extends UnaryExpression {

override def nullable: Boolean = true

override def eval(input: InternalRow): Any =
throw new UnsupportedOperationException("Only code-generated evaluation is supported")

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val row = child.gen(ctx)
s"""
${row.code}
final boolean ${ev.isNull} = ${row.isNull};
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.value} = ${ctx.getValue(row.value, dataType, ordinal.toString)};
}
"""
}
}
126 changes: 121 additions & 5 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,13 @@

package org.apache.spark.sql

import scala.collection.JavaConverters._
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's use explicit conversions instead of implicit ones.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

JavaConverters is the explicit one (.asScala / .asJava); the more implicit one was banned by me in a Scalastyle updtae.


import org.apache.spark.annotation.Experimental
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias
import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, _}

import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.Inner
Expand Down Expand Up @@ -148,18 +152,37 @@ class Dataset[T] private(
def transform[U](t: Dataset[T] => Dataset[U]): Dataset[U] = t(this)

/**
* (Scala-specific)
* Returns a new [[Dataset]] that only contains elements where `func` returns `true`.
* @since 1.6.0
*/
def filter(func: T => Boolean): Dataset[T] = mapPartitions(_.filter(func))

/**
* (Java-specific)
* Returns a new [[Dataset]] that only contains elements where `func` returns `true`.
* @since 1.6.0
*/
def filter(func: JFunction[T, java.lang.Boolean]): Dataset[T] =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After talking with @rxin we should probably create FilterFunction both because here we can avoid boxing and also because this might be less confusing to java users.

filter(t => func.call(t).booleanValue())

/**
* (Scala-specific)
* Returns a new [[Dataset]] that contains the result of applying `func` to each element.
* @since 1.6.0
*/
def map[U : Encoder](func: T => U): Dataset[U] = mapPartitions(_.map(func))

/**
* (Java-specific)
* Returns a new [[Dataset]] that contains the result of applying `func` to each element.
* @since 1.6.0
*/
def map[U](func: JFunction[T, U], encoder: Encoder[U]): Dataset[U] =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, MapFunction

map(t => func.call(t))(encoder)

/**
* (Scala-specific)
* Returns a new [[Dataset]] that contains the result of applying `func` to each element.
* @since 1.6.0
*/
Expand All @@ -174,37 +197,93 @@ class Dataset[T] private(
logicalPlan))
}

/**
* (Java-specific)
* Returns a new [[Dataset]] that contains the result of applying `func` to each element.
* @since 1.6.0
*/
def mapPartitions[U](
f: FlatMapFunction[java.util.Iterator[T], U],
encoder: Encoder[U]): Dataset[U] = {
val func: (Iterator[T]) => Iterator[U] = x => f.call(x.asJava).iterator().asScala
mapPartitions(func)(encoder)
}

/**
* (Scala-specific)
* Returns a new [[Dataset]] by first applying a function to all elements of this [[Dataset]],
* and then flattening the results.
* @since 1.6.0
*/
def flatMap[U : Encoder](func: T => TraversableOnce[U]): Dataset[U] =
mapPartitions(_.flatMap(func))

/**
* (Java-specific)
* Returns a new [[Dataset]] by first applying a function to all elements of this [[Dataset]],
* and then flattening the results.
* @since 1.6.0
*/
def flatMap[U](f: FlatMapFunction[T, U], encoder: Encoder[U]): Dataset[U] = {
val func: (T) => Iterable[U] = x => f.call(x).asScala
flatMap(func)(encoder)
}

/* ************** *
* Side effects *
* ************** */

/**
* (Scala-specific)
* Runs `func` on each element of this Dataset.
* @since 1.6.0
*/
def foreach(func: T => Unit): Unit = rdd.foreach(func)

/**
* (Java-specific)
* Runs `func` on each element of this Dataset.
* @since 1.6.0
*/
def foreach(func: VoidFunction[T]): Unit = foreach(func.call(_))

/**
* (Scala-specific)
* Runs `func` on each partition of this Dataset.
* @since 1.6.0
*/
def foreachPartition(func: Iterator[T] => Unit): Unit = rdd.foreachPartition(func)

/**
* (Java-specific)
* Runs `func` on each partition of this Dataset.
* @since 1.6.0
*/
def foreachPartition(func: VoidFunction[java.util.Iterator[T]]): Unit =
foreachPartition(it => func.call(it.asJava))

/* ************* *
* Aggregation *
* ************* */

/**
* (Scala-specific)
* Reduces the elements of this Dataset using the specified binary function. The given function
* must be commutative and associative or the result may be non-deterministic.
* @since 1.6.0
*/
def reduce(func: (T, T) => T): T = rdd.reduce(func)

/**
* (Java-specific)
* Reduces the elements of this Dataset using the specified binary function. The given function
* must be commutative and associative or the result may be non-deterministic.
* @since 1.6.0
*/
def reduce(func: JFunction2[T, T, T]): T = reduce(func.call(_, _))

/**
* (Scala-specific)
* Aggregates the elements of each partition, and then the results for all the partitions, using a
* given associative and commutative function and a neutral "zero value".
*
Expand All @@ -218,6 +297,15 @@ class Dataset[T] private(
def fold(zeroValue: T)(op: (T, T) => T): T = rdd.fold(zeroValue)(op)

/**
* (Java-specific)
* Aggregates the elements of each partition, and then the results for all the partitions, using a
* given associative and commutative function and a neutral "zero value".
* @since 1.6.0
*/
def fold(zeroValue: T, func: JFunction2[T, T, T]): T = fold(zeroValue)(func.call(_, _))

/**
* (Scala-specific)
* Returns a [[GroupedDataset]] where the data is grouped by the given key function.
* @since 1.6.0
*/
Expand Down Expand Up @@ -255,6 +343,14 @@ class Dataset[T] private(
keyAttributes)
}

/**
* (Java-specific)
* Returns a [[GroupedDataset]] where the data is grouped by the given key function.
* @since 1.6.0
*/
def groupBy[K](f: JFunction[T, K], encoder: Encoder[K]): GroupedDataset[K, T] =
groupBy(f.call(_))(encoder)

/* ****************** *
* Typed Relational *
* ****************** */
Expand All @@ -264,8 +360,7 @@ class Dataset[T] private(
* {{{
* df.select($"colA", $"colB" + 1)
* }}}
* @group dfops
* @since 1.3.0
* @since 1.6.0
*/
// Copied from Dataframe to make sure we don't have invalid overloads.
@scala.annotation.varargs
Expand All @@ -276,7 +371,7 @@ class Dataset[T] private(
*
* {{{
* val ds = Seq(1, 2, 3).toDS()
* val newDS = ds.select(e[Int]("value + 1"))
* val newDS = ds.select(expr("value + 1").as[Int])
* }}}
* @since 1.6.0
*/
Expand Down Expand Up @@ -402,6 +497,8 @@ class Dataset[T] private(
* This type of join can be useful both for preserving type-safety with the original object
* types as well as working with relational data where either side of the join has column
* names in common.
*
* @since 1.6.0
*/
def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = {
val left = this.logicalPlan
Expand Down Expand Up @@ -435,12 +532,31 @@ class Dataset[T] private(
* Gather to Driver Actions *
* ************************** */

/** Returns the first element in this [[Dataset]]. */
/**
* Returns the first element in this [[Dataset]].
* @since 1.6.0
*/
def first(): T = rdd.first()

/** Collects the elements to an Array. */
/**
* Collects the elements to an Array.
* @since 1.6.0
*/
def collect(): Array[T] = rdd.collect()

/**
* (Java-specific)
* Collects the elements to a Java list.
*
* Due to the incompatibility problem between Scala and Java, the return type of [[collect()]] at
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This just means that the RDD has the wrong classtag. We need to find a way to pass the classtag from the encoder before calling collect.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will the class tag do the trick? I tried to define a generic class with ClassTag:

class MyTest[T : ClassTag] {
  def t(): Array[T] = null
}

object MyTest {
  def apply[T](cls: Class[T]): MyTest[T] = {
    new MyTest[T]()(ClassTag(cls))
  }
}

The return type of MyClass.t() is still Object at java side.
I also tried to use scala RDD at java side, the return type of RDD.collect() is also Object.

One possible solution is to define T <: AnyRef, but I think it's hard to make it for Dataset or RDD.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RDD holds a class tag of the element type that it uses to construct the
correct type of array when you do a collect.
On Nov 4, 2015 4:57 AM, "Wenchen Fan" notifications@github.com wrote:

In sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
#9358 (comment):

@@ -441,6 +537,17 @@ class Dataset[T] private(
/** Collects the elements to an Array. */
def collect(): Array[T] = rdd.collect()

  • /**
  • * (Java-specific)
  • * Collects the elements to a Java list.
  • * Due to the incompatibility problem between Scala and Java, the return type of [[collect()]] at

Will the class tag do the trick? I tried to define a generic class with
ClassTag:

class MyTest[T : ClassTag] {
def t(): Array[T] = null
}

object MyTest {
def apply[T](cls: Class[T]): MyTest[T] = {
new MyTestT(ClassTag(cls))
}
}

The return type of MyClass.t() is still Object at java side.
I also tried to use scala RDD at java side, the return type of
RDD.collect() is also Object.

One possible solution is to define T <: AnyRef, but I think it's hard to
make it for Dataset or RDD.


Reply to this email directly or view it on GitHub
https://github.com/apache/spark/pull/9358/files#r43840774.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @rxin

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can construct right type of array while calling RDD.collect, the problem is the interface. At java side the return type of RDD.collect() is java.lang.Object and we need to do a type cast, which is not friendly to users.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, then we should have collect as list too.
On Nov 4, 2015 5:19 AM, "Wenchen Fan" notifications@github.com wrote:

In sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
#9358 (comment):

@@ -441,6 +537,17 @@ class Dataset[T] private(
/** Collects the elements to an Array. */
def collect(): Array[T] = rdd.collect()

  • /**
  • * (Java-specific)
  • * Collects the elements to a Java list.
  • * Due to the incompatibility problem between Scala and Java, the return type of [[collect()]] at

We can construct right type of array while calling RDD.collect, the
problem is the interface. At java side the return type of RDD.collect()
is java.lang.Object and we need to do a type cast, which is not friendly
to users.


Reply to this email directly or view it on GitHub
https://github.com/apache/spark/pull/9358/files#r43841577.

* Java side is `java.lang.Object`, which is not easy to use. Java user can use this method
* instead and keep the generic type for result.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

*
* @since 1.6.0
*/
def collectAsList(): java.util.List[T] =
rdd.collect().toSeq.asJava

/** Returns the first `num` elements of this [[Dataset]] as an Array. */
def take(num: Int): Array[T] = rdd.take(num)

Expand Down
Loading