Skip to content

Commit

Permalink
made the isClassification parameter for the check numeric types metho…
Browse files Browse the repository at this point in the history
…d optional
  • Loading branch information
BenFradet committed May 12, 2016
1 parent a8e5aa0 commit ce19549
Show file tree
Hide file tree
Showing 17 changed files with 17 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ class DecisionTreeClassifierSuite
test("should support all NumericType labels and not support other types") {
val dt = new DecisionTreeClassifier().setMaxDepth(1)
MLTestingUtils.checkNumericTypes[DecisionTreeClassificationModel, DecisionTreeClassifier](
dt, isClassification = true, spark) { (expected, actual) =>
dt, spark) { (expected, actual) =>
TreeTests.checkEqual(expected, actual)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
test("should support all NumericType labels and not support other types") {
val gbt = new GBTClassifier().setMaxDepth(1)
MLTestingUtils.checkNumericTypes[GBTClassificationModel, GBTClassifier](
gbt, isClassification = true, spark) { (expected, actual) =>
gbt, spark) { (expected, actual) =>
TreeTests.checkEqual(expected, actual)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -938,7 +938,7 @@ class LogisticRegressionSuite
test("should support all NumericType labels and not support other types") {
val lr = new LogisticRegression().setMaxIter(1)
MLTestingUtils.checkNumericTypes[LogisticRegressionModel, LogisticRegression](
lr, isClassification = true, spark) { (expected, actual) =>
lr, spark) { (expected, actual) =>
assert(expected.intercept === actual.intercept)
assert(expected.coefficients.toArray === actual.coefficients.toArray)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ class MultilayerPerceptronClassifierSuite
val mpc = new MultilayerPerceptronClassifier().setLayers(layers).setMaxIter(1)
MLTestingUtils.checkNumericTypes[
MultilayerPerceptronClassificationModel, MultilayerPerceptronClassifier](
mpc, isClassification = true, spark) { (expected, actual) =>
mpc, spark) { (expected, actual) =>
assert(expected.layers === actual.layers)
assert(expected.weights === actual.weights)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
test("should support all NumericType labels and not support other types") {
val nb = new NaiveBayes()
MLTestingUtils.checkNumericTypes[NaiveBayesModel, NaiveBayes](
nb, isClassification = true, spark) { (expected, actual) =>
nb, spark) { (expected, actual) =>
assert(expected.pi === actual.pi)
assert(expected.theta === actual.theta)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
test("should support all NumericType labels and not support other types") {
val ovr = new OneVsRest().setClassifier(new LogisticRegression().setMaxIter(1))
MLTestingUtils.checkNumericTypes[OneVsRestModel, OneVsRest](
ovr, isClassification = true, spark) { (expected, actual) =>
ovr, spark) { (expected, actual) =>
val expectedModels = expected.models.map(m => m.asInstanceOf[LogisticRegressionModel])
val actualModels = actual.models.map(m => m.asInstanceOf[LogisticRegressionModel])
assert(expectedModels.length === actualModels.length)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ class RandomForestClassifierSuite
test("should support all NumericType labels and not support other types") {
val rf = new RandomForestClassifier().setMaxDepth(1)
MLTestingUtils.checkNumericTypes[RandomForestClassificationModel, RandomForestClassifier](
rf, isClassification = true, spark) { (expected, actual) =>
rf, spark) { (expected, actual) =>
TreeTests.checkEqual(expected, actual)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext
test("should support all NumericType labels and not support other types") {
val css = new ChiSqSelector()
MLTestingUtils.checkNumericTypes[ChiSqSelectorModel, ChiSqSelector](
css, isClassification = true, sqlContext) { (expected, actual) =>
css, sqlContext) { (expected, actual) =>
assert(expected.selectedFeatures === actual.selectedFeatures)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,6 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
).toDF("id", "a", "b")
val model = formula.fit(original)
val result = model.transform(original)
val resultSchema = model.transformSchema(original.schema)
val expected = spark.createDataFrame(
Seq(
("male", "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ class AFTSurvivalRegressionSuite
test("should support all NumericType labels") {
val aft = new AFTSurvivalRegression().setMaxIter(1)
MLTestingUtils.checkNumericTypes[AFTSurvivalRegressionModel, AFTSurvivalRegression](
aft, isClassification = false, spark) { (expected, actual) =>
aft, spark, isClassification = false) { (expected, actual) =>
assert(expected.intercept === actual.intercept)
assert(expected.coefficients === actual.coefficients)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ class DecisionTreeRegressorSuite
test("should support all NumericType labels and not support other types") {
val dt = new DecisionTreeRegressor().setMaxDepth(1)
MLTestingUtils.checkNumericTypes[DecisionTreeRegressionModel, DecisionTreeRegressor](
dt, isClassification = false, spark) { (expected, actual) =>
dt, spark, isClassification = false) { (expected, actual) =>
TreeTests.checkEqual(expected, actual)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext
test("should support all NumericType labels and not support other types") {
val gbt = new GBTRegressor().setMaxDepth(1)
MLTestingUtils.checkNumericTypes[GBTRegressionModel, GBTRegressor](
gbt, isClassification = false, spark) { (expected, actual) =>
gbt, spark, isClassification = false) { (expected, actual) =>
TreeTests.checkEqual(expected, actual)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1021,7 +1021,7 @@ class GeneralizedLinearRegressionSuite
val glr = new GeneralizedLinearRegression().setMaxIter(1)
MLTestingUtils.checkNumericTypes[
GeneralizedLinearRegressionModel, GeneralizedLinearRegression](
glr, isClassification = false, spark) { (expected, actual) =>
glr, spark, isClassification = false) { (expected, actual) =>
assert(expected.intercept === actual.intercept)
assert(expected.coefficients === actual.coefficients)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ class IsotonicRegressionSuite
test("should support all NumericType labels and not support other types") {
val ir = new IsotonicRegression()
MLTestingUtils.checkNumericTypes[IsotonicRegressionModel, IsotonicRegression](
ir, isClassification = false, spark) { (expected, actual) =>
ir, spark, isClassification = false) { (expected, actual) =>
assert(expected.boundaries === actual.boundaries)
assert(expected.predictions === actual.predictions)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1010,7 +1010,7 @@ class LinearRegressionSuite
test("should support all NumericType labels and not support other types") {
val lr = new LinearRegression().setMaxIter(1)
MLTestingUtils.checkNumericTypes[LinearRegressionModel, LinearRegression](
lr, isClassification = false, spark) { (expected, actual) =>
lr, spark, isClassification = false) { (expected, actual) =>
assert(expected.intercept === actual.intercept)
assert(expected.coefficients === actual.coefficients)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex
test("should support all NumericType labels and not support other types") {
val rf = new RandomForestRegressor().setMaxDepth(1)
MLTestingUtils.checkNumericTypes[RandomForestRegressionModel, RandomForestRegressor](
rf, isClassification = false, spark) { (expected, actual) =>
rf, spark, isClassification = false) { (expected, actual) =>
TreeTests.checkEqual(expected, actual)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ object MLTestingUtils extends SparkFunSuite {

def checkNumericTypes[M <: Model[M], T <: Estimator[M]](
estimator: T,
isClassification: Boolean,
spark: SparkSession)(check: (M, M) => Unit): Unit = {
spark: SparkSession,
isClassification: Boolean = true)(check: (M, M) => Unit): Unit = {
val dfs = if (isClassification) {
genClassifDFWithNumericLabelCol(spark)
} else {
Expand Down

0 comments on commit ce19549

Please sign in to comment.