forked from PSLmodels/Tax-Calculator
-
Notifications
You must be signed in to change notification settings - Fork 0
/
parameters.py
519 lines (459 loc) · 19.6 KB
/
parameters.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
"""
Tax-Calculator abstract base parameters class.
"""
import os
import json
import six
import abc
import collections as collect
import numpy as np
from taxcalc.utils import read_egg_json
class ParametersBase(object):
"""
Inherit from this class for Policy, Behavior, Consumption, Growdiff, and
other groups of parameters that need to have a set_year method.
Override this __init__ method and DEFAULTS_FILENAME.
"""
__metaclass__ = abc.ABCMeta
DEFAULTS_FILENAME = None
@classmethod
def default_data(cls, metadata=False, start_year=None):
"""
Return parameter data read from the subclass's json file.
Parameters
----------
metadata: boolean
start_year: int or None
Returns
-------
params: dictionary of data
"""
# extract different data from DEFAULT_FILENAME depending on start_year
if start_year is None:
params = cls._params_dict_from_json_file()
else:
nyrs = start_year - cls.JSON_START_YEAR + 1
ppo = cls(num_years=nyrs)
ppo.set_year(start_year)
params = getattr(ppo, '_vals')
params = ParametersBase._revised_default_data(params, start_year,
nyrs, ppo)
# return different data from params dict depending on metadata value
if metadata:
return params
else:
return {name: data['value'] for name, data in params.items()}
def __init__(self):
pass
def initialize(self, start_year, num_years):
"""
Called from subclass __init__ function.
"""
self._current_year = start_year
self._start_year = start_year
self._num_years = num_years
self._end_year = start_year + num_years - 1
self.set_default_vals()
def inflation_rates(self):
"""
Override this method in subclass when appropriate.
"""
return None
def wage_growth_rates(self):
"""
Override this method in subclass when appropriate.
"""
return None
def indexing_rates(self, param_name):
"""
Return appropriate indexing rates for specified param_name.
"""
if param_name == '_SS_Earnings_c':
return self.wage_growth_rates()
else:
return self.inflation_rates()
def set_default_vals(self):
"""
Called by initialize method and from some subclass methods.
"""
if hasattr(self, '_vals'):
for name, data in self._vals.items():
if not isinstance(name, six.string_types):
msg = 'parameter name {} is not a string'
raise ValueError(msg.format(name))
integer_values = data.get('integer_value', None)
values = data.get('value', None)
if values:
cpi_inflated = data.get('cpi_inflated', False)
if cpi_inflated:
index_rates = self.indexing_rates(name)
else:
index_rates = None
setattr(self, name,
self._expand_array(values, integer_values,
inflate=cpi_inflated,
inflation_rates=index_rates,
num_years=self._num_years))
self.set_year(self._start_year)
@property
def num_years(self):
"""
ParametersBase class number of parameter years property.
"""
return self._num_years
@property
def current_year(self):
"""
ParametersBase class current calendar year property.
"""
return self._current_year
@property
def start_year(self):
"""
ParametersBase class first parameter year property.
"""
return self._start_year
@property
def end_year(self):
"""
ParametersBase class lasst parameter year property.
"""
return self._end_year
def set_year(self, year):
"""
Set parameters to their values for the specified calendar year.
Parameters
----------
year: int
calendar year for which to current_year and parameter values
Raises
------
ValueError:
if year is not in [start_year, end_year] range.
Returns
-------
nothing: void
Notes
-----
To increment the current year, use the following statement::
behavior.set_year(behavior.current_year + 1)
where, in this example, behavior is a Behavior object.
"""
if year < self.start_year or year > self.end_year:
msg = 'year {} passed to set_year() must be in [{},{}] range.'
raise ValueError(msg.format(year, self.start_year, self.end_year))
self._current_year = year
year_zero_indexed = year - self._start_year
if hasattr(self, '_vals'):
for name in self._vals:
if isinstance(name, six.string_types):
arr = getattr(self, name)
setattr(self, name[1:], arr[year_zero_indexed])
# ----- begin private methods of ParametersBase class -----
@staticmethod
def _revised_default_data(params, start_year, nyrs, ppo):
"""
Return revised default parameter data.
Parameters
----------
params: dictionary of NAME:DATA pairs for each parameter
as defined in calling default_data staticmethod.
start_year: int
as defined in calling default_data staticmethod.
nyrs: int
as defined in calling default_data staticmethod.
ppo: Policy object
as defined in calling default_data staticmethod.
Returns
-------
params: dictionary of revised parameter data
Notes
-----
This staticmethod is called from default_data staticmethod in
order to reduce the complexity of the default_data staticmethod.
"""
start_year_str = '{}'.format(start_year)
for name, data in params.items():
data['start_year'] = start_year
values = data['value']
num_values = len(values)
if num_values <= nyrs:
# val should be the single start_year value
rawval = getattr(ppo, name[1:])
if isinstance(rawval, np.ndarray):
val = rawval.tolist()
else:
val = rawval
data['value'] = [val]
data['row_label'] = [start_year_str]
else: # if num_values > nyrs
# val should extend beyond the start_year value
data['value'] = data['value'][(nyrs - 1):]
data['row_label'] = data['row_label'][(nyrs - 1):]
return params
@classmethod
def _params_dict_from_json_file(cls):
"""
Read DEFAULTS_FILENAME file and return complete dictionary.
Parameters
----------
nothing: void
Returns
-------
params: dictionary
containing complete contents of DEFAULTS_FILENAME file.
"""
if cls.DEFAULTS_FILENAME is None:
msg = 'DEFAULTS_FILENAME must be overridden by inheriting class'
raise NotImplementedError(msg)
path = os.path.join(os.path.abspath(os.path.dirname(__file__)),
cls.DEFAULTS_FILENAME)
if os.path.exists(path):
with open(path) as pfile:
params_dict = json.load(pfile,
object_pairs_hook=collect.OrderedDict)
else:
# cannot call read_egg_ function in unit tests
params_dict = read_egg_json(
cls.DEFAULTS_FILENAME) # pragma: no cover
return params_dict
def _update(self, year_mods):
"""
Private method used by public implement_reform and update_* methods
in inheriting classes.
Parameters
----------
year_mods: dictionary containing a single YEAR:MODS pair
see Notes below for details on dictionary structure.
Raises
------
ValueError:
if year_mods is not a dictionary of the expected structure.
Returns
-------
nothing: void
Notes
-----
This is a private method that should **never** be used by clients
of the inheriting classes. Instead, always use the public
implement_reform or update_behavior methods.
This is a private method that helps the public methods work.
This method implements a policy reform or behavior modification,
the provisions of which are specified in the year_mods dictionary,
that changes the values of some policy parameters in objects of
inheriting classes. This year_mods dictionary contains exactly one
YEAR:MODS pair, where the integer YEAR key indicates the
calendar year for which the reform provisions in the MODS
dictionary are implemented. The MODS dictionary contains
PARAM:VALUE pairs in which the PARAM is a string specifying
the policy parameter (as used in the DEFAULTS_FILENAME default
parameter file) and the VALUE is a Python list of post-reform
values for that PARAM in that YEAR. Beginning in the year
following the implementation of a reform provision, the
parameter whose value has been changed by the reform continues
to be inflation indexed, if relevant, or not be inflation indexed
according to that parameter's cpi_inflated value loaded from
DEFAULTS_FILENAME. For a cpi-related parameter, a reform can change
the indexing status of a parameter by including in the MODS dictionary
a term that is a PARAM_cpi:BOOLEAN pair specifying the post-reform
indexing status of the parameter.
So, for example, to raise the OASDI (i.e., Old-Age, Survivors,
and Disability Insurance) maximum taxable earnings beginning
in 2018 to $500,000 and to continue indexing it in subsequent
years as in current-law policy, the YEAR:MODS dictionary would
be as follows::
{2018: {"_SS_Earnings_c":[500000]}}
But to raise the maximum taxable earnings in 2018 to $500,000
without any indexing in subsequent years, the YEAR:MODS
dictionary would be as follows::
{2018: {"_SS_Earnings_c":[500000], "_SS_Earnings_c_cpi":False}}
And to raise in 2019 the starting AGI for EITC phaseout for
married filing jointly filing status (which is a two-dimensional
policy parameter that varies by the number of children from zero
to three or more and is inflation indexed), the YEAR:MODS dictionary
would be as follows::
{2019: {"_EITC_ps_MarriedJ":[[8000, 8500, 9000, 9500]]}}
Notice the pair of double square brackets around the four values
for 2019. The one-dimensional parameters above require only a pair
of single square brackets.
To model a change in behavior substitution effect, a year_mods dict
example would be::
{2014: {'_BE_sub': [0.2, 0.3]}}
"""
# check YEAR value in the single YEAR:MODS dictionary parameter
if not isinstance(year_mods, dict):
msg = 'year_mods is not a dictionary'
raise ValueError(msg)
if len(year_mods.keys()) != 1:
msg = 'year_mods dictionary must contain a single YEAR:MODS pair'
raise ValueError(msg)
year = list(year_mods.keys())[0]
if year != self.current_year:
msg = 'YEAR={} in year_mods is not equal to current_year={}'
raise ValueError(msg.format(year, self.current_year))
# check that MODS is a dictionary
if not isinstance(year_mods[year], dict):
msg = 'mods in year_mods is not a dictionary'
raise ValueError(msg)
# implement reform provisions included in the single YEAR:MODS pair
num_years_to_expand = (self.start_year + self.num_years) - year
all_names = set(year_mods[year].keys()) # no duplicate keys in a dict
used_names = set() # set of used parameter names in MODS dict
for name, values in year_mods[year].items():
# determine indexing status of parameter with name for year
if name.endswith('_cpi'):
continue # handle elsewhere in this method
if name in self._vals:
vals_indexed = self._vals[name].get('cpi_inflated', False)
integer_values = self._vals[name].get('integer_value')
else:
msg = 'parameter name {} not in parameter values dictionary'
raise ValueError(msg.format(name))
name_plus_cpi = name + '_cpi'
if name_plus_cpi in year_mods[year].keys():
used_names.add(name_plus_cpi)
indexed = year_mods[year].get(name_plus_cpi)
self._vals[name]['cpi_inflated'] = indexed # remember status
else:
indexed = vals_indexed
# set post-reform values of parameter with name
used_names.add(name)
cval = getattr(self, name, None)
index_rates = self._indexing_rates_for_update(name, year,
num_years_to_expand)
nval = self._expand_array(values, integer_values,
inflate=indexed,
inflation_rates=index_rates,
num_years=num_years_to_expand)
cval[(year - self.start_year):] = nval
# handle unused parameter names, all of which end in _cpi, but some
# parameter names ending in _cpi were handled above
unused_names = all_names - used_names
for name in unused_names:
used_names.add(name)
pname = name[:-4] # root parameter name
if pname not in self._vals:
msg = 'root parameter name {} not in values dictionary'
raise ValueError(msg.format(pname))
pindexed = year_mods[year][name]
self._vals[pname]['cpi_inflated'] = pindexed # remember status
cval = getattr(self, pname, None)
pvalues = [cval[year - self.start_year]]
index_rates = self._indexing_rates_for_update(name, year,
num_years_to_expand)
integer_values = self._vals[pname]['integer_value']
nval = self._expand_array(pvalues, integer_values,
inflate=pindexed,
inflation_rates=index_rates,
num_years=num_years_to_expand)
cval[(year - self.start_year):] = nval
# confirm that all names have been used
assert len(used_names) == len(all_names)
# implement updated parameters for year
self.set_year(year)
@staticmethod
def _expand_array(x, x_dtype_int, inflate, inflation_rates, num_years):
"""
Private method called only within this abstract base class.
Dispatch to either _expand_1D or _expand_2D given dimension of x.
Parameters
----------
x : value to expand
x must be either a scalar list or a 1D numpy array, or
x must be either a list of scalar lists or a 2D numpy array
x_dtype_int : boolean
True implies dtype=np.int8; False implies dtype=np.float64
inflate: boolean
As we expand, inflate values if this is True, otherwise, just copy
inflation_rates: list of inflation rates
Annual decimal inflation rates
num_years: int
Number of budget years to expand
Returns
-------
expanded numpy array with specified dtype
"""
if not isinstance(x, list) and not isinstance(x, np.ndarray):
msg = '_expand_array expects x to be a list or numpy array'
raise ValueError(msg)
if isinstance(x, list):
if x_dtype_int:
x = np.array(x, np.int8)
else:
x = np.array(x, np.float64)
if len(x.shape) == 1:
return ParametersBase._expand_1D(x, inflate, inflation_rates,
num_years)
elif len(x.shape) == 2:
return ParametersBase._expand_2D(x, inflate, inflation_rates,
num_years)
else:
raise ValueError('_expand_array expects a 1D or 2D array')
@staticmethod
def _expand_1D(x, inflate, inflation_rates, num_years):
"""
Private method called only from _expand_array method.
Expand the given data x to account for given number of budget years.
If necessary, pad out additional years by increasing the last given
year using the given inflation_rates list.
"""
if not isinstance(x, np.ndarray):
raise ValueError('_expand_1D expects x to be a numpy array')
if len(x) >= num_years:
return x
else:
ans = np.zeros(num_years, dtype=x.dtype)
ans[:len(x)] = x
if inflate:
extra = []
cur = x[-1]
for i in range(0, num_years - len(x)):
cur *= (1. + inflation_rates[i + len(x) - 1])
cur = round(cur, 2) if cur < 9e99 else 9e99
extra.append(cur)
else:
extra = [float(x[-1]) for i in
range(1, num_years - len(x) + 1)]
ans[len(x):] = extra
return ans
@staticmethod
def _expand_2D(x, inflate, inflation_rates, num_years):
"""
Private method called only from _expand_array method.
Expand the given data to account for the given number of budget years.
For 2D arrays, we expand out the number of rows until we have num_years
number of rows. For each expanded row, we inflate using the given
inflation rates list.
"""
if not isinstance(x, np.ndarray):
raise ValueError('_expand_2D expects x to be a numpy array')
if x.shape[0] >= num_years:
return x
else:
ans = np.zeros((num_years, x.shape[1]), dtype=x.dtype)
ans[:len(x), :] = x
for i in range(x.shape[0], ans.shape[0]):
for j in range(ans.shape[1]):
if inflate:
cur = (ans[i - 1, j] *
(1. + inflation_rates[i - 1]))
cur = round(cur, 2) if cur < 9e99 else 9e99
ans[i, j] = cur
else:
ans[i, j] = ans[i - 1, j]
return ans
def _indexing_rates_for_update(self, param_name,
calyear, num_years_to_expand):
"""
Private method called only in the _update method.
"""
if param_name == '_SS_Earnings_c':
rates = self.wage_growth_rates()
else:
rates = self.inflation_rates()
if rates:
expanded_rates = [rates[(calyear - self.start_year) + i]
for i in range(0, num_years_to_expand)]
return expanded_rates
else:
return None