Skip to content

Commit

Permalink
Merge pull request #233 from Yelp/better_def_parsing
Browse files Browse the repository at this point in the history
Improve argument parsing
  • Loading branch information
Buck Evan committed Jun 3, 2016
2 parents 7780dac + 246ed83 commit a72c72d
Show file tree
Hide file tree
Showing 7 changed files with 195 additions and 238 deletions.
7 changes: 1 addition & 6 deletions Cheetah/SourceReader.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,9 @@ def readToEOL(self, start=None, gobble=True):
pos = EOLmatch.start()
return self.readTo(to=pos, start=start)

def find(self, it, pos=None):
if pos is None:
pos = self._pos
def find(self, it, pos):
return self._src.find(it, pos)

def startswith(self, it, pos=None):
return self.find(it, pos) == self._pos

def findBOL(self, pos=None):
if pos is None:
pos = self._pos
Expand Down
33 changes: 33 additions & 0 deletions Cheetah/ast_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from __future__ import unicode_literals

import ast
import collections
import operator

import six

Expand Down Expand Up @@ -70,3 +72,34 @@ def get_lvalues(expression):
visitor = TopLevelVisitor()
visitor.visit(ast_obj)
return visitor.targets_visitor.lvalues


_arg_to_name = operator.attrgetter('id' if six.PY2 else 'arg')

if six.PY2: # pragma: no cover (PY2)
def _vararg_to_name(arg):
return arg
else: # pragma: no cover (PY3)
def _vararg_to_name(arg):
return arg.arg


def get_argument_names(argspec):
ast_obj = ast.parse('def _({}): pass'.format(argspec)).body[0].args
names = [_arg_to_name(name) for name in ast_obj.args]
if ast_obj.vararg:
names.append(_vararg_to_name(ast_obj.vararg))
if ast_obj.kwarg:
names.append(_vararg_to_name(ast_obj.kwarg))
if hasattr(ast_obj, 'kwonlyargs'): # pragma: no cover: PY3
names.extend([arg.arg for arg in ast_obj.kwonlyargs])
# Raise a nice message on duplicate arguments (since ast doesn't)
counter = collections.Counter(names)
duplicate_arguments = sorted([
name for name, count in counter.items() if count > 1
])
if duplicate_arguments:
raise SyntaxError('Duplicate arguments: {}'.format(
', '.join(duplicate_arguments),
))
return set(names)
44 changes: 14 additions & 30 deletions Cheetah/legacy_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import six

from Cheetah.ast_utils import get_argument_names
from Cheetah.ast_utils import get_imported_names
from Cheetah.ast_utils import get_lvalues
from Cheetah.legacy_parser import escapedNewlineRE
Expand Down Expand Up @@ -60,22 +61,17 @@ def genNameMapperVar(nameChunks, auto_self):
return start + ('.' if tail else '') + tail


def _arg_chunk_to_text(chunk):
if chunk[1] is not None:
return '{}={}'.format(*chunk)
else:
return chunk[0]


def arg_string_list_to_text(arg_string_list):
return ', '.join(_arg_chunk_to_text(chunk) for chunk in arg_string_list)
def _prepare_argspec(argspec):
argspec = 'self, ' + argspec if argspec else 'self'
return argspec, get_argument_names(argspec)


class MethodCompiler(object):
def __init__(
self,
methodName,
class_compiler,
argspec,
initialMethodComment,
decorators=None,
):
Expand All @@ -86,8 +82,7 @@ def __init__(
self._methodBodyChunks = []
self._hasReturnStatement = False
self._isGenerator = False
self._arguments = [('self', None)]
self._local_vars = {'self'}
self._argspec, self._local_vars = _prepare_argspec(argspec)
self._decorators = decorators or []

def cleanupState(self):
Expand Down Expand Up @@ -293,17 +288,12 @@ def _addAutoCleanupCode(self):
self.addChunk('return NO_CONTENT')
self.dedent()

def addMethArg(self, name, val):
self._arguments.append((name, val))
self._local_vars.add(name.lstrip('*'))

def methodSignature(self):
arg_text = arg_string_list_to_text(self._arguments)
return ''.join((
''.join(
INDENT + decorator + '\n' for decorator in self._decorators
),
INDENT + 'def ' + self.methodName() + '(' + arg_text + '):'
INDENT + 'def ' + self.methodName() + '(' + self._argspec + '):'
))


Expand All @@ -319,7 +309,8 @@ def __init__(self, main_method_name):

self._main_method = self._spawnMethodCompiler(
main_method_name,
'## CHEETAH: main method generated for this template'
'',
'## CHEETAH: main method generated for this template',
)

def __getattr__(self, name):
Expand All @@ -336,10 +327,11 @@ def cleanupState(self):
def setMainMethodName(self, methodName):
self._main_method.setMethodName(methodName)

def _spawnMethodCompiler(self, methodName, initialMethodComment):
def _spawnMethodCompiler(self, methodName, argspec, initialMethodComment):
methodCompiler = self.methodCompilerClass(
methodName,
class_compiler=self,
argspec=argspec,
initialMethodComment=initialMethodComment,
decorators=self._decoratorsForNextMethod,
)
Expand All @@ -358,12 +350,7 @@ def _swallowMethodCompiler(self, methodCompiler):
self._finishedMethodsList.append(methodCompiler)
return methodCompiler

def startMethodDef(self, methodName, argsList, parserComment):
methodCompiler = self._spawnMethodCompiler(
methodName, parserComment,
)
for argName, defVal in argsList:
methodCompiler.addMethArg(argName, defVal)
startMethodDef = _spawnMethodCompiler

def addDecorator(self, decorator_expr):
"""Set the decorator to be used with the next method in the source.
Expand All @@ -376,13 +363,10 @@ def addDecorator(self, decorator_expr):
def addAttribute(self, attr_expr):
self._attrs.append(attr_expr)

def addSuper(self, argsList):
def addSuper(self, argspec):
methodName = self._getActiveMethodCompiler().methodName()
arg_text = arg_string_list_to_text(argsList)
self.addFilteredChunk(
'super({}, self).{}({})'.format(
CLASS_NAME, methodName, arg_text,
)
'super({}, self).{}({})'.format(CLASS_NAME, methodName, argspec),
)

def closeDef(self):
Expand Down

0 comments on commit a72c72d

Please sign in to comment.