/
options_utils.py
216 lines (193 loc) · 10.3 KB
/
options_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
import warnings
from featuretools import primitives
from featuretools.feature_base import IdentityFeature
from featuretools.variable_types import Discrete
def _get_primitive_options():
# all possible option keys: function that verifies value type
return {'ignore_entities': list_entity_check,
'include_entities': list_entity_check,
'ignore_variables': dict_to_list_variable_check,
'include_variables': dict_to_list_variable_check,
'ignore_groupby_entities': list_entity_check,
'include_groupby_entities': list_entity_check,
'ignore_groupby_variables': dict_to_list_variable_check,
'include_groupby_variables': dict_to_list_variable_check}
def dict_to_list_variable_check(option, es):
if not (isinstance(option, dict) and
all([isinstance(option_val, list) for option_val in option.values()])):
return False
else:
for entity, variables in option.items():
if entity not in es:
warnings.warn("Entity '%s' not in entityset" % (entity))
else:
for invalid_var in [variable for variable in variables
if variable not in es[entity]]:
warnings.warn("Variable '%s' not in entity '%s'" % (invalid_var, entity))
return True
def list_entity_check(option, es):
if not isinstance(option, list):
return False
else:
for invalid_entity in [entity for entity in option if entity not in es]:
warnings.warn("Entity '%s' not in entityset" % (invalid_entity))
return True
def generate_all_primitive_options(all_primitives,
primitive_options,
ignore_entities,
ignore_variables,
es):
entityset_dict = {entity.id: [variable.id for variable in entity.variables]
for entity in es.entities}
primitive_options = _init_primitive_options(primitive_options, entityset_dict)
global_ignore_entities = ignore_entities
global_ignore_variables = ignore_variables.copy()
# for now, only use primitive names as option keys
for primitive in all_primitives:
if not isinstance(primitive, str):
primitive = primitive.name
if primitive in primitive_options:
# Reconcile global options with individually-specified options
options = primitive_options[primitive]
included_entities = set().union(*[set().union(
option.get('include_entities') if option.get('include_entities') else set([]),
option.get('include_variables').keys() if option.get('include_variables') else set([]))
for option in options])
global_ignore_entities = global_ignore_entities.difference(included_entities)
for option in options:
# don't globally ignore a variable if it's included for a primitive
if 'include_variables' in option:
for entity, include_vars in option['include_variables'].items():
global_ignore_variables[entity] = \
global_ignore_variables[entity].difference(include_vars)
option['ignore_entities'] = option['ignore_entities'].union(
ignore_entities.difference(included_entities)
)
for entity, ignore_vars in ignore_variables.items():
# if already ignoring variables for this entity, add globals
for option in options:
if entity in option['ignore_variables']:
option['ignore_variables'][entity] = option['ignore_variables'][entity].union(ignore_vars)
# if no ignore_variables and entity is explicitly included, don't ignore the variable
elif entity in included_entities:
continue
# Otherwise, keep the global option
else:
option['ignore_variables'][entity] = ignore_vars
else:
# no user specified options, just use global defaults
primitive_options[primitive] = [{'ignore_entities': ignore_entities,
'ignore_variables': ignore_variables}]
return primitive_options, global_ignore_entities, global_ignore_variables
def _init_primitive_options(primitive_options, es):
# Flatten all tuple keys, convert value lists into sets, check for
# conflicting keys
flattened_options = {}
for primitive_key, options in primitive_options.items():
if isinstance(options, list):
primitive = primitives.get_aggregation_primitives().get(primitive_key) or \
primitives.get_transform_primitives().get(primitive_key)
assert len(primitive.input_types[0]) == len(options) if \
isinstance(primitive.input_types[0], list) else \
len(primitive.input_types) == len(options), \
"Number of options does not match number of inputs for primitive %s" \
% (primitive_key)
options = [_init_option_dict(primitive_key, option, es) for option in options]
else:
options = [_init_option_dict(primitive_key, options, es)]
if not isinstance(primitive_key, tuple):
primitive_key = (primitive_key,)
for each_primitive in primitive_key:
# if primitive is specified more than once, raise error
if each_primitive in flattened_options:
raise KeyError('Multiple options found for primitive %s' %
(each_primitive))
flattened_options[each_primitive] = options
return flattened_options
def _init_option_dict(key, option_dict, es):
initialized_option_dict = {}
primitive_options = _get_primitive_options()
# verify all keys are valid and match expected type, convert lists to sets
for option_key, option in option_dict.items():
if option_key not in primitive_options:
raise KeyError("Unrecognized primitive option \'%s\' for %s" %
(option_key, key))
if not primitive_options[option_key](option, es):
raise TypeError("Incorrect type formatting for \'%s\' for %s" %
(option_key, key))
if isinstance(option, list):
initialized_option_dict[option_key] = set(option)
elif isinstance(option, dict):
initialized_option_dict[option_key] = {key: set(option[key]) for key in option}
# initialize ignore_entities and ignore_variables to empty sets if not present
if 'ignore_variables' not in initialized_option_dict:
initialized_option_dict['ignore_variables'] = dict()
if 'ignore_entities' not in initialized_option_dict:
initialized_option_dict['ignore_entities'] = set()
return initialized_option_dict
def variable_filter(f, options, groupby=False):
if groupby and not issubclass(f.variable_type, Discrete):
return False
include_vars = 'include_groupby_variables' if groupby else 'include_variables'
ignore_vars = 'ignore_groupby_variables' if groupby else 'ignore_variables'
include_entities = 'include_groupby_entities' if groupby else 'include_entities'
ignore_entities = 'ignore_groupby_entities' if groupby else 'ignore_entities'
dependencies = f.get_dependencies(deep=True) + [f]
for base_f in dependencies:
if isinstance(base_f, IdentityFeature):
if include_vars in options and base_f.entity.id in options[include_vars]:
if base_f.get_name() in options[include_vars][base_f.entity.id]:
continue # this is a valid feature, go to next
else:
return False # this is not an included feature
if ignore_vars in options and base_f.entity.id in options[ignore_vars]:
if base_f.get_name() in options[ignore_vars][base_f.entity.id]:
return False # ignore this feature
if include_entities in options and \
base_f.entity.id not in options[include_entities]:
return False # not an included entity
elif ignore_entities in options and \
base_f.entity.id in options[ignore_entities]:
return False # ignore the entity
return True
def ignore_entity_for_primitive(options, entity, groupby=False):
# This logic handles whether given options ignore an entity or not
def should_ignore_entity(option):
if groupby:
if 'include_groupby_variables' not in option or entity.id not in option['include_groupby_variables']:
if 'include_groupby_entities' in option and entity.id not in option['include_groupby_entities']:
return True
elif 'ignore_groupby_entities' in option and entity.id in option['ignore_groupby_entities']:
return True
if 'include_variables' in option and entity.id in option['include_variables']:
return False
elif 'include_entities' in option and entity.id not in option['include_entities']:
return True
elif entity.id in option['ignore_entities']:
return True
else:
return False
return any([should_ignore_entity(option) for option in options])
def filter_groupby_matches_by_options(groupby_matches, options):
return filter_matches_by_options([(groupby_match, ) for groupby_match in groupby_matches],
options,
groupby=True)
def filter_matches_by_options(matches, options, groupby=False):
# If more than one option, than need to handle each for each input
if len(options) > 1:
def is_valid_match(match):
if all([variable_filter(m, option, groupby) for m, option in zip(match, options)]):
return True
else:
return False
else:
def is_valid_match(match):
if all([variable_filter(f, options[0], groupby) for f in match]):
return True
else:
return False
valid_matches = set()
for match in matches:
if is_valid_match(match):
valid_matches.add(match)
return valid_matches