Skip to content

Commit

Permalink
Merge pull request #50 from PolicyEngine/mapping
Browse files Browse the repository at this point in the history
Add mapping logic
  • Loading branch information
nikhilwoodruff committed Apr 24, 2022
2 parents 323d6a5 + 807673b commit dede391
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 2 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [0.10.1] - 2022-04-24

### Fixed

* `map_to` is correctly used in `deriv` calls.

## [0.10.0] - 2022-04-18

### Fixed
Expand Down
66 changes: 65 additions & 1 deletion openfisca_tools/hypothetical.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ def __init__(self, reform: ReformType = (), year: int = 2021) -> None:
}
self.varying = False
self.num_points = None
self.group_entity_names = [
entity.key
for entity in self.system.entities
if not entity.is_person
]

# Add add_entity functions

Expand Down Expand Up @@ -164,6 +169,51 @@ def get_entity(self, name: str) -> Entity:
][0]
return entity_type

def map_to(
self, arr: np.array, entity: str, target_entity: str, how: str = None
):
"""Maps values from one entity to another.
Args:
arr (np.array): The values in their original position.
entity (str): The source entity.
target_entity (str): The target entity.
how (str, optional): A function to use when mapping. Defaults to None.
Raises:
ValueError: If an invalid (dis)aggregation function is passed.
Returns:
np.array: The mapped values.
"""
entity_pop = self.simulation.populations[entity]
target_pop = self.simulation.populations[target_entity]
if entity == "person" and target_entity in self.group_entity_names:
if how and how not in (
"sum",
"any",
"min",
"max",
"all",
"value_from_first_person",
):
raise ValueError("Not a valid function.")
return target_pop.__getattribute__(how or "sum")(arr)
elif entity in self.group_entity_names and target_entity == "person":
if not how:
return entity_pop.project(arr)
if how == "mean":
return entity_pop.project(arr / entity_pop.nb_persons())
elif entity == target_entity:
return arr
else:
return self.map_to(
self.map_to(arr, entity, "person", how="mean"),
"person",
target_entity,
how="sum",
)

def get_group(self, entity: str, name: str) -> str:
"""Gets the name of the containing entity for a named entity and group type.
Expand All @@ -188,6 +238,7 @@ def calc(
period: int = None,
target: str = None,
index: int = None,
map_to: str = None,
reform: ReformType = None,
) -> np.array:
"""Calculates the value of a variable, executing any required formulas.
Expand All @@ -207,7 +258,14 @@ def calc(

if self.parametric_vary and reform is None:
results = [
self.calc(var, period, target, index, reform)
self.calc(
var,
period=period,
target=target,
index=index,
map_to=map_to,
reform=reform,
)
for reform in self.parametric_reforms
]
return np.array(results)
Expand All @@ -217,6 +275,9 @@ def calc(
self.apply_reform(reform)
self.build()

if map_to is None and target is not None:
map_to = self.get_entity(target).key

period = period or self.year
entity = self.system.variables[var].entity
if target is not None:
Expand All @@ -230,6 +291,9 @@ def calc(
result = self.sim.calculate_add(var, period)
except:
result = self.simulation.calculate_divide(var, period)
if map_to is not None:
result = self.map_to(result, entity.key, map_to)
entity = self.entities[map_to]
if self.varying:
result = result.reshape(
(self.num_points, len(self.situation_data[entity.plural]))
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name="OpenFisca-Tools",
version="0.9.1",
version="0.10.1",
author="PolicyEngine",
license="http://www.fsf.org/licensing/licenses/agpl-3.0.html",
url="https://github.com/policyengine/openfisca-tools",
Expand Down

0 comments on commit dede391

Please sign in to comment.