Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-7150] SparkContext.range() and SQLContext.range() #6230

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,78 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int, Seq[String]]())
}

/**
* Creates a new RDD[Long] containing elements from `start` to `end`(exclusive), increased by
* `step` every element.
*
* @note if we need to cache this RDD, we should make sure each partition does not exceed limit.
*
* @param start the start value.
* @param end the end value.
* @param step the incremental step
* @param numSlices the partition number of the new RDD.
* @return
*/
def range(
start: Long,
end: Long,
step: Long = 1,
numSlices: Int = defaultParallelism): RDD[Long] = withScope {
assertNotStopped()
// when step is 0, range will run infinitely
require(step != 0, "step cannot be 0")
val numElements: BigInt = {
val safeStart = BigInt(start)
val safeEnd = BigInt(end)
if ((safeEnd - safeStart) % step == 0 || safeEnd > safeStart ^ step > 0) {
(safeEnd - safeStart) / step
} else {
// the remainder has the same sign with range, could add 1 more
(safeEnd - safeStart) / step + 1
}
}
parallelize(0 until numSlices, numSlices).mapPartitionsWithIndex((i, _) => {
val partitionStart = (i * numElements) / numSlices * step + start
val partitionEnd = (((i + 1) * numElements) / numSlices) * step + start
def getSafeMargin(bi: BigInt): Long =
if (bi.isValidLong) {
bi.toLong
} else if (bi > 0) {
Long.MaxValue
} else {
Long.MinValue
}
val safePartitionStart = getSafeMargin(partitionStart)
val safePartitionEnd = getSafeMargin(partitionEnd)

new Iterator[Long] {
private[this] var number: Long = safePartitionStart
private[this] var overflow: Boolean = false

override def hasNext =
if (!overflow) {
if (step > 0) {
number < safePartitionEnd
} else {
number > safePartitionEnd
}
} else false

override def next() = {
val ret = number
number += step
if (number < ret ^ step < 0) {
// we have Long.MaxValue + Long.MaxValue < Long.MaxValue
// and Long.MinValue + Long.MinValue > Long.MinValue, so iff the step causes a step
// back, we are pretty sure that we have an overflow.
overflow = true
}
ret
}
}
})
}

/** Distribute a local Scala collection to form an RDD.
*
* This method is identical to `parallelize`.
Expand Down
16 changes: 16 additions & 0 deletions python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,22 @@ def stop(self):
with SparkContext._lock:
SparkContext._active_spark_context = None

def range(self, start, end, step=1, numSlices=None):
"""
Create a new RDD of int containing elements from `start` to `end`
(exclusive), increased by `step` every element.

:param start: the start value
:param end: the end value (exclusive)
:param step: the incremental step (default: 1)
:param numSlices: the number of partitions of the new RDD
:return: An RDD of int

>>> sc.range(1, 7, 2).collect()
[1, 3, 5]
"""
return self.parallelize(xrange(start, end, step), numSlices)

def parallelize(self, c, numSlices=None):
"""
Distribute a local Python collection to form an RDD. Using xrange
Expand Down
20 changes: 20 additions & 0 deletions python/pyspark/sql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,26 @@ def udf(self):
"""Returns a :class:`UDFRegistration` for UDF registration."""
return UDFRegistration(self)

def range(self, start, end, step=1, numPartitions=None):
"""
Create a :class:`DataFrame` with single LongType column named `id`,
containing elements in a range from `start` to `end` (exclusive) with
step value `step`.

:param start: the start value
:param end: the end value (exclusive)
:param step: the incremental step (default: 1)
:param numPartitions: the number of partitions of the DataFrame
:return: A new DataFrame

>>> sqlContext.range(1, 7, 2).collect()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a test for large ints (i.e. > 32 bits)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might make sense to have that in tests.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

[Row(id=1), Row(id=3), Row(id=5)]
"""
if numPartitions is None:
numPartitions = self._sc.defaultParallelism
jdf = self._ssql_ctx.range(int(start), int(end), int(step), int(numPartitions))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will make the parameters unpredictable, and lead to exceptions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the start or end is invalid, you will get an exception anyway. By converting them in Python, we will got an exception in Python way (failed to converted into int), not a Py4j exception (failed to find a method to call), the later is much harder to understand for most of users.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are right.

return DataFrame(jdf, self)

@ignore_unicode_prefix
def registerFunction(self, name, f, returnType=StringType()):
"""Registers a lambda function as a UDF so it can be used in SQL statements.
Expand Down
5 changes: 5 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,11 @@ def tearDownClass(cls):
ReusedPySparkTestCase.tearDownClass()
shutil.rmtree(cls.tempdir.name, ignore_errors=True)

def test_range(self):
self.assertEqual(self.sqlCtx.range(1, 1).count(), 0)
self.assertEqual(self.sqlCtx.range(1, 0, -1).count(), 1)
self.assertEqual(self.sqlCtx.range(0, 1 << 40, 1 << 39).count(), 2)

def test_explode(self):
from pyspark.sql.functions import explode
d = [Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"})]
Expand Down
5 changes: 5 additions & 0 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,11 @@ def func(x):

class RDDTests(ReusedPySparkTestCase):

def test_range(self):
self.assertEqual(self.sc.range(1, 1).count(), 0)
self.assertEqual(self.sc.range(1, 0, -1).count(), 1)
self.assertEqual(self.sc.range(0, 1 << 40, 1 << 39).count(), 2)

def test_id(self):
rdd = self.sc.parallelize(range(10))
id = rdd.id()
Expand Down
31 changes: 31 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -684,6 +684,37 @@ class SQLContext(@transient val sparkContext: SparkContext)
catalog.unregisterTable(Seq(tableName))
}

/**
* :: Experimental ::
* Creates a [[DataFrame]] with a single [[LongType]] column named `id`, containing elements
* in an range from `start` to `end`(exclusive) with step value 1.
*
* @since 1.4.0
* @group dataframe
*/
@Experimental
def range(start: Long, end: Long): DataFrame = {
createDataFrame(
sparkContext.range(start, end).map(Row(_)),
StructType(StructField("id", LongType, nullable = false) :: Nil))
}

