Skip to content

Commit

Permalink
[SPARK-20596][ML][TEST] Consolidate and improve ALS recommendAll test…
Browse files Browse the repository at this point in the history
… cases

Existing test cases for `recommendForAllX` methods (added in [SPARK-19535](https://issues.apache.org/jira/browse/SPARK-19535)) test `k < num items` and `k = num items`. Technically we should also test that `k > num items` returns the same results as `k = num items`.

## How was this patch tested?

Updated existing unit tests.

Author: Nick Pentreath <nickp@za.ibm.com>

Closes #17860 from MLnick/SPARK-20596-als-rec-tests.
  • Loading branch information
Nick Pentreath committed May 8, 2017
1 parent 1552665 commit 58518d0
Showing 1 changed file with 25 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -671,58 +671,45 @@ class ALSSuite
.setItemCol("item")
}

test("recommendForAllUsers with k < num_items") {
val topItems = getALSModel.recommendForAllUsers(2)
assert(topItems.count() == 3)
assert(topItems.columns.contains("user"))

val expected = Map(
0 -> Array((3, 54f), (4, 44f)),
1 -> Array((3, 39f), (5, 33f)),
2 -> Array((3, 51f), (5, 45f))
)
checkRecommendations(topItems, expected, "item")
}

test("recommendForAllUsers with k = num_items") {
val topItems = getALSModel.recommendForAllUsers(4)
assert(topItems.count() == 3)
assert(topItems.columns.contains("user"))

test("recommendForAllUsers with k <, = and > num_items") {
val model = getALSModel
val numUsers = model.userFactors.count
val numItems = model.itemFactors.count
val expected = Map(
0 -> Array((3, 54f), (4, 44f), (5, 42f), (6, 28f)),
1 -> Array((3, 39f), (5, 33f), (4, 26f), (6, 16f)),
2 -> Array((3, 51f), (5, 45f), (4, 30f), (6, 18f))
)
checkRecommendations(topItems, expected, "item")
}

test("recommendForAllItems with k < num_users") {
val topUsers = getALSModel.recommendForAllItems(2)
assert(topUsers.count() == 4)
assert(topUsers.columns.contains("item"))

val expected = Map(
3 -> Array((0, 54f), (2, 51f)),
4 -> Array((0, 44f), (2, 30f)),
5 -> Array((2, 45f), (0, 42f)),
6 -> Array((0, 28f), (2, 18f))
)
checkRecommendations(topUsers, expected, "user")
Seq(2, 4, 6).foreach { k =>
val n = math.min(k, numItems).toInt
val expectedUpToN = expected.mapValues(_.slice(0, n))
val topItems = model.recommendForAllUsers(k)
assert(topItems.count() == numUsers)
assert(topItems.columns.contains("user"))
checkRecommendations(topItems, expectedUpToN, "item")
}
}

test("recommendForAllItems with k = num_users") {
val topUsers = getALSModel.recommendForAllItems(3)
assert(topUsers.count() == 4)
assert(topUsers.columns.contains("item"))

test("recommendForAllItems with k <, = and > num_users") {
val model = getALSModel
val numUsers = model.userFactors.count
val numItems = model.itemFactors.count
val expected = Map(
3 -> Array((0, 54f), (2, 51f), (1, 39f)),
4 -> Array((0, 44f), (2, 30f), (1, 26f)),
5 -> Array((2, 45f), (0, 42f), (1, 33f)),
6 -> Array((0, 28f), (2, 18f), (1, 16f))
)
checkRecommendations(topUsers, expected, "user")

Seq(2, 3, 4).foreach { k =>
val n = math.min(k, numUsers).toInt
val expectedUpToN = expected.mapValues(_.slice(0, n))
val topUsers = getALSModel.recommendForAllItems(k)
assert(topUsers.count() == numItems)
assert(topUsers.columns.contains("item"))
checkRecommendations(topUsers, expectedUpToN, "user")
}
}

private def checkRecommendations(
Expand Down

0 comments on commit 58518d0

Please sign in to comment.