Skip to content

Commit

Permalink
Safe variable substitution in DB queries.
Browse files Browse the repository at this point in the history
The variable subsitution in db queries uses eval, which is not very
secure. Implented a custom parser and evaluator to avoid eval.
  • Loading branch information
anandology committed May 10, 2014
1 parent 73f1119 commit 331f13c
Showing 1 changed file with 193 additions and 4 deletions.
197 changes: 193 additions & 4 deletions web/db.py
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -17,6 +17,16 @@
except ImportError: except ImportError:
datetime = None datetime = None


try:
import ast
except ImportError:
ast = None

try:
from tokenize import tokenprog
except ImportError:
tokenprog = None

try: set try: set
except NameError: except NameError:
from sets import Set as set from sets import Set as set
Expand Down Expand Up @@ -91,6 +101,9 @@ def __radd__(self, other):


def __str__(self): def __str__(self):
return str(self.value) return str(self.value)

def __eq__(self, other):
return isinstance(other, SQLParam) and other.value == self.value


def __repr__(self): def __repr__(self):
return '<param: %s>' % repr(self.value) return '<param: %s>' % repr(self.value)
Expand Down Expand Up @@ -153,9 +166,10 @@ def __add__(self, other):
def __radd__(self, other): def __radd__(self, other):
if isinstance(other, basestring): if isinstance(other, basestring):
items = [other] items = [other]
elif isinstance(other, SQLQuery):
items = other.items
else: else:
return NotImplemented return NotImplemented

return SQLQuery(items + self.items) return SQLQuery(items + self.items)


def __iadd__(self, other): def __iadd__(self, other):
Expand All @@ -169,6 +183,9 @@ def __iadd__(self, other):


def __len__(self): def __len__(self):
return len(self.query()) return len(self.query())

def __eq__(self, other):
return isinstance(other, SQLQuery) and other.items == self.items


def query(self, paramstyle=None): def query(self, paramstyle=None):
""" """
Expand Down Expand Up @@ -226,10 +243,12 @@ def join(items, sep=' ', prefix=None, suffix=None, target=None):
target_items.append(prefix) target_items.append(prefix)


for i, item in enumerate(items): for i, item in enumerate(items):
if i != 0: if i != 0 and sep != "":
target_items.append(sep) target_items.append(sep)
if isinstance(item, SQLQuery): if isinstance(item, SQLQuery):
target_items.extend(item.items) target_items.extend(item.items)
elif item == "": # joins with empty strings
continue
else: else:
target_items.append(item) target_items.append(item)


Expand Down Expand Up @@ -267,7 +286,7 @@ def __init__(self, v):
self.v = v self.v = v


def __repr__(self): def __repr__(self):
return self.v return "<literal: %r>" % self.v


sqlliteral = SQLLiteral sqlliteral = SQLLiteral


Expand Down Expand Up @@ -295,10 +314,11 @@ def reparam(string_, dictionary):
>>> reparam("s IN $s", dict(s=[1, 2])) >>> reparam("s IN $s", dict(s=[1, 2]))
<sql: 's IN (1, 2)'> <sql: 's IN (1, 2)'>
""" """
return SafeEval().safeeval(string_, dictionary)

dictionary = dictionary.copy() # eval mucks with it dictionary = dictionary.copy() # eval mucks with it
# disable builtins to avoid risk for remote code exection. # disable builtins to avoid risk for remote code exection.
dictionary['__builtins__'] = object() dictionary['__builtins__'] = object()
vals = []
result = [] result = []
for live, chunk in _interpolate(string_): for live, chunk in _interpolate(string_):
if live: if live:
Expand Down Expand Up @@ -1276,6 +1296,175 @@ def matchorfail(text, pos):
chunks.append((0, format[pos:])) chunks.append((0, format[pos:]))
return chunks return chunks


class _Node(object):
def __init__(self, type, first, second=None):
self.type = type
self.first = first
self.second = second

def __eq__(self, other):
return (isinstance(other, _Node)
and self.type == other.type
and self.first == other.first
and self.second == other.second)

def __repr__(self):
return "Node(%r, %r, %r)" % (self.type, self.first, self.second)

class Parser:
"""Parser to parse string templates like "Hello $name".
Loosely based on <http://lfw.org/python/Itpl.py> (public domain, Ka-Ping Yee)
"""
namechars = "abcdefghijklmnopqrstuvwxyz" \
"ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_"

