Skip to content

Commit

Permalink
implicitprefs ut fails fix
Browse files Browse the repository at this point in the history
  • Loading branch information
hqzizania committed Jun 24, 2016
1 parent 8fb4a82 commit 7e3d238
Showing 1 changed file with 7 additions and 19 deletions.
26 changes: 7 additions & 19 deletions mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -619,14 +619,13 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {

/** Number of entries in the upper triangular part of a k-by-k matrix. */
val triK = k * (k + 1) / 2
/** The upper triangular of A^T^ * A */
val ata = new Array[Double](triK)
/** A^T^ * A */
val ata2 = new Array[Double](k * k)
val ata = new Array[Double](triK)
/** A^T^ * b */
val atb = new Array[Double](k)

private val da = new Array[Double](k)
private val ata2 = new Array[Double](k * k)
private val upper = "U"

private def copyToDouble(a: Array[Float]): Unit = {
Expand All @@ -637,7 +636,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
}
}

private def copyToTri(a: Array[Double]): Unit = {
private def copyToTri(): Unit = {
var ii = 0
for(i <- 0 until k)
for(j <- 0 to i) {
Expand All @@ -662,7 +661,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
def addStack(a: Array[Double], b: Array[Double], n: Int): this.type = {
require(a.length == n * k)
blas.dsyrk(upper, "N", k, n, 1.0, a, k, 1.0, ata2, k)
copyToTri(ata2)
copyToTri()
blas.dgemv("N", k, n, 1.0, a, k, b, 1, 1.0, atb, 1)
this
}
Expand Down Expand Up @@ -1334,17 +1333,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
// for rating > 0. Because YtY is already added, we need to adjust the scaling here.
if (rating > 0) {
numExplicits += 1
if (doStack) {
var ii = 0
while(ii < srcFactor.length) {
srcFactorBuffer += srcFactor(ii) * c1
ii += 1
}
bBuffer += (c1 + 1.0) / c1
}
else {
ls.add(srcFactor, (c1 + 1.0) / c1, c1)
}
ls.add(srcFactor, (c1 + 1.0) / c1, c1)
}
} else {
numExplicits += 1
Expand All @@ -1355,14 +1344,13 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
srcFactorBuffer += srcFactor(ii)
ii += 1
}
}
else {
} else {
ls.add(srcFactor, rating)
}
}
i += 1
}
if(numExplicits > 0 && doStack) {
if (!implicitPrefs && doStack) {
ls.addStack(srcFactorBuffer.result(), bBuffer.result(), numExplicits)
}
// Weight lambda by the number of explicit ratings based on the ALS-WR paper.
Expand Down

0 comments on commit 7e3d238

Please sign in to comment.