Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-6685][MLLIB]Use DSYRK to compute AtA in ALS #13891

Closed
wants to merge 12 commits into from
83 changes: 74 additions & 9 deletions mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,24 @@ private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter w
/** @group expertGetParam */
def getFinalStorageLevel: String = $(finalStorageLevel)

/**
* Param for threshold in computation of dst factors to decide
* if stacking factors to speed up the computation.(>= 1).
* Default: 1024
* @group expertParam
*/
val threshold = new IntParam(this, "threshold", "threshold in computation of dst factors " +
"to decide if stacking factors to speed up the computation.",
ParamValidators.gtEq(1))

/** @group expertGetParam */
def getThreshold: Int = $(threshold)

setDefault(rank -> 10, maxIter -> 10, regParam -> 0.1, numUserBlocks -> 10, numItemBlocks -> 10,
implicitPrefs -> false, alpha -> 1.0, userCol -> "user", itemCol -> "item",
ratingCol -> "rating", nonnegative -> false, checkpointInterval -> 10,
intermediateStorageLevel -> "MEMORY_AND_DISK", finalStorageLevel -> "MEMORY_AND_DISK")
intermediateStorageLevel -> "MEMORY_AND_DISK", finalStorageLevel -> "MEMORY_AND_DISK",
threshold -> 1024)

/**
* Validates and transforms the input schema.
Expand Down Expand Up @@ -432,6 +446,10 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel]
@Since("2.0.0")
def setFinalStorageLevel(value: String): this.type = set(finalStorageLevel, value)

/** @group expertSetParam */
@Since("2.1.0")
def setThreshold(value: Int): this.type = set(threshold, value)

/**
* Sets both numUserBlocks and numItemBlocks to the specific value.
*
Expand Down Expand Up @@ -460,14 +478,15 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel]
val instrLog = Instrumentation.create(this, ratings)
instrLog.logParams(rank, numUserBlocks, numItemBlocks, implicitPrefs, alpha,
userCol, itemCol, ratingCol, predictionCol, maxIter,
regParam, nonnegative, checkpointInterval, seed)
regParam, nonnegative, threshold, checkpointInterval, seed)
val (userFactors, itemFactors) = ALS.train(ratings, rank = $(rank),
numUserBlocks = $(numUserBlocks), numItemBlocks = $(numItemBlocks),
maxIter = $(maxIter), regParam = $(regParam), implicitPrefs = $(implicitPrefs),
alpha = $(alpha), nonnegative = $(nonnegative),
intermediateRDDStorageLevel = StorageLevel.fromString($(intermediateStorageLevel)),
finalRDDStorageLevel = StorageLevel.fromString($(finalStorageLevel)),
checkpointInterval = $(checkpointInterval), seed = $(seed))
threshold = $(threshold), checkpointInterval = $(checkpointInterval),
seed = $(seed))
val userDF = userFactors.toDF("id", "features")
val itemDF = itemFactors.toDF("id", "features")
val model = new ALSModel(uid, $(rank), userDF, itemDF).setParent(this)
Expand Down Expand Up @@ -621,6 +640,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
val atb = new Array[Double](k)

private val da = new Array[Double](k)
private val ata2 = new Array[Double](k * k)
private val upper = "U"

private def copyToDouble(a: Array[Float]): Unit = {
Expand All @@ -631,6 +651,22 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
}
}

private def copyToTri(): Unit = {
var i = 0
var j = 0
var ii = 0
while (i < k) {
val temp = i * k
j = 0
while (j <= i) {
ata(ii) += ata2(temp + j)
j += 1
ii += 1
}
i += 1
}
}

/** Adds an observation. */
def add(a: Array[Float], b: Double, c: Double = 1.0): this.type = {
require(c >= 0.0)
Expand All @@ -643,6 +679,15 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
this
}

/** Adds a stack of observations. */
def addStack(a: Array[Double], b: Array[Double], n: Int): this.type = {
require(a.length == n * k)
blas.dsyrk(upper, "N", k, n, 1.0, a, k, 1.0, ata2, k)
copyToTri()
blas.dgemv("N", k, n, 1.0, a, k, b, 1, 1.0, atb, 1)
this
}

