Permalink
Browse files

* decided api response class was unnecessary

* added first permission checkes
* expanded tests, added use of fixtures for faking solr responses
  • Loading branch information...
1 parent 1ff140d commit bff7304786d7c841a48ddfd0106b0d6fb3be6268 @lbjay lbjay committed Nov 16, 2012
View
@@ -80,14 +80,15 @@ def load_user(id):
try:
logger.debug("initializing mongodb")
mongodb.init_app(app) #@UndefinedVariable
- except:
- pass
+ except Exception, e:
+ logger.error("Failed to initialize mongoalchemy session: %s" % e.message)
+ raise
logger.debug("initializing solr connection")
solr.init_app(app) #@UndefinedVariable
logger.debug("initializing pushrod")
- pushrod.init_app(app)
+ pushrod.init_app(app) #@UndefinedVariable
def _configure_error_handlers(app):
"""
@@ -3,18 +3,65 @@
@author: jluker
'''
+import re
+import logging
from flask import g
+from simplejson import loads
from config import config
from .response import SolrResponse
+log = logging.getLogger(__name__)
+
+class SolrRequest(object):
+
+ @staticmethod
+ def parse_query_fields(q):
+ fields = []
+ p = re.compile('(?P<field>[a-z]+):\S')
+ return [match.group('field') for match in p.finditer(q)]
+
+ def __init__(self, q, **kwargs):
+ self.q = q
+ self.params = SolrParams(q=q, **kwargs)
+
+ def set_rows(self, rows):
+ self.rows = rows
+ self.params.rows = rows
+
+ def set_fields(self, fields):
+ self.fields = fields
+ self.params.fl = ','.join(fields)
+
+ def set_sort(self, sort, direction="asc"):
+ self.sort = sort
+ self.sort_direction = direction
+ self.params.sort = "%s %s" % (sort, direction)
+
+ def add_filter(self, field, value):
+ if not hasattr(self, 'filters'):
+ self.filters = {}
+ self.filters.setdefault('field', [])
+ self.filters['field'].append(value)
+ self.params.append('fq', '%s:%s' % (field, value))
+
+ def get_response(self):
+ try:
+ json = g.solr.raw_query(**self.params)
+ except:
+ log.error("Something blew up when querying solr")
+ raise
+
+ data = loads(json)
+ return SolrResponse(data)
+
+
class SolrParams(dict):
def __init__(self, *args, **kwargs):
# set default values
self.update(
config.SOLR_DEFAULT_PARAMS,
- fl=','.join(config.SOLR_DEFAULT_FIELDS_SEARCH),
sort=config.SOLR_DEFAULT_SORT,
rows=config.SOLR_DEFAULT_ROWS,
wt=config.SOLR_DEFAULT_FORMAT
@@ -39,41 +86,4 @@ def update(self, *args, **kwargs):
def append(self, key, val):
self.setdefault(key, [])
self[key].append(val)
-
-class SolrRequest(object):
-
- def __init__(self, q, **kwargs):
- self.q = q
- self.params = SolrParams(q=q, **kwargs)
-
- def get_response(self):
- json = g.solr.raw_query(**self.params)
- return SolrResponse.from_json(json, request=self)
-
- def set_format(self, format):
- self.format = format
- self.params.wt = format
-
- def set_rows(self, rows):
- self.rows = rows
- self.params.rows = rows
-
- def set_fields(self, fields):
- self.fields = fields
- self.params.fl = ','.join(fields)
-
- def set_sort(self, sort, direction="asc"):
- self.sort = sort
- self.sort_direction = direction
- self.params.sort = "%s %s" % (sort, direction)
-
- def add_filter(self, field, value):
- if not hasattr(self, 'filters'):
- self.filters = {}
- self.filters.setdefault('field', [])
- self.filters['field'].append(value)
- self.params.append('fq', '%s:%s' % (field, value))
-
-
-
@@ -9,21 +9,8 @@
class SolrResponse(object):
- @staticmethod
- def from_json(json, request=None):
- data = loads(json)
- docset = data['response']['docs']
- if 'facet_counts' in data['response']:
- facets = SolrFacets.from_dict(data['response']['facet_counts'])
- else:
- facets = None
-
- return SolrResponse(docset, data, facets, request)
-
- def __init__(self, docset=[], data={}, facets=None, request=None):
- self.docset = docset
- self.data = data
- self.facets = facets
+ def __init__(self, raw):
+ self.raw = raw
self.iter_idx = -1
def __iter__(self):
@@ -37,22 +24,42 @@ def next(self):
else:
raise StopIteration
- def get_docs(self):
- return self.docset
+ def search_response(self):
+ resp = {
+ 'meta': { 'errors': None },
+ 'results': {
+ 'count': self.get_count(),
+ 'docs': self.get_docset(),
+ 'facets': self.get_facets(),
+ }
+ }
+ return resp
- def get_doc_objects(self):
- return [SolrDocument(x) for x in self.docset]
+ def record_response(self, idx=0):
+ try:
+ return self.get_docset()[idx]
+ except IndexError:
+ return None
+
+ def get_docset(self):
+ return self.raw['response'].get('docs', [])
+
+ def get_docset_objects(self):
+ return [SolrDocument(x) for x in self.get_docset()]
+ def get_facets(self):
+ return self.raw['response'].get('facet_counts', {})
+
def get_query(self):
- return self.data['responseHeader']['params']['q']
+ return self.raw['responseHeader']['params']['q']
def get_count(self):
- return int(self.data['response']['numFound'])
+ return int(self.raw['response']['numFound'])
def get_qtime(self):
- return self.data['responseHeader']['QTime']
+ return self.raw['responseHeader']['QTime']
def as_json(self):
- return dumps(self.data)
+ return dumps(self.raw)
@@ -5,14 +5,16 @@
'''
from flask.ext.pushrod import pushrod_view #@UnresolvedImport
-from .response import ApiResponse
class ApiNotAuthenticatedError(Exception):
pass
class ApiInvalidRequest(Exception):
pass
+class ApiPermissionError(Exception):
+ pass
+
class ApiRecordNotFound(Exception):
pass
@@ -28,6 +30,12 @@ def not_authenticated(error):
def invalid_request(error):
msg = "API request invalid: %s" % error.message
return {'error': msg},401,None
+
+ @app.errorhandler(ApiPermissionError)
+ @pushrod_view(xml_template="error.xml")
+ def permission_error(error):
+ msg = "Permission error: %s " % error.message
+ return {'error': msg},401,None
@app.errorhandler(ApiRecordNotFound)
@pushrod_view(xml_template="error.xml")
@@ -4,15 +4,27 @@
@author: jluker
'''
-from flask.ext.wtf import Form, fields, validators #@UnresolvedImport
+from flask import g
+from flask.ext.wtf import Form, fields, validators, ValidationError #@UnresolvedImport
+from adsabs.core.solr import SolrRequest
from config import config
+MIN_QUERY_LENGTH = 2
+MAX_QUERY_LENGTH = 2048
+
def api_defaults(*args, **kwargs):
pass
+def validate_query(form, field):
+ if len(field.data) < MIN_QUERY_LENGTH:
+ raise ValidationError("'q' input must be at least %s characters" % MIN_QUERY_LENGTH)
+ if len(field.data) > MAX_QUERY_LENGTH:
+ raise ValidationError("'q' input must be at no more than %s characters" % MIN_QUERY_LENGTH)
+ fields_queried = SolrRequest.parse_query_fields(field.data)
+
class ApiQueryForm(Form):
- q = fields.TextField('query', validators=[validators.required(), validators.length(min=2, max=2048)])
+ q = fields.TextField('query', [validators.Required(), validate_query])
dev_key = fields.TextField('dev_key', default=None)
rows = fields.IntegerField('rows', default=api_defaults)
start = fields.IntegerField('start', default=api_defaults)
@@ -0,0 +1,38 @@
+'''
+Created on Nov 15, 2012
+
+@author: jluker
+'''
+
+from flask.ext.login import current_user #@UnresolvedImport
+
+from .errors import ApiPermissionError
+
+class DevPermissions(object):
+
+ def __init__(self, perms):
+ self.perms = perms
+
+ def check_facets(self):
+ assert self.perms.get('facets', False), 'facets disabled'
+
+ def check_max_rows(self, req_max_rows):
+ max_rows = self.perms.get('max_rows', 0)
+ assert max_rows >= req_max_rows, 'maximum rows allowed is %d' % max_rows
+
+ def check_max_start(self, req_max_start):
+ max_start = self.perms.get('max_start', 0)
+ assert max_start >= req_max_start, 'maximum start allowed is %d' % max_start
+
+ def check_permissions(self, form):
+
+ try:
+ if form.facets.data:
+ self.check_facets(form.facets.data)
+ self.check_max_rows(form.rows.data)
+ self.check_max_start(form.start.data)
+ except AssertionError, e:
+ raise ApiPermissionError(e.message)
+
+ return True
+
@@ -4,17 +4,36 @@
@author: jluker
'''
+from flask import g #@UnresolvedImport
+from adsabs.core.solr import SolrRequest
from .forms import ApiQueryForm
-
+from .permissions import DevPermissions
+from .errors import ApiPermissionError
+
class ApiSearchRequest(object):
- def __init__(self, flask_request):
- self.flask_request = flask_request
- self.form = ApiQueryForm(flask_request.values, csrf_enabled=False)
+ def __init__(self, request_vals, user=None):
+ self.form = ApiQueryForm(request_vals, csrf_enabled=False)
+ if user:
+ perms = user.get_dev_perms()
+ else:
+ perms = g.api_user.get_dev_perms()
+ self.perms = DevPermissions(perms)
def validate(self):
- return self.form.validate()
+ return self.form.validate() and self.perms.check_permissions(self.form)
+ def execute(self):
+ solr_req = SolrRequest(
+ self.form.q.data,
+ facets=self.form.facets.data
+ )
+ self.resp = solr_req.get_response()
+ return self.resp
+
+ def fields_queried(self):
+ return SolrRequest.parse_query_fields(self.form.q.data)
+
def query(self):
return self.form.q.data
@@ -31,4 +50,4 @@ def query(self):
return "identifier:%s" % self.query_id
def validate(self):
- return True
+ return True
Oops, something went wrong. Retry.

0 comments on commit bff7304

Please sign in to comment.