Skip to content
This repository
Browse code

finished adding validators to ApiQueryForm class

  • Loading branch information...
commit b80597c02df5d8db5d15556e1b357a70f09e790c 1 parent d74c0fa
Jay Luker authored lbjay committed
118  adsabs/modules/api/forms.py
@@ -4,32 +4,112 @@
4 4
 @author: jluker
5 5
 '''
6 6
 
  7
+import re
  8
+
7 9
 from flask import g
8  
-from flask.ext.wtf import Form, fields, validators, ValidationError #@UnresolvedImport
  10
+from flask.ext.wtf import Form, fields as fields_, validators, ValidationError #@UnresolvedImport
9 11
 
10 12
 from adsabs.core.solr import SolrRequest
  13
+from .renderers import VALID_FORMATS
11 14
 from config import config
12 15
 
13 16
 MIN_QUERY_LENGTH = 2
14 17
 MAX_QUERY_LENGTH = 2048
  18
+SORT_DIRECTIONS = ['asc','desc']
15 19
 
16  
-def api_defaults(*args, **kwargs):
17  
-    pass
18 20
 
19  
-def validate_query(form, field):
20  
-    if len(field.data) < MIN_QUERY_LENGTH:
21  
-        raise ValidationError("'q' input must be at least %s characters" % MIN_QUERY_LENGTH)
22  
-    if len(field.data) > MAX_QUERY_LENGTH:
23  
-        raise ValidationError("'q' input must be at no more than %s characters" % MIN_QUERY_LENGTH)
24  
-    fields_queried = SolrRequest.parse_query_fields(field.data)
25  
-    
26 21
 class ApiQueryForm(Form):
27  
-    q = fields.TextField('query', [validators.Required(), validate_query])
28  
-    dev_key = fields.TextField('dev_key', default=None)
29  
-    rows = fields.IntegerField('rows', default=api_defaults)
30  
-    start = fields.IntegerField('start', default=api_defaults)
31  
-    sort = fields.TextField('sort', default=api_defaults)
32  
-    format = fields.TextField('format', default=api_defaults)
33  
-    facets = fields.BooleanField('facets', default=api_defaults)
34  
-    highlights = fields.BooleanField('highlights', default=api_defaults)
35  
-    
  22
+    q = fields_.TextField('query', [validators.Required()])
  23
+    dev_key = fields_.TextField('dev_key', [validators.Required()])
  24
+    rows = fields_.IntegerField('rows')
  25
+    start = fields_.IntegerField('start')
  26
+    sort = fields_.TextField('sort')
  27
+    flds = fields_.TextField('fields')
  28
+    facet = fields_.FieldList(fields_.TextField('facet'))
  29
+    fmt = fields_.TextField('format')
  30
+    filter = fields_.FieldList(fields_.TextField('filter'))
  31
+    hl = fields_.FieldList(fields_.TextField('hl'))
  32
+    
  33
+    def validate_q(self, field):
  34
+        """
  35
+        just checks for min/max length so far
  36
+        TODO: maybe do closer inspection of the query syntax?
  37
+        """
  38
+        if len(field.data) < MIN_QUERY_LENGTH:
  39
+            raise ValidationError("'q' input must be at least %s characters" % MIN_QUERY_LENGTH)
  40
+        if len(field.data) > MAX_QUERY_LENGTH:
  41
+            raise ValidationError("'q' input must be at no more than %s characters" % MAX_QUERY_LENGTH)
  42
+        
  43
+    def validate_flds(self, field):
  44
+        """
  45
+        checks that input is a comma separated list of field names
  46
+        """
  47
+        if not len(field.data):
  48
+            return
  49
+        if re.search('[^a-z\.\_]', field.data, re.I):
  50
+            raise ValidationError("Invalid field selection: value must be a comma-separated (no whitespace) list of field names")
  51
+        for field in field.data.split(','):
  52
+            if field not in config.API_SOLR_FIELDS:
  53
+                raise ValidationError("Invalid field selection: %s is not a selectable field" % field)
  54
+    
  55
+    def validate_sort(self, field):
  56
+        """
  57
+        checks that sort input contains both a valid sorting option and direction
  58
+        """
  59
+        if not len(field.data):
  60
+            return
  61
+        try:
  62
+            sort, direction = field.data.split()
  63
+        except:
  64
+            raise ValidationError("Invalid sort option: you must specify a type (%s) and direction (%s)" % \
  65
+                                  (','.join(config.SOLR_SORT_OPTIONS, ','.join(SORT_DIRECTIONS))))
  66
+        if sort not in config.SOLR_SORT_OPTIONS:
  67
+            raise ValidationError("Invalid sort type. Valid options are %s" % ','.join(config.SOLR_SORT_OPTIONS))
  68
+        if direction not in SORT_DIRECTIONS:
  69
+            raise ValidationError("Invalid sort direction. Valid options are %s" % ','.join(SORT_DIRECTIONS))
  70
+    
  71
+    def validate_facet(self, field):
  72
+        for facet in field.data:
  73
+            if not len(facet):
  74
+                continue
  75
+            if re.search('[^a-z\_\:', facet, re.I):
  76
+                raise ValidationError("Invalid facet input: %s. Format is field[:limit[:min]].")
  77
+            facet = facet.split(':')
  78
+            if facet[0] not in config.API_SOLR_FACET_FIELDS:
  79
+                raise ValidationError("Invalid facet selection: %s is not a facetable field" % facet[0])
  80
+            if len(facet) > 1:
  81
+                for opt in facet[1:]:
  82
+                    if re.search('[^\d]', opt):
  83
+                        raise ValidationError("Invalid facet options: %s. Values for limit and min must be integers." % opt)
  84
+            
  85
+    
  86
+    def validate_fmt(self, field):
  87
+        if not len(field.data):
  88
+            return
  89
+        if field.data not in VALID_FORMATS:
  90
+            raise ValidationError("Invalid format: %s. Valid options are %s" % (field.data, ','.join(VALID_FORMATS)))
  91
+    
  92
+    def validate_hl(self, field):
  93
+        for hl in field.data:
  94
+            if not len(hl):
  95
+                continue
  96
+            if re.search('[^a-z\_\:', hl, re.I):
  97
+                raise ValidationError("Invalid highlight input: %s. Format is field[:count].")
  98
+            hl = hl.split(':')
  99
+            if hl[0] not in config.API_SOLR_FIELDS:
  100
+                raise ValidationError("Invalid highlight selection: %s is not a selectable field" % hl[0])
  101
+            if len(hl) > 1:
  102
+                if re.search('[^\d]', hl[1]):
  103
+                    raise ValidationError("Invalid highlight option: %s. Value for count must be integer." % hl[1])
  104
+    
  105
+    def validate_filter(self, field):
  106
+        for filter in field.data:
  107
+            if not len(filter):
  108
+                continue
  109
+            try:
  110
+                field,query = filter.split(':')
  111
+            except ValueError: # too many/few values to split
  112
+                raise ValidationError("Invalid filter: %s. Format should be 'field:value'" % filter)
  113
+            if field not in config.API_SOLR_FIELDS:
  114
+                raise ValidationError("Invalid filter field selection: %s is not a queryable field" % field)
  115
+            
55  adsabs/modules/api/permissions.py
@@ -4,8 +4,10 @@
4 4
 @author: jluker
5 5
 '''
