Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Spark-12374][SPARK-12150][SQL] Adding logical/physical operators for Range #10335

Closed
wants to merge 8 commits into from

Conversation

gatorsmile
Copy link
Member

Based on the suggestions from @marmbrus , added logical/physical operators for Range for improving the performance.

Also added another API for resolving the JIRA Spark-12150.

Could you take a look at my implementation, @marmbrus ? If not good, I can rework it. : )

Thank you very much!

}

bufferHolder.reset()
unsafeRow.pointTo(bufferHolder.buffer, 1, bufferHolder.totalSize())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why point to the same buffer after every iteration? We could do this during the construction of the iterator. BufferHolder might be overkill here, pointing to an array of 16 bytes should also do the trick.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hvanhovell Thank you! You are right, let me move it to the construction of the Iterator.

When writing a prototype, I used a 16 bytes array. I was afraid the Spark community prefers to using the existing library functions here. Thus, I changed it to bufferHolder.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am fine with both, as long as we move out of the next call.

@SparkQA
Copy link

SparkQA commented Dec 16, 2015

Test build #47852 has finished for PR 10335 at commit 2aab4d6.

  • This patch fails Spark unit tests.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):\n * case class Range(\n * case class Range(\n

*
* [[LeafNode]]s must override this.
*/
val sizeInBytes = LongType.defaultSize * numElements
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

protected

@marmbrus
Copy link
Contributor

The high level structure of this look pretty good to me. Could you also post some numbers from a micro benchmark? It would be good to make sure we're actually speeding things up.

@gatorsmile
Copy link
Member Author

Sure, will do It! Thank you for your guidance!

@SparkQA
Copy link

SparkQA commented Dec 17, 2015

Test build #47889 has finished for PR 10335 at commit 258b40a.

  • This patch fails Spark unit tests.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):\n * case class Range(\n * case class Range(\n

}
val safePartitionStart = getSafeMargin(partitionStart)
val safePartitionEnd = getSafeMargin(partitionEnd)
val bufferHolder = new BufferHolder(LongType.defaultSize)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I am not confident that this works. You are allocating 8-bytes by calling this constructor. Without a growing the buffer, you allocate a byte array of 8 bytes to the UnsafeRow, which is happening here. You would need at least 16 bytes for UnsafeRow to work (8 for the bitset and 8 for the long).

BufferHolder is meant to be used with Unsafe*Writer classes. I don't think it adds much value here. I think we should just use a 16 byte array instead.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uh... True! Let me correct it. Thank you!

@SparkQA
Copy link

SparkQA commented Dec 18, 2015

Test build #47972 has finished for PR 10335 at commit 576fea9.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):\n * case class Range(\n * case class Range(\n

@gatorsmile
Copy link
Member Author

RDD range API with collect (workload 10,000,000 rows):

scala> val startTime = System.currentTimeMillis; sc.range(0, 10000000, 1, 15).collect(); val endTime = System.currentTimeMillis; val start = new Timestamp(startTime); val end = new Timestamp(endTime); val elapsed = (endTime - startTime)/ 1000.0
startTime: Long = 1450412989642
endTime: Long = 1450412990221
start: java.sql.Timestamp = 2015-12-17 20:29:49.642
end: java.sql.Timestamp = 2015-12-17 20:29:50.221
elapsed: Double = 0.579

scala> val startTime = System.currentTimeMillis; sc.range(0, 10000000, 1, 15).collect(); val endTime = System.currentTimeMillis; val start = new Timestamp(startTime); val end = new Timestamp(endTime); val elapsed = (endTime - startTime)/ 1000.0
startTime: Long = 1450412991330
endTime: Long = 1450412991892
start: java.sql.Timestamp = 2015-12-17 20:29:51.33
end: java.sql.Timestamp = 2015-12-17 20:29:51.892
elapsed: Double = 0.562

scala> val startTime = System.currentTimeMillis; sc.range(0, 10000000, 1, 15).collect(); val endTime = System.currentTimeMillis; val start = new Timestamp(startTime); val end = new Timestamp(endTime); val elapsed = (endTime - startTime)/ 1000.0
startTime: Long = 1450412995824
endTime: Long = 1450412996666
start: java.sql.Timestamp = 2015-12-17 20:29:55.824
end: java.sql.Timestamp = 2015-12-17 20:29:56.666
elapsed: Double = 0.842

@gatorsmile
Copy link
Member Author

New range API using logical/physical operators with collect (workload 10,000,000 rows):

scala> val startTime = System.currentTimeMillis; sqlContext.range(0, 10000000, 1, 15).collect(); val endTime = System.currentTimeMillis; val start = new Timestamp(startTime); val end = new Timestamp(endTime); val elapsed = (endTime - startTime)/ 1000.0
startTime: Long = 1450413004393                                                 
endTime: Long = 1450413022682
start: java.sql.Timestamp = 2015-12-17 20:30:04.393
end: java.sql.Timestamp = 2015-12-17 20:30:22.682
elapsed: Double = 18.289

scala> val startTime = System.currentTimeMillis; sqlContext.range(0, 10000000, 1, 15).collect(); val endTime = System.currentTimeMillis; val start = new Timestamp(startTime); val end = new Timestamp(endTime); val elapsed = (endTime - startTime)/ 1000.0
startTime: Long = 1450413029542                                                 
endTime: Long = 1450413050828
start: java.sql.Timestamp = 2015-12-17 20:30:29.542
end: java.sql.Timestamp = 2015-12-17 20:30:50.828
elapsed: Double = 21.286

scala> val startTime = System.currentTimeMillis; sqlContext.range(0, 10000000, 1, 15).collect(); val endTime = System.currentTimeMillis; val start = new Timestamp(startTime); val end = new Timestamp(endTime); val elapsed = (endTime - startTime)/ 1000.0
startTime: Long = 1450413052748                                                 
endTime: Long = 1450413072478
start: java.sql.Timestamp = 2015-12-17 20:30:52.748
end: java.sql.Timestamp = 2015-12-17 20:31:12.478
elapsed: Double = 19.73

@gatorsmile
Copy link
Member Author

Original range API with collect (workload 10,000,000 rows):

scala> val startTime = System.currentTimeMillis; sqlContext.oldRange(0, 10000000, 1, 15).collect(); val endTime = System.currentTimeMillis; val start = new Timestamp(startTime); val end = new Timestamp(endTime); val elapsed = (endTime - startTime)/ 1000.0
startTime: Long = 1450413096482                                                 
endTime: Long = 1450413125867
start: java.sql.Timestamp = 2015-12-17 20:31:36.482
end: java.sql.Timestamp = 2015-12-17 20:32:05.867
elapsed: Double = 29.385

scala> val startTime = System.currentTimeMillis; sqlContext.oldRange(0, 10000000, 1, 15).collect(); val endTime = System.currentTimeMillis; val start = new Timestamp(startTime); val end = new Timestamp(endTime); val elapsed = (endTime - startTime)/ 1000.0
startTime: Long = 1450413127738                                                 
endTime: Long = 1450413157937
start: java.sql.Timestamp = 2015-12-17 20:32:07.738
end: java.sql.Timestamp = 2015-12-17 20:32:37.937
elapsed: Double = 30.199

scala> val startTime = System.currentTimeMillis; sqlContext.oldRange(0, 10000000, 1, 15).collect(); val endTime = System.currentTimeMillis; val start = new Timestamp(startTime); val end = new Timestamp(endTime); val elapsed = (endTime - startTime)/ 1000.0
startTime: Long = 1450413159888                                                 
endTime: Long = 1450413188929
start: java.sql.Timestamp = 2015-12-17 20:32:39.888
end: java.sql.Timestamp = 2015-12-17 20:33:08.929
elapsed: Double = 29.041

@gatorsmile
Copy link
Member Author

@marmbrus Compared with the original range API (the elapsed time is around 30 seconds), the new version is around 33% faster. Of course, the RDD range API is still much faster (its elapsed time is 0.6 second).

Do you think the performance improvement is good enough? Or any solution to further reduce the overhead?

FYI: I just tried it in my local computer by using spark-shell: ./bin/spark-shell --driver-java-options "-Xmx2g -XX:MaxPermSize=2G"

Thanks!

@gatorsmile
Copy link
Member Author

Let me add the fix of #10337, try the function count.

@gatorsmile
Copy link
Member Author

RDD range API with count (workload 1,000,000,000 rows):

scala> val startTime = System.currentTimeMillis; sc.range(0, 1000000000, 1, 15).count(); val endTime = System.currentTimeMillis; val start = new Timestamp(startTime); val end = new Timestamp(endTime); val elapsed = (endTime - startTime)/ 1000.0
startTime: Long = 1450416352767
endTime: Long = 1450416353302
start: java.sql.Timestamp = 2015-12-17 21:25:52.767
end: java.sql.Timestamp = 2015-12-17 21:25:53.302
elapsed: Double = 0.535

scala> val startTime = System.currentTimeMillis; sc.range(0, 1000000000, 1, 15).count(); val endTime = System.currentTimeMillis; val start = new Timestamp(startTime); val end = new Timestamp(endTime); val elapsed = (endTime - startTime)/ 1000.0
startTime: Long = 1450416354143
endTime: Long = 1450416354673
start: java.sql.Timestamp = 2015-12-17 21:25:54.143
end: java.sql.Timestamp = 2015-12-17 21:25:54.673
elapsed: Double = 0.53

scala> val startTime = System.currentTimeMillis; sc.range(0, 1000000000, 1, 15).count(); val endTime = System.currentTimeMillis; val start = new Timestamp(startTime); val end = new Timestamp(endTime); val elapsed = (endTime - startTime)/ 1000.0
startTime: Long = 1450416355390
endTime: Long = 1450416355936
start: java.sql.Timestamp = 2015-12-17 21:25:55.39
end: java.sql.Timestamp = 2015-12-17 21:25:55.936
elapsed: Double = 0.546

@gatorsmile
Copy link
Member Author

New range API using logical/physical operators with count (workload 1,000,000,000 rows):

scala> val startTime = System.currentTimeMillis; sqlContext.range(0, 1000000000, 1, 15).count(); val endTime = System.currentTimeMillis; val start = new Timestamp(startTime); val end = new Timestamp(endTime); val elapsed = (endTime - startTime)/ 1000.0
startTime: Long = 1450416360107                                                 
endTime: Long = 1450416368590
start: java.sql.Timestamp = 2015-12-17 21:26:00.107
end: java.sql.Timestamp = 2015-12-17 21:26:08.59
elapsed: Double = 8.483

scala> val startTime = System.currentTimeMillis; sqlContext.range(0, 1000000000, 1, 15).count(); val endTime = System.currentTimeMillis; val start = new Timestamp(startTime); val end = new Timestamp(endTime); val elapsed = (endTime - startTime)/ 1000.0
startTime: Long = 1450416371664                                                 
endTime: Long = 1450416380654
start: java.sql.Timestamp = 2015-12-17 21:26:11.664
end: java.sql.Timestamp = 2015-12-17 21:26:20.654
elapsed: Double = 8.99

scala> val startTime = System.currentTimeMillis; sqlContext.range(0, 1000000000, 1, 15).count(); val endTime = System.currentTimeMillis; val start = new Timestamp(startTime); val end = new Timestamp(endTime); val elapsed = (endTime - startTime)/ 1000.0
startTime: Long = 1450416383434                                                 
endTime: Long = 1450416392369
start: java.sql.Timestamp = 2015-12-17 21:26:23.434
end: java.sql.Timestamp = 2015-12-17 21:26:32.369
elapsed: Double = 8.935

@gatorsmile
Copy link
Member Author

Original range API with count (workload 1,000,000,000 rows):

scala> val startTime = System.currentTimeMillis; sqlContext.oldRange(0, 1000000000, 1, 15).count(); val endTime = System.currentTimeMillis; val start = new Timestamp(startTime); val end = new Timestamp(endTime); val elapsed = (endTime - startTime)/ 1000.0
startTime: Long = 1450416394240                                                 
endTime: Long = 1450416421199
start: java.sql.Timestamp = 2015-12-17 21:26:34.24
end: java.sql.Timestamp = 2015-12-17 21:27:01.199
elapsed: Double = 26.959

scala> val startTime = System.currentTimeMillis; sqlContext.oldRange(0, 1000000000, 1, 15).count(); val endTime = System.currentTimeMillis; val start = new Timestamp(startTime); val end = new Timestamp(endTime); val elapsed = (endTime - startTime)/ 1000.0
startTime: Long = 1450416424035                                                 
endTime: Long = 1450416450869
start: java.sql.Timestamp = 2015-12-17 21:27:04.035
end: java.sql.Timestamp = 2015-12-17 21:27:30.869
elapsed: Double = 26.834

scala> val startTime = System.currentTimeMillis; sqlContext.oldRange(0, 1000000000, 1, 15).count(); val endTime = System.currentTimeMillis; val start = new Timestamp(startTime); val end = new Timestamp(endTime); val elapsed = (endTime - startTime)/ 1000.0
startTime: Long = 1450416452321                                                 
endTime: Long = 1450416480277
start: java.sql.Timestamp = 2015-12-17 21:27:32.321
end: java.sql.Timestamp = 2015-12-17 21:28:00.277
elapsed: Double = 27.956

@gatorsmile
Copy link
Member Author

When the workload is small, both APIs are less than 1 second. Thus, I increase the workload by a factor of 100. Compared with the old Range API, the new version is 3 times faster than the old version. Sure, due to the overhead, RDD API is still much faster.

@hvanhovell
Copy link
Contributor

@gatorsmile I have been playing arround with this for a bit. Overall I think we should do this. I do have two things for you to consider.

The current approach follows the formal route. It implements a LogicalOperator/PhysicalOperator and changes the planner. It also - as you stated - reuses quite a bit of the code from SparkContext.range. The current range operator slowness comes from the fact that we create a normal Row for each element; this is expensive because it creates 1E9 objects and it will try to convert each Row to an internal one. We could also just address these two issues directly by wrapping the iterator provided by sc.range differently:

def range(start: Long, end: Long, step: Long, numPartitions: Int): DataFrame = {
    val logicalPlan = LogicalRDD(
      AttributeReference("id", LongType, nullable = false)() :: Nil,
      sparkContext.range(start, end, step, numPartitions).mapPartitions ({ i =>
        val unsafeRow = new UnsafeRow
        unsafeRow.pointTo(new Array[Byte](16), 1, 16)
        i.map { id =>
          unsafeRow.setLong(0, id)
          unsafeRow
        }
      }, true))(self)
    DataFrame(this, logicalPlan)
  }

What do you think?

My second point is about the benchmarking. I have been toying with this PR and the benchmarking code, and I am not sure that the current results are as revealing as they should be. I think the sqlContext.range code in this PR is nearly as fast as the sc.range code. The big difference is caused by the fact that the collect() call involves serialization. Serializing a Long is nowhere near as expensive as serializing an UnsafeRow, in my (tiny) benchmark of sqlContext.range serialization accounts for about 80-90% of the execution time (use the Spark Stage Timeline for this).

@hvanhovell
Copy link
Contributor

@gatorsmile small follow-up on the benchmarking. If I execute the following code (note the .rdd.map(_.getLong(0))):

val startTime = System.currentTimeMillis;
sqlContext.range(0, 10000000, 1, 15).rdd.map(_.getLong(0)).collect();
val endTime = System.currentTimeMillis;
val elapsed = (endTime - startTime)/ 1000.0

I get to an average of 845 ms per run (versus 477 for sc.range).

@gatorsmile
Copy link
Member Author

Hi, @hvanhovell ,

Thank you for your comments! Regarding the benchmarking, I do not have a better way to measure them. So far, collect() is not a good way when the workload scale is huge. It will cause a large scale of data movement. Sorry, the performance number in my test is exaggerated when we compare the RDD Range API and Dataframe Range API.

I just tried your suggested method. To compare results, I have to increase the workload scale to 1000000000. When the scale is small, it is hard to do the performance compare since the result could be affected by many factors. It is 2 times slower when we using count(). Maybe the performance penalty is caused by building a new RDD by mapPartition?

scala> val startTime = System.currentTimeMillis; sqlContext.logicalRDD_Range(0, 1000000000, 1, 15).count(); val endTime = System.currentTimeMillis; val elapsed = (endTime - startTime)/ 1000.0
startTime: Long = 1450466566418                                                 
endTime: Long = 1450466581232
elapsed: Double = 14.814

scala> val startTime = System.currentTimeMillis; sqlContext.logicalRDD_Range(0, 1000000000, 1, 15).count(); val endTime = System.currentTimeMillis; val elapsed = (endTime - startTime)/ 1000.0
startTime: Long = 1450466583781                                                 
endTime: Long = 1450466597751
elapsed: Double = 13.97

scala> val startTime = System.currentTimeMillis; sqlContext.range(0, 1000000000, 1, 15).count(); val endTime = System.currentTimeMillis; val elapsed = (endTime - startTime)/ 1000.0
startTime: Long = 1450466600825                                                 
endTime: Long = 1450466608397
elapsed: Double = 7.572

scala> val startTime = System.currentTimeMillis; sqlContext.range(0, 1000000000, 1, 15).count(); val endTime = System.currentTimeMillis; val elapsed = (endTime - startTime)/ 1000.0
startTime: Long = 1450466611679                                                 
endTime: Long = 1450466619421
elapsed: Double = 7.742

@marmbrus
Copy link
Contributor

Regarding benchmarking, we typically measure this kind of stuff using ForeachResults in spark-sql-perf (which measures just the time to pull the rows out of the iterator + conversion to the external format.) We should probably add an internal version as well that avoids the conversion cost.

Regarding @hvanhovell simplified implementation, I thought about proposing it like this. The only question is if we will ever want to add optimizations on top of this (i.e. we could do a count(*) on this kind of plan really quickly). Since its already implemented I skew towards the more logically transparent implementation. However, it might be nice to reuse the code in RDD as he proposes in the physical operator.

Super minor point: We should probably use UnsafeRow.createFromByteArray instead of pointTo.

@marmbrus
Copy link
Contributor

BTW, the benchmarks look reasonable to me, I'm okay with merging this as soon as we are happy with the implementation.

@gatorsmile
Copy link
Member Author

Thank you very much! @marmbrus

I am trying to get an account. Hopefully, next time, I can directly use your performance benchmarking for other performance-related topics. Otherwise, I will try to mimic your benchmarking in my local laptop. : )

Also changed the code to use createFromByteArray. It looks much concise now.

@SparkQA
Copy link

SparkQA commented Dec 19, 2015

Test build #48032 has finished for PR 10335 at commit a1abc2f.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):\n * case class Range(\n * case class Range(\n

@SparkQA
Copy link

SparkQA commented Dec 20, 2015

Test build #48085 has finished for PR 10335 at commit 36c862b.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):\n * case class Range(\n * case class Range(\n

@marmbrus
Copy link
Contributor

Thanks, merging to master.

@asfgit asfgit closed this in 4883a50 Dec 21, 2015
@gatorsmile gatorsmile deleted the rangeOperators branch January 19, 2016 05:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants