From 270b60ea18327d0364789da049c9b83c2d1b9b7e Mon Sep 17 00:00:00 2001 From: Charles Bradshaw Date: Wed, 11 Oct 2017 15:52:47 -0400 Subject: [PATCH] Head fix (#9) * Added Cutoff Time Functionality to Head * Updated head function logic and formatting * Remvoed module scope and added head entity tests * Added TODO and renamed variables in head function * Updated Doc Headers for head function * Swapped pd.Merge for pd.isin * Removed print and comments from head test function * Added head cutoff_time support for datetime * Added head test for datetime --- featuretools/entityset/entity.py | 26 ++++++++++++++++-- featuretools/entityset/entityset.py | 6 +++-- .../tests/entityset_tests/test_pandas_es.py | 27 ++++++++++++++++++- featuretools/variable_types/variable.py | 18 +++---------- 4 files changed, 58 insertions(+), 19 deletions(-) diff --git a/featuretools/entityset/entity.py b/featuretools/entityset/entity.py index a9382fabc2..68f7c280e5 100644 --- a/featuretools/entityset/entity.py +++ b/featuretools/entityset/entity.py @@ -120,13 +120,35 @@ def head(self, n=10, cutoff_time=None): Args: n (int) : number of instances to return + cutoff_time (pd.Timestamp,pd.DataFrame) : Timestamp(s) to restrict rows Returns: :class:`pd.DataFrame` : Pandas DataFrame - """ - return self.df.head(n=n) + + if cutoff_time is None: + valid_data = self.df + + elif isinstance(cutoff_time, pd.Timestamp) or \ + isinstance(cutoff_time, datetime): + valid_data = self.df[self.df[self.time_index] < cutoff_time] + + elif isinstance(cutoff_time, pd.DataFrame): + + instance_ids, time = list(cutoff_time) + + # TODO filtering the top n during "isin" would be more efficient + valid_data = self.df[ + self.df[self.index].isin(cutoff_time[instance_ids])] + valid_data = valid_data[ + valid_data[self.time_index] < cutoff_time[time]] + + else: + raise ValueError( + 'cutoff_time must be None, a Datetime, a pd.Timestamp, or a pd.DataFrame') + + return valid_data.head(n) def get_column_type(self, column_id): """ get type of column in underlying data structure """ diff --git a/featuretools/entityset/entityset.py b/featuretools/entityset/entityset.py index 63b8291254..db4aebd9ed 100644 --- a/featuretools/entityset/entityset.py +++ b/featuretools/entityset/entityset.py @@ -115,9 +115,11 @@ def get_sample(self, n): def head(self, entity_id, n=10, variable_id=None, cutoff_time=None): if variable_id is None: - return self.entity_stores[entity_id].head(n, cutoff_time=cutoff_time) + return self.entity_stores[entity_id].head( + n, cutoff_time=cutoff_time) else: - return self.entity_stores[entity_id].head(n, cutoff_time=cutoff_time)[variable_id] + return self.entity_stores[entity_id].head( + n, cutoff_time=cutoff_time)[variable_id] def get_instance_data(self, entity_id, instance_ids): return self.entity_stores[entity_id].query_by_values(instance_ids) diff --git a/featuretools/tests/entityset_tests/test_pandas_es.py b/featuretools/tests/entityset_tests/test_pandas_es.py index 50b73fd71a..bf18fac2f7 100644 --- a/featuretools/tests/entityset_tests/test_pandas_es.py +++ b/featuretools/tests/entityset_tests/test_pandas_es.py @@ -8,7 +8,7 @@ import copy -@pytest.fixture(scope='module') +@pytest.fixture() def entityset(): return make_ecommerce_entityset() @@ -632,3 +632,28 @@ def test_normalize_entity_new_time_index(self, entityset): assert entityset['values'].time_index == 'value_time' assert 'value_time' in entityset['values'].df.columns assert len(entityset['values'].df.columns) == 3 + + +def test_head_of_entity(entityset): + + entity = entityset['log'] + assert(isinstance(entityset.head('log', 3), pd.DataFrame)) + assert(isinstance(entity.head(3), pd.DataFrame)) + assert(isinstance(entity['product_id'].head(3), pd.DataFrame)) + + assert(entity.head(n=5).shape == (5, 9)) + + timestamp1 = pd.to_datetime("2011-04-09 10:30:10") + timestamp2 = pd.to_datetime("2011-04-09 10:30:18") + datetime1 = datetime(2011, 4, 9, 10, 30, 18) + + assert(entity.head(5, cutoff_time=timestamp1).shape == (2, 9)) + assert(entity.head(5, cutoff_time=timestamp2).shape == (3, 9)) + assert(entity.head(5, cutoff_time=datetime1).shape == (3, 9)) + + time_list = [timestamp2]*3+[timestamp1]*2 + cutoff_times = pd.DataFrame(zip(range(5), time_list)) + + assert(entityset.head('log', 5, cutoff_time=cutoff_times).shape == (3, 9)) + assert(entity.head(5, cutoff_time=cutoff_times).shape == (3, 9)) + assert(entity['product_id'].head(5, cutoff_time=cutoff_times).shape == (3, 1)) diff --git a/featuretools/variable_types/variable.py b/featuretools/variable_types/variable.py index ee1ae2e848..7806fbc082 100644 --- a/featuretools/variable_types/variable.py +++ b/featuretools/variable_types/variable.py @@ -163,25 +163,15 @@ def head(self, n=10, cutoff_time=None): Args: n (int) : number of instances to return + cutoff_time (pd.Timestamp,pd.DataFrame) : Timestamp(s) to restrict rows Returns: :class:`pd.DataFrame` : Pandas DataFrame """ - if cutoff_time is None: - series = self.entityset.head(entity_id=self.entity_id, n=n, - variable_id=self.id) - else: - from featuretools.computational_backends.calculate_feature_matrix import calculate_feature_matrix - from featuretools.primitives import Feature - - f = Feature(self) - - instance_ids = self.entityset.get_top_n_instances(self.entity.id, n) - cutoff_time = pd.DataFrame({'instance_id': instance_ids}) - cutoff_time['time'] = cutoff_time - cfm = calculate_feature_matrix([f], cutoff_time=cutoff_time) - series = cfm[f.get_name()] + series = self.entityset.head(entity_id=self.entity_id, + n=n, variable_id=self.id, + cutoff_time=cutoff_time) return series.to_frame() @property