/** Merges another normal equation object. */
def merge(other: NormalEquation): this.type = {
require(other.k == k)
Expand All @@ -654,6 +699,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
/** Resets everything to zero, which should be called after each solve. */
def reset(): Unit = {
ju.Arrays.fill(ata, 0.0)
ju.Arrays.fill(ata2, 0.0)
ju.Arrays.fill(atb, 0.0)
}
}
Expand All @@ -675,6 +721,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
nonnegative: Boolean = false,
intermediateRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK,
finalRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK,
threshold: Int = 1024,
checkpointInterval: Int = 10,
seed: Long = 0L)(
implicit ord: Ordering[ID]): (RDD[(ID, Array[Float])], RDD[(ID, Array[Float])]) = {
Expand Down Expand Up @@ -721,7 +768,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
userFactors.setName(s"userFactors-$iter").persist(intermediateRDDStorageLevel)
val previousItemFactors = itemFactors
itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam,
userLocalIndexEncoder, implicitPrefs, alpha, solver)
userLocalIndexEncoder, implicitPrefs, alpha, solver, threshold)
previousItemFactors.unpersist()
itemFactors.setName(s"itemFactors-$iter").persist(intermediateRDDStorageLevel)
// TODO: Generalize PeriodicGraphCheckpointer and use it here.
Expand All @@ -731,7 +778,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
}
val previousUserFactors = userFactors
userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam,
itemLocalIndexEncoder, implicitPrefs, alpha, solver)
itemLocalIndexEncoder, implicitPrefs, alpha, solver, threshold)
if (shouldCheckpoint(iter)) {
ALS.cleanShuffleDependencies(sc, deps)
deletePreviousCheckpointFile()
Expand All @@ -742,7 +789,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
} else {
for (iter <- 0 until maxIter) {
itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam,
userLocalIndexEncoder, solver = solver)
userLocalIndexEncoder, solver = solver, threshold = threshold)
if (shouldCheckpoint(iter)) {
val deps = itemFactors.dependencies
itemFactors.checkpoint()
Expand All @@ -752,7 +799,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
previousCheckpointFile = itemFactors.getCheckpointFile
}
userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam,
itemLocalIndexEncoder, solver = solver)
itemLocalIndexEncoder, solver = solver, threshold = threshold)
}
}
val userIdAndFactors = userInBlocks
Expand Down Expand Up @@ -1266,7 +1313,8 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
srcEncoder: LocalIndexEncoder,
implicitPrefs: Boolean = false,
alpha: Double = 1.0,
solver: LeastSquaresNESolver): RDD[(Int, FactorBlock)] = {
solver: LeastSquaresNESolver,
threshold: Int): RDD[(Int, FactorBlock)] = {
val numSrcBlocks = srcFactorBlocks.partitions.length
val YtY = if (implicitPrefs) Some(computeYtY(srcFactorBlocks, rank)) else None
val srcOut = srcOutBlocks.join(srcFactorBlocks).flatMap {
Expand All @@ -1292,6 +1340,11 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
}
var i = srcPtrs(j)
var numExplicits = 0
// Stacking factors(vectors) in matrices to speed up the computation,
// when the number of factors and the rank is large enough.
val doStack = srcPtrs(j + 1) - srcPtrs(j) > threshold && rank > threshold
val srcFactorBuffer = mutable.ArrayBuilder.make[Double]
val bBuffer = mutable.ArrayBuilder.make[Double]
while (i < srcPtrs(j + 1)) {
val encoded = srcEncodedIndices(i)
val blockId = srcEncoder.blockId(encoded)
Expand All @@ -1309,11 +1362,23 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
ls.add(srcFactor, (c1 + 1.0) / c1, c1)
}
} else {
ls.add(srcFactor, rating)
numExplicits += 1
if (doStack) {
bBuffer += rating
var ii = 0
while(ii < srcFactor.length) {
srcFactorBuffer += srcFactor(ii)
ii += 1
}
} else {
ls.add(srcFactor, rating)
}
}
i += 1
}
if (!implicitPrefs && doStack) {
ls.addStack(srcFactorBuffer.result(), bBuffer.result(), numExplicits)
}
// Weight lambda by the number of explicit ratings based on the ALS-WR paper.
dstFactors(j) = solver.solve(ls, numExplicits * regParam)
j += 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,8 @@ class ALSSuite
implicitPrefs: Boolean = false,
numUserBlocks: Int = 2,
numItemBlocks: Int = 3,
targetRMSE: Double = 0.05): Unit = {
targetRMSE: Double = 0.05,
threshold: Int = 1024): Unit = {
val spark = this.spark
import spark.implicits._
val als = new ALS()
Expand All @@ -311,6 +312,7 @@ class ALSSuite
.setNumUserBlocks(numUserBlocks)
.setNumItemBlocks(numItemBlocks)
.setSeed(0)
.setThreshold(threshold)
val alpha = als.getAlpha
val model = als.fit(training.toDF())
val predictions = model.transform(test.toDF()).select("rating", "prediction").rdd.map {
Expand Down Expand Up @@ -382,6 +384,12 @@ class ALSSuite
numItemBlocks = 5, numUserBlocks = 5)
}

test("do stacking factors in matrices") {
val (training, test) = genExplicitTestData(numUsers = 200, numItems = 20, rank = 1)
testALS(training, test, maxIter = 1, rank = 129, regParam = 0.01, targetRMSE = 0.02,
threshold = 128)
}

test("implicit feedback") {
val (training, test) =
genImplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01)
Expand Down
3 changes: 3 additions & 0 deletions project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -864,6 +864,9 @@ object MimaExcludes {
// [SPARK-12221] Add CPU time to metrics
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskMetrics.this"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskMetricDistributions.this")
) ++ Seq(
// SPARK-6685
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.recommendation.ALS.train")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this actually remove a method? that shouldn't be necessary, I imagine.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I don't know how to write the "mima exclude" exactly. It could be not a proper solution to the failure of mima, which may be caused by the modification to def train

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. It's a developer API, so more reasonable to change, though it's still ideal to not change these APIs unless necessary. Try putting the new param at the end? I don't think that changes the situation but makes it at least source-compatible with any current invocations.

)
}

Expand Down