Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue #117: Fix bug with finding comment node #118

Merged
merged 1 commit into from Dec 22, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
94 changes: 29 additions & 65 deletions redbaron/base_nodes.py
Expand Up @@ -304,12 +304,14 @@ def __getattr__(self, key):

def __setitem__(self, key, value):
self.data[key] = self._convert_input_to_node_object(value, parent=self.parent, on_attribute=self.on_attribute)

def find_iter(self, identifier, *args, **kwargs):
for node in self.data:
for matched_node in node.find_iter(identifier, *args, **kwargs):
yield matched_node

def find_all(self, identifier, *args, **kwargs):
to_return = NodeList([])
for i in self.data:
to_return += i.find_all(identifier, *args, **kwargs)
return to_return
return NodeList(list(self.find_iter(identifier, *args, **kwargs)))

findAll = find_all
__call__ = find_all
Expand Down Expand Up @@ -644,40 +646,6 @@ def _get_list_attribute_is_member_off(self):

return in_list

def find(self, identifier, *args, **kwargs):
if "recursive" in kwargs:
recursive = kwargs["recursive"]
kwargs = kwargs.copy()
del kwargs["recursive"]
else:
recursive = True

if self._node_match_query(self, identifier, *args, **kwargs):
return self

if not recursive:
return None

for kind, key, _ in filter(lambda x: x[0] in ("list", "key"), self._render()):
if kind == "key":
i = getattr(self, key)
if not i:
continue

found = i.find(identifier, *args, **kwargs)
if found is not None:
return found

elif kind == "list":
attr = getattr(self, key).node_list if isinstance(getattr(self, key), ProxyList) else getattr(self, key)
for i in attr:
found = i.find(identifier, *args, **kwargs)
if found is not None:
return found

else:
raise Exception()

def __getattr__(self, key):
if key.endswith("_") and key[:-1] in self._dict_keys + self._list_keys + self._str_keys:
return getattr(self, key[:-1])
Expand Down Expand Up @@ -738,8 +706,7 @@ def __delslice__(self, i, j):
else:
raise AttributeError("__delitem__")

def find_all(self, identifier, *args, **kwargs):
to_return = NodeList([])
def find_iter(self, identifier, *args, **kwargs):
if "recursive" in kwargs:
recursive = kwargs["recursive"]
kwargs = kwargs.copy()
Expand All @@ -748,33 +715,29 @@ def find_all(self, identifier, *args, **kwargs):
recursive = True

if self._node_match_query(self, identifier, *args, **kwargs):
to_return.append(self)

if not recursive:
return to_return

for kind, key, _ in filter(
lambda x: x[0] in ("list", "formatting") or (x[0] == "key" and isinstance(getattr(self, x[1]), Node)),
self._render()):
if kind == "key":
i = getattr(self, key)
if not i:
continue
yield self

if recursive:
for (kind, key, _) in self._render():
if kind == "key":
node = getattr(self, key)
if not isinstance(node, Node):
continue
for matched_node in node.find_iter(identifier, *args, **kwargs):
yield matched_node
elif kind in ("list", "formatting"):
nodes = getattr(self, key)
if isinstance(nodes, ProxyList):
nodes = nodes.node_list
for node in nodes:
for matched_node in node.find_iter(identifier, *args, **kwargs):
yield matched_node

to_return += i.find_all(identifier, *args, **kwargs)

elif kind in ("list", "formatting"):
if isinstance(getattr(self, key), ProxyList):
for i in getattr(self, key).node_list:
to_return += i.find_all(identifier, *args, **kwargs)
else:
for i in getattr(self, key):
to_return += i.find_all(identifier, *args, **kwargs)

else:
raise Exception()
def find(self, identifier, *args, **kwargs):
return next(self.find_iter(identifier, *args, **kwargs), None)

return to_return
def find_all(self, identifier, *args, **kwargs):
return NodeList(list(self.find_iter(identifier, *args, **kwargs)))

findAll = find_all
__call__ = find_all
Expand Down Expand Up @@ -869,6 +832,7 @@ def _get_helpers(self):
'find',
'findAll',
'find_all',
'find_iter',
'fst',
'help',
'next_generator',
Expand Down
7 changes: 7 additions & 0 deletions tests/test_initial_parsing.py
Expand Up @@ -920,6 +920,13 @@ def test_default_test_value_find_all():
red = RedBaron("badger\nmushroom\nsnake")
assert red("name", "snake") == red("name", value="snake")

def test_find_comment_node():
red = RedBaron("def f():\n #a\n pass\n#b")
assert red.find('comment').value == '#a'

def test_find_all_comment_nodes():
red = RedBaron("def f():\n #a\n pass\n#b")
assert [x.value for x in red.find_all('comment')] == ['#a', '#b']

def test_default_test_value_find_def():
red = RedBaron("def a(): pass\ndef b(): pass")
Expand Down