From 153996678d69c0baaac2b2ce46b9396faa1d2f16 Mon Sep 17 00:00:00 2001 From: Josephasafg Date: Wed, 22 May 2024 18:20:39 +0300 Subject: [PATCH 1/4] fix: updated env var live --- ai21/ai21_env_config.py | 80 ++++++++++++++++++++----- tests/unittests/test_ai21_env_config.py | 34 +++++++++++ 2 files changed, 100 insertions(+), 14 deletions(-) create mode 100644 tests/unittests/test_ai21_env_config.py diff --git a/ai21/ai21_env_config.py b/ai21/ai21_env_config.py index b62a6464..fabcb34e 100644 --- a/ai21/ai21_env_config.py +++ b/ai21/ai21_env_config.py @@ -5,28 +5,80 @@ from ai21.constants import DEFAULT_API_VERSION, STUDIO_HOST +# Constants for environment variable keys +_ENV_API_KEY = "AI21_API_KEY" +_ENV_API_VERSION = "AI21_API_VERSION" +_ENV_API_HOST = "AI21_API_HOST" +_ENV_TIMEOUT_SEC = "AI21_TIMEOUT_SEC" +_ENV_NUM_RETRIES = "AI21_NUM_RETRIES" +_ENV_AWS_REGION = "AI21_AWS_REGION" +_ENV_LOG_LEVEL = "AI21_LOG_LEVEL" + @dataclass class _AI21EnvConfig: - api_key: Optional[str] = None - api_version: str = DEFAULT_API_VERSION - api_host: str = STUDIO_HOST - timeout_sec: Optional[int] = None - num_retries: Optional[int] = None - aws_region: Optional[str] = None - log_level: Optional[str] = None + _api_key: Optional[str] = None + _api_version: str = DEFAULT_API_VERSION + _api_host: str = STUDIO_HOST + _timeout_sec: Optional[int] = None + _num_retries: Optional[int] = None + _aws_region: Optional[str] = None + _log_level: Optional[str] = None @classmethod def from_env(cls) -> _AI21EnvConfig: return cls( - api_key=os.getenv("AI21_API_KEY"), - api_version=os.getenv("AI21_API_VERSION", DEFAULT_API_VERSION), - api_host=os.getenv("AI21_API_HOST", STUDIO_HOST), - timeout_sec=os.getenv("AI21_TIMEOUT_SEC"), - num_retries=os.getenv("AI21_NUM_RETRIES"), - aws_region=os.getenv("AI21_AWS_REGION", "us-east-1"), - log_level=os.getenv("AI21_LOG_LEVEL", "info"), + _api_key=os.getenv(_ENV_API_KEY), + _api_version=os.getenv(_ENV_API_VERSION, DEFAULT_API_VERSION), + _api_host=os.getenv(_ENV_API_HOST, STUDIO_HOST), + _timeout_sec=os.getenv(_ENV_TIMEOUT_SEC), + _num_retries=os.getenv(_ENV_NUM_RETRIES), + _aws_region=os.getenv(_ENV_AWS_REGION, "us-east-1"), + _log_level=os.getenv(_ENV_LOG_LEVEL, "info"), ) + @property + def api_key(self) -> str: + self._api_key = os.getenv(_ENV_API_KEY, self._api_key) + return self._api_key + + @property + def api_version(self) -> str: + self._api_version = os.getenv(_ENV_API_VERSION, self._api_version) + return self._api_version + + @property + def api_host(self) -> str: + self._api_host = os.getenv(_ENV_API_HOST, self._api_host) + return self._api_host + + @property + def timeout_sec(self) -> Optional[int]: + timeout_str = os.getenv(_ENV_TIMEOUT_SEC) + + if timeout_str is not None: + self._timeout_sec = int(timeout_str) + + return self._timeout_sec + + @property + def num_retries(self) -> Optional[int]: + retries_str = os.getenv(_ENV_NUM_RETRIES) + + if retries_str is not None: + self._num_retries = int(retries_str) + + return self._num_retries + + @property + def aws_region(self) -> Optional[str]: + self._aws_region = os.getenv(_ENV_AWS_REGION, self._aws_region) + return self._aws_region + + @property + def log_level(self) -> Optional[str]: + self._log_level = os.getenv(_ENV_LOG_LEVEL, self._log_level) + return self._log_level + AI21EnvConfig = _AI21EnvConfig.from_env() diff --git a/tests/unittests/test_ai21_env_config.py b/tests/unittests/test_ai21_env_config.py new file mode 100644 index 00000000..091a539b --- /dev/null +++ b/tests/unittests/test_ai21_env_config.py @@ -0,0 +1,34 @@ +import os +from ai21 import AI21Client + +_FAKE_API_KEY = "fake-key" +os.environ["AI21_API_KEY"] = _FAKE_API_KEY + + +def test_env_config__when_set_twice__should_be_updated(): + client = AI21Client() + + assert client._http_client._api_key == _FAKE_API_KEY + + new_api_key = "new-key" + os.environ["AI21_API_KEY"] = new_api_key + client2 = AI21Client() + assert client2._http_client._api_key == new_api_key + + +def test_env_config__when_set_via_init_and_env__should_be_taken_from_init(): + client = AI21Client() + assert client._http_client._api_key == _FAKE_API_KEY + + init_api_key = "init-key" + client2 = AI21Client(api_key=init_api_key) + + assert client2._http_client._api_key == init_api_key + + +def test_env_config__when_set_int__should_be_set(): + os.environ["AI21_TIMEOUT_SEC"] = "1" + + client = AI21Client() + + assert client._http_client._timeout_sec == 1 From 87a17c4a6b54c76efdaef6e84f084d780a78117b Mon Sep 17 00:00:00 2001 From: Josephasafg Date: Wed, 22 May 2024 18:24:36 +0300 Subject: [PATCH 2/4] fix: test order --- tests/unittests/test_ai21_env_config.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/unittests/test_ai21_env_config.py b/tests/unittests/test_ai21_env_config.py index 091a539b..cbb0be29 100644 --- a/tests/unittests/test_ai21_env_config.py +++ b/tests/unittests/test_ai21_env_config.py @@ -5,6 +5,16 @@ os.environ["AI21_API_KEY"] = _FAKE_API_KEY +def test_env_config__when_set_via_init_and_env__should_be_taken_from_init(): + client = AI21Client() + assert client._http_client._api_key == _FAKE_API_KEY + + init_api_key = "init-key" + client2 = AI21Client(api_key=init_api_key) + + assert client2._http_client._api_key == init_api_key + + def test_env_config__when_set_twice__should_be_updated(): client = AI21Client() @@ -16,16 +26,6 @@ def test_env_config__when_set_twice__should_be_updated(): assert client2._http_client._api_key == new_api_key -def test_env_config__when_set_via_init_and_env__should_be_taken_from_init(): - client = AI21Client() - assert client._http_client._api_key == _FAKE_API_KEY - - init_api_key = "init-key" - client2 = AI21Client(api_key=init_api_key) - - assert client2._http_client._api_key == init_api_key - - def test_env_config__when_set_int__should_be_set(): os.environ["AI21_TIMEOUT_SEC"] = "1" From 757d5414e88daa33095714138e07344f053f6654 Mon Sep 17 00:00:00 2001 From: Josephasafg Date: Wed, 22 May 2024 18:27:00 +0300 Subject: [PATCH 3/4] fix: Added contextmanager set_env_var --- tests/unittests/test_ai21_env_config.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/tests/unittests/test_ai21_env_config.py b/tests/unittests/test_ai21_env_config.py index cbb0be29..eb2db1fe 100644 --- a/tests/unittests/test_ai21_env_config.py +++ b/tests/unittests/test_ai21_env_config.py @@ -1,10 +1,19 @@ import os +from contextlib import contextmanager + from ai21 import AI21Client _FAKE_API_KEY = "fake-key" os.environ["AI21_API_KEY"] = _FAKE_API_KEY +@contextmanager +def set_env_var(key: str, value: str): + os.environ[key] = value + yield + del os.environ[key] + + def test_env_config__when_set_via_init_and_env__should_be_taken_from_init(): client = AI21Client() assert client._http_client._api_key == _FAKE_API_KEY @@ -21,14 +30,14 @@ def test_env_config__when_set_twice__should_be_updated(): assert client._http_client._api_key == _FAKE_API_KEY new_api_key = "new-key" - os.environ["AI21_API_KEY"] = new_api_key - client2 = AI21Client() - assert client2._http_client._api_key == new_api_key + with set_env_var("AI21_API_KEY", new_api_key): + client2 = AI21Client() + assert client2._http_client._api_key == new_api_key -def test_env_config__when_set_int__should_be_set(): - os.environ["AI21_TIMEOUT_SEC"] = "1" - client = AI21Client() +def test_env_config__when_set_int__should_be_set(): + with set_env_var("AI21_TIMEOUT_SEC", "1"): + client = AI21Client() - assert client._http_client._timeout_sec == 1 + assert client._http_client._timeout_sec == 1 From 966f9c3eda7eb216a22a9ab24ac54fa208acbad3 Mon Sep 17 00:00:00 2001 From: Josephasafg Date: Wed, 22 May 2024 18:33:39 +0300 Subject: [PATCH 4/4] fix: sagemaker tests --- tests/integration_tests/services/test_sagemaker.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/integration_tests/services/test_sagemaker.py b/tests/integration_tests/services/test_sagemaker.py index 9e90b438..29e61be2 100644 --- a/tests/integration_tests/services/test_sagemaker.py +++ b/tests/integration_tests/services/test_sagemaker.py @@ -1,13 +1,15 @@ +import os + import pytest -from ai21 import SageMaker, AI21EnvConfig +from ai21 import SageMaker def _add_or_remove_api_key(use_api_key: bool): if use_api_key: - AI21EnvConfig.api_key = "test" + os.environ["AI21_API_KEY"] = "test" else: - AI21EnvConfig.api_key = None + del os.environ["AI21_API_KEY"] @pytest.mark.parametrize(