Skip to content

Commit

Permalink
Optimize comprehensions
Browse files Browse the repository at this point in the history
  • Loading branch information
asottile committed Jun 23, 2016
1 parent 2db3fab commit 1f32bfe
Show file tree
Hide file tree
Showing 3 changed files with 245 additions and 4 deletions.
150 changes: 150 additions & 0 deletions Cheetah/legacy_compiler.py
Expand Up @@ -21,6 +21,8 @@
from Cheetah.ast_utils import get_argument_names
from Cheetah.ast_utils import get_imported_names
from Cheetah.ast_utils import get_lvalues
from Cheetah.legacy_parser import brace_ends
from Cheetah.legacy_parser import brace_starts
from Cheetah.legacy_parser import CheetahVar
from Cheetah.legacy_parser import LegacyParser
from Cheetah.SettingsManager import SettingsManager
Expand Down Expand Up @@ -69,7 +71,155 @@ def _cheetah_var_to_text(
)


def _process_comprehensions(expr_parts):
"""Comprehensions are a unique part of python's syntax which
references variables earlier in the source than they are declared.
Because of this, we need to do some pre-processing to predict local
variables introduced by comprehensions.
For instance, the following is "legal" cheetah syntax:
#py y = [$x for x in (1, 2, 3) if $x]
Naively, $x is compiled as a cheetah variable and attempts a lookup.
However, this variable is guaranteed to come from locals.
We'll use the python3 rules here for determining local variables. That is
the scope of a variable that is an lvalue in a `for ... in` comprehension
is only the comprehension and not the rest of the function as it is in
python2.
The ast defines each of the comprehensions as follows:
ListComp(expr elt, comprehension* generators)
comprehension(expr target, expr iter, expr* ifs)
(set / dict / generator expressions are similar)
Consider:
[elt_expr for x in x_iter if x_if1 if x_if2 for y in y_iter if y_if]
Each `for ... in ...` introduces names which are available in:
- `elt`
- the `ifs` for that part
- Any `for ... in ...` after that
In the above, the expressions have the following local variables
introduced by the comprehensions:
- elt: [x, y]
- x_iter: []
- x_if1 / x_if2: [x]
- y_iter: [x]
- y_if: [x, y]
The approximate algorithm:
Search for a `for` token.
Search left for a brace, if there is none abandon -- this is a for
loop and not a comprehension.
While searching left, if a `for` token is encountered, record its
position
Search forward for `in`, `for`, `if` and the right boundary
Process `for ... in` + (): pass looking for introduced locals
For example, 'for (x, y) in' will look for locals in
`for (x, y) in (): pass` and finds `x` and `y` as lvalues
Process tokens in `elt` and the rest of the expression (if applicable)
For each CheetahVar encountered, if it is in the locals detected
replace it with the raw variable
"""
def _search(parts, index, direction):
"""Helper for searching forward / backward.
Yields (index, token, brace_depth)
"""
assert direction in (1, -1), direction

def in_bounds(index):
return index >= 0 and index < len(parts)

if direction == 1:
starts = brace_starts
ends = brace_ends
else:
starts = brace_ends
ends = brace_starts

brace_depth = 0
index += direction
while in_bounds(index) and brace_depth >= 0:
token = parts[index]
if token in starts:
brace_depth += 1
elif token in ends:
brace_depth -= 1
yield index, token, brace_depth
index += direction

expr_parts = list(expr_parts)
for i in range(len(expr_parts)):
if expr_parts[i] != 'for':
continue

# A diagram of the below indices:
# (Considering the first `for`)
# [(x, y) for x in (1,) if x for y in (2,)]
# | | | | |
# | | | +- next_index +- right_boundary
# | | +- in_index
# | +- for_index + first_for_index
# +- left_boundary
#
# (Considering the second `for`)
# [(x, y) for x in (1,) if x for y in (2,)]
# | | | | |
# | | | | +- right_boundary
# | +- first_for_index | +- in_index
# +- left_boundary +- for_index
# (next_index is None)

first_for_index = for_index = i

# Search for the left boundary or abandon (if it is a for loop)
for i, token, depth in _search(expr_parts, for_index, direction=-1):
if depth == 0 and token == 'for':
first_for_index = i
elif depth == -1:
left_boundary = i
break
else:
continue

in_index = None
next_index = None
for i, token, depth in _search(expr_parts, for_index, direction=1):
if in_index is None and depth == 0 and token == 'in':
in_index = i
elif next_index is None and depth == 0 and token in {'if', 'for'}:
next_index = i
elif depth == -1:
right_boundary = i
break

# Defensive assertion is required, slicing with [:None] is valid
assert in_index is not None, in_index
lvalue_expr = ''.join(expr_parts[for_index:in_index]) + 'in (): pass'
lvalue_expr = lvalue_expr.replace('\n', ' ')
lvalues = get_lvalues(lvalue_expr)

replace_ranges = [range(left_boundary, first_for_index)]
if next_index is not None:
replace_ranges.append(range(next_index, right_boundary))

for replace_range in replace_ranges:
for i in replace_range:
token = expr_parts[i]
if isinstance(token, CheetahVar) and token.name in lvalues:
expr_parts[i] = token.name

return tuple(expr_parts)


def _expr_to_text(expr_parts, **kwargs):
expr_parts = _process_comprehensions(expr_parts)
return ''.join(
_cheetah_var_to_text(part, **kwargs)
if isinstance(part, CheetahVar) else
Expand Down
6 changes: 4 additions & 2 deletions Cheetah/legacy_parser.py
Expand Up @@ -29,6 +29,8 @@
}

