Skip to content

Commit

Permalink
[SPARK-7135][SQL] DataFrame expression for monotonically increasing IDs.
Browse files Browse the repository at this point in the history
Author: Reynold Xin <rxin@databricks.com>

Closes #5709 from rxin/inc-id and squashes the following commits:

7853611 [Reynold Xin] private sql.
a9fda0d [Reynold Xin] Missed a few numbers.
343d896 [Reynold Xin] Self review feedback.
a7136cb [Reynold Xin] [SPARK-7135][SQL] DataFrame expression for monotonically increasing IDs.
  • Loading branch information
rxin committed Apr 28, 2015
1 parent bf35edd commit d94cd1a
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 5 deletions.
22 changes: 21 additions & 1 deletion python/pyspark/sql/functions.py
Expand Up @@ -103,8 +103,28 @@ def countDistinct(col, *cols):
return Column(jc)


def monotonicallyIncreasingId():
"""A column that generates monotonically increasing 64-bit integers.
The generated ID is guaranteed to be monotonically increasing and unique, but not consecutive.
The current implementation puts the partition ID in the upper 31 bits, and the record number
within each partition in the lower 33 bits. The assumption is that the data frame has
less than 1 billion partitions, and each partition has less than 8 billion records.
As an example, consider a [[DataFrame]] with two partitions, each with 3 records.
This expression would return the following IDs:
0, 1, 2, 8589934592 (1L << 33), 8589934593, 8589934594.
>>> df0 = sc.parallelize(range(2), 2).mapPartitions(lambda x: [(1,), (2,), (3,)]).toDF(['col1'])
>>> df0.select(monotonicallyIncreasingId().alias('id')).collect()
[Row(id=0), Row(id=1), Row(id=2), Row(id=8589934592), Row(id=8589934593), Row(id=8589934594)]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.monotonicallyIncreasingId())


def sparkPartitionId():
"""Returns a column for partition ID of the Spark task.
"""A column for partition ID of the Spark task.
Note that this is indeterministic because it depends on data partitioning and task scheduling.
Expand Down
@@ -0,0 +1,53 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.expressions

import org.apache.spark.TaskContext
import org.apache.spark.sql.catalyst.expressions.{Row, LeafExpression}
import org.apache.spark.sql.types.{LongType, DataType}

/**
* Returns monotonically increasing 64-bit integers.
*
* The generated ID is guaranteed to be monotonically increasing and unique, but not consecutive.
* The current implementation puts the partition ID in the upper 31 bits, and the lower 33 bits
* represent the record number within each partition. The assumption is that the data frame has
* less than 1 billion partitions, and each partition has less than 8 billion records.
*
* Since this expression is stateful, it cannot be a case object.
*/
private[sql] case class MonotonicallyIncreasingID() extends LeafExpression {

/**
* Record ID within each partition. By being transient, count's value is reset to 0 every time
* we serialize and deserialize it.
*/
@transient private[this] var count: Long = 0L

override type EvaluatedType = Long

override def nullable: Boolean = false

override def dataType: DataType = LongType

override def eval(input: Row): Long = {
val currentCount = count
count += 1
(TaskContext.get().partitionId().toLong << 33) + currentCount
}
}
Expand Up @@ -18,16 +18,14 @@
package org.apache.spark.sql.execution.expressions

import org.apache.spark.TaskContext
import org.apache.spark.sql.catalyst.expressions.{Row, Expression}
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.expressions.{LeafExpression, Row}
import org.apache.spark.sql.types.{IntegerType, DataType}


/**
* Expression that returns the current partition id of the Spark task.
*/
case object SparkPartitionID extends Expression with trees.LeafNode[Expression] {
self: Product =>
private[sql] case object SparkPartitionID extends LeafExpression {

override type EvaluatedType = Int

Expand Down
16 changes: 16 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Expand Up @@ -301,6 +301,22 @@ object functions {
*/
def lower(e: Column): Column = Lower(e.expr)

/**
* A column expression that generates monotonically increasing 64-bit integers.
*
* The generated ID is guaranteed to be monotonically increasing and unique, but not consecutive.
* The current implementation puts the partition ID in the upper 31 bits, and the record number
* within each partition in the lower 33 bits. The assumption is that the data frame has
* less than 1 billion partitions, and each partition has less than 8 billion records.
*
* As an example, consider a [[DataFrame]] with two partitions, each with 3 records.
* This expression would return the following IDs:
* 0, 1, 2, 8589934592 (1L << 33), 8589934593, 8589934594.
*
* @group normal_funcs
*/
def monotonicallyIncreasingId(): Column = execution.expressions.MonotonicallyIncreasingID()

/**
* Unary minus, i.e. negate the expression.
* {{{
Expand Down
Expand Up @@ -309,6 +309,17 @@ class ColumnExpressionSuite extends QueryTest {
)
}

test("monotonicallyIncreasingId") {
// Make sure we have 2 partitions, each with 2 records.
val df = TestSQLContext.sparkContext.parallelize(1 to 2, 2).mapPartitions { iter =>
Iterator(Tuple1(1), Tuple1(2))
}.toDF("a")
checkAnswer(
df.select(monotonicallyIncreasingId()),
Row(0L) :: Row(1L) :: Row((1L << 33) + 0L) :: Row((1L << 33) + 1L) :: Nil
)
}

test("sparkPartitionId") {
val df = TestSQLContext.sparkContext.parallelize(1 to 1, 1).map(i => (i, i)).toDF("a", "b")
checkAnswer(
Expand Down

0 comments on commit d94cd1a

Please sign in to comment.