6 6
 
  7
+import re
7 8
 from flask.ext.login import current_user #@UnresolvedImport
8 9
 
  10
+from config import config
9 11
 from .errors import ApiPermissionError
10 12
 
11 13
 class DevPermissions(object):
@@ -13,24 +15,59 @@ class DevPermissions(object):
13 15
     def __init__(self, perms):
14 16
         self.perms = perms
15 17
         
16  
-    def check_facets(self):
  18
+    def _facets_ok(self, req_facets):
17 19
         assert self.perms.get('facets', False), 'facets disabled'
18 20
         
19  
-    def check_max_rows(self, req_max_rows):
  21
+        excluded = self.perms.get('ex_fields', [])
  22
+        facet_limit_max = self.perms.get('facet_limit_max', 0)
  23
+        for facet in req_facets:
  24
+            # facet value format str[:limit[:mincount]], e.g., "author:100:10"
  25
+            facet = facet.strip().split(':')
  26
+            assert facet[0] not in excluded, 'disallowed facet: %s' % facet[0]
  27
+            if len(facet) > 1:
  28
+                assert facet_limit_max >= int(facet[1]), \
  29
+                    'facet limit value %d exceeds max allowed value: %d' % (int(facet[1]), facet_limit_max)
  30
+        
  31
+    def _max_rows_ok(self, req_rows):
20 32
         max_rows = self.perms.get('max_rows', 0)
