diff --git a/src/amplitude_experiment/client.py b/src/amplitude_experiment/client.py index 32e237c..7a6eabc 100644 --- a/src/amplitude_experiment/client.py +++ b/src/amplitude_experiment/client.py @@ -1,3 +1,4 @@ +import threading import time from time import sleep from .config import Config @@ -41,20 +42,42 @@ def fetch(self, user: User): Variants Dictionary. """ try: - return self.fetch_internal(user) + return self.__fetch_internal(user) except Exception as e: self.logger.error(f"[Experiment] Failed to fetch variants: {e}") return {} - def fetch_internal(self, user): + def fetch_async(self, user: User, callback=None): + """ + Fetch all variants for a user asynchronous. Will trigger callback after fetch complete + Parameters: + user (User): The Experiment User + callback (callable): Callback function, takes user and variants arguments + """ + thread = threading.Thread(target=self.__fetch_async_internal, args=(user, callback)) + thread.start() + + def __fetch_async_internal(self, user, callback): + try: + variants = self.__fetch_internal(user) + if callback: + callback(user, variants) + return variants + except Exception as e: + self.logger.error(f"[Experiment] Failed to fetch variants: {e}") + if callback: + callback(user, {}) + return {} + + def __fetch_internal(self, user): self.logger.debug(f"[Experiment] Fetching variants for user: {user}") try: - return self.do_fetch(user, self.config.fetch_timeout_millis) + return self.__do_fetch(user, self.config.fetch_timeout_millis) except Exception as e: self.logger.error(f"Experiment] Fetch failed: {e}") - return self.retry_fetch(user) + return self.__retry_fetch(user) - def retry_fetch(self, user): + def __retry_fetch(self, user): if self.config.fetch_retries == 0: return {} self.logger.debug("[Experiment] Retrying fetch") @@ -63,7 +86,7 @@ def retry_fetch(self, user): for i in range(self.config.fetch_retries): sleep(delay_millis / 1000.0) try: - return self.do_fetch(user, self.config.fetch_timeout_millis) + return self.__do_fetch(user, self.config.fetch_timeout_millis) except Exception as e: self.logger.error(f"[Experiment] Retry failed: {e}") err = e @@ -71,16 +94,16 @@ def retry_fetch(self, user): self.config.fetch_retry_backoff_max_millis) raise err - def do_fetch(self, user, timeout_millis): + def __do_fetch(self, user, timeout_millis): start = time.time() - user_context = self.add_context(user) + user_context = self.__add_context(user) headers = { 'Authorization': f"Api-Key {self.api_key}", 'Content-Type': 'application/json;charset=utf-8' } scheme, _, host = self.config.server_url.split('/', 3) Connection = http.client.HTTPConnection if scheme == 'http:' else http.client.HTTPSConnection - conn = Connection(host) + conn = Connection(host, timeout=timeout_millis / 1000) conn.connect() body = user_context.to_json().encode('utf8') if len(body) > 8000: @@ -92,17 +115,17 @@ def do_fetch(self, user, timeout_millis): elapsed = '%.3f' % ((time.time() - start) * 1000) self.logger.debug(f"[Experiment] Fetch complete in {elapsed} ms") json_response = json.loads(response.read().decode("utf8")) - variants = self.parse_json_variants(json_response) + variants = self.__parse_json_variants(json_response) self.logger.debug(f"[Experiment] Fetched variants: {json.dumps(variants, default=str)}") conn.close() return variants - def add_context(self, user): + def __add_context(self, user): user = user or {} user.library = user.library or f"experiment-python-server/{__version__}" return user - def parse_json_variants(self, json_response): + def __parse_json_variants(self, json_response): variants = {} for key, value in json_response.items(): variant_value = '' diff --git a/src/amplitude_experiment/config.py b/src/amplitude_experiment/config.py index 6b6e729..d10f0f0 100644 --- a/src/amplitude_experiment/config.py +++ b/src/amplitude_experiment/config.py @@ -14,7 +14,7 @@ def __init__(self, debug=False, """ Initialize a config Parameters: - debug (str): Set to true to log some extra information to the console. + debug (bool): Set to true to log some extra information to the console. server_url (str): The server endpoint from which to request variants. fetch_timeout_millis (int): The request timeout, in milliseconds, used when fetching variants triggered by calling start() or setUser(). diff --git a/tests/client_test.py b/tests/client_test.py index fcb3d75..1dccbd8 100644 --- a/tests/client_test.py +++ b/tests/client_test.py @@ -20,6 +20,17 @@ def test_fetch(self): self.assertIn(variant_name, variants) self.assertEqual(expected_variant, variants.get(variant_name)) + def callback_for_async(self, user, variants): + expected_variant = Variant('on', 'payload') + variant_name = 'sdk-ci-test' + self.assertIn(variant_name, variants) + self.assertEqual(expected_variant, variants.get(variant_name)) + + def test_fetch_async(self): + client = Client(API_KEY, Config(debug=True)) + user = User(user_id='test_user') + client.fetch_async(user, self.callback_for_async) + if __name__ == '__main__': unittest.main()