<a href="https://colab.research.google.com/github/TianYao12/Book-Recommender/blob/main/crag.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
! pip install langchain_community tiktoken langchain-openai langchainhub chromadb langchain langgraph tavily-python




In [52]:
import getpass
import os

def _set_env(key: str):
    if key not in os.environ:
        os.environ[key] = getpass.getpass(f"{key}:")


_set_env("OPENAI_API_KEY")
_set_env("TAVILY_API_KEY")

In [53]:
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Chroma
from langchain_openai import OpenAIEmbeddings

urls = [
    "https://www.gutenberg.org/files/1232/1232-h/1232-h.htm",
    # "https://www.gutenberg.org/files/1998/1998-h/1998-h.htm",
    # "https://www.gutenberg.org/files/1497/1497-h/1497-h.htm"
]

docs = [WebBaseLoader(url).load() for url in urls]
docs_list = [item for sublist in docs for item in sublist]

text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
    chunk_size=250, chunk_overlap=0
)
doc_splits = text_splitter.split_documents(docs_list)

vectorstore = Chroma.from_documents(
    documents=doc_splits,
    collection_name="rag-chroma",
    embedding=OpenAIEmbeddings(),
)
retriever = vectorstore.as_retriever()

Use structured LLM function call to produce a binary (yes/no) score representing if a retrieved document is **relevant to a given user query**.


In [63]:
from langchain_core.prompts import ChatPromptTemplate
from pydantic import BaseModel, Field
from langchain_openai import ChatOpenAI

class GradeDocuments(BaseModel):
    """Binary score for relevance check on retrieved documents."""
    binary_score: str = Field(
        description="Documents are relevant to the question, 'yes' or 'no'"
    )


llm = ChatOpenAI(model="gpt-4o-2024-08-06", temperature=0)
structured_llm_grader = llm.with_structured_output(GradeDocuments)

system = """You are a grader assessing relevance of a retrieved document to a user question. \n
    If the document contains keyword(s) or semantic meaning related to the question, grade it as relevant. \n
    Give a binary score, either 'Yes' or 'No' to indicate that the document is relevant to the question. Don't be too strict, if it is related enough, say Yes."""

grade_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        ("human", "Retrieved document: \n\n {document} \n\n User question: {question}"),
    ]
)

retrieval_grader = grade_prompt | structured_llm_grader

## Retrieve documents according to question
question = "Who was Sicilian"
docs = retriever.get_relevant_documents(question)
for i, doc in enumerate(docs):
    print(f"Document {i+1}:\n{doc.page_content[:100]}...\n")
doc_txt = docs[1].page_content

## get evaluation
binary_score = retrieval_grader.invoke({"question": question, "document": doc_txt})
print(binary_score)

Document 1:
I was not intending to go beyond Italian and recent examples, but I am
unwilling to leave out Hiero...

Document 2:
I was not intending to go beyond Italian and recent examples, but I am
unwilling to leave out Hiero...

Document 3:
I was not intending to go beyond Italian and recent examples, but I am
unwilling to leave out Hiero...

Document 4:
I was not intending to go beyond Italian and recent examples, but I am
unwilling to leave out Hiero...

binary_score='Yes'


In [64]:
from langchain import hub
from langchain_core.output_parsers import StrOutputParser

prompt = hub.pull("rlm/rag-prompt")
llm = ChatOpenAI(model_name="gpt-4o", temperature=0)

## joins multiple retrieved document texts into a single string
def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)

rag_chain = prompt | llm | StrOutputParser()

generation = rag_chain.invoke({"context": docs, "question": question})
print(generation)



Hiero, the Syracusan, was Sicilian.


In [65]:
llm = ChatOpenAI(model="gpt-4o-2024-08-06", temperature=0)

system = """You are a question re-writer whose  goal is to refine the input question by improving clarity, relevance, and searchability while preserving its original intent.
            Consider synonyms, rephrasing, and structuring for optimal search engine results."""
re_write_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        (
            "human",
            "Here is the initial question: \n\n {question} \n Formulate an improved question.",
        ),
    ]
)

question_rewriter = re_write_prompt | llm | StrOutputParser()
question_rewriter.invoke({"question": question})

'Who were notable historical figures from Sicily?'

Web Search Tool

In [66]:
from langchain_community.tools.tavily_search import TavilySearchResults
web_search_tool = TavilySearchResults(k=3)

Create Graph

In [67]:
from typing import List
from typing_extensions import TypedDict

class GraphState(TypedDict):
    """
    Represents the state of our graph.

    Attributes:
        question: question
        generation: LLM generation
        web_search: whether to add search
        documents: list of documents
    """
    question: str
    generation: str
    web_search: str
    documents: List[str]

In [68]:
from langchain.schema import Document

