Skip to content

Commit

Permalink
Add some types
Browse files Browse the repository at this point in the history
  • Loading branch information
Josef-Friedrich committed Jun 22, 2022
1 parent cbfd483 commit 4b60af0
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 21 deletions.
3 changes: 3 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"restructuredtext.syntaxHighlighting.disabled": true
}
6 changes: 3 additions & 3 deletions test/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def setUp(self):
'track': 7,
}

def parseEqual(self, a, b):
def parseEqual(self, a: str, b: str):
self.assertEqual(tmep.parse(a, self.values), b)

# alphanum
Expand Down Expand Up @@ -273,7 +273,7 @@ def setUp(self):
'only_whitespaces': ' \t\n',
}

def parseEqual(self, a, b):
def parseEqual(self, a: str, b: str):
self.assertEqual(tmep.parse(a, self.values), b)

# empty_string
Expand Down Expand Up @@ -315,7 +315,7 @@ def setUp(self):
'only_whitespaces': ' \t\n',
}

def parseEqual(self, a, b):
def parseEqual(self, a: str, b: str):
self.assertEqual(tmep.parse(a, self.values), b)

# empty_string
Expand Down
3 changes: 2 additions & 1 deletion tmep/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ def __init__(self, values=None):
super(Functions, self).__init__(values)


def parse(template, values=None, additional_functions=None, functions=None):
def parse(template: str, values=None, additional_functions=None,
functions=None):
template_ = Template(template)
if not functions:
functions_ = Functions(values)
Expand Down
40 changes: 25 additions & 15 deletions tmep/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ class Functions:
additional context to the functions -- specifically, the Item being
evaluated.
"""
_prefix = 'tmpl_'
prefix = 'tmpl_'

values: Values

_func_names: typing.Dict[str, typing.Callable[..., str]]
func_names: typing.List[str]

def __init__(self, values: Values = None):
"""Parametrize the functions.
Expand All @@ -56,8 +56,8 @@ def functions(self) -> FunctionCollection:
and the values are Python functions.
"""
out: FunctionCollection = {}
for key in self._func_names:
out[key[len(self._prefix):]] = getattr(self, key)
for key in self.func_names:
out[key[len(self.prefix):]] = getattr(self, key)
return out

def tmpl_alpha(self, text: str) -> str:
Expand Down Expand Up @@ -109,7 +109,7 @@ def tmpl_delchars(text: str, chars: str) -> str:
return text

@staticmethod
def tmpl_deldupchars(text: str, chars=r'-_\.'):
def tmpl_deldupchars(text: str, chars: str = r'-_\.') -> str:
"""
* synopsis: ``%deldupchars{text,chars}``
* description: Search for duplicate characters and replace with only \
Expand All @@ -119,7 +119,8 @@ def tmpl_deldupchars(text: str, chars=r'-_\.'):
return re.sub(r'([' + chars + r'])\1*', r'\1', text)

@staticmethod
def tmpl_first(text: str, count=1, skip=0, sep='; ', join_str='; '):
def tmpl_first(text: str, count: int = 1, skip: int = 0, sep: str = '; ',
join_str: str = '; ') -> str:
"""
* synopsis: ``%first{text}`` or ``%first{text,count,skip}`` or \
``%first{text,count,skip,sep,join}``
Expand All @@ -140,7 +141,7 @@ def tmpl_first(text: str, count=1, skip=0, sep='; ', join_str='; '):
return join_str.join(text.split(sep)[skip:count])

@staticmethod
def tmpl_if(condition, trueval, falseval=''):
def tmpl_if(condition: str, trueval: str, falseval: str = '') -> str:
"""If ``condition`` is nonempty and nonzero, emit ``trueval``;
otherwise, emit ``falseval`` (if provided).
Expand All @@ -151,20 +152,23 @@ def tmpl_if(condition, trueval, falseval=''):
third argument if specified (or nothing if falsetext is left off).
"""
c: typing.Union[str, int]
c = condition
try:
int_condition = _int_arg(condition)
except ValueError:
if condition.lower() == "false":
return falseval
else:
condition = int_condition
c = int_condition

if condition:
if c:
return trueval
else:
return falseval

def tmpl_ifdef(self, field, trueval='', falseval=''):
def tmpl_ifdef(self, field: str, trueval: str = '',
falseval: str = '') -> str:
"""If field exists return trueval or the field (default) otherwise,
emit return falseval (if provided).
Expand All @@ -179,12 +183,13 @@ def tmpl_ifdef(self, field, trueval='', falseval=''):
:param falseval: The string if the condition is false
:return: The string, based on condition
"""
if field in self.values:
if self.values and field in self.values:
return trueval
else:
return falseval

def tmpl_ifdefempty(self, field, trueval='', falseval=''):
def tmpl_ifdefempty(self, field: str, trueval: str = '',
falseval: str = ''):
"""If field exists and is emtpy return trueval
otherwise, emit return falseval (if provided).
Expand All @@ -199,14 +204,17 @@ def tmpl_ifdefempty(self, field, trueval='', falseval=''):
:param falseval: The string if the condition is false
:return: The string, based on condition
"""
if not self.values:
return falseval
if field not in self.values or \
(field in self.values and not self.values[field]) or \
re.search(r'^\s*$', self.values[field]):
return trueval
else:
return falseval

def tmpl_ifdefnotempty(self, field, trueval='', falseval=''):
def tmpl_ifdefnotempty(self, field: str, trueval: str = '',
falseval: str = '') -> str:
"""If field is not emtpy return trueval or the field (default)
otherwise, emit return falseval (if provided).
Expand All @@ -221,6 +229,8 @@ def tmpl_ifdefnotempty(self, field, trueval='', falseval=''):
:param falseval: The string if the condition is false
:return: The string, based on condition
"""
if not self.values:
return trueval
if field not in self.values or \
(field in self.values and not self.values[field]) or \
re.search(r'^\s*$', self.values[field]):
Expand Down Expand Up @@ -372,6 +382,6 @@ def tmpl_upper(text: str) -> str:


# Get the name of tmpl_* functions in the above class.
Functions._func_names = \
Functions.func_names = \
[s for s in dir(Functions)
if s.startswith(Functions._prefix)]
if s.startswith(Functions.prefix)]
4 changes: 2 additions & 2 deletions tmep/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,7 @@ def cached(func):


@cached
def template(fmt):
def template(fmt: str):
return Template(fmt)


Expand All @@ -553,7 +553,7 @@ class Template:
"""A string template, including text, Symbols, and Calls.
"""

def __init__(self, template):
def __init__(self, template: str):
self.expr = _parse(template)
self.original = template
self.compiled = self.translate()
Expand Down

0 comments on commit 4b60af0

Please sign in to comment.