-
Notifications
You must be signed in to change notification settings - Fork 1
/
model_api.py
164 lines (121 loc) · 4.39 KB
/
model_api.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
from openfisca_core.model_api import (
DAY,
MONTH,
YEAR,
ETERNITY,
Variable as CoreVariable,
Reform,
max_,
min_,
)
from typing import Callable, Tuple, Union
import numpy as np
ReformType = Union[Reform, Tuple[Reform]]
class Variable(CoreVariable):
def __init__(self, baseline_variable=None):
try:
CoreVariable.__init__(self, baseline_variable=baseline_variable)
except ValueError as e:
if "metadata" not in str(e):
raise e
self.is_neutralized = False
np.random.seed(0)
def add(entity, period, variable_names, options=None):
"""Sums a list of variables over entities.
Args:
entity (Entity): Either person, benunit or household
period (Period): The period to calculate over
variable_names (list): A list of variable names
options (list, optional): The options to use - ADD, DIVIDE or MATCH to define period mismatch behaviour. Defaults to None.
Returns:
Array: Array of entity values.
"""
return sum(
map(lambda var: entity(var, period, options=options), variable_names)
)
def aggr(entity, period, variable_names, options=None):
"""Sums a list of variables over each member of a group.
Args:
entity (Entity): Either benunit or household
period (Period): The period to calculate over
variable_names (list): A list of variable names
options (list, optional): The options to use - ADD, DIVIDE or MATCH to define period mismatch behaviour. Defaults to None.
Returns:
Array: Array of entity values.
"""
return sum(
map(
lambda var: entity.sum(
entity.members(var, period, options=options)
),
variable_names,
)
)
def aggr_max(entity, period, variable_names, options=None):
"""Finds the maximum of a list of variables over each member of a group.
Args:
entity (Entity): Either benunit or household
period (Period): The period to calculate over
variable_names (list): A list of variable names
options (list, optional): The options to use - ADD, DIVIDE or MATCH to define period mismatch behaviour. Defaults to None.
Returns:
Array: Array of entity values.
"""
return sum(
map(
lambda var: entity.max(
entity.members(var, period, options=options)
),
variable_names,
)
)
def select(conditions, choices):
"""Selects the corresponding choice for the first matching condition in a list.
Args:
conditions (list): A list of boolean arrays
choices (list): A list of arrays
Returns:
Array: Array of values
"""
return np.select(conditions, choices)
clip = np.clip
inf = np.inf
WEEKS_IN_YEAR = 52
MONTHS_IN_YEAR = 12
def amount_over(amount, threshold):
return max_(0, amount - threshold)
def amount_between(amount, threshold_1, threshold_2):
return clip(amount, threshold_1, threshold_2) - threshold_1
def random(entity, reset=True):
x = np.random.rand(entity.count)
if reset:
np.random.seed(0)
return x
def is_in(values, *targets):
return sum(map(lambda target: values == target, targets))
def uprated(by: str = None, start_year: int = 2015) -> Callable:
"""Attaches a formula applying an uprating factor to input variables (going back as far as 2015).
Args:
by (str, optional): The name of the parameter (under parameters.uprating). Defaults to None (no uprating applied).
Returns:
Callable: A class decorator.
"""
def uprater(variable: type) -> type:
if hasattr(variable, f"formula_{start_year}"):
return variable
def formula_start_year(entity, period, parameters):
if by is None:
return entity(variable.__name__, period.last_year)
else:
uprating = (
parameters(period).uprating[by]
/ parameters(period.last_year).uprating[by]
)
old = entity(variable.__name__, period.last_year)
return uprating * old
formula_start_year.__name__ = f"formula_{start_year}"
setattr(variable, formula_start_year.__name__, formula_start_year)
return variable
return uprater
def carried_over(variable: type) -> type:
return uprated()(variable)