#### Self RAG
Self - reflection can enhance RAG, enabling correction of poor quality retrieval or generations
Several recent papers focused on this theme, but implementing the ideas can be tricky. This section focuses on using mistral models for
the implementation of self-RAG

In [1]:
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores.chroma import Chroma
from langchain_community.embeddings.gpt4all import GPT4AllEmbeddings

url = "https://lilianweng.github.io/posts/2023-06-23-agent/"
docs = WebBaseLoader(url).load()


text_splitter  = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
    chunk_size =500,
    chunk_overlap =100)

all_splits = text_splitter.split_documents(docs)

embedding = GPT4AllEmbeddings()


vectorstore = Chroma.from_documents(
    documents=all_splits,
    collection_name="rag-chroma",
    embedding=embedding
)

retreiver = vectorstore.as_retriever()

bert_load_from_file: gguf version     = 2
bert_load_from_file: gguf alignment   = 32
bert_load_from_file: gguf data offset = 695552
bert_load_from_file: model name           = BERT
bert_load_from_file: model architecture   = bert
bert_load_from_file: model file type      = 1
bert_load_from_file: bert tokenizer vocab = 30522


##### State
Every node in our graph will modify `state`, which is dict that contains values (`question`, `documents`, etc) relevant to RAG

In [2]:
from typing import Dict, TypedDict, Any
from langchain_core.messages import BaseMessage

class GraphState(TypedDict):
    """ 
    Represents the state of our graph
    """
    keys: Dict[str, Any]

#### Node and Edges
Every node in the graph we laid out above is a function
Each ndoe will modify the state in some way
Each edge will choose which node to call next

In [3]:
import json
import operator
from typing import Annotated, Sequence, TypedDict
from langchain import hub
from langchain.prompts import PromptTemplate
from langchain_core.messages import BaseMessage, FunctionMessage
from langchain_community.vectorstores.chroma import Chroma
from langchain_core.output_parsers import StrOutputParser
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.runnables import RunnablePassthrough
from langchain_openai.chat_models import ChatOpenAI
from langchain_core.output_parsers import JsonOutputParser


def retrieve(state: GraphState):
    """ 
    Retrieve documents
    """
    print("--retrieve--")
    state_dict = state["keys"]
    question = state_dict["question"]
    documents = retreiver.get_relevant_documents(question)
    return {
        "keys": {
            "documents": documents,
            "question": question
        }
    }
    
def generate(state: GraphState):
    """
    Generate answer
    """
    print("--generate--")
    state_dict = state["keys"]
    question = state_dict["question"]
    documents = state_dict["documents"]
    
    prompt = hub.pull("rlm/rag-prompt")
    
    llm = ChatOpenAI(model="mistralai/Mixtral-8x7B-Instruct-v0.1")
    
    def format_docs(docs):
        return "\n\n".join(doc.page_content for doc in docs)
    
    rag_chain = prompt | llm | StrOutputParser()
    
    generation = rag_chain.invoke({"context": documents, "question": question})
    return {
        "keys":{
            "documents": documents,
            "question": question,
            "generation": generation
        }
    }
    
def grade_documents(state: GraphState):
    """ 
    Determines whether the retrieved documents are relevant to the question
    """
    print("--check relevance--")
    state_dict = state["keys"]
    question = state_dict["question"]
    documents = state_dict["documents"]
    
    llm = ChatOpenAI(model="mistralai/Mixtral-8x7B-Instruct-v0.1")
    prompt = PromptTemplate(
        template="""
        You are a grader assessing relevance of a retrieved document to a user question. \n 
        Here is the retrieved document: \n\n {context} \n\n
        Here is the user question: {question} \n
        If the document contains keywords related to the user question, grade it as relevant. \n
        It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \n
        Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question. \n
        Provide the binary score as a JSON with a single key 'score' and no premable or explaination.
        """,
        input_variables=["question", "context"]
    )

    chain = prompt | llm | JsonOutputParser()
    
    # score
    filtered_docs = []
    for d in documents:
        score = chain.invoke(
            {
                "question": question,
                "context": d.page_content,
            }
        )
        grade = score["score"]
        if grade == "yes":
            print("-- document relevant --")
            filtered_docs.append(d)
        else:
            print("document not relevant --")
            continue
        
    return {
        "keys":{
            "documents": filtered_docs,
            "question": question,
        }
    }

In [4]:
def transform_query(state: GraphState):
    """
    Trasnform the query to produce a better question
    """
    print("tranform query")
    state_dict = state["keys"]
    question = state_dict["question"]
    documents = state_dict["documents"]
    
    # LLM
    llm = ChatOpenAI(model="mistralai/Mixtral-8x7B-Instruct-v0.1")
    
    prompt = PromptTemplate(
        template="""You are generating questions that is well optimized for retrieval. \n 
        Look at the input and try to reason about the underlying sematic intent / meaning. \n 
        Here is the initial question:
        \n ------- \n
        {question} 
        \n ------- \n
        Formulate an improved question:""",
        input_variables=["question"],
    )
    
    chain = prompt | llm | StrOutputParser()
    better_question = chain.invoke({"question": question})
    
    return { "keys": {
        "documents": documents,
        "question": better_question
    }}

In [5]:
def prepare_for_final_grade(state: GraphState):
    """Passthrough state for final grade"""
    print("final grade")
    state_dict = state["keys"]
    question = state_dict["question"]
    documents = state_dict["documents"]
    generation = state_dict["generation"]
    
    return {
        "keys": {
            "documents": documents, 
            "question": question,
            "generation": generation
        }
    }

