Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/reject-nan-set-input.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Raised a clear error when numeric simulation inputs contain NaN values.
3 changes: 2 additions & 1 deletion policyengine_core/holders/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def set_input_dispatch_by_period(holder: Holder, period: Period, array: ArrayLik

To read more about ``set_input`` attributes, check the `documentation <https://openfisca.org/doc/coding-the-legislation/35_periods.html#set-input-automatically-process-variable-inputs-defined-for-periods-not-matching-the-definition-period>`_.
"""
array = holder._to_array(array)
array = holder._to_array(array, validate_nan=True)

period_size = period.size
period_unit = period.unit
Expand Down Expand Up @@ -70,6 +70,7 @@ def set_input_divide_by_period(holder: Holder, period: Period, array: ArrayLike)
"""
if not isinstance(array, numpy.ndarray):
array = numpy.array(array)
array = holder._to_array(array, validate_nan=True)
period_size = period.size
period_unit = period.unit

Expand Down
34 changes: 30 additions & 4 deletions policyengine_core/holders/holder.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ def set_input(
return warnings.warn(warning_message, Warning)
if self.variable.value_type in (float, int) and isinstance(array, str):
array = tools.eval_expression(array)
self._raise_if_input_contains_nan(numpy.asarray(array))
simulation = getattr(self, "simulation", None)
if simulation is not None:
if not hasattr(simulation, "_user_input_keys"):
Expand All @@ -263,12 +264,29 @@ def set_input(
and period.unit != self.variable.definition_period
):
return self.variable.set_input(self, period, array)
return self._set(period, array, branch_name)
return self._set(period, array, branch_name, validate_nan=True)
finally:
if simulation is not None:
simulation._user_input_contexts.pop()

def _to_array(self, value: Any) -> ArrayLike:
def _raise_if_input_contains_nan(self, value: ArrayLike) -> None:
if self.variable.value_type not in (float, int):
return
value = numpy.asarray(value)
try:
if value.dtype.kind in ("O", "S", "U"):
value = value.astype(float)
contains_nan = numpy.isnan(value).any()
except (TypeError, ValueError):
return
if contains_nan:
raise ValueError(
'Unable to set value for variable "{}", as the input contains NaN values.'.format(
self.variable.name,
)
)

def _to_array(self, value: Any, validate_nan: bool = False) -> ArrayLike:
if not isinstance(value, numpy.ndarray):
value = numpy.asarray(value)
if value.ndim == 0:
Expand All @@ -284,6 +302,8 @@ def _to_array(self, value: Any) -> ArrayLike:
self.population.entity.plural,
)
)
if validate_nan:
self._raise_if_input_contains_nan(value)
if self.variable.value_type == Enum:
original_value = value
value = self.variable.possible_values.encode(value)
Expand All @@ -301,16 +321,22 @@ def _to_array(self, value: Any) -> ArrayLike:
value.dtype,
)
)
if validate_nan:
self._raise_if_input_contains_nan(value)
return value

def _set(
self, period: Period, value: ArrayLike, branch_name: str = "default"
self,
period: Period,
value: ArrayLike,
branch_name: str = "default",
validate_nan: bool = False,
) -> None:
simulation = getattr(self, "simulation", None)
user_input_contexts = getattr(simulation, "_user_input_contexts", None)
if user_input_contexts and branch_name == "default":
branch_name = user_input_contexts[-1]
value = self._to_array(value)
value = self._to_array(value, validate_nan=validate_nan)
if self.variable.definition_period != periods.ETERNITY:
if period is None:
raise ValueError(
Expand Down
63 changes: 63 additions & 0 deletions tests/core/test_holders.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,3 +216,66 @@ def test_set_input_float_to_int(single):
simulation.person.get_holder("age").set_input(period, age)
result = simulation.calculate("age", period)
assert result == numpy.asarray([50])


def test__given_nan_float_array__then_set_input_raises_value_error(single):
simulation = single

with pytest.raises(ValueError, match='variable "salary".*NaN'):
simulation.set_input("salary", period, numpy.asarray([numpy.nan]))


def test__given_nan_int_array__then_set_input_raises_value_error(single):
simulation = single

with pytest.raises(ValueError, match='variable "age".*NaN'):
simulation.set_input("age", period, numpy.asarray([numpy.nan]))


def test__given_object_array_containing_nan__then_set_input_raises_value_error(
single,
):
simulation = single
age = numpy.asarray([numpy.nan], dtype=object)

with pytest.raises(ValueError, match='variable "age".*NaN'):
simulation.set_input("age", period, age)


def test__given_nan_yearly_input__then_set_input_divide_by_period_raises_value_error(
single,
):
simulation = single
salary_holder = simulation.person.get_holder("salary")

with pytest.raises(ValueError, match='variable "salary".*NaN'):
holders.set_input_divide_by_period(
salary_holder,
periods.period(2017),
numpy.asarray([numpy.nan]),
)


def test__given_nan_period_dispatch_input__then_helper_raises_value_error(
single,
):
simulation = single
age_holder = simulation.person.get_holder("age")

with pytest.raises(ValueError, match='variable "age".*NaN'):
holders.set_input_dispatch_by_period(
age_holder,
periods.period(2017),
numpy.asarray([numpy.nan]),
)


def test__given_nan_cache_value__then_put_in_cache_keeps_internal_write_allowed(
single,
):
simulation = single
salary_holder = simulation.person.get_holder("salary")

salary_holder.put_in_cache(numpy.asarray([numpy.nan]), period)

assert numpy.isnan(salary_holder.get_array(period)).all()
Loading