# Retrieval Augmented Generation

In [None]:
import itertools
from glob import glob
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.document_loaders import PyMuPDFLoader
from langchain_community.vectorstores import Chroma
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnableLambda, RunnableParallel, RunnablePassthrough
from langchain_openai import OpenAI
from langchain_text_splitters import RecursiveCharacterTextSplitter
from operator import itemgetter

### Load documents

In [None]:
docs = [PyMuPDFLoader(fp).load() for fp in glob("./documents/*")]
print(f"Number of documents: {len(docs)}")
print(f"Total pages: {sum(len(d) for d in docs)}")

### Split into chunks

In [None]:
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=1000,
    chunk_overlap=200,
    separators=["\n\n", ".\n", ".", " ", "\n"]
)
chunks = text_splitter.split_documents(itertools.chain.from_iterable(docs))
print(f"Total chunks: {len(chunks)}")

### Vector Store

In [None]:
embeddings = HuggingFaceEmbeddings()
vectorstore = Chroma.from_documents(documents=chunks, embedding=embeddings)

### RAG chain

In [None]:
def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)

In [None]:
prompt = PromptTemplate.from_template("""You are an AI assistant, that helps people find information in documents.
Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise.

Question: {question} 

Context: {context} 

Answer:""")

chat_llm = OpenAI(
    base_url="http://localhost:8080/v1",
    api_key="dev",
    model="mistral",
    temperature=0,
    max_tokens=128,
)

retriever = vectorstore.as_retriever()

rag_chain = (
    {"context": retriever | format_docs, "question": RunnablePassthrough()}
    | prompt
    | chat_llm
    | StrOutputParser()
)

In [None]:
rag_chain.invoke("What to do if the dishwasher doesn't run?")

In [None]:
response = rag_chain.stream("What is covered by the warranty?")
for token in response:
    print(token, end="")

In [None]:
response = rag_chain.stream("What is the normal wash temperature?")
for token in response:
    print(token, end="")

### RAG chain with references

In [None]:
rag_chain_with_ref = (
    RunnableParallel({"documents": retriever, "question": RunnablePassthrough()})
    | {
        "documents": itemgetter("documents"),
        "answer": (
            {"context": RunnableLambda(itemgetter("documents")) | format_docs, "question": itemgetter("question")}
            | prompt
            | chat_llm
            | StrOutputParser()
        )
    }
)

In [None]:
result = rag_chain_with_ref.invoke("What is the normal wash temperature?")

print("Answer:")
print(result["answer"].strip())
for i, doc in enumerate(result["documents"], 1):
    print('-' * 80)
    print()
    print(f"[{i}] {doc.metadata['title'].strip()}, Page {doc.metadata['page']}, {doc.metadata['author'].strip()}")
    print(doc.page_content.strip())