From d93588d7a4188df8101b3556afe8e9c194019953 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Fri, 21 Nov 2025 10:45:48 +0800 Subject: [PATCH 1/7] fix --- .../org/apache/spark/ml/fpm/FPGrowth.scala | 14 +---- .../org/apache/spark/ml/util/ReadWrite.scala | 51 +++++++++++++++++-- .../apache/spark/ml/fpm/FPGrowthSuite.scala | 2 +- 3 files changed, 51 insertions(+), 16 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala index 6fd20ceb562b..e25fdc3e05ab 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala @@ -343,16 +343,11 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] { class FPGrowthModelWriter(instance: FPGrowthModel) extends MLWriter { override protected def saveImpl(path: String): Unit = { - if (ReadWriteUtils.localSavingModeState.get()) { - throw new UnsupportedOperationException( - "FPGrowthModel does not support saving to local filesystem path." - ) - } val extraMetadata: JObject = Map("numTrainingRecords" -> instance.numTrainingRecords) DefaultParamsWriter.saveMetadata(instance, path, sparkSession, extraMetadata = Some(extraMetadata)) val dataPath = new Path(path, "data").toString - instance.freqItemsets.write.parquet(dataPath) + ReadWriteUtils.saveDataFrame(dataPath, instance.freqItemsets) } } @@ -362,11 +357,6 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] { private val className = classOf[FPGrowthModel].getName override def load(path: String): FPGrowthModel = { - if (ReadWriteUtils.localSavingModeState.get()) { - throw new UnsupportedOperationException( - "FPGrowthModel does not support loading from local filesystem path." - ) - } implicit val format = DefaultFormats val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val (major, minor) = VersionUtils.majorMinorVersion(metadata.sparkVersion) @@ -378,7 +368,7 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] { (metadata.metadata \ "numTrainingRecords").extract[Long] } val dataPath = new Path(path, "data").toString - val frequentItems = sparkSession.read.parquet(dataPath) + val frequentItems = ReadWriteUtils.loadDataFrame(dataPath, sparkSession) val itemSupport = if (numTrainingRecords == 0L) { Map.empty[Any, Double] } else { diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index e3f31874a4c2..812515aa62a3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -19,10 +19,11 @@ package org.apache.spark.ml.util import java.io.{ BufferedInputStream, BufferedOutputStream, DataInputStream, DataOutputStream, - File, FileInputStream, FileOutputStream, IOException + File, FileInputStream, FileOutputStream, IOException, ObjectInputStream, + ObjectOutputStream } import java.nio.file.{Files, Paths} -import java.util.{Locale, ServiceLoader} +import java.util.{ArrayList, Locale, ServiceLoader} import scala.collection.mutable import scala.jdk.CollectionConverters._ @@ -46,7 +47,8 @@ import org.apache.spark.ml.feature.RFormulaModel import org.apache.spark.ml.linalg.{DenseMatrix, DenseVector, Matrix, SparseMatrix, SparseVector, Vector} import org.apache.spark.ml.param.{ParamPair, Params} import org.apache.spark.ml.tuning.ValidatorParams -import org.apache.spark.sql.{SparkSession, SQLContext} +import org.apache.spark.sql.{DataFrame, Row, SparkSession, SQLContext} +import org.apache.spark.sql.types.StructType import org.apache.spark.util.{Utils, VersionUtils} /** @@ -1142,4 +1144,47 @@ private[spark] object ReadWriteUtils { spark.read.parquet(path).as[T].collect() } } + + def saveDataFrame(path: String, df: DataFrame): Unit = { + if (localSavingModeState.get()) { + val filePath = Paths.get(path) + Files.createDirectories(filePath.getParent) + + Using.resource( + new ObjectOutputStream(new BufferedOutputStream(new FileOutputStream(filePath.toFile))) + ) { oos => + val schema: StructType = df.schema + oos.writeObject(schema) + val it = df.toLocalIterator() + while (it.hasNext) { + oos.writeBoolean(true) // hasNext = True + val row: Row = it.next() + oos.writeObject(row) + } + oos.writeBoolean(false) // hasNext = False + } + } else { + df.write.parquet(path) + } + } + + def loadDataFrame(path: String, spark: SparkSession): DataFrame = { + if (localSavingModeState.get()) { + Using.resource( + new ObjectInputStream(new BufferedInputStream(new FileInputStream(path))) + ) { ois => + val schema = ois.readObject().asInstanceOf[StructType] + val rows = new ArrayList[Row] + var hasNext = ois.readBoolean() + while (hasNext) { + val row = ois.readObject().asInstanceOf[Row] + rows.add(row) + hasNext = ois.readBoolean() + } + spark.createDataFrame(rows, schema) + } + } else { + spark.read.parquet(path) + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala index 1630a5d07d8e..3d994366b891 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala @@ -165,7 +165,7 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul } val fPGrowth = new FPGrowth() testEstimatorAndModelReadWrite(fPGrowth, dataset, FPGrowthSuite.allParamSettings, - FPGrowthSuite.allParamSettings, checkModelData, skipTestSaveLocal = true) + FPGrowthSuite.allParamSettings, checkModelData) } } From 9474f6951a94adc126fd65fa9cb44be9e82d7e1c Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Fri, 21 Nov 2025 22:43:15 +0800 Subject: [PATCH 2/7] apply arrow nit --- .../apache/spark/ml/util/DatasetUtils.scala | 50 ++++++++++++++- .../org/apache/spark/ml/util/ReadWrite.scala | 64 ++++++++++++------- .../apache/spark/sql/types/StructType.scala | 2 +- .../spark/sql/classic/SparkSession.scala | 2 +- .../sql/execution/arrow/ArrowConverters.scala | 6 +- 5 files changed, 95 insertions(+), 29 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala index 06de43260b30..c64e8d3007e4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.util -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.internal.LogKeys.{CLASS_NAME, LABEL_COLUMN, NUM_CLASSES} import org.apache.spark.ml.PredictorParams @@ -28,6 +28,7 @@ import org.apache.spark.ml.param.shared.HasWeightCol import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} import org.apache.spark.rdd.RDD import org.apache.spark.sql._ +import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -212,4 +213,51 @@ private[spark] object DatasetUtils extends Logging { dataset.select(columnToVector(dataset, vectorCol)).head().getAs[Vector](0).size } } + + private[ml] def toArrowBatchRDD( + dataFrame: DataFrame, + timeZoneId: String): RDD[Array[Byte]] = { + dataFrame match { + case df: org.apache.spark.sql.classic.DataFrame => + val spark = df.sparkSession + val schema = df.schema + val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch + df.queryExecution.executedPlan.execute().mapPartitionsInternal { iter => + val context = TaskContext.get() + ArrowConverters.toBatchIterator( + iter, + schema, + maxRecordsPerBatch, + timeZoneId, + true, + false, + context) + } + + case _ => throw new UnsupportedOperationException("Not implemented") + } + } + + private[ml] def fromArrowBatchRDD( + rdd: RDD[Array[Byte]], + schema: StructType, + timeZoneId: String, + sparkSession: SparkSession): DataFrame = { + sparkSession match { + case spark: org.apache.spark.sql.classic.SparkSession => + val rowRDD = rdd.mapPartitions { iter => + val context = TaskContext.get() + ArrowConverters.fromBatchIterator( + iter, + schema, + timeZoneId, + true, + false, + context) + } + spark.internalCreateDataFrame(rowRDD.setName("arrow"), schema) + + case _ => throw new UnsupportedOperationException("Not implemented") + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 812515aa62a3..c04e798dccae 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -19,11 +19,10 @@ package org.apache.spark.ml.util import java.io.{ BufferedInputStream, BufferedOutputStream, DataInputStream, DataOutputStream, - File, FileInputStream, FileOutputStream, IOException, ObjectInputStream, - ObjectOutputStream + File, FileInputStream, FileOutputStream, IOException } import java.nio.file.{Files, Paths} -import java.util.{ArrayList, Locale, ServiceLoader} +import java.util.{Locale, ServiceLoader} import scala.collection.mutable import scala.jdk.CollectionConverters._ @@ -47,7 +46,7 @@ import org.apache.spark.ml.feature.RFormulaModel import org.apache.spark.ml.linalg.{DenseMatrix, DenseVector, Matrix, SparseMatrix, SparseVector, Vector} import org.apache.spark.ml.param.{ParamPair, Params} import org.apache.spark.ml.tuning.ValidatorParams -import org.apache.spark.sql.{DataFrame, Row, SparkSession, SQLContext} +import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext} import org.apache.spark.sql.types.StructType import org.apache.spark.util.{Utils, VersionUtils} @@ -1151,17 +1150,21 @@ private[spark] object ReadWriteUtils { Files.createDirectories(filePath.getParent) Using.resource( - new ObjectOutputStream(new BufferedOutputStream(new FileOutputStream(filePath.toFile))) - ) { oos => + new DataOutputStream(new BufferedOutputStream(new FileOutputStream(filePath.toFile))) + ) { dos => + dos.writeUTF("ARROW") // format + val schema: StructType = df.schema - oos.writeObject(schema) - val it = df.toLocalIterator() - while (it.hasNext) { - oos.writeBoolean(true) // hasNext = True - val row: Row = it.next() - oos.writeObject(row) + dos.writeUTF(schema.json) + + val iter = DatasetUtils.toArrowBatchRDD(df, "UTC").toLocalIterator + while (iter.hasNext) { + val bytes = iter.next() + require(bytes != null) + dos.writeInt(bytes.length) + dos.write(bytes) } - oos.writeBoolean(false) // hasNext = False + dos.writeInt(-1) // End } } else { df.write.parquet(path) @@ -1170,18 +1173,33 @@ private[spark] object ReadWriteUtils { def loadDataFrame(path: String, spark: SparkSession): DataFrame = { if (localSavingModeState.get()) { + val sc = spark match { + case s: org.apache.spark.sql.classic.SparkSession => s.sparkContext + } + Using.resource( - new ObjectInputStream(new BufferedInputStream(new FileInputStream(path))) - ) { ois => - val schema = ois.readObject().asInstanceOf[StructType] - val rows = new ArrayList[Row] - var hasNext = ois.readBoolean() - while (hasNext) { - val row = ois.readObject().asInstanceOf[Row] - rows.add(row) - hasNext = ois.readBoolean() + new DataInputStream(new BufferedInputStream(new FileInputStream(path))) + ) { dis => + val format = dis.readUTF() + require(format == "ARROW") + + val schema: StructType = StructType.fromString(dis.readUTF()) + + val buff = mutable.ListBuffer.empty[Array[Byte]] + var nextBytes = dis.readInt() + while (nextBytes >= 0) { + val bytes = dis.readNBytes(nextBytes) + buff.append(bytes) + nextBytes = dis.readInt() } - spark.createDataFrame(rows, schema) + require(nextBytes == -1) + + DatasetUtils.fromArrowBatchRDD( + sc.parallelize[Array[Byte]](buff.result()), + schema, + "UTC", + spark + ) } } else { spark.read.parquet(path) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala index 5b1d9f1f116a..bbc27b2a73f9 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -531,7 +531,7 @@ object StructType extends AbstractDataType { override private[sql] def simpleString: String = "struct" - private[sql] def fromString(raw: String): StructType = { + private[spark] def fromString(raw: String): StructType = { Try(DataType.fromJson(raw)).getOrElse(LegacyTypeStringParser.parseString(raw)) match { case t: StructType => t case _ => throw DataTypeErrors.failedParsingStructTypeError(raw) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/SparkSession.scala index 5811fe759d3e..e002a1a616fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/SparkSession.scala @@ -399,7 +399,7 @@ class SparkSession private( /** * Creates a `DataFrame` from an `RDD[InternalRow]`. */ - private[sql] def internalCreateDataFrame( + private[spark] def internalCreateDataFrame( catalystRows: RDD[InternalRow], schema: StructType, isStreaming: Boolean = false): DataFrame = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 8b031af14e8b..ac2b873d9beb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -80,7 +80,7 @@ private[sql] class ArrowBatchStreamWriter( } } -private[sql] object ArrowConverters extends Logging { +private[spark] object ArrowConverters extends Logging { private[sql] class ArrowBatchIterator( rowIter: Iterator[InternalRow], schema: StructType, @@ -231,7 +231,7 @@ private[sql] object ArrowConverters extends Logging { * Maps Iterator from InternalRow to serialized ArrowRecordBatches. Limit ArrowRecordBatch size * in a batch by setting maxRecordsPerBatch or use 0 to fully consume rowIter. */ - private[sql] def toBatchIterator( + private[spark] def toBatchIterator( rowIter: Iterator[InternalRow], schema: StructType, maxRecordsPerBatch: Long, @@ -484,7 +484,7 @@ private[sql] object ArrowConverters extends Logging { /** * Maps iterator from serialized ArrowRecordBatches to InternalRows. */ - private[sql] def fromBatchIterator( + private[spark] def fromBatchIterator( arrowBatchIter: Iterator[Array[Byte]], schema: StructType, timeZoneId: String, From 07f67e67d015b1cd594b4d9887e5ea17f99c04d7 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Mon, 24 Nov 2025 21:30:24 +0800 Subject: [PATCH 3/7] init --- .../org/apache/spark/ml/util/ReadWrite.scala | 6 + .../sql/execution/arrow/ArrowConverters.scala | 4 +- .../sql/execution/arrow/ArrowFileWriter.scala | 116 ++++++++++++++++++ 3 files changed, 124 insertions(+), 2 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowFileWriter.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index c04e798dccae..23bcc69e28c4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -1149,6 +1149,12 @@ private[spark] object ReadWriteUtils { val filePath = Paths.get(path) Files.createDirectories(filePath.getParent) + Using.resource( + new DataOutputStream(new BufferedOutputStream(new FileOutputStream(filePath.toFile))) + ) { dos => + + } + Using.resource( new DataOutputStream(new BufferedOutputStream(new FileOutputStream(filePath.toFile))) ) { dos => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index ac2b873d9beb..3b8d0f8c1f8b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.arrow -import java.io.{ByteArrayInputStream, ByteArrayOutputStream, FileInputStream, OutputStream} +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, FileInputStream, FileOutputStream, OutputStream} import java.nio.channels.{Channels, ReadableByteChannel} import scala.collection.mutable.ArrayBuffer @@ -28,7 +28,7 @@ import org.apache.arrow.flatbuf.MessageHeader import org.apache.arrow.memory.BufferAllocator import org.apache.arrow.vector._ import org.apache.arrow.vector.compression.{CompressionCodec, NoCompressionCodec} -import org.apache.arrow.vector.ipc.{ArrowStreamReader, ArrowStreamWriter, ReadChannel, WriteChannel} +import org.apache.arrow.vector.ipc.{ArrowFileWriter, ArrowStreamReader, ArrowStreamWriter, ReadChannel, WriteChannel} import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch, IpcOption, MessageSerializer} import org.apache.spark.SparkException diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowFileWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowFileWriter.scala new file mode 100644 index 000000000000..efec4fc8bebd --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowFileWriter.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.arrow + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, FileInputStream, FileOutputStream, OutputStream} +import java.nio.channels.{Channels, ReadableByteChannel} + +import scala.collection.mutable.ArrayBuffer +import scala.jdk.CollectionConverters._ + +import org.apache.arrow.compression.{Lz4CompressionCodec, ZstdCompressionCodec} +import org.apache.arrow.flatbuf.MessageHeader +import org.apache.arrow.memory.BufferAllocator +import org.apache.arrow.vector._ +import org.apache.arrow.vector.compression.{CompressionCodec, NoCompressionCodec} +import org.apache.arrow.vector.ipc.{ArrowFileWriter, ArrowStreamReader, ArrowStreamWriter, ReadChannel, WriteChannel} +import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch, IpcOption, MessageSerializer} + +import org.apache.spark.SparkException +import org.apache.spark.TaskContext +import org.apache.spark.internal.Logging +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes +import org.apache.spark.sql.classic.{DataFrame, Dataset, SparkSession} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} +import org.apache.spark.util.{ByteBufferOutputStream, SizeEstimator, Utils} +import org.apache.spark.util.ArrayImplicits._ + +private[spark] class ArrowBatchFileWriter( + schema: StructType, + out: FileOutputStream, + maxRecordsPerBatch: Long, + timeZoneId: String, + errorOnDuplicatedFieldNames: Boolean, + largeVarTypes: Boolean, + context: TaskContext) extends AutoCloseable { + + protected val arrowSchema = + ArrowUtils.toArrowSchema(schema, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes) + private val allocator = + ArrowUtils.rootAllocator.newChildAllocator( + s"to${this.getClass.getSimpleName}", 0, Long.MaxValue) + + protected val root = VectorSchemaRoot.create(arrowSchema, allocator) + val writer = new ArrowFileWriter(root, null, out.getChannel) + + // Create compression codec based on config + private val compressionCodecName = SQLConf.get.arrowCompressionCodec + private val codec = compressionCodecName match { + case "none" => NoCompressionCodec.INSTANCE + case "zstd" => + val compressionLevel = SQLConf.get.arrowZstdCompressionLevel + val factory = CompressionCodec.Factory.INSTANCE + val codecType = new ZstdCompressionCodec(compressionLevel).getCodecType() + factory.createCodec(codecType) + case "lz4" => + val factory = CompressionCodec.Factory.INSTANCE + val codecType = new Lz4CompressionCodec().getCodecType() + factory.createCodec(codecType) + case other => + throw SparkException.internalError( + s"Unsupported Arrow compression codec: $other. Supported values: none, zstd, lz4") + } + protected val unloader = new VectorUnloader(root, true, codec, true) + protected val arrowWriter = ArrowWriter.create(root) + + Option(context).foreach {_.addTaskCompletionListener[Unit] { _ => + close() + }} + + def writeRows(rowIter: Iterator[InternalRow]): Unit = { + writer.start() + while (rowIter.hasNext) { + Utils.tryWithSafeFinally { + var rowCount = 0L + while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || rowCount < maxRecordsPerBatch)) { + val row = rowIter.next() + arrowWriter.write(row) + rowCount += 1 + } + arrowWriter.finish() + writer.writeBatch() + } { + arrowWriter.reset() + root.clear() + } + } + writer.end() + } + + override def close(): Unit = { + root.close() + allocator.close() + } +} From 536b4036b29a70f605081a014040bbf5bba93724 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Wed, 26 Nov 2025 19:47:59 +0800 Subject: [PATCH 4/7] test --- .../org/apache/spark/ml/util/ReadWrite.scala | 63 ++++----- .../apache/spark/sql/classic/Dataset.scala | 32 +++-- .../sql/execution/arrow/ArrowConverters.scala | 38 ++++-- .../execution/arrow/ArrowFileReadWrite.scala | 122 ++++++++++++++++++ .../sql/execution/arrow/ArrowFileWriter.scala | 116 ----------------- 5 files changed, 195 insertions(+), 176 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowFileReadWrite.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowFileWriter.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 23bcc69e28c4..8860242bbafc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -47,6 +47,7 @@ import org.apache.spark.ml.linalg.{DenseMatrix, DenseVector, Matrix, SparseMatri import org.apache.spark.ml.param.{ParamPair, Params} import org.apache.spark.ml.tuning.ValidatorParams import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext} +import org.apache.spark.sql.execution.arrow.ArrowFileReadWrite import org.apache.spark.sql.types.StructType import org.apache.spark.util.{Utils, VersionUtils} @@ -1147,30 +1148,20 @@ private[spark] object ReadWriteUtils { def saveDataFrame(path: String, df: DataFrame): Unit = { if (localSavingModeState.get()) { val filePath = Paths.get(path) - Files.createDirectories(filePath.getParent) + val parentPath = filePath.getParent + Files.createDirectories(parentPath) + val schemaPath = new Path(parentPath.toString, "schema").toString Using.resource( - new DataOutputStream(new BufferedOutputStream(new FileOutputStream(filePath.toFile))) + new DataOutputStream(new BufferedOutputStream(new FileOutputStream(schemaPath))) ) { dos => - + dos.writeUTF(df.schema.json) } - Using.resource( - new DataOutputStream(new BufferedOutputStream(new FileOutputStream(filePath.toFile))) - ) { dos => - dos.writeUTF("ARROW") // format - - val schema: StructType = df.schema - dos.writeUTF(schema.json) - - val iter = DatasetUtils.toArrowBatchRDD(df, "UTC").toLocalIterator - while (iter.hasNext) { - val bytes = iter.next() - require(bytes != null) - dos.writeInt(bytes.length) - dos.write(bytes) - } - dos.writeInt(-1) // End + df match { + case d: org.apache.spark.sql.classic.DataFrame => + ArrowFileReadWrite.save(d, path) + case _ => throw new UnsupportedOperationException("Unsupported dataframe type") } } else { df.write.parquet(path) @@ -1179,33 +1170,23 @@ private[spark] object ReadWriteUtils { def loadDataFrame(path: String, spark: SparkSession): DataFrame = { if (localSavingModeState.get()) { - val sc = spark match { - case s: org.apache.spark.sql.classic.SparkSession => s.sparkContext - } + val filePath = Paths.get(path) + val parentPath = filePath.getParent + val schemaPath = new Path(parentPath.toString, "schema").toString + var schemaString: String = null Using.resource( - new DataInputStream(new BufferedInputStream(new FileInputStream(path))) + new DataInputStream(new BufferedInputStream(new FileInputStream(schemaPath))) ) { dis => - val format = dis.readUTF() - require(format == "ARROW") + schemaString = dis.readUTF() + } - val schema: StructType = StructType.fromString(dis.readUTF()) + spark match { + case s: org.apache.spark.sql.classic.SparkSession => + val schema = StructType.fromString(schemaString) + ArrowFileReadWrite.load(s, path, schema) - val buff = mutable.ListBuffer.empty[Array[Byte]] - var nextBytes = dis.readInt() - while (nextBytes >= 0) { - val bytes = dis.readNBytes(nextBytes) - buff.append(bytes) - nextBytes = dis.readInt() - } - require(nextBytes == -1) - - DatasetUtils.fromArrowBatchRDD( - sc.parallelize[Array[Byte]](buff.result()), - schema, - "UTC", - spark - ) + case _ => throw new UnsupportedOperationException("Unsupported session type") } } else { spark.read.parquet(path) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala index 59bcc864ac81..397d42e77b5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala @@ -2340,14 +2340,13 @@ class Dataset[T] private[sql]( } /** Convert to an RDD of serialized ArrowRecordBatches. */ - private[sql] def toArrowBatchRdd(plan: SparkPlan): RDD[Array[Byte]] = { + private def toArrowBatchRddImpl( + plan: SparkPlan, + maxRecordsPerBatch: Int, + timeZoneId: String, + errorOnDuplicatedFieldNames: Boolean, + largeVarTypes: Boolean): RDD[Array[Byte]] = { val schemaCaptured = this.schema - val maxRecordsPerBatch = sparkSession.sessionState.conf.arrowMaxRecordsPerBatch - val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone - val errorOnDuplicatedFieldNames = - sparkSession.sessionState.conf.pandasStructHandlingMode == "legacy" - val largeVarTypes = - sparkSession.sessionState.conf.arrowUseLargeVarTypes plan.execute().mapPartitionsInternal { iter => val context = TaskContext.get() ArrowConverters.toBatchIterator( @@ -2361,7 +2360,24 @@ class Dataset[T] private[sql]( } } - // This is only used in tests, for now. + private[sql] def toArrowBatchRdd( + maxRecordsPerBatch: Int, + timeZoneId: String, + errorOnDuplicatedFieldNames: Boolean, + largeVarTypes: Boolean): RDD[Array[Byte]] = { + toArrowBatchRddImpl(queryExecution.executedPlan, + maxRecordsPerBatch, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes) + } + + private[sql] def toArrowBatchRdd(plan: SparkPlan): RDD[Array[Byte]] = { + toArrowBatchRddImpl( + plan, + sparkSession.sessionState.conf.arrowMaxRecordsPerBatch, + sparkSession.sessionState.conf.sessionLocalTimeZone, + sparkSession.sessionState.conf.pandasStructHandlingMode == "legacy", + sparkSession.sessionState.conf.arrowUseLargeVarTypes) + } + private[sql] def toArrowBatchRdd: RDD[Array[Byte]] = { toArrowBatchRdd(queryExecution.executedPlan) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 3b8d0f8c1f8b..226840d69cbf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.arrow -import java.io.{ByteArrayInputStream, ByteArrayOutputStream, FileInputStream, FileOutputStream, OutputStream} +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, FileInputStream, OutputStream} import java.nio.channels.{Channels, ReadableByteChannel} import scala.collection.mutable.ArrayBuffer @@ -28,7 +28,7 @@ import org.apache.arrow.flatbuf.MessageHeader import org.apache.arrow.memory.BufferAllocator import org.apache.arrow.vector._ import org.apache.arrow.vector.compression.{CompressionCodec, NoCompressionCodec} -import org.apache.arrow.vector.ipc.{ArrowFileWriter, ArrowStreamReader, ArrowStreamWriter, ReadChannel, WriteChannel} +import org.apache.arrow.vector.ipc.{ArrowStreamReader, ArrowStreamWriter, ReadChannel, WriteChannel} import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch, IpcOption, MessageSerializer} import org.apache.spark.SparkException @@ -555,25 +555,41 @@ private[spark] object ArrowConverters extends Logging { arrowBatches: Iterator[Array[Byte]], schemaString: String, session: SparkSession): DataFrame = { - val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] + toDataFrame( + arrowBatches, + DataType.fromJson(schemaString).asInstanceOf[StructType], + session, + session.sessionState.conf.sessionLocalTimeZone, + false, + session.sessionState.conf.arrowUseLargeVarTypes) + } + + /** + * Create a DataFrame from an iterator of serialized ArrowRecordBatches. + */ + private[sql] def toDataFrame( + arrowBatches: Iterator[Array[Byte]], + schema: StructType, + session: SparkSession, + timeZoneId: String, + errorOnDuplicatedFieldNames: Boolean, + largeVarTypes: Boolean): DataFrame = { val attrs = toAttributes(schema) val batchesInDriver = arrowBatches.toArray - val largeVarTypes = session.sessionState.conf.arrowUseLargeVarTypes val shouldUseRDD = session.sessionState.conf .arrowLocalRelationThreshold < batchesInDriver.map(_.length.toLong).sum if (shouldUseRDD) { logDebug("Using RDD-based createDataFrame with Arrow optimization.") - val timezone = session.sessionState.conf.sessionLocalTimeZone val rdd = session.sparkContext .parallelize(batchesInDriver.toImmutableArraySeq, batchesInDriver.length) .mapPartitions { batchesInExecutors => ArrowConverters.fromBatchIterator( batchesInExecutors, schema, - timezone, - errorOnDuplicatedFieldNames = false, - largeVarTypes = largeVarTypes, + timeZoneId, + errorOnDuplicatedFieldNames, + largeVarTypes, TaskContext.get()) } session.internalCreateDataFrame(rdd.setName("arrow"), schema) @@ -582,9 +598,9 @@ private[spark] object ArrowConverters extends Logging { val data = ArrowConverters.fromBatchIterator( batchesInDriver.iterator, schema, - session.sessionState.conf.sessionLocalTimeZone, - errorOnDuplicatedFieldNames = false, - largeVarTypes = largeVarTypes, + timeZoneId, + errorOnDuplicatedFieldNames, + largeVarTypes, TaskContext.get()) // Project/copy it. Otherwise, the Arrow column vectors will be closed and released out. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowFileReadWrite.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowFileReadWrite.scala new file mode 100644 index 000000000000..b5f7cb1360a6 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowFileReadWrite.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.arrow + +import java.io.{ByteArrayOutputStream, FileOutputStream} +import java.nio.channels.Channels +import java.nio.file.Files +import java.nio.file.Paths + +import scala.jdk.CollectionConverters._ + +import org.apache.arrow.vector._ +import org.apache.arrow.vector.ipc.{ArrowFileReader, ArrowFileWriter, WriteChannel} +import org.apache.arrow.vector.ipc.message.MessageSerializer +import org.apache.arrow.vector.types.pojo.Schema + +import org.apache.spark.TaskContext +import org.apache.spark.sql.classic.{DataFrame, SparkSession} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.ArrowUtils + +private[spark] class SparkArrowFileWriter( + arrowSchema: Schema, + out: FileOutputStream, + context: TaskContext) extends AutoCloseable { + + private val allocator = + ArrowUtils.rootAllocator.newChildAllocator( + s"to${this.getClass.getSimpleName}", 0, Long.MaxValue) + + protected val root = VectorSchemaRoot.create(arrowSchema, allocator) + protected val fileWriter = new ArrowFileWriter(root, null, Channels.newChannel(out)) + protected val loader = new VectorLoader(root) + protected val arrowWriter = ArrowWriter.create(root) + + Option(context).foreach {_.addTaskCompletionListener[Unit] { _ => + close() + }} + + override def close(): Unit = { + root.close() + allocator.close() + fileWriter.close() + } + + def write(batchBytesIter: Iterator[Array[Byte]]): Unit = { + fileWriter.start() + while (batchBytesIter.hasNext) { + val batchBytes = batchBytesIter.next() + val batch = ArrowConverters.loadBatch(batchBytes, allocator) + loader.load(batch) + fileWriter.writeBatch() + } + fileWriter.close() + } +} + +private[spark] class SparkArrowFileReader( + path: String, + context: TaskContext) extends AutoCloseable { + + private val allocator = + ArrowUtils.rootAllocator.newChildAllocator( + s"to${this.getClass.getSimpleName}", 0, Long.MaxValue) + + protected val fileReader = + new ArrowFileReader(Files.newByteChannel(Paths.get(path)), allocator) + + Option(context).foreach {_.addTaskCompletionListener[Unit] { _ => + close() + }} + + override def close(): Unit = { + allocator.close() + fileReader.close() + } + + def read(): Iterator[Array[Byte]] = { + fileReader.getRecordBlocks.iterator().asScala.map { block => + fileReader.loadRecordBatch(block) + val root = fileReader.getVectorSchemaRoot + val unloader = new VectorUnloader(root) + val batch = unloader.getRecordBatch + val out = new ByteArrayOutputStream() + val writeChannel = new WriteChannel(Channels.newChannel(out)) + MessageSerializer.serialize(writeChannel, batch) + out.toByteArray + } + } +} + +private[spark] object ArrowFileReadWrite { + def save(df: DataFrame, path: String): Unit = { + val spark = df.sparkSession + val rdd = df.toArrowBatchRdd( + spark.sessionState.conf.arrowMaxRecordsPerBatch, + "UTC", true, false) + val arrowSchema = ArrowUtils.toArrowSchema(df.schema, "UTC", true, false) + val writer = new SparkArrowFileWriter(arrowSchema, new FileOutputStream(path), null) + writer.write(rdd.toLocalIterator) + } + + def load(spark: SparkSession, path: String, schema: StructType): DataFrame = { + val reader = new SparkArrowFileReader(path, null) + ArrowConverters.toDataFrame(reader.read(), schema, spark, "UTC", true, false) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowFileWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowFileWriter.scala deleted file mode 100644 index efec4fc8bebd..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowFileWriter.scala +++ /dev/null @@ -1,116 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.arrow - -import java.io.{ByteArrayInputStream, ByteArrayOutputStream, FileInputStream, FileOutputStream, OutputStream} -import java.nio.channels.{Channels, ReadableByteChannel} - -import scala.collection.mutable.ArrayBuffer -import scala.jdk.CollectionConverters._ - -import org.apache.arrow.compression.{Lz4CompressionCodec, ZstdCompressionCodec} -import org.apache.arrow.flatbuf.MessageHeader -import org.apache.arrow.memory.BufferAllocator -import org.apache.arrow.vector._ -import org.apache.arrow.vector.compression.{CompressionCodec, NoCompressionCodec} -import org.apache.arrow.vector.ipc.{ArrowFileWriter, ArrowStreamReader, ArrowStreamWriter, ReadChannel, WriteChannel} -import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch, IpcOption, MessageSerializer} - -import org.apache.spark.SparkException -import org.apache.spark.TaskContext -import org.apache.spark.internal.Logging -import org.apache.spark.network.util.JavaUtils -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} -import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes -import org.apache.spark.sql.classic.{DataFrame, Dataset, SparkSession} -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types._ -import org.apache.spark.sql.util.ArrowUtils -import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} -import org.apache.spark.util.{ByteBufferOutputStream, SizeEstimator, Utils} -import org.apache.spark.util.ArrayImplicits._ - -private[spark] class ArrowBatchFileWriter( - schema: StructType, - out: FileOutputStream, - maxRecordsPerBatch: Long, - timeZoneId: String, - errorOnDuplicatedFieldNames: Boolean, - largeVarTypes: Boolean, - context: TaskContext) extends AutoCloseable { - - protected val arrowSchema = - ArrowUtils.toArrowSchema(schema, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes) - private val allocator = - ArrowUtils.rootAllocator.newChildAllocator( - s"to${this.getClass.getSimpleName}", 0, Long.MaxValue) - - protected val root = VectorSchemaRoot.create(arrowSchema, allocator) - val writer = new ArrowFileWriter(root, null, out.getChannel) - - // Create compression codec based on config - private val compressionCodecName = SQLConf.get.arrowCompressionCodec - private val codec = compressionCodecName match { - case "none" => NoCompressionCodec.INSTANCE - case "zstd" => - val compressionLevel = SQLConf.get.arrowZstdCompressionLevel - val factory = CompressionCodec.Factory.INSTANCE - val codecType = new ZstdCompressionCodec(compressionLevel).getCodecType() - factory.createCodec(codecType) - case "lz4" => - val factory = CompressionCodec.Factory.INSTANCE - val codecType = new Lz4CompressionCodec().getCodecType() - factory.createCodec(codecType) - case other => - throw SparkException.internalError( - s"Unsupported Arrow compression codec: $other. Supported values: none, zstd, lz4") - } - protected val unloader = new VectorUnloader(root, true, codec, true) - protected val arrowWriter = ArrowWriter.create(root) - - Option(context).foreach {_.addTaskCompletionListener[Unit] { _ => - close() - }} - - def writeRows(rowIter: Iterator[InternalRow]): Unit = { - writer.start() - while (rowIter.hasNext) { - Utils.tryWithSafeFinally { - var rowCount = 0L - while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || rowCount < maxRecordsPerBatch)) { - val row = rowIter.next() - arrowWriter.write(row) - rowCount += 1 - } - arrowWriter.finish() - writer.writeBatch() - } { - arrowWriter.reset() - root.clear() - } - } - writer.end() - } - - override def close(): Unit = { - root.close() - allocator.close() - } -} From 4dcf366b42d007231194673de5c4eaf716e988b5 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Wed, 26 Nov 2025 19:49:44 +0800 Subject: [PATCH 5/7] test --- .../apache/spark/ml/util/DatasetUtils.scala | 50 +------------------ 1 file changed, 1 insertion(+), 49 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala index c64e8d3007e4..06de43260b30 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.util -import org.apache.spark.{SparkException, TaskContext} +import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.internal.LogKeys.{CLASS_NAME, LABEL_COLUMN, NUM_CLASSES} import org.apache.spark.ml.PredictorParams @@ -28,7 +28,6 @@ import org.apache.spark.ml.param.shared.HasWeightCol import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} import org.apache.spark.rdd.RDD import org.apache.spark.sql._ -import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -213,51 +212,4 @@ private[spark] object DatasetUtils extends Logging { dataset.select(columnToVector(dataset, vectorCol)).head().getAs[Vector](0).size } } - - private[ml] def toArrowBatchRDD( - dataFrame: DataFrame, - timeZoneId: String): RDD[Array[Byte]] = { - dataFrame match { - case df: org.apache.spark.sql.classic.DataFrame => - val spark = df.sparkSession - val schema = df.schema - val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch - df.queryExecution.executedPlan.execute().mapPartitionsInternal { iter => - val context = TaskContext.get() - ArrowConverters.toBatchIterator( - iter, - schema, - maxRecordsPerBatch, - timeZoneId, - true, - false, - context) - } - - case _ => throw new UnsupportedOperationException("Not implemented") - } - } - - private[ml] def fromArrowBatchRDD( - rdd: RDD[Array[Byte]], - schema: StructType, - timeZoneId: String, - sparkSession: SparkSession): DataFrame = { - sparkSession match { - case spark: org.apache.spark.sql.classic.SparkSession => - val rowRDD = rdd.mapPartitions { iter => - val context = TaskContext.get() - ArrowConverters.fromBatchIterator( - iter, - schema, - timeZoneId, - true, - false, - context) - } - spark.internalCreateDataFrame(rowRDD.setName("arrow"), schema) - - case _ => throw new UnsupportedOperationException("Not implemented") - } - } } From e09ece31fd5c2779329df657d7049d2b8a62c969 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Wed, 26 Nov 2025 20:25:58 +0800 Subject: [PATCH 6/7] test --- .../org/apache/spark/ml/util/ReadWrite.scala | 15 +----------- .../spark/sql/classic/SparkSession.scala | 2 +- .../sql/execution/arrow/ArrowConverters.scala | 6 ++--- .../execution/arrow/ArrowFileReadWrite.scala | 24 ++++++++++--------- 4 files changed, 18 insertions(+), 29 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 8860242bbafc..e2c6f8429fcc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -48,7 +48,6 @@ import org.apache.spark.ml.param.{ParamPair, Params} import org.apache.spark.ml.tuning.ValidatorParams import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext} import org.apache.spark.sql.execution.arrow.ArrowFileReadWrite -import org.apache.spark.sql.types.StructType import org.apache.spark.util.{Utils, VersionUtils} /** @@ -1170,21 +1169,9 @@ private[spark] object ReadWriteUtils { def loadDataFrame(path: String, spark: SparkSession): DataFrame = { if (localSavingModeState.get()) { - val filePath = Paths.get(path) - val parentPath = filePath.getParent - val schemaPath = new Path(parentPath.toString, "schema").toString - - var schemaString: String = null - Using.resource( - new DataInputStream(new BufferedInputStream(new FileInputStream(schemaPath))) - ) { dis => - schemaString = dis.readUTF() - } - spark match { case s: org.apache.spark.sql.classic.SparkSession => - val schema = StructType.fromString(schemaString) - ArrowFileReadWrite.load(s, path, schema) + ArrowFileReadWrite.load(s, path) case _ => throw new UnsupportedOperationException("Unsupported session type") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/SparkSession.scala index e002a1a616fd..5811fe759d3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/SparkSession.scala @@ -399,7 +399,7 @@ class SparkSession private( /** * Creates a `DataFrame` from an `RDD[InternalRow]`. */ - private[spark] def internalCreateDataFrame( + private[sql] def internalCreateDataFrame( catalystRows: RDD[InternalRow], schema: StructType, isStreaming: Boolean = false): DataFrame = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 226840d69cbf..ad456f38ed42 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -80,7 +80,7 @@ private[sql] class ArrowBatchStreamWriter( } } -private[spark] object ArrowConverters extends Logging { +private[sql] object ArrowConverters extends Logging { private[sql] class ArrowBatchIterator( rowIter: Iterator[InternalRow], schema: StructType, @@ -231,7 +231,7 @@ private[spark] object ArrowConverters extends Logging { * Maps Iterator from InternalRow to serialized ArrowRecordBatches. Limit ArrowRecordBatch size * in a batch by setting maxRecordsPerBatch or use 0 to fully consume rowIter. */ - private[spark] def toBatchIterator( + private[sql] def toBatchIterator( rowIter: Iterator[InternalRow], schema: StructType, maxRecordsPerBatch: Long, @@ -484,7 +484,7 @@ private[spark] object ArrowConverters extends Logging { /** * Maps iterator from serialized ArrowRecordBatches to InternalRows. */ - private[spark] def fromBatchIterator( + private[sql] def fromBatchIterator( arrowBatchIter: Iterator[Array[Byte]], schema: StructType, timeZoneId: String, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowFileReadWrite.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowFileReadWrite.scala index b5f7cb1360a6..9a84b8dcb710 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowFileReadWrite.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowFileReadWrite.scala @@ -31,12 +31,11 @@ import org.apache.arrow.vector.types.pojo.Schema import org.apache.spark.TaskContext import org.apache.spark.sql.classic.{DataFrame, SparkSession} -import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.ArrowUtils -private[spark] class SparkArrowFileWriter( +private[sql] class SparkArrowFileWriter( arrowSchema: Schema, - out: FileOutputStream, + path: String, context: TaskContext) extends AutoCloseable { private val allocator = @@ -44,10 +43,12 @@ private[spark] class SparkArrowFileWriter( s"to${this.getClass.getSimpleName}", 0, Long.MaxValue) protected val root = VectorSchemaRoot.create(arrowSchema, allocator) - protected val fileWriter = new ArrowFileWriter(root, null, Channels.newChannel(out)) protected val loader = new VectorLoader(root) protected val arrowWriter = ArrowWriter.create(root) + protected val fileWriter = + new ArrowFileWriter(root, null, Channels.newChannel(new FileOutputStream(path))) + Option(context).foreach {_.addTaskCompletionListener[Unit] { _ => close() }} @@ -70,7 +71,7 @@ private[spark] class SparkArrowFileWriter( } } -private[spark] class SparkArrowFileReader( +private[sql] class SparkArrowFileReader( path: String, context: TaskContext) extends AutoCloseable { @@ -90,6 +91,8 @@ private[spark] class SparkArrowFileReader( fileReader.close() } + val schema: Schema = fileReader.getVectorSchemaRoot.getSchema + def read(): Iterator[Array[Byte]] = { fileReader.getRecordBlocks.iterator().asScala.map { block => fileReader.loadRecordBatch(block) @@ -106,17 +109,16 @@ private[spark] class SparkArrowFileReader( private[spark] object ArrowFileReadWrite { def save(df: DataFrame, path: String): Unit = { - val spark = df.sparkSession - val rdd = df.toArrowBatchRdd( - spark.sessionState.conf.arrowMaxRecordsPerBatch, - "UTC", true, false) + val maxRecordsPerBatch = df.sparkSession.sessionState.conf.arrowMaxRecordsPerBatch + val rdd = df.toArrowBatchRdd(maxRecordsPerBatch, "UTC", true, false) val arrowSchema = ArrowUtils.toArrowSchema(df.schema, "UTC", true, false) - val writer = new SparkArrowFileWriter(arrowSchema, new FileOutputStream(path), null) + val writer = new SparkArrowFileWriter(arrowSchema, path, null) writer.write(rdd.toLocalIterator) } - def load(spark: SparkSession, path: String, schema: StructType): DataFrame = { + def load(spark: SparkSession, path: String): DataFrame = { val reader = new SparkArrowFileReader(path, null) + val schema = ArrowUtils.fromArrowSchema(reader.schema) ArrowConverters.toDataFrame(reader.read(), schema, spark, "UTC", true, false) } } From 76bc0a87812a96d2cc414eab4b2c29bba15bbd40 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Wed, 26 Nov 2025 20:31:28 +0800 Subject: [PATCH 7/7] nit --- .../org/apache/spark/ml/util/ReadWrite.scala | 11 +--------- .../apache/spark/sql/types/StructType.scala | 2 +- .../execution/arrow/ArrowFileReadWrite.scala | 22 ++++--------------- 3 files changed, 6 insertions(+), 29 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index e2c6f8429fcc..bdfb1a6a5cd1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -1147,15 +1147,7 @@ private[spark] object ReadWriteUtils { def saveDataFrame(path: String, df: DataFrame): Unit = { if (localSavingModeState.get()) { val filePath = Paths.get(path) - val parentPath = filePath.getParent - Files.createDirectories(parentPath) - - val schemaPath = new Path(parentPath.toString, "schema").toString - Using.resource( - new DataOutputStream(new BufferedOutputStream(new FileOutputStream(schemaPath))) - ) { dos => - dos.writeUTF(df.schema.json) - } + Files.createDirectories(filePath.getParent) df match { case d: org.apache.spark.sql.classic.DataFrame => @@ -1172,7 +1164,6 @@ private[spark] object ReadWriteUtils { spark match { case s: org.apache.spark.sql.classic.SparkSession => ArrowFileReadWrite.load(s, path) - case _ => throw new UnsupportedOperationException("Unsupported session type") } } else { diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala index bbc27b2a73f9..5b1d9f1f116a 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -531,7 +531,7 @@ object StructType extends AbstractDataType { override private[sql] def simpleString: String = "struct" - private[spark] def fromString(raw: String): StructType = { + private[sql] def fromString(raw: String): StructType = { Try(DataType.fromJson(raw)).getOrElse(LegacyTypeStringParser.parseString(raw)) match { case t: StructType => t case _ => throw DataTypeErrors.failedParsingStructTypeError(raw) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowFileReadWrite.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowFileReadWrite.scala index 9a84b8dcb710..a4557aebf607 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowFileReadWrite.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowFileReadWrite.scala @@ -29,15 +29,12 @@ import org.apache.arrow.vector.ipc.{ArrowFileReader, ArrowFileWriter, WriteChann import org.apache.arrow.vector.ipc.message.MessageSerializer import org.apache.arrow.vector.types.pojo.Schema -import org.apache.spark.TaskContext import org.apache.spark.sql.classic.{DataFrame, SparkSession} import org.apache.spark.sql.util.ArrowUtils private[sql] class SparkArrowFileWriter( arrowSchema: Schema, - path: String, - context: TaskContext) extends AutoCloseable { - + path: String) extends AutoCloseable { private val allocator = ArrowUtils.rootAllocator.newChildAllocator( s"to${this.getClass.getSimpleName}", 0, Long.MaxValue) @@ -49,10 +46,6 @@ private[sql] class SparkArrowFileWriter( protected val fileWriter = new ArrowFileWriter(root, null, Channels.newChannel(new FileOutputStream(path))) - Option(context).foreach {_.addTaskCompletionListener[Unit] { _ => - close() - }} - override def close(): Unit = { root.close() allocator.close() @@ -71,10 +64,7 @@ private[sql] class SparkArrowFileWriter( } } -private[sql] class SparkArrowFileReader( - path: String, - context: TaskContext) extends AutoCloseable { - +private[sql] class SparkArrowFileReader(path: String) extends AutoCloseable { private val allocator = ArrowUtils.rootAllocator.newChildAllocator( s"to${this.getClass.getSimpleName}", 0, Long.MaxValue) @@ -82,10 +72,6 @@ private[sql] class SparkArrowFileReader( protected val fileReader = new ArrowFileReader(Files.newByteChannel(Paths.get(path)), allocator) - Option(context).foreach {_.addTaskCompletionListener[Unit] { _ => - close() - }} - override def close(): Unit = { allocator.close() fileReader.close() @@ -112,12 +98,12 @@ private[spark] object ArrowFileReadWrite { val maxRecordsPerBatch = df.sparkSession.sessionState.conf.arrowMaxRecordsPerBatch val rdd = df.toArrowBatchRdd(maxRecordsPerBatch, "UTC", true, false) val arrowSchema = ArrowUtils.toArrowSchema(df.schema, "UTC", true, false) - val writer = new SparkArrowFileWriter(arrowSchema, path, null) + val writer = new SparkArrowFileWriter(arrowSchema, path) writer.write(rdd.toLocalIterator) } def load(spark: SparkSession, path: String): DataFrame = { - val reader = new SparkArrowFileReader(path, null) + val reader = new SparkArrowFileReader(path) val schema = ArrowUtils.fromArrowSchema(reader.schema) ArrowConverters.toDataFrame(reader.read(), schema, spark, "UTC", true, false) }