diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index 23386ce0e1..52f6db8b10 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -6,6 +6,7 @@ Changelog * Enhancements * Added First primitive (:pr:`770`) * Added Entropy aggregation primitive (:pr:`779`) + * Allow custom naming for multi-output primitives (:pr:`780`) * Fixes * Prevents user from removing base entity time index using additional_variables (:pr:`768`) * Fixes error when a multioutput primitive was supplied to dfs as a groupby trans primitive (:pr:`786`) diff --git a/featuretools/feature_base/feature_base.py b/featuretools/feature_base/feature_base.py index accc27e37a..bb5b922b05 100644 --- a/featuretools/feature_base/feature_base.py +++ b/featuretools/feature_base/feature_base.py @@ -83,9 +83,8 @@ def get_name(self): return self._name def get_names(self): - n = self.number_output_features if not self._names: - self._names = [self.generate_name() + "[{}]".format(i) for i in range(n)] + self._names = self.generate_names() return self._names def get_feature_names(self): @@ -618,6 +617,13 @@ def generate_name(self): where_str=self._where_str(), use_prev_str=self._use_prev_str()) + def generate_names(self): + return self.primitive.generate_names(base_feature_names=[bf.get_name() for bf in self.base_features], + relationship_path_name=self.relationship_path_name(), + parent_entity_id=self.parent_entity.id, + where_str=self._where_str(), + use_prev_str=self._use_prev_str()) + def get_arguments(self): return { 'name': self._name, @@ -668,6 +674,9 @@ def copy(self): def generate_name(self): return self.primitive.generate_name(base_feature_names=[bf.get_name() for bf in self.base_features]) + def generate_names(self): + return self.primitive.generate_names(base_feature_names=[bf.get_name() for bf in self.base_features]) + def get_arguments(self): return { 'name': self._name, @@ -713,6 +722,12 @@ def generate_name(self): _name = self.primitive.generate_name(base_names) return u"{} by {}".format(_name, self.groupby.get_name()) + def generate_names(self): + base_names = [bf.get_name() for bf in self.base_features[:-1]] + _names = self.primitive.generate_names(base_names) + names = [name + " by {}".format(self.groupby.get_name()) for name in _names] + return names + def get_arguments(self): # Do not include groupby in base_features. feature_names = [feat.unique_name() for feat in self.base_features diff --git a/featuretools/primitives/base/aggregation_primitive_base.py b/featuretools/primitives/base/aggregation_primitive_base.py index aa75d220a8..07e76f7c0f 100755 --- a/featuretools/primitives/base/aggregation_primitive_base.py +++ b/featuretools/primitives/base/aggregation_primitive_base.py @@ -25,6 +25,16 @@ def generate_name(self, base_feature_names, relationship_path_name, self.get_args_string(), ) + def generate_names(self, base_feature_names, relationship_path_name, + parent_entity_id, where_str, use_prev_str): + n = self.number_output_features + base_name = self.generate_name(base_feature_names, + relationship_path_name, + parent_entity_id, + where_str, + use_prev_str) + return [base_name + "[%s]" % i for i in range(n)] + def make_agg_primitive(function, input_types, return_type, name=None, stack_on_self=True, stack_on=None, diff --git a/featuretools/primitives/base/primitive_base.py b/featuretools/primitives/base/primitive_base.py index 5e35cf571e..3a5f680e72 100644 --- a/featuretools/primitives/base/primitive_base.py +++ b/featuretools/primitives/base/primitive_base.py @@ -46,6 +46,9 @@ def __call__(self, *args, **kwargs): def generate_name(self): raise NotImplementedError("Subclass must implement") + def generate_names(self): + raise NotImplementedError("Subclass must implement") + def get_function(self): raise NotImplementedError("Subclass must implement") diff --git a/featuretools/primitives/base/transform_primitive_base.py b/featuretools/primitives/base/transform_primitive_base.py index bb7e528d03..4472581c55 100755 --- a/featuretools/primitives/base/transform_primitive_base.py +++ b/featuretools/primitives/base/transform_primitive_base.py @@ -20,6 +20,11 @@ def generate_name(self, base_feature_names): self.get_args_string(), ) + def generate_names(self, base_feature_names): + n = self.number_output_features + base_name = self.generate_name(base_feature_names) + return [base_name + "[%s]" % i for i in range(n)] + def make_trans_primitive(function, input_types, return_type, name=None, description=None, cls_attributes=None, diff --git a/featuretools/tests/primitive_tests/test_agg_feats.py b/featuretools/tests/primitive_tests/test_agg_feats.py index 10635cfd3e..d02eb5e6c0 100644 --- a/featuretools/tests/primitive_tests/test_agg_feats.py +++ b/featuretools/tests/primitive_tests/test_agg_feats.py @@ -644,3 +644,38 @@ def _assert_agg_feats_equal(f1, f2): assert f1.parent_entity.id == f2.parent_entity.id assert f1.relationship_path == f2.relationship_path assert f1.use_previous == f2.use_previous + + +def test_override_multi_feature_names(es): + def gen_custom_names(primitive, base_feature_names, relationship_path_name, + parent_entity_id, where_str, use_prev_str): + base_string = 'Custom_%s({}.{})'.format(parent_entity_id, base_feature_names) + return [base_string % i for i in range(primitive.number_output_features)] + + def pd_top3(x): + array = np.array(x.value_counts()[:3].index) + if len(array) < 3: + filler = np.full(3 - len(array), np.nan) + array = np.append(array, filler) + return array + + num_features = 3 + NMostCommoner = make_agg_primitive(function=pd_top3, + input_types=[Numeric], + return_type=Discrete, + number_output_features=num_features, + cls_attributes={"generate_names": gen_custom_names}) + + fm, features = ft.dfs(entityset=es, + target_entity="products", + instance_ids=[0, 1, 2], + agg_primitives=[NMostCommoner], + trans_primitives=[]) + + expected_names = [] + base_names = [['value'], ['value_2'], ['value_many_nans']] + for name in base_names: + expected_names += gen_custom_names(NMostCommoner, name, None, 'products', None, None) + + for name in expected_names: + assert name in fm.columns diff --git a/featuretools/tests/primitive_tests/test_groupby_transform_primitives.py b/featuretools/tests/primitive_tests/test_groupby_transform_primitives.py index 63223ec1f4..a7eb4362d5 100644 --- a/featuretools/tests/primitive_tests/test_groupby_transform_primitives.py +++ b/featuretools/tests/primitive_tests/test_groupby_transform_primitives.py @@ -419,18 +419,76 @@ def multi_cum_sum(x): agg_primitives=[], groupby_trans_primitives=[MultiCumSum, CumSum, CumMax, CumMin]) - correct_answers = [ - [fm['CUM_SUM(age) by cohort'], fm['CUM_SUM(age) by région_id']], - [fm['CUM_MAX(age) by cohort'], fm['CUM_MAX(age) by région_id']], - [fm['CUM_MIN(age) by cohort'], fm['CUM_MIN(age) by région_id']] + # Calculate output in a separate DFS call to make sure the multi-output code + # does not alter any values + fm2, _ = dfs(entityset=es, + target_entity='customers', + trans_primitives=[], + agg_primitives=[], + groupby_trans_primitives=[CumSum, CumMax, CumMin]) + + answer_cols = [ + ['CUM_SUM(age) by cohort', 'CUM_SUM(age) by région_id'], + ['CUM_MAX(age) by cohort', 'CUM_MAX(age) by région_id'], + ['CUM_MIN(age) by cohort', 'CUM_MIN(age) by région_id'] ] for i in range(3): - f = 'MULTI_CUM_SUM(age) by cohort[%d]' % i + # Check that multi-output gives correct answers + f = 'MULTI_CUM_SUM(age)[%d] by cohort' % i assert f in fm.columns - for x, y in zip(fm[f].values, correct_answers[i][0].values): + for x, y in zip(fm[f].values, fm[answer_cols[i][0]].values): assert x == y - f = 'MULTI_CUM_SUM(age) by région_id[%d]' % i + f = 'MULTI_CUM_SUM(age)[%d] by région_id' % i assert f in fm.columns - for x, y in zip(fm[f].values, correct_answers[i][1].values): + for x, y in zip(fm[f].values, fm[answer_cols[i][1]].values): + assert x == y + # Verify single output results are unchanged by inclusion of + # multi-output primitive + for x, y in zip(fm[answer_cols[i][0]], fm2[answer_cols[i][0]]): + assert x == y + for x, y in zip(fm[answer_cols[i][1]], fm2[answer_cols[i][1]]): + assert x == y + + +def test_groupby_with_multioutput_primitive_custom_names(es): + def gen_custom_names(primitive, base_feature_names): + return ["CUSTOM_SUM", "CUSTOM_MAX", "CUSTOM_MIN"] + + def multi_cum_sum(x): + return x.cumsum(), x.cummax(), x.cummin() + + num_features = 3 + MultiCumSum = make_trans_primitive(function=multi_cum_sum, + input_types=[Numeric], + return_type=Numeric, + number_output_features=num_features, + cls_attributes={"generate_names": gen_custom_names}) + + fm, _ = dfs(entityset=es, + target_entity='customers', + trans_primitives=[], + agg_primitives=[], + groupby_trans_primitives=[MultiCumSum, CumSum, CumMax, CumMin]) + + answer_cols = [ + ['CUM_SUM(age) by cohort', 'CUM_SUM(age) by région_id'], + ['CUM_MAX(age) by cohort', 'CUM_MAX(age) by région_id'], + ['CUM_MIN(age) by cohort', 'CUM_MIN(age) by région_id'] + ] + + expected_names = [ + ['CUSTOM_SUM by cohort', 'CUSTOM_SUM by région_id'], + ['CUSTOM_MAX by cohort', 'CUSTOM_MAX by région_id'], + ['CUSTOM_MIN by cohort', 'CUSTOM_MIN by région_id'] + ] + + for i in range(3): + f = expected_names[i][0] + assert f in fm.columns + for x, y in zip(fm[f].values, fm[answer_cols[i][0]].values): + assert x == y + f = expected_names[i][1] + assert f in fm.columns + for x, y in zip(fm[f].values, fm[answer_cols[i][1]].values): assert x == y diff --git a/featuretools/tests/primitive_tests/test_transform_features.py b/featuretools/tests/primitive_tests/test_transform_features.py index e019e2cf67..f76862ce9e 100644 --- a/featuretools/tests/primitive_tests/test_transform_features.py +++ b/featuretools/tests/primitive_tests/test_transform_features.py @@ -860,3 +860,31 @@ def _map(x): assert fm["MOD4(value)"][0] == 0 assert fm["MOD4(value)"][14] == 2 assert pd.isnull(fm["MOD4(value)"][15]) + + +def test_override_multi_feature_names(es): + def gen_custom_names(primitive, base_feature_names): + return ['Above18(%s)' % base_feature_names, + 'Above21(%s)' % base_feature_names, + 'Above65(%s)' % base_feature_names] + + def is_greater(x): + return x > 18, x > 21, x > 65 + + num_features = 3 + IsGreater = make_trans_primitive(function=is_greater, + input_types=[Numeric], + return_type=Numeric, + number_output_features=num_features, + cls_attributes={"generate_names": gen_custom_names}) + + fm, features = ft.dfs(entityset=es, + target_entity="customers", + instance_ids=[0, 1, 2], + agg_primitives=[], + trans_primitives=[IsGreater]) + + expected_names = gen_custom_names(IsGreater, ['age']) + + for name in expected_names: + assert name in fm.columns