From d17d3fbee84fcb0072d3030f3118ca18ce783e0c Mon Sep 17 00:00:00 2001 From: Arseniy Tashoyan Date: Sun, 11 Feb 2018 00:16:51 +0300 Subject: [PATCH 1/2] [SPARK-23318][ML]Workaround for 'ArrayStoreException: [Ljava.lang.Object' when trying to cache the RDD of items. --- mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala index aa7871d6ff29d..f72b357a5ecee 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala @@ -159,7 +159,7 @@ class FPGrowth @Since("2.2.0") ( private def genericFit[T: ClassTag](dataset: Dataset[_]): FPGrowthModel = { val data = dataset.select($(itemsCol)) - val items = data.where(col($(itemsCol)).isNotNull).rdd.map(r => r.getSeq[T](0).toArray) + val items = data.where(col($(itemsCol)).isNotNull).rdd.map(r => r.getSeq[Any](0).toArray) val mllibFP = new MLlibFPGrowth().setMinSupport($(minSupport)) if (isSet(numPartitions)) { mllibFP.setNumPartitions($(numPartitions)) From e0eb8519bf09db12f5d5bc426eaf17d6488e05c1 Mon Sep 17 00:00:00 2001 From: Arseniy Tashoyan Date: Sun, 11 Feb 2018 18:21:39 +0300 Subject: [PATCH 2/2] [SPARK-23318][ML] Cache the RDD of items if the user did not cache the input dataset of transactions. This should eliminate the warning about uncahed data in mllib.FPGrowth. --- .../scala/org/apache/spark/ml/fpm/FPGrowth.scala | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala index f72b357a5ecee..3d041fc80eb7f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala @@ -32,6 +32,7 @@ import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ +import org.apache.spark.storage.StorageLevel /** * Common params for FPGrowth and FPGrowthModel @@ -158,18 +159,30 @@ class FPGrowth @Since("2.2.0") ( } private def genericFit[T: ClassTag](dataset: Dataset[_]): FPGrowthModel = { + val handlePersistence = dataset.storageLevel == StorageLevel.NONE + val data = dataset.select($(itemsCol)) val items = data.where(col($(itemsCol)).isNotNull).rdd.map(r => r.getSeq[Any](0).toArray) val mllibFP = new MLlibFPGrowth().setMinSupport($(minSupport)) if (isSet(numPartitions)) { mllibFP.setNumPartitions($(numPartitions)) } + + if (handlePersistence) { + items.persist(StorageLevel.MEMORY_AND_DISK) + } + val parentModel = mllibFP.run(items) val rows = parentModel.freqItemsets.map(f => Row(f.items, f.freq)) val schema = StructType(Seq( StructField("items", dataset.schema($(itemsCol)).dataType, nullable = false), StructField("freq", LongType, nullable = false))) val frequentItems = dataset.sparkSession.createDataFrame(rows, schema) + + if (handlePersistence) { + items.unpersist() + } + copyValues(new FPGrowthModel(uid, frequentItems)).setParent(this) }