Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions connect/common/src/main/protobuf/spark/connect/commands.proto
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,9 @@ message WriteStreamOperationStart {

StreamingForeachFunction foreach_writer = 13;
StreamingForeachFunction foreach_batch = 14;

// (Optional) Columns used for clustering the table.
repeated string clustering_column_names = 15;
}

message StreamingForeachFunction {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3216,6 +3216,10 @@ class SparkConnectPlanner(
writer.partitionBy(writeOp.getPartitioningColumnNamesList.asScala.toList: _*)
}

if (writeOp.getClusteringColumnNamesCount > 0) {
writer.clusterBy(writeOp.getClusteringColumnNamesList.asScala.toList: _*)
}

writeOp.getTriggerCase match {
case TriggerCase.PROCESSING_TIME_INTERVAL =>
writer.trigger(Trigger.ProcessingTime(writeOp.getProcessingTimeInterval))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,23 @@ final class DataStreamWriter[T] private[sql] (ds: Dataset[T]) extends Logging {
this
}

/**
* Clusters the output by the given columns. If specified, the output is laid out such that
* records with similar values on the clustering column are grouped together in the same file.
*
* Clustering improves query efficiency by allowing queries with predicates on the clustering
* columns to skip unnecessary data. Unlike partitioning, clustering can be used on very high
* cardinality columns.
*
* @since 4.0.0
*/
@scala.annotation.varargs
def clusterBy(colNames: String*): DataStreamWriter[T] = {
sinkBuilder.clearClusteringColumnNames()
sinkBuilder.addAllClusteringColumnNames(colNames.asJava)
this
}

/**
* Adds an output option for the underlying data source.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,42 @@ class ClientStreamingQuerySuite extends QueryTest with RemoteSparkSession with L
}
}

test("clusterBy") {
withSQLConf(
"spark.sql.shuffle.partitions" -> "1" // Avoid too many reducers.
) {
spark.sql("DROP TABLE IF EXISTS my_table").collect()

withTempPath { ckpt =>
val q1 = spark.readStream
.format("rate")
.load()
.writeStream
.clusterBy("value")
.option("checkpointLocation", ckpt.getCanonicalPath)
.toTable("my_table")

try {
q1.processAllAvailable()
eventually(timeout(30.seconds)) {
checkAnswer(
spark.sql("DESCRIBE my_table"),
Seq(
Row("timestamp", "timestamp", null),
Row("value", "bigint", null),
Row("# Clustering Information", "", ""),
Row("# col_name", "data_type", "comment"),
Row("value", "bigint", null)))
assert(spark.table("my_sink").count() > 0)
}
} finally {
q1.stop()
spark.sql("DROP TABLE my_table")
}
}
}
}

test("throw exception in streaming") {
try {
val session = spark
Expand Down
4 changes: 3 additions & 1 deletion project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,9 @@ object MimaExcludes {
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.SQLImplicits.session"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SparkSession#implicits._sqlContext"),
// SPARK-48761: Add clusterBy() to CreateTableWriter.
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.CreateTableWriter.clusterBy")
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.CreateTableWriter.clusterBy"),
// SPARK-48901: Add clusterBy() to DataStreamWriter.
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.DataStreamWriter.clusterBy")
)

// Default exclude rules
Expand Down
136 changes: 68 additions & 68 deletions python/pyspark/sql/connect/proto/commands_pb2.py

Large diffs are not rendered by default.

9 changes: 9 additions & 0 deletions python/pyspark/sql/connect/proto/commands_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -905,6 +905,7 @@ class WriteStreamOperationStart(google.protobuf.message.Message):
TABLE_NAME_FIELD_NUMBER: builtins.int
FOREACH_WRITER_FIELD_NUMBER: builtins.int
FOREACH_BATCH_FIELD_NUMBER: builtins.int
CLUSTERING_COLUMN_NAMES_FIELD_NUMBER: builtins.int
@property
def input(self) -> pyspark.sql.connect.proto.relations_pb2.Relation:
"""(Required) The output of the `input` streaming relation will be written."""
Expand Down Expand Up @@ -932,6 +933,11 @@ class WriteStreamOperationStart(google.protobuf.message.Message):
def foreach_writer(self) -> global___StreamingForeachFunction: ...
@property
def foreach_batch(self) -> global___StreamingForeachFunction: ...
@property
def clustering_column_names(
self,
) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
"""(Optional) Columns used for clustering the table."""
def __init__(
self,
*,
Expand All @@ -949,6 +955,7 @@ class WriteStreamOperationStart(google.protobuf.message.Message):
table_name: builtins.str = ...,
foreach_writer: global___StreamingForeachFunction | None = ...,
foreach_batch: global___StreamingForeachFunction | None = ...,
clustering_column_names: collections.abc.Iterable[builtins.str] | None = ...,
) -> None: ...
def HasField(
self,
Expand Down Expand Up @@ -982,6 +989,8 @@ class WriteStreamOperationStart(google.protobuf.message.Message):
field_name: typing_extensions.Literal[
"available_now",
b"available_now",
"clustering_column_names",
b"clustering_column_names",
"continuous_checkpoint_interval",
b"continuous_checkpoint_interval",
"foreach_batch",
Expand Down
19 changes: 19 additions & 0 deletions python/pyspark/sql/connect/streaming/readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,25 @@ def partitionBy(self, *cols: str) -> "DataStreamWriter": # type: ignore[misc]

partitionBy.__doc__ = PySparkDataStreamWriter.partitionBy.__doc__

@overload
def clusterBy(self, *cols: str) -> "DataStreamWriter":
...

@overload
def clusterBy(self, __cols: List[str]) -> "DataStreamWriter":
...

def clusterBy(self, *cols: str) -> "DataStreamWriter": # type: ignore[misc]
if len(cols) == 1 and isinstance(cols[0], (list, tuple)):
cols = cols[0]
# Clear any existing columns (if any).
while len(self._write_proto.clustering_column_names) > 0:
self._write_proto.clustering_column_names.pop()
self._write_proto.clustering_column_names.extend(cast(List[str], cols))
return self

clusterBy.__doc__ = PySparkDataStreamWriter.clusterBy.__doc__

def queryName(self, queryName: str) -> "DataStreamWriter":
if not queryName or type(queryName) != str or len(queryName.strip()) == 0:
raise PySparkValueError(
Expand Down
59 changes: 59 additions & 0 deletions python/pyspark/sql/streaming/readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1123,6 +1123,65 @@ def partitionBy(self, *cols: str) -> "DataStreamWriter": # type: ignore[misc]
self._jwrite = self._jwrite.partitionBy(_to_seq(self._spark._sc, cols))
return self

@overload
def clusterBy(self, *cols: str) -> "DataStreamWriter":
...

@overload
def clusterBy(self, __cols: List[str]) -> "DataStreamWriter":
...

def clusterBy(self, *cols: str) -> "DataStreamWriter": # type: ignore[misc]
"""Clusters the output by the given columns.

If specified, the output is laid out such that records with similar values on the clustering
column(s) are grouped together in the same file.

Clustering improves query efficiency by allowing queries with predicates on the clustering
columns to skip unnecessary data. Unlike partitioning, clustering can be used on very high
cardinality columns.

.. versionadded:: 4.0.0

Parameters
----------
cols : str or list
name of columns

Notes
-----
This API is evolving.

Examples
--------
>>> df = spark.readStream.format("rate").load()
>>> df.writeStream.clusterBy("value")
<...streaming.readwriter.DataStreamWriter object ...>

Cluster-by timestamp column from Rate source.

>>> import tempfile
>>> import time
>>> with tempfile.TemporaryDirectory(prefix="partitionBy1") as d:
... with tempfile.TemporaryDirectory(prefix="partitionBy2") as cp:
... df = spark.readStream.format("rate").option("rowsPerSecond", 10).load()
... q = df.writeStream.clusterBy(
... "timestamp").format("parquet").option("checkpointLocation", cp).start(d)
... time.sleep(5)
... q.stop()
... spark.read.schema(df.schema).parquet(d).show()
+...---------+-----+
|...timestamp|value|
+...---------+-----+
...
"""
from pyspark.sql.classic.column import _to_seq

if len(cols) == 1 and isinstance(cols[0], (list, tuple)):
cols = cols[0]
self._jwrite = self._jwrite.clusterBy(_to_seq(self._spark._sc, cols))
return self

def queryName(self, queryName: str) -> "DataStreamWriter":
"""Specifies the name of the :class:`StreamingQuery` that can be started with
:func:`start`. This name must be unique among all the currently active queries
Expand Down
25 changes: 25 additions & 0 deletions python/pyspark/sql/tests/streaming/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,31 @@ def test_streaming_write_to_table(self):
result = self.spark.sql("SELECT value FROM output_table").collect()
self.assertTrue(len(result) > 0)

def test_streaming_write_to_table_cluster_by(self):
with self.table("output_table"), tempfile.TemporaryDirectory(prefix="to_table") as tmpdir:
df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load()
q = df.writeStream.clusterBy("value").toTable(
"output_table", format="parquet", checkpointLocation=tmpdir
)
self.assertTrue(q.isActive)
time.sleep(10)
q.stop()
result = self.spark.sql("DESCRIBE output_table").collect()
self.assertEqual(
set(
[
Row(col_name="timestamp", data_type="timestamp", comment=None),
Row(col_name="value", data_type="bigint", comment=None),
Row(col_name="# Clustering Information", data_type="", comment=""),
Row(col_name="# col_name", data_type="data_type", comment="comment"),
Row(col_name="value", data_type="bigint", comment=None),
]
),
set(result),
)
result = self.spark.sql("SELECT value FROM output_table").collect()
self.assertTrue(len(result) > 0)

def test_streaming_with_temporary_view(self):
"""
This verifies createOrReplaceTempView() works with a streaming dataframe. An SQL
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@ import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.connector.catalog.{Identifier, SupportsWrite, Table, TableCatalog, TableProvider, V1Table, V2TableWithV1Fallback}
import org.apache.spark.sql.connector.catalog.TableCapability._
import org.apache.spark.sql.connector.expressions.{ClusterByTransform, FieldReference}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.execution.datasources.{DataSource, DataSourceUtils}
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Utils, FileDataSourceV2}
import org.apache.spark.sql.execution.datasources.v2.python.PythonDataSourceV2
import org.apache.spark.sql.execution.streaming._
Expand Down Expand Up @@ -166,6 +167,24 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
@scala.annotation.varargs
def partitionBy(colNames: String*): DataStreamWriter[T] = {
this.partitioningColumns = Option(colNames)
validatePartitioningAndClustering()
this
}

/**
* Clusters the output by the given columns. If specified, the output is laid out such that
* records with similar values on the clustering column are grouped together in the same file.
*
* Clustering improves query efficiency by allowing queries with predicates on the clustering
* columns to skip unnecessary data. Unlike partitioning, clustering can be used on very high
* cardinality columns.
*
* @since 4.0.0
*/
@scala.annotation.varargs
def clusterBy(colNames: String*): DataStreamWriter[T] = {
this.clusteringColumns = Option(colNames)
validatePartitioningAndClustering()
this
}

Expand Down Expand Up @@ -288,12 +307,21 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {

if (!catalog.asTableCatalog.tableExists(identifier)) {
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._

val properties = normalizedClusteringCols.map { cols =>
Map(
DataSourceUtils.CLUSTERING_COLUMNS_KEY -> DataSourceUtils.encodePartitioningColumns(cols))
}.getOrElse(Map.empty)
val partitioningOrClusteringTransform = normalizedClusteringCols.map { colNames =>
Array(ClusterByTransform(colNames.map(col => FieldReference(col)))).toImmutableArraySeq
}.getOrElse(partitioningColumns.getOrElse(Nil).asTransforms.toImmutableArraySeq)

/**
* Note, currently the new table creation by this API doesn't fully cover the V2 table.
* TODO (SPARK-33638): Full support of v2 table creation
*/
val tableSpec = UnresolvedTableSpec(
Map.empty[String, String],
properties,
Some(source),
OptionList(Seq.empty),
extraOptions.get("path"),
Expand All @@ -303,7 +331,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
val cmd = CreateTable(
UnresolvedIdentifier(originalMultipartIdentifier),
df.schema.asNullable.map(ColumnDefinition.fromV1Column(_, parser)),
partitioningColumns.getOrElse(Nil).asTransforms.toImmutableArraySeq,
partitioningOrClusteringTransform,
tableSpec,
ignoreIfExists = false)
Dataset.ofRows(df.sparkSession, cmd)
Expand Down Expand Up @@ -439,10 +467,22 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
}

private def createV1Sink(optionsWithPath: CaseInsensitiveMap[String]): Sink = {
// Do not allow the user to specify clustering columns in the options. Ignoring this option is
// consistent with the behavior of DataFrameWriter on non Path-based tables and with the
// behavior of DataStreamWriter on partitioning columns specified in options.
val optionsWithoutClusteringKey =
optionsWithPath.originalMap - DataSourceUtils.CLUSTERING_COLUMNS_KEY

val optionsWithClusteringColumns = normalizedClusteringCols match {
case Some(cols) => optionsWithoutClusteringKey + (
DataSourceUtils.CLUSTERING_COLUMNS_KEY ->
DataSourceUtils.encodePartitioningColumns(cols))
case None => optionsWithoutClusteringKey
}
val ds = DataSource(
df.sparkSession,
className = source,
options = optionsWithPath.originalMap,
options = optionsWithClusteringColumns,
partitionColumns = normalizedParCols.getOrElse(Nil))
ds.createSink(outputMode)
}
Expand Down Expand Up @@ -514,6 +554,10 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
cols.map(normalize(_, "Partition"))
}

private def normalizedClusteringCols: Option[Seq[String]] = clusteringColumns.map { cols =>
cols.map(normalize(_, "Clustering"))
}

/**
* The given column name may not be equal to any of the existing column names if we were in
* case-insensitive context. Normalize the given column name to the real one so that we don't
Expand All @@ -532,6 +576,13 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
}
}

// Validate that partitionBy isn't used with clusterBy.
private def validatePartitioningAndClustering(): Unit = {
if (clusteringColumns.nonEmpty && partitioningColumns.nonEmpty) {
throw QueryCompilationErrors.clusterByWithPartitionedBy()
}
}

///////////////////////////////////////////////////////////////////////////////////////
// Builder pattern config options
///////////////////////////////////////////////////////////////////////////////////////
Expand All @@ -554,6 +605,8 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
private var foreachBatchWriter: (Dataset[T], Long) => Unit = null

private var partitioningColumns: Option[Seq[String]] = None

private var clusteringColumns: Option[Seq[String]] = None
}

object DataStreamWriter {
Expand Down
Loading