Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-13049] Add First/last with ignore nulls to functions.scala #10957

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,6 @@ def _():

'max': 'Aggregate function: returns the maximum value of the expression in a group.',
'min': 'Aggregate function: returns the minimum value of the expression in a group.',
'first': 'Aggregate function: returns the first value in a group.',
'last': 'Aggregate function: returns the last value in a group.',
'count': 'Aggregate function: returns the number of items in a group.',
'sum': 'Aggregate function: returns the sum of all values in the expression.',
'avg': 'Aggregate function: returns the average of the values in a group.',
Expand Down Expand Up @@ -278,6 +276,15 @@ def countDistinct(col, *cols):
return Column(jc)


@since(1.3)
def first(col, ignorenulls=False):
"""Aggregate function: returns the first value in a group.
"""
sc = SparkContext._active_spark_context
jc = sc._jvm.functions.first(_to_java_column(col), ignorenulls)
return Column(jc)


@since(1.6)
def input_file_name():
"""Creates a string column for the file name of the current Spark task.
Expand Down Expand Up @@ -310,6 +317,15 @@ def isnull(col):
return Column(sc._jvm.functions.isnull(_to_java_column(col)))


@since(1.3)
def last(col, ignorenulls=False):
"""Aggregate function: returns the last value in a group.
"""
sc = SparkContext._active_spark_context
jc = sc._jvm.functions.last(_to_java_column(col), ignorenulls)
return Column(jc)


@since(1.6)
def monotonically_increasing_id():
"""A column that generates monotonically increasing 64-bit integers.
Expand Down
10 changes: 10 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,16 @@ def test_aggregator(self):
self.assertTrue(95 < g.agg(functions.approxCountDistinct(df.key)).first()[0])
self.assertEqual(100, g.agg(functions.countDistinct(df.value)).first()[0])

def test_first_last_ignorenulls(self):
from pyspark.sql import functions
df = self.sqlCtx.range(0, 100)
df2 = df.select(functions.when(df.id % 3 == 0, None).otherwise(df.id).alias("id"))
df3 = df2.select(functions.first(df2.id, False).alias('a'),
functions.first(df2.id, True).alias('b'),
functions.last(df2.id, False).alias('c'),
functions.last(df2.id, True).alias('d'))
self.assertEqual([Row(a=None, b=1, c=None, d=98)], df3.collect())

def test_corr(self):
import math
df = self.sc.parallelize([Row(a=i, b=math.sqrt(i)) for i in range(10)]).toDF()
Expand Down
90 changes: 67 additions & 23 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -349,19 +349,41 @@ object functions extends LegacyFunctions {
}

/**
* Aggregate function: returns the first value in a group.
*
* @group agg_funcs
* @since 1.3.0
*/
def first(e: Column): Column = withAggregateFunction { new First(e.expr) }
* Aggregate function: returns the first value in a group. The function does not consider null
* values when the ignoreNulls flag is set to true.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you write something like this to be more clear? And update all the docs (including Python).

"The function by default includes the first value it sees. When ignoreNulls is set to true, then it ignores the null values and includes the first non-null value. If all values are null, then null is returned."

*
* @group agg_funcs
* @since 2.0.0
*/
def first(e: Column, ignoreNulls: Boolean): Column = withAggregateFunction {
new First(e.expr, Literal(ignoreNulls))
}

/**
* Aggregate function: returns the first value of a column in a group.
*
* @group agg_funcs
* @since 1.3.0
*/
* Aggregate function: returns the first value of a column in a group. The function does not
* consider null values when the ignoreNulls flag is set to true.
*
* @group agg_funcs
* @since 2.0.0
*/
def first(columnName: String, ignoreNulls: Boolean): Column = {
first(Column(columnName), ignoreNulls)
}

/**
* Aggregate function: returns the first value in a group.
*
* @group agg_funcs
* @since 1.3.0
*/
def first(e: Column): Column = first(e, ignoreNulls = false)

/**
* Aggregate function: returns the first value of a column in a group.
*
* @group agg_funcs
* @since 1.3.0
*/
def first(columnName: String): Column = first(Column(columnName))

/**
Expand All @@ -381,20 +403,42 @@ object functions extends LegacyFunctions {
def kurtosis(columnName: String): Column = kurtosis(Column(columnName))

/**
* Aggregate function: returns the last value in a group.
*
* @group agg_funcs
* @since 1.3.0
*/
def last(e: Column): Column = withAggregateFunction { new Last(e.expr) }
* Aggregate function: returns the last value in a group. The function does not
* consider null values when the ignoreNulls flag is set to true.
*
* @group agg_funcs
* @since 2.0.0
*/
def last(e: Column, ignoreNulls: Boolean): Column = withAggregateFunction {
new Last(e.expr, Literal(ignoreNulls))
}

/**
* Aggregate function: returns the last value of the column in a group.
*
* @group agg_funcs
* @since 1.3.0
*/
def last(columnName: String): Column = last(Column(columnName))
* Aggregate function: returns the last value of the column in a group. The function does not
* consider null values when the ignoreNulls flag is set to true.
*
* @group agg_funcs
* @since 2.0.0
*/
def last(columnName: String, ignoreNulls: Boolean): Column = {
last(Column(columnName), ignoreNulls)
}

/**
* Aggregate function: returns the last value in a group.
*
* @group agg_funcs
* @since 1.3.0
*/
def last(e: Column): Column = last(e, ignoreNulls = false)

/**
* Aggregate function: returns the last value of the column in a group.
*
* @group agg_funcs
* @since 1.3.0
*/
def last(columnName: String): Column = last(Column(columnName), ignoreNulls = false)

/**
* Aggregate function: returns the maximum value of the expression in a group.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -312,4 +312,36 @@ class DataFrameWindowSuite extends QueryTest with SharedSQLContext {
Row("b", 3, null, null),
Row("b", 2, null, null)))
}

test("last/first with ignoreNulls") {
val nullStr: String = null
val df = Seq(
("a", 0, nullStr),
("a", 1, "x"),
("a", 2, "y"),
("a", 3, "z"),
("a", 4, nullStr),
("b", 1, nullStr),
("b", 2, nullStr)).
toDF("key", "order", "value")
val window = Window.partitionBy($"key").orderBy($"order")
checkAnswer(
df.select(
$"key",
$"order",
first($"value").over(window),
first($"value", ignoreNulls = false).over(window),
first($"value", ignoreNulls = true).over(window),
last($"value").over(window),
last($"value", ignoreNulls = false).over(window),
last($"value", ignoreNulls = true).over(window)),
Seq(
Row("a", 0, null, null, null, null, null, null),
Row("a", 1, null, null, "x", "x", "x", "x"),
Row("a", 2, null, null, "x", "y", "y", "y"),
Row("a", 3, null, null, "x", "z", "z", "z"),
Row("a", 4, null, null, "x", null, null, "z"),
Row("b", 1, null, null, null, null, null, null),
Row("b", 2, null, null, null, null, null, null)))
}
}