Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-19008][SQL] Improve performance of Dataset.map by eliminating boxing/unboxing #17172

Closed
wants to merge 12 commits into from

Conversation

kiszk
Copy link
Member

@kiszk kiszk commented Mar 6, 2017

What changes were proposed in this pull request?

This PR improve performance of Dataset.map() for primitive types by removing boxing/unbox operations. This is based on the discussion with @cloud-fan.

Current Catalyst generates a method call to a apply() method of an anonymous function written in Scala. The types of an argument and return value are java.lang.Object. As a result, each method call for a primitive value involves a pair of unboxing and boxing for calling this apply() method and a pair of boxing and unboxing for returning from this apply() method.

This PR directly calls a specialized version of a apply() method without boxing and unboxing. For example, if types of an arguments ant return value is int, this PR generates a method call to apply$mcII$sp. This PR supports any combination of Int, Long, Float, and Double.

The following is a benchmark result using this program with 4.7x. Here is a Dataset part of this program.

Without this PR

OpenJDK 64-Bit Server VM 1.8.0_111-8u111-b14-2ubuntu0.16.04.2-b14 on Linux 4.4.0-47-generic
Intel(R) Xeon(R) CPU E5-2667 v3 @ 3.20GHz
back-to-back map:                        Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
RDD                                           1923 / 1952         52.0          19.2       1.0X
DataFrame                                      526 /  548        190.2           5.3       3.7X
Dataset                                       3094 / 3154         32.3          30.9       0.6X

With this PR

OpenJDK 64-Bit Server VM 1.8.0_111-8u111-b14-2ubuntu0.16.04.2-b14 on Linux 4.4.0-47-generic
Intel(R) Xeon(R) CPU E5-2667 v3 @ 3.20GHz
back-to-back map:                        Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
RDD                                           1883 / 1892         53.1          18.8       1.0X
DataFrame                                      502 /  642        199.1           5.0       3.7X
Dataset                                        657 /  784        152.2           6.6       2.9X
  def backToBackMap(spark: SparkSession, numRows: Long, numChains: Int): Benchmark = {
    import spark.implicits._
    val rdd = spark.sparkContext.range(0, numRows)
    val ds = spark.range(0, numRows)
    val func = (l: Long) => l + 1
    val benchmark = new Benchmark("back-to-back map", numRows)
...
    benchmark.addCase("Dataset") { iter =>
      var res = ds.as[Long]
      var i = 0
      while (i < numChains) {
        res = res.map(func)
        i += 1
      }
      res.queryExecution.toRdd.foreach(_ => Unit)
    }
    benchmark
  }

A motivating example

Seq(1, 2, 3).toDS.map(i => i * 7).show

Generated code without this PR

/* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 006 */   private Object[] references;
/* 007 */   private scala.collection.Iterator[] inputs;
/* 008 */   private scala.collection.Iterator inputadapter_input;
/* 009 */   private UnsafeRow deserializetoobject_result;
/* 010 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder deserializetoobject_holder;
/* 011 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter deserializetoobject_rowWriter;
/* 012 */   private int mapelements_argValue;
/* 013 */   private UnsafeRow mapelements_result;
/* 014 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder mapelements_holder;
/* 015 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter mapelements_rowWriter;
/* 016 */   private UnsafeRow serializefromobject_result;
/* 017 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder serializefromobject_holder;
/* 018 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter serializefromobject_rowWriter;
/* 019 */
/* 020 */   public GeneratedIterator(Object[] references) {
/* 021 */     this.references = references;
/* 022 */   }
/* 023 */
/* 024 */   public void init(int index, scala.collection.Iterator[] inputs) {
/* 025 */     partitionIndex = index;
/* 026 */     this.inputs = inputs;
/* 027 */     inputadapter_input = inputs[0];
/* 028 */     deserializetoobject_result = new UnsafeRow(1);
/* 029 */     this.deserializetoobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(deserializetoobject_result, 0);
/* 030 */     this.deserializetoobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(deserializetoobject_holder, 1);
/* 031 */
/* 032 */     mapelements_result = new UnsafeRow(1);
/* 033 */     this.mapelements_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(mapelements_result, 0);
/* 034 */     this.mapelements_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(mapelements_holder, 1);
/* 035 */     serializefromobject_result = new UnsafeRow(1);
/* 036 */     this.serializefromobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(serializefromobject_result, 0);
/* 037 */     this.serializefromobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(serializefromobject_holder, 1);
/* 038 */
/* 039 */   }
/* 040 */
/* 041 */   protected void processNext() throws java.io.IOException {
/* 042 */     while (inputadapter_input.hasNext() && !stopEarly()) {
/* 043 */       InternalRow inputadapter_row = (InternalRow) inputadapter_input.next();
/* 044 */       int inputadapter_value = inputadapter_row.getInt(0);
/* 045 */
/* 046 */       boolean mapelements_isNull = true;
/* 047 */       int mapelements_value = -1;
/* 048 */       if (!false) {
/* 049 */         mapelements_argValue = inputadapter_value;
/* 050 */
/* 051 */         mapelements_isNull = false;
/* 052 */         if (!mapelements_isNull) {
/* 053 */           Object mapelements_funcResult = null;
/* 054 */           mapelements_funcResult = ((scala.Function1) references[0]).apply(mapelements_argValue);
/* 055 */           if (mapelements_funcResult == null) {
/* 056 */             mapelements_isNull = true;
/* 057 */           } else {
/* 058 */             mapelements_value = (Integer) mapelements_funcResult;
/* 059 */           }
/* 060 */
/* 061 */         }
/* 062 */
/* 063 */       }
/* 064 */
/* 065 */       serializefromobject_rowWriter.zeroOutNullBytes();
/* 066 */
/* 067 */       if (mapelements_isNull) {
/* 068 */         serializefromobject_rowWriter.setNullAt(0);
/* 069 */       } else {
/* 070 */         serializefromobject_rowWriter.write(0, mapelements_value);
/* 071 */       }
/* 072 */       append(serializefromobject_result);
/* 073 */       if (shouldStop()) return;
/* 074 */     }
/* 075 */   }
/* 076 */ }

Generated code with this PR (lines 48-56 are changed)

/* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 006 */   private Object[] references;
/* 007 */   private scala.collection.Iterator[] inputs;
/* 008 */   private scala.collection.Iterator inputadapter_input;
/* 009 */   private UnsafeRow deserializetoobject_result;
/* 010 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder deserializetoobject_holder;
/* 011 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter deserializetoobject_rowWriter;
/* 012 */   private int mapelements_argValue;
/* 013 */   private UnsafeRow mapelements_result;
/* 014 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder mapelements_holder;
/* 015 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter mapelements_rowWriter;
/* 016 */   private UnsafeRow serializefromobject_result;
/* 017 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder serializefromobject_holder;
/* 018 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter serializefromobject_rowWriter;
/* 019 */
/* 020 */   public GeneratedIterator(Object[] references) {
/* 021 */     this.references = references;
/* 022 */   }
/* 023 */
/* 024 */   public void init(int index, scala.collection.Iterator[] inputs) {
/* 025 */     partitionIndex = index;
/* 026 */     this.inputs = inputs;
/* 027 */     inputadapter_input = inputs[0];
/* 028 */     deserializetoobject_result = new UnsafeRow(1);
/* 029 */     this.deserializetoobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(deserializetoobject_result, 0);
/* 030 */     this.deserializetoobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(deserializetoobject_holder, 1);
/* 031 */
/* 032 */     mapelements_result = new UnsafeRow(1);
/* 033 */     this.mapelements_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(mapelements_result, 0);
/* 034 */     this.mapelements_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(mapelements_holder, 1);
/* 035 */     serializefromobject_result = new UnsafeRow(1);
/* 036 */     this.serializefromobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(serializefromobject_result, 0);
/* 037 */     this.serializefromobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(serializefromobject_holder, 1);
/* 038 */
/* 039 */   }
/* 040 */
/* 041 */   protected void processNext() throws java.io.IOException {
/* 042 */     while (inputadapter_input.hasNext() && !stopEarly()) {
/* 043 */       InternalRow inputadapter_row = (InternalRow) inputadapter_input.next();
/* 044 */       int inputadapter_value = inputadapter_row.getInt(0);
/* 045 */
/* 046 */       boolean mapelements_isNull = true;
/* 047 */       int mapelements_value = -1;
/* 048 */       if (!false) {
/* 049 */         mapelements_argValue = inputadapter_value;
/* 050 */
/* 051 */         mapelements_isNull = false;
/* 052 */         if (!mapelements_isNull) {
/* 053 */           mapelements_value = ((scala.Function1) references[0]).apply$mcII$sp(mapelements_argValue);
/* 054 */         }
/* 055 */
/* 056 */       }
/* 057 */
/* 058 */       serializefromobject_rowWriter.zeroOutNullBytes();
/* 059 */
/* 060 */       if (mapelements_isNull) {
/* 061 */         serializefromobject_rowWriter.setNullAt(0);
/* 062 */       } else {
/* 063 */         serializefromobject_rowWriter.write(0, mapelements_value);
/* 064 */       }
/* 065 */       append(serializefromobject_result);
/* 066 */       if (shouldStop()) return;
/* 067 */     }
/* 068 */   }
/* 069 */ }

