# RAG


In [None]:
import re
from collections import defaultdict
from datetime import datetime
from typing import Iterable, cast

import gel
import gel.ai
import tiktoken
from pydantic import BaseModel, TypeAdapter
from pydantic_ai import Agent
from pydantic_ai.messages import (
    FinalResultEvent,
    FunctionToolCallEvent,
    FunctionToolResultEvent,
    PartDeltaEvent,
    PartStartEvent,
    TextPart,
    TextPartDelta,
    ToolCallPartDelta,
)
from settings import (
    EMBEDDING_MODEL_MAX_TOKENS,
    EMBEDDING_MODEL_NAME,
    NANAPI_CLIENT_ID,
    PYDANTIC_AI_DEFAULT_MODEL_NAME,
    PYDANTIC_AI_MODEL_CLS,
    PYDANTIC_AI_PROVIDER,
)

## Load


In [None]:
client = cast(
    gel.AsyncIOClient,
    gel.create_async_client().with_globals(client_id=NANAPI_CLIENT_ID),
)

In [None]:
class Author(BaseModel):
    id: str
    username: str
    global_name: str | None
    bot: bool | None = None


class Message(BaseModel):
    id: str
    channel_id: str
    content: str
    timestamp: datetime
    author: Author


class MessagePage(BaseModel):
    context: str
    channel_id: str
    messages: list[Message] | None = None

In [None]:
resp1 = await client.query(
    r"""
    select discord::Message { * }
    filter not exists .deleted_at
    and .timestamp > <datetime>"2024-01-01T00:00:00+00:00"
    order by .timestamp asc
    """
)

In [None]:
all_messages = [Message.model_validate_json(item.data) for item in resp1]
all_messages = [m for m in all_messages if not m.author.bot]

## Split


In [None]:
all_channel_messages = defaultdict[str, list[Message]](list)
for message in all_messages:
    all_channel_messages[message.channel_id].append(message)
# print({k: len(v) for k, v in all_channel_messages.items()})

In [None]:
SPACE_REG = re.compile(r'\s+')
encoding = tiktoken.encoding_for_model(EMBEDDING_MODEL_NAME)


def format_message(message: Message) -> str | None:
    username = message.author.username
    author = f'{gn} ({username})' if (gn := message.author.global_name) else username
    content = SPACE_REG.sub(' ', message.content).strip()
    if content:
        return (
            f'Author: {author}; '
            f'Timestamp: {message.timestamp:%Y-%m-%d %H:%M:%S}; '
            f'Content: {content}\n'
        )


def yield_pages(messages: Iterable[Message]):
    page_messages: list[Message] = []
    page_lines: list[str] = []
    page_tokens = 0
    for message in messages:
        line = format_message(message)
        if not line:
            continue
        line_tokens = len(encoding.encode(line))
        if line_tokens > EMBEDDING_MODEL_MAX_TOKENS:
            if page_lines:
                yield page_messages, ''.join(page_lines)
                page_messages, page_lines, page_tokens = overlap(page_messages, page_lines)
            yield [message], line
            continue
        if len(page_messages) == 100 or page_tokens + line_tokens > EMBEDDING_MODEL_MAX_TOKENS:
            yield page_messages, ''.join(page_lines)
            page_messages, page_lines, page_tokens = overlap(page_messages, page_lines)
        page_messages.append(message)
        page_lines.append(line)
        page_tokens += line_tokens
    if page_lines:
        yield page_messages, ''.join(page_lines)


def overlap(
    lines_messages: list[Message], lines: list[str]
) -> tuple[list[Message], list[str], int]:
    assert len(lines_messages) == len(lines)
    messages_overlap = lines_messages[int(len(lines_messages) * 0.8) :]
    lines_overlap = lines[int(len(lines) * 0.8) :]
    return messages_overlap, lines_overlap, len(encoding.encode(''.join(lines_overlap)))

In [None]:
pages: list[MessagePage] = []
for channel_id, channel_messages in all_channel_messages.items():
    channel_messages.sort(key=lambda m: m.timestamp)
    for messages, context in yield_pages(channel_messages):
        pages.append(MessagePage(context=context, channel_id=channel_id, messages=messages))

len(pages)

## Embed and store


In [None]:
PAGE_INSERT_QUERY = r"""
with
    context := <str>$context,
    channel_id := <str>$channel_id,
    message_ids := <array<str>>$message_ids,
    messages := (
        select discord::Message
        filter .client = global client and .message_id in array_unpack(message_ids)
    )
insert discord::MessagePage {
    client := global client,
    context := context,
    channel_id := channel_id,
    messages := messages,
}
"""

for page in pages:
    await client.query(
        PAGE_INSERT_QUERY,
        context=page.context,
        channel_id=page.channel_id,
        message_ids=[m.id for m in page.messages],
    )

## Retrieve & Generate


In [None]:
rag = await gel.ai.create_async_rag_client(client, model='')

In [None]:
QUESTION = """
Que pense bidon de la censure ?
"""

RAG_QUERY = r"""
with
    embeddings := <array<float32>>$embeddings
select ext::ai::search(discord::MessagePage { * }, embeddings)
"""


class SearchResult(BaseModel):
    object: MessagePage
    distance: float


search_adapter = TypeAdapter(list[SearchResult])

model = PYDANTIC_AI_MODEL_CLS(
    PYDANTIC_AI_DEFAULT_MODEL_NAME,
    provider=PYDANTIC_AI_PROVIDER,
)

agent = Agent(model, system_prompt='The assistant should retrieve context before answering.')


@agent.tool_plain
async def retrieve(search_query: str) -> str:
    """Retrieve chat sections based on a search query in French."""
    print(search_query)
    embeddings = await rag.generate_embeddings(search_query, model=EMBEDDING_MODEL_NAME)
    resp = await client.query_json(RAG_QUERY, embeddings=embeddings[:2000])
    results = search_adapter.validate_json(resp)
    pages = [p.object.context for p in results[:50]]
    context = '\n-------------------------\n'.join(pages)
    return context


async with agent.iter(QUESTION) as run:
    async for node in run:
        if Agent.is_user_prompt_node(node):
            # A user prompt node => The user has provided input
            ...
        elif Agent.is_model_request_node(node):
            # A model request node => We can stream tokens from the model's request
            async with node.stream(run.ctx) as request_stream:
                async for event in request_stream:
                    if isinstance(event, PartStartEvent):
                        if isinstance(event.part, TextPart):
                            print(event.part.content, end='')
                    elif isinstance(event, PartDeltaEvent):
                        if isinstance(event.delta, TextPartDelta):
                            print(event.delta.content_delta, end='')
                        elif isinstance(event.delta, ToolCallPartDelta):
                            ...
                    elif isinstance(event, FinalResultEvent):
                        ...
                print()
        elif Agent.is_call_tools_node(node):
            # A handle-response node => The model returned some data, potentially calls a tool
            async with node.stream(run.ctx) as handle_stream:
                async for event in handle_stream:
                    if isinstance(event, FunctionToolCallEvent):
                        ...
                    elif isinstance(event, FunctionToolResultEvent):
                        ...
        elif Agent.is_end_node(node):
            assert run.result and run.result.output == node.data.output
            # Once an End node is reached, the agent run is complete
            ...