Skip to content

Commit

Permalink
add assertion for 0-length vectors
Browse files Browse the repository at this point in the history
  • Loading branch information
mgaido91 committed Feb 11, 2018
1 parent ba73fc8 commit 4b41213
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 1 deletion.
Expand Up @@ -761,6 +761,7 @@ private[spark] class CosineDistanceMeasure extends DistanceMeasure {
* @return the cosine distance between the two input vectors
*/
override def distance(v1: VectorWithNorm, v2: VectorWithNorm): Double = {
assert(v1.norm > 0 && v2.norm > 0, "Cosine distance is not defined for zero-length vectors.")
1 - dot(v1.vector, v2.vector) / v1.norm / v2.norm
}

Expand Down
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.ml.clustering

import scala.util.Random

import org.apache.spark.SparkFunSuite
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
Expand Down Expand Up @@ -182,6 +182,18 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
model.clusterCenters.forall(Vectors.norm(_, 2) == 1.0)
}

test("KMeans with cosine distance is not supported for 0-length vectors") {
val model = new KMeans().setDistanceMeasure(DistanceMeasure.COSINE).setK(2)
val df = spark.createDataFrame(spark.sparkContext.parallelize(Array(
Vectors.dense(0.0, 0.0),
Vectors.dense(10.0, 10.0),
Vectors.dense(1.0, 0.5)
)).map(v => TestRow(v)))
val e = intercept[SparkException](model.fit(df))
assert(e.getCause.isInstanceOf[AssertionError])
assert(e.getCause.getMessage.contains("Cosine distance is not defined"))
}

test("read/write") {
def checkModelData(model: KMeansModel, model2: KMeansModel): Unit = {
assert(model.clusterCenters === model2.clusterCenters)
Expand Down

0 comments on commit 4b41213

Please sign in to comment.