Skip to content

Commit

Permalink
[SPARK-14489][ML][PYSPARK] ALS unknown user/item prediction strategy
Browse files Browse the repository at this point in the history
This PR adds a param to `ALS`/`ALSModel` to set the strategy used when encountering unknown users or items at prediction time in `transform`. This can occur in 2 scenarios: (a) production scoring, and (b) cross-validation & evaluation.

The current behavior returns `NaN` if a user/item is unknown. In scenario (b), this can easily occur when using `CrossValidator` or `TrainValidationSplit` since some users/items may only occur in the test set and not in the training set. In this case, the evaluator returns `NaN` for all metrics, making model selection impossible.

The new param, `coldStartStrategy`, defaults to `nan` (the current behavior). The other option supported initially is `drop`, which drops all rows with `NaN` predictions. This flag allows users to use `ALS` in cross-validation settings. It is made an `expertParam`. The param is made a string so that the set of strategies can be extended in future (some options are discussed in [SPARK-14489](https://issues.apache.org/jira/browse/SPARK-14489)).
## How was this patch tested?

New unit tests, and manual "before and after" tests for Scala & Python using MovieLens `ml-latest-small` as example data. Here, using `CrossValidator` or `TrainValidationSplit` with the default param setting results in metrics that are all `NaN`, while setting `coldStartStrategy` to `drop` results in valid metrics.

Author: Nick Pentreath <nickp@za.ibm.com>

Closes #12896 from MLnick/SPARK-14489-als-nan.
  • Loading branch information
Nick Pentreath committed Feb 28, 2017
1 parent 9b8eca6 commit b405466
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 9 deletions.
44 changes: 42 additions & 2 deletions mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,27 @@ private[recommendation] trait ALSModelParams extends Params with HasPredictionCo
n.toInt
}
}

/**
* Param for strategy for dealing with unknown or new users/items at prediction time.
* This may be useful in cross-validation or production scenarios, for handling user/item ids
* the model has not seen in the training data.
* Supported values:
* - "nan": predicted value for unknown ids will be NaN.
* - "drop": rows in the input DataFrame containing unknown ids will be dropped from
* the output DataFrame containing predictions.
* Default: "nan".
* @group expertParam
*/
val coldStartStrategy = new Param[String](this, "coldStartStrategy",
"strategy for dealing with unknown or new users/items at prediction time. This may be " +
"useful in cross-validation or production scenarios, for handling user/item ids the model " +
"has not seen in the training data. Supported values: " +
s"${ALSModel.supportedColdStartStrategies.mkString(",")}.",
(s: String) => ALSModel.supportedColdStartStrategies.contains(s.toLowerCase))

/** @group expertGetParam */
def getColdStartStrategy: String = $(coldStartStrategy).toLowerCase
}

/**
Expand Down Expand Up @@ -203,7 +224,8 @@ private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter w
setDefault(rank -> 10, maxIter -> 10, regParam -> 0.1, numUserBlocks -> 10, numItemBlocks -> 10,
implicitPrefs -> false, alpha -> 1.0, userCol -> "user", itemCol -> "item",
ratingCol -> "rating", nonnegative -> false, checkpointInterval -> 10,
intermediateStorageLevel -> "MEMORY_AND_DISK", finalStorageLevel -> "MEMORY_AND_DISK")
intermediateStorageLevel -> "MEMORY_AND_DISK", finalStorageLevel -> "MEMORY_AND_DISK",
coldStartStrategy -> "nan")

/**
* Validates and transforms the input schema.
Expand Down Expand Up @@ -248,6 +270,10 @@ class ALSModel private[ml] (
@Since("1.3.0")
def setPredictionCol(value: String): this.type = set(predictionCol, value)

/** @group expertSetParam */
@Since("2.2.0")
def setColdStartStrategy(value: String): this.type = set(coldStartStrategy, value)

@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema)
Expand All @@ -260,13 +286,19 @@ class ALSModel private[ml] (
Float.NaN
}
}
dataset
val predictions = dataset
.join(userFactors,
checkedCast(dataset($(userCol)).cast(DoubleType)) === userFactors("id"), "left")
.join(itemFactors,
checkedCast(dataset($(itemCol)).cast(DoubleType)) === itemFactors("id"), "left")
.select(dataset("*"),
predict(userFactors("features"), itemFactors("features")).as($(predictionCol)))
getColdStartStrategy match {
case ALSModel.Drop =>
predictions.na.drop("all", Seq($(predictionCol)))
case ALSModel.NaN =>
predictions
}
}

