Skip to content

Commit

Permalink
recompute propagates to children nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
mristin committed Oct 10, 2018
1 parent 8cf8888 commit 314eb27
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 17 deletions.
27 changes: 26 additions & 1 deletion icontract/recompute.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,17 @@
from typing import Any, Mapping, Dict, List, Optional, Union, Tuple, Set, Callable # pylint: disable=unused-import


class Placeholder:
"""Represent a placeholder for variables local to the lambda such as targets in generator expressions."""

def __repr__(self) -> str:
"""Represent the placeholder as <Placeholder>."""
return "<Placeholder>"


PLACEHOLDER = Placeholder()


class Visitor(ast.NodeVisitor):
"""
Traverse the abstract syntax tree and recompute the values of each node defined by the function frame.
Expand Down Expand Up @@ -111,7 +122,9 @@ def visit_Name(self, node: ast.Name) -> Any:
result = getattr(builtins, node.id)

if result is None and node.id != "None":
raise ValueError("Name not found in the variable lookup: {}".format(node.id))
# The variable refers to a name local of the lambda (e.g., a target in the generator expression).
# Since we evaluate generator expressions with runtime compilation, None is returned here as a placeholder.
return PLACEHOLDER

self.recomputed_values[node] = result
return result
Expand Down Expand Up @@ -353,27 +366,39 @@ def visit_GeneratorExp(self, node: ast.GeneratorExp) -> Any:
"""Compile the generator expression as a function and call it."""
result = self._execute_comprehension(node=node)

for generator in node.generators:
self.visit(generator.iter)

# Do not set the computed value of the node since its representation would be non-informative.
return result

def visit_ListComp(self, node: ast.ListComp) -> Any:
"""Compile the list comprehension as a function and call it."""
result = self._execute_comprehension(node=node)

for generator in node.generators:
self.visit(generator.iter)

self.recomputed_values[node] = result
return result

def visit_SetComp(self, node: ast.SetComp) -> Any:
"""Compile the set comprehension as a function and call it."""
result = self._execute_comprehension(node=node)

for generator in node.generators:
self.visit(generator.iter)

self.recomputed_values[node] = result
return result

def visit_DictComp(self, node: ast.DictComp) -> Any:
"""Compile the dictionary comprehension as a function and call it."""
result = self._execute_comprehension(node=node)

for generator in node.generators:
self.visit(generator.iter)

self.recomputed_values[node] = result
return result

Expand Down
31 changes: 15 additions & 16 deletions tests/test_represent.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,18 +361,17 @@ def some_func() -> List[Tuple[pathlib.Path, pathlib.Path]]:
icontract_violation_error = err

self.assertIsNotNone(icontract_violation_error)
self.assertEqual('all(single_res[1].is_absolute() for single_res in result): all(single_res[1].is_absolute() '
'for single_res in result) was False', str(icontract_violation_error))
self.assertEqual("all(single_res[1].is_absolute() for single_res in result):\n"
"all(single_res[1].is_absolute() for single_res in result) was False\n"
"result was [(PosixPath('/home/file1'), PosixPath('home/file2'))]",
str(icontract_violation_error))

def test_generator_expression_multiple_for(self):
lst = [1, 2, 3]
another_lst = [4, 5, 6]
lst = [[1, 2], [3]]

# yapf: disable
@icontract.pre(
lambda x: all(item == x or another_item == x
for item in lst if item % 2 == 0
for another_item in another_lst if another_item % 3 == 0)
lambda x: all(item == x for sublst in lst for item in sublst)
)
# yapf: enable
def func(x: int) -> int:
Expand All @@ -386,13 +385,9 @@ def func(x: int) -> int:

self.assertIsNotNone(icontract_violation_error)

self.assertEqual('all(item == x or another_item == x\n'
' for item in lst if item % 2 == 0\n'
' for another_item in another_lst if another_item % 3 == 0): '
'all(item == x or another_item == x\n'
' for item in lst if item % 2 == 0\n'
' for another_item in another_lst if another_item % 3 == 0) was False',
str(icontract_violation_error))
self.assertEqual('all(item == x for sublst in lst for item in sublst):\n'
'all(item == x for sublst in lst for item in sublst) was False\n'
'lst was [[1, 2], [3]]', str(icontract_violation_error))

def test_list_comprehension(self):
lst = [1, 2, 3]
Expand All @@ -408,8 +403,9 @@ def func(x: int) -> int:
icontract_violation_error = err

self.assertIsNotNone(icontract_violation_error)
self.assertEqual('[item < x for item in lst if item % x == 0] == []: '
'[item < x for item in lst if item % x == 0] was [False]', str(icontract_violation_error))
self.assertEqual('[item < x for item in lst if item % x == 0] == []:\n'
'[item < x for item in lst if item % x == 0] was [False]\n'
'lst was [1, 2, 3]', str(icontract_violation_error))

def test_set_comprehension(self):
lst = [1, 2, 3]
Expand All @@ -427,6 +423,7 @@ def func(x: int) -> int:
self.assertIsNotNone(icontract_violation_error)
self.assertEqual('len({item < x for item in lst if item % x == 0}) == 0:\n'
'len({item < x for item in lst if item % x == 0}) was 1\n'
'lst was [1, 2, 3]\n'
'{item < x for item in lst if item % x == 0} was {False}', str(icontract_violation_error))

def test_dict_comprehension(self):
Expand All @@ -443,6 +440,8 @@ def func(x: int) -> int:
self.assertIsNotNone(icontract_violation_error)
self.assertEqual('len({i: i**2 for i in range(x)}) == 0:\n'
'len({i: i**2 for i in range(x)}) was 2\n'
'range(x) was range(0, 2)\n'
'x was 2\n'
'{i: i**2 for i in range(x)} was {0: 0, 1: 1}', str(icontract_violation_error))


Expand Down

0 comments on commit 314eb27

Please sign in to comment.