Permalink
Browse files

first cut at extensible view predicates via config.add_view_predicate…

…; still requires testing of predicates themselves
  • Loading branch information...
mcdonc committed Aug 3, 2012
1 parent fc3f23c commit a00621e45ef29cde34469798144156c80a17a1e9
@@ -353,6 +353,8 @@ def setup_registry(self,
for name, renderer in DEFAULT_RENDERERS:
self.add_renderer(name, renderer)
+ self.add_default_view_predicates()
+
if exceptionresponse_view is not None:
exceptionresponse_view = self.maybe_dotted(exceptionresponse_view)
self.add_view(exceptionresponse_view, context=IExceptionResponse)
@@ -0,0 +1,218 @@
+import re
+
+from pyramid.compat import is_nonstr_iter
+
+from pyramid.exceptions import ConfigurationError
+
+from pyramid.traversal import (
+ find_interface,
+ traversal_path,
+ )
+
+from pyramid.urldispatch import _compile_route
+
+from .util import as_sorted_tuple
+
+class XHRPredicate(object):
+ def __init__(self, val):
+ self.val = bool(val)
+
+ def __text__(self):
+ return 'xhr = True'
+
+ def __phash__(self):
+ return 'xhr:%r' % (self.val,)
+
+ def __call__(self, context, request):
+ return request.is_xhr
+
+
+class RequestMethodPredicate(object):
+ def __init__(self, val):
+ self.val = as_sorted_tuple(val)
+
+ def __text__(self):
+ return 'request method = %r' % (self.val,)
+
+ def __phash__(self):
+ L = []
+ for v in self.val:
+ L.append('request_method:%r' % v)
+ return L
+
+ def __call__(self, context, request):
+ return request.method in self.val
+
+class PathInfoPredicate(object):
+ def __init__(self, val):
+ self.orig = val
+ try:
+ val = re.compile(val)
+ except re.error as why:
+ raise ConfigurationError(why.args[0])
+ self.val = val
+
+ def __text__(self):
+ return 'path_info = %s' % (self.orig,)
+
+ def __phash__(self):
+ return 'path_info:%r' % (self.orig,)
+
+ def __call__(self, context, request):
+ return self.val.match(request.upath_info) is not None
+
+class RequestParamPredicate(object):
+ def __init__(self, val):
+ name = val
+ v = None
+ if '=' in name:
+ name, v = name.split('=', 1)
+ if v is None:
+ self.text = 'request_param %s' % (name,)
+ else:
+ self.text = 'request_param %s = %s' % (name, v)
+ self.name = name
+ self.val = v
+
+ def __text__(self):
+ return self.text
+
+ def __phash__(self):
+ return 'request_param:%r=%r' % (self.name, self.val)
+
+ def __call__(self, context, request):
+ if self.val is None:
+ return self.name in request.params
+ return request.params.get(self.name) == self.val
+
+
+class HeaderPredicate(object):
+ def __init__(self, val):
+ name = val
+ v = None
+ if ':' in name:
+ name, v = name.split(':', 1)
+ try:
+ v = re.compile(v)
+ except re.error as why:
+ raise ConfigurationError(why.args[0])
+ if v is None:
+ self.text = 'header %s' % (name,)
+ else:
+ self.text = 'header %s = %s' % (name, v)
+ self.name = name
+ self.val = v
+
+ def __text__(self):
+ return self.text
+
+ def __phash__(self):
+ return 'header:%r=%r' % (self.name, self.val)
+
+ def __call__(self, context, request):
+ if self.val is None:
+ return self.name in request.headers
+ val = request.headers.get(self.name)
+ if val is None:
+ return False
+ return self.val.match(val) is not None
+
+class AcceptPredicate(object):
+ def __init__(self, val):
+ self.val = val
+
+ def __text__(self):
+ return 'accept = %s' % (self.val,)
+
+ def __phash__(self):
+ return 'accept:%r' % (self.val,)
+
+ def __call__(self, context, request):
+ return self.val in request.accept
+
+class ContainmentPredicate(object):
+ def __init__(self, val):
+ self.val = val
+
+ def __text__(self):
+ return 'containment = %s' % (self.val,)
+
+ def __phash__(self):
+ return 'containment:%r' % hash(self.val)
+
+ def __call__(self, context, request):
+ ctx = getattr(request, 'context', context)
+ return find_interface(ctx, self.val) is not None
+
+class RequestTypePredicate(object):
+ def __init__(self, val):
+ self.val = val
+
+ def __text__(self):
+ return 'request_type = %s' % (self.val,)
+
+ def __phash__(self):
+ return 'request_type:%r' % hash(self.val)
+
+ def __call__(self, context, request):
+ return self.val.providedBy(request)
+
+class MatchParamPredicate(object):
+ def __init__(self, val):
+ if not is_nonstr_iter(val):
+ val = (val,)
+ val = sorted(val)
+ self.val = val
+ self.reqs = [
+ (x.strip(), y.strip()) for x, y in [ p.split('=', 1) for p in val ]
+ ]
+
+ def __text__(self):
+ return 'match_param %s' % (self.val,)
+
+ def __phash__(self):
+ L = []
+ for k, v in self.reqs:
+ L.append('match_param:%r=%r' % (k, v))
+ return L
+
+ def __call__(self, context, request):
+ for k, v in self.reqs:
+ if request.matchdict.get(k) != v:
+ return False
+ return True
+
+class CustomPredicate(object):
+ def __init__(self, func):
+ self.func = func
+
+ def __text__(self):
+ return getattr(self.func, '__text__', repr(self.func))
+
+ def __phash__(self):
+ return 'custom:%r' % hash(self.func)
+
+ def __call__(self, context, request):
+ return self.func(context, request)
+
+
+class TraversePredicate(object):
+ def __init__(self, val):
+ _, self.tgenerate = _compile_route(val)
+ self.val = val
+
+ def __text__(self):
+ return 'traverse matchdict pseudo-predicate'
+
+ def __phash__(self):
+ return ''
+
+ def __call__(self, context, request):
+ if 'traverse' in context:
+ return True
+ m = context['match']
+ tvalue = self.tgenerate(m)
+ m['traverse'] = traversal_path(tvalue)
+ return True
+
+
View
@@ -324,14 +324,26 @@ def __init__(
self.names = []
self.req_before = set()
self.req_after = set()
+ self.name2before = {}
+ self.name2after = {}
self.name2val = {}
self.order = []
self.default_before = default_before
self.default_after = default_after
self.first = first
self.last = last
-
+
+ def remove(self, name):
+ if name in self.names:
+ self.names.remove(name)
+ del self.name2val[name]
+ for u in self.name2after.get(name, []):
+ self.order.remove((u, name))
+ for u in self.name2before.get(name, []):
+ self.order.remove((name, u))
+
def add(self, name, val, after=None, before=None):
+ self.remove(name)
self.names.append(name)
self.name2val[name] = val
if after is None and before is None:
@@ -340,11 +352,13 @@ def add(self, name, val, after=None, before=None):
if after is not None:
if not is_nonstr_iter(after):
after = (after,)
+ self.name2after[name] = after
self.order += [(u, name) for u in after]
self.req_after.add(name)
if before is not None:
if not is_nonstr_iter(before):
before = (before,)
+ self.name2before[name] = before
self.order += [(name, o) for o in before]
self.req_before.add(name)
@@ -432,3 +446,43 @@ def __str__(self):
L.append('%r sorts before %r' % (dependent, dependees))
msg = 'Implicit ordering cycle:' + '; '.join(L)
return msg
+
+class PredicateList(object):
+ def __init__(self):
+ self.sorter = TopologicalSorter()
+
+ def add(self, name, factory, weighs_more_than=None, weighs_less_than=None):
+ self.sorter.add(name, factory, after=weighs_more_than,
+ before=weighs_less_than)
+
+ def make(self, **kw):
+ ordered = self.sorter.sorted()
+ phash = md5()
+ weights = []
+ predicates = []
+ for order, (name, predicate_factory) in enumerate(ordered):
+ vals = kw.pop(name, None)
+ if vals is None:
+ continue
+ if not isinstance(vals, SequenceOfPredicateValues):
+ vals = (vals,)
+ for val in vals:
+ predicate = predicate_factory(val)
+ hashes = predicate.__phash__()
+ if not is_nonstr_iter(hashes):
+ hashes = [hashes]
+ for h in hashes:
+ phash.update(bytes_(h))
+ predicate = predicate_factory(val)
+ weights.append(1 << order)
+ predicates.append(predicate)
+ if kw:
+ raise ConfigurationError('Unknown predicate values: %r' % (kw,))
+ score = 0
+ for bit in weights:
+ score = score | bit
+ order = (MAX_ORDER - score) / (len(predicates) + 1)
+ return order, predicates, phash.hexdigest()
+
+class SequenceOfPredicateValues(tuple):
+ pass
Oops, something went wrong.

0 comments on commit a00621e

Please sign in to comment.