-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-54446][ML][CONNECT] FPGrowth supports local filesystem with Arrow file format #53232
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
|
This PR is another attempt to save ml models containing dataframes to driver's local fs. |
| fileWriter.start() | ||
| while (batchBytesIter.hasNext) { | ||
| val batchBytes = batchBytesIter.next() | ||
| val batch = ArrowConverters.loadBatch(batchBytes, allocator) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The batch: ArrowRecordBatch doesn't extends Serializable, so still use the Array[Byte] as the underlying data in the PR.
|
|
||
| protected val root = VectorSchemaRoot.create(arrowSchema, allocator) | ||
| protected val loader = new VectorLoader(root) | ||
| protected val arrowWriter = ArrowWriter.create(root) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is arrowWriter used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch!
| fileWriter.close() | ||
| } | ||
|
|
||
| def write(batchBytesIter: Iterator[Array[Byte]]): Unit = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks like do such thing:
Dataset -> Arrow batches -> Bytes -> Arrow batches -> Write Arrow batches by ArrowFileWriter
Looks like the intermediate Bytes could be skipped?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think he's doing it cuz local data has to go to executors, and to do that, the arrow batches should be in ipc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dataset is already distributed on executors. Rows are written into Arrow batches in executors. If they are not to distributed again, they could be in Arrow batches, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Below, writer.write(rdd.toLocalIterator) I think the code path here is to collect Arrow batches into Spark Diver, and write them in Spark Driver. So .. it should collect the Arrow batches from executors to the driver.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess it's because to write down into Drivers' local file system
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I see. Okay.
| def saveDataFrame(path: String, df: DataFrame): Unit = { | ||
| if (localSavingModeState.get()) { | ||
| val filePath = Paths.get(path) | ||
| Files.createDirectories(filePath.getParent) | ||
|
|
||
| df match { | ||
| case d: org.apache.spark.sql.classic.DataFrame => | ||
| ArrowFileReadWrite.save(d, path) | ||
| case _ => throw new UnsupportedOperationException("Unsupported dataframe type") | ||
| } | ||
| } else { | ||
| df.write.parquet(path) | ||
| } | ||
| } | ||
|
|
||
| def loadDataFrame(path: String, spark: SparkSession): DataFrame = { | ||
| if (localSavingModeState.get()) { | ||
| spark match { | ||
| case s: org.apache.spark.sql.classic.SparkSession => | ||
| ArrowFileReadWrite.load(s, path) | ||
| case _ => throw new UnsupportedOperationException("Unsupported session type") | ||
| } | ||
| } else { | ||
| spark.read.parquet(path) | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So if we have localSavingModeState set to true this will write out an arrow file which is not stable format wise. It does look like localSavingModeState is only set to true in internal methods in Scala. Looking in the PySpark docstrings I see we tell people to use this API so I remain -0.9.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hi @holdenk , as @WeichenXu123 explained #53150 (comment), this is a runtime temporary file in spark connect server side, and will be cleaned after session close.
So I think we don't have to use a stable format here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
localSavingModeState is also used internally, (only Spark driver code can set the flag) . Where does the doc string mentioned it ? we should remove it from doc and mark localSavingModeState as private field
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, even it is just a temporary session file, is there any reason not to use Parquet but Arrow file format?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can read/write parquet with arrow, but it requires a new dependency
<dependency>
<groupId>org.apache.parquet</groupId>
<artifactId>parquet-arrow</artifactId>
</dependency>
otherwise, I am not sure whether we have utils to read/write parquet.
viirya
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wonder why choosing Arrow file format now instead of Parquet?
Due to the process of batch -> bytes -> batch -> bytes (when writing to file), it doesn't look like an efficient way.
| val rdd = df.toArrowBatchRdd(maxRecordsPerBatch, "UTC", true, false) | ||
| val arrowSchema = ArrowUtils.toArrowSchema(df.schema, "UTC", true, false) | ||
| val writer = new SparkArrowFileWriter(arrowSchema, path) | ||
| writer.write(rdd.toLocalIterator) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead, can we call toLocalIterator on original DataFrame's rdd and write rows to Arrow batches locally? Then we don't need to have the redundant Bytes?
| val extraMetadata: JObject = Map("numTrainingRecords" -> instance.numTrainingRecords) | ||
| DefaultParamsWriter.saveMetadata(instance, path, sparkSession, | ||
| extraMetadata = Some(extraMetadata)) | ||
| val dataPath = new Path(path, "data").toString |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we pass Path object to saveDataFrame directly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good!
| /** Convert to an RDD of serialized ArrowRecordBatches. */ | ||
| private[sql] def toArrowBatchRdd(plan: SparkPlan): RDD[Array[Byte]] = { | ||
| private def toArrowBatchRddImpl( | ||
| plan: SparkPlan, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: 4 spaces indentation
| import org.apache.spark.sql.util.ArrowUtils | ||
|
|
||
| private[sql] class SparkArrowFileWriter( | ||
| arrowSchema: Schema, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: 4 spaces indentation
WeichenXu123
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Can we have a shared util to produce RDD of arrow batches? Then we can either turn it to RDD of bytes, or write it to local files. |
This is actually already reusing a lot of existing utiles at Basically Below code is for |
| spark match { | ||
| case s: org.apache.spark.sql.classic.SparkSession => | ||
| ArrowFileReadWrite.load(s, path) | ||
| case _ => throw new UnsupportedOperationException("Unsupported session type") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we show actual session type in the error?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make sense, will update!
What changes were proposed in this pull request?
FPGrowth supports local filesystem
Why are the changes needed?
to make FPGrowth work with local filesystem
Does this PR introduce any user-facing change?
yes, FPGrowth will work when local saving mode is one
How was this patch tested?
updated tests
Was this patch authored or co-authored using generative AI tooling?
no