Java bytecode for methods for i => i * 7

$ javap -c Test\$\$anonfun\$5\$\$anonfun\$apply\$mcV\$sp\$1.class
Compiled from "Test.scala"
public final class org.apache.spark.sql.Test$$anonfun$5$$anonfun$apply$mcV$sp$1 extends scala.runtime.AbstractFunction1$mcII$sp implements scala.Serializable {
  public static final long serialVersionUID;

  public final int apply(int);
    Code:
       0: aload_0
       1: iload_1
       2: invokevirtual #18                 // Method apply$mcII$sp:(I)I
       5: ireturn

  public int apply$mcII$sp(int);
    Code:
       0: iload_1
       1: bipush        7
       3: imul
       4: ireturn

  public final java.lang.Object apply(java.lang.Object);
    Code:
       0: aload_0
       1: aload_1
       2: invokestatic  #29                 // Method scala/runtime/BoxesRunTime.unboxToInt:(Ljava/lang/Object;)I
       5: invokevirtual #31                 // Method apply:(I)I
       8: invokestatic  #35                 // Method scala/runtime/BoxesRunTime.boxToInteger:(I)Ljava/lang/Integer;
      11: areturn

  public org.apache.spark.sql.Test$$anonfun$5$$anonfun$apply$mcV$sp$1(org.apache.spark.sql.Test$$anonfun$5);
    Code:
       0: aload_0
       1: invokespecial #42                 // Method scala/runtime/AbstractFunction1$mcII$sp."<init>":()V
       4: return
}

How was this patch tested?

Added new test suites to DatasetPrimitiveSuite.

@SparkQA
Copy link

SparkQA commented Mar 6, 2017

Test build #73971 has finished for PR 17172 at commit d8b5f8d.

  • This patch fails Scala style tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@@ -17,6 +17,8 @@

package org.apache.spark.sql.execution

import com.sun.org.apache.xalan.internal.xsltc.compiler.util.VoidType
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a mistaken import? I don't see it used in the change and can't imagine we'd be invoking Xalan here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that is why the first try was failed. It was unintentionally imported during my debugging.

@@ -217,9 +219,33 @@ case class MapElementsExec(
}

override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
val inType = if (child.output.length == 1) child.output(0).dataType else NullType
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two are only needed inside the case _ block right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. I simplified their scope.

@SparkQA
Copy link

SparkQA commented Mar 6, 2017

Test build #73987 has finished for PR 17172 at commit a885907.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Mar 6, 2017

Test build #73990 has finished for PR 17172 at commit 65fa05a.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@kiszk
Copy link
Member Author

kiszk commented Mar 6, 2017

cc @cloud-fan

