Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1115,13 +1115,20 @@ class Dataset[T] private[sql] (
}

/** @inheritdoc */
protected def checkpoint(eager: Boolean, reliableCheckpoint: Boolean): Dataset[T] = {
protected def checkpoint(
eager: Boolean,
reliableCheckpoint: Boolean,
storageLevel: Option[StorageLevel]): Dataset[T] = {
sparkSession.newDataset(agnosticEncoder) { builder =>
val command = sparkSession.newCommand { builder =>
builder.getCheckpointCommandBuilder
val checkpointBuilder = builder.getCheckpointCommandBuilder
.setLocal(!reliableCheckpoint)
.setEager(eager)
.setRelation(this.plan.getRoot)
storageLevel.foreach { storageLevel =>
checkpointBuilder.setStorageLevel(
StorageLevelProtoConverter.toConnectProtoType(storageLevel))
}
}
val responseIter = sparkSession.execute(command)
try {
Expand Down Expand Up @@ -1304,6 +1311,10 @@ class Dataset[T] private[sql] (
/** @inheritdoc */
override def localCheckpoint(eager: Boolean): Dataset[T] = super.localCheckpoint(eager)

/** @inheritdoc */
override def localCheckpoint(eager: Boolean, storageLevel: StorageLevel): Dataset[T] =
super.localCheckpoint(eager, storageLevel)

/** @inheritdoc */
override def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] =
super.joinWith(other, condition)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.scalatest.exceptions.TestFailedDueToTimeoutException
import org.apache.spark.SparkException
import org.apache.spark.connect.proto
import org.apache.spark.sql.test.{ConnectFunSuite, RemoteSparkSession, SQLHelper}
import org.apache.spark.storage.StorageLevel

class CheckpointSuite extends ConnectFunSuite with RemoteSparkSession with SQLHelper {

Expand All @@ -50,12 +51,20 @@ class CheckpointSuite extends ConnectFunSuite with RemoteSparkSession with SQLHe
checkFragments(captureStdOut(block), fragmentsToCheck)
}

test("checkpoint") {
test("localCheckpoint") {
Comment on lines -53 to +54
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: there are no tests that test Connect reliable checkpoint. I renamed this test accordingly.

val df = spark.range(100).localCheckpoint()
testCapturedStdOut(df.explain(), "ExistingRDD")
}

test("checkpoint gc") {
test("localCheckpoint with StorageLevel") {
// We don't have a way to reach into the server and assert the storage level server side, but
// this test should cover for unexpected errors in the API.
Comment on lines +60 to +61
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hvanhovell with the SQL API refactoring, would it be now possible to have tests that use a connect client to self-connect, and have server side objects (SparkContext) etc. available inside the test to verify? The existing SparkConnectServerTest can only test internal SparkConnectClient with the server, due to past namespace conflicts between server and client SparkSession APIs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@juliuszsompolski it will take a few more PRs, but yeah that is the objective.

val df =
spark.range(100).localCheckpoint(eager = true, storageLevel = StorageLevel.DISK_ONLY)
df.collect()
}

test("localCheckpoint gc") {
val df = spark.range(100).localCheckpoint(eager = true)
val encoder = df.agnosticEncoder
val dfId = df.plan.getRoot.getCachedRemoteRelation.getRelationId
Expand All @@ -77,7 +86,7 @@ class CheckpointSuite extends ConnectFunSuite with RemoteSparkSession with SQLHe

// This test is flaky because cannot guarantee GC
// You can locally run this to verify the behavior.
ignore("checkpoint gc derived DataFrame") {
ignore("localCheckpoint gc derived DataFrame") {
var df1 = spark.range(100).localCheckpoint(eager = true)
var derived = df1.repartition(10)
val encoder = df1.agnosticEncoder
Expand Down
9 changes: 7 additions & 2 deletions python/pyspark/sql/classic/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,8 +360,13 @@ def checkpoint(self, eager: bool = True) -> ParentDataFrame:
jdf = self._jdf.checkpoint(eager)
return DataFrame(jdf, self.sparkSession)

def localCheckpoint(self, eager: bool = True) -> ParentDataFrame:
jdf = self._jdf.localCheckpoint(eager)
def localCheckpoint(
self, eager: bool = True, storageLevel: Optional[StorageLevel] = None
) -> ParentDataFrame:
if storageLevel is None:
jdf = self._jdf.localCheckpoint(eager)
else:
jdf = self._jdf.localCheckpoint(eager, self._sc._getJavaStorageLevel(storageLevel))
return DataFrame(jdf, self.sparkSession)

def withWatermark(self, eventTime: str, delayThreshold: str) -> ParentDataFrame:
Expand Down
6 changes: 4 additions & 2 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2173,8 +2173,10 @@ def checkpoint(self, eager: bool = True) -> ParentDataFrame:
assert isinstance(checkpointed._plan, plan.CachedRemoteRelation)
return checkpointed

def localCheckpoint(self, eager: bool = True) -> ParentDataFrame:
cmd = plan.Checkpoint(child=self._plan, local=True, eager=eager)
def localCheckpoint(
self, eager: bool = True, storageLevel: Optional[StorageLevel] = None
) -> ParentDataFrame:
cmd = plan.Checkpoint(child=self._plan, local=True, eager=eager, storage_level=storageLevel)
_, properties, self._execution_info = self._session.client.execute_command(
cmd.command(self._session.client)
)
Expand Down
22 changes: 15 additions & 7 deletions python/pyspark/sql/connect/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -1868,21 +1868,29 @@ def command(self, session: "SparkConnectClient") -> proto.Command:


class Checkpoint(LogicalPlan):
def __init__(self, child: Optional["LogicalPlan"], local: bool, eager: bool) -> None:
def __init__(
self,
child: Optional["LogicalPlan"],
local: bool,
eager: bool,
storage_level: Optional[StorageLevel] = None,
) -> None:
super().__init__(child)
self._local = local
self._eager = eager
self._storage_level = storage_level

def command(self, session: "SparkConnectClient") -> proto.Command:
cmd = proto.Command()
assert self._child is not None
cmd.checkpoint_command.CopyFrom(
proto.CheckpointCommand(
relation=self._child.plan(session),
local=self._local,
eager=self._eager,
)
checkpoint_command = proto.CheckpointCommand(
relation=self._child.plan(session),
local=self._local,
eager=self._eager,
)
if self._storage_level is not None:
checkpoint_command.storage_level.CopyFrom(storage_level_to_proto(self._storage_level))
cmd.checkpoint_command.CopyFrom(checkpoint_command)
return cmd


Expand Down
14 changes: 7 additions & 7 deletions python/pyspark/sql/connect/proto/commands_pb2.py

Large diffs are not rendered by default.

29 changes: 27 additions & 2 deletions python/pyspark/sql/connect/proto/commands_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2188,6 +2188,7 @@ class CheckpointCommand(google.protobuf.message.Message):
RELATION_FIELD_NUMBER: builtins.int
LOCAL_FIELD_NUMBER: builtins.int
EAGER_FIELD_NUMBER: builtins.int
STORAGE_LEVEL_FIELD_NUMBER: builtins.int
@property
def relation(self) -> pyspark.sql.connect.proto.relations_pb2.Relation:
"""(Required) The logical plan to checkpoint."""
Expand All @@ -2197,22 +2198,46 @@ class CheckpointCommand(google.protobuf.message.Message):
"""
eager: builtins.bool
"""(Required) Whether to checkpoint this dataframe immediately."""
@property
def storage_level(self) -> pyspark.sql.connect.proto.common_pb2.StorageLevel:
"""(Optional) For local checkpoint, the storage level to use."""
def __init__(
self,
*,
relation: pyspark.sql.connect.proto.relations_pb2.Relation | None = ...,
local: builtins.bool = ...,
eager: builtins.bool = ...,
storage_level: pyspark.sql.connect.proto.common_pb2.StorageLevel | None = ...,
) -> None: ...
def HasField(
self, field_name: typing_extensions.Literal["relation", b"relation"]
self,
field_name: typing_extensions.Literal[
"_storage_level",
b"_storage_level",
"relation",
b"relation",
"storage_level",
b"storage_level",
],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"eager", b"eager", "local", b"local", "relation", b"relation"
"_storage_level",
b"_storage_level",
"eager",
b"eager",
"local",
b"local",
"relation",
b"relation",
"storage_level",
b"storage_level",
],
) -> None: ...
def WhichOneof(
self, oneof_group: typing_extensions.Literal["_storage_level", b"_storage_level"]
) -> typing_extensions.Literal["storage_level"] | None: ...

global___CheckpointCommand = CheckpointCommand

Expand Down
9 changes: 8 additions & 1 deletion python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1015,7 +1015,9 @@ def checkpoint(self, eager: bool = True) -> "DataFrame":
"""
...

def localCheckpoint(self, eager: bool = True) -> "DataFrame":
def localCheckpoint(
self, eager: bool = True, storageLevel: Optional[StorageLevel] = None
) -> "DataFrame":
"""Returns a locally checkpointed version of this :class:`DataFrame`. Checkpointing can
be used to truncate the logical plan of this :class:`DataFrame`, which is especially
useful in iterative algorithms where the plan may grow exponentially. Local checkpoints
Expand All @@ -1026,12 +1028,17 @@ def localCheckpoint(self, eager: bool = True) -> "DataFrame":

.. versionchanged:: 4.0.0
Supports Spark Connect.
Added storageLevel parameter.

Parameters
----------
eager : bool, optional, default True
Whether to checkpoint this :class:`DataFrame` immediately.

storageLevel : :class:`StorageLevel`, optional, default None
The StorageLevel with which the checkpoint will be stored.
If not specified, default for RDD local checkpoints.

Returns
-------
:class:`DataFrame`
Expand Down
8 changes: 7 additions & 1 deletion python/pyspark/sql/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,11 +951,17 @@ def test_union_classmethod_usage(self):
def test_isinstance_dataframe(self):
self.assertIsInstance(self.spark.range(1), DataFrame)

def test_checkpoint_dataframe(self):
def test_local_checkpoint_dataframe(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: there are no tests at all for reliable checkpoint in pyspark API. I renamed this test accordingly.

with io.StringIO() as buf, redirect_stdout(buf):
self.spark.range(1).localCheckpoint().explain()
self.assertIn("ExistingRDD", buf.getvalue())

def test_local_checkpoint_dataframe_with_storage_level(self):
# We don't have a way to reach into the server and assert the storage level server side, but
# this test should cover for unexpected errors in the API.
df = self.spark.range(10).localCheckpoint(eager=True, storageLevel=StorageLevel.DISK_ONLY)
df.collect()

def test_transpose(self):
df = self.spark.createDataFrame([{"a": "x", "b": "y", "c": "z"}])

Expand Down
42 changes: 36 additions & 6 deletions sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,8 @@ abstract class Dataset[T] extends Serializable {
* @group basic
* @since 2.1.0
*/
def checkpoint(): Dataset[T] = checkpoint(eager = true, reliableCheckpoint = true)
def checkpoint(): Dataset[T] =
checkpoint(eager = true, reliableCheckpoint = true, storageLevel = None)

/**
* Returns a checkpointed version of this Dataset. Checkpointing can be used to truncate the
Expand All @@ -332,7 +333,7 @@ abstract class Dataset[T] extends Serializable {
* @since 2.1.0
*/
def checkpoint(eager: Boolean): Dataset[T] =
checkpoint(eager = eager, reliableCheckpoint = true)
checkpoint(eager = eager, reliableCheckpoint = true, storageLevel = None)

/**
* Eagerly locally checkpoints a Dataset and return the new Dataset. Checkpointing can be used
Expand All @@ -343,7 +344,8 @@ abstract class Dataset[T] extends Serializable {
* @group basic
* @since 2.3.0
*/
def localCheckpoint(): Dataset[T] = checkpoint(eager = true, reliableCheckpoint = false)
def localCheckpoint(): Dataset[T] =
checkpoint(eager = true, reliableCheckpoint = false, storageLevel = None)

/**
* Locally checkpoints a Dataset and return the new Dataset. Checkpointing can be used to
Expand All @@ -363,7 +365,29 @@ abstract class Dataset[T] extends Serializable {
* @since 2.3.0
*/
def localCheckpoint(eager: Boolean): Dataset[T] =
checkpoint(eager = eager, reliableCheckpoint = false)
checkpoint(eager = eager, reliableCheckpoint = false, storageLevel = None)

/**
* Locally checkpoints a Dataset and return the new Dataset. Checkpointing can be used to
* truncate the logical plan of this Dataset, which is especially useful in iterative algorithms
* where the plan may grow exponentially. Local checkpoints are written to executor storage and
* despite potentially faster they are unreliable and may compromise job completion.
*
* @param eager
* Whether to checkpoint this dataframe immediately
* @param storageLevel
* StorageLevel with which to checkpoint the data.
* @note
* When checkpoint is used with eager = false, the final data that is checkpointed after the
* first action may be different from the data that was used during the job due to
* non-determinism of the underlying operation and retries. If checkpoint is used to achieve
* saving a deterministic snapshot of the data, eager = true should be used. Otherwise, it is
* only deterministic after the first execution, after the checkpoint was finalized.
* @group basic
* @since 4.0.0
*/
def localCheckpoint(eager: Boolean, storageLevel: StorageLevel): Dataset[T] =
checkpoint(eager = eager, reliableCheckpoint = false, storageLevel = Some(storageLevel))

/**
* Returns a checkpointed version of this Dataset.
Expand All @@ -373,8 +397,14 @@ abstract class Dataset[T] extends Serializable {
* @param reliableCheckpoint
* Whether to create a reliable checkpoint saved to files inside the checkpoint directory. If
* false creates a local checkpoint using the caching subsystem
*/
protected def checkpoint(eager: Boolean, reliableCheckpoint: Boolean): Dataset[T]
* @param storageLevel
* Option. If defined, StorageLevel with which to checkpoint the data. Only with
* reliableCheckpoint = false.
*/
protected def checkpoint(
eager: Boolean,
reliableCheckpoint: Boolean,
storageLevel: Option[StorageLevel]): Dataset[T]

/**
* Defines an event time watermark for this [[Dataset]]. A watermark tracks a point in time
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,9 @@ message CheckpointCommand {

// (Required) Whether to checkpoint this dataframe immediately.
bool eager = 3;

// (Optional) For local checkpoint, the storage level to use.
optional StorageLevel storage_level = 4;
}

message MergeIntoTableCommand {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3354,9 +3354,18 @@ class SparkConnectPlanner(
responseObserver: StreamObserver[proto.ExecutePlanResponse]): Unit = {
val target = Dataset
.ofRows(session, transformRelation(checkpointCommand.getRelation))
val checkpointed = target.checkpoint(
eager = checkpointCommand.getEager,
reliableCheckpoint = !checkpointCommand.getLocal)
val checkpointed = if (checkpointCommand.getLocal) {
if (checkpointCommand.hasStorageLevel) {
target.localCheckpoint(
eager = checkpointCommand.getEager,
storageLevel =
StorageLevelProtoConverter.toStorageLevel(checkpointCommand.getStorageLevel))
} else {
target.localCheckpoint(eager = checkpointCommand.getEager)
}
} else {
target.checkpoint(eager = checkpointCommand.getEager)
}

val dfId = UUID.randomUUID().toString
logInfo(log"Caching DataFrame with id ${MDC(DATAFRAME_ID, dfId)}")
Expand Down
11 changes: 10 additions & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -540,13 +540,18 @@ class Dataset[T] private[sql](
def isStreaming: Boolean = logicalPlan.isStreaming

/** @inheritdoc */
protected[sql] def checkpoint(eager: Boolean, reliableCheckpoint: Boolean): Dataset[T] = {
protected[sql] def checkpoint(
eager: Boolean,
reliableCheckpoint: Boolean,
storageLevel: Option[StorageLevel]): Dataset[T] = {
val actionName = if (reliableCheckpoint) "checkpoint" else "localCheckpoint"
withAction(actionName, queryExecution) { physicalPlan =>
val internalRdd = physicalPlan.execute().map(_.copy())
if (reliableCheckpoint) {
assert(storageLevel.isEmpty, "StorageLevel should not be defined for reliableCheckpoint")
internalRdd.checkpoint()
} else {
storageLevel.foreach(storageLevel => internalRdd.persist(storageLevel))
internalRdd.localCheckpoint()
}

Expand Down Expand Up @@ -1810,6 +1815,10 @@ class Dataset[T] private[sql](
/** @inheritdoc */
override def localCheckpoint(eager: Boolean): Dataset[T] = super.localCheckpoint(eager)

/** @inheritdoc */
override def localCheckpoint(eager: Boolean, storageLevel: StorageLevel): Dataset[T] =
super.localCheckpoint(eager, storageLevel)

/** @inheritdoc */
override def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] =
super.joinWith(other, condition)
Expand Down
Loading