diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index 8beeb99f55..434d7ffabd 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -12,7 +12,7 @@ Changelog * Fixes * Fixed entity set deserialization (:pr:`720`) * Added error message when DateTimeIndex is a variable but not set as the time_index (:pr:`723`) - * Fixed CumCount and other group-by transform primitives that take ID as input (:pr:`733`) + * Fixed CumCount and other group-by transform primitives that take ID as input (:pr:`733`, :pr:`754`) * Fix progress bar undercounting (:pr:`743`) * Updated training_window error assertion to only check against observations (:pr:`728`) * Don't delete the whole destination folder while saving entityset (:pr:`717`) diff --git a/featuretools/synthesis/deep_feature_synthesis.py b/featuretools/synthesis/deep_feature_synthesis.py index a8696ed904..90a8ab2c88 100644 --- a/featuretools/synthesis/deep_feature_synthesis.py +++ b/featuretools/synthesis/deep_feature_synthesis.py @@ -517,16 +517,24 @@ def _build_transform_features(self, all_features, entity, max_depth=0, entity, new_max_depth, input_types, - groupby_prim, - require_direct_input=require_direct_input) + groupby_prim) # get IDs to use as groupby id_matches = self._features_by_type(all_features=all_features, entity=entity, max_depth=new_max_depth, variable_type=set([Id])) + # If require_direct_input, require a DirectFeature in input or as a + # groupby, and don't create features of inputs/groupbys which are + # all direct features with the same relationship path for matching_input in matching_inputs: if all(bf.number_output_features == 1 for bf in matching_input): for id_groupby in id_matches: + if require_direct_input and ( + _all_direct_and_same_path(matching_input + (id_groupby,)) or + not any([isinstance(feature, DirectFeature) for + feature in (matching_input + (id_groupby, ))]) + ): + continue new_f = GroupByTransformFeature(list(matching_input), groupby=id_groupby, primitive=groupby_prim) diff --git a/featuretools/tests/synthesis/test_deep_feature_synthesis.py b/featuretools/tests/synthesis/test_deep_feature_synthesis.py index 7a6abb3caa..fa1ad04db1 100644 --- a/featuretools/tests/synthesis/test_deep_feature_synthesis.py +++ b/featuretools/tests/synthesis/test_deep_feature_synthesis.py @@ -271,6 +271,17 @@ def test_make_groupby_features(es): "CUM_SUM(value) by session_id")) +def test_make_indirect_groupby_features(es): + dfs_obj = DeepFeatureSynthesis(target_entity_id='log', + entityset=es, + agg_primitives=[], + trans_primitives=[], + groupby_trans_primitives=['cum_sum']) + features = dfs_obj.build_features() + assert (feature_with_name(features, + "CUM_SUM(products.rating) by session_id")) + + def test_make_groupby_features_with_id(es): dfs_obj = DeepFeatureSynthesis(target_entity_id='sessions', entityset=es,