Skip to content

Commit

Permalink
add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Feb 6, 2015
1 parent 62fc43c commit f487cb2
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@

package org.apache.spark.mllib.recommendation

import java.io.IOException
import java.lang.{Integer => JavaInteger}

import org.apache.hadoop.fs.Path
import org.jblas.DoubleMatrix

import org.apache.spark.{Logging, SparkContext}
import org.apache.spark.api.java.{JavaPairRDD, JavaRDD}
import org.apache.spark.mllib.recommendation.MatrixFactorizationModel.SaveLoadV1_0
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext}
Expand Down Expand Up @@ -130,9 +130,10 @@ class MatrixFactorizationModel(
recommend(productFeatures.lookup(product).head, userFeatures, num)
.map(t => Rating(t._1, product, t._2))

override val formatVersion: String = "1.0"

override def save(sc: SparkContext, path: String): Unit = {
SaveLoadV1_0.save(this, path)
MatrixFactorizationModel.SaveLoadV1_0.save(this, path)
}

private def recommend(
Expand All @@ -151,12 +152,30 @@ private object MatrixFactorizationModel extends Loader[MatrixFactorizationModel]

import org.apache.spark.mllib.util.Loader._

private object SaveLoadV1_0 {
override def load(sc: SparkContext, path: String): MatrixFactorizationModel = {
val (loadedClassName, formatVersion, metadata) = loadMetadata(sc, path)
val classNameV1_0 = SaveLoadV1_0.thisClassName
(loadedClassName, formatVersion) match {
case (className, "1.0") if className == classNameV1_0 =>
SaveLoadV1_0.load(sc, path)
case _ =>
throw new IOException("" +
"MatrixFactorizationModel.load did not recognize model with" +
s"(class: $loadedClassName, version: $formatVersion). Supported:\n" +
s" ($classNameV1_0, 1.0)")
}
}

private object SaveLoadV1_0 extends Loader[MatrixFactorizationModel] {

private val thisFormatVersion = "1.0"

private val thisClassName = "org.apache.spark.mllib.recommendation.MatrixFactorizationModel"
val thisClassName = "org.apache.spark.mllib.recommendation.MatrixFactorizationModel"

/**
* Saves a [[MatrixFactorizationModel]], where user features are saved under `data/users` and
* product features are saved under `data/products`.
*/
def save(model: MatrixFactorizationModel, path: String): Unit = {
val sc = model.userFeatures.sparkContext
val sqlContext = new SQLContext(sc)
Expand All @@ -173,9 +192,7 @@ private object MatrixFactorizationModel extends Loader[MatrixFactorizationModel]
val (className, formatVersion, metadata) = loadMetadata(sc, path)
assert(className == thisClassName)
assert(formatVersion == thisFormatVersion)
val rank = metadata.select("rank").map { case Row(r: Int) =>
r
}.first()
val rank = metadata.select("rank").first().getInt(0)
val userFeatures = sqlContext.parquetFile(userPath(path))
.map { case Row(id: Int, features: Seq[Double]) =>
(id, features.toArray)
Expand All @@ -184,7 +201,7 @@ private object MatrixFactorizationModel extends Loader[MatrixFactorizationModel]
.map { case Row(id: Int, features: Seq[Double]) =>
(id, features.toArray)
}
new MatrixFactorizationModel(r, userFeatures, productFeatures)
new MatrixFactorizationModel(rank, userFeatures, productFeatures)
}

private def userPath(path: String): String = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.scalatest.FunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils

class MatrixFactorizationModelSuite extends FunSuite with MLlibTestSparkContext {

Expand Down Expand Up @@ -53,4 +54,22 @@ class MatrixFactorizationModelSuite extends FunSuite with MLlibTestSparkContext
new MatrixFactorizationModel(rank, userFeatures, prodFeatures1)
}
}

test("save/load") {
val model = new MatrixFactorizationModel(rank, userFeatures, prodFeatures)
val tempDir = Utils.createTempDir()
val path = tempDir.toURI.toString
def collect(features: RDD[(Int, Array[Double])]): Set[(Int, Seq[Double])] = {
features.mapValues(_.toSeq).collect().toSet
}
try {
model.save(sc, path)
val newModel = MatrixFactorizationModel.load(sc, path)
assert(newModel.rank === rank)
assert(collect(newModel.userFeatures) === collect(userFeatures))
assert(collect(newModel.productFeatures) === collect(prodFeatures))
} finally {
Utils.deleteRecursively(tempDir)
}
}
}

0 comments on commit f487cb2

Please sign in to comment.