Skip to content

Commit

Permalink
Fix transform stacking on multi-output aggregation (#394)
Browse files Browse the repository at this point in the history
* added test case

* added dfs test

* comment fix

* move test

* fix

* use func

* target entity:

* better test case

* remove import

* isort
  • Loading branch information
gsheni committed Feb 2, 2019
1 parent a22d852 commit 6d991ad
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 4 deletions.
2 changes: 1 addition & 1 deletion featuretools/feature_base/feature_base.py
Expand Up @@ -450,7 +450,7 @@ def __init__(self, base_features, primitive):
base_features = [_check_feature(base_features)]

# R TODO handle stacking on sub-features
assert (bf.number_output_features == 1 for bf in base_features)
assert all(bf.number_output_features == 1 for bf in base_features)

super(TransformFeature, self).__init__(base_features[0].entity,
base_features, primitive=primitive)
Expand Down
8 changes: 5 additions & 3 deletions featuretools/synthesis/deep_feature_synthesis.py
Expand Up @@ -486,9 +486,11 @@ def _build_transform_features(self, all_features, entity, max_depth=0):
commutative=trans_prim.commutative)

for matching_input in matching_inputs:
new_f = TransformFeature(matching_input, primitive=trans_prim)
self._handle_new_feature(all_features=all_features,
new_feature=new_f)
if all(bf.number_output_features == 1 for bf in matching_input):
new_f = TransformFeature(matching_input,
primitive=trans_prim)
self._handle_new_feature(all_features=all_features,
new_feature=new_f)

def _build_forward_features(self, all_features, parent_entity,
child_entity, relationship, max_depth=0):
Expand Down
11 changes: 11 additions & 0 deletions featuretools/tests/dfs_tests/test_deep_feature_synthesis.py
Expand Up @@ -24,6 +24,7 @@
Last,
Mode,
NMostCommon,
NotEqual,
Sum,
TimeSincePrevious
)
Expand Down Expand Up @@ -662,6 +663,16 @@ def test_transform_consistency():
assert feature_with_name(feature_defs, 'OR(AND(b, b1), b1)')


def test_transform_no_stack_agg(es):
feature_defs = ft.dfs(entityset=es,
target_entity="customers",
agg_primitives=[NMostCommon],
trans_primitives=[NotEqual],
max_depth=3,
features_only=True)
assert not feature_with_name(feature_defs, 'id != N_MOST_COMMON(sessions.device_type)')


def test_intialized_trans_prim(es):
prim = IsIn(list_of_outputs=['coke zero'])
dfs_obj = DeepFeatureSynthesis(target_entity_id='log',
Expand Down
9 changes: 9 additions & 0 deletions featuretools/tests/primitive_tests/test_transform_features.py
Expand Up @@ -43,6 +43,7 @@
NotEqualScalar,
NumCharacters,
NumWords,
NMostCommon,
Percentile,
ScalarSubtractNumericFeature,
Second,
Expand Down Expand Up @@ -1052,6 +1053,14 @@ def gen_feat_names(self):
assert base_feature.hash() != join_time_split.hash()


def test_tranform_stack_agg(es):
topn = ft.Feature(es['log']['product_id'],
parent_entity=es['customers'],
primitive=NMostCommon(n=3))
with pytest.raises(AssertionError):
ft.Feature(topn, primitive=Percentile)


def test_feature_names_inherit_from_make_trans_primitive():
# R TODO
pass
Expand Down

0 comments on commit 6d991ad

Please sign in to comment.