In [None]:
from operator import itemgetter
from dotenv import dotenv_values
from langchain.prompts import ChatPromptTemplate
from langchain_community.chat_models import ChatOpenAI
from langchain_community.embeddings import OpenAIEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain.retrievers import BM25Retriever, EnsembleRetriever
from langchain.document_loaders import TextLoader

In [None]:
ENV_CONFIG = dotenv_values("../.env")
vectorstore = FAISS.load_local("../data/db_index", 
    OpenAIEmbeddings(api_key=ENV_CONFIG.get("API_KEY"), 
                               base_url=ENV_CONFIG.get("BASE_URL"))
)
retriever = vectorstore.as_retriever()

In [None]:
vectorstore.similarity_search("RAG的本质是什么",k = 1)

In [None]:
text_loader = TextLoader("../data/text.txt")
text_docs = text_loader.load()

In [None]:
bm25_retriever = BM25Retriever.from_documents(documents = text_docs, k = 1)


In [None]:
emb_retriever = vectorstore.as_retriever(search_kwargs = {"k": 1})
bm25_retriever

In [None]:
ensemble_retriever = EnsembleRetriever(
    retrievers = [bm25_retriever, emb_retriever],weight = [0.5,0.5]
)

In [None]:
template = """
answer the question based only on the following context:
{context}

Question: {question}
"""

prompt = ChatPromptTemplate.from_template(template)

In [None]:
chain = (
    {"context": ensemble_retriever, "question": RunnablePassthrough()}
    | prompt
    | model
    | StrOutputParser()
)

In [None]:
print(({"context": ensemble_retriever, "question": RunnablePassthrough()}
    | prompt).invoke("Rag的本质是什么").messages[0].content)

In [None]:
# type(chain)
model.invoke("rag的本质是什么")


In [None]:
chain.invoke("rag的本质是什么？")


In [None]:
qa = RetrieverQA.from_chain_type(
    llm = model,
    chain_type = "stuff",
    retriever = retriever
)

qa.invoke("rag的本质是什么")

In [None]:
template = """Answer the question based only on the following context:
{context}

Question: {question}

Answer in the following language: {language}
"""
prompt = ChatPromptTemplate.from_template(template)

chain = (
    {
        "context": itemgetter("question") | ensemble_retriever,
        "question": itemgetter("question"),
        "language": itemgetter("language"),
    }
    | prompt
    | model
    | StrOutputParser()
)

In [None]:
chain.invoke({"question": "RAG的本质是什么？", "language": "English"})

In [None]:
from langchain.prompts.prompt import PromptTemplate
# 这里 chat history 是用来 【标准化用户的输入的问题】

_template = """Given the following conversation and a follow up question, 
rephrase the follow up question to be a standalone question, in its original language.

Chat History:
{chat_history}
Follow Up Input: {question}
Standalone question:"""
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)

In [None]:
template = """Answer the question based only on the following context:
{context}

Question: {question}
"""
ANSWER_PROMPT = ChatPromptTemplate.from_template(template)

In [None]:

DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")


def _combine_documents(
    docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"
):
    # 重新排序； 把第二个放到 最末尾
    final_docs = [docs[0]]
    for i in range(2, len(docs)):
        final_docs.append(docs[i])
    final_docs.append(docs[1])
    
    doc_strings = [format_document(doc, document_prompt) for doc in final_docs]
    return document_separator.join(doc_strings)
# print('\n'.join([x.page_content for x in docs_test]))

In [None]:
_inputs = RunnableParallel(
    standalone_question=RunnablePassthrough.assign(
        chat_history=lambda x: get_buffer_string(x["chat_history"])
    )
    | CONDENSE_QUESTION_PROMPT
    | ChatOpenAI(temperature=0, api_key=ENV_CONFIG.get("API_KEY"), base_url=ENV_CONFIG.get("BASE_URL"))
    | StrOutputParser(),
)

_context = {
    "context": itemgetter("standalone_question") | ensemble_retriever | _combine_documents,
    "question": lambda x: x["standalone_question"],
}
conversational_qa_chain = _inputs | _context | ANSWER_PROMPT | ChatOpenAI(api_key=ENV_CONFIG.get("API_KEY"), base_url=ENV_CONFIG.get("BASE_URL")) | StrOutputParser()