Skip to content

Commit

Permalink
Remove max_hlevel from DFS (#608)
Browse files Browse the repository at this point in the history
This was not exposed in the dfs method so it always used the default
value of 2, it was undocumented, and it was unclear if it provided any
value.

Also remove EntitySet.find_path as it is no longer used.
  • Loading branch information
CJStadler committed Jun 19, 2019
1 parent 7ecd4c7 commit d1390d5
Show file tree
Hide file tree
Showing 5 changed files with 3 additions and 181 deletions.
1 change: 1 addition & 0 deletions docs/source/changelog.rst
Expand Up @@ -34,6 +34,7 @@ Changelog
* Refactor get_pandas_data_slice to take single entity (:pr:`547`)
* Updates TimeSincePrevious and Diff Primitives (:pr:`561`)
* Remove unecessary time_last variable (:pr:`546`)
* Remove max_hlevel from DeepFeatureSynthesis (:pr:`608`)
* Documentation Changes
* Add Featuretools Enterprise to documentation (:pr:`563`)
* Miscellaneous changes (:pr:`552`, :pr:`573`, :pr:`577`, :pr:`599`)
Expand Down
79 changes: 0 additions & 79 deletions featuretools/entityset/entityset.py
Expand Up @@ -271,83 +271,6 @@ def add_relationship(self, relationship):
# Relationship access/helper methods ###################################
###########################################################################

def find_path(self, start_entity_id, goal_entity_id,
include_num_forward=False):
"""Find a path in the entityset represented as a DAG
between start_entity and goal_entity
Args:
start_entity_id (str) : Id of entity to start the search from.
goal_entity_id (str) : Id of entity to find forward path to.
include_num_forward (bool) : If True, return number of forward
relationships in path if the path ends on a forward
relationship, otherwise return 0.
Returns:
List of relationships that go from start entity to goal
entity. None is returned if no path exists.
If include_num_forward is True,
returns a tuple of (relationship_list, forward_distance).
See Also:
:func:`EntitySet.find_forward_paths`
:func:`EntitySet.find_backward_paths`
"""
if start_entity_id == goal_entity_id:
if include_num_forward:
return [], 0
else:
return []

# Search for path using BFS to get the shortest path.
# Start by initializing the queue with all relationships from start entity
queue = [[r] for r in self.get_forward_relationships(start_entity_id)] + \
[[r] for r in self.get_backward_relationships(start_entity_id)]
visited = set([start_entity_id])

while len(queue) > 0:
# get first path from queue
current_path = queue.pop(0)

# last entity in path will be which ever one we haven't visited
if current_path[-1].parent_entity.id not in visited:
next_entity_id = current_path[-1].parent_entity.id
elif current_path[-1].child_entity.id not in visited:
next_entity_id = current_path[-1].child_entity.id
else:
# if we've visited both, we don't need to explore this path further
continue

# we've found a path to goal
if next_entity_id == goal_entity_id:
if include_num_forward:
# count the number of forward relationships along this path
# starting from beginning
check_entity = start_entity_id
num_forward = 0
for r in current_path:
# if the current entity we're checking is a child, that means the
# relationship is a forward and the next entity to check is the parent
if r.child_entity.id == check_entity:
num_forward += 1
check_entity = r.parent_entity.id
else:
check_entity = r.child_entity.id

return current_path, num_forward
else:
return current_path

next_relationships = self.get_forward_relationships(next_entity_id)
next_relationships += self.get_backward_relationships(next_entity_id)

for r in next_relationships:
queue.append(current_path + [r])

visited.add(next_entity_id)
e = "No path from {} to {}. Check that all entities are connected by relationships".format(start_entity_id, goal_entity_id)
raise ValueError(e)

def find_forward_paths(self, start_entity_id, goal_entity_id):
"""
Generator which yields all forward paths between a start and goal
Expand All @@ -359,7 +282,6 @@ def find_forward_paths(self, start_entity_id, goal_entity_id):
See Also:
:func:`BaseEntitySet.find_backward_paths`
:func:`BaseEntitySet.find_path`
"""
for sub_entity_id, path in self._forward_entity_paths(start_entity_id):
if sub_entity_id == goal_entity_id:
Expand All @@ -376,7 +298,6 @@ def find_backward_paths(self, start_entity_id, goal_entity_id):
See Also:
:func:`BaseEntitySet.find_forward_paths`
:func:`BaseEntitySet.find_path`
"""
for path in self.find_forward_paths(goal_entity_id, start_entity_id):
# Reverse path
Expand Down
38 changes: 2 additions & 36 deletions featuretools/synthesis/deep_feature_synthesis.py
Expand Up @@ -53,9 +53,6 @@ class DeepFeatureSynthesis(object):
max_depth (int, optional) : maximum allowed depth of features.
Default: 2. If -1, no limit.
max_hlevel (int, optional) : #TODO how to document.
Default: 2. If -1, no limit.
max_features (int, optional) : Cap the number of generated features to
this number. If -1, no limit.
Expand Down Expand Up @@ -90,7 +87,6 @@ def __init__(self,
where_primitives=None,
groupby_trans_primitives=None,
max_depth=2,
max_hlevel=2,
max_features=-1,
allowed_paths=None,
ignore_entities=None,
Expand All @@ -105,15 +101,11 @@ def __init__(self,
msg = 'Provided target entity %s does not exist in %s' % (target_entity_id, es_name)
raise KeyError(msg)

# need to change max_depth and max_hlevel to None because DFs terminates when <0
# need to change max_depth to None because DFs terminates when <0
if max_depth == -1:
max_depth = None
self.max_depth = max_depth

if max_hlevel == -1:
max_hlevel = None
self.max_hlevel = max_hlevel

self.max_features = max_features

self.allowed_paths = allowed_paths
Expand Down Expand Up @@ -401,9 +393,6 @@ def _handle_new_feature(self, new_feature, all_features):
Raises:
Exception: Attempted to add a single feature multiple times
"""
if (self.max_hlevel is not None and
self._max_hlevel(new_feature) > self.max_hlevel):
return
entity_id = new_feature.entity.id
name = new_feature.unique_name()

Expand Down Expand Up @@ -648,16 +637,11 @@ def _features_by_type(self, all_features, entity, max_depth,
if (variable_type == variable_types.PandasTypes._all or
f.variable_type == variable_type or
any(issubclass(f.variable_type, vt) for vt in variable_type)):
if ((max_depth is None or self._get_depth(f) <= max_depth) and
(self.max_hlevel is None or
self._max_hlevel(f) <= self.max_hlevel)):
if max_depth is None or f.get_depth(stop_at=self.seed_features) <= max_depth:
selected_features.append(f)

return selected_features

def _get_depth(self, f):
return f.get_depth(stop_at=self.seed_features)

def _feature_in_relationship_path(self, relationship_path, feature):
# must be identity feature to be in the relationship path
if not isinstance(feature, IdentityFeature):
Expand All @@ -674,24 +658,6 @@ def _feature_in_relationship_path(self, relationship_path, feature):

return False

def _max_hlevel(self, f):
# for each base_feat along each path in f,
# if base_feat is a direct_feature of an agg_primitive
# determine aggfeat's hlevel
# return max hlevel
deps = [f] + f.get_dependencies(deep=True)
hlevel = 0
for d in deps:
if isinstance(d, DirectFeature) and \
isinstance(d.base_features[0], AggregationFeature):

assert d.parent_entity.id == d.base_features[0].entity.id
path, new_hlevel = self.es.find_path(self.target_entity_id,
d.parent_entity.id,
include_num_forward=True)
hlevel = max(hlevel, new_hlevel)
return hlevel


def check_stacking(primitive, inputs):
"""checks if features in inputs can be used with supplied primitive
Expand Down
33 changes: 0 additions & 33 deletions featuretools/tests/entityset_tests/test_es_metadata.py
Expand Up @@ -203,39 +203,6 @@ def test_find_backward_paths_multiple_relationships(games_es):
assert r2.parent_variable.id == 'id'


def test_find_path(es):
path, forward = es.find_path('products', 'customers',
include_num_forward=True)

assert len(path) == 3
assert forward == 2
assert path[0].child_entity.id == 'log'
assert path[0].parent_entity.id == 'products'
assert path[1].child_entity.id == 'log'
assert path[1].parent_entity.id == 'sessions'
assert path[2].child_entity.id == 'sessions'
assert path[2].parent_entity.id == 'customers'


def test_find_path_same_entity(es):
path, forward = es.find_path('products', 'products',
include_num_forward=True)
assert len(path) == 0
assert forward == 0

# also test include_num_forward==False
path = es.find_path('products', 'products',
include_num_forward=False)
assert len(path) == 0


def test_find_path_no_path_found(es):
es.relationships = []
error_text = "No path from products to customers. Check that all entities are connected by relationships"
with pytest.raises(ValueError, match=error_text):
es.find_path('products', 'customers')


def test_has_unique_path(diamond_es):
assert diamond_es.has_unique_forward_path('customers', 'regions')
assert not diamond_es.has_unique_forward_path('transactions', 'regions')
Expand Down
33 changes: 0 additions & 33 deletions featuretools/tests/synthesis/test_deep_feature_synthesis.py
Expand Up @@ -575,39 +575,6 @@ def test_dfeats_where(es):
features, 'COUNT(log WHERE products.department = electronics)'))


def test_max_hlevel(es):
kwargs = dict(
target_entity_id='log',
entityset=es,
agg_primitives=[Count, Last],
trans_primitives=[Hour],
max_depth=-1,
)

dfs_h_n1 = DeepFeatureSynthesis(max_hlevel=-1, **kwargs)
dfs_h_0 = DeepFeatureSynthesis(max_hlevel=0, **kwargs)
dfs_h_1 = DeepFeatureSynthesis(max_hlevel=1, **kwargs)
feats_n1 = dfs_h_n1.build_features()
feats_n1 = [f.get_name() for f in feats_n1]
feats_0 = dfs_h_0.build_features()
feats_0 = [f.get_name() for f in feats_0]
feats_1 = dfs_h_1.build_features()
feats_1 = [f.get_name() for f in feats_1]

customer_log = ft.Feature(es['log']['value'], parent_entity=es['customers'], primitive=Last)
session_log = ft.Feature(es['log']['value'], parent_entity=es['sessions'], primitive=Last)
log_customer_log = ft.Feature(ft.Feature(customer_log, es["sessions"]), es['log'])
log_session_log = ft.Feature(session_log, es['log'])
assert log_customer_log.get_name() in feats_n1
assert log_session_log.get_name() in feats_n1

assert log_customer_log.get_name() not in feats_1
assert log_session_log.get_name() in feats_1

assert log_customer_log.get_name() not in feats_0
assert log_session_log.get_name() not in feats_0


def test_commutative(es):
dfs_obj = DeepFeatureSynthesis(target_entity_id='log',
entityset=es,
Expand Down

0 comments on commit d1390d5

Please sign in to comment.