Skip to content

Commit

Permalink
Convert it to a df and use set for the inital params
Browse files Browse the repository at this point in the history
  • Loading branch information
holdenk committed May 25, 2015
1 parent e8e03a1 commit 38a024b
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,17 @@ class LogisticRegression(override val uid: String)
def setThreshold(value: Double): this.type = set(threshold, value)
setDefault(threshold -> 0.5)

override protected def train(dataset: DataFrame): LogisticRegressionModel = {
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
train(extractLabeledPoints(dataset), handlePersistence, None)
private var optInitialWeights: Option[Vector] = None
/** @group setParam */
def setInitialWeights(value: Vector): this.type = {
this.optInitialWeights = Some(value)
this
}
private [spark] def train(dataset: RDD[LabeledPoint], handlePersistence: Boolean,
optInitialWeights: Option[Vector]=None): LogisticRegressionModel = {

override protected[spark] def train(dataset: DataFrame): LogisticRegressionModel = {
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
// Extract columns from data. If dataset is persisted, do not persist instances.
val instances = dataset.map {
val instances = extractLabeledPoints(dataset).map {
case LabeledPoint(label: Double, features: Vector) => (label, features)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.mllib.pmml.PMMLExportable
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.{DataValidators, Saveable, Loader}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
import org.apache.spark.storage.StorageLevel

/**
Expand Down Expand Up @@ -383,12 +384,28 @@ class LogisticRegressionWithLBFGS
// ml's Logisitic regression only supports binary classifcation currently.
if (numOfLinearPredictor == 1 && useFeatureScaling) {
def runWithMlLogisitcRegression(elasticNetParam: Double) = {
// Prepare the ml LogisticRegression based on our settings
val lr = new org.apache.spark.ml.classification.LogisticRegression()
lr.setRegParam(optimizer.getRegParam())
val handlePersistence = input.getStorageLevel == StorageLevel.NONE
lr.setElasticNetParam(elasticNetParam)
val initialWeightsWithIntercept = Vectors.dense(0.0, initialWeights.toArray:_*)
val mlLogisticRegresionModel = lr.train(input, handlePersistence,
Some(initialWeightsWithIntercept))// TODO swap back to including the initialWeights
lr.setInitialWeights(initialWeightsWithIntercept)
// Convert our input into a DataFrame
val sqlContext = new SQLContext(input.context)
import sqlContext.implicits._
val df = input.toDF()
// Determine if we should cache the DF
val handlePersistence = input.getStorageLevel == StorageLevel.NONE
if (handlePersistence) {
df.persist(StorageLevel.MEMORY_AND_DISK)
}
// Train our model
val mlLogisticRegresionModel = lr.train(df)
// unpersist if we persisted
if (handlePersistence) {
df.unpersist()
}
// convert the model
createModel(mlLogisticRegresionModel.weights, mlLogisticRegresionModel.intercept)
}
optimizer.getUpdater() match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ private class MockLogisticRegression(uid: String) extends LogisticRegression(uid

setMaxIter(1)

override protected def train(dataset: DataFrame): LogisticRegressionModel = {
override protected[spark] def train(dataset: DataFrame): LogisticRegressionModel = {
val labelSchema = dataset.schema($(labelCol))
// check for label attribute propagation.
assert(MetadataUtils.getNumClasses(labelSchema).forall(_ == 2))
Expand Down

0 comments on commit 38a024b

Please sign in to comment.