21  
-        assert max_rows >= req_max_rows, 'maximum rows allowed is %d' % max_rows
  33
+        assert max_rows >= req_rows, 'rows=%s exceeds max allowed value: %d' % (req_rows, max_rows)
22 34
                    
23  
-    def check_max_start(self, req_max_start):
  35
+    def _max_start_ok(self, req_start):
24 36
         max_start = self.perms.get('max_start', 0)
25  
-        assert max_start >= req_max_start, 'maximum start allowed is %d' % max_start
  37
+        assert max_start >= req_start, 'start=%s exceeds max allowed value: %d' % (req_start, max_start)
  38
+    
  39
+    def _fields_ok(self, req_fields):
  40
+        req_fields = set(re.split('[,\s]+', req_fields.strip()))
  41
+        allowed = set(config.API_SOLR_FIELDS)
  42
+        excluded = set(self.perms.get('ex_fields', []))
  43
+        possible = allowed.difference(excluded)
  44
+        denied = req_fields.difference(possible)
  45
+        assert len(denied) == 0, 'disallowed fields: %s' % ','.join(denied)
26 46
                    
  47
+    def _highlight_ok(self, hl_fields):
  48
+        assert self.perms.get('highlight', False), 'highlighting disabled'
  49
+        
  50
+        excluded = self.perms.get('ex_highlight_fields', [])
  51
+        highlight_max = self.perms.get('highlight_max', 0)
  52
+        for hl in hl_fields:
  53
+            # highlight field format str[:count], e.g., "abstract:3"
  54
+            hl = hl.strip().split(':')
  55
+            assert hl[0] not in excluded, 'disallowed highlight field: %s' % hl[0]
  56
+            if len(hl) > 1:
  57
+                assert highlight_max >= int(hl[1]), \
  58
+                    'highlight count %d exceeds max allowed value: %d' % (int(hl[1]), highlight_max)
  59
+        
27 60
     def check_permissions(self, form):
28 61
         
29 62
         try:
30  
-            if form.facets.data:
31  
-                self.check_facets(form.facets.data)
32  
-            self.check_max_rows(form.rows.data)
33  
-            self.check_max_start(form.start.data)
  63
+            if len(form.facet.data):
  64
+                self._facets_ok(form.facet.data)
  65
+            self._max_rows_ok(form.rows.data)
  66
+            self._max_start_ok(form.start.data)
  67
+            if len(form.flds.data):
  68
+                self._fields_ok(form.flds.data)
  69
+            if len(form.hl.data):
  70
+                self._highlight_ok(form.hl.data)
34 71
         except AssertionError, e:
35 72
             raise ApiPermissionError(e.message)
36 73
         
10  adsabs/modules/api/renderers.py
@@ -10,6 +10,12 @@
10 10
 from flask.ext.pushrod.renderers import renderer #@UnresolvedImport
11 11
 from adsabs.extensions import pushrod
12 12
 
  13
+VALID_FORMATS = []
  14
+
  15
+def _register(renderer):
  16
+    VALID_FORMATS.extend(renderer.renderer_names)
  17
+    pushrod.register_renderer(renderer)
  18
+    
13 19
 # json is the default rendering method
14 20
 @renderer(name='json', mime_type=('application/json','text/html'))
15 21
 def json_renderer(unrendered, **kwargs):
@@ -26,5 +32,5 @@ def xml_renderer(unrendered, xml_template=None, **kwargs):
26 32
     else:
