Skip to content

Commit

Permalink
update ALS tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Jan 8, 2015
1 parent 2a8deb3 commit a76da7b
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 56 deletions.
Expand Up @@ -643,7 +643,7 @@ private[recommendation] object ALS extends Logging {
*/
def compress(): InBlock = {
val sz = size
assert(sz > 0) // TODO: Check whether it is possible to have empty blocks.
assert(sz > 0, "Empty in-link block should not exist.")
sort()
val uniqueSrcIdsBuilder = mutable.ArrayBuilder.make[Int]
val dstCountsBuilder = mutable.ArrayBuilder.make[Int]
Expand Down Expand Up @@ -681,7 +681,7 @@ private[recommendation] object ALS extends Logging {

private def sort(): Unit = {
val sz = size
// Since there might be interleaved log messages, we insert a unqiue id for easy pairing.
// Since there might be interleaved log messages, we insert a unique id for easy pairing.
val sortId = Utils.random.nextInt()
logDebug(s"Start sorting an uncompressed in-block of size $sz. (sortId = $sortId)")
val start = System.nanoTime()
Expand Down Expand Up @@ -807,7 +807,7 @@ private[recommendation] object ALS extends Logging {
i += 1
}
logDebug(
"Converting to local indices took " + (System.nanoTime() - start) / 1e9 + "seconds.")
"Converting to local indices took " + (System.nanoTime() - start) / 1e9 + " seconds.")
val dstLocalIndices = dstIds.map(dstIdToLocalIndex.apply)
(srcBlockId, (dstBlockId, srcIds, dstLocalIndices, ratings))
}.groupByKey(new HashPartitioner(srcPart.numPartitions))
Expand Down Expand Up @@ -845,7 +845,7 @@ private[recommendation] object ALS extends Logging {
}

/**
* Compute dst factors by forming and solving least square problems.
* Compute dst factors by constructing and solving least square problems.
*
* @param srcFactorBlocks src factors
* @param srcOutBlocks src out-blocks
Expand All @@ -867,6 +867,7 @@ private[recommendation] object ALS extends Logging {
srcEncoder: LocalIndexEncoder,
implicitPrefs: Boolean = false,
alpha: Double = 1.0): RDD[(Int, FactorBlock)] = {
val numSrcBlocks = srcFactorBlocks.partitions.size
val YtY = if (implicitPrefs) Some(computeYtY(srcFactorBlocks, rank)) else None
val srcOut = srcOutBlocks.join(srcFactorBlocks).flatMap {
case (srcBlockId, (srcOutBlock, srcFactors)) =>
Expand All @@ -877,7 +878,10 @@ private[recommendation] object ALS extends Logging {
val merged = srcOut.groupByKey(new HashPartitioner(dstInBlocks.partitions.size))
dstInBlocks.join(merged).mapValues {
case (InBlock(dstIds, srcPtrs, srcEncodedIndices, ratings), srcFactors) =>
val sortedSrcFactors = srcFactors.toSeq.sortBy(_._1).map(_._2).toArray
val sortedSrcFactors = new Array[FactorBlock](numSrcBlocks)
srcFactors.foreach { case (srcBlockId, factors) =>
sortedSrcFactors(srcBlockId) = factors
}
val dstFactors = new Array[Array[Float]](dstIds.size)
var j = 0
val ls = new NormalEquation(rank)
Expand Down
101 changes: 50 additions & 51 deletions mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
Expand Up @@ -30,14 +30,9 @@ import org.apache.spark.ml.recommendation.ALS._
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext}

case class ALSTestData(
training: Seq[Rating],
test: Seq[Rating],
userFactors: Map[Int, Array[Float]],
itemFactors: Map[Int, Array[Float]])

class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {

private var sqlContext: SQLContext = _
Expand Down Expand Up @@ -219,46 +214,61 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
assert(decompressed.toSet === expected)
}

/**
* Generates ratings for testing ALS.
*
* @param numUsers number of users
* @param numItems number of items
* @param rank rank
* @param trainingFraction fraction for training
* @param testFraction fraction for test
* @param noiseLevel noise level for additive Gaussian noise on training data
* @param seed random seed
* @return (training, test)
*/
def genALSTestData(
numUsers: Int,
numItems: Int,
rank: Int,
trainingFraction: Double,
testFraction: Double,
noiseLevel: Double = 0.0,
seed: Long = 11L): ALSTestData = {
seed: Long = 11L): (RDD[Rating], RDD[Rating]) = {
val totalFraction = trainingFraction + testFraction
require(totalFraction <= 1.0)
val random = new Random(seed)
val userFactors = genFactors(numUsers, rank, random)
val itemFactors = genFactors(numItems, rank, random)
val totalFraction = trainingFraction + testFraction
val training = ArrayBuffer.empty[Rating]
val test = ArrayBuffer.empty[Rating]
for ((userId, userFactor) <- userFactors; (itemId, itemFactor) <- itemFactors) {
val x = random.nextDouble()
if (x < totalFraction) {
val rating = blas.sdot(rank, userFactor, 1, itemFactor, 1)
if (x < trainingFraction) {
training += Rating(userId, itemId, rating + noiseLevel.toFloat * random.nextFloat())
val noise = noiseLevel * random.nextGaussian()
training += Rating(userId, itemId, rating + noise.toFloat)
} else {
test += Rating(userId, itemId, rating)
}
}
}
logInfo(s"Generated ${training.size} ratings for training and ${test.size} for test.")
ALSTestData(training.toSeq, test.toSeq, userFactors, itemFactors)
(sc.parallelize(training, 2), sc.parallelize(test, 2))
}

def genFactors(size: Int, rank: Int, random: Random): Map[Int, Array[Float]] = {
private def genFactors(size: Int, rank: Int, random: Random): Seq[(Int, Array[Float])] = {
require(size > 0 && size < Int.MaxValue / 3)
val ids = mutable.Set.empty[Int]
while (ids.size < size) {
ids += random.nextInt()
}
ids.map(id => (id, Array.fill(rank)(random.nextFloat()))).toMap
ids.toSeq.sorted.map(id => (id, Array.fill(rank)(random.nextFloat())))
}

def testALS(
alsTestData: ALSTestData,
training: RDD[Rating],
test: RDD[Rating],
rank: Int,
maxIter: Int,
regParam: Double,
Expand All @@ -267,8 +277,6 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
numItemBlocks: Int = 3): Unit = {
val sqlContext = this.sqlContext
import sqlContext.{createSchemaRDD, symbolToUnresolvedAttribute}
val training = sc.parallelize(alsTestData.training, 2)
val test = sc.parallelize(alsTestData.test, 2)
val als = new ALS()
.setRank(rank)
.setRegParam(regParam)
Expand All @@ -287,48 +295,39 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
}

test("exact rank-1 matrix") {
val testData = genALSTestData(
numUsers = 20,
numItems = 40,
rank = 1,
trainingFraction = 0.6,
testFraction = 0.3)
testALS(testData, maxIter = 1, rank = 1, regParam = 1e-4, targetRMSE = 0.002)
testALS(testData, maxIter = 1, rank = 2, regParam = 1e-4, targetRMSE = 0.002)
val (training, test) = genALSTestData(numUsers = 20, numItems = 40, rank = 1,
trainingFraction = 0.6, testFraction = 0.3)
testALS(training, test, maxIter = 1, rank = 1, regParam = 1e-5, targetRMSE = 0.001)
testALS(training, test, maxIter = 1, rank = 2, regParam = 1e-5, targetRMSE = 0.001)
}

test("approximate rank-1 matrix") {
val testData = genALSTestData(
numUsers = 20,
numItems = 40,
rank = 1,
trainingFraction = 0.6,
testFraction = 0.3,
noiseLevel = 0.01)
testALS(testData, maxIter = 2, rank = 1, regParam = 0.01, targetRMSE = 0.02)
testALS(testData, maxIter = 2, rank = 2, regParam = 0.01, targetRMSE = 0.02)
val (training, test) = genALSTestData(numUsers = 20, numItems = 40, rank = 1,
trainingFraction = 0.6, testFraction = 0.3, noiseLevel = 0.01)
testALS(training, test, maxIter = 2, rank = 1, regParam = 0.01, targetRMSE = 0.02)
testALS(training, test, maxIter = 2, rank = 2, regParam = 0.01, targetRMSE = 0.02)
}

test("approximate rank-2 matrix") {
val (training, test) = genALSTestData(numUsers = 20, numItems = 40, rank = 2,
trainingFraction = 0.6, testFraction = 0.3, noiseLevel = 0.01)
testALS(training, test, maxIter = 4, rank = 2, regParam = 0.01, targetRMSE = 0.03)
testALS(training, test, maxIter = 4, rank = 3, regParam = 0.01, targetRMSE = 0.03)
}

test("exact rank-2 matrix") {
val testData = genALSTestData(
numUsers = 20,
numItems = 40,
rank = 2,
trainingFraction = 0.6,
testFraction = 0.3)
testALS(testData, maxIter = 4, rank = 2, regParam = 1e-4, targetRMSE = 0.002)
testALS(testData, maxIter = 6, rank = 3, regParam = 0.01, targetRMSE = 0.04)
test("different block settings") {
val (training, test) = genALSTestData(numUsers = 20, numItems = 40, rank = 2,
trainingFraction = 0.6, testFraction = 0.3, noiseLevel = 0.01)
for ((numUserBlocks, numItemBlocks) <- Seq((1, 1), (1, 2), (2, 1), (2, 2))) {
testALS(training, test, maxIter = 4, rank = 2, regParam = 0.01, targetRMSE = 0.03,
numUserBlocks = numUserBlocks, numItemBlocks = numItemBlocks)
}
}

test("approximate rank-2 matrix") {
val testData = genALSTestData(
numUsers = 20,
numItems = 40,
rank = 2,
trainingFraction = 0.6,
testFraction = 0.3,
noiseLevel = 0.01)
testALS(testData, maxIter = 4, rank = 2, regParam = 0.01, targetRMSE = 0.03)
testALS(testData, maxIter = 4, rank = 3, regParam = 0.01, targetRMSE = 0.03)
test("more blocks than ratings") {
val (training, test) = genALSTestData(numUsers = 4, numItems = 4, rank = 1,
trainingFraction = 0.7, testFraction = 0.3)
testALS(training, test, maxIter = 2, rank = 1, regParam = 1e-4, targetRMSE = 0.002,
numItemBlocks = 5, numUserBlocks = 5)
}
}

0 comments on commit a76da7b

Please sign in to comment.