def retrieve(state):
    """
    Retrieve documents
    Args: state (dict): current graph state
    Returns: state (dict): The key "documents" is added to state, which contains retrieved documents
    """

    print("Retrieving...")
    question = state["question"]
    documents = retriever.get_relevant_documents(question)
    return {"documents": documents, "question": question}

def generate(state):
    """
    Args: state (dict): The current graph state
    Returns: state (dict): New key added to state, generation, that contains LLM generation
    """
    print("Generating...")

    question = state["question"]
    documents = state["documents"]
    generation = rag_chain.invoke({"context": documents, "question": question})
    return {"documents": documents, "question": question, "generation": generation}


def grade_documents(state):
    """
    Determines whether the retrieved documents are relevant to the question.
    Args: state (dict): current graph state
    Returns: state (dict): Updates documents key with only filtered relevant documents
    """
    print("Checking document relevance to question...")
    question = state["question"]
    documents = state["documents"]

    filtered_docs = []
    web_search = "No"
    for d in documents:
        score = retrieval_grader.invoke(
            {"question": question, "document": d.page_content}
        )
        grade = score.binary_score
        if grade == "yes":
            print("Grade: Relevant")
            filtered_docs.append(d)
        else:
            print("Grade: Irrelevant")
            web_search = "Yes"
            continue
    return {"documents": filtered_docs, "question": question, "web_search": web_search}


def transform_query(state):
    """
    Transform the query to produce a better question.
    Args: state (dict): current graph state
    Returns: state (dict): Updates question key with a re-phrased question
    """

    print("Transforming query...")

    question = state["question"]
    documents = state["documents"]

    better_question = question_rewriter.invoke({"question": question})
    return {"documents": documents, "question": better_question}

def web_search(state):
    """
    Web search based on the re-phrased question.
    Args: state (dict): current graph state

    Returns: state (dict): Updates documents key with appended web results
    """

    print("Searching web...")
    question = state["question"]
    documents = state["documents"]

    docs = web_search_tool.invoke({"query": question})
    web_results = "\n".join([d["content"] for d in docs])
    web_results = Document(page_content=web_results)
    documents.append(web_results)
    return {"documents": documents, "question": question}

def decide_to_generate(state):
    """
    Determines whether to generate an answer, or re-generate a question.
    Args: state (dict): current graph state
    Returns: str: Binary decision for next node to call
    """

    print("Assessing graded documents...")

    state["question"]
    web_search = state["web_search"]
    state["documents"]

    if web_search == "Yes":
        # All documents have been filtered check_relevance
        # We will re-generate a new query
        print(
            "---DECISION: All documents are not relevant to question. Transforming query..."
        )
        return "transform_query"
    else:
        # We have relevant documents, so generate answer
        print("---DECISION: GENERATE---")
        return "generate"

In [69]:
from langgraph.graph import END, StateGraph, START

workflow = StateGraph(GraphState)

## Define nodes
workflow.add_node("retrieve", retrieve)
workflow.add_node("grade_documents", grade_documents)
workflow.add_node("generate", generate)
workflow.add_node("transform_query", transform_query)
workflow.add_node("web_search_node", web_search)

# Build graph
workflow.add_edge(START, "retrieve")
workflow.add_edge("retrieve", "grade_documents")
workflow.add_conditional_edges(
    "grade_documents",
    decide_to_generate,
    {
        "transform_query": "transform_query",
        "generate": "generate",
    },
)
workflow.add_edge("transform_query", "web_search_node")
workflow.add_edge("web_search_node", "generate")
workflow.add_edge("generate", END)

app = workflow.compile()

In [70]:
from pprint import pprint

inputs = {"question": "Who was in Sicily"}
for output in app.stream(inputs):
    for key, value in output.items():
        pprint(f"Node '{key}':")
        # pprint.pprint(value["keys"], indent=2, width=80, depth=None) # full state at each node
    pprint("\n-------------\n")

# Final generation
pprint(value["generation"])

Retrieving...
"Node 'retrieve':"
'\n-------------\n'
Checking document relevance to question...
Grade: Irrelevant
Grade: Irrelevant
Grade: Irrelevant
Grade: Irrelevant
Assessing graded documents...
---DECISION: All documents are not relevant to question. Transforming query...
"Node 'grade_documents':"
'\n-------------\n'
Transforming query...
"Node 'transform_query':"
'\n-------------\n'
Searching web...
"Node 'web_search_node':"
'\n-------------\n'
Generating...
"Node 'generate':"
'\n-------------\n'
('Notable historical figures and groups present in Sicily include Frederick '
 'II, the Phoenicians, Carthaginians, Greeks, Romans, Vandals, Ostrogoths, '
 'Byzantines, Arabs, Normans, Aragonese, Spanish, Austrians, and British. '
 'Additionally, indigenous groups such as the Sicanians, Elymians, and Sicels, '
 'as well as the Greek-Siceliotes, played significant roles in its history.')
