#### Self-RAG
Self-reflection can enhance RAG, enabling correction of poor quality retrieval or generations.
Several recent papers focus on thes theme, but implementing the ideas can be tricky. 

Self RAG details

Self-RAG is a recent paper that introduces an interesting approach for self-reflection RAG.
1. Should I retrieve for retriever, `R`
   - Token: Retrieve
   - Input `x (question)` or `x (question)`, `y(generation)`
   - Decides when to retrieve `D` chunks with `R`
   - output `yes, no, continue`
2. Are the retrieved pages `D` relevant to the question `x`
   - Token: `ISREL`
   - Input (`x (question)`, `d(chunk)` for `d` in `D`)
   - `d` provides useful information to solve `x`
   - Output: `relevant, irrelevant`
3. Are the LLM generation from each chunk in `D` is relevant to the chunk (hallucinations, etc)
   - Token: `ISSUP`
   - Input `x (question)`, `d(chunk)`, `y(generation)` for `d` in `D`
   - All of the verification-worthy statements in `y(generation)` are supported by `d`
   - Output: `{fully supported, partially supported, no support}`
4. The LLM generation from each chunk in `D` is a useful response to `x(question)`
    - Token: `ISREL`
    - Input `x(question)`, `y(generation)` for `d` in `D`
    - `y(generation)` is useful response to `x(question)`
    - output:  `{5, 4, 3, 2, 1}`

We can represent this as a graph

In [1]:
from dotenv import load_dotenv
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores.chroma import Chroma
from langchain_community.embeddings import GPT4AllEmbeddings

urls = [
    "https://lilianweng.github.io/posts/2023-06-23-agent/",
    "https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/",
    "https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/",
]

docs = [WebBaseLoader(url).load() for url in urls]
docs_list = [item for sublist in docs for item in sublist]

text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
    chunk_size=250, chunk_overlap=0
)

doc_splits = text_splitter.split_documents(docs_list)

vectorstores = Chroma.from_documents(
    documents=doc_splits,
    collection_name="rag-chroma",
    embedding=GPT4AllEmbeddings()
)

retriever = vectorstores.as_retriever()

bert_load_from_file: gguf version     = 2
bert_load_from_file: gguf alignment   = 32
bert_load_from_file: gguf data offset = 695552
bert_load_from_file: model name           = BERT
bert_load_from_file: model architecture   = bert
bert_load_from_file: model file type      = 1
bert_load_from_file: bert tokenizer vocab = 30522


#### State
We will define a graph.
Our state will be `dict`
We can access this from any graph node as `state['keys']`



In [2]:
from typing import Dict, TypedDict, Any
from langchain_core.messages import BaseMessage


class GraphState(TypedDict):
    """
    Represents the state of our graph
    """
    keys: Dict[str, Any]

#### Nodes and Edges
Each `node` will modify the `state`
Each `edge` will choose which `node` to call next
We can lay out `self-RAG` as a graph
Here is our graph flow

In [3]:
import json
import operator
from typing import Annotated, Sequence, TypedDict

from langchain import hub
from langchain.output_parsers.openai_tools import PydanticToolsParser
from langchain.prompts import PromptTemplate
from langchain_community.vectorstores.chroma import Chroma
from langchain_core.output_parsers import StrOutputParser
from langchain_core.messages import BaseMessage, FunctionMessage
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.runnables import RunnablePassthrough
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_openai.chat_models import ChatOpenAI
from langchain_community.embeddings import GPT4AllEmbeddings

In [20]:
def retrieve(state: GraphState):
    """ 
    Retrieve documents
    """
    print("--retrieve--")
    state_dict = state["keys"]
    question = state_dict["question"]
    documents = retriever.get_relevant_documents(question)
    return {"keys": {
        "documents": documents,
        "question": question
    }}

In [66]:
def generate(state: GraphState):
    """ 
    Generate answer
    """
    print("--generate--")
    state_dict = state["keys"]
    question = state_dict["question"]
    documents = state_dict["documents"]
    
    prompt = hub.pull("rlm/rag-prompt")
    
    llm = ChatOpenAI(model="mistralai/Mixtral-8x7B-Instruct-v0.1")
    
    def format_docs(docs):
        return "\n\n".join(doc.page_content for doc in docs)
    
    rag_chain = prompt | llm | StrOutputParser()
    
    generation = rag_chain.invoke({"context": documents, "question": question})
    return {
        "keys": {
            "documents": documents,
            "generation": generation,
            "question": question,
        }
    }