27 33
         return NotImplemented
28 34
     
29  
-pushrod.register_renderer(json_renderer)
30  
-pushrod.register_renderer(xml_renderer)
  35
+_register(json_renderer)
  36
+_register(xml_renderer)
13  adsabs/modules/api/request.py
@@ -14,19 +14,20 @@ class ApiSearchRequest(object):
14 14
     
15 15
     def __init__(self, request_vals, user=None):
16 16
         self.form = ApiQueryForm(request_vals, csrf_enabled=False)
17  
-        if user:
18  
-            perms = user.get_dev_perms()
19  
-        else:
20  
-            perms = g.api_user.get_dev_perms()
  17
+        if not user:
  18
+            user = g.api_user
  19
+        perms = user.get_dev_perms()
21 20
         self.perms = DevPermissions(perms)
22 21
         
23 22
     def validate(self):
24  
-        return self.form.validate() and self.perms.check_permissions(self.form)
  23
+        valid = self.form.validate()
  24
+        perms_ok = self.perms.check_permissions(self.form)
  25
+        return valid and perms_ok
25 26
     
26 27
     def execute(self):
27 28
         solr_req = SolrRequest(
28 29
             self.form.q.data,
29  
-            facets=self.form.facets.data
  30
+            facets=self.form.facet.data
30 31
             )
31 32
         self.resp = solr_req.get_response()
32 33
         return self.resp
9  adsabs/modules/api/ret_functions.py
... ...
@@ -1,9 +0,0 @@
1  
-from flask import make_response
2  
-
3  
-def ret_xml(str_text):
4  
-    """
5  
-    Function that creates a specific response object to return XML
6  
-    """
7  
-    response = make_response(str_text)
8  
-    response.headers['Content-Type'] = 'text/xml; charset=utf-8'
9  
-    return response
6  adsabs/modules/api/views.py
@@ -42,11 +42,11 @@ def search():
42 42
         
43 43
         
44 44
 @api_blueprint.route('/record/<identifier>', methods=['GET'])
45  
-@api_blueprint.route('/record/<identifier>/<field>', methods=['GET'])
  45
+@api_blueprint.route('/record/<identifier>/<operator>', methods=['GET'])
46 46
 @api_user_required
47 47
 @pushrod_view(xml_template="record.xml")
48  
-def record(identifier, field=None):
49  
-    record_req = ApiRecordRequest(identifier, field=field)
  48
+def record(identifier, operator=None):
  49
+    record_req = ApiRecordRequest(identifier, operator=operator)
50 50
     if record_req.validate():
51 51
         resp = solr.query(record_req.query())
52 52
         if not resp.get_count() > 0:
13  config.py
@@ -36,9 +36,9 @@ class AppConfig(object):
36 36
     SOLR_ROW_OPTIONS = [('20','20'),('50','50'),('100','100')]
37 37
     SOLR_DEFAULT_ROWS = '20'
38 38
     SOLR_DEFAULT_SORT = 'pubdate_sort desc'
  39
+    SOLR_SORT_OPTIONS = ['DATE','RELEVANCE','CITED','POPULARITY']
39 40
     SOLR_DEFAULT_PARAMS = [('fq', ['pubdate_sort:[* TO 20130000]'])]
40 41
     SOLR_DEFAULT_FORMAT = 'json'
41  
-    SOLR_ALLOWED_FIELDS = ['id','bibcode','title','author','pub','score','property','pubdate_sort']
42 42
     
43 43
     # copy logging.conf.dist -> logging.conf and uncomment
44 44
     LOGGING_CONFIG = os.path.join(_basedir, 'logging.conf')
@@ -47,6 +47,17 @@ class AppConfig(object):
47 47
     ADS_CLASSIC_BASEURL = 'http://adsabs.harvard.edu'
48 48
 
49 49
     API_DEFAULT_RESPONSE_FORMAT = 'json'
  50
+    # this is the full list of fields available
  51
+    # Note that most api accounts will not have access to the full list of fields
  52
