Skip to content

Commit

Permalink
For comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Jun 19, 2015
1 parent b5f52c1 commit c3dd8d9
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,13 @@ final class RegressionEvaluator(override val uid: String)
val metrics = new RegressionMetrics(predictionAndLabels)
val metric = $(metricName) match {
case "rmse" =>
metrics.rootMeanSquaredError
1 / metrics.rootMeanSquaredError
case "mse" =>
metrics.meanSquaredError
1 / metrics.meanSquaredError
case "r2" =>
metrics.r2
case "mae" =>
metrics.meanAbsoluteError
1 / metrics.meanAbsoluteError
}
metric
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,20 +76,6 @@ private[ml] trait CrossValidatorParams extends Params {
def getNumFolds: Int = $(numFolds)

setDefault(numFolds -> 3)

/**
* Param for whether maximize the evaluation value during cross validation.
* If false, turn to minimize the evaluation value.
* Default: true
* @group param
*/
val useMax: BooleanParam = new BooleanParam(this, "useMax",
"whether maximize the evaluation value durin cross validation")

/** @group getParam */
def getUseMax: Boolean = $(useMax)

setDefault(useMax -> true)
}

/**
Expand All @@ -116,9 +102,6 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM
/** @group setParam */
def setNumFolds(value: Int): this.type = set(numFolds, value)

/** @group setParam */
def setUseMax(value: Boolean): this.type = set(useMax, value)

override def fit(dataset: DataFrame): CrossValidatorModel = {
val schema = dataset.schema
transformSchema(schema, logging = true)
Expand Down Expand Up @@ -148,11 +131,7 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM
}
f2jBLAS.dscal(numModels, 1.0 / $(numFolds), metrics, 1)
logInfo(s"Average cross-validation metrics: ${metrics.toSeq}")
val (bestMetric, bestIndex) = if ($(useMax)) {
metrics.zipWithIndex.maxBy(_._1)
} else {
metrics.zipWithIndex.minBy(_._1)
}
val (bestMetric, bestIndex) = metrics.zipWithIndex.maxBy(_._1)
logInfo(s"Best set of parameters:\n${epm(bestIndex)}")
logInfo(s"Best cross-validation metric: $bestMetric.")
val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,18 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext {
.setEstimatorParamMaps(lrParamMaps)
.setEvaluator(eval)
.setNumFolds(3)
.setUseMax(false)
val cvModel = cv.fit(dataset)
val parent = cvModel.bestModel.parent.asInstanceOf[LinearRegression]
assert(parent.getRegParam === 0.001)
assert(parent.getMaxIter === 10)
assert(cvModel.avgMetrics.length === lrParamMaps.length)

eval.setMetricName("r2")
val cvModel2 = cv.fit(dataset)
val parent2 = cvModel2.bestModel.parent.asInstanceOf[LinearRegression]
assert(parent2.getRegParam === 0.001)
assert(parent2.getMaxIter === 10)
assert(cvModel2.avgMetrics.length === lrParamMaps.length)
}

test("validateParams should check estimatorParamMaps") {
Expand Down

0 comments on commit c3dd8d9

Please sign in to comment.