@Since("1.3.0")
Expand All @@ -290,6 +322,10 @@ class ALSModel private[ml] (
@Since("1.6.0")
object ALSModel extends MLReadable[ALSModel] {

private val NaN = "nan"
private val Drop = "drop"
private[recommendation] final val supportedColdStartStrategies = Array(NaN, Drop)

@Since("1.6.0")
override def read: MLReader[ALSModel] = new ALSModelReader

Expand Down Expand Up @@ -432,6 +468,10 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel]
@Since("2.0.0")
def setFinalStorageLevel(value: String): this.type = set(finalStorageLevel, value)

/** @group expertSetParam */
@Since("2.2.0")
def setColdStartStrategy(value: String): this.type = set(coldStartStrategy, value)

/**
* Sets both numUserBlocks and numItemBlocks to the specific value.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -498,8 +498,8 @@ class ALSSuite
(ex, act) =>
ex.userFactors.first().getSeq[Float](1) === act.userFactors.first.getSeq[Float](1)
} { (ex, act, _) =>
ex.transform(_: DataFrame).select("prediction").first.getFloat(0) ~==
act.transform(_: DataFrame).select("prediction").first.getFloat(0) absTol 1e-6
ex.transform(_: DataFrame).select("prediction").first.getDouble(0) ~==
act.transform(_: DataFrame).select("prediction").first.getDouble(0) absTol 1e-6
}
}
// check user/item ids falling outside of Int range
Expand Down Expand Up @@ -547,6 +547,53 @@ class ALSSuite
ALS.train(ratings)
}
}

test("ALS cold start user/item prediction strategy") {
val spark = this.spark
import spark.implicits._
import org.apache.spark.sql.functions._

val (ratings, _) = genExplicitTestData(numUsers = 4, numItems = 4, rank = 1)
val data = ratings.toDF
val knownUser = data.select(max("user")).as[Int].first()
val unknownUser = knownUser + 10
val knownItem = data.select(max("item")).as[Int].first()
val unknownItem = knownItem + 20
val test = Seq(
(unknownUser, unknownItem),
(knownUser, unknownItem),
(unknownUser, knownItem),
(knownUser, knownItem)
).toDF("user", "item")

val als = new ALS().setMaxIter(1).setRank(1)
// default is 'nan'
val defaultModel = als.fit(data)
val defaultPredictions = defaultModel.transform(test).select("prediction").as[Float].collect()
assert(defaultPredictions.length == 4)
assert(defaultPredictions.slice(0, 3).forall(_.isNaN))
assert(!defaultPredictions.last.isNaN)

// check 'drop' strategy should filter out rows with unknown users/items
val dropPredictions = defaultModel
.setColdStartStrategy("drop")
.transform(test)
.select("prediction").as[Float].collect()
assert(dropPredictions.length == 1)
assert(!dropPredictions.head.isNaN)
assert(dropPredictions.head ~== defaultPredictions.last relTol 1e-14)
}

test("case insensitive cold start param value") {
val spark = this.spark
import spark.implicits._
val (ratings, _) = genExplicitTestData(numUsers = 2, numItems = 2, rank = 1)
val data = ratings.toDF
val model = new ALS().fit(data)
Seq("nan", "NaN", "Nan", "drop", "DROP", "Drop").foreach { s =>
model.setColdStartStrategy(s).transform(data)
}
}
}

class ALSCleanerSuite extends SparkFunSuite {
Expand Down
30 changes: 25 additions & 5 deletions python/pyspark/ml/recommendation.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,27 +125,33 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
finalStorageLevel = Param(Params._dummy(), "finalStorageLevel",
"StorageLevel for ALS model factors.",
typeConverter=TypeConverters.toString)
coldStartStrategy = Param(Params._dummy(), "coldStartStrategy", "strategy for dealing with " +
"unknown or new users/items at prediction time. This may be useful " +
"in cross-validation or production scenarios, for handling " +
"user/item ids the model has not seen in the training data. " +
"Supported values: 'nan', 'drop'.",
typeConverter=TypeConverters.toString)

@keyword_only
def __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10,
implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None,
ratingCol="rating", nonnegative=False, checkpointInterval=10,
intermediateStorageLevel="MEMORY_AND_DISK",
finalStorageLevel="MEMORY_AND_DISK"):
finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan"):
"""
__init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, \
implicitPrefs=false, alpha=1.0, userCol="user", itemCol="item", seed=None, \
ratingCol="rating", nonnegative=false, checkpointInterval=10, \
intermediateStorageLevel="MEMORY_AND_DISK", \
finalStorageLevel="MEMORY_AND_DISK")
finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan")
"""
super(ALS, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.recommendation.ALS", self.uid)
self._setDefault(rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10,
implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item",
ratingCol="rating", nonnegative=False, checkpointInterval=10,
intermediateStorageLevel="MEMORY_AND_DISK",
finalStorageLevel="MEMORY_AND_DISK")
finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan")
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)

Expand All @@ -155,13 +161,13 @@ def setParams(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItem
implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None,
ratingCol="rating", nonnegative=False, checkpointInterval=10,
intermediateStorageLevel="MEMORY_AND_DISK",
finalStorageLevel="MEMORY_AND_DISK"):
finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan"):
"""
setParams(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, \
implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None, \
ratingCol="rating", nonnegative=False, checkpointInterval=10, \
intermediateStorageLevel="MEMORY_AND_DISK", \
finalStorageLevel="MEMORY_AND_DISK")
finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan")
Sets params for ALS.
"""
kwargs = self.setParams._input_kwargs
Expand Down Expand Up @@ -332,6 +338,20 @@ def getFinalStorageLevel(self):
"""
return self.getOrDefault(self.finalStorageLevel)

@since("2.2.0")
def setColdStartStrategy(self, value):
"""
Sets the value of :py:attr:`coldStartStrategy`.
"""
return self._set(coldStartStrategy=value)

@since("2.2.0")
def getColdStartStrategy(self):
"""
Gets the value of coldStartStrategy or its default value.
"""
return self.getOrDefault(self.coldStartStrategy)


class ALSModel(JavaModel, JavaMLWritable, JavaMLReadable):
"""
Expand Down

0 comments on commit b405466

Please sign in to comment.