Skip to content

Commit

Permalink
Head fix (#9)
Browse files Browse the repository at this point in the history
* Added Cutoff Time Functionality to Head

* Updated head function logic and formatting

* Remvoed module scope and added head entity tests

* Added TODO and renamed variables in head function

* Updated Doc Headers for head function

* Swapped pd.Merge for pd.isin

* Removed print and comments from head test function

* Added head cutoff_time support for datetime

* Added head test for datetime
  • Loading branch information
CharlesBradshaw authored and kmax12 committed Oct 11, 2017
1 parent 309c41e commit 270b60e
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 19 deletions.
26 changes: 24 additions & 2 deletions featuretools/entityset/entity.py
Expand Up @@ -120,13 +120,35 @@ def head(self, n=10, cutoff_time=None):
Args:
n (int) : number of instances to return
cutoff_time (pd.Timestamp,pd.DataFrame) : Timestamp(s) to restrict rows
Returns:
:class:`pd.DataFrame` : Pandas DataFrame
"""
return self.df.head(n=n)

if cutoff_time is None:
valid_data = self.df

elif isinstance(cutoff_time, pd.Timestamp) or \
isinstance(cutoff_time, datetime):
valid_data = self.df[self.df[self.time_index] < cutoff_time]

elif isinstance(cutoff_time, pd.DataFrame):

instance_ids, time = list(cutoff_time)

# TODO filtering the top n during "isin" would be more efficient
valid_data = self.df[
self.df[self.index].isin(cutoff_time[instance_ids])]
valid_data = valid_data[
valid_data[self.time_index] < cutoff_time[time]]

else:
raise ValueError(
'cutoff_time must be None, a Datetime, a pd.Timestamp, or a pd.DataFrame')

return valid_data.head(n)

def get_column_type(self, column_id):
""" get type of column in underlying data structure """
Expand Down
6 changes: 4 additions & 2 deletions featuretools/entityset/entityset.py
Expand Up @@ -115,9 +115,11 @@ def get_sample(self, n):

def head(self, entity_id, n=10, variable_id=None, cutoff_time=None):
if variable_id is None:
return self.entity_stores[entity_id].head(n, cutoff_time=cutoff_time)
return self.entity_stores[entity_id].head(
n, cutoff_time=cutoff_time)
else:
return self.entity_stores[entity_id].head(n, cutoff_time=cutoff_time)[variable_id]
return self.entity_stores[entity_id].head(
n, cutoff_time=cutoff_time)[variable_id]

def get_instance_data(self, entity_id, instance_ids):
return self.entity_stores[entity_id].query_by_values(instance_ids)
Expand Down
27 changes: 26 additions & 1 deletion featuretools/tests/entityset_tests/test_pandas_es.py
Expand Up @@ -8,7 +8,7 @@
import copy


@pytest.fixture(scope='module')
@pytest.fixture()
def entityset():
return make_ecommerce_entityset()

Expand Down Expand Up @@ -632,3 +632,28 @@ def test_normalize_entity_new_time_index(self, entityset):
assert entityset['values'].time_index == 'value_time'
assert 'value_time' in entityset['values'].df.columns
assert len(entityset['values'].df.columns) == 3


def test_head_of_entity(entityset):

entity = entityset['log']
assert(isinstance(entityset.head('log', 3), pd.DataFrame))
assert(isinstance(entity.head(3), pd.DataFrame))
assert(isinstance(entity['product_id'].head(3), pd.DataFrame))

assert(entity.head(n=5).shape == (5, 9))

timestamp1 = pd.to_datetime("2011-04-09 10:30:10")
timestamp2 = pd.to_datetime("2011-04-09 10:30:18")
datetime1 = datetime(2011, 4, 9, 10, 30, 18)

assert(entity.head(5, cutoff_time=timestamp1).shape == (2, 9))
assert(entity.head(5, cutoff_time=timestamp2).shape == (3, 9))
assert(entity.head(5, cutoff_time=datetime1).shape == (3, 9))

time_list = [timestamp2]*3+[timestamp1]*2
cutoff_times = pd.DataFrame(zip(range(5), time_list))

assert(entityset.head('log', 5, cutoff_time=cutoff_times).shape == (3, 9))
assert(entity.head(5, cutoff_time=cutoff_times).shape == (3, 9))
assert(entity['product_id'].head(5, cutoff_time=cutoff_times).shape == (3, 1))
18 changes: 4 additions & 14 deletions featuretools/variable_types/variable.py
Expand Up @@ -163,25 +163,15 @@ def head(self, n=10, cutoff_time=None):
Args:
n (int) : number of instances to return
cutoff_time (pd.Timestamp,pd.DataFrame) : Timestamp(s) to restrict rows
Returns:
:class:`pd.DataFrame` : Pandas DataFrame
"""
if cutoff_time is None:
series = self.entityset.head(entity_id=self.entity_id, n=n,
variable_id=self.id)
else:
from featuretools.computational_backends.calculate_feature_matrix import calculate_feature_matrix
from featuretools.primitives import Feature

f = Feature(self)

instance_ids = self.entityset.get_top_n_instances(self.entity.id, n)
cutoff_time = pd.DataFrame({'instance_id': instance_ids})
cutoff_time['time'] = cutoff_time
cfm = calculate_feature_matrix([f], cutoff_time=cutoff_time)
series = cfm[f.get_name()]
series = self.entityset.head(entity_id=self.entity_id,
n=n, variable_id=self.id,
cutoff_time=cutoff_time)
return series.to_frame()

@property
Expand Down

0 comments on commit 270b60e

Please sign in to comment.