Skip to content

Commit

Permalink
Merge pull request #24 from PolicyEngine/fixes
Browse files Browse the repository at this point in the history
Fixes
  • Loading branch information
nikhilwoodruff committed Feb 10, 2022
2 parents 678e7f7 + e561f8d commit 2634e93
Show file tree
Hide file tree
Showing 9 changed files with 252 additions and 211 deletions.
21 changes: 21 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,27 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [0.3.0] - 2022-02-09

### Added

* `GeneralMicrosimulation` alias for `Microsimulation`.
* `between(values, lower, upper)` function for variable formulas.
* `and_`, `or_` and `multiply_` functions for variable formulas.
* `any_` and `any_of_variables` helper functions for variable formulas.

### Changed

* `is_in` and `amount_over` functions deprecated.
* `select` is now an alias for `np.select`.
* `is_in` will work for both `list` and `*args` inputs.
* `household_net_income` used instead of `net_income` for parameter tests.


### Fixed

* Parameter nesting function bug fixes.

## [0.2.3] - 2022-01-26

### Fixed
Expand Down
5 changes: 4 additions & 1 deletion openfisca_tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from openfisca_tools.microsimulation import Microsimulation
from openfisca_tools.microsimulation import (
Microsimulation,
GeneralMicrosimulation,
)
from openfisca_tools.hypothetical import IndividualSim
from openfisca_tools.testing import generate_tests
from openfisca_tools.model_api import *
Expand Down
3 changes: 3 additions & 0 deletions openfisca_tools/microsimulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,3 +395,6 @@ def deriv_df(
target, wrt=wrt, delta=delta, percent=percent
)
return df


GeneralMicrosimulation = Microsimulation
243 changes: 194 additions & 49 deletions openfisca_tools/model_api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import logging
from turtle import pd
from openfisca_core.model_api import (
DAY,
MONTH,
Expand All @@ -8,8 +10,13 @@
max_,
min_,
)
from typing import Callable, Tuple, Type, Union
from openfisca_core.populations import Population
from openfisca_core.entities import Entity
from typing import Callable, List, Tuple, Type, Union
import numpy as np
from numpy.typing import ArrayLike
from pandas import Period
from itertools import product

ReformType = Union[Reform, Tuple[Reform]]

Expand Down Expand Up @@ -40,78 +47,138 @@ def __init__(self, baseline_variable=None):
np.random.seed(0)