+    API_SOLR_FIELDS = ['bibcode','bibstem','title','author','pub','score','property','abstract','keyword','references','full','ack','identifier']
  53
+    API_SOLR_FACET_FIELDS = {
  54
+        'bibstem': 'bibstem_facet',
  55
+        'author': 'author_facet',
  56
+        'property': 'property',
  57
+        'keyword': 'keyword_facet',
  58
+        'pubdate': 'pubdate',
  59
+        'pub': 'pub',
  60
+    }
50 61
 
51 62
 try:
52 63
     from local_config import LocalConfig
101  test/api_tests.py
@@ -16,6 +16,7 @@
16 16
 
17 17
 from adsabs.app import create_app
18 18
 from adsabs.modules.user import AdsUser
  19
+from adsabs.modules.api.permissions import DevPermissions as DP
19 20
 from adsabs.core.solr import SolrResponse
20 21
 from config import config
21 22
 from test.utils import SolrRawQueryFixture
@@ -76,21 +77,21 @@ def test_authorized_request(self):
76 77
 #        rv = self.client.get('/api/record/1234?dev_key=foo_dev_key')
77 78
 #        self.assertEqual(rv.status_code, 404)
78 79
         
79  
-#    def test_search_output(self):
80  
-#        
81  
-#        self.insert_dev_user("foo", "baz")
82  
-#        self.solr.set_data({
83  
-#            'response': {
84  
-#                'numFound': 1,
85  
-#                'docs': [],
86  
-#            }
87  
-#        })
88  
-#        rv = self.client.get('/api/search/?q=black+holes&dev_key=baz')
89  
-#        resp_data = loads(rv.data)
90  
-#        self.assertIn('meta', resp_data)
91  
-#        self.assertIn('results', resp_data)
92  
-#        self.assertTrue(resp_data['results']['count'] >= 1)
93  
-#        self.assertIsInstance(resp_data['results']['docs'], list)
  80
+    def test_search_output(self):
  81
+        
  82
+        self.insert_user("foo", developer=True)
  83
+        fixture = self.useFixture(SolrRawQueryFixture())
  84
+        rv = self.client.get('/api/search/?q=black+holes&dev_key=foo_dev_key')
  85
+        resp_data = loads(rv.data)
  86
+        self.assertIn('meta', resp_data)
  87
+        self.assertIn('results', resp_data)
  88
+        self.assertTrue(resp_data['results']['count'] >= 1)
  89
+        self.assertIsInstance(resp_data['results']['docs'], list)
  90
+        
  91
+        self.insert_user("bar", developer=True, dev_perms={'facets': True})
  92
+        rv = self.client.get('/api/search/?q=black+holes&dev_key=bar_dev_key')
  93
+        resp_data = loads(rv.data)
  94
+        self.assertIn('facets', resp_data['results'])
94 95
     
95 96
 #    def test_record_output(self):
96 97
 #        
@@ -122,23 +123,7 @@ def test_content_types(self):
122 123
         self.assertEqual(rv.status_code, 406)
123 124
         self.assertIn('renderer does not exist', rv.data)
124 125
     
125  
-#    def test_facet_output(self):
126  
-#        
127  
-#        perms = {'facets': False}
128  
-#        self.insert_user("facets_off", "foo", perms)
129  
-#        self.solr.set_data({
130  
-#            'response': {
131  
-#                'numFound': 1,
132  
-#                'docs': [],
133  
-#                'facets': [ ('bar', 1) ]
134  
-#            }
135  
-#        })
136  
-#        
137  
-#        rv = self.client.get('/api/search/?q=black+holes&dev_key=foo&facets=true')
138  
-#        resp_data = loads(rv.data)
139  
-#        self.assertIn('facets', resp_data)
140  
-        
141  
-    def test_permissions(self):
  126
+    def test_user_permissions(self):
142 127
         self.insert_user("a", "1")
143 128
         self.insert_user("b", "2")
144 129
         self.insert_user("c", "3")
@@ -146,7 +131,59 @@ def test_permissions(self):
146 131
         self.insert_user("e", "5")
