Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
frances-h committed Nov 20, 2019
1 parent 7ac2493 commit 5f4f4bb
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
3 changes: 1 addition & 2 deletions featuretools/primitives/options_utils.py
Expand Up @@ -142,8 +142,7 @@ def variable_filter(f, options, groupby=False):
include_entities = 'include_groupby_entities' if groupby else 'include_entities'
ignore_entities = 'ignore_groupby_entities' if groupby else 'ignore_entities'

dependencies = f.get_dependencies(deep=True)
dependencies = [f] if not dependencies else dependencies + [f]
dependencies = f.get_dependencies(deep=True) + [f]
for base_f in dependencies:
if isinstance(base_f, IdentityFeature):
if include_vars in options and base_f.entity.id in options[include_vars]:
Expand Down
12 changes: 7 additions & 5 deletions featuretools/tests/synthesis/test_deep_feature_synthesis.py
Expand Up @@ -981,7 +981,7 @@ def test_primitive_options(es):
if identity_base.entity.id == 'customers':
assert identity_base.get_name() == 'age'
if isinstance(f.primitive, Mean):
assert 'customers' in entities
assert all([entity in ['customers'] for entity in entities])
if isinstance(f.primitive, Mode):
assert 'sessions' not in entities
if isinstance(f.primitive, NumUnique):
Expand Down Expand Up @@ -1014,7 +1014,7 @@ def test_primitive_options(es):
assert identity_base.get_name() == 'signup_date' or \
identity_base.get_name() == 'upgrade_date'
if isinstance(f.primitive, Year):
assert 'customers' in entities
assert all([entity in ['customers'] for entity in entities])


def test_primitive_options_with_globals(es):
Expand Down Expand Up @@ -1072,16 +1072,18 @@ def test_primitive_options_with_globals(es):
entities = [d.entity.id for d in deps]
variables = [d for d in deps if isinstance(d, IdentityFeature)]
if isinstance(f.primitive, Mode):
assert 'sessions' in entities or 'customers' in entities
assert [all([entity in ['sessions', 'customers'] for entity in entities])]
for identity_base in variables:
assert not (identity_base.entity.id == 'customers' and
(identity_base.get_name() == 'age' or
identity_base.get_name() == u'région_id'))
elif isinstance(f.primitive, NumUnique):
assert 'sessions' in entities or 'customers' in entities
assert [all([entity in ['sessions', 'customers'] for entity in entities])]
for identity_base in variables:
if identity_base.entity.id == 'sessions':
assert identity_base.get_name() == 'device_type'
if identity_base.entity.id == 'customers':
assert identity_base.get_name() != 'age'
# All other primitives ignore 'sessions' and 'age'
else:
assert 'sessions' not in entities
Expand Down Expand Up @@ -1157,7 +1159,7 @@ def test_primitive_options_multiple_inputs(es):
entities = [d.entity.id for d in deps]
variables = [d.get_name() for d in deps]
if f.primitive.name == 'trend':
assert 'log' in entities
assert all([entity in ['log'] for entity in entities])
assert 'datetime' in variables
if len(variables) == 2:
assert 'value' != variables[0]

0 comments on commit 5f4f4bb

Please sign in to comment.