Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 2 additions & 12 deletions mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
Original file line number Diff line number Diff line change
Expand Up @@ -343,16 +343,11 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] {
class FPGrowthModelWriter(instance: FPGrowthModel) extends MLWriter {

override protected def saveImpl(path: String): Unit = {
if (ReadWriteUtils.localSavingModeState.get()) {
throw new UnsupportedOperationException(
"FPGrowthModel does not support saving to local filesystem path."
)
}
val extraMetadata: JObject = Map("numTrainingRecords" -> instance.numTrainingRecords)
DefaultParamsWriter.saveMetadata(instance, path, sparkSession,
extraMetadata = Some(extraMetadata))
val dataPath = new Path(path, "data").toString
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we pass Path object to saveDataFrame directly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good!

instance.freqItemsets.write.parquet(dataPath)
ReadWriteUtils.saveDataFrame(dataPath, instance.freqItemsets)
}
}

Expand All @@ -362,11 +357,6 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] {
private val className = classOf[FPGrowthModel].getName

override def load(path: String): FPGrowthModel = {
if (ReadWriteUtils.localSavingModeState.get()) {
throw new UnsupportedOperationException(
"FPGrowthModel does not support loading from local filesystem path."
)
}
implicit val format = DefaultFormats
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className)
val (major, minor) = VersionUtils.majorMinorVersion(metadata.sparkVersion)
Expand All @@ -378,7 +368,7 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] {
(metadata.metadata \ "numTrainingRecords").extract[Long]
}
val dataPath = new Path(path, "data").toString
val frequentItems = sparkSession.read.parquet(dataPath)
val frequentItems = ReadWriteUtils.loadDataFrame(dataPath, sparkSession)
val itemSupport = if (numTrainingRecords == 0L) {
Map.empty[Any, Double]
} else {
Expand Down
30 changes: 29 additions & 1 deletion mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ import org.apache.spark.ml.feature.RFormulaModel
import org.apache.spark.ml.linalg.{DenseMatrix, DenseVector, Matrix, SparseMatrix, SparseVector, Vector}
import org.apache.spark.ml.param.{ParamPair, Params}
import org.apache.spark.ml.tuning.ValidatorParams
import org.apache.spark.sql.{SparkSession, SQLContext}
import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext}
import org.apache.spark.sql.execution.arrow.ArrowFileReadWrite
import org.apache.spark.util.{Utils, VersionUtils}

