Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Entrypoints plugin #192

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions plugins/vikings/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
[project]
name = "vikings"
description = "A collection of bearded and aggressive assistants."
version = "0.1.0"

[tool.setuptools.packages.find]
where = ["src"]

[project.entry-points."ragna.assistants"]
assistants = "vikings.assistants"
Empty file.
20 changes: 20 additions & 0 deletions plugins/vikings/src/vikings/assistants/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
__all__ = ["Ivar"]

from ragna.core import Assistant, Source


class IvarTheBoneless(Assistant):
"""Ivar the Boneless"""

@classmethod
def display_name(cls) -> str:
return "Vikings/IvarTheBoneless"

@property
def max_input_size(self) -> int:
return 873

def answer(
self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256
) -> str:
return "I am Ivar the Boneless! "
19 changes: 19 additions & 0 deletions ragna/_cli/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from rich.table import Table

import ragna
from ragna._compat import importlib_metadata_entry_points
from ragna.core import (
Assistant,
Config,
Expand Down Expand Up @@ -237,6 +238,24 @@ def _handle_unmet_requirements(components: Iterable[Type[Component]]) -> None:
def _wizard_common() -> Config:
config = _wizard_builtin()

if questionary.confirm(
"Do you want to install any ragna assistant plugins?",
default=False,
qmark=QMARK,
).unsafe_ask():
plugin_modules = [
plugin.load()
for plugin in importlib_metadata_entry_points(group="ragna.assistants")
]
for plugin_module in plugin_modules:
plugin_assistants = _select_components(
"assistants",
plugin_module,
Assistant, # type: ignore[type-abstract]
)
for assistant in plugin_assistants:
config.core.assistants.append(assistant)

config.local_cache_root = Path(
questionary.path(
"Where should local files be stored?",
Expand Down
41 changes: 39 additions & 2 deletions ragna/_compat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,27 @@
import sys
from typing import Callable, Iterable, Iterator, Mapping, TypeVar
from typing import (
TYPE_CHECKING,
Any,
Callable,
Iterable,
Iterator,
Mapping,
Protocol,
TypeVar,
)

__all__ = ["itertools_pairwise", "importlib_metadata_package_distributions"]
if TYPE_CHECKING:
if sys.version_info[:2] >= (3, 10):
from importlib.metadata import EntryPoints
else:
from importlib_metadata import EntryPoints


__all__ = [
"itertools_pairwise",
"importlib_metadata_package_distributions",
"importlib_metadata_entry_points",
]

T = TypeVar("T")

Expand Down Expand Up @@ -38,3 +58,20 @@ def _importlib_metadata_package_distributions() -> (


importlib_metadata_package_distributions = _importlib_metadata_package_distributions()


class EntryPointsCallable(Protocol):
def __call__(self, **kwargs: Any) -> "EntryPoints":
...


def _importlib_metadata_entry_points() -> EntryPointsCallable:
if sys.version_info[:2] >= (3, 10):
from importlib.metadata import entry_points
else:
from importlib_metadata import entry_points

return entry_points


importlib_metadata_entry_points = _importlib_metadata_entry_points()