From 5eee170bfa67b7dd75b4a93c30559c354b99b541 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 2 Nov 2017 18:03:23 -0700 Subject: [PATCH] Move ClusteringEvaluatorSuite test data to data/mllib. --- .../iris.libsvm => data/mllib/iris_libsvm.txt | 0 .../evaluation/ClusteringEvaluatorSuite.scala | 30 +++++++------------ 2 files changed, 11 insertions(+), 19 deletions(-) rename mllib/src/test/resources/test-data/iris.libsvm => data/mllib/iris_libsvm.txt (100%) diff --git a/mllib/src/test/resources/test-data/iris.libsvm b/data/mllib/iris_libsvm.txt similarity index 100% rename from mllib/src/test/resources/test-data/iris.libsvm rename to data/mllib/iris_libsvm.txt diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala index e60ebbd7c852d..677ce49a903ab 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala @@ -22,8 +22,7 @@ import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, SparkSession} -import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.Dataset class ClusteringEvaluatorSuite @@ -31,6 +30,13 @@ class ClusteringEvaluatorSuite import testImplicits._ + @transient var irisDataset: Dataset[_] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + irisDataset = spark.read.format("libsvm").load("../data/mllib/iris_libsvm.txt") + } + test("params") { ParamsSuite.checkParams(new ClusteringEvaluator) } @@ -53,37 +59,23 @@ class ClusteringEvaluatorSuite 0.6564679231 */ test("squared euclidean Silhouette") { - val iris = ClusteringEvaluatorSuite.irisDataset(spark) val evaluator = new ClusteringEvaluator() .setFeaturesCol("features") .setPredictionCol("label") - assert(evaluator.evaluate(iris) ~== 0.6564679231 relTol 1e-5) + assert(evaluator.evaluate(irisDataset) ~== 0.6564679231 relTol 1e-5) } test("number of clusters must be greater than one") { - val iris = ClusteringEvaluatorSuite.irisDataset(spark) - .where($"label" === 0.0) + val singleClusterDataset = irisDataset.where($"label" === 0.0) val evaluator = new ClusteringEvaluator() .setFeaturesCol("features") .setPredictionCol("label") val e = intercept[AssertionError]{ - evaluator.evaluate(iris) + evaluator.evaluate(singleClusterDataset) } assert(e.getMessage.contains("Number of clusters must be greater than one")) } } - -object ClusteringEvaluatorSuite { - def irisDataset(spark: SparkSession): DataFrame = { - - val irisPath = Thread.currentThread() - .getContextClassLoader - .getResource("test-data/iris.libsvm") - .toString - - spark.read.format("libsvm").load(irisPath) - } -}