From 202168f06e7b4ff84590da70cda89a09b851fbcd Mon Sep 17 00:00:00 2001 From: Qingzhuo Zhen Date: Sun, 5 Jun 2022 21:55:41 -0700 Subject: [PATCH] feat: add connection pool to support connection reuse --- README.md | 2 +- src/amplitude_experiment/client.py | 49 ++++-- src/amplitude_experiment/connection_pool.py | 185 ++++++++++++++++++++ tests/client_test.py | 22 ++- tests/factory_test.py | 1 + 5 files changed, 233 insertions(+), 26 deletions(-) create mode 100644 src/amplitude_experiment/connection_pool.py diff --git a/README.md b/README.md index fdda0ff..ef2faee 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,7 @@ user = User(device_id="abcdefg", user_id="user@company.com", user_properties={ # # To fetch synchronous variants = experiment.fetch(user) -variant = variants['sdk-ci-test'] +variant = variants['YOUR-FLAG-KEY'] if variant: if variant.value == 'on': # Flag is on diff --git a/src/amplitude_experiment/client.py b/src/amplitude_experiment/client.py index 7a6eabc..93b0235 100644 --- a/src/amplitude_experiment/client.py +++ b/src/amplitude_experiment/client.py @@ -1,13 +1,15 @@ +import json +import logging import threading import time from time import sleep +from typing import Any + from .config import Config -from .version import __version__ -from .variant import Variant +from .connection_pool import HTTPConnectionPool from .user import User -import http.client -import json -import logging +from .variant import Variant +from .version import __version__ class Client: @@ -31,6 +33,7 @@ def __init__(self, api_key, config=None): self.logger.addHandler(logging.StreamHandler()) if self.config.debug: self.logger.setLevel(logging.DEBUG) + self.__setup_connection_pool() def fetch(self, user: User): """ @@ -72,9 +75,9 @@ def __fetch_async_internal(self, user, callback): def __fetch_internal(self, user): self.logger.debug(f"[Experiment] Fetching variants for user: {user}") try: - return self.__do_fetch(user, self.config.fetch_timeout_millis) + return self.__do_fetch(user) except Exception as e: - self.logger.error(f"Experiment] Fetch failed: {e}") + self.logger.error(f"[Experiment] Fetch failed: {e}") return self.__retry_fetch(user) def __retry_fetch(self, user): @@ -86,7 +89,7 @@ def __retry_fetch(self, user): for i in range(self.config.fetch_retries): sleep(delay_millis / 1000.0) try: - return self.__do_fetch(user, self.config.fetch_timeout_millis) + return self.__do_fetch(user) except Exception as e: self.logger.error(f"[Experiment] Retry failed: {e}") err = e @@ -94,32 +97,46 @@ def __retry_fetch(self, user): self.config.fetch_retry_backoff_max_millis) raise err - def __do_fetch(self, user, timeout_millis): + def __do_fetch(self, user): start = time.time() user_context = self.__add_context(user) headers = { 'Authorization': f"Api-Key {self.api_key}", 'Content-Type': 'application/json;charset=utf-8' } - scheme, _, host = self.config.server_url.split('/', 3) - Connection = http.client.HTTPConnection if scheme == 'http:' else http.client.HTTPSConnection - conn = Connection(host, timeout=timeout_millis / 1000) - conn.connect() + conn = self._connection_pool.acquire() body = user_context.to_json().encode('utf8') if len(body) > 8000: self.logger.warning(f"[Experiment] encoded user object length ${len(body)} " f"cannot be cached by CDN; must be < 8KB") self.logger.debug(f"[Experiment] Fetch variants for user: {str(user_context)}") - conn.request('POST', '/sdk/vardata', body, headers) - response = conn.getresponse() + response = conn.request('POST', '/sdk/vardata', body, headers) + self._connection_pool.release(conn) elapsed = '%.3f' % ((time.time() - start) * 1000) self.logger.debug(f"[Experiment] Fetch complete in {elapsed} ms") json_response = json.loads(response.read().decode("utf8")) variants = self.__parse_json_variants(json_response) self.logger.debug(f"[Experiment] Fetched variants: {json.dumps(variants, default=str)}") - conn.close() return variants + def __setup_connection_pool(self): + scheme, _, host = self.config.server_url.split('/', 3) + timeout = int(self.config.fetch_timeout_millis / 1000) + self._connection_pool = HTTPConnectionPool(host, max_size=1, idle_timeout=30, + read_timeout=timeout, scheme=scheme) + + def close(self) -> None: + """ + Close resource like connection pool with client + """ + self._connection_pool.close() + + def __enter__(self) -> 'Client': + return self + + def __exit__(self, *exit_info: Any) -> None: + self.close() + def __add_context(self, user): user = user or {} user.library = user.library or f"experiment-python-server/{__version__}" diff --git a/src/amplitude_experiment/connection_pool.py b/src/amplitude_experiment/connection_pool.py new file mode 100644 index 0000000..4ddb72c --- /dev/null +++ b/src/amplitude_experiment/connection_pool.py @@ -0,0 +1,185 @@ +import threading +import time +from typing import Any + +from http.client import HTTPConnection, HTTPResponse, HTTPSConnection + + +class WrapperHTTPConnection: + + def __init__(self, pool: 'HTTPConnectionPool', conn: HTTPConnection) -> None: + """ + Wrapped Http Connection, used with connection pool + :param pool: Connection pool this connection belongs to + :param conn: Wrapped HTTPConnection + """ + self.pool = pool + self.conn = conn + self.response = None + self.last_time = time.time() + self.is_available = True + + def __enter__(self) -> 'WrapperHTTPConnection': + return self + + def __exit__(self, *exit_info: Any) -> None: + if not self.response.will_close and not self.response.is_closed(): + self.close() + self.pool.release(self) + + def request(self, *args: Any, **kwargs: Any) -> HTTPResponse: + self.conn.request(*args, **kwargs) + self.response = self.conn.getresponse() + return self.response + + def close(self) -> None: + self.conn.close() + self.is_available = False + + +class HTTPConnectionPool: + + def __init__(self, host: str, port: int = None, max_size: int = None, idle_timeout: int = None, + read_timeout: int = None, scheme: str = 'https') -> None: + """ + A simple connection pool to reuse the http connections + :param host: pass + :param port: pass + :param max_size: Max connection allowed + :param idle_timeout: Idle timeout to clear the connection + :param read_timeout: Read timeout with connection + :param scheme: http or https + """ + self.host = host + self.port = port + self.max_size = max_size + self.idle_timeout = idle_timeout + self.read_timeout = read_timeout + self.scheme = scheme + self._lock = threading.Condition() + self._pool = [] + self.conn_num = 0 + self.is_closed = False + self._clearer = None + self.start_clear_conn() + + def acquire(self, blocking: bool = True, timeout: int = None) -> WrapperHTTPConnection: + if self.is_closed: + raise ConnectionPoolClosed + with self._lock: + if self.max_size is None or not self.is_full(): + if self.is_pool_empty(): + self._put_connection(self._create_connection()) + else: + if not blocking: + if self.is_pool_empty(): + raise EmptyPoolError + elif timeout is None: + while self.is_pool_empty(): + self._lock.wait() + elif timeout < 0: + raise ValueError("'timeout' must be a non-negative number") + else: + end_time = time.time() + timeout + while self.is_pool_empty(): + remaining = end_time - time.time() + if remaining <= 0: + raise EmptyPoolError + self._lock.wait(remaining) + return self._get_connection() + + def release(self, conn: WrapperHTTPConnection) -> None: + if self.is_closed: + conn.close() + return + with self._lock: + if not conn.is_available: + conn.close() + self.conn_num -= 1 + conn = self._create_connection() + self._put_connection(conn) + self._lock.notify() + + def _get_connection(self) -> WrapperHTTPConnection: + try: + return self._pool.pop() + except IndexError: + raise EmptyPoolError + + def _put_connection(self, conn: WrapperHTTPConnection) -> None: + conn.last_time = time.time() + self._pool.append(conn) + + def _create_connection(self) -> WrapperHTTPConnection: + self.conn_num += 1 + connection = HTTPConnection if self.scheme == 'http:' else HTTPSConnection + return WrapperHTTPConnection(self, connection(self.host, self.port, timeout=self.read_timeout)) + + def is_pool_empty(self) -> bool: + return len(self._pool) == 0 + + def is_full(self) -> bool: + if self.max_size is None: + return False + return self.conn_num >= self.max_size + + def close(self) -> None: + if self.is_closed: + return + self.is_closed = True + self.stop_clear_conn() + pool, self._pool = self._pool, None + for conn in pool: + conn.close() + + def clear_idle_conn(self) -> None: + if self.is_closed: + raise ConnectionPoolClosed + # Staring a thread to clear idle connections + threading.Thread(target=self._clear_idle_conn).start() + + def _clear_idle_conn(self) -> None: + if not self._lock.acquire(timeout=self.idle_timeout): + return + current_time = time.time() + if self.is_pool_empty(): + pass + elif current_time - self._pool[-1].last_time >= self.idle_timeout: + self.conn_num -= len(self._pool) + self._pool.clear() + else: + left, right = 0, len(self._pool) - 1 + while left < right: + mid = (left + right) // 2 + if current_time - self._pool[mid].last_time >= self.idle_timeout: + left = mid + 1 + else: + right = mid + self._pool = self._pool[left:] + self.conn_num -= left + self._lock.release() + + def start_clear_conn(self) -> None: + if self.idle_timeout is None: + return + self.clear_idle_conn() + self._clearer = threading.Timer(self.idle_timeout, self.start_clear_conn) + self._clearer.start() + + def stop_clear_conn(self) -> None: + if self._clearer is not None: + self._clearer.cancel() + + def __enter__(self) -> 'HTTPConnectionPool': + return self + + def __exit__(self, *exit_info: Any) -> None: + self.close() + + +class EmptyPoolError(Exception): + pass + + +class ConnectionPoolClosed(Exception): + pass diff --git a/tests/client_test.py b/tests/client_test.py index 1dccbd8..72af54f 100644 --- a/tests/client_test.py +++ b/tests/client_test.py @@ -8,28 +8,32 @@ class ClientTestCase(unittest.TestCase): + def setUp(self): + self.client = None + def test_initialize_raise_error(self): self.assertRaises(ValueError, Client, "") def test_fetch(self): - client = Client(API_KEY) - expected_variant = Variant('on', 'payload') - user = User(user_id='test_user') - variants = client.fetch(user) - variant_name = 'sdk-ci-test' - self.assertIn(variant_name, variants) - self.assertEqual(expected_variant, variants.get(variant_name)) + with Client(API_KEY) as client: + expected_variant = Variant('on', 'payload') + user = User(user_id='test_user') + variants = client.fetch(user) + variant_name = 'sdk-ci-test' + self.assertIn(variant_name, variants) + self.assertEqual(expected_variant, variants.get(variant_name)) def callback_for_async(self, user, variants): expected_variant = Variant('on', 'payload') variant_name = 'sdk-ci-test' self.assertIn(variant_name, variants) self.assertEqual(expected_variant, variants.get(variant_name)) + self.client.close() def test_fetch_async(self): - client = Client(API_KEY, Config(debug=True)) + self.client = Client(API_KEY, Config(debug=True)) user = User(user_id='test_user') - client.fetch_async(user, self.callback_for_async) + self.client.fetch_async(user, self.callback_for_async) if __name__ == '__main__': diff --git a/tests/factory_test.py b/tests/factory_test.py index 2777164..c4ac532 100644 --- a/tests/factory_test.py +++ b/tests/factory_test.py @@ -9,6 +9,7 @@ def test_singleton_instance(self): client1 = Experiment.initialize(API_KEY) client2 = Experiment.initialize(API_KEY) self.assertEqual(client1, client2) + client1.close() if __name__ == '__main__':