Skip to content

Commit

Permalink
Merge pull request #45 from PolicyEngine/data
Browse files Browse the repository at this point in the history
Dataset improvements
  • Loading branch information
nikhilwoodruff authored Apr 13, 2022
2 parents 35eb5c4 + 733a94e commit 7404709
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 30 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [0.9.0] - 2022-04-11

### Added

* Datasets can now input variables at multiple time periods.

## [0.8.0] - 2022-04-10

### Added
Expand Down
1 change: 1 addition & 0 deletions openfisca_tools/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from openfisca_tools.data.private import PrivateDataset
from openfisca_tools.data.public import PublicDataset
from openfisca_tools.data.dataset import Dataset
from openfisca_tools.data.cli import openfisca_data_cli
6 changes: 4 additions & 2 deletions openfisca_tools/data/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@ def openfisca_data_cli(datasets: List[Dataset]):
try:
target = getattr(datasets[args.dataset], args.action)
if callable(target):
target(*args.args)
result = target(*args.args)
else:
return target
result = target
if result is not None:
print(result)
except Exception as e:
print("\n\nEncountered an error:")
raise e
Expand Down
45 changes: 42 additions & 3 deletions openfisca_tools/data/dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union
from typing import Callable, Union
import logging
import os
from pathlib import Path
Expand All @@ -22,6 +22,7 @@ class Dataset:
# Data formats
TABLES = "tables"
ARRAYS = "arrays"
TIME_PERIOD_ARRAYS = "time_period_arrays"

def __init__(self):
# Setup dataset
Expand Down Expand Up @@ -56,8 +57,23 @@ def __init__(self):
assert self.data_format in [
Dataset.TABLES,
Dataset.ARRAYS,
Dataset.TIME_PERIOD_ARRAYS,
], "You tried to instantiate a Dataset object, but your data_format attribute is invalid."

# Ensure typed arguments are enforced in `generate`

def cast_first_arg_as_int(fn: Callable) -> Callable:
def wrapper(*args, **kwargs):
args = list(args)
args[0] = int(args[0])
return fn(*args, **kwargs)

return wrapper

self.generate = cast_first_arg_as_int(self.generate)
self.download = cast_first_arg_as_int(self.download)
self.upload = cast_first_arg_as_int(self.upload)

def filename(self, year: int) -> str:
"""Returns the filename of the dataset for a given year.
Expand Down Expand Up @@ -95,7 +111,7 @@ def load(
Union[h5py.File, np.array, pd.DataFrame, pd.HDFStore]: The dataset.
"""
file = self.folder_path / self.filename(year)
if self.data_format == Dataset.ARRAYS:
if self.data_format in (Dataset.ARRAYS, Dataset.TIME_PERIOD_ARRAYS):
if key is None:
# If no key provided, return the basic H5 reader.
return h5py.File(file, mode=mode)
Expand All @@ -118,6 +134,29 @@ def load(
f"Invalid data format {self.data_format} for dataset {self.label}."
)

def save(self, year: int, key: str, values: Union[np.array, pd.DataFrame]):
"""Overwrites the values for `key` with `values`.
Args:
year (int): The year of the dataset to save.
key (str): The key to save.
values (Union[np.array, pd.DataFrame]): The values to save.
"""
file = self.folder_path / self.filename(year)
if self.data_format in (Dataset.ARRAYS, Dataset.TIME_PERIOD_ARRAYS):
with h5py.File(file, "a") as f:
# Overwrite if existing
if key in f:
del f[key]
f.create_dataset(key, data=values)
elif self.data_format == Dataset.TABLES:
with pd.HDFStore(file, "a") as f:
f.put(key, values)
else:
raise ValueError(
f"Invalid data format {self.data_format} for dataset {self.label}."
)