In [67]:
def grade_documents(state: GraphState):
    """
    Determin whether the retrieved documents are relevant to the question
    """
    print("check relevance")
    state_dict = state["keys"]
    question = state_dict["question"]
    documents = state_dict["documents"]
    
    class Grade(BaseModel):
        """ Binary score for relevance check """
        binary_score: str = Field(description="Relevance score 'yes' or 'no'")
        
    model = ChatOpenAI(model="mistralai/Mixtral-8x7B-Instruct-v0.1")
    
    grade_tool_oai = convert_to_openai_tool(Grade)
    
    llm_with_tool = model.bind(
        tools=[convert_to_openai_tool(grade_tool_oai)],
        tool_choice={"type": "function", "function": {
            "name": "grade"
        }}
    )
    
    parser_tool = PydanticToolsParser(tools=[Grade])
    
    prompt = PromptTemplate(
        template="""You are a grader assessing relevance of a retrieved document to a user question. \n 
        Here is the retrieved document: \n\n {context} \n\n
        Here is the user question: {question} \n
        If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n
        Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.""",
        input_variables=["context", "question"]
    )
    
    chain = prompt | llm_with_tool | parser_tool
    
    filtered_docs = []
    
    for d in documents:
        try:
            score = chain.invoke({"question": question, "context": d.page_content})
            grade = score[0].binary_score
        except Exception as e: 
            grade = "yes"
        
        if grade == "yes":
            print("document relevant")
            filtered_docs.append(d)
        else:
            print("document not relevant")
            continue
    
    return {"keys": {
        "documents": filtered_docs,
        "question": question
    }}

In [68]:
def transform_query(state: GraphState):
    """ 
    Transform the query to produce a better question
    """
    print("tranform query")
    state_dict = state["keys"]
    question = state_dict["question"]
    documents = state_dict["documents"]
    
    
    prompt = PromptTemplate(
        template="""You are generating questions that is well optimized for retrieval. \n 
        Look at the input and try to reason about the underlying sematic intent / meaning. \n 
        Here is the initial question:
        \n ------- \n
        {question} 
        \n ------- \n
        Formulate an improved question: """,
        input_variables=["question"]
    )
    
    model = ChatOpenAI(model="mistralai/Mixtral-8x7B-Instruct-v0.1")
    
    chain = prompt | model | StrOutputParser()
    
    better_question = chain.invoke({"question": question})
    return {    
                "keys": {
                "documents": documents,
                "question": better_question
        }}
    
    
    
def prepare_for_final_grade(state: GraphState):
    """
    Passthrough state for final grade.
    """
    print("final grade")
    state_dict = state["keys"]
    question = state_dict["question"]
    documents = state_dict["documents"]
    generation = state_dict["generation"]
    
    return {
        "keys": {"documents": documents, "question": question, "generation": generation}
    }

In [69]:
def decide_to_generate(state: GraphState):
    """ 
    Determine whether to generate an answer, or re-generate a question
    """
    print("decide to generate")
    state_dict = state["keys"]
    question = state_dict["question"]
    filtered_documents = state_dict["documents"]
    
    if not filtered_documents:
        print("tranform query")
        return "transform_query"
    
    else:
        print("generate")
        return "generate"

In [70]:
def grade_generate_v_documents(state: GraphState):
    """
    Determines whether the generation is grounded in the document
    """
    print("generation vs documents")
    state_dict = state["keys"]
    question = state_dict["question"]
    documents = state_dict["documents"]
    generation = state_dict["generation"]
    
    
    class Grade(BaseModel):
        """ 
        Binary score for relevance check
        """
        binary_score: str = Field(description="Supported score 'yes' or 'no'")
        
    model = ChatOpenAI(model="mistralai/Mixtral-8x7B-Instruct-v0.1")
    grade_tool_oai = convert_to_openai_tool(Grade)
    
    llm_with_tool = model.bind(
        tools=[convert_to_openai_tool(grade_tool_oai)],
        tool_choice={"type": "function", "function": {
            "name": "grade"
        }}
    )
    
    parser_tool = PydanticToolsParser(tools=[Grade])
    
    prompt = PromptTemplate(
        template="""You are a grader assessing whether an answer is grounded in / supported by a set of facts. \n 
        Here are the facts:
        \n ------- \n
        {documents} 
        \n ------- \n
        Here is the answer: {generation}
        Give a binary score 'yes' or 'no' to indicate whether the answer is grounded in / supported by a set of facts.""",
        input_variables=["generation", "documents"]

    )
    
    chain = prompt | llm_with_tool | parser_tool
    
    score = chain.invoke({"generation": generation, "documents": documents, "question": question})
    
    grade = score[0].binary_score
    
    if grade == "yes":
        print("move to final grade")
        return "supported"
    
    else:
        print("not supported, generate again")
        return "not supported"
    
    

