From 6dc22b96a3fa5706111329917f860aed32e31cc9 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 31 Mar 2016 21:46:27 +0800 Subject: [PATCH 01/10] create MapElements --- .../sql/catalyst/plans/logical/object.scala | 26 +++++++++++++++++++ .../scala/org/apache/spark/sql/Dataset.scala | 7 ++++- .../spark/sql/execution/SparkStrategies.scala | 2 ++ .../sql/execution/WholeStageCodegen.scala | 2 +- .../apache/spark/sql/execution/objects.scala | 21 +++++++++++++++ .../org/apache/spark/sql/QueryTest.scala | 5 +--- 6 files changed, 57 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 058fb6bff1c6e..a5bad61148bc0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -91,6 +91,32 @@ case class MapPartitions( override def deserializers: Seq[(Expression, Seq[Attribute])] = Seq(deserializer -> child.output) } +object MapElements { + def apply[T : Encoder, U : Encoder]( + func: T => U, + child: LogicalPlan): MapElements = { + MapElements( + func.asInstanceOf[Any => Any], + encoderFor[T].deserializer, + encoderFor[U].namedExpressions, + child) + } +} + +/** + * A relation produced by applying `func` to each element of the `child`. + * + * @param deserializer used to extract the input to `func` from an input row. + * @param serializer use to serialize the output of `func`. + */ +case class MapElements( + func: Any => Any, + deserializer: Expression, + serializer: Seq[NamedExpression], + child: LogicalPlan) extends UnaryNode with ObjectOperator { + override def deserializers: Seq[(Expression, Seq[Attribute])] = Seq(deserializer -> child.output) +} + /** Factory for constructing new `AppendColumn` nodes. */ object AppendColumns { def apply[T : Encoder, U : Encoder]( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 41cb799b97141..6435e76912fa2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1882,7 +1882,12 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - def map[U : Encoder](func: T => U): Dataset[U] = mapPartitions(_.map(func)) + def map[U : Encoder](func: T => U): Dataset[U] = { + new Dataset[U]( + sqlContext, + MapElements[T, U](func, logicalPlan), + implicitly[Encoder[U]]) + } /** * :: Experimental :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 7841ff01f93c2..b1e30da8f6237 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -375,6 +375,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.MapPartitions(f, in, out, child) => execution.MapPartitions(f, in, out, planLater(child)) :: Nil + case logical.MapElements(f, in, out, child) => + execution.MapElements(f, in, out, planLater(child)) :: Nil case logical.AppendColumns(f, in, out, child) => execution.AppendColumns(f, in, out, planLater(child)) :: Nil case logical.MapGroups(f, key, in, out, grouping, data, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 6a779abd40a3c..5e14d1979d308 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -169,7 +169,7 @@ trait CodegenSupport extends SparkPlan { /** * Returns source code to evaluate the variables for required attributes, and clear the code - * of evaluated variables, to prevent them to be evaluated twice.. + * of evaluated variables, to prevent them to be evaluated twice. */ protected def evaluateRequiredVariables( attributes: Seq[Attribute], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index 582dda8603f4e..fde37eead635b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -67,6 +67,27 @@ case class MapPartitions( } } +/** + * Applies the given function to each input row and encodes the result. + */ +case class MapElements( + func: Any => Any, + deserializer: Expression, + serializer: Seq[NamedExpression], + child: SparkPlan) extends UnaryNode with ObjectOperator { + override def output: Seq[Attribute] = serializer.map(_.toAttribute) + + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitionsInternal { iter => + val getObject = generateToObject(deserializer, child.output) + val outputObject = generateToRow(serializer) + iter.map(getObject).map(func).map(outputObject) + } + } + + override def outputOrdering: Seq[SortOrder] = child.outputOrdering +} + /** * Applies the given function to each input row, appending the encoded result at the end of the row. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 854a662cc4d3d..02580c7c409b1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -198,10 +198,7 @@ abstract class QueryTest extends PlanTest { val logicalPlan = df.queryExecution.analyzed // bypass some cases that we can't handle currently. logicalPlan.transform { - case _: MapPartitions => return - case _: MapGroups => return - case _: AppendColumns => return - case _: CoGroup => return + case _: ObjectOperator => return case _: LogicalRelation => return }.transformAllExpressions { case a: ImperativeAggregate => return From 770d9bbbb03de79905dada70e186adb74676edb0 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 31 Mar 2016 22:23:06 +0800 Subject: [PATCH 02/10] support whole stage codegen --- .../apache/spark/sql/execution/objects.scala | 33 +++++++++++++++++-- .../execution/WholeStageCodegenSuite.scala | 11 +++++++ 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index fde37eead635b..958350ada809e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection, GenerateUnsafeRowJoiner} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.types.ObjectType @@ -74,9 +74,38 @@ case class MapElements( func: Any => Any, deserializer: Expression, serializer: Seq[NamedExpression], - child: SparkPlan) extends UnaryNode with ObjectOperator { + child: SparkPlan) extends UnaryNode with ObjectOperator with CodegenSupport { override def output: Seq[Attribute] = serializer.map(_.toAttribute) + override def upstreams(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].upstreams() + } + + protected override def doProduce(ctx: CodegenContext): String = { + child.asInstanceOf[CodegenSupport].produce(ctx, this) + } + + // Mark this as empty. We'll always deserialize the input row and apply the given function at the + // beginning. + override def usedInputs: AttributeSet = AttributeSet.empty + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { + val bound = ExpressionCanonicalizer.execute( + BindReferences.bindReference(deserializer, child.output)) + ctx.currentVars = input + val evaluated = bound.gen(ctx) + + val resultObj = LambdaVariable(evaluated.value, evaluated.isNull, bound.dataType) + val outputFields = serializer.map(_ transform { + case _: BoundReference => resultObj + }) + val resultVars = outputFields.map(_.gen(ctx)) + s""" + ${evaluateRequiredVariables(output, Seq(evaluated), bound.references)} + ${consume(ctx, resultVars)} + """ + } + override protected def doExecute(): RDD[InternalRow] = { child.execute().mapPartitionsInternal { iter => val getObject = generateToObject(deserializer, child.output) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 6d5be0b5dda12..1733cb931f9d5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -70,4 +70,15 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[Sort]).isDefined) assert(df.collect() === Array(Row(1), Row(2), Row(3))) } + + test("MapElements should be included in WholeStageCodegen") { + import testImplicits._ + + val ds = sqlContext.range(10).map(_.toString) + val plan = ds.queryExecution.executedPlan + assert(plan.find(p => + p.isInstanceOf[WholeStageCodegen] && + p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[MapElements]).isDefined) + assert(ds.collect() === 0.until(10).map(_.toString).toArray) + } } From b22752b65dda134dab04cbfd0500e148e828e551 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 4 Apr 2016 14:12:50 +0800 Subject: [PATCH 03/10] finish --- .../spark/api/java/function/MapFunction.java | 2 +- .../sql/catalyst/plans/logical/object.scala | 6 ++-- .../scala/org/apache/spark/sql/Dataset.scala | 9 ++++-- .../sql/execution/WholeStageCodegen.scala | 9 +++--- .../apache/spark/sql/execution/objects.scala | 29 +++++++++++++------ .../apache/spark/sql/JavaDatasetSuite.java | 6 ++-- .../sources/JavaDatasetAggregatorSuite.java | 2 +- .../execution/WholeStageCodegenSuite.scala | 3 +- 8 files changed, 41 insertions(+), 25 deletions(-) diff --git a/core/src/main/java/org/apache/spark/api/java/function/MapFunction.java b/core/src/main/java/org/apache/spark/api/java/function/MapFunction.java index 3ae6ef44898e1..1e874c8cc8fb4 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/MapFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/MapFunction.java @@ -23,5 +23,5 @@ * Base interface for a map function used in Dataset's map function. */ public interface MapFunction extends Serializable { - U call(T value) throws Exception; + U call(T value); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index a5bad61148bc0..e4af526106a66 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -93,10 +93,10 @@ case class MapPartitions( object MapElements { def apply[T : Encoder, U : Encoder]( - func: T => U, + func: AnyRef, child: LogicalPlan): MapElements = { MapElements( - func.asInstanceOf[Any => Any], + func, encoderFor[T].deserializer, encoderFor[U].namedExpressions, child) @@ -110,7 +110,7 @@ object MapElements { * @param serializer use to serialize the output of `func`. */ case class MapElements( - func: Any => Any, + func: AnyRef, deserializer: Expression, serializer: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode with ObjectOperator { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 6435e76912fa2..7fbc42b154ae0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1898,8 +1898,13 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = - map(t => func.call(t))(encoder) + def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = { + implicit val uEnc = encoder + new Dataset[U]( + sqlContext, + MapElements[T, U](func, logicalPlan), + uEnc) + } /** * :: Experimental :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index a6d273c7717cb..4e75a3a7945e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -152,7 +152,7 @@ trait CodegenSupport extends SparkPlan { s""" | |/*** CONSUME: ${toCommentSafeString(parent.simpleString)} */ - |${evaluated} + |$evaluated |${parent.doConsume(ctx, inputVars, rowVar)} """.stripMargin } @@ -175,14 +175,14 @@ trait CodegenSupport extends SparkPlan { attributes: Seq[Attribute], variables: Seq[ExprCode], required: AttributeSet): String = { - var evaluateVars = "" + val evaluateVars = new StringBuilder variables.zipWithIndex.foreach { case (ev, i) => if (ev.code != "" && required.contains(attributes(i))) { - evaluateVars += ev.code.trim + "\n" + evaluateVars.append(ev.code.trim + "\n") ev.code = "" } } - evaluateVars + evaluateVars.toString() } /** @@ -305,7 +305,6 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup def doCodeGen(): (CodegenContext, String) = { val ctx = new CodegenContext val code = child.asInstanceOf[CodegenSupport].produce(ctx, this) - val references = ctx.references.toArray val source = s""" public Object generate(Object[] references) { return new GeneratedIterator(references); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index 958350ada809e..d3b747fddbf69 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.execution +import scala.language.existentials + +import org.apache.spark.api.java.function.MapFunction import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -71,7 +74,7 @@ case class MapPartitions( * Applies the given function to each input row and encodes the result. */ case class MapElements( - func: Any => Any, + func: AnyRef, deserializer: Expression, serializer: Seq[NamedExpression], child: SparkPlan) extends UnaryNode with ObjectOperator with CodegenSupport { @@ -85,32 +88,40 @@ case class MapElements( child.asInstanceOf[CodegenSupport].produce(ctx, this) } - // Mark this as empty. We'll always deserialize the input row and apply the given function at the - // beginning. - override def usedInputs: AttributeSet = AttributeSet.empty - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { + val (funcClass, methodName) = func match { + case m: MapFunction[_, _] => classOf[MapFunction[_, _]] -> "call" + case _ => classOf[Any => Any] -> "apply" + } + val funcObj = Literal.create(func, ObjectType(funcClass)) + val resultObjType = serializer.head.collect { case b: BoundReference => b }.head.dataType + val callFunc = Invoke(funcObj, methodName, resultObjType, Seq(deserializer)) + val bound = ExpressionCanonicalizer.execute( - BindReferences.bindReference(deserializer, child.output)) + BindReferences.bindReference(callFunc, child.output)) ctx.currentVars = input val evaluated = bound.gen(ctx) - val resultObj = LambdaVariable(evaluated.value, evaluated.isNull, bound.dataType) + val resultObj = LambdaVariable(evaluated.value, evaluated.isNull, resultObjType) val outputFields = serializer.map(_ transform { case _: BoundReference => resultObj }) val resultVars = outputFields.map(_.gen(ctx)) s""" - ${evaluateRequiredVariables(output, Seq(evaluated), bound.references)} + ${evaluated.code} ${consume(ctx, resultVars)} """ } override protected def doExecute(): RDD[InternalRow] = { + val callFunc: Any => Any = func match { + case m: MapFunction[_, _] => i => m.asInstanceOf[MapFunction[Any, Any]].call(i) + case _ => func.asInstanceOf[Any => Any] + } child.execute().mapPartitionsInternal { iter => val getObject = generateToObject(deserializer, child.output) val outputObject = generateToRow(serializer) - iter.map(getObject).map(func).map(outputObject) + iter.map(getObject).map(callFunc).map(outputObject) } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index a5ab446e08d65..36196b82f70a0 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -102,7 +102,7 @@ public boolean call(String v) throws Exception { Dataset mapped = ds.map(new MapFunction() { @Override - public Integer call(String v) throws Exception { + public Integer call(String v) { return v.length(); } }, Encoders.INT()); @@ -171,7 +171,7 @@ public void testGroupBy() { KeyValueGroupedDataset grouped = ds.groupByKey( new MapFunction() { @Override - public Integer call(String v) throws Exception { + public Integer call(String v) { return v.length(); } }, @@ -221,7 +221,7 @@ public String call(String v1, String v2) throws Exception { KeyValueGroupedDataset grouped2 = ds2.groupByKey( new MapFunction() { @Override - public Integer call(Integer v) throws Exception { + public Integer call(Integer v) { return v / 2; } }, diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java index c4c455b6e6fa2..f11300bb748c5 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java @@ -73,7 +73,7 @@ private KeyValueGroupedDataset> generateGroupedD return ds.groupByKey( new MapFunction, String>() { @Override - public String call(Tuple2 value) throws Exception { + public String call(Tuple2 value) { return value._1(); } }, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 1733cb931f9d5..f73ca887f165a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.Row +import org.apache.spark.api.java.function.MapFunction +import org.apache.spark.sql.{Encoders, Row} import org.apache.spark.sql.execution.aggregate.TungstenAggregate import org.apache.spark.sql.execution.joins.BroadcastHashJoin import org.apache.spark.sql.functions.{avg, broadcast, col, max} From 35d1cad76b48147e5fc939ea08b89523fd38681a Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 4 Apr 2016 15:57:53 +0800 Subject: [PATCH 04/10] try catch --- .../spark/api/java/function/MapFunction.java | 2 +- .../sql/catalyst/expressions/objects.scala | 38 +++++++++++++------ .../apache/spark/sql/JavaDatasetSuite.java | 6 +-- .../sources/JavaDatasetAggregatorSuite.java | 2 +- 4 files changed, 31 insertions(+), 17 deletions(-) diff --git a/core/src/main/java/org/apache/spark/api/java/function/MapFunction.java b/core/src/main/java/org/apache/spark/api/java/function/MapFunction.java index 1e874c8cc8fb4..3ae6ef44898e1 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/MapFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/MapFunction.java @@ -23,5 +23,5 @@ * Base interface for a map function used in Dataset's map function. */ public interface MapFunction extends Serializable { - U call(T value); + U call(T value) throws Exception; } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 07b67a0240f0f..d73300d02622d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -112,23 +112,27 @@ case class Invoke( arguments: Seq[Expression] = Nil) extends Expression with NonSQLExpression { override def nullable: Boolean = true - override def children: Seq[Expression] = arguments.+:(targetObject) + override def children: Seq[Expression] = targetObject +: arguments override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported.") - lazy val method = targetObject.dataType match { + @transient lazy val method = targetObject.dataType match { case ObjectType(cls) => cls .getMethods .find(_.getName == functionName) .getOrElse(sys.error(s"Couldn't find $functionName on $cls")) - .getReturnType - .getName - case _ => "" + case _ => null } - lazy val unboxer = (dataType, method) match { + private def returnType = if (method == null) { + "" + } else { + method.getReturnType.getName + } + + lazy val unboxer = (dataType, returnType) match { case (IntegerType, "java.lang.Object") => (s: String) => s"((java.lang.Integer)$s).intValue()" case (LongType, "java.lang.Object") => (s: String) => @@ -155,21 +159,31 @@ case class Invoke( // If the function can return null, we do an extra check to make sure our null bit is still set // correctly. val objNullCheck = if (ctx.defaultValue(dataType) == "null") { - s"${ev.isNull} = ${ev.value} == null;" + s"boolean ${ev.isNull} = ${ev.value} == null;" } else { + ev.isNull = obj.isNull "" } val value = unboxer(s"${obj.value}.$functionName($argString)") + val evaluate = if (method == null || method.getExceptionTypes.isEmpty) { + s"$javaType ${ev.value} = ${obj.isNull} ? ${ctx.defaultValue(dataType)} : ($javaType) $value;" + } else { + s""" + $javaType ${ev.value} = null; + try { + ${ev.value} = ${obj.isNull} ? ${ctx.defaultValue(dataType)} : ($javaType) $value; + } catch (Exception e) { + org.apache.spark.unsafe.Platform.throwException(e); + } + """ + } + s""" ${obj.code} ${argGen.map(_.code).mkString("\n")} - - boolean ${ev.isNull} = ${obj.isNull}; - $javaType ${ev.value} = - ${ev.isNull} ? - ${ctx.defaultValue(dataType)} : ($javaType) $value; + $evaluate $objNullCheck """ } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 36196b82f70a0..a5ab446e08d65 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -102,7 +102,7 @@ public boolean call(String v) throws Exception { Dataset mapped = ds.map(new MapFunction() { @Override - public Integer call(String v) { + public Integer call(String v) throws Exception { return v.length(); } }, Encoders.INT()); @@ -171,7 +171,7 @@ public void testGroupBy() { KeyValueGroupedDataset grouped = ds.groupByKey( new MapFunction() { @Override - public Integer call(String v) { + public Integer call(String v) throws Exception { return v.length(); } }, @@ -221,7 +221,7 @@ public String call(String v1, String v2) throws Exception { KeyValueGroupedDataset grouped2 = ds2.groupByKey( new MapFunction() { @Override - public Integer call(Integer v) { + public Integer call(Integer v) throws Exception { return v / 2; } }, diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java index f11300bb748c5..c4c455b6e6fa2 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java @@ -73,7 +73,7 @@ private KeyValueGroupedDataset> generateGroupedD return ds.groupByKey( new MapFunction, String>() { @Override - public String call(Tuple2 value) { + public String call(Tuple2 value) throws Exception { return value._1(); } }, From 52f164db8512e3ff2bfd9fb57779e57f2333e3e7 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 4 Apr 2016 17:10:04 +0800 Subject: [PATCH 05/10] add comment --- .../main/scala/org/apache/spark/sql/execution/objects.scala | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index d3b747fddbf69..91421a590ab92 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -72,6 +72,11 @@ case class MapPartitions( /** * Applies the given function to each input row and encodes the result. + * + * TODO: Each serializer expression needs the result object which is returned by the given function, + * as input. This operator uses some tricks to make sure we only calculate the result object once, + * we can use [[Project]] to replace this operator after we make subexpression elimination work in + * whole stage codegen. */ case class MapElements( func: AnyRef, From 60a0da1fa20b124d59cfb5b80065ad8fb179aa37 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 5 Apr 2016 11:17:28 +0800 Subject: [PATCH 06/10] address comments --- .../scala/org/apache/spark/sql/Dataset.scala | 24 ++++++------------- .../apache/spark/sql/execution/objects.scala | 11 +++++---- 2 files changed, 13 insertions(+), 22 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 7fbc42b154ae0..5318e0ed14241 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -748,7 +748,8 @@ class Dataset[T] private[sql]( implicit val tuple2Encoder: Encoder[(T, U)] = ExpressionEncoder.tuple(this.unresolvedTEncoder, other.unresolvedTEncoder) - withTypedPlan[(T, U)](other, encoderFor[(T, U)]) { (left, right) => + + withTypedPlan { Project( leftData :: rightData :: Nil, joined.analyzed) @@ -1882,11 +1883,8 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - def map[U : Encoder](func: T => U): Dataset[U] = { - new Dataset[U]( - sqlContext, - MapElements[T, U](func, logicalPlan), - implicitly[Encoder[U]]) + def map[U : Encoder](func: T => U): Dataset[U] = withTypedPlan { + MapElements[T, U](func, logicalPlan) } /** @@ -1900,10 +1898,7 @@ class Dataset[T] private[sql]( @Experimental def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = { implicit val uEnc = encoder - new Dataset[U]( - sqlContext, - MapElements[T, U](func, logicalPlan), - uEnc) + withTypedPlan(MapElements[T, U](func, logicalPlan)) } /** @@ -2380,12 +2375,7 @@ class Dataset[T] private[sql]( } /** A convenient function to wrap a logical plan and produce a Dataset. */ - @inline private def withTypedPlan(logicalPlan: => LogicalPlan): Dataset[T] = { - new Dataset[T](sqlContext, logicalPlan, encoder) + @inline private def withTypedPlan[U : Encoder](logicalPlan: => LogicalPlan): Dataset[U] = { + Dataset(sqlContext, logicalPlan) } - - private[sql] def withTypedPlan[R]( - other: Dataset[_], encoder: Encoder[R])( - f: (LogicalPlan, LogicalPlan) => LogicalPlan): Dataset[R] = - new Dataset[R](sqlContext, f(logicalPlan, other.logicalPlan), encoder) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index 91421a590ab92..f48f3f09c74f0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -73,10 +73,11 @@ case class MapPartitions( /** * Applies the given function to each input row and encodes the result. * - * TODO: Each serializer expression needs the result object which is returned by the given function, - * as input. This operator uses some tricks to make sure we only calculate the result object once, - * we can use [[Project]] to replace this operator after we make subexpression elimination work in - * whole stage codegen. + * Note that, each serializer expression needs the result object which is returned by the given + * function, as input. This operator uses some tricks to make sure we only calculate the result + * object once. We don't use [[Project]] directly as subexpression elimination doesn't work with + * whole stage codegen and it's confusing to show the un-common-subexpression-eliminated version of + * a project while explain. */ case class MapElements( func: AnyRef, @@ -126,7 +127,7 @@ case class MapElements( child.execute().mapPartitionsInternal { iter => val getObject = generateToObject(deserializer, child.output) val outputObject = generateToRow(serializer) - iter.map(getObject).map(callFunc).map(outputObject) + iter.map(row => outputObject(callFunc(getObject(row)))) } } From dbfb4ac03957a2fd55e36f93e790cdd222ffe541 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 5 Apr 2016 13:52:08 +0800 Subject: [PATCH 07/10] minor fix --- .../sql/catalyst/expressions/objects.scala | 26 ++++++++----------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index d73300d02622d..6bef5e44f3312 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -119,20 +119,16 @@ case class Invoke( @transient lazy val method = targetObject.dataType match { case ObjectType(cls) => - cls - .getMethods - .find(_.getName == functionName) - .getOrElse(sys.error(s"Couldn't find $functionName on $cls")) - case _ => null - } - - private def returnType = if (method == null) { - "" - } else { - method.getReturnType.getName + val m = cls.getMethods.find(_.getName == functionName) + if (m.isEmpty) { + sys.error(s"Couldn't find $functionName on $cls") + } else { + m + } + case _ => None } - lazy val unboxer = (dataType, returnType) match { + lazy val unboxer = (dataType, method.map(_.getReturnType.getName).getOrElse("")) match { case (IntegerType, "java.lang.Object") => (s: String) => s"((java.lang.Integer)$s).intValue()" case (LongType, "java.lang.Object") => (s: String) => @@ -167,13 +163,13 @@ case class Invoke( val value = unboxer(s"${obj.value}.$functionName($argString)") - val evaluate = if (method == null || method.getExceptionTypes.isEmpty) { + val evaluate = if (method.forall(_.getExceptionTypes.isEmpty)) { s"$javaType ${ev.value} = ${obj.isNull} ? ${ctx.defaultValue(dataType)} : ($javaType) $value;" } else { s""" - $javaType ${ev.value} = null; + $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; try { - ${ev.value} = ${obj.isNull} ? ${ctx.defaultValue(dataType)} : ($javaType) $value; + ${ev.value} = ${obj.isNull} ? ${ctx.defaultValue(javaType)} : ($javaType) $value; } catch (Exception e) { org.apache.spark.unsafe.Platform.throwException(e); } From a5b0d57bb7bce985771ede00b0f116a4cce82900 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 5 Apr 2016 22:09:43 +0800 Subject: [PATCH 08/10] add benchmark --- .../sql/catalyst/optimizer/Optimizer.scala | 9 +++ .../apache/spark/sql/DatasetBenchmark.scala | 79 +++++++++++++++++++ 2 files changed, 88 insertions(+) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 69b09bcb35f01..c085a377ff0fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -136,6 +136,7 @@ object SamplePushDown extends Rule[LogicalPlan] { * representation of data item. For example back to back map operations. */ object EliminateSerialization extends Rule[LogicalPlan] { + // TODO: find a more general way to do this optimization. def apply(plan: LogicalPlan): LogicalPlan = plan transform { case m @ MapPartitions(_, deserializer, _, child: ObjectOperator) if !deserializer.isInstanceOf[Attribute] && @@ -144,6 +145,14 @@ object EliminateSerialization extends Rule[LogicalPlan] { m.copy( deserializer = childWithoutSerialization.output.head, child = childWithoutSerialization) + + case m @ MapElements(_, deserializer, _, child: ObjectOperator) + if !deserializer.isInstanceOf[Attribute] && + deserializer.dataType == child.outputObject.dataType => + val childWithoutSerialization = child.withObjectOutput + m.copy( + deserializer = childWithoutSerialization.output.head, + child = childWithoutSerialization) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala new file mode 100644 index 0000000000000..deec8064c175d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.SparkContext +import org.apache.spark.api.java.function.MapFunction +import org.apache.spark.util.Benchmark + +/** + * Benchmark for Dataset typed operations. + */ +object DatasetBenchmark { + + case class Data(i: Int, s: String) + + def main(args: Array[String]): Unit = { + val sparkContext = new SparkContext("local[*]", "benchmark") + val sqlContext = new SQLContext(sparkContext) + + import sqlContext.implicits._ + + val numRows = 10000000 + val ds = sqlContext.range(numRows).map(l => Data(l.toInt, l.toString)) + ds.cache() + ds.collect() // make sure data are cached + + val benchmark = new Benchmark("Dataset.map", numRows) + + val scalaFunc = (d: Data) => Data(d.i + 1, d.s) + benchmark.addCase("scala function") { iter => + var res = ds + var i = 0 + while (i < 10) { + res = res.map(scalaFunc) + i += 1 + } + res.queryExecution.toRdd.count() + } + + val javaFunc = new MapFunction[Data, Data] { + override def call(d: Data): Data = Data(d.i + 1, d.s) + } + val enc = implicitly[Encoder[Data]] + benchmark.addCase("java function") { iter => + var res = ds + var i = 0 + while (i < 10) { + res = res.map(javaFunc, enc) + i += 1 + } + res.queryExecution.toRdd.count() + } + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.11.4 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + Dataset.map: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + scala function 1029 / 1080 9.7 102.9 1.0X + java function 965 / 999 10.4 96.5 1.1X + */ + benchmark.run() + } +} From 0ce90fe22da1d5887c669f859ab8af099f73e577 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 6 Apr 2016 10:11:17 +0800 Subject: [PATCH 09/10] improve benchmark --- .../sql/catalyst/analysis/unresolved.scala | 2 +- .../sql/catalyst/plans/logical/object.scala | 10 ++-- .../apache/spark/sql/DatasetBenchmark.scala | 50 +++++++++++-------- 3 files changed, 33 insertions(+), 29 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index b2f362b6b8a38..4ec43aba02d66 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -345,7 +345,7 @@ case class UnresolvedAlias(child: Expression, aliasName: Option[String] = None) * @param inputAttributes The input attributes used to resolve deserializer expression, can be empty * if we want to resolve deserializer by children output. */ -case class UnresolvedDeserializer(deserializer: Expression, inputAttributes: Seq[Attribute]) +case class UnresolvedDeserializer(deserializer: Expression, inputAttributes: Seq[Attribute] = Nil) extends UnaryExpression with Unevaluable with NonSQLExpression { // The input attributes used to resolve deserializer expression must be all resolved. require(inputAttributes.forall(_.resolved), "Input attributes must all be resolved.") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 90b652f31aad1..ec33a538a914d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -65,7 +65,7 @@ object MapPartitions { child: LogicalPlan): MapPartitions = { MapPartitions( func.asInstanceOf[Iterator[Any] => Iterator[Any]], - UnresolvedDeserializer(encoderFor[T].deserializer, Nil), + UnresolvedDeserializer(encoderFor[T].deserializer), encoderFor[U].namedExpressions, child) } @@ -89,7 +89,7 @@ object MapElements { child: LogicalPlan): MapElements = { MapElements( func, - encoderFor[T].deserializer, + UnresolvedDeserializer(encoderFor[T].deserializer), encoderFor[U].namedExpressions, child) } @@ -105,9 +105,7 @@ case class MapElements( func: AnyRef, deserializer: Expression, serializer: Seq[NamedExpression], - child: LogicalPlan) extends UnaryNode with ObjectOperator { - override def deserializers: Seq[(Expression, Seq[Attribute])] = Seq(deserializer -> child.output) -} + child: LogicalPlan) extends UnaryNode with ObjectOperator /** Factory for constructing new `AppendColumn` nodes. */ object AppendColumns { @@ -116,7 +114,7 @@ object AppendColumns { child: LogicalPlan): AppendColumns = { new AppendColumns( func.asInstanceOf[Any => Any], - UnresolvedDeserializer(encoderFor[T].deserializer, Nil), + UnresolvedDeserializer(encoderFor[T].deserializer), encoderFor[U].namedExpressions, child) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala index deec8064c175d..163759b8d730c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala @@ -18,61 +18,67 @@ package org.apache.spark.sql import org.apache.spark.SparkContext -import org.apache.spark.api.java.function.MapFunction +import org.apache.spark.sql.types.StringType import org.apache.spark.util.Benchmark /** - * Benchmark for Dataset typed operations. + * Benchmark for Dataset typed operations comparing with DataFrame and RDD versions. */ object DatasetBenchmark { - case class Data(i: Int, s: String) + case class Data(l: Long, s: String) def main(args: Array[String]): Unit = { - val sparkContext = new SparkContext("local[*]", "benchmark") + val sparkContext = new SparkContext("local[*]", "Dataset benchmark") val sqlContext = new SQLContext(sparkContext) import sqlContext.implicits._ val numRows = 10000000 - val ds = sqlContext.range(numRows).map(l => Data(l.toInt, l.toString)) - ds.cache() - ds.collect() // make sure data are cached + val df = sqlContext.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s")) - val benchmark = new Benchmark("Dataset.map", numRows) + val benchmark = new Benchmark("back-to-back map", numRows) - val scalaFunc = (d: Data) => Data(d.i + 1, d.s) - benchmark.addCase("scala function") { iter => - var res = ds + val scalaFunc = (d: Data) => Data(d.l + 1, d.s) + benchmark.addCase("Dataset") { iter => + var res = df.as[Data] var i = 0 while (i < 10) { res = res.map(scalaFunc) i += 1 } - res.queryExecution.toRdd.count() + res.queryExecution.toRdd.foreach(_ => Unit) } - val javaFunc = new MapFunction[Data, Data] { - override def call(d: Data): Data = Data(d.i + 1, d.s) + benchmark.addCase("DataFrame") { iter => + var res = df + var i = 0 + while (i < 10) { + res = res.select($"l" + 1 as "l") + i += 1 + } + res.queryExecution.toRdd.foreach(_ => Unit) } - val enc = implicitly[Encoder[Data]] - benchmark.addCase("java function") { iter => - var res = ds + + val rdd = sparkContext.range(1, numRows).map(l => Data(l, l.toString)) + benchmark.addCase("RDD") { iter => + var res = rdd var i = 0 while (i < 10) { - res = res.map(javaFunc, enc) + res = rdd.map(scalaFunc) i += 1 } - res.queryExecution.toRdd.count() + res.foreach(_ => Unit) } /* Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.11.4 Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - Dataset.map: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + back-to-back map: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - scala function 1029 / 1080 9.7 102.9 1.0X - java function 965 / 999 10.4 96.5 1.1X + Dataset 902 / 995 11.1 90.2 1.0X + DataFrame 132 / 167 75.5 13.2 6.8X + RDD 216 / 237 46.3 21.6 4.2X */ benchmark.run() } From af0d19312fd7dc89aed192a9787f224c9d347180 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 6 Apr 2016 10:30:54 +0800 Subject: [PATCH 10/10] some renaming --- .../org/apache/spark/sql/DatasetBenchmark.scala | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala index 163759b8d730c..6eb952445f221 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala @@ -36,15 +36,16 @@ object DatasetBenchmark { val numRows = 10000000 val df = sqlContext.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s")) + val numChains = 10 val benchmark = new Benchmark("back-to-back map", numRows) - val scalaFunc = (d: Data) => Data(d.l + 1, d.s) + val func = (d: Data) => Data(d.l + 1, d.s) benchmark.addCase("Dataset") { iter => var res = df.as[Data] var i = 0 - while (i < 10) { - res = res.map(scalaFunc) + while (i < numChains) { + res = res.map(func) i += 1 } res.queryExecution.toRdd.foreach(_ => Unit) @@ -53,7 +54,7 @@ object DatasetBenchmark { benchmark.addCase("DataFrame") { iter => var res = df var i = 0 - while (i < 10) { + while (i < numChains) { res = res.select($"l" + 1 as "l") i += 1 } @@ -64,8 +65,8 @@ object DatasetBenchmark { benchmark.addCase("RDD") { iter => var res = rdd var i = 0 - while (i < 10) { - res = rdd.map(scalaFunc) + while (i < numChains) { + res = rdd.map(func) i += 1 } res.foreach(_ => Unit)