Skip to content

Commit

Permalink
add dsyrk to ALS
Browse files Browse the repository at this point in the history
  • Loading branch information
hqzizania committed Jun 24, 2016
1 parent cc6778e commit 8fb4a82
Showing 1 changed file with 50 additions and 3 deletions.
53 changes: 50 additions & 3 deletions mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -619,8 +619,10 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {

/** Number of entries in the upper triangular part of a k-by-k matrix. */
val triK = k * (k + 1) / 2
/** A^T^ * A */
/** The upper triangular of A^T^ * A */
val ata = new Array[Double](triK)
/** A^T^ * A */
val ata2 = new Array[Double](k * k)
/** A^T^ * b */
val atb = new Array[Double](k)

Expand All @@ -635,6 +637,15 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
}
}

private def copyToTri(a: Array[Double]): Unit = {
var ii = 0
for(i <- 0 until k)
for(j <- 0 to i) {
ata(ii) += ata2(i * k + j)
ii += 1
}
}

/** Adds an observation. */
def add(a: Array[Float], b: Double, c: Double = 1.0): this.type = {
require(c >= 0.0)
Expand All @@ -647,6 +658,15 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
this
}

/** Adds a stack of observations. */
def addStack(a: Array[Double], b: Array[Double], n: Int): this.type = {
require(a.length == n * k)
blas.dsyrk(upper, "N", k, n, 1.0, a, k, 1.0, ata2, k)
copyToTri(ata2)
blas.dgemv("N", k, n, 1.0, a, k, b, 1, 1.0, atb, 1)
this
}

/** Merges another normal equation object. */
def merge(other: NormalEquation): this.type = {
require(other.k == k)
Expand All @@ -658,6 +678,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
/** Resets everything to zero, which should be called after each solve. */
def reset(): Unit = {
ju.Arrays.fill(ata, 0.0)
ju.Arrays.fill(ata2, 0.0)
ju.Arrays.fill(atb, 0.0)
}
}
Expand Down Expand Up @@ -1296,6 +1317,9 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
}
var i = srcPtrs(j)
var numExplicits = 0
val doStack = if (srcPtrs(j + 1) - srcPtrs(j) > 10) true else false
val srcFactorBuffer = mutable.ArrayBuilder.make[Double]
val bBuffer = mutable.ArrayBuilder.make[Double]
while (i < srcPtrs(j + 1)) {
val encoded = srcEncodedIndices(i)
val blockId = srcEncoder.blockId(encoded)
Expand All @@ -1310,14 +1334,37 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
// for rating > 0. Because YtY is already added, we need to adjust the scaling here.
if (rating > 0) {
numExplicits += 1
ls.add(srcFactor, (c1 + 1.0) / c1, c1)
if (doStack) {
var ii = 0
while(ii < srcFactor.length) {
srcFactorBuffer += srcFactor(ii) * c1
ii += 1
}
bBuffer += (c1 + 1.0) / c1
}
else {
ls.add(srcFactor, (c1 + 1.0) / c1, c1)
}
}
} else {
ls.add(srcFactor, rating)
numExplicits += 1
if (doStack) {
bBuffer += rating
var ii = 0
while(ii < srcFactor.length) {
srcFactorBuffer += srcFactor(ii)
ii += 1
}
}
else {
ls.add(srcFactor, rating)
}
}
i += 1
}
if(numExplicits > 0 && doStack) {
ls.addStack(srcFactorBuffer.result(), bBuffer.result(), numExplicits)
}
// Weight lambda by the number of explicit ratings based on the ALS-WR paper.
dstFactors(j) = solver.solve(ls, numExplicits * regParam)
j += 1
Expand Down

0 comments on commit 8fb4a82

Please sign in to comment.