Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor EntitySet.find_path(...) #295

Merged
merged 10 commits into from Oct 29, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
105 changes: 47 additions & 58 deletions featuretools/entityset/entityset.py
Expand Up @@ -20,36 +20,6 @@
logger = logging.getLogger('featuretools.entityset')


class BFSNode(object):

def __init__(self, entity_id, parent, relationship):
self.entity_id = entity_id
self.parent = parent
self.relationship = relationship

def build_path(self):
path = []
cur_node = self
num_forward = 0
i = 0
last_forward = False
while cur_node.parent is not None:
path.append(cur_node.relationship)
if cur_node.relationship.parent_entity.id == cur_node.entity_id:
num_forward += 1
if i == 0:
last_forward = True
cur_node = cur_node.parent
i += 1
path.reverse()

# if path ends on a forward relationship, return number of
# forward relationships, otherwise 0
if len(path) == 0 or not last_forward:
num_forward = 0
return path, num_forward


class EntitySet(object):
"""
Stores all actual data for a entityset
Expand Down Expand Up @@ -456,48 +426,67 @@ def find_path(self, start_entity_id, goal_entity_id,
Returns:
List of relationships that go from start entity to goal
entity. None is returned if no path exists.
If include_forward_distance is True,
If include_num_forward is True,
returns a tuple of (relationship_list, forward_distance).

See Also:
:func:`BaseEntitySet.find_forward_path`
:func:`BaseEntitySet.find_backward_path`
:func:`EntitySet.find_forward_path`
:func:`EntitySet.find_backward_path`
"""
if start_entity_id == goal_entity_id:
if include_num_forward:
return [], 0
else:
return []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This else case not tested according to codecov

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added


# BFS so we get shortest path
start_node = BFSNode(start_entity_id, None, None)
queue = [start_node]
nodes = {}
# Search for path using BFS to get the shortest path.
# Start by initializing the queue with all relationships from start entity
queue = [[r] for r in self.get_forward_relationships(start_entity_id)] + \
[[r] for r in self.get_backward_relationships(start_entity_id)]
rwedge marked this conversation as resolved.
Show resolved Hide resolved
visited = set([start_entity_id])

while len(queue) > 0:
current_node = queue.pop(0)
if current_node.entity_id == goal_entity_id:
path, num_forward = current_node.build_path()
# get first path from queue
current_path = queue.pop(0)

# last entity in path will be which ever one we haven't visited
if current_path[-1].parent_entity.id not in visited:
next_entity_id = current_path[-1].parent_entity.id
elif current_path[-1].child_entity.id not in visited:
next_entity_id = current_path[-1].child_entity.id
else:
# if we've visited both, we don't need to explore this path further
continue
kmax12 marked this conversation as resolved.
Show resolved Hide resolved

# we've found a path to goal
if next_entity_id == goal_entity_id:
if include_num_forward:
return path, num_forward
# count the number of forward relationships along this path
# starting from beginning
check_entity = start_entity_id
num_forward = 0
for r in current_path:
# if the current entity we're checking is a child, that means the
# relationship is a forward and the next entity to check is the parent
if r.child_entity.id == check_entity:
num_forward += 1
check_entity = r.parent_entity.id
else:
check_entity = r.child_entity.id

return current_path, num_forward
else:
return path

for r in self.get_forward_relationships(current_node.entity_id):
if r.parent_entity.id not in nodes:
parent_node = BFSNode(r.parent_entity.id, current_node, r)
nodes[r.parent_entity.id] = parent_node
queue.append(parent_node)

for r in self.get_backward_relationships(current_node.entity_id):
if r.child_entity.id not in nodes:
child_node = BFSNode(r.child_entity.id, current_node, r)
nodes[r.child_entity.id] = child_node
queue.append(child_node)

raise ValueError(("No path from {} to {}! Check that all entities "
.format(start_entity_id, goal_entity_id)),
"are connected by relationships")
return current_path

next_relationships = self.get_forward_relationships(next_entity_id)
next_relationships += self.get_backward_relationships(next_entity_id)

for r in next_relationships:
queue.append(current_path + [r])

visited.add(next_entity_id)
e = "No path from {} to {}. Check that all entities are connected by relationships".format(start_entity_id, goal_entity_id)
raise ValueError(e)

def find_forward_path(self, start_entity_id, goal_entity_id):
"""Find a forward path between a start and goal entity
Expand Down
19 changes: 19 additions & 0 deletions featuretools/tests/entityset_tests/test_es_metadata.py
Expand Up @@ -108,6 +108,25 @@ def test_find_path(es):
assert path[2].parent_entity.id == 'customers'


def test_find_path_same_entity(es):
path, forward = es.find_path('products', 'products',
include_num_forward=True)
assert len(path) == 0
assert forward == 0

# also test include_num_forward==False
path = es.find_path('products', 'products',
include_num_forward=False)
assert len(path) == 0


def test_find_path_no_path_found(es):
es.relationships = []
error_text = "No path from products to customers. Check that all entities are connected by relationships"
with pytest.raises(ValueError, match=error_text):
es.find_path('products', 'customers')


def test_raise_key_error_missing_entity(es):
with pytest.raises(KeyError):
es["this entity doesn't exist"]
Expand Down