Skip to content

Commit

Permalink
Merge e15e3de into 0290cf8
Browse files Browse the repository at this point in the history
  • Loading branch information
LeoErcolanelli committed Jun 14, 2016
2 parents 0290cf8 + e15e3de commit d7ae488
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 178 deletions.
188 changes: 54 additions & 134 deletions algoliasearch/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,39 +22,21 @@
THE SOFTWARE.
"""

import os
import json
import hmac
import hashlib
import base64
import random
import sys

APPENGINE = 'APPENGINE_RUNTIME' in os.environ
SSL_CERTIFICATE_DOMAIN = 'algolia.net'

try:
from urllib import urlencode
except ImportError:
from urllib.parse import urlencode

if APPENGINE:
from google.appengine.api import urlfetch
APPENGINE_METHODS = {
'GET' : urlfetch.GET,
'POST' : urlfetch.POST,
'PUT' : urlfetch.PUT,
'DELETE' : urlfetch.DELETE
}

from requests import Session
from requests import exceptions

from .version import VERSION
from .index import Index

from .helpers import AlgoliaException
from .helpers import CustomJSONEncoder
from .transport import Transport
from .helpers import deprecated
from .helpers import safe
from .helpers import urlify
Expand All @@ -67,49 +49,62 @@ class Client(object):
start using Algolia Search API.
"""

def __init__(self, app_id, api_key, hosts_array=None):
def __init__(self, app_id, api_key, hosts=None, _transport=None):
"""
Algolia Search Client initialization
@param app_id the application ID you have in your admin interface
@param api_key a valid API key for the service
@param hosts_array the list of hosts that you have received for the service
"""
if not hosts_array:
self._transport = Transport() if _transport is None else _transport

if not hosts:
fallbacks = [
'%s-1.algolianet.com' % app_id,
'%s-2.algolianet.com' % app_id,
'%s-3.algolianet.com' % app_id,
]
random.shuffle(fallbacks)

self.read_hosts = ['%s-dsn.algolia.net' % app_id]
self.read_hosts.extend(fallbacks)
self.write_hosts = ['%s.algolia.net' % app_id]
self.write_hosts.extend(fallbacks)

self._transport.read_hosts = ['%s-dsn.algolia.net' % app_id]
self._transport.read_hosts.extend(fallbacks)
self._transport.write_hosts = ['%s.algolia.net' % app_id]
self._transport.write_hosts.extend(fallbacks)
else:
self.read_hosts = hosts_array
self.write_hosts = hosts_array
self._transport.read_hosts = hosts
self._transport.write_hosts = hosts

self._app_id = app_id
self._api_key = api_key
self.timeout = (1, 30)
self.search_timeout = (1, 5)

self._session = Session()
self._session.verify = os.path.join(os.path.dirname(__file__),
'resources/ca-bundle.crt')
self._session.headers = {
'X-Algolia-API-Key': self.api_key,
'X-Algolia-Application-Id': self.app_id,
self._transport.headers = {
'X-Algolia-API-Key': api_key,
'X-Algolia-Application-Id': app_id,
'Content-Type': 'gzip',
'Accept-Encoding': 'gzip',
'User-Agent': 'Algolia Search for Python %s' % VERSION
}

self._app_id = app_id
self._api_key = api_key

# Fix for AppEngine bug when using urlfetch_stub
if 'google.appengine.api.apiproxy_stub_map' in sys.modules.keys():
self._session.headers.pop('Accept-Encoding', None)
self.headers.pop('Accept-Encoding', None)

@property
def timeout(self):
return self._transport.timeout

@timeout.setter
def timeout(self, t):
self._transport.timeout = t

@property
def search_timeout(self):
return self._transport.search_timeout

@search_timeout.setter
def search_timeout(self, t):
self._transport.search_timeout = t

@property
def app_id(self):
Expand Down Expand Up @@ -185,7 +180,7 @@ def set_extra_headers(self, **kwargs):

@property
def headers(self):
return self._session.headers
return self._transport.headers

@deprecated
def set_timeout(self, connect_timeout, read_timeout, search_timeout=5):
Expand Down Expand Up @@ -227,17 +222,16 @@ def multiple_queries(self, queries,
'params': urlencode(urlify(query))
})

