Skip to content

Commit

Permalink
extended checking when testing RFormula against other types for the l…
Browse files Browse the repository at this point in the history
…abel column
  • Loading branch information
BenFradet committed May 12, 2016
1 parent ce19549 commit 3786ef9
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
Expand Up @@ -85,7 +85,7 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext
test("should support all NumericType labels and not support other types") {
val css = new ChiSqSelector()
MLTestingUtils.checkNumericTypes[ChiSqSelectorModel, ChiSqSelector](
css, sqlContext) { (expected, actual) =>
css, spark) { (expected, actual) =>
assert(expected.selectedFeatures === actual.selectedFeatures)
}
}
Expand Down
Expand Up @@ -312,11 +312,16 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("should support all NumericType labels") {
val formula = new RFormula().setFormula("label ~ features")
.setLabelCol("x")
.setFeaturesCol("z")
.setFeaturesCol("y")
val dfs = MLTestingUtils.genRegressionDFWithNumericLabelCol(spark)
val expected = formula.fit(dfs(DoubleType))
val actuals = dfs.keys.filter(_ != DoubleType).map(t => formula.fit(dfs(t)))
actuals.foreach { actual =>
assert(expected.pipelineModel.stages.length === actual.pipelineModel.stages.length)
expected.pipelineModel.stages.zip(actual.pipelineModel.stages).foreach {
case (exTransformer, acTransformer) =>
assert(exTransformer.params === acTransformer.params)
}
assert(expected.resolvedFormula.label === actual.resolvedFormula.label)
assert(expected.resolvedFormula.terms === actual.resolvedFormula.terms)
assert(expected.resolvedFormula.hasIntercept === actual.resolvedFormula.hasIntercept)
Expand Down

0 comments on commit 3786ef9

Please sign in to comment.