diff --git a/ragna/assistants/_ai21labs.py b/ragna/assistants/_ai21labs.py index 19cfd59b..f941c999 100644 --- a/ragna/assistants/_ai21labs.py +++ b/ragna/assistants/_ai21labs.py @@ -2,10 +2,10 @@ from ragna.core import Source -from ._api import ApiAssistant +from ._api import AuthenticatedApiAssistant -class Ai21LabsAssistant(ApiAssistant): +class Ai21LabsAssistant(AuthenticatedApiAssistant): _API_KEY_ENV_VAR = "AI21_API_KEY" _MODEL_TYPE: str diff --git a/ragna/assistants/_anthropic.py b/ragna/assistants/_anthropic.py index f5f4c538..4b44ce8e 100644 --- a/ragna/assistants/_anthropic.py +++ b/ragna/assistants/_anthropic.py @@ -3,10 +3,10 @@ from ragna.core import PackageRequirement, RagnaException, Requirement, Source -from ._api import ApiAssistant +from ._api import AuthenticatedApiAssistant -class AnthropicApiAssistant(ApiAssistant): +class AnthropicApiAssistant(AuthenticatedApiAssistant): _API_KEY_ENV_VAR = "ANTHROPIC_API_KEY" _MODEL: str diff --git a/ragna/assistants/_api.py b/ragna/assistants/_api.py index 01041c2b..296476b8 100644 --- a/ragna/assistants/_api.py +++ b/ragna/assistants/_api.py @@ -11,12 +11,10 @@ from ragna.core import Assistant, EnvVarRequirement, RagnaException, Requirement, Source -class ApiAssistant(Assistant): - _API_KEY_ENV_VAR: str - +class _ApiAssistant(Assistant): @classmethod def requirements(cls) -> list[Requirement]: - return [EnvVarRequirement(cls._API_KEY_ENV_VAR), *cls._extra_requirements()] + return [] @classmethod def _extra_requirements(cls) -> list[Requirement]: @@ -27,7 +25,6 @@ def __init__(self) -> None: headers={"User-Agent": f"{ragna.__version__}/{self}"}, timeout=60, ) - self._api_key = os.environ[self._API_KEY_ENV_VAR] async def answer( self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256 @@ -58,3 +55,20 @@ async def _assert_api_call_is_success(self, response: Response) -> None: response_status_code=response.status_code, response_content=content, ) + + +class AuthenticatedApiAssistant(_ApiAssistant): + _API_KEY_ENV_VAR: str + + @classmethod + def requirements(cls) -> list[Requirement]: + return [EnvVarRequirement(cls._API_KEY_ENV_VAR), *cls._extra_requirements()] + + def __init__(self) -> None: + super().__init__() + self._api_key = os.environ[self._API_KEY_ENV_VAR] + + +class UnauthenticatedApiAssistant(_ApiAssistant): + def __init__(self) -> None: + super().__init__() diff --git a/ragna/assistants/_cohere.py b/ragna/assistants/_cohere.py index a93a264f..0525481b 100644 --- a/ragna/assistants/_cohere.py +++ b/ragna/assistants/_cohere.py @@ -3,10 +3,10 @@ from ragna.core import RagnaException, Source -from ._api import ApiAssistant +from ._api import AuthenticatedApiAssistant -class CohereApiAssistant(ApiAssistant): +class CohereApiAssistant(AuthenticatedApiAssistant): _API_KEY_ENV_VAR = "COHERE_API_KEY" _MODEL: str diff --git a/ragna/assistants/_google.py b/ragna/assistants/_google.py index afbb829a..92998bb3 100644 --- a/ragna/assistants/_google.py +++ b/ragna/assistants/_google.py @@ -3,7 +3,7 @@ from ragna._compat import anext from ragna.core import PackageRequirement, Requirement, Source -from ._api import ApiAssistant +from ._api import AuthenticatedApiAssistant # ijson does not support reading from an (async) iterator, but only from file-like @@ -25,7 +25,7 @@ async def read(self, n: int) -> bytes: return await anext(self._ait, b"") # type: ignore[call-arg] -class GoogleApiAssistant(ApiAssistant): +class GoogleApiAssistant(AuthenticatedApiAssistant): _API_KEY_ENV_VAR = "GOOGLE_API_KEY" _MODEL: str diff --git a/ragna/assistants/_mosaicml.py b/ragna/assistants/_mosaicml.py index 64edac1b..5e024d63 100644 --- a/ragna/assistants/_mosaicml.py +++ b/ragna/assistants/_mosaicml.py @@ -2,10 +2,10 @@ from ragna.core import Source -from ._api import ApiAssistant +from ._api import AuthenticatedApiAssistant -class MosaicmlApiAssistant(ApiAssistant): +class MosaicmlApiAssistant(AuthenticatedApiAssistant): _API_KEY_ENV_VAR = "MOSAICML_API_KEY" _MODEL: str diff --git a/ragna/assistants/_openai.py b/ragna/assistants/_openai.py index a9dad3ee..4a1b0a46 100644 --- a/ragna/assistants/_openai.py +++ b/ragna/assistants/_openai.py @@ -3,10 +3,10 @@ from ragna.core import PackageRequirement, Requirement, Source -from ._api import ApiAssistant +from ._api import AuthenticatedApiAssistant -class OpenaiApiAssistant(ApiAssistant): +class OpenaiApiAssistant(AuthenticatedApiAssistant): _API_KEY_ENV_VAR = "OPENAI_API_KEY" _MODEL: str diff --git a/tests/assistants/test_api.py b/tests/assistants/test_api.py index 97961456..5c84f3cf 100644 --- a/tests/assistants/test_api.py +++ b/tests/assistants/test_api.py @@ -4,7 +4,11 @@ from ragna import assistants from ragna._compat import anext -from ragna.assistants._api import ApiAssistant +from ragna.assistants._api import ( + AuthenticatedApiAssistant, + UnauthenticatedApiAssistant, + _ApiAssistant, +) from ragna.core import RagnaException from tests.utils import skip_on_windows @@ -12,8 +16,10 @@ assistant for assistant in assistants.__dict__.values() if isinstance(assistant, type) - and issubclass(assistant, ApiAssistant) - and assistant is not ApiAssistant + and issubclass(assistant, _ApiAssistant) + and assistant is not _ApiAssistant + and assistant is not AuthenticatedApiAssistant + and assistant is not UnauthenticatedApiAssistant ]