/**
* :: Experimental ::
* Creates a [[DataFrame]] with a single [[LongType]] column named `id`, containing elements
* in an range from `start` to `end`(exclusive) with an step value, with partition number
* specified.
*
* @since 1.4.0
* @group dataframe
*/
@Experimental
def range(start: Long, end: Long, step: Long, numPartitions: Int): DataFrame = {
createDataFrame(
sparkContext.range(start, end, step, numPartitions).map(Row(_)),
StructType(StructField("id", LongType, nullable = false) :: Nil))
}

/**
* Executes a SQL query using Spark, returning the result as a [[DataFrame]]. The dialect that is
* used for SQL parsing can be configured with 'spark.sql.dialect'.
Expand Down
40 changes: 40 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -532,4 +532,44 @@ class DataFrameSuite extends QueryTest {
val p = df.logicalPlan.asInstanceOf[Project].child.asInstanceOf[Project]
assert(!p.child.isInstanceOf[Project])
}

test("SPARK-7150 range api") {
// numSlice is greater than length
val res1 = TestSQLContext.range(0, 10, 1, 15).select("id")
assert(res1.count == 10)
assert(res1.agg(sum("id")).as("sumid").collect() === Seq(Row(45)))

val res2 = TestSQLContext.range(3, 15, 3, 2).select("id")
assert(res2.count == 4)
assert(res2.agg(sum("id")).as("sumid").collect() === Seq(Row(30)))

val res3 = TestSQLContext.range(1, -2).select("id")
assert(res3.count == 0)

// start is positive, end is negative, step is negative
val res4 = TestSQLContext.range(1, -2, -2, 6).select("id")
assert(res4.count == 2)
assert(res4.agg(sum("id")).as("sumid").collect() === Seq(Row(0)))

// start, end, step are negative
val res5 = TestSQLContext.range(-3, -8, -2, 1).select("id")
assert(res5.count == 3)
assert(res5.agg(sum("id")).as("sumid").collect() === Seq(Row(-15)))

// start, end are negative, step is positive
val res6 = TestSQLContext.range(-8, -4, 2, 1).select("id")
assert(res6.count == 2)
assert(res6.agg(sum("id")).as("sumid").collect() === Seq(Row(-14)))

val res7 = TestSQLContext.range(-10, -9, -20, 1).select("id")
assert(res7.count == 0)

val res8 = TestSQLContext.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id")
assert(res8.count == 3)
assert(res8.agg(sum("id")).as("sumid").collect() === Seq(Row(-3)))

val res9 = TestSQLContext.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id")
assert(res9.count == 2)
assert(res9.agg(sum("id")).as("sumid").collect() === Seq(Row(Long.MaxValue - 1)))
}
}