Skip to content

Commit

Permalink
Fixed Linear/Logistic RegressionSuites
Browse files Browse the repository at this point in the history
  • Loading branch information
jkbradley committed Feb 5, 2015
1 parent c3c8da5 commit 0a16da9
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import org.scalatest.FunSuite

import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
Expand Down Expand Up @@ -107,39 +106,26 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
import sqlContext._
val lr = new LogisticRegression

// fit() vs. train()
val model1 = lr.fit(dataset)
val rdd = dataset.select('label, 'features).map { case Row(label: Double, features: Vector) =>
LabeledPoint(label, features)
}
val featuresRDD = rdd.map(_.features)
val model2 = lr.train(rdd)
assert(model1.intercept == model2.intercept)
assert(model1.weights.equals(model2.weights))
assert(model1.numClasses == model2.numClasses)
assert(model1.numClasses === 2)

// transform() vs. predict()
val trans = model1.transform(dataset).select('prediction)
val preds = model1.predict(rdd.map(_.features))
trans.zip(preds).collect().foreach { case (Row(pred1: Double), pred2: Double) =>
assert(pred1 == pred2)
val model = lr.fit(dataset)
assert(model.numClasses === 2)

val threshold = model.getThreshold
val results = model.transform(dataset)

// Compare rawPrediction with probability
results.select('rawPrediction, 'probability).collect().map {
case Row(raw: Vector, prob: Vector) =>
val raw2prob: (Double => Double) = (m) => 1.0 / (1.0 + math.exp(-m))
raw.toArray.map(raw2prob).zip(prob.toArray).foreach { case (r, p) =>
assert(r ~== p relTol eps)
}
}

// Check various types of predictions.
val rawPredictions = model1.predictRaw(featuresRDD)
val probabilities = model1.predictProbabilities(featuresRDD)
val predictions = model1.predict(featuresRDD)
val threshold = model1.getThreshold
rawPredictions.zip(probabilities).collect().foreach { case (raw: Vector, prob: Vector) =>
val computeProbFromRaw: (Double => Double) = (m) => 1.0 / (1.0 + math.exp(-m))
raw.toArray.map(computeProbFromRaw).zip(prob.toArray).foreach { case (r, p) =>
assert(r ~== p relTol eps)
}
}
probabilities.zip(predictions).collect().foreach { case (prob: Vector, pred: Double) =>
val predFromProb = prob.toArray.zipWithIndex.maxBy(_._1)._2
assert(pred == predFromProb)
// Compare prediction with probability
results.select('prediction, 'probability).collect().map {
case Row(pred: Double, prob: Vector) =>
val predFromProb = prob.toArray.zipWithIndex.maxBy(_._1)._2
assert(pred == predFromProb)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -69,27 +69,4 @@ class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext {
assert(model2.fittingParamMap.get(lr.regParam).get === 0.1)
assert(model2.getPredictionCol == "thePred")
}

test("linear regression: Predictor, Regressor methods") {
val sqlContext = this.sqlContext
import sqlContext._
val lr = new LinearRegression

// fit() vs. train()
val model1 = lr.fit(dataset)
val rdd = dataset.select('label, 'features).map { case Row(label: Double, features: Vector) =>
LabeledPoint(label, features)
}
val features = rdd.map(_.features)
val model2 = lr.train(rdd)
assert(model1.intercept == model2.intercept)
assert(model1.weights.equals(model2.weights))

// transform() vs. predict()
val trans = model1.transform(dataset).select('prediction)
val preds = model1.predict(rdd.map(_.features))
trans.zip(preds).collect().foreach { case (Row(pred1: Double), pred2: Double) =>
assert(pred1 == pred2)
}
}
}

0 comments on commit 0a16da9

Please sign in to comment.