@@ -219,7 +219,30 @@ case class MapElementsExec(
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
val (funcClass, methodName) = func match {
case m: MapFunction[_, _] => classOf[MapFunction[_, _]] -> "call"
case _ => classOf[Any => Any] -> "apply"
case _ =>
(if (child.output.length == 1) child.output(0).dataType else NullType,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the if is not needed, see the assert in ObjectConsumerExec

outputObjAttr.dataType) match {
// if a pair of an argument and return types is one of specific types
// whose specialized method (apply$mc..$sp) is generated by scalac,
// Catalyst generated a direct method call to the specialized method.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you link to some official document or blogpost?

@cloud-fan
Copy link
Contributor

This is cool! Can you also update the benchmark result in DatasetBenchmark?

@kiszk
Copy link
Member Author

kiszk commented Mar 7, 2017

The latest DatasetBenchmark has a function val func = (d: Data) => Data(d.l + 1, d.s) that this PR cannot be applied to.

Do we add a new benchmark with a function val func = (d: Data) => Data(d.l + 1) based on this suggestion instead of replacing the current benchmark with the function val func = (d: Data) => Data(d.l + 1, d.s)?

@cloud-fan
Copy link
Contributor

yea let's add a new case in the benchmark

@SparkQA
Copy link

SparkQA commented Mar 8, 2017

Test build #74201 has finished for PR 17172 at commit dfbce2a.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

// SpecializeTypes.scala
// http://www.cakesolutions.net/teamblogs/scala-dissection-functions
// http://axel22.github.io/2013/11/03/specialization-quirks.html
case (IntegerType, IntegerType) => classOf[Int => Int] -> "apply$mcII$sp"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about boolean type?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. I overlooked boolean type for return type.

// http://www.cakesolutions.net/teamblogs/scala-dissection-functions
// http://axel22.github.io/2013/11/03/specialization-quirks.html
case (IntegerType, IntegerType) => classOf[Int => Int] -> "apply$mcII$sp"
case (IntegerType, LongType) => classOf[Int => Long] -> "apply$mcJI$sp"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible do it in a composable way instead of enumerating all combinations?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I found a composable way.

Use compositional approach instead of enumeration approach
@SparkQA
Copy link

SparkQA commented Mar 9, 2017

Test build #74249 has finished for PR 17172 at commit 8ee91af.

  • This patch fails Scala style tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Mar 9, 2017

Test build #74257 has finished for PR 17172 at commit 1fb2933.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@@ -216,10 +217,39 @@ case class MapElementsExec(
child.asInstanceOf[CodegenSupport].produce(ctx, this)
}

private def getMethodType(dt: DataType, isOutput: Boolean): String = {
dt match {
case BooleanType if isOutput => "Z"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so boolean type can't be a parameter?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, this code specializes boolean type only for return type.

@@ -216,10 +217,39 @@ case class MapElementsExec(
child.asInstanceOf[CodegenSupport].produce(ctx, this)
}

private def getMethodType(dt: DataType, isOutput: Boolean): String = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's return a Option[String]

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I will do this

@@ -165,11 +208,23 @@ object DatasetBenchmark {
val numRows = 100000000
val numChains = 10

val benchmark = backToBackMap(spark, numRows, numChains)
val benchmark0 = backToBackMapLong(spark, numRows, numChains)
val benchmark1 = backToBackMap(spark, numRows, numChains)
val benchmark2 = backToBackFilter(spark, numRows, numChains)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can also add a new case for backToBackFilterLong, as we handle boolean type now.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

filter() is handled by FilterExec(). Should this PR handle filter(), too? Or, do I open another PR for filter()?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correction: FilterExec() generates code. TypedFilter generated code piece for a method invocation.

Copy link
Member Author

@kiszk kiszk Mar 10, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a new case backToBackFilterLong, but made a mistake to put different result.
Let me correct it soon.

case _ => null
}
}

override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
val (funcClass, methodName) = func match {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's put this thing in a util so that FilterExec can also use it

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. Now, can generate a call to a specialized method for Dataset.filter().

@SparkQA
Copy link

SparkQA commented Mar 10, 2017

Test build #74296 has finished for PR 17172 at commit 200cec7.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Mar 10, 2017

Test build #74297 has finished for PR 17172 at commit b25b191.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@cloud-fan
Copy link
Contributor

cool! merging to master!

@asfgit asfgit closed this in 5949e6c Mar 10, 2017
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
4 participants