Skip to content

Commit

Permalink
Merge pull request #254 from vmuriart/tests_str-format
Browse files Browse the repository at this point in the history
Add various tests and change to new style str-format
  • Loading branch information
vmuriart committed Jun 11, 2016
2 parents 00304af + 1fd3da4 commit 751933d
Show file tree
Hide file tree
Showing 19 changed files with 188 additions and 121 deletions.
8 changes: 4 additions & 4 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@
master_doc = 'index'

# General information about the project.
project = u'python-sqlparse'
copyright = u'%s, Andi Albrecht' % datetime.date.today().strftime('%Y')
project = 'python-sqlparse'
copyright = '{:%Y}, Andi Albrecht'.format(datetime.date.today())

# The version info for the project you're documenting, acts as replacement for
# |version| and |release|, also used in various other places throughout the
Expand Down Expand Up @@ -177,8 +177,8 @@
# Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title, author, documentclass [howto/manual]).
latex_documents = [
('index', 'python-sqlparse.tex', ur'python-sqlparse Documentation',
ur'Andi Albrecht', 'manual'),
('index', 'python-sqlparse.tex', 'python-sqlparse Documentation',
'Andi Albrecht', 'manual'),
]

# The name of an image file (relative to this directory) to place at the top of
Expand Down
4 changes: 2 additions & 2 deletions examples/column_defs_lowlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,5 @@ def extract_definitions(token_list):
columns = extract_definitions(par)

for column in columns:
print('NAME: %-12s DEFINITION: %s' % (column[0],
''.join(str(t) for t in column[1:])))
print('NAME: {name:10} DEFINITION: {definition}'.format(
name=column[0], definition=''.join(str(t) for t in column[1:])))
15 changes: 8 additions & 7 deletions examples/extract_table_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,6 @@
# See:
# http://groups.google.com/group/sqlparse/browse_thread/thread/b0bd9a022e9d4895

sql = """
select K.a,K.b from (select H.b from (select G.c from (select F.d from
(select E.e from A, B, C, D, E), F), G), H), I, J, K order by 1,2;
"""

import sqlparse
from sqlparse.sql import IdentifierList, Identifier
from sqlparse.tokens import Keyword, DML
Expand Down Expand Up @@ -59,10 +54,16 @@ def extract_table_identifiers(token_stream):
yield item.value


def extract_tables():
def extract_tables(sql):
stream = extract_from_part(sqlparse.parse(sql)[0])
return list(extract_table_identifiers(stream))


if __name__ == '__main__':
print('Tables: %s' % ', '.join(extract_tables()))
sql = """
select K.a,K.b from (select H.b from (select G.c from (select F.d from
(select E.e from A, B, C, D, E), F), G), H), I, J, K order by 1,2;
"""

tables = ', '.join(extract_tables(sql))
print('Tables: {0}'.format(tables))
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
[wheel]
universal = 1

[pytest]
xfail_strict=true

[flake8]
exclude =
sqlparse/compat.py
Expand Down
7 changes: 3 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,13 @@
def get_version():
"""Parse __init__.py for version number instead of importing the file."""
VERSIONFILE = 'sqlparse/__init__.py'
verstrline = open(VERSIONFILE, "rt").read()
VSRE = r'^__version__ = [\'"]([^\'"]*)[\'"]'
with open(VERSIONFILE) as f:
verstrline = f.read()
mo = re.search(VSRE, verstrline, re.M)
if mo:
return mo.group(1)
else:
raise RuntimeError('Unable to find version string in %s.'
% (VERSIONFILE,))
raise RuntimeError('Unable to find version in {fn}'.format(fn=VERSIONFILE))