brace_pairs = {'(': ')', '[': ']', '{': '}'}
brace_starts = set(brace_pairs)
brace_ends = set(brace_pairs.values())

escape_lookbehind = r'(?:(?<=\A)|(?<!\\))'
identRE = re.compile(r'[a-zA-Z_][a-zA-Z_0-9]*')
Expand Down Expand Up @@ -325,9 +327,9 @@ def _read_braced_expression(self, allow_cheetah_vars=True, force_variable=False)
parts.append(self.getc())
else:
token = self.getPyToken()
if token in '({[':
if token in brace_starts:
brace_stack.append(token)
elif token in ')}]':
elif token in brace_ends:
if brace_pairs[brace_stack[-1]] != token:
raise ParseError(
self,
Expand Down
93 changes: 91 additions & 2 deletions tests/Compiler_test.py
Expand Up @@ -4,14 +4,21 @@
import io
import os.path

import pytest

from Cheetah.cheetah_compile import compile_template
from Cheetah.compile import _create_module_from_source
from Cheetah.compile import compile_source
from Cheetah.compile import compile_to_class
from Cheetah.legacy_parser import brace_pairs
from Cheetah.Template import Template
from testing.util import run_python


def lookup(v):
return 'VFFNS("{}", locals(), globals(), NS)'.format(v)


def test_templates_runnable_using_env(tmpdir):
tmpl_filename = os.path.join(tmpdir.strpath, 'my_template.tmpl')
tmpl_py_filename = os.path.join(tmpdir.strpath, 'my_template.py')
Expand Down Expand Up @@ -57,7 +64,7 @@ def test_optimized_attributes_of_builtins_function_args():

def test_non_optimized_searchlist():
src = compile_source('$int($foo)')
assert ' _v = int(VFFNS("foo"' in src
assert ' _v = int({}) #'.format(lookup('foo')) in src


def test_optimization_still_prefers_locals():
Expand Down Expand Up @@ -181,7 +188,7 @@ class fooobj(object):
def test_optimization_removes_VFN():
src = compile_source(VFN_opt_src)
assert 'VFN(' not in src
assert ' _v = VFFNS("foo", locals(), globals(), NS).barvar[0].upper() #' in src
assert ' _v = {}.barvar[0].upper() #'.format(lookup('foo')) in src
cls = compile_to_class(VFN_opt_src)
assert cls({'foo': fooobj}).respond() == 'W'

Expand Down Expand Up @@ -229,3 +236,85 @@ def test_optimization_partial_template_functions():
assert foo(Template()).strip() == '25'
src = io.open('testing/templates/src/optimize_name.py').read()
assert ' _v = bar(5) #' in src


@pytest.mark.parametrize(('start', 'end'), tuple(brace_pairs.items()))
def test_optimize_comprehensions(start, end):
src = '#py {}$x for x in (1,) if $x{}'.format(start, end)
src = compile_source(src)
expected = ' {}x for x in (1,) if x{} #'.format(start, end)
assert expected in src


def test_optimize_dict_comprehension():
src = '#py {$x: $y for x, y in {1: 2}.items() if $x and $y}'
src = compile_source(src)
assert ' {x: y for x, y in {1: 2}.items() if x and y} #' in src


def test_optimize_for_loop_over_comprehension():
src = compile_source('#for y in ($x for x in (1,)): $y')
assert ' for y in (x for x in (1,)): #' in src


def test_optimize_multi_comprehension():
src = '#py [($x, $y) for x in (1,) for y in (2,)]'
src = compile_source(src)
assert ' [(x, y) for x in (1,) for y in (2,)] #' in src


def test_optimize_multi_comprehension_referenced_later():
src = '#py [($x, $y) for x in (1,) for y in ($x,) if $x and $y]'
src = compile_source(src)
assert ' [(x, y) for x in (1,) for y in (x,) if x and y] #' in src


def test_optimize_multi_comprehension_rename_ns_variable():
src = '#py [($x, $y) for x in (1,) if $x and $y for y in (2,)]'
src = compile_source(src)
expected = ' [(x, y) for x in (1,) if x and {} for y in (2,)] #'.format(
lookup('y'),
)
assert expected in src


def test_optimize_comprehension_multi_if():
src = compile_source('#py [$x for x in (1,) if $x if $x + 1]')
assert ' [x for x in (1,) if x if x + 1] #' in src


def test_optimize_comprehension_in_element():
src = compile_source('#py [[$y for y in ($x,)] for x in (1,)]')
assert ' [[y for y in (x,)] for x in (1,)] #' in src


def test_optimize_comprehension_in_ifs():
src = compile_source('#py [$x for x in (1,) if [$y for y in ($x,)]]')
assert ' [x for x in (1,) if [y for y in (x,)]] #' in src


def test_optimize_comprehension_rename_ns_variable():
src = compile_source('#py [$x for x in ($x,)]')
assert ' [x for x in ({},)] #'.format(lookup('x')) in src


def test_comprehension_with_newlines():
src = compile_source(
'#py [\n'
'$x\n'
'for\n'
'x\n'
'in\n'
'(1,)\n'
']'
)
expected = (
' [\n'
'x\n'
'for\n'
'x\n'
'in\n'
'(1,)\n'
'] #'
)
assert expected in src

0 comments on commit 1f32bfe

Please sign in to comment.