From 85680e263ced99c05455f2abc3b7eaa4fa71c40a Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Mon, 1 Apr 2024 16:57:17 -0700 Subject: [PATCH 1/5] Add `AuthenticatedApiAssistant` and override requirements classmethod --- ragna/assistants/_api.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/ragna/assistants/_api.py b/ragna/assistants/_api.py index 01041c2b..c9936021 100644 --- a/ragna/assistants/_api.py +++ b/ragna/assistants/_api.py @@ -1,7 +1,6 @@ import abc import contextlib import json -import os from typing import AsyncIterator import httpx @@ -12,11 +11,9 @@ class ApiAssistant(Assistant): - _API_KEY_ENV_VAR: str - @classmethod def requirements(cls) -> list[Requirement]: - return [EnvVarRequirement(cls._API_KEY_ENV_VAR), *cls._extra_requirements()] + return [] @classmethod def _extra_requirements(cls) -> list[Requirement]: @@ -27,7 +24,6 @@ def __init__(self) -> None: headers={"User-Agent": f"{ragna.__version__}/{self}"}, timeout=60, ) - self._api_key = os.environ[self._API_KEY_ENV_VAR] async def answer( self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256 @@ -58,3 +54,11 @@ async def _assert_api_call_is_success(self, response: Response) -> None: response_status_code=response.status_code, response_content=content, ) + + +class AuthenticatedApiAssistant(ApiAssistant): + _API_KEY_ENV_VAR: str + + @classmethod + def requirements(cls) -> list[Requirement]: + return [EnvVarRequirement(cls._API_KEY_ENV_VAR), *cls._extra_requirements()] From 0cd617454de53c981cc662651d4b7ccd1242515d Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Mon, 1 Apr 2024 16:58:56 -0700 Subject: [PATCH 2/5] Override `__init__()` for `AuthenticatedApiAssistant` --- ragna/assistants/_api.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/ragna/assistants/_api.py b/ragna/assistants/_api.py index c9936021..0b211be7 100644 --- a/ragna/assistants/_api.py +++ b/ragna/assistants/_api.py @@ -1,6 +1,7 @@ import abc import contextlib import json +import os from typing import AsyncIterator import httpx @@ -62,3 +63,7 @@ class AuthenticatedApiAssistant(ApiAssistant): @classmethod def requirements(cls) -> list[Requirement]: return [EnvVarRequirement(cls._API_KEY_ENV_VAR), *cls._extra_requirements()] + + def __init__(self) -> None: + super().__init__() + self._api_key = os.environ[self._API_KEY_ENV_VAR] From 73951316c41ccd8b2c48082e2836cc5632a5111c Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Mon, 1 Apr 2024 17:44:24 -0700 Subject: [PATCH 3/5] Use `AuthenticatedApiAssistant` --- ragna/assistants/_ai21labs.py | 4 ++-- ragna/assistants/_anthropic.py | 4 ++-- ragna/assistants/_api.py | 4 ++-- ragna/assistants/_cohere.py | 4 ++-- ragna/assistants/_google.py | 4 ++-- ragna/assistants/_mosaicml.py | 4 ++-- ragna/assistants/_openai.py | 4 ++-- tests/assistants/test_api.py | 7 ++++--- 8 files changed, 18 insertions(+), 17 deletions(-) diff --git a/ragna/assistants/_ai21labs.py b/ragna/assistants/_ai21labs.py index 19cfd59b..f941c999 100644 --- a/ragna/assistants/_ai21labs.py +++ b/ragna/assistants/_ai21labs.py @@ -2,10 +2,10 @@ from ragna.core import Source -from ._api import ApiAssistant +from ._api import AuthenticatedApiAssistant -class Ai21LabsAssistant(ApiAssistant): +class Ai21LabsAssistant(AuthenticatedApiAssistant): _API_KEY_ENV_VAR = "AI21_API_KEY" _MODEL_TYPE: str diff --git a/ragna/assistants/_anthropic.py b/ragna/assistants/_anthropic.py index f5f4c538..4b44ce8e 100644 --- a/ragna/assistants/_anthropic.py +++ b/ragna/assistants/_anthropic.py @@ -3,10 +3,10 @@ from ragna.core import PackageRequirement, RagnaException, Requirement, Source -from ._api import ApiAssistant +from ._api import AuthenticatedApiAssistant -class AnthropicApiAssistant(ApiAssistant): +class AnthropicApiAssistant(AuthenticatedApiAssistant): _API_KEY_ENV_VAR = "ANTHROPIC_API_KEY" _MODEL: str diff --git a/ragna/assistants/_api.py b/ragna/assistants/_api.py index 0b211be7..575d4359 100644 --- a/ragna/assistants/_api.py +++ b/ragna/assistants/_api.py @@ -11,7 +11,7 @@ from ragna.core import Assistant, EnvVarRequirement, RagnaException, Requirement, Source -class ApiAssistant(Assistant): +class _ApiAssistant(Assistant): @classmethod def requirements(cls) -> list[Requirement]: return [] @@ -57,7 +57,7 @@ async def _assert_api_call_is_success(self, response: Response) -> None: ) -class AuthenticatedApiAssistant(ApiAssistant): +class AuthenticatedApiAssistant(_ApiAssistant): _API_KEY_ENV_VAR: str @classmethod diff --git a/ragna/assistants/_cohere.py b/ragna/assistants/_cohere.py index a93a264f..0525481b 100644 --- a/ragna/assistants/_cohere.py +++ b/ragna/assistants/_cohere.py @@ -3,10 +3,10 @@ from ragna.core import RagnaException, Source -from ._api import ApiAssistant +from ._api import AuthenticatedApiAssistant -class CohereApiAssistant(ApiAssistant): +class CohereApiAssistant(AuthenticatedApiAssistant): _API_KEY_ENV_VAR = "COHERE_API_KEY" _MODEL: str diff --git a/ragna/assistants/_google.py b/ragna/assistants/_google.py index afbb829a..92998bb3 100644 --- a/ragna/assistants/_google.py +++ b/ragna/assistants/_google.py @@ -3,7 +3,7 @@ from ragna._compat import anext from ragna.core import PackageRequirement, Requirement, Source -from ._api import ApiAssistant +from ._api import AuthenticatedApiAssistant # ijson does not support reading from an (async) iterator, but only from file-like @@ -25,7 +25,7 @@ async def read(self, n: int) -> bytes: return await anext(self._ait, b"") # type: ignore[call-arg] -class GoogleApiAssistant(ApiAssistant): +class GoogleApiAssistant(AuthenticatedApiAssistant): _API_KEY_ENV_VAR = "GOOGLE_API_KEY" _MODEL: str diff --git a/ragna/assistants/_mosaicml.py b/ragna/assistants/_mosaicml.py index 64edac1b..5e024d63 100644 --- a/ragna/assistants/_mosaicml.py +++ b/ragna/assistants/_mosaicml.py @@ -2,10 +2,10 @@ from ragna.core import Source -from ._api import ApiAssistant +from ._api import AuthenticatedApiAssistant -class MosaicmlApiAssistant(ApiAssistant): +class MosaicmlApiAssistant(AuthenticatedApiAssistant): _API_KEY_ENV_VAR = "MOSAICML_API_KEY" _MODEL: str diff --git a/ragna/assistants/_openai.py b/ragna/assistants/_openai.py index a9dad3ee..4a1b0a46 100644 --- a/ragna/assistants/_openai.py +++ b/ragna/assistants/_openai.py @@ -3,10 +3,10 @@ from ragna.core import PackageRequirement, Requirement, Source -from ._api import ApiAssistant +from ._api import AuthenticatedApiAssistant -class OpenaiApiAssistant(ApiAssistant): +class OpenaiApiAssistant(AuthenticatedApiAssistant): _API_KEY_ENV_VAR = "OPENAI_API_KEY" _MODEL: str diff --git a/tests/assistants/test_api.py b/tests/assistants/test_api.py index 97961456..013e5e67 100644 --- a/tests/assistants/test_api.py +++ b/tests/assistants/test_api.py @@ -4,7 +4,7 @@ from ragna import assistants from ragna._compat import anext -from ragna.assistants._api import ApiAssistant +from ragna.assistants._api import AuthenticatedApiAssistant, _ApiAssistant from ragna.core import RagnaException from tests.utils import skip_on_windows @@ -12,8 +12,9 @@ assistant for assistant in assistants.__dict__.values() if isinstance(assistant, type) - and issubclass(assistant, ApiAssistant) - and assistant is not ApiAssistant + and issubclass(assistant, _ApiAssistant) + and assistant is not _ApiAssistant + and assistant is not AuthenticatedApiAssistant ] From 917e958697e75d9547550e602c4bacaaeeb3e7cb Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Mon, 1 Apr 2024 17:47:04 -0700 Subject: [PATCH 4/5] Add `UnauthenticatedApiAssistant` --- ragna/assistants/_api.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/ragna/assistants/_api.py b/ragna/assistants/_api.py index 575d4359..296476b8 100644 --- a/ragna/assistants/_api.py +++ b/ragna/assistants/_api.py @@ -67,3 +67,8 @@ def requirements(cls) -> list[Requirement]: def __init__(self) -> None: super().__init__() self._api_key = os.environ[self._API_KEY_ENV_VAR] + + +class UnauthenticatedApiAssistant(_ApiAssistant): + def __init__(self) -> None: + super().__init__() From aa1f9856660935308a92090efef04abb1e4d54c0 Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Mon, 1 Apr 2024 17:51:34 -0700 Subject: [PATCH 5/5] Add `UnuthenticatedApiAssistant` to tests --- tests/assistants/test_api.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/assistants/test_api.py b/tests/assistants/test_api.py index 013e5e67..5c84f3cf 100644 --- a/tests/assistants/test_api.py +++ b/tests/assistants/test_api.py @@ -4,7 +4,11 @@ from ragna import assistants from ragna._compat import anext -from ragna.assistants._api import AuthenticatedApiAssistant, _ApiAssistant +from ragna.assistants._api import ( + AuthenticatedApiAssistant, + UnauthenticatedApiAssistant, + _ApiAssistant, +) from ragna.core import RagnaException from tests.utils import skip_on_windows @@ -15,6 +19,7 @@ and issubclass(assistant, _ApiAssistant) and assistant is not _ApiAssistant and assistant is not AuthenticatedApiAssistant + and assistant is not UnauthenticatedApiAssistant ]