From 9064e7bde92f206602ebde9b3d99a861b2a90f8a Mon Sep 17 00:00:00 2001 From: zhengruifeng Date: Thu, 26 Jul 2018 15:02:09 +0800 Subject: [PATCH] force to update git username --- .../spark/ml/evaluation/ClusteringEvaluator.scala | 15 +++++++++------ .../ml/evaluation/ClusteringEvaluatorSuite.scala | 15 ++++++++++++++- 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala index 4353c46781e9d..a6d6b4ea8b965 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala @@ -21,11 +21,10 @@ import org.apache.spark.SparkContext import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.broadcast.Broadcast import org.apache.spark.ml.attribute.AttributeGroup -import org.apache.spark.ml.linalg.{BLAS, DenseVector, SparseVector, Vector, Vectors, VectorUDT} +import org.apache.spark.ml.linalg.{BLAS, DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasPredictionCol} -import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, - SchemaUtils} +import org.apache.spark.ml.util._ import org.apache.spark.sql.{Column, DataFrame, Dataset} import org.apache.spark.sql.functions.{avg, col, udf} import org.apache.spark.sql.types.DoubleType @@ -107,15 +106,19 @@ class ClusteringEvaluator @Since("2.3.0") (@Since("2.3.0") override val uid: Str @Since("2.3.0") override def evaluate(dataset: Dataset[_]): Double = { - SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT) + SchemaUtils.validateVectorCompatibleColumn(dataset.schema, $(featuresCol)) SchemaUtils.checkNumericType(dataset.schema, $(predictionCol)) + val vectorCol = DatasetUtils.columnToVector(dataset, $(featuresCol)) + val df = dataset.select(col($(predictionCol)), + vectorCol.as($(featuresCol), dataset.schema($(featuresCol)).metadata)) + ($(metricName), $(distanceMeasure)) match { case ("silhouette", "squaredEuclidean") => SquaredEuclideanSilhouette.computeSilhouetteScore( - dataset, $(predictionCol), $(featuresCol)) + df, $(predictionCol), $(featuresCol)) case ("silhouette", "cosine") => - CosineSilhouette.computeSilhouetteScore(dataset, $(predictionCol), $(featuresCol)) + CosineSilhouette.computeSilhouetteScore(df, $(predictionCol), $(featuresCol)) } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala index 2c175ff68e0b8..e2d77560293fa 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Dataset @@ -33,10 +33,17 @@ class ClusteringEvaluatorSuite import testImplicits._ @transient var irisDataset: Dataset[_] = _ + @transient var newIrisDataset: Dataset[_] = _ + @transient var newIrisDatasetD: Dataset[_] = _ + @transient var newIrisDatasetF: Dataset[_] = _ override def beforeAll(): Unit = { super.beforeAll() irisDataset = spark.read.format("libsvm").load("../data/mllib/iris_libsvm.txt") + val datasets = MLTestingUtils.generateArrayFeatureDataset(irisDataset) + newIrisDataset = datasets._1 + newIrisDatasetD = datasets._2 + newIrisDatasetF = datasets._3 } test("params") { @@ -66,6 +73,9 @@ class ClusteringEvaluatorSuite .setPredictionCol("label") assert(evaluator.evaluate(irisDataset) ~== 0.6564679231 relTol 1e-5) + assert(evaluator.evaluate(newIrisDataset) ~== 0.6564679231 relTol 1e-5) + assert(evaluator.evaluate(newIrisDatasetD) ~== 0.6564679231 relTol 1e-5) + assert(evaluator.evaluate(newIrisDatasetF) ~== 0.6564679231 relTol 1e-5) } /* @@ -85,6 +95,9 @@ class ClusteringEvaluatorSuite .setDistanceMeasure("cosine") assert(evaluator.evaluate(irisDataset) ~== 0.7222369298 relTol 1e-5) + assert(evaluator.evaluate(newIrisDataset) ~== 0.7222369298 relTol 1e-5) + assert(evaluator.evaluate(newIrisDatasetD) ~== 0.7222369298 relTol 1e-5) + assert(evaluator.evaluate(newIrisDatasetF) ~== 0.7222369298 relTol 1e-5) } test("number of clusters must be greater than one") {