/**
Expand Down Expand Up @@ -1142,4 +1143,31 @@ private[spark] object ReadWriteUtils {
spark.read.parquet(path).as[T].collect()
}
}

def saveDataFrame(path: String, df: DataFrame): Unit = {
if (localSavingModeState.get()) {
val filePath = Paths.get(path)
Files.createDirectories(filePath.getParent)

df match {
case d: org.apache.spark.sql.classic.DataFrame =>
ArrowFileReadWrite.save(d, path)
case _ => throw new UnsupportedOperationException("Unsupported dataframe type")
}
} else {
df.write.parquet(path)
}
}

def loadDataFrame(path: String, spark: SparkSession): DataFrame = {
if (localSavingModeState.get()) {
spark match {
case s: org.apache.spark.sql.classic.SparkSession =>
ArrowFileReadWrite.load(s, path)
case _ => throw new UnsupportedOperationException("Unsupported session type")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we show actual session type in the error?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sense, will update!

}
} else {
spark.read.parquet(path)
}
}
Comment on lines +1147 to +1172
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if we have localSavingModeState set to true this will write out an arrow file which is not stable format wise. It does look like localSavingModeState is only set to true in internal methods in Scala. Looking in the PySpark docstrings I see we tell people to use this API so I remain -0.9.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi @holdenk , as @WeichenXu123 explained #53150 (comment), this is a runtime temporary file in spark connect server side, and will be cleaned after session close.
So I think we don't have to use a stable format here.

Copy link
Contributor

@WeichenXu123 WeichenXu123 Nov 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

localSavingModeState is also used internally, (only Spark driver code can set the flag) . Where does the doc string mentioned it ? we should remove it from doc and mark localSavingModeState as private field

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, even it is just a temporary session file, is there any reason not to use Parquet but Arrow file format?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can read/write parquet with arrow, but it requires a new dependency

<dependency>
    <groupId>org.apache.parquet</groupId>
    <artifactId>parquet-arrow</artifactId>
</dependency>

otherwise, I am not sure whether we have utils to read/write parquet.

}
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
}
val fPGrowth = new FPGrowth()
testEstimatorAndModelReadWrite(fPGrowth, dataset, FPGrowthSuite.allParamSettings,
FPGrowthSuite.allParamSettings, checkModelData, skipTestSaveLocal = true)
FPGrowthSuite.allParamSettings, checkModelData)
}
}

Expand Down
32 changes: 24 additions & 8 deletions sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2340,14 +2340,13 @@ class Dataset[T] private[sql](
}

/** Convert to an RDD of serialized ArrowRecordBatches. */
private[sql] def toArrowBatchRdd(plan: SparkPlan): RDD[Array[Byte]] = {
private def toArrowBatchRddImpl(
plan: SparkPlan,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: 4 spaces indentation

maxRecordsPerBatch: Int,
timeZoneId: String,
errorOnDuplicatedFieldNames: Boolean,
largeVarTypes: Boolean): RDD[Array[Byte]] = {
val schemaCaptured = this.schema
val maxRecordsPerBatch = sparkSession.sessionState.conf.arrowMaxRecordsPerBatch
val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone
val errorOnDuplicatedFieldNames =
sparkSession.sessionState.conf.pandasStructHandlingMode == "legacy"
val largeVarTypes =
sparkSession.sessionState.conf.arrowUseLargeVarTypes
plan.execute().mapPartitionsInternal { iter =>
val context = TaskContext.get()
ArrowConverters.toBatchIterator(
Expand All @@ -2361,7 +2360,24 @@ class Dataset[T] private[sql](
}
}

// This is only used in tests, for now.
private[sql] def toArrowBatchRdd(
maxRecordsPerBatch: Int,
timeZoneId: String,
errorOnDuplicatedFieldNames: Boolean,
largeVarTypes: Boolean): RDD[Array[Byte]] = {
toArrowBatchRddImpl(queryExecution.executedPlan,
maxRecordsPerBatch, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes)
}

private[sql] def toArrowBatchRdd(plan: SparkPlan): RDD[Array[Byte]] = {
toArrowBatchRddImpl(
plan,
sparkSession.sessionState.conf.arrowMaxRecordsPerBatch,
sparkSession.sessionState.conf.sessionLocalTimeZone,
sparkSession.sessionState.conf.pandasStructHandlingMode == "legacy",
sparkSession.sessionState.conf.arrowUseLargeVarTypes)
}

private[sql] def toArrowBatchRdd: RDD[Array[Byte]] = {
toArrowBatchRdd(queryExecution.executedPlan)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -555,25 +555,41 @@ private[sql] object ArrowConverters extends Logging {
arrowBatches: Iterator[Array[Byte]],
schemaString: String,
session: SparkSession): DataFrame = {
val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
toDataFrame(
arrowBatches,
DataType.fromJson(schemaString).asInstanceOf[StructType],
session,
session.sessionState.conf.sessionLocalTimeZone,
false,
session.sessionState.conf.arrowUseLargeVarTypes)
}

/**
* Create a DataFrame from an iterator of serialized ArrowRecordBatches.
*/
private[sql] def toDataFrame(
arrowBatches: Iterator[Array[Byte]],
schema: StructType,
session: SparkSession,
timeZoneId: String,
errorOnDuplicatedFieldNames: Boolean,
largeVarTypes: Boolean): DataFrame = {
val attrs = toAttributes(schema)
val batchesInDriver = arrowBatches.toArray
val largeVarTypes = session.sessionState.conf.arrowUseLargeVarTypes
val shouldUseRDD = session.sessionState.conf
.arrowLocalRelationThreshold < batchesInDriver.map(_.length.toLong).sum

if (shouldUseRDD) {
logDebug("Using RDD-based createDataFrame with Arrow optimization.")
val timezone = session.sessionState.conf.sessionLocalTimeZone
val rdd = session.sparkContext
.parallelize(batchesInDriver.toImmutableArraySeq, batchesInDriver.length)
.mapPartitions { batchesInExecutors =>
ArrowConverters.fromBatchIterator(
batchesInExecutors,
schema,
timezone,
errorOnDuplicatedFieldNames = false,
largeVarTypes = largeVarTypes,
timeZoneId,
errorOnDuplicatedFieldNames,
largeVarTypes,
TaskContext.get())
}
session.internalCreateDataFrame(rdd.setName("arrow"), schema)
Expand All @@ -582,9 +598,9 @@ private[sql] object ArrowConverters extends Logging {
val data = ArrowConverters.fromBatchIterator(
batchesInDriver.iterator,
schema,
session.sessionState.conf.sessionLocalTimeZone,
errorOnDuplicatedFieldNames = false,
largeVarTypes = largeVarTypes,
timeZoneId,
errorOnDuplicatedFieldNames,
largeVarTypes,
TaskContext.get())

// Project/copy it. Otherwise, the Arrow column vectors will be closed and released out.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.arrow

import java.io.{ByteArrayOutputStream, FileOutputStream}
import java.nio.channels.Channels
import java.nio.file.Files
import java.nio.file.Paths

import scala.jdk.CollectionConverters._

import org.apache.arrow.vector._
import org.apache.arrow.vector.ipc.{ArrowFileReader, ArrowFileWriter, WriteChannel}
import org.apache.arrow.vector.ipc.message.MessageSerializer
import org.apache.arrow.vector.types.pojo.Schema

import org.apache.spark.sql.classic.{DataFrame, SparkSession}
import org.apache.spark.sql.util.ArrowUtils

private[sql] class SparkArrowFileWriter(
arrowSchema: Schema,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: 4 spaces indentation

path: String) extends AutoCloseable {
private val allocator =
ArrowUtils.rootAllocator.newChildAllocator(
s"to${this.getClass.getSimpleName}", 0, Long.MaxValue)

protected val root = VectorSchemaRoot.create(arrowSchema, allocator)
protected val loader = new VectorLoader(root)
protected val arrowWriter = ArrowWriter.create(root)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is arrowWriter used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch!


protected val fileWriter =
new ArrowFileWriter(root, null, Channels.newChannel(new FileOutputStream(path)))

override def close(): Unit = {
root.close()
allocator.close()
fileWriter.close()
}

def write(batchBytesIter: Iterator[Array[Byte]]): Unit = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like do such thing:

Dataset -> Arrow batches -> Bytes -> Arrow batches -> Write Arrow batches by ArrowFileWriter

Looks like the intermediate Bytes could be skipped?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think he's doing it cuz local data has to go to executors, and to do that, the arrow batches should be in ipc.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dataset is already distributed on executors. Rows are written into Arrow batches in executors. If they are not to distributed again, they could be in Arrow batches, no?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Below, writer.write(rdd.toLocalIterator) I think the code path here is to collect Arrow batches into Spark Diver, and write them in Spark Driver. So .. it should collect the Arrow batches from executors to the driver.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it's because to write down into Drivers' local file system

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see. Okay.

fileWriter.start()
while (batchBytesIter.hasNext) {
val batchBytes = batchBytesIter.next()
val batch = ArrowConverters.loadBatch(batchBytes, allocator)
Copy link
Contributor Author

@zhengruifeng zhengruifeng Nov 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The batch: ArrowRecordBatch doesn't extends Serializable, so still use the Array[Byte] as the underlying data in the PR.

loader.load(batch)
fileWriter.writeBatch()
}
fileWriter.close()
}
}

private[sql] class SparkArrowFileReader(path: String) extends AutoCloseable {
private val allocator =
ArrowUtils.rootAllocator.newChildAllocator(
s"to${this.getClass.getSimpleName}", 0, Long.MaxValue)

protected val fileReader =
new ArrowFileReader(Files.newByteChannel(Paths.get(path)), allocator)

override def close(): Unit = {
allocator.close()
fileReader.close()
}

val schema: Schema = fileReader.getVectorSchemaRoot.getSchema

def read(): Iterator[Array[Byte]] = {
fileReader.getRecordBlocks.iterator().asScala.map { block =>
fileReader.loadRecordBatch(block)
val root = fileReader.getVectorSchemaRoot
val unloader = new VectorUnloader(root)
val batch = unloader.getRecordBatch
val out = new ByteArrayOutputStream()
val writeChannel = new WriteChannel(Channels.newChannel(out))
MessageSerializer.serialize(writeChannel, batch)
out.toByteArray
}
}
}

private[spark] object ArrowFileReadWrite {
def save(df: DataFrame, path: String): Unit = {
val maxRecordsPerBatch = df.sparkSession.sessionState.conf.arrowMaxRecordsPerBatch
val rdd = df.toArrowBatchRdd(maxRecordsPerBatch, "UTC", true, false)
val arrowSchema = ArrowUtils.toArrowSchema(df.schema, "UTC", true, false)
val writer = new SparkArrowFileWriter(arrowSchema, path)
writer.write(rdd.toLocalIterator)
Copy link
Member

@viirya viirya Nov 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead, can we call toLocalIterator on original DataFrame's rdd and write rows to Arrow batches locally? Then we don't need to have the redundant Bytes?

}

def load(spark: SparkSession, path: String): DataFrame = {
val reader = new SparkArrowFileReader(path)
val schema = ArrowUtils.fromArrowSchema(reader.schema)
ArrowConverters.toDataFrame(reader.read(), schema, spark, "UTC", true, false)
}
}