Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 33 additions & 22 deletions mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import com.github.fommil.netlib.LAPACK.{getInstance => lapack}
import org.jblas.DoubleMatrix
import org.netlib.util.intW

import org.apache.spark.{HashPartitioner, Logging, Partitioner}
import org.apache.spark.{Logging, Partitioner}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param._
Expand Down Expand Up @@ -501,8 +501,8 @@ object ALS extends Logging {
require(intermediateRDDStorageLevel != StorageLevel.NONE,
"ALS is not designed to run without persisting intermediate RDDs.")
val sc = ratings.sparkContext
val userPart = new HashPartitioner(numUserBlocks)
val itemPart = new HashPartitioner(numItemBlocks)
val userPart = new ALSPartitioner(numUserBlocks)
val itemPart = new ALSPartitioner(numItemBlocks)
val userLocalIndexEncoder = new LocalIndexEncoder(userPart.numPartitions)
val itemLocalIndexEncoder = new LocalIndexEncoder(itemPart.numPartitions)
val solver = if (nonnegative) new NNLSSolver else new CholeskySolver
Expand Down Expand Up @@ -550,13 +550,23 @@ object ALS extends Logging {
val userIdAndFactors = userInBlocks
.mapValues(_.srcIds)
.join(userFactors)
.values
.mapPartitions({ items =>
items.flatMap { case (_, (ids, factors)) =>
ids.view.zip(factors)
}
// Preserve the partitioning because IDs are consistent with the partitioners in userInBlocks
// and userFactors.
}, preservesPartitioning = true)
.setName("userFactors")
.persist(finalRDDStorageLevel)
val itemIdAndFactors = itemInBlocks
.mapValues(_.srcIds)
.join(itemFactors)
.values
.mapPartitions({ items =>
items.flatMap { case (_, (ids, factors)) =>
ids.view.zip(factors)
}
}, preservesPartitioning = true)
.setName("itemFactors")
.persist(finalRDDStorageLevel)
if (finalRDDStorageLevel != StorageLevel.NONE) {
Expand All @@ -569,13 +579,7 @@ object ALS extends Logging {
itemOutBlocks.unpersist()
blockRatings.unpersist()
}
val userOutput = userIdAndFactors.flatMap { case (ids, factors) =>
ids.view.zip(factors)
}
val itemOutput = itemIdAndFactors.flatMap { case (ids, factors) =>
ids.view.zip(factors)
}
(userOutput, itemOutput)
(userIdAndFactors, itemIdAndFactors)
}

/**
Expand Down Expand Up @@ -995,15 +999,15 @@ object ALS extends Logging {
"Converting to local indices took " + (System.nanoTime() - start) / 1e9 + " seconds.")
val dstLocalIndices = dstIds.map(dstIdToLocalIndex.apply)
(srcBlockId, (dstBlockId, srcIds, dstLocalIndices, ratings))
}.groupByKey(new HashPartitioner(srcPart.numPartitions))
.mapValues { iter =>
val builder =
new UncompressedInBlockBuilder[ID](new LocalIndexEncoder(dstPart.numPartitions))
iter.foreach { case (dstBlockId, srcIds, dstLocalIndices, ratings) =>
builder.add(dstBlockId, srcIds, dstLocalIndices, ratings)
}
builder.build().compress()
}.setName(prefix + "InBlocks")
}.groupByKey(new ALSPartitioner(srcPart.numPartitions))
.mapValues { iter =>
val builder =
new UncompressedInBlockBuilder[ID](new LocalIndexEncoder(dstPart.numPartitions))
iter.foreach { case (dstBlockId, srcIds, dstLocalIndices, ratings) =>
builder.add(dstBlockId, srcIds, dstLocalIndices, ratings)
}
builder.build().compress()
}.setName(prefix + "InBlocks")
.persist(storageLevel)
val outBlocks = inBlocks.mapValues { case InBlock(srcIds, dstPtrs, dstEncodedIndices, _) =>
val encoder = new LocalIndexEncoder(dstPart.numPartitions)
Expand Down Expand Up @@ -1064,7 +1068,7 @@ object ALS extends Logging {
(dstBlockId, (srcBlockId, activeIndices.map(idx => srcFactors(idx))))
}
}
val merged = srcOut.groupByKey(new HashPartitioner(dstInBlocks.partitions.length))
val merged = srcOut.groupByKey(new ALSPartitioner(dstInBlocks.partitions.length))
dstInBlocks.join(merged).mapValues {
case (InBlock(dstIds, srcPtrs, srcEncodedIndices, ratings), srcFactors) =>
val sortedSrcFactors = new Array[FactorBlock](numSrcBlocks)
Expand Down Expand Up @@ -1149,4 +1153,11 @@ object ALS extends Logging {
encoded & localIndexMask
}
}

/**
* Partitioner used by ALS. We requires that getPartition is a projection. That is, for any key k,
* we have getPartition(getPartition(k)) = getPartition(k). Since the the default HashPartitioner
* satisfies this requirement, we simply use a type alias here.
*/
private[recommendation] type ALSPartitioner = org.apache.spark.HashPartitioner
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import scala.collection.mutable.ArrayBuffer
import com.github.fommil.netlib.BLAS.{getInstance => blas}
import org.scalatest.FunSuite

import org.apache.spark.Logging
import org.apache.spark.{Logging, SparkException}
import org.apache.spark.ml.recommendation.ALS._
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
Expand Down Expand Up @@ -455,4 +455,34 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
assert(isNonnegative(itemFactors))
// TODO: Validate the solution.
}

test("als partitioner is a projection") {
for (p <- Seq(1, 10, 100, 1000)) {
val part = new ALSPartitioner(p)
var k = 0
while (k < p) {
assert(k === part.getPartition(k))
assert(k === part.getPartition(k.toLong))
k += 1
}
}
}

test("partitioner in returned factors") {
val (ratings, _) = genImplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01)
val (userFactors, itemFactors) = ALS.train(
ratings, rank = 2, maxIter = 4, numUserBlocks = 3, numItemBlocks = 4)
for ((tpe, factors) <- Seq(("User", userFactors), ("Item", itemFactors))) {
assert(userFactors.partitioner.isDefined, s"$tpe factors should have partitioner.")
val part = userFactors.partitioner.get
userFactors.mapPartitionsWithIndex { (idx, items) =>
items.foreach { case (id, _) =>
if (part.getPartition(id) != idx) {
throw new SparkException(s"$tpe with ID $id should not be in partition $idx.")
}
}
Iterator.empty
}.count()
}
}
}