def __init__(self):
self.reset()

def reset(self):
self.pos = 0
self.level = 0
self.text = ""

def parse(self, text):
"""Parses the given text and returns a parse tree.
"""
self.reset()
self.text = text
return self.parse_all()

def parse_all(self):
while True:
dollar = self.text.find("$", self.pos)
if dollar < 0:
break
nextchar = self.text[dollar + 1]
if nextchar in self.namechars:
yield _Node("text", self.text[self.pos:dollar])
self.pos = dollar+1
yield self.parse_expr()

# for supporting ${x.id}, for backward compataility
elif nextchar == '{':
saved_pos = self.pos
self.pos = dollar+2 # skip "${"
expr = self.parse_expr()
if self.text[self.pos] == '}':
self.pos += 1
yield _Node("text", self.text[self.pos:dollar])
yield expr
else:
self.pos = saved_pos
break
else:
yield _Node("text", self.text[self.pos:dollar+1])
self.pos = dollar + 1
# $$ is used to escape $
if nextchar == "$":
self.pos += 1

if self.pos < len(self.text):
yield _Node("text", self.text[self.pos:])

def match(self):
match = tokenprog.match(self.text, self.pos)
if match is None:
raise _ItplError(self.text, self.pos)
return match, match.end()

def is_literal(self, text):
return text and text[0] in "0123456789\"'"

def parse_expr(self):
match, pos = self.match()
if self.is_literal(match.group()):
expr = _Node("literal", match.group())
else:
expr = _Node("param", self.text[self.pos:pos])
self.pos = pos
while self.pos < len(self.text):
if self.text[self.pos] == "." and \
self.pos + 1 < len(self.text) and self.text[self.pos + 1] in self.namechars:
self.pos += 1
match, pos = self.match()
attr = match.group()
expr = _Node("getattr", expr, attr)
self.pos = pos
elif self.text[self.pos] == "[":
saved_pos = self.pos
self.pos += 1
key = self.parse_expr()
if self.text[self.pos] == ']':
self.pos += 1
expr = _Node("getitem", expr, key)
else:
self.pos = saved_pos
break
else:
break
return expr

class SafeEval(object):
"""Safe evaluator for binding params to db queries.
"""
def safeeval(self, text, mapping):
nodes = Parser().parse(text)
return SQLQuery.join([self.eval_node(node, mapping) for node in nodes], "")

def eval_node(self, node, mapping):
if node.type == "text":
return node.first
else:
return sqlquote(self.eval_expr(node, mapping))

def eval_expr(self, node, mapping):
if node.type == "literal":
return ast.literal_eval(node.first)
elif node.type == "getattr":
return getattr(self.eval_expr(node.first, mapping), node.second)
elif node.type == "getitem":
return self.eval_expr(node.first, mapping)[self.eval_expr(node.second, mapping)]
elif node.type == "param":
return mapping[node.first]

def test_parser():
def f(text, expected):
p = Parser()
nodes = list(p.parse(text))
print repr(text), nodes
assert nodes == expected, "Expected %r" % expected

f("Hello", [_Node("text", "Hello")])
f("Hello $name", [_Node("text", "Hello "), _Node("param", "name")])
f("Hello $name.foo", [
_Node("text", "Hello "),
_Node("getattr",
_Node("param", "name"),
"foo")])
f("WHERE id=$self.id LIMIT 1", [
_Node("text", "WHERE id="),
_Node('getattr',
_Node('param', 'self', None),
'id'),
_Node("text", " LIMIT 1")])

f("WHERE id=$self['id'] LIMIT 1", [
_Node("text", "WHERE id="),
_Node('getitem',
_Node('param', 'self', None),
_Node('literal', "'id'")),
_Node("text", " LIMIT 1")])

def test_safeeval():
def f(q, vars):
return SafeEval().safeeval(q, vars)

print f("WHERE id=$id", {"id": 1}).items
assert f("WHERE id=$id", {"id": 1}).items == ["WHERE id=", sqlparam(1)]

if __name__ == "__main__": if __name__ == "__main__":
import doctest import doctest
doctest.testmod() doctest.testmod()
test_parser()
test_safeeval()

0 comments on commit 331f13c

Please sign in to comment.