Skip to content

Commit

Permalink
Revert "Added Host and amazon trace id"
Browse files Browse the repository at this point in the history
  • Loading branch information
Taylor Shaulis committed Mar 19, 2018
1 parent c81aacf commit 1bbfe24
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 77 deletions.
37 changes: 17 additions & 20 deletions solr/tests/unittests/test_solr.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from werkzeug.security import gen_salt
from werkzeug.datastructures import MultiDict
from StringIO import StringIO
from solr.tests.mocks import MockSolrResponse
from ..mocks import MockSolrResponse
from views import SolrInterface
from models import Limits, Base

Expand Down Expand Up @@ -42,63 +42,60 @@ def test_cleanup_solr_request(self):
"""
si = SolrInterface()
payload = {}
cleaned, headers = si.cleanup_solr_request(payload)
cleaned = si.cleanup_solr_request(payload)
self.assertEqual(cleaned['rows'], self.app.config.get('SOLR_SERVICE_MAX_ROWS', 100))
self.assertEqual(cleaned['fl'], 'id')

payload = {'rows': '1000000'}
cleaned, headers = si.cleanup_solr_request(payload)
cleaned = si.cleanup_solr_request(payload)
self.assertEqual(cleaned['rows'], self.app.config.get('SOLR_SERVICE_MAX_ROWS', 100))

payload = {'rows': 1000000}
cleaned, headers = si.cleanup_solr_request(payload)
cleaned = si.cleanup_solr_request(payload)
self.assertEqual(cleaned['rows'], self.app.config.get('SOLR_SERVICE_MAX_ROWS', 100))

payload = {'rows': '5'}
cleaned, headers = si.cleanup_solr_request(payload)
cleaned = si.cleanup_solr_request(payload)
self.assertEqual(cleaned['rows'], 5)

payload = {'rows': ['5', '1000000']}
cleaned, headers = si.cleanup_solr_request(payload)
cleaned = si.cleanup_solr_request(payload)
self.assertEqual(cleaned['rows'], 5)

payload = {'rows': ['1', '0']}
cleaned, headers = si.cleanup_solr_request(payload)
cleaned = si.cleanup_solr_request(payload)
self.assertEqual(cleaned['rows'], 1)

payload = {'hl.snippets': 1000000, 'hl.fragsize': 1000000}
cleaned, headers = si.cleanup_solr_request(payload)
cleaned = si.cleanup_solr_request(payload)
self.assertEqual(cleaned['hl.snippets'], self.app.config.get('SOLR_SERVICE_MAX_SNIPPETS', 4))
self.assertEqual(cleaned['hl.fragsize'], self.app.config.get('SOLR_SERVICE_MAX_FRAGSIZE', 100))

payload = {'hl.snippets': [2, 1000000], 'hl.fragsize': [3, 1000000]}
cleaned, headers = si.cleanup_solr_request(payload)
cleaned = si.cleanup_solr_request(payload)
self.assertEqual(cleaned['hl.snippets'], self.app.config.get('SOLR_SERVICE_MAX_SNIPPETS', 2))
self.assertEqual(cleaned['hl.fragsize'], self.app.config.get('SOLR_SERVICE_MAX_FRAGSIZE', 3))

payload = {'fl': ['id,bibcode,title,volume']}
cleaned, headers = si.cleanup_solr_request(payload)
cleaned = si.cleanup_solr_request(payload)
self.assertEqual(cleaned['fl'], 'id,bibcode,title,volume')

payload = {'fl': ['id ', ' bibcode ', 'title ', ' volume']}
cleaned, headers = si.cleanup_solr_request(payload)
cleaned = si.cleanup_solr_request(payload)
self.assertEqual(cleaned['fl'], 'id,bibcode,title,volume')
self.assertEqual(cleaned['rows'], self.app.config.get('SOLR_SERVICE_MAX_ROWS', 100))

payload = {'fl': ['id', 'bibcode', '*']}
cleaned, headers = si.cleanup_solr_request(payload)
cleaned = si.cleanup_solr_request(payload)
self.assertNotIn('*', cleaned['fl'])

payload = {'fl': ['id,bibcode,*']}
cleaned, headers = si.cleanup_solr_request(payload)
cleaned = si.cleanup_solr_request(payload)
self.assertNotIn('*', cleaned['fl'])

payload = {'fq': ['pos(1,author:foo)']}
cleaned, headers = si.cleanup_solr_request(payload)
cleaned = si.cleanup_solr_request(payload)
self.assertEqual(cleaned['fq'], ['pos(1,author:foo)'])

self.assertEqual(headers,
{'Host': u'http://localhost:8983', 'Content-Type': 'application/x-www-form-urlencoded'})


def test_limits(self):
Expand All @@ -112,11 +109,11 @@ def test_limits(self):
self.assertTrue(len(session.query(Limits).filter_by(uid='9').all()) == 1)

payload = {'fl': ['id,bibcode,title,full,bar'], 'q': '*:*'}
cleaned, headers = si.cleanup_solr_request(payload, user_id='9')
cleaned = si.cleanup_solr_request(payload, user_id='9')
self.assertEqual(cleaned['fl'], u'id,bibcode,title,full')
self.assertEqual(cleaned['fq'], [u'bibstem:apj'])

cleaned, headers = si.cleanup_solr_request(
cleaned = si.cleanup_solr_request(
{'fl': ['id,bibcode,full'], 'fq': ['*:*']},
user_id='9')
self.assertEqual(cleaned['fl'], u'id,bibcode,full')
Expand All @@ -127,7 +124,7 @@ def test_limits(self):
session.add(Limits(uid='9', field='bar', filter='bibstem:apr'))
session.commit()

cleaned, headers = si.cleanup_solr_request(
cleaned = si.cleanup_solr_request(
{'fl': ['id,bibcode,fuLL,BAR'], 'fq': ['*:*']},
user_id='9')
self.assertEqual(cleaned['fl'], u'id,bibcode,full,bar')
Expand Down
80 changes: 23 additions & 57 deletions solr/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,14 @@ class StatusView(Resource):
def get(self):
return {'app': current_app.name, 'status': 'online'}, 200


class SolrInterface(Resource):
"""Base class that responsible for forwarding a query to Solr"""
handler = 'SOLR_SERVICE_URL'

def __init__(self, *args, **kwargs):
Resource.__init__(self, *args, **kwargs)
self._host = None

def get(self):
query, headers = self.cleanup_solr_request(dict(request.args))

query = self.cleanup_solr_request(dict(request.args), request.headers.get('X-Adsws-Uid', 'default'))
headers = dict()
headers['Content-Type'] = 'application/x-www-form-urlencoded'
r = requests.post(
current_app.config[self.handler],
data=query,
Expand Down Expand Up @@ -73,48 +69,37 @@ def apply_protective_filters(self, payload, user_id, protected_fields):
session.commit()


def cleanup_solr_request(self, payload, user_id=None):
def cleanup_solr_request(self, payload, user_id='default'):
"""
Sanitizes a request before it is passed to solr
:param payload: dict, raw request payload. Warning: we'll
modify the dictionary directly
:kwarg user_id: string, identifying the user
:return: tuple - (sanitized payload, headers for solr)
:param payload: raw request payload
:return: sanitized payload
"""

if not user_id:
user_id = request.headers.get('X-Adsws-Uid', 'default')

headers = {}
headers['Content-Type'] = request.headers.get('Content-Type', 'application/x-www-form-urlencoded') or 'application/x-www-form-urlencoded'

# trace id and Host header are important for proper routing/logging
headers['Host'] = self.get_host(current_app.config.get(self.handler))

if 'x-amzn-trace-id' in request.headers:
payload['x-amzn-trace-id'] = request.headers['x-amzn-trace-id']
def safe_int(val, default=0):
if isinstance(val, (list, tuple)):
val = val[0]
try:
return int(val)
except (ValueError, TypeError):
return default

payload['wt'] = 'json'
max_rows = current_app.config.get('SOLR_SERVICE_MAX_ROWS', 100)
max_rows *= int(
request.headers.get('X-Adsws-Ratelimit-Level', 1)
)



# Ensure there is a single rows value and that it does not bypass the max rows limit
rows = max_rows
if 'rows' in payload:
rows = _safe_int(payload['rows'], default=max_rows)
rows = safe_int(payload['rows'], default=max_rows)
rows = min(rows, max_rows)
payload['rows'] = rows

# Ensure there is a single start value
start = 0
if 'start' in payload:
start = _safe_int(payload['start'], default=0)
start = safe_int(payload['start'], default=0)
payload['start'] = start

# we disallow 'return everything'
Expand Down Expand Up @@ -148,23 +133,12 @@ def cleanup_solr_request(self, payload, user_id=None):
for k,v in payload.items():
if 'hl.' in k:
if '.snippets' in k:
payload[k] = max(0, min(_safe_int(v, default=max_hl), max_hl))
payload[k] = max(0, min(safe_int(v, default=max_hl), max_hl))
elif '.fragsize' in k:
payload[k] = max(1, min(_safe_int(v, default=max_frag), max_frag)) #0 would return whole field

return payload, headers

def get_host(self, url):
"""Just extracts the host from the url."""
return self._host or self._get_host(url)

def _get_host(self, url):
parts = url.split('/')
if 'http' in parts[0].lower():
self._host = parts[0].lower() + '//' + parts[2]
else:
self._host = 'http://' + parts[0]
return self._host
payload[k] = max(1, min(safe_int(v, default=max_frag), max_frag)) #0 would return whole field

return payload


class Tvrh(SolrInterface):
"""Exposes the solr term-vector histogram endpoint"""
Expand Down Expand Up @@ -200,8 +174,9 @@ class BigQuery(SolrInterface):
def post(self):
payload = dict(request.form)
payload.update(request.args)
headers = dict(request.headers)

query, headers = self.cleanup_solr_request(payload)
query = self.cleanup_solr_request(payload, headers.get('X-Adsws-Uid', 'default'))

if request.files and \
sum([len(i) for i in request.files.listvalues()]) > 1:
Expand Down Expand Up @@ -235,12 +210,3 @@ def post(self):
else:
return json.dumps({'error': "malformed request"}), 400
return r.text, r.status_code, r.headers


def _safe_int(val, default=0):
if isinstance(val, (list, tuple)):
val = val[0]
try:
return int(val)
except (ValueError, TypeError):
return default

0 comments on commit 1bbfe24

Please sign in to comment.