Skip to content

Commit

Permalink
Bounds checking
Browse files Browse the repository at this point in the history
  • Loading branch information
TomAugspurger committed Apr 25, 2018
1 parent 449983b commit c449afd
Show file tree
Hide file tree
Showing 7 changed files with 132 additions and 20 deletions.
22 changes: 7 additions & 15 deletions pandas/core/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1488,24 +1488,16 @@ def take(arr, indexer, allow_fill=False, fill_value=None):
--------
numpy.take
"""
indexer = np.asarray(indexer)
from pandas.core.indexing import validate_indices

# Do we require int64 here?
indexer = np.asarray(indexer, dtype='int')

if allow_fill:
# Pandas style, -1 means NA
# bounds checking
if (indexer < -1).any():
raise ValueError("Invalid value in 'indexer'. All values "
"must be non-negative or -1. When "
"'fill_value' is specified.")
if (indexer > len(arr)).any():
# TODO: reuse with logic elsewhere.
raise IndexError

# # take on empty array not handled as desired by numpy
# # in case of -1 (all missing take)
# if not len(arr) and mask.all():
# return arr._from_sequence([fill_value] * len(indexer))
result = take_1d(arr, indexer, fill_value=fill_value)
# Use for bounds checking, we don't actually want to convert.
validate_indices(indexer, len(arr))
result = take_1d(arr, indexer, allow_fill=True, fill_value=fill_value)
else:
# NumPy style
result = arr.take(indexer)
Expand Down
41 changes: 41 additions & 0 deletions pandas/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2417,12 +2417,53 @@ def maybe_convert_indices(indices, n):
mask = indices < 0
if mask.any():
indices[mask] += n

mask = (indices >= n) | (indices < 0)
if mask.any():
raise IndexError("indices are out-of-bounds")
return indices


def validate_indices(indices, n):
"""Perform bounds-checking for an indexer.
-1 is allowed for indicating missing values.
Parameters
----------
indices : ndarray
n : int
length of the array being indexed
Raises
------
ValueError
Examples
--------
>>> validate_indices([1, 2], 3)
# OK
>>> validate_indices([1, -2], 3)
ValueError
>>> validate_indices([1, 2, 3], 3)
IndexError
>>> validate_indices([-1, -1], 0)
# OK
>>> validate_indices([0, 1], 0)
IndexError
"""
if len(indices):
min_idx = indices.min()
if min_idx < -1:
msg = ("'indices' contains values less than allowed ({} < {})"
.format(min_idx, -1))
raise ValueError(msg)

max_idx = indices.max()
if max_idx >= n:
raise IndexError("indices are out-of-bounds")


def maybe_convert_ix(*args):
"""
We likely want to take the cross-product
Expand Down
14 changes: 12 additions & 2 deletions pandas/tests/extension/base/getitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,12 @@ def test_take(self, data, na_value, na_cmp):

def test_take_empty(self, data, na_value, na_cmp):
empty = data[:0]
# result = empty.take([-1])
# na_cmp(result[0], na_value)

result = empty.take([-1], allow_fill=True)
na_cmp(result[0], na_value)

with pytest.raises(IndexError):
empty.take([-1])

with tm.assert_raises_regex(IndexError, "cannot do a non-empty take"):
empty.take([0, 1])
Expand All @@ -160,6 +164,12 @@ def test_take_pandas_style_negative_raises(self, data, na_value):
with pytest.raises(ValueError):
data.take([0, -2], fill_value=na_value, allow_fill=True)

@pytest.mark.parametrize('allow_fill', [True, False])
def test_take_out_of_bounds_raises(self, data, allow_fill):
arr = data[:3]
with pytest.raises(IndexError):
arr.take(np.asarray([0, 3]), allow_fill=allow_fill)

@pytest.mark.xfail(reason="Series.take with extension array buggy for -1")
def test_take_series(self, data):
s = pd.Series(data)
Expand Down
4 changes: 4 additions & 0 deletions pandas/tests/extension/category/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ def test_take_pandas_style_negative_raises(self):
def test_take_non_na_fill_value(self):
pass