return self._perform_request(
self.read_hosts, path, 'POST', params=params,
body={'requests': requests}, is_search=True)
data = {'requests': requests}
return self._req(True, path, 'POST', params, data)

def batch(self, requests):
"""Send a batch request targetting multiple indices."""
"""Send a batch request targeting multiple indices."""
if isinstance(requests, (list, tuple)):
requests = {'requests': requests}

return self._perform_request(self.write_hosts, '/1/indexes/*/batch',
'POST', body=requests)
path = '/1/indexes/*/batch'
return self._req(False, path, 'POST', data=requests)

@deprecated
def listIndexes(self):
Expand All @@ -250,7 +244,7 @@ def list_indexes(self):
{'items': [{ 'name': 'contacts', 'created_at': '2013-01-18T15:33:13.556Z'},
{'name': 'notes', 'created_at': '2013-01-18T15:33:13.556Z'}]}
"""
return self._perform_request(self.read_hosts, '/1/indexes', 'GET')
return self._req(True, '/1/indexes', 'GET')

@deprecated
def deleteIndex(self, index_name):
Expand All @@ -264,7 +258,7 @@ def delete_index(self, index_name):
@param index_name the name of index to delete
"""
path = '/1/indexes/%s' % safe(index_name)
return self._perform_request(self.write_hosts, path, 'DELETE')
return self._req(False, path, 'DELETE')

@deprecated
def moveIndex(self, src_index_name, dst_index_name):
Expand All @@ -280,8 +274,7 @@ def move_index(self, src_index_name, dst_index_name):
"""
path = '/1/indexes/%s/operation' % safe(src_index_name)
request = {'operation': 'move', 'destination': dst_index_name}
return self._perform_request(self.write_hosts, path, 'POST',
body=request)
return self._req(False, path, 'POST', data=request)

@deprecated
def copyIndex(self, src_index_name, dst_index_name):
Expand All @@ -297,8 +290,7 @@ def copy_index(self, src_index_name, dst_index_name):
"""
path = '/1/indexes/%s/operation' % safe(src_index_name)
request = {'operation': 'copy', 'destination': dst_index_name}
return self._perform_request(self.write_hosts, path, 'POST',
body=request)
return self._req(False, path, 'POST', data=request)

@deprecated
def getLogs(self, offset=0, length=10, type='all'):
Expand All @@ -313,13 +305,8 @@ def get_logs(self, offset=0, length=10, type='all'):
@param length Specify the maximum number of entries to retrieve
starting at offset. Maximum allowed value: 1000.
"""
params = {
'offset': offset,
'length': length,
'type': type
}
return self._perform_request(self.write_hosts, '/1/logs', 'GET',
params=params)
params = {'offset': offset, 'length': length, 'type': type}
return self._req(False, '/1/logs', 'GET', params)

@deprecated
def initIndex(self, index_name):
Expand All @@ -340,7 +327,7 @@ def listUserKeys(self):

def list_user_keys(self):
"""List all existing user keys with their associated ACLs."""
return self._perform_request(self.read_hosts, '/1/keys', 'GET')
return self._req(True, '/1/keys', 'GET')

@deprecated
def getUserKeyACL(self, key):
Expand All @@ -349,7 +336,7 @@ def getUserKeyACL(self, key):
def get_user_key_acl(self, key):
"""'Get ACL of a user key."""
path = '/1/keys/%s' % key
return self._perform_request(self.read_hosts, path, 'GET')
return self._req(True, path, 'GET')

@deprecated
def deleteUserKey(self, key):
Expand All @@ -358,7 +345,7 @@ def deleteUserKey(self, key):
def delete_user_key(self, key):
"""Delete an existing user key."""
path = '/1/keys/%s' % key
return self._perform_request(self.write_hosts, path, 'DELETE')
return self._req(False, path, 'DELETE')