In [75]:
def grade_generate_v_question(state: GraphState):
    """
    Determine whether the generation addersses the question
    """
    print("grade generation vs question")
    state_dict = state["keys"]
    
    question = state_dict["question"]
    documents = state_dict["documents"]
    generation = state_dict["generation"]
    
    class Grade(BaseModel):
        """
        Binary score for relevance check
        """
        binary_score: str = Field(description="Useful score 'yes' or 'no'")
        
    model = ChatOpenAI(model="mistralai/Mixtral-8x7B-Instruct-v0.1")
    
    grade_tool_oai = convert_to_openai_tool(Grade)
    
    llm_with_tool = model.bind(
        tools=[convert_to_openai_tool(grade_tool_oai)],
        tool_choice = {
            "type" : "function",
            "function": {
                "name": "grade"
            }
        }
    )
    
    parser_tool = PydanticToolsParser(tools=[Grade])
    
    prompt = PromptTemplate(
        template=""" 
        You are a grader assessing whether an answer is useful to resolve a question. \n 
        Here is the answer:
        \n ------- \n
        {generation} 
        \n ------- \n
        Here is the question: {question}
        Give a binary score 'yes' or 'no' to indicate whether the answer is useful to resolve a question.
        """,
        input_variables=["generation", "question"]
    )
    
    chain = prompt | llm_with_tool | parser_tool
    
    try:
        score = chain.invoke({"generation": generation, "question": question})
        grade = score[0].binary_score
    except Exception as e:
        grade = "yes"
    
    if grade == "yes":
        print("decision useful")
        return "useful"
    
    else:
        print("not useful")
        return "not useful"

##### Build Graph

The just follows the flow we outlined in the figure above

In [76]:
import pprint

from langgraph.graph import END, StateGraph

workflow = StateGraph(GraphState)

workflow.add_node("retrieve", retrieve)
workflow.add_node("grade_documents", grade_documents)
workflow.add_node("generate", generate)
workflow.add_node("transform_query", transform_query)
workflow.add_node("prepare_for_final_grade", prepare_for_final_grade)

In [77]:
workflow.set_entry_point("retrieve")
workflow.add_edge("retrieve", "grade_documents")
workflow.add_conditional_edges(
    "grade_documents", 
    decide_to_generate,
    {
        "transform_query": "transform_query",
        "generate": "generate",
    }
)


workflow.add_edge("transform_query", "retrieve")
workflow.add_conditional_edges(
    "generate",
    grade_generate_v_documents,
    {
        "supported": "prepare_for_final_grade",
        "not supported": "generate",
    }
)

workflow.add_conditional_edges(
    "prepare_for_final_grade",
    grade_generate_v_question,
    {
        "useful": END,
        "not useful": "transform_query"
    }
)

app = workflow.compile()

In [78]:
# Run
inputs = {"keys": {"question": "Explain how the different types of agent memory work?", "documents": []}}
for output in app.stream(inputs):
    for key, value in output.items():
        # Node
        pprint.pprint(f"Node '{key}':")
        # Optional: print full state at each node
        # pprint.pprint(value["keys"], indent=2, width=80, depth=None)
    pprint.pprint("\n---\n")

# Final generation
pprint.pprint(value["keys"]["generation"])

--retrieve--
"Node 'retrieve':"
'\n---\n'
check relevance
document relevant
document relevant
document relevant
document relevant
"Node 'grade_documents':"
'\n---\n'
decide to generate
generate
--generate--
"Node 'generate':"
'\n---\n'
generation vs documents
move to final grade
final grade
"Node 'prepare_for_final_grade':"
'\n---\n'
grade generation vs question
decision useful
"Node '__end__':"
'\n---\n'
(' In a LLM-powered autonomous agent system, memory is a key component that '
 'works alongside the large language model (LLM) brain of the agent. There are '
 'different types of memory:\n'
 '\n'
 '1. Sensory Memory: This is the earliest stage, providing the ability to '
 'retain impressions of sensory information for up to a few seconds.\n'
 "2. Long-term memory, called the 'memory stream', is an external database "
 'that records a comprehensive list of agents’ experience in natural '
 'language.\n'
 '\n'
 'Additional memory mechanisms include a retrieval model for surfacing '
 'co