diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala index 6fd20ceb562b..e25fdc3e05ab 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala @@ -343,16 +343,11 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] { class FPGrowthModelWriter(instance: FPGrowthModel) extends MLWriter { override protected def saveImpl(path: String): Unit = { - if (ReadWriteUtils.localSavingModeState.get()) { - throw new UnsupportedOperationException( - "FPGrowthModel does not support saving to local filesystem path." - ) - } val extraMetadata: JObject = Map("numTrainingRecords" -> instance.numTrainingRecords) DefaultParamsWriter.saveMetadata(instance, path, sparkSession, extraMetadata = Some(extraMetadata)) val dataPath = new Path(path, "data").toString - instance.freqItemsets.write.parquet(dataPath) + ReadWriteUtils.saveDataFrame(dataPath, instance.freqItemsets) } } @@ -362,11 +357,6 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] { private val className = classOf[FPGrowthModel].getName override def load(path: String): FPGrowthModel = { - if (ReadWriteUtils.localSavingModeState.get()) { - throw new UnsupportedOperationException( - "FPGrowthModel does not support loading from local filesystem path." - ) - } implicit val format = DefaultFormats val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) val (major, minor) = VersionUtils.majorMinorVersion(metadata.sparkVersion) @@ -378,7 +368,7 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] { (metadata.metadata \ "numTrainingRecords").extract[Long] } val dataPath = new Path(path, "data").toString - val frequentItems = sparkSession.read.parquet(dataPath) + val frequentItems = ReadWriteUtils.loadDataFrame(dataPath, sparkSession) val itemSupport = if (numTrainingRecords == 0L) { Map.empty[Any, Double] } else { diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index e3f31874a4c2..dc003884c5dd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -46,7 +46,8 @@ import org.apache.spark.ml.feature.RFormulaModel import org.apache.spark.ml.linalg.{DenseMatrix, DenseVector, Matrix, SparseMatrix, SparseVector, Vector} import org.apache.spark.ml.param.{ParamPair, Params} import org.apache.spark.ml.tuning.ValidatorParams -import org.apache.spark.sql.{SparkSession, SQLContext} +import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext} +import org.apache.spark.sql.execution.arrow.ArrowFileReadWrite import org.apache.spark.util.{Utils, VersionUtils} /** @@ -1142,4 +1143,32 @@ private[spark] object ReadWriteUtils { spark.read.parquet(path).as[T].collect() } } + + def saveDataFrame(path: String, df: DataFrame): Unit = { + if (localSavingModeState.get()) { + df match { + case d: org.apache.spark.sql.classic.DataFrame => + val filePath = Paths.get(path) + Files.createDirectories(filePath.getParent) + ArrowFileReadWrite.save(d, filePath) + case o => throw new UnsupportedOperationException( + s"Unsupported dataframe type: ${o.getClass.getName}") + } + } else { + df.write.parquet(path) + } + } + + def loadDataFrame(path: String, spark: SparkSession): DataFrame = { + if (localSavingModeState.get()) { + spark match { + case s: org.apache.spark.sql.classic.SparkSession => + ArrowFileReadWrite.load(s, Paths.get(path)) + case o => throw new UnsupportedOperationException( + s"Unsupported session type: ${o.getClass.getName}") + } + } else { + spark.read.parquet(path) + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala index 1630a5d07d8e..3d994366b891 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala @@ -165,7 +165,7 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul } val fPGrowth = new FPGrowth() testEstimatorAndModelReadWrite(fPGrowth, dataset, FPGrowthSuite.allParamSettings, - FPGrowthSuite.allParamSettings, checkModelData, skipTestSaveLocal = true) + FPGrowthSuite.allParamSettings, checkModelData) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala index 59bcc864ac81..d73918586b09 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala @@ -2340,14 +2340,13 @@ class Dataset[T] private[sql]( } /** Convert to an RDD of serialized ArrowRecordBatches. */ - private[sql] def toArrowBatchRdd(plan: SparkPlan): RDD[Array[Byte]] = { + private def toArrowBatchRddImpl( + plan: SparkPlan, + maxRecordsPerBatch: Int, + timeZoneId: String, + errorOnDuplicatedFieldNames: Boolean, + largeVarTypes: Boolean): RDD[Array[Byte]] = { val schemaCaptured = this.schema - val maxRecordsPerBatch = sparkSession.sessionState.conf.arrowMaxRecordsPerBatch - val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone - val errorOnDuplicatedFieldNames = - sparkSession.sessionState.conf.pandasStructHandlingMode == "legacy" - val largeVarTypes = - sparkSession.sessionState.conf.arrowUseLargeVarTypes plan.execute().mapPartitionsInternal { iter => val context = TaskContext.get() ArrowConverters.toBatchIterator( @@ -2361,7 +2360,24 @@ class Dataset[T] private[sql]( } } - // This is only used in tests, for now. + private[sql] def toArrowBatchRdd( + maxRecordsPerBatch: Int, + timeZoneId: String, + errorOnDuplicatedFieldNames: Boolean, + largeVarTypes: Boolean): RDD[Array[Byte]] = { + toArrowBatchRddImpl(queryExecution.executedPlan, + maxRecordsPerBatch, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes) + } + + private[sql] def toArrowBatchRdd(plan: SparkPlan): RDD[Array[Byte]] = { + toArrowBatchRddImpl( + plan, + sparkSession.sessionState.conf.arrowMaxRecordsPerBatch, + sparkSession.sessionState.conf.sessionLocalTimeZone, + sparkSession.sessionState.conf.pandasStructHandlingMode == "legacy", + sparkSession.sessionState.conf.arrowUseLargeVarTypes) + } + private[sql] def toArrowBatchRdd: RDD[Array[Byte]] = { toArrowBatchRdd(queryExecution.executedPlan) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 8b031af14e8b..683d8de25e0a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -128,8 +128,7 @@ private[sql] object ArrowConverters extends Logging { } override def next(): Array[Byte] = { - val out = new ByteArrayOutputStream() - val writeChannel = new WriteChannel(Channels.newChannel(out)) + var bytes: Array[Byte] = null Utils.tryWithSafeFinally { var rowCount = 0L @@ -140,13 +139,13 @@ private[sql] object ArrowConverters extends Logging { } arrowWriter.finish() val batch = unloader.getRecordBatch() - MessageSerializer.serialize(writeChannel, batch) + bytes = serializeBatch(batch) batch.close() } { arrowWriter.reset() } - out.toByteArray + bytes } override def close(): Unit = { @@ -548,6 +547,13 @@ private[sql] object ArrowConverters extends Logging { new ReadChannel(Channels.newChannel(in)), allocator) // throws IOException } + private[arrow] def serializeBatch(batch: ArrowRecordBatch): Array[Byte] = { + val out = new ByteArrayOutputStream() + val writeChannel = new WriteChannel(Channels.newChannel(out)) + MessageSerializer.serialize(writeChannel, batch) + out.toByteArray + } + /** * Create a DataFrame from an iterator of serialized ArrowRecordBatches. */ @@ -555,25 +561,41 @@ private[sql] object ArrowConverters extends Logging { arrowBatches: Iterator[Array[Byte]], schemaString: String, session: SparkSession): DataFrame = { - val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] + toDataFrame( + arrowBatches, + DataType.fromJson(schemaString).asInstanceOf[StructType], + session, + session.sessionState.conf.sessionLocalTimeZone, + false, + session.sessionState.conf.arrowUseLargeVarTypes) + } + + /** + * Create a DataFrame from an iterator of serialized ArrowRecordBatches. + */ + private[sql] def toDataFrame( + arrowBatches: Iterator[Array[Byte]], + schema: StructType, + session: SparkSession, + timeZoneId: String, + errorOnDuplicatedFieldNames: Boolean, + largeVarTypes: Boolean): DataFrame = { val attrs = toAttributes(schema) val batchesInDriver = arrowBatches.toArray - val largeVarTypes = session.sessionState.conf.arrowUseLargeVarTypes val shouldUseRDD = session.sessionState.conf .arrowLocalRelationThreshold < batchesInDriver.map(_.length.toLong).sum if (shouldUseRDD) { logDebug("Using RDD-based createDataFrame with Arrow optimization.") - val timezone = session.sessionState.conf.sessionLocalTimeZone val rdd = session.sparkContext .parallelize(batchesInDriver.toImmutableArraySeq, batchesInDriver.length) .mapPartitions { batchesInExecutors => ArrowConverters.fromBatchIterator( batchesInExecutors, schema, - timezone, - errorOnDuplicatedFieldNames = false, - largeVarTypes = largeVarTypes, + timeZoneId, + errorOnDuplicatedFieldNames, + largeVarTypes, TaskContext.get()) } session.internalCreateDataFrame(rdd.setName("arrow"), schema) @@ -582,9 +604,9 @@ private[sql] object ArrowConverters extends Logging { val data = ArrowConverters.fromBatchIterator( batchesInDriver.iterator, schema, - session.sessionState.conf.sessionLocalTimeZone, - errorOnDuplicatedFieldNames = false, - largeVarTypes = largeVarTypes, + timeZoneId, + errorOnDuplicatedFieldNames, + largeVarTypes, TaskContext.get()) // Project/copy it. Otherwise, the Arrow column vectors will be closed and released out. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowFileReadWrite.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowFileReadWrite.scala new file mode 100644 index 000000000000..e7ec2d2b7984 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowFileReadWrite.scala @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.arrow + +import java.nio.channels.Channels +import java.nio.file.{Files, Path} + +import scala.jdk.CollectionConverters._ + +import org.apache.arrow.vector._ +import org.apache.arrow.vector.ipc.{ArrowFileReader, ArrowFileWriter} +import org.apache.arrow.vector.types.pojo.Schema + +import org.apache.spark.sql.classic.{DataFrame, SparkSession} +import org.apache.spark.sql.util.ArrowUtils + +private[sql] class SparkArrowFileWriter(schema: Schema, path: Path) extends AutoCloseable { + private val allocator = ArrowUtils.rootAllocator + .newChildAllocator(s"to${this.getClass.getSimpleName}", 0, Long.MaxValue) + + protected val root = VectorSchemaRoot.create(schema, allocator) + protected val loader = new VectorLoader(root) + + protected val fileWriter = + new ArrowFileWriter(root, null, Channels.newChannel(Files.newOutputStream(path))) + + override def close(): Unit = { + fileWriter.close() + root.close() + allocator.close() + } + + def write(batchBytesIter: Iterator[Array[Byte]]): Unit = { + fileWriter.start() + while (batchBytesIter.hasNext) { + val batchBytes = batchBytesIter.next() + val batch = ArrowConverters.loadBatch(batchBytes, allocator) + loader.load(batch) + fileWriter.writeBatch() + batch.close() + } + fileWriter.close() + } +} + +private[sql] class SparkArrowFileReader(path: Path) extends AutoCloseable { + private val allocator = ArrowUtils.rootAllocator + .newChildAllocator(s"to${this.getClass.getSimpleName}", 0, Long.MaxValue) + + protected val fileReader = + new ArrowFileReader(Files.newByteChannel(path), allocator) + + override def close(): Unit = { + fileReader.close() + allocator.close() + } + + val schema: Schema = fileReader.getVectorSchemaRoot.getSchema + + def read(): Iterator[Array[Byte]] = { + fileReader.getRecordBlocks.iterator().asScala.map { block => + fileReader.loadRecordBatch(block) + val root = fileReader.getVectorSchemaRoot + val unloader = new VectorUnloader(root) + val batch = unloader.getRecordBatch + val bytes = ArrowConverters.serializeBatch(batch) + batch.close() + bytes + } + } +} + +private[spark] object ArrowFileReadWrite { + def save(df: DataFrame, path: Path): Unit = { + val maxRecordsPerBatch = df.sparkSession.sessionState.conf.arrowMaxRecordsPerBatch + val rdd = df.toArrowBatchRdd(maxRecordsPerBatch, "UTC", true, false) + val arrowSchema = ArrowUtils.toArrowSchema(df.schema, "UTC", true, false) + val writer = new SparkArrowFileWriter(arrowSchema, path) + writer.write(rdd.toLocalIterator) + } + + def load(spark: SparkSession, path: Path): DataFrame = { + val reader = new SparkArrowFileReader(path) + val schema = ArrowUtils.fromArrowSchema(reader.schema) + ArrowConverters.toDataFrame(reader.read(), schema, spark, "UTC", true, false) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowFileReadWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowFileReadWriteSuite.scala new file mode 100644 index 000000000000..aa55b9706b5b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowFileReadWriteSuite.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.arrow + +import java.io.File + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.util.Utils + +class ArrowFileReadWriteSuite extends QueryTest with SharedSparkSession { + + private var tempDataPath: String = _ + + override def beforeAll(): Unit = { + super.beforeAll() + tempDataPath = Utils.createTempDir(namePrefix = "arrowFileReadWrite").getAbsolutePath + } + + test("simple") { + val df = spark.range(0, 100, 1, 10).select( + col("id"), + lit(1).alias("int"), + lit(2L).alias("long"), + lit(3.0).alias("double"), + lit("a string").alias("str"), + lit(Array(1.0, 2.0, Double.NaN, Double.NegativeInfinity)).alias("arr")) + + val path = new File(tempDataPath, "simple.arrowfile").toPath + ArrowFileReadWrite.save(df, path) + + val df2 = ArrowFileReadWrite.load(spark, path) + checkAnswer(df, df2) + } + + test("empty dataframe") { + val df = spark.range(0).withColumn("v", lit(1)) + assert(df.count() === 0) + + val path = new File(tempDataPath, "empty.arrowfile").toPath + ArrowFileReadWrite.save(df, path) + + val df2 = ArrowFileReadWrite.load(spark, path) + checkAnswer(df, df2) + } +}