Skip to content

Commit

Permalink
CR feedback: make the setInitialWeights function private, don't mess …
Browse files Browse the repository at this point in the history
…with the weights when they are user supploed, validate that the user supplied weights are reasonable.
  • Loading branch information
holdenk committed May 26, 2015
1 parent 478b8c5 commit 08589f5
Showing 1 changed file with 18 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -102,18 +102,27 @@ class LogisticRegression(override val uid: String)

private var optInitialWeights: Option[Vector] = None
/** @group setParam */
def setInitialWeights(value: Vector): this.type = {
private[spark] def setInitialWeights(value: Vector): this.type = {
this.optInitialWeights = Some(value)
this
}

/** Validate the initial weights, return an Option, if not the expected size return None and log */
private def validateWeights(vectorOpt: Option[Vector], numFeatures: Int): Option[Vector] = {
vectorOpt.flatMap(vec =>
if (vec.size == numFeatures) {
Some(vec)
} else {
logWarning(s"""Initial weights provided (${vec})did not match the expected size ${numFeatures}""")
None
})
}

override protected[spark] def train(dataset: DataFrame): LogisticRegressionModel = {
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
// Extract columns from data. If dataset is persisted, do not persist instances.
val instances = extractLabeledPoints(dataset).map {
case LabeledPoint(label: Double, features: Vector) => (label, features)
}

val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)

val (summarizer, labelSummarizer) = instances.treeAggregate(
Expand Down Expand Up @@ -168,10 +177,12 @@ class LogisticRegression(override val uid: String)
new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, regParamL1Fun, $(tol))
}

val initialWeightsWithIntercept = optInitialWeights.getOrElse(
Vectors.zeros(if ($(fitIntercept)) numFeatures + 1 else numFeatures))
val numFeaturesWithIntercept = if ($(fitIntercept)) numFeatures + 1 else numFeatures
val userSuppliedWeights = validateWeights(optInitialWeights, numFeaturesWithIntercept)
val initialWeightsWithIntercept = userSuppliedWeights.getOrElse(
Vectors.zeros(numFeaturesWithIntercept))

if ($(fitIntercept)) {
if ($(fitIntercept) && !userSuppliedWeights.isDefined) {
/**
* For binary logistic regression, when we initialize the weights as zeros,
* it will converge faster if we initialize the intercept such that
Expand Down

0 comments on commit 08589f5

Please sign in to comment.