Skip to content

Commit

Permalink
[SPARK-25234][SPARKR] avoid integer overflow in parallelize
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

`parallelize` uses integer multiplication to determine the split indices. It might cause integer overflow.

## How was this patch tested?

unit test

Closes #22225 from mengxr/SPARK-25234.

Authored-by: Xiangrui Meng <meng@databricks.com>
Signed-off-by: Xiangrui Meng <meng@databricks.com>
(cherry picked from commit 9714fa5)
Signed-off-by: Xiangrui Meng <meng@databricks.com>
  • Loading branch information
mengxr committed Aug 24, 2018
1 parent fcc9bd6 commit 42c1fdd
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
9 changes: 4 additions & 5 deletions R/pkg/R/context.R
Expand Up @@ -138,11 +138,10 @@ parallelize <- function(sc, coll, numSlices = 1) {

sizeLimit <- getMaxAllocationLimit(sc)
objectSize <- object.size(coll)
len <- length(coll)

# For large objects we make sure the size of each slice is also smaller than sizeLimit
numSerializedSlices <- max(numSlices, ceiling(objectSize / sizeLimit))
if (numSerializedSlices > length(coll))
numSerializedSlices <- length(coll)
numSerializedSlices <- min(len, max(numSlices, ceiling(objectSize / sizeLimit)))

# Generate the slice ids to put each row
# For instance, for numSerializedSlices of 22, length of 50
Expand All @@ -153,8 +152,8 @@ parallelize <- function(sc, coll, numSlices = 1) {
splits <- if (numSerializedSlices > 0) {
unlist(lapply(0: (numSerializedSlices - 1), function(x) {
# nolint start
start <- trunc((x * length(coll)) / numSerializedSlices)
end <- trunc(((x + 1) * length(coll)) / numSerializedSlices)
start <- trunc((as.numeric(x) * len) / numSerializedSlices)
end <- trunc(((as.numeric(x) + 1) * len) / numSerializedSlices)
# nolint end
rep(start, end - start)
}))
Expand Down
7 changes: 7 additions & 0 deletions R/pkg/tests/fulltests/test_context.R
Expand Up @@ -240,3 +240,10 @@ test_that("add and get file to be downloaded with Spark job on every node", {
unlink(path, recursive = TRUE)
sparkR.session.stop()
})

test_that("SPARK-25234: parallelize should not have integer overflow", {
sc <- sparkR.sparkContext(master = sparkRTestMaster)
# 47000 * 47000 exceeds integer range
parallelize(sc, 1:47000, 47000)
sparkR.session.stop()
})

0 comments on commit 42c1fdd

Please sign in to comment.