Skip to content

Commit

Permalink
[SPARK-5598][MLLIB] model save/load for ALS
Browse files Browse the repository at this point in the history
following #4233. jkbradley

Author: Xiangrui Meng <meng@databricks.com>

Closes #4422 from mengxr/SPARK-5598 and squashes the following commits:

a059394 [Xiangrui Meng] SaveLoad not extending Loader
14b7ea6 [Xiangrui Meng] address comments
f487cb2 [Xiangrui Meng] add unit tests
62fc43c [Xiangrui Meng] implement save/load for MFM
  • Loading branch information
mengxr committed Feb 9, 2015
1 parent 804949d commit 5c299c5
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.mllib.recommendation

import org.apache.spark.Logging
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.ml.recommendation.{ALS => NewALS}
import org.apache.spark.rdd.RDD
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,17 @@

package org.apache.spark.mllib.recommendation

import java.io.IOException
import java.lang.{Integer => JavaInteger}

import org.apache.hadoop.fs.Path
import org.jblas.DoubleMatrix

import org.apache.spark.Logging
import org.apache.spark.{Logging, SparkContext}
import org.apache.spark.api.java.{JavaPairRDD, JavaRDD}
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.storage.StorageLevel

/**
Expand All @@ -41,7 +45,8 @@ import org.apache.spark.storage.StorageLevel
class MatrixFactorizationModel(
val rank: Int,
val userFeatures: RDD[(Int, Array[Double])],
val productFeatures: RDD[(Int, Array[Double])]) extends Serializable with Logging {
val productFeatures: RDD[(Int, Array[Double])])
extends Saveable with Serializable with Logging {

require(rank > 0)
validateFeatures("User", userFeatures)
Expand Down Expand Up @@ -125,6 +130,12 @@ class MatrixFactorizationModel(
recommend(productFeatures.lookup(product).head, userFeatures, num)
.map(t => Rating(t._1, product, t._2))

protected override val formatVersion: String = "1.0"

override def save(sc: SparkContext, path: String): Unit = {
MatrixFactorizationModel.SaveLoadV1_0.save(this, path)
}

private def recommend(
recommendToFeatures: Array[Double],
recommendableFeatures: RDD[(Int, Array[Double])],
Expand All @@ -136,3 +147,70 @@ class MatrixFactorizationModel(
scored.top(num)(Ordering.by(_._2))
}
}

object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] {

import org.apache.spark.mllib.util.Loader._

override def load(sc: SparkContext, path: String): MatrixFactorizationModel = {
val (loadedClassName, formatVersion, metadata) = loadMetadata(sc, path)
val classNameV1_0 = SaveLoadV1_0.thisClassName
(loadedClassName, formatVersion) match {
case (className, "1.0") if className == classNameV1_0 =>
SaveLoadV1_0.load(sc, path)
case _ =>
throw new IOException("MatrixFactorizationModel.load did not recognize model with" +
s"(class: $loadedClassName, version: $formatVersion). Supported:\n" +
s" ($classNameV1_0, 1.0)")
}
}

private[recommendation]
object SaveLoadV1_0 {

private val thisFormatVersion = "1.0"

private[recommendation]
val thisClassName = "org.apache.spark.mllib.recommendation.MatrixFactorizationModel"

/**
* Saves a [[MatrixFactorizationModel]], where user features are saved under `data/users` and
* product features are saved under `data/products`.
*/
def save(model: MatrixFactorizationModel, path: String): Unit = {
val sc = model.userFeatures.sparkContext
val sqlContext = new SQLContext(sc)
import sqlContext.implicits.createDataFrame
val metadata = (thisClassName, thisFormatVersion, model.rank)
val metadataRDD = sc.parallelize(Seq(metadata), 1).toDataFrame("class", "version", "rank")
metadataRDD.toJSON.saveAsTextFile(metadataPath(path))
model.userFeatures.toDataFrame("id", "features").saveAsParquetFile(userPath(path))
model.productFeatures.toDataFrame("id", "features").saveAsParquetFile(productPath(path))
}

def load(sc: SparkContext, path: String): MatrixFactorizationModel = {
val sqlContext = new SQLContext(sc)
val (className, formatVersion, metadata) = loadMetadata(sc, path)
assert(className == thisClassName)
assert(formatVersion == thisFormatVersion)
val rank = metadata.select("rank").first().getInt(0)
val userFeatures = sqlContext.parquetFile(userPath(path))
.map { case Row(id: Int, features: Seq[Double]) =>
(id, features.toArray)
}
val productFeatures = sqlContext.parquetFile(productPath(path))
.map { case Row(id: Int, features: Seq[Double]) =>
(id, features.toArray)
}
new MatrixFactorizationModel(rank, userFeatures, productFeatures)
}

private def userPath(path: String): String = {
new Path(dataPath(path), "user").toUri.toString
}

private def productPath(path: String): String = {
new Path(dataPath(path), "product").toUri.toString
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.scalatest.FunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils

class MatrixFactorizationModelSuite extends FunSuite with MLlibTestSparkContext {

Expand Down Expand Up @@ -53,4 +54,22 @@ class MatrixFactorizationModelSuite extends FunSuite with MLlibTestSparkContext
new MatrixFactorizationModel(rank, userFeatures, prodFeatures1)
}
}

test("save/load") {
val model = new MatrixFactorizationModel(rank, userFeatures, prodFeatures)
val tempDir = Utils.createTempDir()
val path = tempDir.toURI.toString
def collect(features: RDD[(Int, Array[Double])]): Set[(Int, Seq[Double])] = {
features.mapValues(_.toSeq).collect().toSet
}
try {
model.save(sc, path)
val newModel = MatrixFactorizationModel.load(sc, path)
assert(newModel.rank === rank)
assert(collect(newModel.userFeatures) === collect(userFeatures))
assert(collect(newModel.productFeatures) === collect(prodFeatures))
} finally {
Utils.deleteRecursively(tempDir)
}
}
}

0 comments on commit 5c299c5

Please sign in to comment.