From f487cb216fdfddbd12772456f277ccbb1bb7a116 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Fri, 6 Feb 2015 00:22:48 -0800 Subject: [PATCH] add unit tests --- .../MatrixFactorizationModel.scala | 33 ++++++++++++++----- .../MatrixFactorizationModelSuite.scala | 19 +++++++++++ 2 files changed, 44 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index 78f4acba2d15e..08ff68a1b3d6c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -17,6 +17,7 @@ package org.apache.spark.mllib.recommendation +import java.io.IOException import java.lang.{Integer => JavaInteger} import org.apache.hadoop.fs.Path @@ -24,7 +25,6 @@ import org.jblas.DoubleMatrix import org.apache.spark.{Logging, SparkContext} import org.apache.spark.api.java.{JavaPairRDD, JavaRDD} -import org.apache.spark.mllib.recommendation.MatrixFactorizationModel.SaveLoadV1_0 import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SQLContext} @@ -130,9 +130,10 @@ class MatrixFactorizationModel( recommend(productFeatures.lookup(product).head, userFeatures, num) .map(t => Rating(t._1, product, t._2)) + override val formatVersion: String = "1.0" override def save(sc: SparkContext, path: String): Unit = { - SaveLoadV1_0.save(this, path) + MatrixFactorizationModel.SaveLoadV1_0.save(this, path) } private def recommend( @@ -151,12 +152,30 @@ private object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] import org.apache.spark.mllib.util.Loader._ - private object SaveLoadV1_0 { + override def load(sc: SparkContext, path: String): MatrixFactorizationModel = { + val (loadedClassName, formatVersion, metadata) = loadMetadata(sc, path) + val classNameV1_0 = SaveLoadV1_0.thisClassName + (loadedClassName, formatVersion) match { + case (className, "1.0") if className == classNameV1_0 => + SaveLoadV1_0.load(sc, path) + case _ => + throw new IOException("" + + "MatrixFactorizationModel.load did not recognize model with" + + s"(class: $loadedClassName, version: $formatVersion). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + } + + private object SaveLoadV1_0 extends Loader[MatrixFactorizationModel] { private val thisFormatVersion = "1.0" - private val thisClassName = "org.apache.spark.mllib.recommendation.MatrixFactorizationModel" + val thisClassName = "org.apache.spark.mllib.recommendation.MatrixFactorizationModel" + /** + * Saves a [[MatrixFactorizationModel]], where user features are saved under `data/users` and + * product features are saved under `data/products`. + */ def save(model: MatrixFactorizationModel, path: String): Unit = { val sc = model.userFeatures.sparkContext val sqlContext = new SQLContext(sc) @@ -173,9 +192,7 @@ private object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] val (className, formatVersion, metadata) = loadMetadata(sc, path) assert(className == thisClassName) assert(formatVersion == thisFormatVersion) - val rank = metadata.select("rank").map { case Row(r: Int) => - r - }.first() + val rank = metadata.select("rank").first().getInt(0) val userFeatures = sqlContext.parquetFile(userPath(path)) .map { case Row(id: Int, features: Seq[Double]) => (id, features.toArray) @@ -184,7 +201,7 @@ private object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] .map { case Row(id: Int, features: Seq[Double]) => (id, features.toArray) } - new MatrixFactorizationModel(r, userFeatures, productFeatures) + new MatrixFactorizationModel(rank, userFeatures, productFeatures) } private def userPath(path: String): String = { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala index b9caecc904a23..9801e87576744 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala @@ -22,6 +22,7 @@ import org.scalatest.FunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils class MatrixFactorizationModelSuite extends FunSuite with MLlibTestSparkContext { @@ -53,4 +54,22 @@ class MatrixFactorizationModelSuite extends FunSuite with MLlibTestSparkContext new MatrixFactorizationModel(rank, userFeatures, prodFeatures1) } } + + test("save/load") { + val model = new MatrixFactorizationModel(rank, userFeatures, prodFeatures) + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + def collect(features: RDD[(Int, Array[Double])]): Set[(Int, Seq[Double])] = { + features.mapValues(_.toSeq).collect().toSet + } + try { + model.save(sc, path) + val newModel = MatrixFactorizationModel.load(sc, path) + assert(newModel.rank === rank) + assert(collect(newModel.userFeatures) === collect(userFeatures)) + assert(collect(newModel.productFeatures) === collect(prodFeatures)) + } finally { + Utils.deleteRecursively(tempDir) + } + } }