Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor ApiAssistant to AuthenticatedApiAssistant and UnauthenticatedApiAssistant #381

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions ragna/assistants/_ai21labs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

from ragna.core import Source

from ._api import ApiAssistant
from ._api import AuthenticatedApiAssistant


class Ai21LabsAssistant(ApiAssistant):
class Ai21LabsAssistant(AuthenticatedApiAssistant):
_API_KEY_ENV_VAR = "AI21_API_KEY"
_MODEL_TYPE: str

Expand Down
4 changes: 2 additions & 2 deletions ragna/assistants/_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

from ragna.core import PackageRequirement, RagnaException, Requirement, Source

from ._api import ApiAssistant
from ._api import AuthenticatedApiAssistant


class AnthropicApiAssistant(ApiAssistant):
class AnthropicApiAssistant(AuthenticatedApiAssistant):
_API_KEY_ENV_VAR = "ANTHROPIC_API_KEY"
_MODEL: str

Expand Down
24 changes: 19 additions & 5 deletions ragna/assistants/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,10 @@
from ragna.core import Assistant, EnvVarRequirement, RagnaException, Requirement, Source


class ApiAssistant(Assistant):
_API_KEY_ENV_VAR: str

class _ApiAssistant(Assistant):
@classmethod
def requirements(cls) -> list[Requirement]:
return [EnvVarRequirement(cls._API_KEY_ENV_VAR), *cls._extra_requirements()]
return []
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the default return if we don't override the method. Meaning, we can just leave it out here.


@classmethod
def _extra_requirements(cls) -> list[Requirement]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if it makes sense to have this on the superclass if for now only the authenticated subclass needs it.

Expand All @@ -27,7 +25,6 @@ def __init__(self) -> None:
headers={"User-Agent": f"{ragna.__version__}/{self}"},
timeout=60,
)
self._api_key = os.environ[self._API_KEY_ENV_VAR]

async def answer(
self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256
Expand Down Expand Up @@ -58,3 +55,20 @@ async def _assert_api_call_is_success(self, response: Response) -> None:
response_status_code=response.status_code,
response_content=content,
)


class AuthenticatedApiAssistant(_ApiAssistant):
_API_KEY_ENV_VAR: str

@classmethod
def requirements(cls) -> list[Requirement]:
return [EnvVarRequirement(cls._API_KEY_ENV_VAR), *cls._extra_requirements()]

def __init__(self) -> None:
super().__init__()
self._api_key = os.environ[self._API_KEY_ENV_VAR]


class UnauthenticatedApiAssistant(_ApiAssistant):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class is not needed. We can simply use the superclass for this.

def __init__(self) -> None:
super().__init__()
4 changes: 2 additions & 2 deletions ragna/assistants/_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

from ragna.core import RagnaException, Source

from ._api import ApiAssistant
from ._api import AuthenticatedApiAssistant


class CohereApiAssistant(ApiAssistant):
class CohereApiAssistant(AuthenticatedApiAssistant):
_API_KEY_ENV_VAR = "COHERE_API_KEY"
_MODEL: str

Expand Down
4 changes: 2 additions & 2 deletions ragna/assistants/_google.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from ragna._compat import anext
from ragna.core import PackageRequirement, Requirement, Source

from ._api import ApiAssistant
from ._api import AuthenticatedApiAssistant


# ijson does not support reading from an (async) iterator, but only from file-like
Expand All @@ -25,7 +25,7 @@ async def read(self, n: int) -> bytes:
return await anext(self._ait, b"") # type: ignore[call-arg]


class GoogleApiAssistant(ApiAssistant):
class GoogleApiAssistant(AuthenticatedApiAssistant):
_API_KEY_ENV_VAR = "GOOGLE_API_KEY"
_MODEL: str

Expand Down
4 changes: 2 additions & 2 deletions ragna/assistants/_mosaicml.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

from ragna.core import Source

from ._api import ApiAssistant
from ._api import AuthenticatedApiAssistant


class MosaicmlApiAssistant(ApiAssistant):
class MosaicmlApiAssistant(AuthenticatedApiAssistant):
_API_KEY_ENV_VAR = "MOSAICML_API_KEY"
_MODEL: str

Expand Down
4 changes: 2 additions & 2 deletions ragna/assistants/_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

from ragna.core import PackageRequirement, Requirement, Source

from ._api import ApiAssistant
from ._api import AuthenticatedApiAssistant


class OpenaiApiAssistant(ApiAssistant):
class OpenaiApiAssistant(AuthenticatedApiAssistant):
_API_KEY_ENV_VAR = "OPENAI_API_KEY"
_MODEL: str

Expand Down
12 changes: 9 additions & 3 deletions tests/assistants/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,22 @@

from ragna import assistants
from ragna._compat import anext
from ragna.assistants._api import ApiAssistant
from ragna.assistants._api import (
AuthenticatedApiAssistant,
UnauthenticatedApiAssistant,
_ApiAssistant,
)
from ragna.core import RagnaException
from tests.utils import skip_on_windows

API_ASSISTANTS = [
assistant
for assistant in assistants.__dict__.values()
if isinstance(assistant, type)
and issubclass(assistant, ApiAssistant)
and assistant is not ApiAssistant
and issubclass(assistant, _ApiAssistant)
and assistant is not _ApiAssistant
and assistant is not AuthenticatedApiAssistant
and assistant is not UnauthenticatedApiAssistant
Comment on lines +19 to +22
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
and issubclass(assistant, _ApiAssistant)
and assistant is not _ApiAssistant
and assistant is not AuthenticatedApiAssistant
and assistant is not UnauthenticatedApiAssistant
and issubclass(assistant, AuthenticatedApiAssistant)
and assistant is not AuthenticatedApiAssistant

]


Expand Down