Skip to content

Commit

Permalink
CR feedback, pass RDD of Labeled points to ml implemetnation. Also fr…
Browse files Browse the repository at this point in the history
…om tests require that feature scaling is turned on to use ml implementation.
  • Loading branch information
holdenk committed May 24, 2015
1 parent 4febcc3 commit e8e03a1
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -101,16 +101,16 @@ class LogisticRegression(override val uid: String)
setDefault(threshold -> 0.5)

override protected def train(dataset: DataFrame): LogisticRegressionModel = {
// Extract columns from data. If dataset is persisted, do not persist oldDataset.
val instances = extractLabeledPoints(dataset).map {
case LabeledPoint(label: Double, features: Vector) => (label, features)
}
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
trainOnInstances(instances, handlePersistence)
train(extractLabeledPoints(dataset), handlePersistence, None)
}
private [spark] def train(dataset: RDD[LabeledPoint], handlePersistence: Boolean,
optInitialWeights: Option[Vector]=None): LogisticRegressionModel = {
// Extract columns from data. If dataset is persisted, do not persist instances.
val instances = dataset.map {
case LabeledPoint(label: Double, features: Vector) => (label, features)
}

protected[spark] def trainOnInstances(instances: RDD[(Double, Vector)],
handlePersistence: Boolean, optInitialWeights: Option[Vector]=None): LogisticRegressionModel = {
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)

val (summarizer, labelSummarizer) = instances.treeAggregate(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.spark.SparkContext
import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.classification.impl.GLMClassificationModel
import org.apache.spark.mllib.linalg.BLAS.dot
import org.apache.spark.mllib.linalg.{DenseVector, Vector}
import org.apache.spark.mllib.linalg.{DenseVector, Vector, Vectors}
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.pmml.PMMLExportable
import org.apache.spark.mllib.regression._
Expand Down Expand Up @@ -377,20 +377,18 @@ class LogisticRegressionWithLBFGS
* If a known updater is used calls the ml implementation, to avoid
* applying a regularization penalty to the intercept, otherwise
* defaults to the mllib implementation. If more than two classes
* always uses mllib implementation.
* or feature scaling is disabled, always uses mllib implementation.
*/
override def run(input: RDD[LabeledPoint], initialWeights: Vector): LogisticRegressionModel = {
// ml's Logisitic regression only supports binary classifcation currently.
if (numOfLinearPredictor == 1) {
if (numOfLinearPredictor == 1 && useFeatureScaling) {
def runWithMlLogisitcRegression(elasticNetParam: Double) = {
val lr = new org.apache.spark.ml.classification.LogisticRegression()
lr.setRegParam(optimizer.getRegParam())
val handlePersistence = input.getStorageLevel == StorageLevel.NONE
val instances = input.map {
case LabeledPoint(label: Double, features: Vector) => (label, features)
}
val mlLogisticRegresionModel = lr.trainOnInstances(instances, handlePersistence,
Some(initialWeights))
val initialWeightsWithIntercept = Vectors.dense(0.0, initialWeights.toArray:_*)
val mlLogisticRegresionModel = lr.train(input, handlePersistence,
Some(initialWeightsWithIntercept))// TODO swap back to including the initialWeights
createModel(mlLogisticRegresionModel.weights, mlLogisticRegresionModel.intercept)
}
optimizer.getUpdater() match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
* translated back to resulting model weights, so it's transparent to users.
* Note: This technique is used in both libsvm and glmnet packages. Default false.
*/
private var useFeatureScaling = false
private[mllib] var useFeatureScaling = false

/**
* The dimension of training features.
Expand Down

0 comments on commit e8e03a1

Please sign in to comment.