Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ user = User(device_id="abcdefg", user_id="user@company.com", user_properties={
#
# To fetch synchronous
variants = experiment.fetch(user)
variant = variants['sdk-ci-test']
variant = variants['YOUR-FLAG-KEY']
if variant:
if variant.value == 'on':
# Flag is on
Expand Down
49 changes: 33 additions & 16 deletions src/amplitude_experiment/client.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import json
import logging
import threading
import time
from time import sleep
from typing import Any

from .config import Config
from .version import __version__
from .variant import Variant
from .connection_pool import HTTPConnectionPool
from .user import User
import http.client
import json
import logging
from .variant import Variant
from .version import __version__


class Client:
Expand All @@ -31,6 +33,7 @@ def __init__(self, api_key, config=None):
self.logger.addHandler(logging.StreamHandler())
if self.config.debug:
self.logger.setLevel(logging.DEBUG)
self.__setup_connection_pool()

def fetch(self, user: User):
"""
Expand Down Expand Up @@ -72,9 +75,9 @@ def __fetch_async_internal(self, user, callback):
def __fetch_internal(self, user):
self.logger.debug(f"[Experiment] Fetching variants for user: {user}")
try:
return self.__do_fetch(user, self.config.fetch_timeout_millis)
return self.__do_fetch(user)
except Exception as e:
self.logger.error(f"Experiment] Fetch failed: {e}")
self.logger.error(f"[Experiment] Fetch failed: {e}")
return self.__retry_fetch(user)

def __retry_fetch(self, user):
Expand All @@ -86,40 +89,54 @@ def __retry_fetch(self, user):
for i in range(self.config.fetch_retries):
sleep(delay_millis / 1000.0)
try:
return self.__do_fetch(user, self.config.fetch_timeout_millis)
return self.__do_fetch(user)
except Exception as e:
self.logger.error(f"[Experiment] Retry failed: {e}")
err = e
delay_millis = min(delay_millis * self.config.fetch_retry_backoff_scalar,
self.config.fetch_retry_backoff_max_millis)
raise err

def __do_fetch(self, user, timeout_millis):
def __do_fetch(self, user):
start = time.time()
user_context = self.__add_context(user)
headers = {
'Authorization': f"Api-Key {self.api_key}",
'Content-Type': 'application/json;charset=utf-8'
}
scheme, _, host = self.config.server_url.split('/', 3)
Connection = http.client.HTTPConnection if scheme == 'http:' else http.client.HTTPSConnection
conn = Connection(host, timeout=timeout_millis / 1000)
conn.connect()
conn = self._connection_pool.acquire()
body = user_context.to_json().encode('utf8')
if len(body) > 8000:
self.logger.warning(f"[Experiment] encoded user object length ${len(body)} "
f"cannot be cached by CDN; must be < 8KB")
self.logger.debug(f"[Experiment] Fetch variants for user: {str(user_context)}")
conn.request('POST', '/sdk/vardata', body, headers)
response = conn.getresponse()
response = conn.request('POST', '/sdk/vardata', body, headers)
self._connection_pool.release(conn)
elapsed = '%.3f' % ((time.time() - start) * 1000)
self.logger.debug(f"[Experiment] Fetch complete in {elapsed} ms")
json_response = json.loads(response.read().decode("utf8"))
variants = self.__parse_json_variants(json_response)
self.logger.debug(f"[Experiment] Fetched variants: {json.dumps(variants, default=str)}")
conn.close()
return variants

def __setup_connection_pool(self):
scheme, _, host = self.config.server_url.split('/', 3)
timeout = int(self.config.fetch_timeout_millis / 1000)
self._connection_pool = HTTPConnectionPool(host, max_size=1, idle_timeout=30,
read_timeout=timeout, scheme=scheme)

def close(self) -> None:
"""
Close resource like connection pool with client
"""
self._connection_pool.close()

def __enter__(self) -> 'Client':
return self

def __exit__(self, *exit_info: Any) -> None:
self.close()

def __add_context(self, user):
user = user or {}
user.library = user.library or f"experiment-python-server/{__version__}"
Expand Down
185 changes: 185 additions & 0 deletions src/amplitude_experiment/connection_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
import threading
import time
from typing import Any

from http.client import HTTPConnection, HTTPResponse, HTTPSConnection


class WrapperHTTPConnection:

def __init__(self, pool: 'HTTPConnectionPool', conn: HTTPConnection) -> None:
"""
Wrapped Http Connection, used with connection pool
:param pool: Connection pool this connection belongs to
:param conn: Wrapped HTTPConnection
"""
self.pool = pool
self.conn = conn
self.response = None
self.last_time = time.time()
self.is_available = True

def __enter__(self) -> 'WrapperHTTPConnection':
return self

def __exit__(self, *exit_info: Any) -> None:
if not self.response.will_close and not self.response.is_closed():
self.close()
self.pool.release(self)

def request(self, *args: Any, **kwargs: Any) -> HTTPResponse:
self.conn.request(*args, **kwargs)
self.response = self.conn.getresponse()
return self.response

def close(self) -> None:
self.conn.close()
self.is_available = False


class HTTPConnectionPool:

def __init__(self, host: str, port: int = None, max_size: int = None, idle_timeout: int = None,
read_timeout: int = None, scheme: str = 'https') -> None:
"""
A simple connection pool to reuse the http connections
:param host: pass
:param port: pass
:param max_size: Max connection allowed
:param idle_timeout: Idle timeout to clear the connection
:param read_timeout: Read timeout with connection
:param scheme: http or https
"""
self.host = host
self.port = port
self.max_size = max_size
self.idle_timeout = idle_timeout
self.read_timeout = read_timeout
self.scheme = scheme
self._lock = threading.Condition()
self._pool = []
self.conn_num = 0
self.is_closed = False
self._clearer = None
self.start_clear_conn()

def acquire(self, blocking: bool = True, timeout: int = None) -> WrapperHTTPConnection:
if self.is_closed:
raise ConnectionPoolClosed
with self._lock:
if self.max_size is None or not self.is_full():
if self.is_pool_empty():
self._put_connection(self._create_connection())
else:
if not blocking:
if self.is_pool_empty():
raise EmptyPoolError
elif timeout is None:
while self.is_pool_empty():
self._lock.wait()
elif timeout < 0:
raise ValueError("'timeout' must be a non-negative number")
else:
end_time = time.time() + timeout
while self.is_pool_empty():
remaining = end_time - time.time()
if remaining <= 0:
raise EmptyPoolError
self._lock.wait(remaining)
return self._get_connection()

def release(self, conn: WrapperHTTPConnection) -> None:
if self.is_closed:
conn.close()
return
with self._lock:
if not conn.is_available:
conn.close()
self.conn_num -= 1
conn = self._create_connection()
self._put_connection(conn)
self._lock.notify()

def _get_connection(self) -> WrapperHTTPConnection:
try:
return self._pool.pop()
except IndexError:
raise EmptyPoolError

def _put_connection(self, conn: WrapperHTTPConnection) -> None:
conn.last_time = time.time()
self._pool.append(conn)

def _create_connection(self) -> WrapperHTTPConnection:
self.conn_num += 1
connection = HTTPConnection if self.scheme == 'http:' else HTTPSConnection
return WrapperHTTPConnection(self, connection(self.host, self.port, timeout=self.read_timeout))

def is_pool_empty(self) -> bool:
return len(self._pool) == 0

def is_full(self) -> bool:
if self.max_size is None:
return False
return self.conn_num >= self.max_size

def close(self) -> None:
if self.is_closed:
return
self.is_closed = True
self.stop_clear_conn()
pool, self._pool = self._pool, None
for conn in pool:
conn.close()

def clear_idle_conn(self) -> None:
if self.is_closed:
raise ConnectionPoolClosed
# Staring a thread to clear idle connections
threading.Thread(target=self._clear_idle_conn).start()

def _clear_idle_conn(self) -> None:
if not self._lock.acquire(timeout=self.idle_timeout):
return
current_time = time.time()
if self.is_pool_empty():
pass
elif current_time - self._pool[-1].last_time >= self.idle_timeout:
self.conn_num -= len(self._pool)
self._pool.clear()
else:
left, right = 0, len(self._pool) - 1
while left < right:
mid = (left + right) // 2
if current_time - self._pool[mid].last_time >= self.idle_timeout:
left = mid + 1
else:
right = mid
self._pool = self._pool[left:]
self.conn_num -= left
self._lock.release()

def start_clear_conn(self) -> None:
if self.idle_timeout is None:
return
self.clear_idle_conn()
self._clearer = threading.Timer(self.idle_timeout, self.start_clear_conn)
self._clearer.start()

def stop_clear_conn(self) -> None:
if self._clearer is not None:
self._clearer.cancel()

def __enter__(self) -> 'HTTPConnectionPool':
return self

def __exit__(self, *exit_info: Any) -> None:
self.close()


class EmptyPoolError(Exception):
pass


class ConnectionPoolClosed(Exception):
pass
22 changes: 13 additions & 9 deletions tests/client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,32 @@

class ClientTestCase(unittest.TestCase):

def setUp(self):
self.client = None

def test_initialize_raise_error(self):
self.assertRaises(ValueError, Client, "")

def test_fetch(self):
client = Client(API_KEY)
expected_variant = Variant('on', 'payload')
user = User(user_id='test_user')
variants = client.fetch(user)
variant_name = 'sdk-ci-test'
self.assertIn(variant_name, variants)
self.assertEqual(expected_variant, variants.get(variant_name))
with Client(API_KEY) as client:
expected_variant = Variant('on', 'payload')
user = User(user_id='test_user')
variants = client.fetch(user)
variant_name = 'sdk-ci-test'
self.assertIn(variant_name, variants)
self.assertEqual(expected_variant, variants.get(variant_name))

def callback_for_async(self, user, variants):
expected_variant = Variant('on', 'payload')
variant_name = 'sdk-ci-test'
self.assertIn(variant_name, variants)
self.assertEqual(expected_variant, variants.get(variant_name))
self.client.close()

def test_fetch_async(self):
client = Client(API_KEY, Config(debug=True))
self.client = Client(API_KEY, Config(debug=True))
user = User(user_id='test_user')
client.fetch_async(user, self.callback_for_async)
self.client.fetch_async(user, self.callback_for_async)


if __name__ == '__main__':
Expand Down
1 change: 1 addition & 0 deletions tests/factory_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ def test_singleton_instance(self):
client1 = Experiment.initialize(API_KEY)
client2 = Experiment.initialize(API_KEY)
self.assertEqual(client1, client2)
client1.close()


if __name__ == '__main__':
Expand Down