# Setup

Create a `.env` file containing:
```
OPENAI_API_KEY="<your key here>"
```

Install langchain and dewy-client as shown below:

In [None]:
%pip install dewy-client langchain langchain-openai

# Example LangChain without RAG
This example shows a simple LangChain application which attempts to answer questions without retrieval.

In [None]:
from langchain_openai import ChatOpenAI
# MODEL="gpt-4-0125-preview"
MODEL="gpt-3.5-turbo"
llm = ChatOpenAI(temperature=0.9, model_name=MODEL)

llm.invoke("What is RAG useful for?")

# Example LangChain with RAG (using Dewy)
This example shows what the previous chain looks like using Dewy to retrieve relevant chunks.

## Create the Dewy Client
The following cell creates the Dewy client. It assumes you wish to connect to a Dewy service running on your local machine on port 8000. Change the URL as appropriate to your situation.

In [None]:
from dewy_client import Client
client = Client(base_url="http://localhost:8000")

In [None]:
# The following retrieves a collection ID from Dewy.
# In general use you could hard-code the collection ID.
# This may switch to using the names directly.
from dewy_client.api.default import list_collections
collection = list_collections.sync(name="main", client=client)[0]
print(f"Collection: {collection.to_dict()}")
collection_id = collection.id

## Retrieving documents in a chain

In [None]:
# Langchain retriever using Dewy.
#
# This will be added to Dewy or LangChain.
from langchain_core.callbacks.manager import AsyncCallbackManagerForRetrieverRun
from langchain_core.retrievers import BaseRetriever
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from typing import Any, Coroutine, List

from dewy_client.api.default import retrieve_chunks
from dewy_client.models import RetrieveRequest, TextResult

class DewyRetriever(BaseRetriever):

    collection_id: int

    def _make_request(self, query: str) -> RetrieveRequest:
        return RetrieveRequest(
            collection_id=self.collection_id,
            query=query,
            include_image_chunks=False,
        )

    def _make_document(self, chunk: TextResult) -> Document:
        return Document(page_content=chunk.text, metadata = { "chunk_id": chunk.chunk_id })

    def _get_relevant_documents(
        self, query: str, *, run_manager: CallbackManagerForRetrieverRun
    ) -> List[Document]:
        retrieved = retrieve_chunks.sync(client=client, body=self._make_request(query))
        return [self._make_document(chunk) for chunk in retrieved.text_results]

    async def _aget_relevant_documents(
        self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
    ) -> Coroutine[Any, Any, List[Document]]:
        retrieved = await retrieve_chunks.asyncio(client=client, body=self._make_request(query))
        return [self._make_document(chunk) for chunk in retrieved.text_results]

In [None]:
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate

retriever = DewyRetriever(collection_id=collection_id)
prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            """
            You're a helpful AI assistant. Given a user question and some retrieved content, answer the user question.
            If none of the articles answer the question, just say you don't know.

            Here is the retrieved content:
            {context}
            """,
        ),
        ("human", "{question}"),
    ]
)

def format_chunks(chunks):
    return "\n\n".join([d.page_content for d in chunks])

chain = (
    { "context": retriever | format_chunks, "question": RunnablePassthrough() }
    | prompt
    | llm
    | StrOutputParser()
)

In [None]:
chain.invoke("What is RAG useful for?")

## Langchain with Citations
Based on https://python.langchain.com/docs/use_cases/question_answering/citations#cite-documents.

In [None]:
from langchain_core.pydantic_v1 import BaseModel, Field
from operator import itemgetter
from langchain_core.runnables import (
    RunnableLambda,
)

class cited_answer(BaseModel):
    """Answer the user question based only on the given sources, and cite the sources used."""

    answer: str = Field(
        ...,
        description="The answer to the user question, which is based only on the given sources.",
    )
    citations: List[int] = Field(
        ...,
        description="The integer IDs of the SPECIFIC sources which justify the answer.",
    )

def format_docs_with_id(docs: List[Document]) -> str:
    formatted = [
        f"Source ID: {doc.metadata['chunk_id']}\nArticle Snippet: {doc.page_content}"
        for doc in docs
    ]
    return "\n\n" + "\n\n".join(formatted)

format = itemgetter("docs") | RunnableLambda(format_docs_with_id)

# Setup a "cited_answer" tool.
from langchain.output_parsers.openai_tools import JsonOutputKeyToolsParser
output_parser = JsonOutputKeyToolsParser(key_name="cited_answer", return_single=True)

llm_with_tool = llm.bind_tools(
    [cited_answer],
    tool_choice="cited_answer",
)
answer = prompt | llm_with_tool | output_parser

citation_chain = (
    RunnableParallel(docs = retriever, question=RunnablePassthrough())
    .assign(context=format)
    .assign(cited_answer=answer)
    # Can't include `docs` because they're not JSON serializable.
    .pick(["cited_answer"])
)

In [None]:
citation_chain.invoke("What is RAG useful for?")

## Bonus: Adding documents to the collection

In [None]:
from dewy_client.api.default import add_document
from dewy_client.models import AddDocumentRequest
add_document.sync(client=client, body=AddDocumentRequest(
    url = "https://arxiv.org/pdf/2305.14283.pdf",
    collection_id=collection_id,
))