Skip to content

Commit

Permalink
[SPARK-16965][MLLIB][PYSPARK] Fix bound checking for SparseVector.
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

1. In scala, add negative low bound checking and put all the low/upper bound checking in one place
2. In python, add low/upper bound checking of indices.

## How was this patch tested?

unit test added

Author: Jeff Zhang <zjffdu@apache.org>

Closes #14555 from zjffdu/SPARK-16965.
  • Loading branch information
zjffdu authored and srowen committed Aug 19, 2016
1 parent 864be93 commit 072acf5
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 15 deletions.
34 changes: 19 additions & 15 deletions mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala
Original file line number Diff line number Diff line change
Expand Up @@ -208,17 +208,7 @@ object Vectors {
*/
@Since("2.0.0")
def sparse(size: Int, elements: Seq[(Int, Double)]): Vector = {
require(size > 0, "The size of the requested sparse vector must be greater than 0.")

val (indices, values) = elements.sortBy(_._1).unzip
var prev = -1
indices.foreach { i =>
require(prev < i, s"Found duplicate indices: $i.")
prev = i
}
require(prev < size, s"You may not write an element to index $prev because the declared " +
s"size of your vector is $size")

new SparseVector(size, indices.toArray, values.toArray)
}

Expand Down Expand Up @@ -560,11 +550,25 @@ class SparseVector @Since("2.0.0") (
@Since("2.0.0") val indices: Array[Int],
@Since("2.0.0") val values: Array[Double]) extends Vector {

require(indices.length == values.length, "Sparse vectors require that the dimension of the" +
s" indices match the dimension of the values. You provided ${indices.length} indices and " +
s" ${values.length} values.")
require(indices.length <= size, s"You provided ${indices.length} indices and values, " +
s"which exceeds the specified vector size ${size}.")
// validate the data
{
require(size >= 0, "The size of the requested sparse vector must be greater than 0.")
require(indices.length == values.length, "Sparse vectors require that the dimension of the" +
s" indices match the dimension of the values. You provided ${indices.length} indices and " +
s" ${values.length} values.")
require(indices.length <= size, s"You provided ${indices.length} indices and values, " +
s"which exceeds the specified vector size ${size}.")

if (indices.nonEmpty) {
require(indices(0) >= 0, s"Found negative index: ${indices(0)}.")
}
var prev = -1
indices.foreach { i =>
require(prev < i, s"Index $i follows $prev and is not strictly increasing")
prev = i
}
require(prev < size, s"Index $prev out of bounds for vector of size $size")
}

override def toString: String =
s"($size,${indices.mkString("[", ",", "]")},${values.mkString("[", ",", "]")})"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@ class VectorsSuite extends SparkMLFunSuite {
}
}

test("sparse vector construction with negative indices") {
intercept[IllegalArgumentException] {
Vectors.sparse(3, Array(-1, 1), Array(3.0, 5.0))
}
}

test("dense to array") {
val vec = Vectors.dense(arr).asInstanceOf[DenseVector]
assert(vec.toArray.eq(arr))
Expand Down
15 changes: 15 additions & 0 deletions python/pyspark/ml/linalg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,14 @@ def __init__(self, size, *args):
SparseVector(4, {1: 1.0, 3: 5.5})
>>> SparseVector(4, [1, 3], [1.0, 5.5])
SparseVector(4, {1: 1.0, 3: 5.5})
>>> SparseVector(4, {1:1.0, 6:2.0})
Traceback (most recent call last):
...
AssertionError: Index 6 is out of the the size of vector with size=4
>>> SparseVector(4, {-1:1.0})
Traceback (most recent call last):
...
AssertionError: Contains negative index -1
"""
self.size = int(size)
""" Size of the vector. """
Expand Down Expand Up @@ -511,6 +519,13 @@ def __init__(self, size, *args):
"Indices %s and %s are not strictly increasing"
% (self.indices[i], self.indices[i + 1]))

if self.indices.size > 0:
assert np.max(self.indices) < self.size, \
"Index %d is out of the the size of vector with size=%d" \
% (np.max(self.indices), self.size)
assert np.min(self.indices) >= 0, \
"Contains negative index %d" % (np.min(self.indices))

def numNonzeros(self):
"""
Number of nonzero elements. This scans all active values and count non zeros.
Expand Down

0 comments on commit 072acf5

Please sign in to comment.