diff --git a/python/pyspark/sql/window.py b/python/pyspark/sql/window.py index b8bc90f458cdf..5cd5f69a85a61 100644 --- a/python/pyspark/sql/window.py +++ b/python/pyspark/sql/window.py @@ -128,11 +128,23 @@ def rowsBetween(start: int, end: int) -> "WindowSpec": -------- >>> from pyspark.sql import Window >>> from pyspark.sql import functions as func - >>> from pyspark.sql import SQLContext - >>> sc = SparkContext.getOrCreate() - >>> sqlContext = SQLContext(sc) - >>> tup = [(1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")] - >>> df = sqlContext.createDataFrame(tup, ["id", "category"]) + >>> df = spark.createDataFrame( + ... [(1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")], ["id", "category"]) + >>> df.show() + +---+--------+ + | id|category| + +---+--------+ + | 1| a| + | 1| a| + | 2| a| + | 1| b| + | 2| b| + | 3| b| + +---+--------+ + + Calculate sum of ``id`` in the range from currentRow to currentRow + 1 + in partition ``category`` + >>> window = Window.partitionBy("category").orderBy("id").rowsBetween(Window.currentRow, 1) >>> df.withColumn("sum", func.sum("id").over(window)).sort("id", "category", "sum").show() +---+--------+---+ @@ -196,11 +208,23 @@ def rangeBetween(start: int, end: int) -> "WindowSpec": -------- >>> from pyspark.sql import Window >>> from pyspark.sql import functions as func - >>> from pyspark.sql import SQLContext - >>> sc = SparkContext.getOrCreate() - >>> sqlContext = SQLContext(sc) - >>> tup = [(1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")] - >>> df = sqlContext.createDataFrame(tup, ["id", "category"]) + >>> df = spark.createDataFrame( + ... [(1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b")], ["id", "category"]) + >>> df.show() + +---+--------+ + | id|category| + +---+--------+ + | 1| a| + | 1| a| + | 2| a| + | 1| b| + | 2| b| + | 3| b| + +---+--------+ + + Calculate sum of ``id`` in the range from ``id`` of currentRow to ``id`` of currentRow + 1 + in partition ``category`` + >>> window = Window.partitionBy("category").orderBy("id").rangeBetween(Window.currentRow, 1) >>> df.withColumn("sum", func.sum("id").over(window)).sort("id", "category").show() +---+--------+---+ @@ -329,9 +353,10 @@ def rangeBetween(self, start: int, end: int) -> "WindowSpec": def _test() -> None: import doctest + from pyspark.sql import SparkSession import pyspark.sql.window - SparkContext("local[4]", "PythonTest") + spark = SparkSession.builder.master("local[4]").appName("sql.window tests").getOrCreate() globs = pyspark.sql.window.__dict__.copy() (failure_count, test_count) = doctest.testmod( pyspark.sql.window, globs=globs, optionflags=doctest.NORMALIZE_WHITESPACE