LONG_DESCRIPTION = """
Expand Down
7 changes: 4 additions & 3 deletions sqlparse/engine/grouping.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,14 @@ def _group_left_right(tlist, m, cls,

def _group_matching(tlist, cls):
"""Groups Tokens that have beginning and end."""
idx = 1 if imt(tlist, i=cls) else 0
idx = 1 if isinstance(tlist, cls) else 0

token = tlist.token_next_by(m=cls.M_OPEN, idx=idx)
while token:
end = find_matching(tlist, token, cls.M_OPEN, cls.M_CLOSE)
if end is not None:
token = tlist.group_tokens(cls, tlist.tokens_between(token, end))
tokens = tlist.tokens_between(token, end)
token = tlist.group_tokens(cls, tokens)
_group_matching(token, cls)
token = tlist.token_next_by(m=cls.M_OPEN, idx=token)

Expand Down Expand Up @@ -120,7 +121,7 @@ def group_period(tlist):
def group_arrays(tlist):
token = tlist.token_next_by(i=sql.SquareBrackets)
while token:
prev = tlist.token_prev(idx=token)
prev = tlist.token_prev(token)
if imt(prev, i=(sql.SquareBrackets, sql.Identifier, sql.Function),
t=(T.Name, T.String.Symbol,)):
tokens = tlist.tokens_between(prev, token)
Expand Down
2 changes: 1 addition & 1 deletion sqlparse/filters/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def _process(self, stream, varname, has_nl):
def process(self, stmt):
self.count += 1
if self.count > 1:
varname = '%s%d' % (self.varname, self.count)
varname = '{f.varname}{f.count}'.format(f=self)
else:
varname = self.varname

Expand Down
8 changes: 4 additions & 4 deletions sqlparse/filters/reindent.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,13 @@ def _process_where(self, tlist):
self._process_default(tlist)

def _process_parenthesis(self, tlist):
is_DML_DLL = tlist.token_next_by(t=(T.Keyword.DML, T.Keyword.DDL))
is_dml_dll = tlist.token_next_by(t=(T.Keyword.DML, T.Keyword.DDL))
first = tlist.token_next_by(m=sql.Parenthesis.M_OPEN)

with indent(self, 1 if is_DML_DLL else 0):
tlist.tokens.insert(0, self.nl()) if is_DML_DLL else None
with indent(self, 1 if is_dml_dll else 0):
tlist.tokens.insert(0, self.nl()) if is_dml_dll else None
with offset(self, self._get_offset(first) + 1):
self._process_default(tlist, not is_DML_DLL)
self._process_default(tlist, not is_dml_dll)

def _process_identifierlist(self, tlist):
identifiers = list(tlist.get_identifiers())
Expand Down
2 changes: 1 addition & 1 deletion sqlparse/filters/right_margin.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def _process(self, group, stream):
indent = match.group()
else:
indent = ''
yield sql.Token(T.Whitespace, '\n%s' % indent)
yield sql.Token(T.Whitespace, '\n{0}'.format(indent))
self.line = indent
self.line += val
yield token
Expand Down
40 changes: 22 additions & 18 deletions sqlparse/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,61 +15,65 @@ def validate_options(options):
"""Validates options."""
kwcase = options.get('keyword_case')
if kwcase not in [None, 'upper', 'lower', 'capitalize']:
raise SQLParseError('Invalid value for keyword_case: %r' % kwcase)
raise SQLParseError('Invalid value for keyword_case: '
'{0!r}'.format(kwcase))

idcase = options.get('identifier_case')
if idcase not in [None, 'upper', 'lower', 'capitalize']:
raise SQLParseError('Invalid value for identifier_case: %r' % idcase)
raise SQLParseError('Invalid value for identifier_case: '
'{0!r}'.format(idcase))

ofrmt = options.get('output_format')
if ofrmt not in [None, 'sql', 'python', 'php']:
raise SQLParseError('Unknown output format: %r' % ofrmt)
raise SQLParseError('Unknown output format: '
'{0!r}'.format(ofrmt))

strip_comments = options.get('strip_comments', False)
if strip_comments not in [True, False]:
raise SQLParseError('Invalid value for strip_comments: %r'
% strip_comments)
raise SQLParseError('Invalid value for strip_comments: '
'{0!r}'.format(strip_comments))

space_around_operators = options.get('use_space_around_operators', False)
if space_around_operators not in [True, False]:
raise SQLParseError('Invalid value for use_space_around_operators: %r'
% space_around_operators)
raise SQLParseError('Invalid value for use_space_around_operators: '
'{0!r}'.format(space_around_operators))

strip_ws = options.get('strip_whitespace', False)
if strip_ws not in [True, False]:
raise SQLParseError('Invalid value for strip_whitespace: %r'
% strip_ws)
raise SQLParseError('Invalid value for strip_whitespace: '
'{0!r}'.format(strip_ws))

truncate_strings = options.get('truncate_strings')
if truncate_strings is not None:
try:
truncate_strings = int(truncate_strings)
except (ValueError, TypeError):
raise SQLParseError('Invalid value for truncate_strings: %r'
% truncate_strings)
raise SQLParseError('Invalid value for truncate_strings: '
'{0!r}'.format(truncate_strings))
if truncate_strings <= 1:
raise SQLParseError('Invalid value for truncate_strings: %r'
% truncate_strings)
raise SQLParseError('Invalid value for truncate_strings: '
'{0!r}'.format(truncate_strings))
options['truncate_strings'] = truncate_strings
options['truncate_char'] = options.get('truncate_char', '[...]')

reindent = options.get('reindent', False)
if reindent not in [True, False]:
raise SQLParseError('Invalid value for reindent: %r'
% reindent)
raise SQLParseError('Invalid value for reindent: '
'{0!r}'.format(reindent))
elif reindent:
options['strip_whitespace'] = True

reindent_aligned = options.get('reindent_aligned', False)
if reindent_aligned not in [True, False]:
raise SQLParseError('Invalid value for reindent_aligned: %r'
% reindent)
raise SQLParseError('Invalid value for reindent_aligned: '
'{0!r}'.format(reindent))
elif reindent_aligned:
options['strip_whitespace'] = True

indent_tabs = options.get('indent_tabs', False)
if indent_tabs not in [True, False]:
raise SQLParseError('Invalid value for indent_tabs: %r' % indent_tabs)
raise SQLParseError('Invalid value for indent_tabs: '
'{0!r}'.format(indent_tabs))
elif indent_tabs:
options['indent_char'] = '\t'
else:
Expand Down
15 changes: 4 additions & 11 deletions sqlparse/lexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from sqlparse import tokens
from sqlparse.keywords import SQL_REGEX
from sqlparse.compat import StringIO, string_types, text_type
from sqlparse.compat import StringIO, string_types, u
from sqlparse.utils import consume


Expand All @@ -37,17 +37,10 @@ def get_tokens(text, encoding=None):
``stack`` is the inital stack (default: ``['root']``)
"""
encoding = encoding or 'utf-8'

