Skip to content

Commit

Permalink
Simplify extends logic.
Browse files Browse the repository at this point in the history
  • Loading branch information
asottile committed Oct 15, 2014
1 parent 97d7ba6 commit 6882df1
Showing 1 changed file with 28 additions and 26 deletions.
54 changes: 28 additions & 26 deletions Cheetah/legacy_compiler.py
Expand Up @@ -782,41 +782,43 @@ def genNameMapperVar(self, nameChunks):

return pythonCode

def setBaseClass(self, baseClassName):
def setBaseClass(self, extends_name):
self.setMainMethodName(self.setting('mainMethodNameForSubclasses'))

if baseClassName in self.importedVarNames():
if extends_name in self.importedVarNames():
raise AssertionError(
'yelp_cheetah only supports extends by module name'
)

# If the #extends directive contains a classname or modulename that isn't
# in self.importedVarNames() already, we assume that we need to add
# an implied 'from ModName import ClassName' where ModName == ClassName.
# - We also assume that the final . separates the classname from the
# module name. This might break if people do something really fancy
# with their dots and namespaces.
chunks = baseClassName.split('.')
# The #extends directive results in the base class being imported
# There are (basically) three cases:
# 1. #extends foo
# import added: from foo import foo
# baseclass: foo
# 2. #extends foo.bar
# import added: from foo.bar import bar
# baseclass: bar
# 3. #extends foo.bar.bar
# import added: from foo.bar import bar
# baseclass: bar
chunks = extends_name.split('.')
# Case 1
# If we only have one part, assume it's like from {chunk} import {chunk}
if len(chunks) == 1:
self._getActiveClassCompiler().setBaseClass(baseClassName)
modName = baseClassName
chunks *= 2

class_name = chunks[-1]
if class_name != chunks[-2]:
# Case 2
# we assume the class name to be the module name
# and that it's not a builtin:
importStatement = 'from {0} import {1}'.format(
modName, baseClassName
)
self.addImportStatement(importStatement)
self.addImportedVarNames((baseClassName,))
module = '.'.join(chunks)
else:
modName, finalClassName = '.'.join(chunks[:-1]), chunks[-1]
# if finalClassName != chunks[:-1][-1]:
if finalClassName != chunks[-2]:
# we assume the class name to be the module name
modName = '.'.join(chunks)
self._getActiveClassCompiler().setBaseClass(finalClassName)
importStatement = "from %s import %s" % (modName, finalClassName)
self.addImportStatement(importStatement)
self.addImportedVarNames([finalClassName])
# Case 3
module = '.'.join(chunks[:-1])
self._getActiveClassCompiler().setBaseClass(class_name)
importStatement = 'from {0} import {1}'.format(module, class_name)
self.addImportStatement(importStatement)
self.addImportedVarNames((class_name,))

def setCompilerSettings(self, settingsStr):
self.updateSettingsFromConfigStr(settingsStr)
Expand Down

0 comments on commit 6882df1

Please sign in to comment.