def add(entity, period, variable_names, options=None):
"""Sums a list of variables over entities.
def for_each_variable(
entity: Population,
period: Period,
variables: List[str],
agg_func: str = "add",
group_agg_func: str = "add",
options: List[str] = None,
) -> ArrayLike:
"""Applies operations to lists of variables.
Args:
entity (Entity): Either person, benunit or household
period (Period): The period to calculate over
variable_names (list): A list of variable names
options (list, optional): The options to use - ADD, DIVIDE or MATCH to define period mismatch behaviour. Defaults to None.
entity (Population): The entity population, as passed in formulas.
period (Period): The period, as pass in formulas.
variables (List[str]): A list of variable names.
agg_func (str, optional): The operation to apply to combine variable results. Defaults to "add".
group_agg_func (str, optional): The operation to apply to transform values to the target entity level. Defaults to "add".
options (List[str], optional): Options to pass to the `entity(variable, period)` call. Defaults to None.
Raises:
ValueError: If any target variable is not at or below the target entity level.
Returns:
Array: Array of entity values.
ArrayLike: The result of the operation.
"""
return sum(
map(lambda var: entity(var, period, options=options), variable_names)
)
result = None
agg_func = dict(
add=lambda x, y: x + y, multiply=lambda x, y: x * y, max=max_, min=min_
)[agg_func]
if not entity.entity.is_person:
group_agg_func = dict(
add=entity.sum, all=entity.all, max=entity.max, min=entity.min
)[group_agg_func]
for variable in variables:
variable_entity = entity.entity.get_variable(variable).entity
if variable_entity.key == entity.entity.key:
values = entity(variable, period, options=options)
else:
try:
values = group_agg_func(
entity.members(variable, period, options=options)
)
except Exception as e:
raise ValueError(
f"Variable {variable} is not defined for {entity.entity.label} or {entity.entity.label} members: {e}"
)
if result is None:
result = values
else:
result = agg_func(result, values)
return result


def aggr(entity, period, variable_names, options=None):
"""Sums a list of variables over each member of a group.
def add(
entity: Population,
period: Period,
variables: List[str],
options: List[str] = None,
):
"""Sums a list of variables.
Args:
entity (Entity): Either benunit or household
period (Period): The period to calculate over
variable_names (list): A list of variable names
options (list, optional): The options to use - ADD, DIVIDE or MATCH to define period mismatch behaviour. Defaults to None.
entity (Population): The entity population, as passed in formulas.
period (Period): The period, as pass in formulas.
variables (List[str]): A list of variable names.
options (List[str], optional): Options to pass to the `entity(variable, period)` call. Defaults to None.
Raises:
ValueError: If any target variable is not at or below the target entity level.
Returns:
Array: Array of entity values.
ArrayLike: The result of the operation.
"""
return sum(
map(
lambda var: entity.sum(
entity.members(var, period, options=options)
),
variable_names,
)
return for_each_variable(
entity, period, variables, agg_func="add", options=options
)


def aggr_max(entity, period, variable_names, options=None):
"""Finds the maximum of a list of variables over each member of a group.
def aggr(entity, period, variables, options=None):
"""Sums a list of variables belonging to entity members.
Args:
entity (Entity): Either benunit or household
period (Period): The period to calculate over
variable_names (list): A list of variable names
options (list, optional): The options to use - ADD, DIVIDE or MATCH to define period mismatch behaviour. Defaults to None.
entity (Population): The entity population, as passed in formulas.
period (Period): The period, as pass in formulas.
variables (List[str]): A list of variable names.
options (List[str], optional): Options to pass to the `entity(variable, period)` call. Defaults to None.
Raises:
ValueError: If any target variable is not below the target entity level.
Returns:
Array: Array of entity values.
ArrayLike: The result of the operation.
"""
return sum(
map(
lambda var: entity.max(
entity.members(var, period, options=options)
),
variable_names,
)
return for_each_variable(
entity,
period,
variables,
agg_func="add",
group_agg_func="add",
options=options,
)


def select(conditions, choices):
"""Selects the corresponding choice for the first matching condition in a list.
def and_(
entity: Population,
period: Period,
variables: List[str],
options: List[str] = None,
):
"""Performs a logical and operation on a list of variables.
Args:
conditions (list): A list of boolean arrays
choices (list): A list of arrays
entity (Population): The entity population, as passed in formulas.
period (Period): The period, as pass in formulas.
variables (List[str]): A list of variable names.
options (List[str], optional): Options to pass to the `entity(variable, period)` call. Defaults to None.
Raises:
ValueError: If any target variable is not at the target entity level.
Returns:
Array: Array of values
ArrayLike: The result of the operation.
"""
return np.select(conditions, choices)
return for_each_variable(
entity, period, variables, agg_func="multiply", options=options
)


or_ = add
any_ = or_
multiply = and_

select = np.select


clip = np.clip
Expand All @@ -121,11 +188,35 @@ def select(conditions, choices):
MONTHS_IN_YEAR = 12


def amount_over(amount, threshold):
def amount_over(amount: ArrayLike, threshold: float) -> ArrayLike:
"""Calculates the amounts over a threshold.
Args:
amount (ArrayLike): The amount to calculate for.
threshold_1 (float): The threshold.
Returns:
ArrayLike: The amounts over the threshold.
"""
logging.debug(
"amount_over(x, y) is deprecated, use max_(x - y, 0) instead."
)
return max_(0, amount - threshold)


def amount_between(amount, threshold_1, threshold_2):
def amount_between(
amount: ArrayLike, threshold_1: float, threshold_2: float
) -> ArrayLike:
"""Calculates the amounts between two thresholds.
Args:
amount (ArrayLike): The amount to calculate for.
threshold_1 (float): The lower threshold.
threshold_2 (float): The upper threshold.
Returns:
ArrayLike: The amounts between the thresholds.
"""
return clip(amount, threshold_1, threshold_2) - threshold_1


Expand All @@ -136,8 +227,35 @@ def random(entity, reset=True):
return x


def is_in(values, *targets):
return sum(map(lambda target: values == target, targets))
def is_in(values: ArrayLike, *targets: list) -> ArrayLike:
"""Returns true if the value is in the list of targets.
Args:
values (ArrayLike): The values to test.
Returns:
ArrayLike: True if the value is in the list of targets.
"""
if (len(targets) == 1) and isinstance(targets[0], list):
targets = targets[0]
return np.any([values == target for target in targets], axis=0)


def between(
values: ArrayLike, lower: float, upper: float, inclusive: str = "both"
) -> ArrayLike:
"""Returns true if values are between lower and upper.
Args:
values (ArrayLike): The input array.
lower (float): The lower bound.
upper (float): The upper bound.
inclusive (bool, optional): Whether to include or exclude the bounds. Defaults to True.
Returns:
ArrayLike: The resulting array.
"""
return pd.Series(values).between(lower, upper, inclusive=inclusive)


def uprated(by: str = None, start_year: int = 2015) -> Callable:
Expand Down Expand Up @@ -180,3 +298,30 @@ def formula_start_year(entity, period, parameters):

def carried_over(variable: type) -> type:
return uprated()(variable)


def sum_of_variables(variables: Union[List[str], str]) -> Callable:
"""Returns a function that sums the values of a list of variables.
Args:
variables (Union[List[str], str]): A list of variable names.
Returns:
Callable: A function that sums the values of the variables.
"""

def sum_of_variables(entity, period, parameters):
if isinstance(variables, str):
# A string parameter name is passed
node = parameters(period)
for name in variables.split("."):
node = getattr(node, name)
variable_names = node
else:
variable_names = variables
return add(entity, period, variable_names)

return sum_of_variables


any_of_variables = sum_of_variables
Loading

0 comments on commit 2634e93

Please sign in to comment.