if isinstance(text, string_types):
text = StringIO(text)

text = text.read()
if not isinstance(text, text_type):
try:
text = text.decode(encoding)
except UnicodeDecodeError:
text = text.decode('unicode-escape')
text = u(text, encoding)
elif isinstance(text, StringIO):
text = u(text.read(), encoding)

iterable = enumerate(text)
for pos, char in iterable:
Expand Down
48 changes: 15 additions & 33 deletions sqlparse/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def get_token_at_offset(self, offset):
idx = 0
for token in self.flatten():
end = idx + len(token.value)
if idx <= offset <= end:
if idx <= offset < end:
return token
idx = end

Expand Down Expand Up @@ -248,8 +248,6 @@ def token_prev(self, idx, skip_ws=True, skip_cm=False):
If *skip_ws* is ``True`` (the default) whitespace tokens are ignored.
``None`` is returned if there's no previous token.
"""
if isinstance(idx, int):
idx += 1 # alot of code usage current pre-compensates for this
funcs = lambda tk: not ((skip_ws and tk.is_whitespace()) or
(skip_cm and imt(tk, t=T.Comment)))
return self._token_matching(funcs, idx, reverse=True)
Expand All @@ -260,8 +258,6 @@ def token_next(self, idx, skip_ws=True, skip_cm=False):
If *skip_ws* is ``True`` (the default) whitespace tokens are ignored.
``None`` is returned if there's no next token.
"""
if isinstance(idx, int):
idx += 1 # alot of code usage current pre-compensates for this
funcs = lambda tk: not ((skip_ws and tk.is_whitespace()) or
(skip_cm and imt(tk, t=T.Comment)))
return self._token_matching(funcs, idx)
Expand All @@ -283,34 +279,26 @@ def tokens_between(self, start, end, include_end=True):

def group_tokens(self, grp_cls, tokens, skip_ws=False, extend=False):
"""Replace tokens by an instance of *grp_cls*."""
if skip_ws:
while tokens and tokens[-1].is_whitespace():
tokens = tokens[:-1]

while skip_ws and tokens and tokens[-1].is_whitespace():
tokens = tokens[:-1]

left = tokens[0]
idx = self.token_index(left)

if extend:
if not isinstance(left, grp_cls):
grp = grp_cls([left])
self.tokens.remove(left)
self.tokens.insert(idx, grp)
left = grp
left.parent = self
tokens = tokens[1:]
left.tokens.extend(tokens)
left.value = str(left)

if extend and isinstance(left, grp_cls):
grp = left
grp.tokens.extend(tokens[1:])
else:
left = grp_cls(tokens)
left.parent = self
self.tokens.insert(idx, left)
grp = grp_cls(tokens)

for token in tokens:
token.parent = left
token.parent = grp
self.tokens.remove(token)

return left
self.tokens.insert(idx, grp)
grp.parent = self
return grp

def insert_before(self, where, token):
"""Inserts *token* before *where*."""
Expand All @@ -322,7 +310,7 @@ def insert_after(self, where, token, skip_ws=True):
if next_token is None:
self.tokens.append(token)
else:
self.tokens.insert(self.token_index(next_token), token)
self.insert_before(next_token, token)

def has_alias(self):
"""Returns ``True`` if an alias is present."""
Expand Down Expand Up @@ -435,19 +423,13 @@ def is_wildcard(self):
def get_typecast(self):
"""Returns the typecast or ``None`` of this object as a string."""
marker = self.token_next_by(m=(T.Punctuation, '::'))
if marker is None:
return None
next_ = self.token_next(marker, False)
if next_ is None:
return None
return next_.value
return next_.value if next_ else None

def get_ordering(self):
"""Returns the ordering or ``None`` as uppercase string."""
ordering = self.token_next_by(t=T.Keyword.Order)
if ordering is None:
return None
return ordering.normalized
return ordering.normalized if ordering else None

def get_array_indices(self):
"""Returns an iterator of index token lists"""
Expand Down
Loading

0 comments on commit 751933d

Please sign in to comment.