diff --git a/featuretools/entityset/entityset.py b/featuretools/entityset/entityset.py index d9f8ce5ef3..f1028f42b9 100644 --- a/featuretools/entityset/entityset.py +++ b/featuretools/entityset/entityset.py @@ -710,6 +710,8 @@ def normalize_entity(self, base_entity_id, new_entity_id, index, transfer_types = {} transfer_types[new_index] = type(base_entity[index]) + for v in additional_variables + copy_variables: + transfer_types[v] = type(base_entity[v]) # create and add new entity new_entity_df = self.get_dataframe(base_entity_id) diff --git a/featuretools/tests/entityset_tests/test_pandas_es.py b/featuretools/tests/entityset_tests/test_pandas_es.py index 1351549abd..b488e9582a 100644 --- a/featuretools/tests/entityset_tests/test_pandas_es.py +++ b/featuretools/tests/entityset_tests/test_pandas_es.py @@ -609,6 +609,25 @@ def test_normalize_entity(self, entityset): assert 'device_name' not in entityset['sessions'].df.columns assert 'device_type' in entityset['device_types'].df.columns + def test_normalize_entity_copies_variable_types(self, entityset): + entityset['log'].convert_variable_type('value', variable_types.Ordinal, convert_data=False) + assert entityset['log'].variable_types['value'] == variable_types.Ordinal + assert entityset['log'].variable_types['priority_level'] == variable_types.Ordinal + entityset.normalize_entity('log', 'values_2', 'value_2', + additional_variables=['priority_level'], + copy_variables=['value'], + make_time_index=False) + + assert len(entityset.get_forward_relationships('log')) == 3 + assert entityset.get_forward_relationships('log')[2].parent_entity.id == 'values_2' + assert 'priority_level' in entityset['values_2'].df.columns + assert 'value' in entityset['values_2'].df.columns + assert 'priority_level' not in entityset['log'].df.columns + assert 'value' in entityset['log'].df.columns + assert 'value_2' in entityset['values_2'].df.columns + assert entityset['values_2'].variable_types['priority_level'] == variable_types.Ordinal + assert entityset['values_2'].variable_types['value'] == variable_types.Ordinal + def test_make_time_index_keeps_original_sorting(self): trips = { 'trip_id': [999 - i for i in xrange(1000)],