Skip to content

Commit

Permalink
* Changed semantics of Predictor.train() to merge the given paramMap …
Browse files Browse the repository at this point in the history
…with the embedded paramMap.

* remove threshold_internal from logreg
* Added Predictor.copy()
* Extended LogisticRegressionSuite
  • Loading branch information
jkbradley committed Feb 5, 2015
1 parent e433872 commit 57d54ab
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class LogisticRegression extends Classifier[LogisticRegression, LogisticRegressi
* These values override any specified in this Estimator's embedded ParamMap.
*/
def train(dataset: RDD[LabeledPoint], paramMap: ParamMap): LogisticRegressionModel = {
val map = this.paramMap ++ paramMap
val oldDataset = dataset.map { case LabeledPoint(label: Double, features: Vector, weight) =>
org.apache.spark.mllib.regression.LabeledPoint(label, features)
}
Expand All @@ -86,14 +87,13 @@ class LogisticRegression extends Classifier[LogisticRegression, LogisticRegressi
}
val lr = new LogisticRegressionWithLBFGS
lr.optimizer
.setRegParam(paramMap(regParam))
.setNumIterations(paramMap(maxIter))
.setRegParam(map(regParam))
.setNumIterations(map(maxIter))
val model = lr.run(oldDataset)
val lrm = new LogisticRegressionModel(this, paramMap, model.weights, model.intercept)
val lrm = new LogisticRegressionModel(this, map, model.weights, model.intercept)
if (handlePersistence) {
oldDataset.unpersist()
}
lrm.setThreshold(paramMap(threshold))
lrm
}
}
Expand All @@ -115,18 +115,9 @@ class LogisticRegressionModel private[ml] (

setThreshold(0.5)

def setThreshold(value: Double): this.type = {
this.threshold_internal = value
set(threshold, value)
}
def setThreshold(value: Double): this.type = set(threshold, value)
def setScoreCol(value: String): this.type = set(scoreCol, value)

/**
* Store for faster test-time prediction.
* Initialized to threshold in fittingParamMap if exists, else default threshold.
*/
private var threshold_internal: Double = fittingParamMap.get(threshold).getOrElse(getThreshold)

private val margin: Vector => Double = (features) => {
BLAS.dot(features, weights) + intercept
}
Expand All @@ -142,7 +133,8 @@ class LogisticRegressionModel private[ml] (
val scoreFunction = udf { v: Vector =>
val margin = BLAS.dot(v, weights)
1.0 / (1.0 + math.exp(-margin))
val t = threshold_internal
}
val t = map(threshold)
val predictFunction: Double => Double = (score) => { if (score > t) 1.0 else 0.0 }
dataset
.select($"*", scoreFunction(col(map(featuresCol))).as(map(scoreCol)))
Expand All @@ -151,12 +143,14 @@ class LogisticRegressionModel private[ml] (

override val numClasses: Int = 2

// TODO: Override batch predict() for efficiency.

/**
* Predict label for the given feature vector.
* The behavior of this can be adjusted using [[threshold]].
*/
override def predict(features: Vector): Double = {
if (score(features) > threshold_internal) 1 else 0
if (score(features) > paramMap(threshold)) 1 else 0
}

override def predictProbabilities(features: Vector): Vector = {
Expand All @@ -168,4 +162,10 @@ class LogisticRegressionModel private[ml] (
val m = margin(features)
Vectors.dense(Array(-m, m))
}

private[ml] override def copy(): LogisticRegressionModel = {
val m = new LogisticRegressionModel(parent, fittingParamMap, weights, intercept)
Params.inheritValues(this.paramMap, this, m)
m
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -128,22 +128,33 @@ private[ml] abstract class PredictionModel[M <: PredictionModel[M]]

transformSchema(dataset.schema, paramMap, logging = true)
val map = this.paramMap ++ paramMap
val tmpModel = this.copy()
Params.inheritValues(paramMap, parent, tmpModel)
val pred: Vector => Double = (features) => {
predict(features)
tmpModel.predict(features)
}
dataset.select(Star(None), pred.call(map(featuresCol).attr) as map(predictionCol))
}

/**
* Default implementation.
* Override for efficiency; e.g., this does not broadcast the model.
* Default implementation using single-instance predict().
*
* Developers should override this for efficiency. E.g., this does not broadcast the model.
*/
def predict(dataset: RDD[Vector]): RDD[Double] = {
dataset.map(predict)
def predict(dataset: RDD[Vector], paramMap: ParamMap): RDD[Double] = {
val tmpModel = this.copy()
Params.inheritValues(paramMap, parent, tmpModel)
dataset.map(tmpModel.predict)
}

/**
* Predict label for the given features.
*/
def predict(features: Vector): Double

/**
* Create a copy of the model.
* The copy is shallow, except for the embedded paramMap, which gets a deep copy.
*/
private[ml] def copy(): M
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.ml.regression

import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.LabeledPoint
import org.apache.spark.ml.param.{ParamMap, HasMaxIter, HasRegParam}
import org.apache.spark.ml.param.{Params, ParamMap, HasMaxIter, HasRegParam}
import org.apache.spark.mllib.linalg.{BLAS, Vector}
import org.apache.spark.mllib.regression.LinearRegressionWithSGD
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -89,4 +89,10 @@ class LinearRegressionModel private[ml] (
override def predict(features: Vector): Double = {
BLAS.dot(features, weights) + intercept
}

private[ml] override def copy(): LinearRegressionModel = {
val m = new LinearRegressionModel(parent, fittingParamMap, weights, intercept)
Params.inheritValues(this.paramMap, this, m)
m
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ import org.scalatest.FunSuite

import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{SQLContext, DataFrame}
import org.apache.spark.sql.{DataFrame, Row, SQLContext}


class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {

Expand All @@ -32,21 +33,29 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
super.beforeAll()
sqlContext = new SQLContext(sc)
dataset = sqlContext.createDataFrame(
sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2))
sc.parallelize(generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42), 2))
}

test("logistic regression") {
test("logistic regression: default params") {
val lr = new LogisticRegression
val model = lr.fit(dataset)
model.transform(dataset)
.select("label", "prediction")
.collect()
// Check defaults
assert(model.getThreshold === 0.5)
assert(model.getFeaturesCol == "features")
assert(model.getPredictionCol == "prediction")
assert(model.getScoreCol == "score")
}

test("logistic regression with setters") {
// Set params, train, and check as many as we can.
val lr = new LogisticRegression()
.setMaxIter(10)
.setRegParam(1.0)
.setThreshold(0.6)
.setScoreCol("probability")
val model = lr.fit(dataset)
model.transform(dataset, model.threshold -> 0.8) // overwrite threshold
.select("label", "score", "prediction")
Expand All @@ -58,6 +67,33 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
val model = lr.fit(dataset, lr.maxIter -> 10, lr.regParam -> 1.0)
model.transform(dataset, model.threshold -> 0.8, model.scoreCol -> "probability")
.select("label", "probability", "prediction")
assert(model.fittingParamMap.get(lr.maxIter) === Some(10))
assert(model.fittingParamMap.get(lr.regParam) === Some(1.0))
assert(model.fittingParamMap.get(lr.threshold) === Some(0.6))
assert(model.getThreshold === 0.6)

// Modify model params, and check that they work.
model.setThreshold(1.0)
val predAllZero = model.transform(dataset)
.select('prediction, 'probability)
.collect()
.map { case Row(pred: Double, prob: Double) => pred }
assert(predAllZero.forall(_ === 0.0))
// Call transform with params, and check that they work.
val predNotAllZero =
model.transform(dataset, model.threshold -> 0.0, model.scoreCol -> "myProb")
.select('prediction, 'myProb)
.collect()
.map { case Row(pred: Double, prob: Double) => pred }
assert(predNotAllZero.exists(_ !== 0.0))

// Call fit() with new params, and check as many as we can.
val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1, lr.threshold -> 0.4,
lr.scoreCol -> "theProb")
assert(model2.fittingParamMap.get(lr.maxIter) === Some(5))
assert(model2.fittingParamMap.get(lr.regParam) === Some(0.1))
assert(model2.fittingParamMap.get(lr.threshold) === Some(0.4))
assert(model2.getThreshold === 0.4)
assert(model2.getScoreCol == "theProb")
}
}

0 comments on commit 57d54ab

Please sign in to comment.