From b002d601f1d017582497613d94a733bec15fd7ad Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Sun, 31 Jan 2016 15:20:09 +0100 Subject: [PATCH] Add first/last ignoreNulls in python --- python/pyspark/sql/functions.py | 16 ++++++++++++++-- python/pyspark/sql/tests.py | 10 ++++++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 719eca8f5559e..7ba8d91b79ba9 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -81,8 +81,6 @@ def _(): 'max': 'Aggregate function: returns the maximum value of the expression in a group.', 'min': 'Aggregate function: returns the minimum value of the expression in a group.', - 'first': 'Aggregate function: returns the first value in a group.', - 'last': 'Aggregate function: returns the last value in a group.', 'count': 'Aggregate function: returns the number of items in a group.', 'sum': 'Aggregate function: returns the sum of all values in the expression.', 'avg': 'Aggregate function: returns the average of the values in a group.', @@ -277,6 +275,13 @@ def countDistinct(col, *cols): jc = sc._jvm.functions.countDistinct(_to_java_column(col), _to_seq(sc, cols, _to_java_column)) return Column(jc) +@since(1.3) +def first(col, ignorenulls=False): + """Aggregate function: returns the first value in a group. + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.first(_to_java_column(col), ignorenulls) + return Column(jc) @since(1.6) def input_file_name(): @@ -309,6 +314,13 @@ def isnull(col): sc = SparkContext._active_spark_context return Column(sc._jvm.functions.isnull(_to_java_column(col))) +@since(1.3) +def last(col, ignorenulls=False): + """Aggregate function: returns the last value in a group. + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.last(_to_java_column(col), ignorenulls) + return Column(jc) @since(1.6) def monotonically_increasing_id(): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 410efbafe0792..e30aa0a796924 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -641,6 +641,16 @@ def test_aggregator(self): self.assertTrue(95 < g.agg(functions.approxCountDistinct(df.key)).first()[0]) self.assertEqual(100, g.agg(functions.countDistinct(df.value)).first()[0]) + def test_first_last_ignorenulls(self): + from pyspark.sql import functions + df = self.sqlCtx.range(0, 100) + df2 = df.select(functions.when(df.id % 3 == 0, None).otherwise(df.id).alias("id")) + df3 = df2.select(functions.first(df2.id, False).alias('a'), + functions.first(df2.id, True).alias('b'), + functions.last(df2.id, False).alias('c'), + functions.last(df2.id, True).alias('d')) + self.assertEqual([Row(a=None, b=1, c=None, d=98)], df3.collect()) + def test_corr(self): import math df = self.sc.parallelize([Row(a=i, b=math.sqrt(i)) for i in range(10)]).toDF()