Skip to content

Commit

Permalink
Merge pull request #126 from PyCQA/feature/find_iter
Browse files Browse the repository at this point in the history
Issue #117: Fix bug with finding comment node
  • Loading branch information
ibizaman committed Dec 22, 2016
2 parents ee74b0d + 2866320 commit b0d4b1c
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 66 deletions.
104 changes: 38 additions & 66 deletions redbaron/base_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,11 +326,13 @@ def __getattr__(self, key):
def __setitem__(self, key, value):
self.data[key] = self._convert_input_to_node_object(value, parent=self.parent, on_attribute=self.on_attribute)

def find_iter(self, identifier, *args, **kwargs):
for node in self.data:
for matched_node in node.find_iter(identifier, *args, **kwargs):
yield matched_node

def find_all(self, identifier, *args, **kwargs):
to_return = NodeList([])
for i in self.data:
to_return += i.find_all(identifier, *args, **kwargs)
return to_return
return NodeList(list(self.find_iter(identifier, *args, **kwargs)))

findAll = find_all
__call__ = find_all
Expand Down Expand Up @@ -668,40 +670,6 @@ def _get_list_attribute_is_member_off(self):

return in_list

def find(self, identifier, *args, **kwargs):
if "recursive" in kwargs:
recursive = kwargs["recursive"]
kwargs = kwargs.copy()
del kwargs["recursive"]
else:
recursive = True

if self._node_match_query(self, identifier, *args, **kwargs):
return self

if not recursive:
return None

for kind, key, _ in filter(lambda x: x[0] in ("list", "key"), self._render()):
if kind == "key":
i = getattr(self, key)
if not i:
continue

found = i.find(identifier, *args, **kwargs)
if found is not None:
return found

elif kind == "list":
attr = getattr(self, key).node_list if isinstance(getattr(self, key), ProxyList) else getattr(self, key)
for i in attr:
found = i.find(identifier, *args, **kwargs)
if found is not None:
return found

else:
raise Exception()

def __getattr__(self, key):
if key.endswith("_") and key[:-1] in self._dict_keys + self._list_keys + self._str_keys:
return getattr(self, key[:-1])
Expand Down Expand Up @@ -762,8 +730,7 @@ def __delslice__(self, i, j):
else:
raise AttributeError("__delitem__")

def find_all(self, identifier, *args, **kwargs):
to_return = NodeList([])
def find_iter(self, identifier, *args, **kwargs):
if "recursive" in kwargs:
recursive = kwargs["recursive"]
kwargs = kwargs.copy()
Expand All @@ -772,33 +739,29 @@ def find_all(self, identifier, *args, **kwargs):
recursive = True

if self._node_match_query(self, identifier, *args, **kwargs):
to_return.append(self)
yield self

if recursive:
for (kind, key, _) in self._render():
if kind == "key":
node = getattr(self, key)
if not isinstance(node, Node):
continue
for matched_node in node.find_iter(identifier, *args, **kwargs):
yield matched_node
elif kind in ("list", "formatting"):
nodes = getattr(self, key)
if isinstance(nodes, ProxyList):
nodes = nodes.node_list
for node in nodes:
for matched_node in node.find_iter(identifier, *args, **kwargs):
yield matched_node

if not recursive:
return to_return

for kind, key, _ in filter(
lambda x: x[0] in ("list", "formatting") or (x[0] == "key" and isinstance(getattr(self, x[1]), Node)),
self._render()):
if kind == "key":
i = getattr(self, key)
if not i:
continue

to_return += i.find_all(identifier, *args, **kwargs)

elif kind in ("list", "formatting"):
if isinstance(getattr(self, key), ProxyList):
for i in getattr(self, key).node_list:
to_return += i.find_all(identifier, *args, **kwargs)
else:
for i in getattr(self, key):
to_return += i.find_all(identifier, *args, **kwargs)

else:
raise Exception()
def find(self, identifier, *args, **kwargs):
return next(self.find_iter(identifier, *args, **kwargs), None)

return to_return
def find_all(self, identifier, *args, **kwargs):
return NodeList(list(self.find_iter(identifier, *args, **kwargs)))

findAll = find_all
__call__ = find_all
Expand Down Expand Up @@ -888,6 +851,7 @@ def generate_identifiers(klass):

def _get_helpers(self):
not_helpers = set([
'at',
'copy',
'decrease_indentation',
'dumps',
Expand All @@ -897,25 +861,33 @@ def _get_helpers(self):
'findAll',
'find_by_path',
'find_by_position',
'at',
'find_iter',
'from_fst',
'fst',
'fst',
'generate_identifiers',
'get_absolute_bounding_box_of_attribute',
'get_indentation_node',
'get_indentation_node',
'has_render_key',
'help',
'help',
'increase_indentation',
'indentation_node_is_direct',
'indentation_node_is_direct',
'index_on_parent',
'index_on_parent_raw',
'insert_after',
'insert_before',
'next_generator',
'next_generator',
'parent_find',
'parent_find',
'parse_code_block',
'parse_decorators',
'path',
'path',
'previous_generator',
'previous_generator',
'replace',
'to_python',
Expand Down
7 changes: 7 additions & 0 deletions tests/test_initial_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,6 +958,13 @@ def test_default_test_value_find_all():
red = RedBaron("badger\nmushroom\nsnake")
assert red("name", "snake") == red("name", value="snake")

def test_find_comment_node():
red = RedBaron("def f():\n #a\n pass\n#b")
assert red.find('comment').value == '#a'

def test_find_all_comment_nodes():
red = RedBaron("def f():\n #a\n pass\n#b")
assert [x.value for x in red.find_all('comment')] == ['#a', '#b']

def test_default_test_value_find_def():
red = RedBaron("def a(): pass\ndef b(): pass")
Expand Down

0 comments on commit b0d4b1c

Please sign in to comment.