@deprecated
def addUserKey(self, obj,
Expand Down Expand Up @@ -416,8 +403,7 @@ def add_user_key(self, obj,
if indexes:
obj['indexes'] = indexes

return self._perform_request(self.write_hosts, '/1/keys', 'POST',
body=obj)
return self._req(False, '/1/keys', 'POST', data=obj)

def update_user_key(self, key, obj,
validity=None,
Expand Down Expand Up @@ -469,8 +455,7 @@ def update_user_key(self, key, obj,
obj['indexes'] = indexes

path = '/1/keys/%s' % key
return self._perform_request(self.write_hosts, path, 'PUT',
body=obj)
return self._req(False, path, 'PUT', data=obj)

@deprecated
def generateSecuredApiKey(self, private_api_key, tag_filters,
Expand Down Expand Up @@ -504,70 +489,5 @@ def generate_secured_api_key(self, private_api_key, queryParameters,
securedKey = hmac.new(private_api_key.encode('utf-8'), queryParameters.encode('utf-8'), hashlib.sha256).hexdigest()
return str(base64.b64encode(("%s%s" % (securedKey, queryParameters)).encode('utf-8')).decode('utf-8'))

def _perform_appengine_request(self, host, path, method, timeout, params=None, data=None):
"""
Perform an HTTPS request with AppEngine's urlfetch. SSL certificate will not validate when
the request is on a domain which is not a aloglia.net subdomain, a SNI is not available by
default on GAE. Hence, we do set validate_certificate to False when calling those domains.
"""
method = APPENGINE_METHODS.get(method)
if isinstance(timeout, tuple):
timeout = timeout[1]
url = 'https://%s%s' % (host, path)
url = params and '%s?%s' %(url, urlencode(urlify(params))) or url
res = urlfetch.fetch(url=url, method=method, payload=data,
headers=self.headers, deadline=timeout,
validate_certificate=host.endswith(SSL_CERTIFICATE_DOMAIN))
content = res.content != None and json.loads(res.content) or None
if (int(res.status_code / 100) == 2 and content):
return content
elif (int(res.status_code / 100) == 4):
message = "HttpCode: %d" % res.status_code
if content and content.get('message'):
message = content['message']
raise AlgoliaException(message)
else:
mesage = '%s Server Error: %s' % (res.status_code, res.content)
raise Exception(http_error_msg, response=res)

def _perform_session_request(self, host, path, method, timeout, params=None, data=None):
"""Perform an HTTPS request with request's Session."""
res = self._session.request(
method, 'https://%s%s' % (host, path),
params=params, data=data, timeout=timeout)
if (int(res.status_code / 100) == 2 and res.json != None):
return res.json()
elif (int(res.status_code / 100) == 4):
message = "HttpCode: %d" % res.status_code
if res.json != None and 'message' in res.json():
message = res.json()['message']
raise AlgoliaException(message)
res.raise_for_status()

def _perform_request(self, hosts, path, method, params=None, body=None,
is_search=False):
"""Perform an HTTPS request with retry logic."""
if params:
params = urlify(params)
if body:
body = json.dumps(body, cls=CustomJSONEncoder)
timeout = self.search_timeout if is_search else self.timeout
exceptions_hosts = {}
for i, host in enumerate(hosts):
if i > 1:
if isinstance(timeout, tuple):
timeout = (timeout[0] + 2, timeout[1] + 10)
else:
timeout += 10
try:
_request = APPENGINE and self._perform_appengine_request or self._perform_session_request
return _request(host, path, method, timeout, params=params, data=body)
except AlgoliaException as e:
raise e
except Exception as e:
exceptions_hosts[host] = "%s: %s" % (e.__class__.__name__, str(e))
pass

# Impossible to connect
raise AlgoliaException('%s %s' % ('Unreachable hosts:',
exceptions_hosts))
def _req(self, is_search, path, meth, params=None, data=None):
return self._transport.req(is_search, path, meth, params, data)
Loading

0 comments on commit d7ae488

Please sign in to comment.