Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 66 additions & 14 deletions ai21/ai21_env_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,80 @@

from ai21.constants import DEFAULT_API_VERSION, STUDIO_HOST

# Constants for environment variable keys
_ENV_API_KEY = "AI21_API_KEY"
_ENV_API_VERSION = "AI21_API_VERSION"
_ENV_API_HOST = "AI21_API_HOST"
_ENV_TIMEOUT_SEC = "AI21_TIMEOUT_SEC"
_ENV_NUM_RETRIES = "AI21_NUM_RETRIES"
_ENV_AWS_REGION = "AI21_AWS_REGION"
_ENV_LOG_LEVEL = "AI21_LOG_LEVEL"


@dataclass
class _AI21EnvConfig:
api_key: Optional[str] = None
api_version: str = DEFAULT_API_VERSION
api_host: str = STUDIO_HOST
timeout_sec: Optional[int] = None
num_retries: Optional[int] = None
aws_region: Optional[str] = None
log_level: Optional[str] = None
_api_key: Optional[str] = None
_api_version: str = DEFAULT_API_VERSION
_api_host: str = STUDIO_HOST
_timeout_sec: Optional[int] = None
_num_retries: Optional[int] = None
_aws_region: Optional[str] = None
_log_level: Optional[str] = None

@classmethod
def from_env(cls) -> _AI21EnvConfig:
return cls(
api_key=os.getenv("AI21_API_KEY"),
api_version=os.getenv("AI21_API_VERSION", DEFAULT_API_VERSION),
api_host=os.getenv("AI21_API_HOST", STUDIO_HOST),
timeout_sec=os.getenv("AI21_TIMEOUT_SEC"),
num_retries=os.getenv("AI21_NUM_RETRIES"),
aws_region=os.getenv("AI21_AWS_REGION", "us-east-1"),
log_level=os.getenv("AI21_LOG_LEVEL", "info"),
_api_key=os.getenv(_ENV_API_KEY),
_api_version=os.getenv(_ENV_API_VERSION, DEFAULT_API_VERSION),
_api_host=os.getenv(_ENV_API_HOST, STUDIO_HOST),
_timeout_sec=os.getenv(_ENV_TIMEOUT_SEC),
_num_retries=os.getenv(_ENV_NUM_RETRIES),
_aws_region=os.getenv(_ENV_AWS_REGION, "us-east-1"),
_log_level=os.getenv(_ENV_LOG_LEVEL, "info"),
)

@property
def api_key(self) -> str:
self._api_key = os.getenv(_ENV_API_KEY, self._api_key)
return self._api_key

@property
def api_version(self) -> str:
self._api_version = os.getenv(_ENV_API_VERSION, self._api_version)
return self._api_version

@property
def api_host(self) -> str:
self._api_host = os.getenv(_ENV_API_HOST, self._api_host)
return self._api_host

@property
def timeout_sec(self) -> Optional[int]:
timeout_str = os.getenv(_ENV_TIMEOUT_SEC)

if timeout_str is not None:
self._timeout_sec = int(timeout_str)

return self._timeout_sec

@property
def num_retries(self) -> Optional[int]:
retries_str = os.getenv(_ENV_NUM_RETRIES)

if retries_str is not None:
self._num_retries = int(retries_str)

return self._num_retries

@property
def aws_region(self) -> Optional[str]:
self._aws_region = os.getenv(_ENV_AWS_REGION, self._aws_region)
return self._aws_region

@property
def log_level(self) -> Optional[str]:
self._log_level = os.getenv(_ENV_LOG_LEVEL, self._log_level)
return self._log_level


AI21EnvConfig = _AI21EnvConfig.from_env()
8 changes: 5 additions & 3 deletions tests/integration_tests/services/test_sagemaker.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import os

import pytest

from ai21 import SageMaker, AI21EnvConfig
from ai21 import SageMaker


def _add_or_remove_api_key(use_api_key: bool):
if use_api_key:
AI21EnvConfig.api_key = "test"
os.environ["AI21_API_KEY"] = "test"
else:
AI21EnvConfig.api_key = None
del os.environ["AI21_API_KEY"]


@pytest.mark.parametrize(
Expand Down
43 changes: 43 additions & 0 deletions tests/unittests/test_ai21_env_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import os
from contextlib import contextmanager

from ai21 import AI21Client

_FAKE_API_KEY = "fake-key"
os.environ["AI21_API_KEY"] = _FAKE_API_KEY


@contextmanager
def set_env_var(key: str, value: str):
os.environ[key] = value
yield
del os.environ[key]


def test_env_config__when_set_via_init_and_env__should_be_taken_from_init():
client = AI21Client()
assert client._http_client._api_key == _FAKE_API_KEY

init_api_key = "init-key"
client2 = AI21Client(api_key=init_api_key)

assert client2._http_client._api_key == init_api_key


def test_env_config__when_set_twice__should_be_updated():
client = AI21Client()

assert client._http_client._api_key == _FAKE_API_KEY

new_api_key = "new-key"

with set_env_var("AI21_API_KEY", new_api_key):
client2 = AI21Client()
assert client2._http_client._api_key == new_api_key


def test_env_config__when_set_int__should_be_set():
with set_env_var("AI21_TIMEOUT_SEC", "1"):
client = AI21Client()

assert client._http_client._timeout_sec == 1