In [7]:
def decide_to_generate(state: GraphState):
    """ 
    Determines whether to generate an answer, or re-generate a question
    """
    print("--decide to generate--")
    state_dict = state["keys"]
    question = state_dict["question"]
    filter_documents = state_dict["question"]
    
    if not filter_documents:
        print("tranform query")
        return "transform_query"
    
    else:
        print("--generate--")
        return "generate"
    

def grade_generation_v_documents(state: GraphState):
    """
    Determines whether the generation is grounded in the document.
    """
    print("grade generation vs documents")
    state_dict = state["keys"]
    question = state_dict["question"]
    documents = state_dict["documents"]
    generation = state_dict["generation"]
    
    llm = ChatOpenAI(model="mistralai/Mixtral-8x7B-Instruct-v0.1")
    
    prompt = PromptTemplate(
        template="""You are a grader assessing whether an answer is grounded in / supported by a set of facts. \n 
        Here are the facts:
        \n ------- \n
        {documents} 
        \n ------- \n
        Here is the answer: {generation}
        Give a binary score 'yes' or 'no' score to indicate whether the answer is grounded in / supported by a set of facts. \n
        Provide the binary score as a JSON with a single key 'score' and no premable or explaination.""",
        input_variables=["generation", "documents"]
    )
    
    chain = prompt | llm | JsonOutputParser()
    score = chain.invoke({"generation": generation, "documents": documents})
    grade = score["score"]
    
    if grade == "yes":
        print("decision supported, move to final grade")
        return "supported"
    
    else:
        print("generate again")
        return "not supported"
    

In [8]:
def grade_generation_v_question(state: GraphState):
    """ 
    Determine whether the generation address the question
    """
    print("--grade generation vs question--")
    state_dict = state["keys"]
    question = state_dict["question"]
    documents = state_dict["documents"]
    generation = state_dict["generation"]
    
    
    llm = ChatOpenAI(model="mistralai/Mixtral-8x7B-Instruct-v0.1")
    
    
    prompt = PromptTemplate(
        template="""You are a grader assessing whether an answer is useful to resolve a question. \n 
        Here is the answer:
        \n ------- \n
        {generation} 
        \n ------- \n
        Here is the question: {question}
        Give a binary score 'yes' or 'no' to indicate whether the answer is useful to resolve a question. \n
        Provide the binary score as a JSON with a single key 'score' and no premable or explaination.""",
        input_variables=["generation", "question"],
    )

    chain = prompt | llm | JsonOutputParser()
    score = chain.invoke({"generation": generation, "question": question})
    grade = score["score"]
    
    if grade == "yes":
        print("decision useful")
        return "useful"
    
    else:
        print("decision not useful")
        return "not useful"

##### Build the Graph
This just follows the flow we outlined in the figure above

In [9]:
import pprint

from langgraph.graph import END, StateGraph


workflow = StateGraph(GraphState)

workflow.add_node("retrieve", retrieve)
workflow.add_node("grade_documents", grade_documents)
workflow.add_node("generate", generate)
workflow.add_node("transform_query", transform_query)
workflow.add_node("prepare_for_final_grade", prepare_for_final_grade)

workflow.set_entry_point("retrieve")
workflow.add_edge('retrieve', "grade_documents")
workflow.add_conditional_edges(
    "grade_documents",
    decide_to_generate,
    {
        "transform_query": "transform_query",
        "generate": "generate"
    }
)

workflow.add_edge("transform_query", "retrieve")
workflow.add_conditional_edges(
    "generate",
    grade_generation_v_documents,
    {
        "supported": "prepare_for_final_grade",
        "not_supported": "generate"
    }
)

workflow.add_conditional_edges(
    "prepare_for_final_grade",
    grade_generation_v_question,
    {
        "useful": END,
        "not useful": "transform_query",
    }
)

app = workflow.compile()

#### Run 
Now we will run the grpah with the provided inputs

In [10]:
# Run
inputs = {"keys": {"question": "Explain how the different types of agent memory work?"}}
for output in app.stream(inputs):
    for key, value in output.items():
        # Node
        pprint.pprint(f"Node '{key}':")
        # Optional: print full state at each node
        # pprint.pprint(value["keys"], indent=2, width=80, depth=None)
    pprint.pprint("\n---\n")

# Final generation
pprint.pprint(value['keys']['generation'])

--retrieve--
"Node 'retrieve':"
'\n---\n'
--check relevance--
-- document relevant --
-- document relevant --
-- document relevant --
-- document relevant --
"Node 'grade_documents':"
'\n---\n'
--decide to generate--
--generate--
--generate--
"Node 'generate':"
'\n---\n'
grade generation vs documents
decision supported, move to final grade
final grade
"Node 'prepare_for_final_grade':"
'\n---\n'
--grade generation vs question--
decision useful
"Node '__end__':"
'\n---\n'
(' The different types of agent memory include short-term memory, which is '
 'used for in-context learning, and long-term memory, which retains and '
 'recalls information over extended periods. Short-term memory is like the '
 "model's working memory, while long-term memory is often supported by an "
 'external vector store and fast retrieval. Additionally, the agent can use '
 'tools to call external APIs for extra information that is missing from the '
 'model weights.')
