-
Notifications
You must be signed in to change notification settings - Fork 21
/
_components.py
231 lines (181 loc) · 6.38 KB
/
_components.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
from __future__ import annotations
import abc
import enum
import functools
import inspect
from typing import AsyncIterable, AsyncIterator, Iterator, Optional, Type, Union
import pydantic
import pydantic.utils
from ._document import Document
from ._utils import RequirementsMixin, merge_models
class Component(RequirementsMixin):
"""Base class for RAG components.
!!! tip See also
- [ragna.core.SourceStorage][]
- [ragna.core.Assistant][]
"""
@classmethod
def display_name(cls) -> str:
"""
Returns:
Component name.
"""
return cls.__name__
def __repr__(self) -> str:
return self.display_name()
# FIXME: rename this to reflect that these methods can be parametrized from the chat
# level
__ragna_protocol_methods__: list[str]
@classmethod
@functools.cache
def _protocol_models(
cls,
) -> dict[tuple[Type[Component], str], Type[pydantic.BaseModel]]:
protocol_cls, protocol_methods = next(
(cls_, cls_.__ragna_protocol_methods__) # type: ignore[attr-defined]
for cls_ in cls.__mro__
if "__ragna_protocol_methods__" in cls_.__dict__
)
models = {}
for method_name in protocol_methods:
method = getattr(cls, method_name)
concrete_params = inspect.signature(method).parameters
protocol_params = inspect.signature(
getattr(protocol_cls, method_name)
).parameters
extra_param_names = concrete_params.keys() - protocol_params.keys()
models[(cls, method_name)] = pydantic.create_model( # type: ignore[call-overload]
f"{cls.__name__}.{method_name}",
**{
(param := concrete_params[param_name]).name: (
param.annotation,
param.default
if param.default is not inspect.Parameter.empty
else ...,
)
for param_name in extra_param_names
},
)
return models
@classmethod
@functools.cache
def _protocol_model(cls) -> Type[pydantic.BaseModel]:
return merge_models(cls.display_name(), *cls._protocol_models().values())
class Source(pydantic.BaseModel):
"""Data class for sources stored inside a source storage.
Attributes:
id: Unique ID of the source.
document: Document this source belongs to.
location: Location of the source inside the document.
content: Content of the source.
num_tokens: Number of tokens of the content.
"""
model_config = pydantic.ConfigDict(arbitrary_types_allowed=True)
id: str
document: Document
location: str
content: str
num_tokens: int
class SourceStorage(Component, abc.ABC):
__ragna_protocol_methods__ = ["store", "retrieve"]
@abc.abstractmethod
def store(self, documents: list[Document]) -> None:
"""Store content of documents.
Args:
documents: Documents to store.
"""
...
@abc.abstractmethod
def retrieve(self, documents: list[Document], prompt: str) -> list[Source]:
"""Retrieve sources for a given prompt.
Args:
documents: Documents to retrieve sources from.
prompt: Prompt to retrieve sources for.
Returns:
Matching sources for the given prompt ordered by relevance.
"""
...
class MessageRole(enum.Enum):
"""Message role
Attributes:
SYSTEM: The message was produced by the system. This includes the welcome
message when [preparing a new chat][ragna.core.Chat.prepare] as well as
error messages.
USER: The message was produced by the user.
ASSISTANT: The message was produced by an assistant.
"""
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
class Message:
"""Data class for messages.
Attributes:
role: The message producer.
sources: The sources used to produce the message.
!!! tip "See also"
- [ragna.core.Chat.prepare][]
- [ragna.core.Chat.answer][]
"""
def __init__(
self,
content: Union[str, AsyncIterable[str]],
*,
role: MessageRole = MessageRole.SYSTEM,
sources: Optional[list[Source]] = None,
) -> None:
if isinstance(content, str):
self._content: str = content
else:
self._content_stream: AsyncIterable[str] = content
self.role = role
self.sources = sources or []
async def __aiter__(self) -> AsyncIterator[str]:
if hasattr(self, "_content"):
yield self._content
return
chunks = []
async for chunk in self._content_stream:
chunks.append(chunk)
yield chunk
self._content = "".join(chunks)
async def read(self) -> str:
if not hasattr(self, "_content"):
# Since self.__aiter__ is already setting the self._content attribute, we
# only need to exhaust the content stream here.
async for _ in self:
pass
return self._content
@property
def content(self) -> str:
if not hasattr(self, "_content"):
raise RuntimeError(
"Message content cannot be accessed without having iterated over it, "
"e.g. `async for chunk in message`, or reading the content, e.g. "
"`await message.read()`, first."
)
return self._content
def __str__(self) -> str:
return self.content
def __repr__(self) -> str:
return (
f"{type(self).__name__}("
f"content={self.content}, role={self.role}, sources={self.sources}"
f")"
)
class Assistant(Component, abc.ABC):
"""Abstract base class for assistants used in [ragna.core.Chat][]"""
__ragna_protocol_methods__ = ["answer"]
@property
@abc.abstractmethod
def max_input_size(self) -> int:
...
@abc.abstractmethod
def answer(self, prompt: str, sources: list[Source]) -> Iterator[str]:
"""Answer a prompt given some sources.
Args:
prompt: Prompt to be answered.
sources: Sources to use when answering answer the prompt.
Returns:
Answer.
"""
...