From 71792411083a71bcfd7a0d94ddf754bf09a27054 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Thu, 11 Aug 2022 20:19:24 +0900 Subject: [PATCH] [SPARK-40027][PYTHON][SS][DOCS] Add self-contained examples for pyspark.sql.streaming.readwriter ### What changes were proposed in this pull request? This PR proposes to improve the examples in `pyspark.sql.streaming.readwriter` by making each example self-contained with a brief explanation and a bit more realistic example. ### Why are the changes needed? To make the documentation more readable and able to copy and paste directly in PySpark shell. ### Does this PR introduce _any_ user-facing change? Yes, it changes the documentation ### How was this patch tested? Manually ran each doctest. Closes #37461 from HyukjinKwon/SPARK-40027. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/streaming/readwriter.py | 441 +++++++++++++++++---- 1 file changed, 357 insertions(+), 84 deletions(-) diff --git a/python/pyspark/sql/streaming/readwriter.py b/python/pyspark/sql/streaming/readwriter.py index 74b89dbe46c20..ef3b7e525e3e9 100644 --- a/python/pyspark/sql/streaming/readwriter.py +++ b/python/pyspark/sql/streaming/readwriter.py @@ -24,7 +24,7 @@ from pyspark.sql.column import _to_seq from pyspark.sql.readwriter import OptionUtils, to_str from pyspark.sql.streaming.query import StreamingQuery -from pyspark.sql.types import Row, StructType, StructField, StringType +from pyspark.sql.types import Row, StructType from pyspark.sql.utils import ForeachBatchFunction if TYPE_CHECKING: @@ -46,6 +46,22 @@ class DataStreamReader(OptionUtils): Notes ----- This API is evolving. + + Examples + -------- + >>> spark.readStream + + + The example below uses Rate source that generates rows continously. + After that, we operate a modulo by 3, and then writes the stream out to the console. + The streaming query stops in 3 seconds. + + >>> import time + >>> df = spark.readStream.format("rate").load() + >>> df = df.selectExpr("value % 3 as v") + >>> q = df.writeStream.format("console").start() + >>> time.sleep(3) + >>> q.stop() """ def __init__(self, spark: "SparkSession") -> None: @@ -73,7 +89,23 @@ def format(self, source: str) -> "DataStreamReader": Examples -------- - >>> s = spark.readStream.format("text") + >>> spark.readStream.format("text") + + + This API allows to configure other sources to read. The example below writes a small text + file, and reads it back via Text source. + + >>> import tempfile + >>> import time + >>> with tempfile.TemporaryDirectory() as d: + ... # Write a temporary text file to read it. + ... spark.createDataFrame( + ... [("hello",), ("this",)]).write.mode("overwrite").format("text").save(d) + ... + ... # Start a streaming query to read the text file. + ... q = spark.readStream.format("text").load(d).writeStream.format("console").start() + ... time.sleep(3) + ... q.stop() """ self._jreader = self._jreader.format(source) return self @@ -99,8 +131,22 @@ def schema(self, schema: Union[StructType, str]) -> "DataStreamReader": Examples -------- - >>> s = spark.readStream.schema(sdf_schema) - >>> s = spark.readStream.schema("col0 INT, col1 DOUBLE") + >>> from pyspark.sql.types import StructField, StructType, StringType + >>> spark.readStream.schema(StructType([StructField("data", StringType(), True)])) + + >>> spark.readStream.schema("col0 INT, col1 DOUBLE") + + + The example below specifies a different schema to CSV file. + + >>> import tempfile + >>> import time + >>> with tempfile.TemporaryDirectory() as d: + ... # Start a streaming query to read the CSV file. + ... spark.readStream.schema("col0 INT, col1 STRING").format("csv").load(d).printSchema() + root + |-- col0: integer (nullable = true) + |-- col1: string (nullable = true) """ from pyspark.sql import SparkSession @@ -125,7 +171,17 @@ def option(self, key: str, value: "OptionalPrimitiveType") -> "DataStreamReader" Examples -------- - >>> s = spark.readStream.option("x", 1) + >>> spark.readStream.option("x", 1) + + + The example below specifies 'rowsPerSecond' option to Rate source in order to generate + 10 rows every second. + + >>> import time + >>> q = spark.readStream.format( + ... "rate").option("rowsPerSecond", 10).load().writeStream.format("console").start() + >>> time.sleep(3) + >>> q.stop() """ self._jreader = self._jreader.option(key, to_str(value)) return self @@ -141,7 +197,18 @@ def options(self, **options: "OptionalPrimitiveType") -> "DataStreamReader": Examples -------- - >>> s = spark.readStream.options(x="1", y=2) + >>> spark.readStream.options(x="1", y=2) + + + The example below specifies 'rowsPerSecond' and 'numPartitions' options to + Rate source in order to generate 10 rows with 10 partitions every second. + + >>> import time + >>> q = spark.readStream.format("rate").options( + ... rowsPerSecond=10, numPartitions=10 + ... ).load().writeStream.format("console").start() + >>> time.sleep(3) + >>> q.stop() """ for k in options: self._jreader = self._jreader.option(k, to_str(options[k])) @@ -177,13 +244,22 @@ def load( Examples -------- - >>> json_sdf = spark.readStream.format("json") \\ - ... .schema(sdf_schema) \\ - ... .load(tempfile.mkdtemp()) - >>> json_sdf.isStreaming - True - >>> json_sdf.schema == sdf_schema - True + Load a data stream from a temporary JSON file. + + >>> import tempfile + >>> import time + >>> with tempfile.TemporaryDirectory() as d: + ... # Write a temporary JSON file to read it. + ... spark.createDataFrame( + ... [(100, "Hyukjin Kwon"),], ["age", "name"] + ... ).write.mode("overwrite").format("json").save(d) + ... + ... # Start a streaming query to read the JSON file. + ... q = spark.readStream.schema( + ... "age INT, name STRING" + ... ).format("json").load(d).writeStream.format("console").start() + ... time.sleep(3) + ... q.stop() """ if format is not None: self.format(format) @@ -260,11 +336,22 @@ def json( Examples -------- - >>> json_sdf = spark.readStream.json(tempfile.mkdtemp(), schema = sdf_schema) - >>> json_sdf.isStreaming - True - >>> json_sdf.schema == sdf_schema - True + Load a data stream from a temporary JSON file. + + >>> import tempfile + >>> import time + >>> with tempfile.TemporaryDirectory() as d: + ... # Write a temporary JSON file to read it. + ... spark.createDataFrame( + ... [(100, "Hyukjin Kwon"),], ["age", "name"] + ... ).write.mode("overwrite").format("json").save(d) + ... + ... # Start a streaming query to read the JSON file. + ... q = spark.readStream.schema( + ... "age INT, name STRING" + ... ).json(d).writeStream.format("console").start() + ... time.sleep(3) + ... q.stop() """ self._set_opts( schema=schema, @@ -316,11 +403,18 @@ def orc( Examples -------- - >>> orc_sdf = spark.readStream.schema(sdf_schema).orc(tempfile.mkdtemp()) - >>> orc_sdf.isStreaming - True - >>> orc_sdf.schema == sdf_schema - True + Load a data stream from a temporary ORC file. + + >>> import tempfile + >>> import time + >>> with tempfile.TemporaryDirectory() as d: + ... # Write a temporary ORC file to read it. + ... spark.range(10).write.mode("overwrite").format("orc").save(d) + ... + ... # Start a streaming query to read the ORC file. + ... q = spark.readStream.schema("id LONG").orc(d).writeStream.format("console").start() + ... time.sleep(3) + ... q.stop() """ self._set_opts( mergeSchema=mergeSchema, @@ -362,11 +456,19 @@ def parquet( Examples -------- - >>> parquet_sdf = spark.readStream.schema(sdf_schema).parquet(tempfile.mkdtemp()) - >>> parquet_sdf.isStreaming - True - >>> parquet_sdf.schema == sdf_schema - True + Load a data stream from a temporary Parquet file. + + >>> import tempfile + >>> import time + >>> with tempfile.TemporaryDirectory() as d: + ... # Write a temporary Parquet file to read it. + ... spark.range(10).write.mode("overwrite").format("parquet").save(d) + ... + ... # Start a streaming query to read the Parquet file. + ... q = spark.readStream.schema( + ... "id LONG").parquet(d).writeStream.format("console").start() + ... time.sleep(3) + ... q.stop() """ self._set_opts( mergeSchema=mergeSchema, @@ -418,11 +520,19 @@ def text( Examples -------- - >>> text_sdf = spark.readStream.text(tempfile.mkdtemp()) - >>> text_sdf.isStreaming - True - >>> "value" in str(text_sdf.schema) - True + Load a data stream from a temporary text file. + + >>> import tempfile + >>> import time + >>> with tempfile.TemporaryDirectory() as d: + ... # Write a temporary text file to read it. + ... spark.createDataFrame( + ... [("hello",), ("this",)]).write.mode("overwrite").format("text").save(d) + ... + ... # Start a streaming query to read the text file. + ... q = spark.readStream.text(d).writeStream.format("console").start() + ... time.sleep(3) + ... q.stop() """ self._set_opts( wholetext=wholetext, @@ -500,11 +610,20 @@ def csv( Examples -------- - >>> csv_sdf = spark.readStream.csv(tempfile.mkdtemp(), schema = sdf_schema) - >>> csv_sdf.isStreaming - True - >>> csv_sdf.schema == sdf_schema - True + Load a data stream from a temporary CSV file. + + >>> import tempfile + >>> import time + >>> with tempfile.TemporaryDirectory() as d: + ... # Write a temporary text file to read it. + ... spark.createDataFrame([(1, "2"),]).write.mode("overwrite").format("csv").save(d) + ... + ... # Start a streaming query to read the CSV file. + ... q = spark.readStream.schema( + ... "col0 INT, col1 STRING" + ... ).format("csv").load(d).writeStream.format("console").start() + ... time.sleep(3) + ... q.stop() """ self._set_opts( schema=schema, @@ -564,7 +683,22 @@ def table(self, tableName: str) -> "DataFrame": Examples -------- - >>> spark.readStream.table('input_table') # doctest: +SKIP + Load a data stream from a table. + + >>> import tempfile + >>> import time + >>> _ = spark.sql("DROP TABLE IF EXISTS my_table") + >>> with tempfile.TemporaryDirectory() as d: + ... # Create a table with Rate source. + ... q1 = spark.readStream.format("rate").load().writeStream.toTable( + ... "my_table", checkpointLocation=d) + ... + ... # Read the table back and print out in the console. + ... q2 = spark.readStream.table("my_table").writeStream.format("console").start() + ... time.sleep(3) + ... q1.stop() + ... q2.stop() + ... _ = spark.sql("DROP TABLE my_table") """ if isinstance(tableName, str): return self._df(self._jreader.table(tableName)) @@ -584,6 +718,19 @@ class DataStreamWriter: Notes ----- This API is evolving. + + Examples + -------- + The example below uses Rate source that generates rows continously. + After that, we operate a modulo by 3, and then writes the stream out to the console. + The streaming query stops in 3 seconds. + + >>> import time + >>> df = spark.readStream.format("rate").load() + >>> df = df.selectExpr("value % 3 as v") + >>> q = df.writeStream.format("console").start() + >>> time.sleep(3) + >>> q.stop() """ def __init__(self, df: "DataFrame") -> None: @@ -615,7 +762,18 @@ def outputMode(self, outputMode: str) -> "DataStreamWriter": Examples -------- - >>> writer = sdf.writeStream.outputMode('append') + >>> df = spark.readStream.format("rate").load() + >>> df.writeStream.outputMode('append') + + + The example below uses Complete mode that the entire aggregated counts are printed out. + + >>> import time + >>> df = spark.readStream.format("rate").option("rowsPerSecond", 10).load() + >>> df = df.groupby().count() + >>> q = df.writeStream.outputMode("complete").format("console").start() + >>> time.sleep(3) + >>> q.stop() """ if not outputMode or type(outputMode) != str or len(outputMode.strip()) == 0: raise ValueError("The output mode must be a non-empty string. Got: %s" % outputMode) @@ -638,7 +796,25 @@ def format(self, source: str) -> "DataStreamWriter": Examples -------- - >>> writer = sdf.writeStream.format('json') + >>> df = spark.readStream.format("rate").load() + >>> df.writeStream.format("text") + + + This API allows to configure the source to write. The example below writes a CSV + file from Rate source in a streaming manner. + + >>> import tempfile + >>> import time + >>> with tempfile.TemporaryDirectory() as d, tempfile.TemporaryDirectory() as cp: + ... df = spark.readStream.format("rate").load() + ... q = df.writeStream.format("csv").option("checkpointLocation", cp).start(d) + ... time.sleep(5) + ... q.stop() + ... spark.read.schema("timestamp TIMESTAMP, value STRING").csv(d).show() + +...---------+-----+ + |...timestamp|value| + +...---------+-----+ + ... """ self._jwrite = self._jwrite.format(source) return self @@ -651,6 +827,22 @@ def option(self, key: str, value: "OptionalPrimitiveType") -> "DataStreamWriter" Notes ----- This API is evolving. + + Examples + -------- + >>> df = spark.readStream.format("rate").load() + >>> df.writeStream.option("x", 1) + + + The example below specifies 'numRows' option to Console source in order to print + 3 rows for every batch. + + >>> import time + >>> q = spark.readStream.format( + ... "rate").option("rowsPerSecond", 10).load().writeStream.format( + ... "console").option("numRows", 3).start() + >>> time.sleep(3) + >>> q.stop() """ self._jwrite = self._jwrite.option(key, to_str(value)) return self @@ -663,6 +855,22 @@ def options(self, **options: "OptionalPrimitiveType") -> "DataStreamWriter": Notes ----- This API is evolving. + + Examples + -------- + >>> df = spark.readStream.format("rate").load() + >>> df.writeStream.option("x", 1) + + + The example below specifies 'numRows' and 'truncate' options to Console source in order + to print 3 rows for every batch without truncating the results. + + >>> import time + >>> q = spark.readStream.format( + ... "rate").option("rowsPerSecond", 10).load().writeStream.format( + ... "console").options(numRows=3, truncate=False).start() + >>> time.sleep(3) + >>> q.stop() """ for k in options: self._jwrite = self._jwrite.option(k, to_str(options[k])) @@ -692,6 +900,28 @@ def partitionBy(self, *cols: str) -> "DataStreamWriter": # type: ignore[misc] Notes ----- This API is evolving. + + Examples + -------- + >>> df = spark.readStream.format("rate").load() + >>> df.writeStream.partitionBy("value") + + + Partition-by timestamp column from Rate source. + + >>> import tempfile + >>> import time + >>> with tempfile.TemporaryDirectory() as d, tempfile.TemporaryDirectory() as cp: + ... df = spark.readStream.format("rate").option("rowsPerSecond", 10).load() + ... q = df.writeStream.partitionBy( + ... "timestamp").format("parquet").option("checkpointLocation", cp).start(d) + ... time.sleep(5) + ... q.stop() + ... spark.read.schema(df.schema).parquet(d).show() + +...---------+-----+ + |...timestamp|value| + +...---------+-----+ + ... """ if len(cols) == 1 and isinstance(cols[0], (list, tuple)): cols = cols[0] @@ -716,7 +946,12 @@ def queryName(self, queryName: str) -> "DataStreamWriter": Examples -------- - >>> writer = sdf.writeStream.queryName('streaming_query') + >>> import time + >>> df = spark.readStream.format("rate").load() + >>> q = df.writeStream.queryName("streaming_query").format("console").start() + >>> q.stop() + >>> q.name + 'streaming_query' """ if not queryName or type(queryName) != str or len(queryName.strip()) == 0: raise ValueError("The queryName must be a non-empty string. Got: %s" % queryName) @@ -775,14 +1010,22 @@ def trigger( Examples -------- - >>> # trigger the query for execution every 5 seconds - >>> writer = sdf.writeStream.trigger(processingTime='5 seconds') - >>> # trigger the query for just once batch of data - >>> writer = sdf.writeStream.trigger(once=True) - >>> # trigger the query for execution every 5 seconds - >>> writer = sdf.writeStream.trigger(continuous='5 seconds') - >>> # trigger the query for reading all available data with multiple batches - >>> writer = sdf.writeStream.trigger(availableNow=True) + >>> df = spark.readStream.format("rate").load() + + Trigger the query for execution every 5 seconds + + >>> df.writeStream.trigger(processingTime='5 seconds') + + + Trigger the query for execution every 5 seconds + + >>> df.writeStream.trigger(continuous='5 seconds') + + + Trigger the query for reading all available data with multiple batches + + >>> df.writeStream.trigger(availableNow=True) + """ params = [processingTime, once, continuous, availableNow] @@ -908,22 +1151,34 @@ def foreach(self, f: Union[Callable[[Row], None], "SupportsProcess"]) -> "DataSt Examples -------- - >>> # Print every row using a function + >>> import time + >>> df = spark.readStream.format("rate").load() + + Print every row using a function + >>> def print_row(row): ... print(row) ... - >>> writer = sdf.writeStream.foreach(print_row) - >>> # Print every row using a object with process() method + >>> q = df.writeStream.foreach(print_row).start() + >>> time.sleep(3) + >>> q.stop() + + Print every row using a object with process() method + >>> class RowPrinter: ... def open(self, partition_id, epoch_id): ... print("Opened %d, %d" % (partition_id, epoch_id)) ... return True + ... ... def process(self, row): ... print(row) + ... ... def close(self, error): ... print("Closed with error: %s" % str(error)) ... - >>> writer = sdf.writeStream.foreach(RowPrinter()) + >>> q = df.writeStream.foreach(print_row).start() + >>> time.sleep(3) + >>> q.stop() """ from pyspark.rdd import _wrap_function @@ -1025,10 +1280,14 @@ def foreachBatch(self, func: Callable[["DataFrame", int], None]) -> "DataStreamW Examples -------- + >>> import time + >>> df = spark.readStream.format("rate").load() >>> def func(batch_df, batch_id): ... batch_df.collect() ... - >>> writer = sdf.writeStream.foreachBatch(func) + >>> q = df.writeStream.foreachBatch(func).start() + >>> time.sleep(3) + >>> q.stop() """ from pyspark.java_gateway import ensure_callback_server_started @@ -1090,21 +1349,28 @@ def start( Examples -------- - >>> sq = sdf.writeStream.format('memory').queryName('this_query').start() - >>> sq.isActive + >>> df = spark.readStream.format("rate").load() + + Basic example. + + >>> q = df.writeStream.format('memory').queryName('this_query').start() + >>> q.isActive True - >>> sq.name + >>> q.name 'this_query' - >>> sq.stop() - >>> sq.isActive + >>> q.stop() + >>> q.isActive False - >>> sq = sdf.writeStream.trigger(processingTime='5 seconds').start( + + Example with using other parameters with a trigger. + + >>> q = df.writeStream.trigger(processingTime='5 seconds').start( ... queryName='that_query', outputMode="append", format='memory') - >>> sq.name + >>> q.name 'that_query' - >>> sq.isActive + >>> q.isActive True - >>> sq.stop() + >>> q.stop() """ self.options(**options) if outputMode is not None: @@ -1176,15 +1442,28 @@ def toTable( Examples -------- - >>> sdf.writeStream.format('parquet').queryName('query').toTable('output_table') - ... # doctest: +SKIP - - >>> sdf.writeStream.trigger(processingTime='5 seconds').toTable( - ... 'output_table', - ... queryName='that_query', - ... outputMode="append", - ... format='parquet', - ... checkpointLocation='/tmp/checkpoint') # doctest: +SKIP + Save a data stream to a table. + + >>> import tempfile + >>> import time + >>> _ = spark.sql("DROP TABLE IF EXISTS my_table2") + >>> with tempfile.TemporaryDirectory() as d: + ... # Create a table with Rate source. + ... q = spark.readStream.format("rate").option( + ... "rowsPerSecond", 10).load().writeStream.toTable( + ... "my_table2", + ... queryName='that_query', + ... outputMode="append", + ... format='parquet', + ... checkpointLocation=d) + ... time.sleep(3) + ... q.stop() + ... spark.read.table("my_table2").show() + ... _ = spark.sql("DROP TABLE my_table2") + +...---------+-----+ + |...timestamp|value| + +...---------+-----+ + ... """ self.options(**options) if outputMode is not None: @@ -1201,23 +1480,17 @@ def toTable( def _test() -> None: import doctest import os - import tempfile from pyspark.sql import SparkSession import pyspark.sql.streaming.readwriter - from py4j.protocol import Py4JError os.chdir(os.environ["SPARK_HOME"]) globs = pyspark.sql.streaming.readwriter.__dict__.copy() - try: - spark = SparkSession._getActiveSessionOrCreate() - except Py4JError: # noqa: F821 - spark = SparkSession(sc) # type: ignore[name-defined] # noqa: F821 - - globs["tempfile"] = tempfile - globs["spark"] = spark - globs["sdf"] = spark.readStream.format("text").load("python/test_support/sql/streaming") - globs["sdf_schema"] = StructType([StructField("data", StringType(), True)]) + globs["spark"] = ( + SparkSession.builder.master("local[4]") + .appName("sql.streaming.readwriter tests") + .getOrCreate() + ) (failure_count, test_count) = doctest.testmod( pyspark.sql.streaming.readwriter,