Skip to content

Commit

Permalink
Moved the database work into instance method
Browse files Browse the repository at this point in the history
  • Loading branch information
romanchyla committed Aug 23, 2016
1 parent 08c9595 commit ae3ece6
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 15 deletions.
16 changes: 9 additions & 7 deletions solr/tests/unittests/test_solr.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,37 +36,39 @@ def test_cleanup_solr_request(self):
"""
Simple test of the cleanup classmethod
"""
si = SolrInterface()
payload = {'fl': ['id,bibcode,title,volume']}
cleaned = SolrInterface.cleanup_solr_request(payload)
cleaned = si.cleanup_solr_request(payload)
self.assertEqual(cleaned['fl'], 'id,bibcode,title,volume')

payload = {'fl': ['id ', ' bibcode ', 'title ', ' volume']}
cleaned = SolrInterface.cleanup_solr_request(payload)
cleaned = si.cleanup_solr_request(payload)
self.assertEqual(cleaned['fl'], 'id,bibcode,title,volume')

payload = {'fl': ['id', 'bibcode', '*']}
cleaned = SolrInterface.cleanup_solr_request(payload)
cleaned = si.cleanup_solr_request(payload)
self.assertNotIn('*', cleaned['fl'])

payload = {'fl': ['id,bibcode,*']}
cleaned = SolrInterface.cleanup_solr_request(payload)
cleaned = si.cleanup_solr_request(payload)
self.assertNotIn('*', cleaned['fl'])


def test_limits(self):
"""
Prevent users from getting certain data
"""
si = SolrInterface()
db.session.add(Limits(uid='9', field='full', filter='bibstem:apj'))
db.session.commit()
self.assertTrue(len(db.session.query(Limits).filter_by(uid='9').all()) == 1)

payload = {'fl': ['id,bibcode,title,full,bar'], 'q': '*:*'}
cleaned = SolrInterface.cleanup_solr_request(payload, user_id='9')
cleaned = si.cleanup_solr_request(payload, user_id='9')
self.assertEqual(cleaned['fl'], u'id,bibcode,title,full')
self.assertEqual(cleaned['fq'], [u'bibstem:apj'])

cleaned = SolrInterface.cleanup_solr_request(
cleaned = si.cleanup_solr_request(
{'fl': ['id,bibcode,full'], 'fq': ['*:*']},
user_id='9')
self.assertEqual(cleaned['fl'], u'id,bibcode,full')
Expand All @@ -76,7 +78,7 @@ def test_limits(self):
db.session.add(Limits(uid='9', field='bar', filter='bibstem:apr'))
db.session.commit()

cleaned = SolrInterface.cleanup_solr_request(
cleaned = si.cleanup_solr_request(
{'fl': ['id,bibcode,fuLL,BAR'], 'fq': ['*:*']},
user_id='9')
self.assertEqual(cleaned['fl'], u'id,bibcode,full,bar')
Expand Down
15 changes: 7 additions & 8 deletions solr/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class SolrInterface(Resource):
"""Base class that responsible for forwarding a query to Solr"""

def get(self):
query = SolrInterface.cleanup_solr_request(dict(request.args), request.headers.get('X-Adsws-Uid', 'default'))
query = self.cleanup_solr_request(dict(request.args), request.headers.get('X-Adsws-Uid', 'default'))
headers = dict()
headers['Content-Type'] = 'application/x-www-form-urlencoded'
r = requests.post(
Expand All @@ -44,8 +44,7 @@ def set_cookies(request):
cookie = {cookie_name: request.cookies.get(cookie_name, 'session')}
return cookie if cookie[cookie_name] else None

@staticmethod
def apply_protective_filters(payload, user_id, protected_fields):
def apply_protective_filters(self, payload, user_id, protected_fields):
"""
Adds filters to the query that should limit results to conditions
that are associted with the user_id+protected_field. If a field is
Expand All @@ -66,10 +65,10 @@ def apply_protective_filters(payload, user_id, protected_fields):
fl = u'{0},{1}'.format(fl, f.field)
fq.append(unicode(f.filter))
payload['fl'] = fl
db.session.commit()


@staticmethod
def cleanup_solr_request(payload, user_id='default'):
def cleanup_solr_request(self, payload, user_id='default'):
"""
Sanitizes a request before it is passed to solr
:param payload: raw request payload
Expand Down Expand Up @@ -107,7 +106,7 @@ def cleanup_solr_request(payload, user_id='default'):
payload['fl'] = ','.join(fields)

if len(protected_fields) > 0:
SolrInterface.apply_protective_filters(payload, user_id, protected_fields)
self.apply_protective_filters(payload, user_id, protected_fields)

max_hl = current_app.config.get('SOLR_SERVICE_MAX_SNIPPETS', 4)
max_frag = current_app.config.get('SOLR_SERVICE_MAX_FRAGSIZE', 100)
Expand Down Expand Up @@ -145,7 +144,7 @@ class Qtree(SolrInterface):
handler = 'SOLR_SERVICE_QTREE_HANDLER'


class BigQuery(Resource):
class BigQuery(SolrInterface):
"""Exposes the bigquery endpoint"""
scopes = ['api']
rate_limit = [100, 60*60*24]
Expand All @@ -157,7 +156,7 @@ def post(self):
payload.update(request.args)
headers = dict(request.headers)

query = SolrInterface.cleanup_solr_request(payload, headers.get('X-Adsws-Uid', 'default'))
query = self.cleanup_solr_request(payload, headers.get('X-Adsws-Uid', 'default'))

if request.files and \
sum([len(i) for i in request.files.listvalues()]) > 1:
Expand Down

0 comments on commit ae3ece6

Please sign in to comment.