@skip_take
def test_take_out_of_bounds_raises(self):
pass

@pytest.mark.xfail(reason="Categorical.take buggy")
def test_take_empty(self):
pass
Expand Down
8 changes: 6 additions & 2 deletions pandas/tests/extension/json/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
class JSONDtype(ExtensionDtype):
type = collections.Mapping
name = 'json'
na_value = collections.UserDict()
try:
na_value = collections.UserDict()
except AttributeError:
# source compatibility with Py2.
na_value = {}

@classmethod
def construct_from_string(cls, string):
Expand Down Expand Up @@ -112,7 +116,7 @@ def take(self, indexer, allow_fill=False, fill_value=None):
output = [self.data[loc] if loc != -1 else fill_value
for loc in indexer]
except IndexError:
raise msg
raise IndexError(msg)
else:
try:
output = [self.data[loc] for loc in indexer]
Expand Down
27 changes: 26 additions & 1 deletion pandas/tests/indexing/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
import numpy as np

import pandas as pd
from pandas.core.indexing import _non_reducing_slice, _maybe_numeric_slice
from pandas.core.indexing import (_non_reducing_slice, _maybe_numeric_slice,
validate_indices)
from pandas import NaT, DataFrame, Index, Series, MultiIndex
import pandas.util.testing as tm

Expand Down Expand Up @@ -994,3 +995,27 @@ def test_none_coercion_mixed_dtypes(self):
datetime(2000, 1, 3)],
'd': [None, 'b', 'c']})
tm.assert_frame_equal(start_dataframe, exp)


def test_validate_indices_ok():
indices = np.asarray([0, 1])
validate_indices(indices, 2)
validate_indices(indices[:0], 0)
validate_indices(np.array([-1, -1]), 0)


def test_validate_indices_low():
indices = np.asarray([0, -2])
with tm.assert_raises_regex(ValueError, "'indices' contains"):
validate_indices(indices, 2)


def test_validate_indices_high():
indices = np.asarray([0, 1, 2])
with tm.assert_raises_regex(IndexError, "indices are out"):
validate_indices(indices, 2)


def test_validate_indices_empty():
with tm.assert_raises_regex(IndexError, "indices are out"):
validate_indices(np.array([0, 1]), 0)
36 changes: 36 additions & 0 deletions pandas/tests/test_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -1564,3 +1564,39 @@ def test_index(self):
idx = Index(['1 day', '1 day', '-1 day', '-1 day 2 min',
'2 min', '2 min'], dtype='timedelta64[ns]')
tm.assert_series_equal(algos.mode(idx), exp)


class TestTake(object):

def test_bounds_check_large(self):
arr = np.array([1, 2])
with pytest.raises(IndexError):
algos.take(arr, [2, 3], allow_fill=True)

with pytest.raises(IndexError):
algos.take(arr, [2, 3], allow_fill=False)

def test_bounds_check_small(self):
arr = np.array([1, 2, 3], dtype=np.int64)
indexer = [0, -1, -2]
with pytest.raises(ValueError):
algos.take(arr, indexer, allow_fill=True)

result = algos.take(arr, indexer)
expected = np.array([1, 3, 2], dtype=np.int64)
tm.assert_numpy_array_equal(result, expected)

@pytest.mark.parametrize('allow_fill', [True, False])
def test_take_empty(self, allow_fill):
arr = np.array([], dtype=np.int64)
# empty take is ok
result = algos.take(arr, [], allow_fill=allow_fill)
tm.assert_numpy_array_equal(arr, result)

with pytest.raises(IndexError):
algos.take(arr, [0], allow_fill=allow_fill)

def test_take_na_empty(self):
result = algos.take([], [-1, -1], allow_fill=True, fill_value=0)
expected = np.array([0, 0], dtype=np.int64)
tm.assert_numpy_array_equal(result, expected)

0 comments on commit c449afd

Please sign in to comment.