def keys(self, year: int):
"""Returns the keys of the dataset for a given year.
Expand All @@ -127,7 +166,7 @@ def keys(self, year: int):
Returns:
list: The keys of the dataset.
"""
if self.data_format == Dataset.ARRAYS:
if self.data_format in (Dataset.ARRAYS, Dataset.TIME_PERIOD_ARRAYS):
with h5py.File(self.file(year), mode="r") as f:
return list(f.keys())
elif self.data_format == Dataset.TABLES:
Expand Down
67 changes: 47 additions & 20 deletions openfisca_tools/microsimulation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Microsimulation interfaces and utility functions.
"""
import logging
from re import S
from typing import Callable, List, Tuple
from openfisca_core.entities.entity import Entity
Expand All @@ -11,6 +12,7 @@
from openfisca_core.simulation_builder import SimulationBuilder
from microdf import MicroSeries
from openfisca_core.taxbenefitsystems import TaxBenefitSystem
from openfisca_tools.data.dataset import Dataset

from openfisca_tools.model_api import carried_over, ReformType

Expand Down Expand Up @@ -115,19 +117,31 @@ def apply(s):
builder = SimulationBuilder()
builder.create_entities(self.system)

if not hasattr(dataset, "data_format"):
dataset.data_format = Dataset.ARRAYS

key_suffix = (
f"/{year}"
if dataset.data_format == Dataset.TIME_PERIOD_ARRAYS
else ""
)

for person_entity in self.person_entity_names:
builder.declare_person_entity(
person_entity, np.array(data[f"{person_entity}_id"])
person_entity,
np.array(data[f"{person_entity}_id{key_suffix}"]),
)

for group_entity in self.group_entity_names:
primary_keys = np.array(data[f"{group_entity}_id"])
primary_keys = np.array(data[f"{group_entity}_id{key_suffix}"])
group = builder.declare_entity(group_entity, primary_keys)
foreign_keys = np.array(data[f"person_{group_entity}_id"])
if f"person_{group_entity}_role" in data.keys():
roles = np.array(data[f"person_{group_entity}_role"]).astype(
str
)
foreign_keys = np.array(
data[f"person_{group_entity}_id{key_suffix}"]
)
if f"person_{group_entity}_role{key_suffix}" in data.keys():
roles = np.array(
data[f"person_{group_entity}_role{key_suffix}"]
).astype(str)
elif "role" in data.keys():
roles = np.array(data["role"]).astype(str)
else:
Expand All @@ -138,22 +152,35 @@ def apply(s):

self.simulation = builder.build(self.system)
self.simulation.max_spiral_loops = 10
self.set_input = self.simulation.set_input
skipped = []
for variable in data.keys():
if variable in self.system.variables:
values = np.array(data[variable])
target_dtype = self.system.variables[variable].value_type
if target_dtype in (Enum, str):
values = values.astype(str)
else:
values = values.astype(target_dtype)
if dataset.data_format == Dataset.TIME_PERIOD_ARRAYS:
for variable in data.keys():
for period in data[variable].keys():
try:
self.set_input(
variable, period, data[variable][period]
)
except Exception as e:
logging.warn(
f"Could not set {variable} for {period}: {e}"
)
else:
for variable in data.keys():
try:
self.simulation.set_input(variable, year, values)
except:
skipped += [variable]
self.set_input(variable, year, data[variable])
except Exception as e:
logging.warn(f"Could not set {variable} for {period}: {e}")
data.close()

def set_input(self, variable: str, year: int, values: np.ndarray) -> None:
if variable in self.system.variables:
values = np.array(values)
target_dtype = self.system.variables[variable].value_type
if target_dtype in (Enum, str):
values = values.astype(str)
else:
values = values.astype(target_dtype)
self.simulation.set_input(variable, year, values)

def map_to(
self, arr: np.array, entity: str, target_entity: str, how: str = None
):
Expand Down
10 changes: 6 additions & 4 deletions openfisca_tools/model_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,10 +279,12 @@ def formula_start_year(entity, period, parameters):
if by is None:
return entity(variable.__name__, period.last_year)
else:
uprating = (
parameters(period).uprating[by]
/ parameters(period.last_year).uprating[by]
)
current_parameter = parameters(period)
last_year_parameter = parameters(period.last_year)
for name in by.split("."):
current_parameter = getattr(current_parameter, name)
last_year_parameter = getattr(last_year_parameter, name)
uprating = current_parameter / last_year_parameter
old = entity(variable.__name__, period.last_year)
if (formula is not None) and (all(old) == 0):
# If no values have been inputted, don't uprate and
Expand Down
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name="OpenFisca-Tools",
version="0.8.0",
version="0.9.0",
author="PolicyEngine",
license="http://www.fsf.org/licensing/licenses/agpl-3.0.html",
url="https://github.com/policyengine/openfisca-tools",
Expand All @@ -13,6 +13,8 @@
"pandas",
"wheel",
"h5py",
"tables",
"google-cloud-storage",
],
extras_require={
"test": [
Expand Down

0 comments on commit 7404709

Please sign in to comment.