Permalink
Browse files

Safe variable substitution in DB queries.

The variable subsitution in db queries uses eval, which is not very
secure. Implented a custom parser and evaluator to avoid eval.
  • Loading branch information...
anandology committed May 10, 2014
1 parent 73f1119 commit 331f13cb43f338e537da19cef85529c7c7bbd7e9
Showing with 193 additions and 4 deletions.
  1. +193 −4 web/db.py
View
197 web/db.py
@@ -17,6 +17,16 @@
except ImportError:
datetime = None
try:
import ast
except ImportError:
ast = None
try:
from tokenize import tokenprog
except ImportError:
tokenprog = None
try: set
except NameError:
from sets import Set as set
@@ -91,6 +101,9 @@ def __radd__(self, other):
def __str__(self):
return str(self.value)
def __eq__(self, other):
return isinstance(other, SQLParam) and other.value == self.value
def __repr__(self):
return '<param: %s>' % repr(self.value)
@@ -153,9 +166,10 @@ def __add__(self, other):
def __radd__(self, other):
if isinstance(other, basestring):
items = [other]
elif isinstance(other, SQLQuery):
items = other.items
else:
return NotImplemented
return SQLQuery(items + self.items)
def __iadd__(self, other):
@@ -169,6 +183,9 @@ def __iadd__(self, other):
def __len__(self):
return len(self.query())
def __eq__(self, other):
return isinstance(other, SQLQuery) and other.items == self.items
def query(self, paramstyle=None):
"""
@@ -226,10 +243,12 @@ def join(items, sep=' ', prefix=None, suffix=None, target=None):
target_items.append(prefix)
for i, item in enumerate(items):
if i != 0:
if i != 0 and sep != "":
target_items.append(sep)
if isinstance(item, SQLQuery):
target_items.extend(item.items)
elif item == "": # joins with empty strings
continue
else:
target_items.append(item)
@@ -267,7 +286,7 @@ def __init__(self, v):
self.v = v
def __repr__(self):
return self.v
return "<literal: %r>" % self.v
sqlliteral = SQLLiteral
@@ -295,10 +314,11 @@ def reparam(string_, dictionary):
>>> reparam("s IN $s", dict(s=[1, 2]))
<sql: 's IN (1, 2)'>
"""
return SafeEval().safeeval(string_, dictionary)
dictionary = dictionary.copy() # eval mucks with it
# disable builtins to avoid risk for remote code exection.
dictionary['__builtins__'] = object()
vals = []
result = []
for live, chunk in _interpolate(string_):
if live:
@@ -1276,6 +1296,175 @@ def matchorfail(text, pos):
chunks.append((0, format[pos:]))
return chunks
class _Node(object):
def __init__(self, type, first, second=None):
self.type = type
self.first = first
self.second = second
def __eq__(self, other):
return (isinstance(other, _Node)
and self.type == other.type
and self.first == other.first
and self.second == other.second)
def __repr__(self):
return "Node(%r, %r, %r)" % (self.type, self.first, self.second)
class Parser:
"""Parser to parse string templates like "Hello $name".
Loosely based on <http://lfw.org/python/Itpl.py> (public domain, Ka-Ping Yee)
"""
namechars = "abcdefghijklmnopqrstuvwxyz" \
"ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_"
def __init__(self):
self.reset()
def reset(self):
self.pos = 0
self.level = 0
self.text = ""
def parse(self, text):
"""Parses the given text and returns a parse tree.
"""
self.reset()
self.text = text
return self.parse_all()
def parse_all(self):
while True:
dollar = self.text.find("$", self.pos)
if dollar < 0:
break
nextchar = self.text[dollar + 1]
if nextchar in self.namechars:
yield _Node("text", self.text[self.pos:dollar])
self.pos = dollar+1
yield self.parse_expr()
# for supporting ${x.id}, for backward compataility
elif nextchar == '{':
saved_pos = self.pos
self.pos = dollar+2 # skip "${"
expr = self.parse_expr()
if self.text[self.pos] == '}':
self.pos += 1
yield _Node("text", self.text[self.pos:dollar])
yield expr
else:
self.pos = saved_pos
break
else:
yield _Node("text", self.text[self.pos:dollar+1])
self.pos = dollar + 1
# $$ is used to escape $
if nextchar == "$":
self.pos += 1
if self.pos < len(self.text):
yield _Node("text", self.text[self.pos:])
def match(self):
match = tokenprog.match(self.text, self.pos)
if match is None:
raise _ItplError(self.text, self.pos)
return match, match.end()
def is_literal(self, text):
return text and text[0] in "0123456789\"'"
def parse_expr(self):
match, pos = self.match()
if self.is_literal(match.group()):
expr = _Node("literal", match.group())
else:
expr = _Node("param", self.text[self.pos:pos])
self.pos = pos
while self.pos < len(self.text):
if self.text[self.pos] == "." and \
self.pos + 1 < len(self.text) and self.text[self.pos + 1] in self.namechars:
self.pos += 1
match, pos = self.match()
attr = match.group()
expr = _Node("getattr", expr, attr)
self.pos = pos
elif self.text[self.pos] == "[":
saved_pos = self.pos
self.pos += 1
key = self.parse_expr()
if self.text[self.pos] == ']':
self.pos += 1
expr = _Node("getitem", expr, key)
else:
self.pos = saved_pos
break
else:
break
return expr
class SafeEval(object):
"""Safe evaluator for binding params to db queries.
"""
def safeeval(self, text, mapping):
nodes = Parser().parse(text)
return SQLQuery.join([self.eval_node(node, mapping) for node in nodes], "")
def eval_node(self, node, mapping):
if node.type == "text":
return node.first
else:
return sqlquote(self.eval_expr(node, mapping))
def eval_expr(self, node, mapping):
if node.type == "literal":
return ast.literal_eval(node.first)
elif node.type == "getattr":
return getattr(self.eval_expr(node.first, mapping), node.second)
elif node.type == "getitem":
return self.eval_expr(node.first, mapping)[self.eval_expr(node.second, mapping)]
elif node.type == "param":
return mapping[node.first]
def test_parser():
def f(text, expected):
p = Parser()
nodes = list(p.parse(text))
print repr(text), nodes
assert nodes == expected, "Expected %r" % expected
f("Hello", [_Node("text", "Hello")])
f("Hello $name", [_Node("text", "Hello "), _Node("param", "name")])
f("Hello $name.foo", [
_Node("text", "Hello "),
_Node("getattr",
_Node("param", "name"),
"foo")])
f("WHERE id=$self.id LIMIT 1", [
_Node("text", "WHERE id="),
_Node('getattr',
_Node('param', 'self', None),
'id'),
_Node("text", " LIMIT 1")])
f("WHERE id=$self['id'] LIMIT 1", [
_Node("text", "WHERE id="),
_Node('getitem',
_Node('param', 'self', None),
_Node('literal', "'id'")),
_Node("text", " LIMIT 1")])
def test_safeeval():
def f(q, vars):
return SafeEval().safeeval(q, vars)
print f("WHERE id=$id", {"id": 1}).items
assert f("WHERE id=$id", {"id": 1}).items == ["WHERE id=", sqlparam(1)]
if __name__ == "__main__":
import doctest
doctest.testmod()
test_parser()
test_safeeval()

0 comments on commit 331f13c

Please sign in to comment.