147 132
         self.insert_user("f", "6")
148 133
         
  134
+class PermissionsTest(unittest2.TestCase):
149 135
     
  136
+    def test_permissions(self):
150 137
         
  138
+        p = DP({})
  139
+        self.assertRaises(AssertionError, p._facets_ok, ["author"])
  140
+        p = DP({'facets': True})
  141
+        self.assertIsNone(p._facets_ok(["author"]))
  142
+        p = DP({'ex_fields': ['author']})
  143
+        self.assertRaisesRegexp(AssertionError, 'facets disabled', p._facets_ok, ["author"])
  144
+        p = DP({'facets': True, 'ex_fields': ['author']})
  145
+        self.assertRaisesRegexp(AssertionError, 'disallowed facet', p._facets_ok, ["author"])
  146
+        p = DP({'facets': True, 'facet_limit_max': 10})
  147
+        self.assertIsNone(p._facets_ok(["author:9"]))
  148
+        self.assertIsNone(p._facets_ok(["author:10"]))
  149
+        self.assertIsNone(p._facets_ok(["author:10:100"]))
  150
+        self.assertRaisesRegexp(AssertionError, 'facet limit value 11 exceeds max', p._facets_ok, ["author:11"])
  151
+        
  152
+        p = DP({})
  153
+        self.assertRaises(AssertionError, p._max_rows_ok, 10)
  154
+        p = DP({'max_rows': 10})
  155
+        self.assertIsNone(p._max_rows_ok(9))
  156
+        self.assertIsNone(p._max_rows_ok(10))
  157
+        self.assertRaises(AssertionError, p._max_rows_ok, 11)
  158
+        self.assertRaisesRegexp(AssertionError, 'rows=11 exceeds max allowed value: 10', p._max_rows_ok, 11)
  159
+        
  160
+        p = DP({})
  161
+        self.assertRaises(AssertionError, p._max_start_ok, 100)
  162
+        p = DP({'max_start': 200})
  163
+        self.assertIsNone(p._max_start_ok(100))
  164
+        self.assertIsNone(p._max_start_ok(200))
  165
+        self.assertRaises(AssertionError, p._max_start_ok, 300)
  166
+        self.assertRaisesRegexp(AssertionError, 'start=300 exceeds max allowed value: 200', p._max_start_ok, 300)
  167
+        
  168
+        p = DP({})
  169
+        self.assertIsNone(p._fields_ok('bibcode,title'))
  170
+        p = DP({'ex_fields': ['full']})
  171
+        self.assertIsNone(p._fields_ok('bibcode,title'))
  172
+        self.assertRaises(AssertionError, p._fields_ok, 'bibcode,title,full')
  173
+        self.assertRaisesRegexp(AssertionError, 'disallowed fields: full', p._fields_ok, 'bibcode,title,full')
  174
+        
  175
+        p = DP({})
  176
+        self.assertRaises(AssertionError, p._highlight_ok, ["abstract"])
  177
+        p = DP({'highlight': True})
  178
+        self.assertIsNone(p._highlight_ok(['abstract']))
  179
+        p = DP({'ex_highlight_fields': ['abstract']})
  180
+        self.assertRaisesRegexp(AssertionError, 'highlighting disabled', p._highlight_ok, ["abstract"])
  181
+        p = DP({'ex_highlight_fields': ['abstract'], 'highlight': True})
  182
+        self.assertRaisesRegexp(AssertionError, 'disallowed highlight field: abstract', p._highlight_ok, ["abstract"])
  183
+        p = DP({'highlight': True, 'highlight_max': 3})
  184
+        self.assertIsNone(p._highlight_ok(["abstract:2"]))
  185
+        self.assertIsNone(p._highlight_ok(["abstract:3"]))
  186
+        self.assertRaisesRegexp(AssertionError, 'highlight count 4 exceeds max allowed value: 3', p._highlight_ok, ["abstract:4"])
  187
+    
151 188
 if __name__ == '__main__':
152 189
     unittest2.main()

0 notes on commit b80597c

Please sign in to comment.
Something went wrong with that request. Please try again.