Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
first cut at extensible view predicates via config.add_view_predicate…
…; still requires testing of predicates themselves
  • Loading branch information
mcdonc committed Aug 3, 2012
1 parent fc3f23c commit a00621e
Show file tree
Hide file tree
Showing 10 changed files with 408 additions and 44 deletions.
2 changes: 2 additions & 0 deletions pyramid/config/__init__.py
Expand Up @@ -353,6 +353,8 @@ def setup_registry(self,
for name, renderer in DEFAULT_RENDERERS: for name, renderer in DEFAULT_RENDERERS:
self.add_renderer(name, renderer) self.add_renderer(name, renderer)


self.add_default_view_predicates()

if exceptionresponse_view is not None: if exceptionresponse_view is not None:
exceptionresponse_view = self.maybe_dotted(exceptionresponse_view) exceptionresponse_view = self.maybe_dotted(exceptionresponse_view)
self.add_view(exceptionresponse_view, context=IExceptionResponse) self.add_view(exceptionresponse_view, context=IExceptionResponse)
Expand Down
218 changes: 218 additions & 0 deletions pyramid/config/predicates.py
@@ -0,0 +1,218 @@
import re

from pyramid.compat import is_nonstr_iter

from pyramid.exceptions import ConfigurationError

from pyramid.traversal import (
find_interface,
traversal_path,
)

from pyramid.urldispatch import _compile_route

from .util import as_sorted_tuple

class XHRPredicate(object):
def __init__(self, val):
self.val = bool(val)

def __text__(self):
return 'xhr = True'

def __phash__(self):
return 'xhr:%r' % (self.val,)

def __call__(self, context, request):
return request.is_xhr


class RequestMethodPredicate(object):
def __init__(self, val):
self.val = as_sorted_tuple(val)

def __text__(self):
return 'request method = %r' % (self.val,)

def __phash__(self):
L = []
for v in self.val:
L.append('request_method:%r' % v)
return L

def __call__(self, context, request):
return request.method in self.val

class PathInfoPredicate(object):
def __init__(self, val):
self.orig = val
try:
val = re.compile(val)
except re.error as why:
raise ConfigurationError(why.args[0])
self.val = val

def __text__(self):
return 'path_info = %s' % (self.orig,)

def __phash__(self):
return 'path_info:%r' % (self.orig,)

def __call__(self, context, request):
return self.val.match(request.upath_info) is not None

class RequestParamPredicate(object):
def __init__(self, val):
name = val
v = None
if '=' in name:
name, v = name.split('=', 1)
if v is None:
self.text = 'request_param %s' % (name,)
else:
self.text = 'request_param %s = %s' % (name, v)
self.name = name
self.val = v

def __text__(self):
return self.text

def __phash__(self):
return 'request_param:%r=%r' % (self.name, self.val)

def __call__(self, context, request):
if self.val is None:
return self.name in request.params
return request.params.get(self.name) == self.val


class HeaderPredicate(object):
def __init__(self, val):
name = val
v = None
if ':' in name:
name, v = name.split(':', 1)
try:
v = re.compile(v)
except re.error as why:
raise ConfigurationError(why.args[0])
if v is None:
self.text = 'header %s' % (name,)
else:
self.text = 'header %s = %s' % (name, v)
self.name = name
self.val = v

def __text__(self):
return self.text

def __phash__(self):
return 'header:%r=%r' % (self.name, self.val)

def __call__(self, context, request):
if self.val is None:
return self.name in request.headers
val = request.headers.get(self.name)
if val is None:
return False
return self.val.match(val) is not None

class AcceptPredicate(object):
def __init__(self, val):
self.val = val

def __text__(self):
return 'accept = %s' % (self.val,)

def __phash__(self):
return 'accept:%r' % (self.val,)

def __call__(self, context, request):
return self.val in request.accept

class ContainmentPredicate(object):
def __init__(self, val):
self.val = val

def __text__(self):
return 'containment = %s' % (self.val,)

def __phash__(self):
return 'containment:%r' % hash(self.val)

def __call__(self, context, request):
ctx = getattr(request, 'context', context)
return find_interface(ctx, self.val) is not None

class RequestTypePredicate(object):
def __init__(self, val):
self.val = val

def __text__(self):
return 'request_type = %s' % (self.val,)

def __phash__(self):
return 'request_type:%r' % hash(self.val)

def __call__(self, context, request):
return self.val.providedBy(request)

class MatchParamPredicate(object):
def __init__(self, val):
if not is_nonstr_iter(val):
val = (val,)
val = sorted(val)
self.val = val
self.reqs = [
(x.strip(), y.strip()) for x, y in [ p.split('=', 1) for p in val ]
]

def __text__(self):
return 'match_param %s' % (self.val,)

def __phash__(self):
L = []
for k, v in self.reqs:
L.append('match_param:%r=%r' % (k, v))
return L

def __call__(self, context, request):
for k, v in self.reqs:
if request.matchdict.get(k) != v:
return False
return True

class CustomPredicate(object):
def __init__(self, func):
self.func = func

def __text__(self):
return getattr(self.func, '__text__', repr(self.func))

def __phash__(self):
return 'custom:%r' % hash(self.func)

def __call__(self, context, request):
return self.func(context, request)


class TraversePredicate(object):
def __init__(self, val):
_, self.tgenerate = _compile_route(val)
self.val = val

def __text__(self):
return 'traverse matchdict pseudo-predicate'

def __phash__(self):
return ''

def __call__(self, context, request):
if 'traverse' in context:
return True
m = context['match']
tvalue = self.tgenerate(m)
m['traverse'] = traversal_path(tvalue)
return True


56 changes: 55 additions & 1 deletion pyramid/config/util.py
Expand Up @@ -324,14 +324,26 @@ def __init__(
self.names = [] self.names = []
self.req_before = set() self.req_before = set()
self.req_after = set() self.req_after = set()
self.name2before = {}
self.name2after = {}
self.name2val = {} self.name2val = {}
self.order = [] self.order = []
self.default_before = default_before self.default_before = default_before
self.default_after = default_after self.default_after = default_after
self.first = first self.first = first
self.last = last self.last = last


def remove(self, name):
if name in self.names:
self.names.remove(name)
del self.name2val[name]
for u in self.name2after.get(name, []):
self.order.remove((u, name))
for u in self.name2before.get(name, []):
self.order.remove((name, u))

def add(self, name, val, after=None, before=None): def add(self, name, val, after=None, before=None):
self.remove(name)
self.names.append(name) self.names.append(name)
self.name2val[name] = val self.name2val[name] = val
if after is None and before is None: if after is None and before is None:
Expand All @@ -340,11 +352,13 @@ def add(self, name, val, after=None, before=None):
if after is not None: if after is not None:
if not is_nonstr_iter(after): if not is_nonstr_iter(after):
after = (after,) after = (after,)
self.name2after[name] = after
self.order += [(u, name) for u in after] self.order += [(u, name) for u in after]
self.req_after.add(name) self.req_after.add(name)
if before is not None: if before is not None:
if not is_nonstr_iter(before): if not is_nonstr_iter(before):
before = (before,) before = (before,)
self.name2before[name] = before
self.order += [(name, o) for o in before] self.order += [(name, o) for o in before]
self.req_before.add(name) self.req_before.add(name)


Expand Down Expand Up @@ -432,3 +446,43 @@ def __str__(self):
L.append('%r sorts before %r' % (dependent, dependees)) L.append('%r sorts before %r' % (dependent, dependees))
msg = 'Implicit ordering cycle:' + '; '.join(L) msg = 'Implicit ordering cycle:' + '; '.join(L)
return msg return msg

class PredicateList(object):
def __init__(self):
self.sorter = TopologicalSorter()

def add(self, name, factory, weighs_more_than=None, weighs_less_than=None):
self.sorter.add(name, factory, after=weighs_more_than,
before=weighs_less_than)

def make(self, **kw):
ordered = self.sorter.sorted()
phash = md5()
weights = []
predicates = []
for order, (name, predicate_factory) in enumerate(ordered):
vals = kw.pop(name, None)
if vals is None:
continue
if not isinstance(vals, SequenceOfPredicateValues):
vals = (vals,)
for val in vals:
predicate = predicate_factory(val)
hashes = predicate.__phash__()
if not is_nonstr_iter(hashes):
hashes = [hashes]
for h in hashes:
phash.update(bytes_(h))
predicate = predicate_factory(val)
weights.append(1 << order)
predicates.append(predicate)
if kw:
raise ConfigurationError('Unknown predicate values: %r' % (kw,))
score = 0
for bit in weights:
score = score | bit
order = (MAX_ORDER - score) / (len(predicates) + 1)
return order, predicates, phash.hexdigest()

class SequenceOfPredicateValues(tuple):
pass

0 comments on commit a00621e

Please sign in to comment.