<a href="https://colab.research.google.com/github/Sudip-8345/Corrective-RAG/blob/main/Corrective_RAG.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q langchain langchain-google-genai langchain_community langchain-core langchain-chroma langchain-text-splitters langchain-huggingface langchain-groq

In [None]:
!pip install -q sentence-transformers

In [None]:
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Chroma

urls = [
    "https://lilianweng.github.io/posts/2023-06-23-agent/",
    "https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/",
    "https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/",
]

docs = [WebBaseLoader(url).load() for url in urls]
docs_list = [item for sublist in docs for item in sublist]

text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=250, chunk_overlap=0
)
doc_splits = text_splitter.split_documents(docs_list)
print(len(doc_splits))

In [None]:
from langchain_huggingface import HuggingFaceEmbeddings

embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")

In [None]:
# Add to vectorDB
vectorstore = Chroma.from_documents(
    documents=doc_splits,
    collection_name="rag-chroma",
    embedding=embeddings,
)
retriever = vectorstore.as_retriever()

In [None]:
import os
GOOGLE_API_KEY="your gemini api key"
GROQ_API_KEY="your groq api key"
os.environ["GOOGLE_API_KEY"] = GOOGLE_API_KEY
os.environ["GROQ_API_KEY"] = GROQ_API_KEY

In [None]:
### Retrieval Grader
from langchain_core.prompts import ChatPromptTemplate
from pydantic import BaseModel, Field
from langchain_groq import ChatGroq

# Data model
class GradeDocuments(BaseModel):
    """Binary score for relevance check on retrieved documents."""

    score: float = Field(
        description="Document's relevancy to the question between 0 and 1 "
    )


# LLM with function call
llm = ChatGroq(model="llama-3.1-8b-instant")
structured_llm_grader = llm.with_structured_output(GradeDocuments)

# Prompt
system = """You are a grader assessing relevance of a retrieved document to a user question. \n
    If the document contains keyword(s) or semantic meaning related to the question, grade it as relevant. \n
    Give a floating score from 0 to 1 by answering only the float score to indicate whether the document is relevant to the question."""
grade_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        ("human", "Retrieved document: \n\n {document} \n\n User question: {question}"),
    ]
)

retrieval_grader = grade_prompt | structured_llm_grader
question = "agent memory"
docs = retriever.invoke(question)
doc_txt = docs[0].page_content
print(retrieval_grader.invoke({"question": question, "document": doc_txt}).score)
print(len(docs))

In [None]:
### Generate

from langsmith import Client
from langchain_core.output_parsers import StrOutputParser
from langchain_groq import ChatGroq

# Prompt
client = Client()
prompt = client.pull_prompt("rlm/rag-prompt")

# LLM
llm = ChatGroq(model="llama-3.1-8b-instant", temperature=0)

# Post-processing
def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)


# Chain
rag_chain = prompt | llm | StrOutputParser()

# Run
generation = rag_chain.invoke({"context": docs, "question": question})
print(generation)

In [None]:
### Question Re-writer

# LLM
llm = ChatGroq(model="llama-3.1-8b-instant", temperature=0)

# Prompt
system = """You a question re-writer that converts an input question to a better version that is optimized \n
     for web search in one line. Look at the input and try to reason about the underlying semantic intent / meaning."""
re_write_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        (
            "human",
            "Here is the initial question: \n\n {question} \n Formulate an improved question.",
        ),
    ]
)

question_rewriter = re_write_prompt | llm | StrOutputParser()
question_rewriter.invoke({"question": question})

In [None]:
!pip install -q langchain-tavily tavily

In [None]:
os.environ['TAVILY_API_KEY']="your tavily api key"

In [None]:
### Search
from langchain_tavily import TavilySearch
web_search_tool = TavilySearch(
    max_results=5,
    topic="general"
)

In [None]:
from typing import TypedDict, Optional, Annotated, List
from langchain_core.documents import Document

In [None]:
class GraphState(TypedDict):
  question : str
  generation : str
  web_search : str
  documents : List[str]

In [None]:
def retrieve(state: GraphState):
  documents = retriever.invoke(state['question'])
  # docs = [doc.metadata['description'] for doc in documents]
  return {"documents" : documents}

In [None]:
def generate(state: GraphState):

    docs = state["documents"]

    if len(docs) == 0:
        print("⚠️ No documents to generate from")
        return {"generation": "No relevant documents found."}

    formatted_docs = "\n\n".join(
        doc.page_content for doc in docs
    )

    ans = rag_chain.invoke({
        "context": formatted_docs,
        "question": state["question"]
    })

    return {"generation": ans}

In [None]:
def grade_documents(state: GraphState):
    web_search = "no"
    filtered_docs = []
    relv = []
    for doc in state['documents']:
        grade = retrieval_grader.invoke({
            "question": state["question"],
            "document": doc.page_content   # ✅ FIXED
        })
        if grade.score > 0.6:
          filtered_docs.append(doc)
        relv.append(grade.score)
    if any(i>0.70 for i in relv):
      web_search = "yes"

    return {
        "documents": filtered_docs,
        "web_search": web_search
    }

In [None]:
def transform_question(state:GraphState):
  refined_question = question_rewriter.invoke(state['question'])
  return {'question' : refined_question}

In [None]:
def web_search(state:GraphState):
  documents = state['documents']
  docs = web_search_tool.invoke(state['question'])
  results = docs["results"]
  web_res = "\n".join([d["content"] for d in results])
  web_res = Document(web_res)
  documents.append(web_res)
  return {'documents': documents}

In [None]:
def decide_to_generate(state:GraphState):
  if state['web_search'] == 'yes':
    print("---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, TRANSFORM QUERY---")
    return 'transform_query'
  else:
    print("---DECISION: GENERATE---")
    return 'generate'

In [None]:
from langgraph.graph import StateGraph, START, END

graph = StateGraph(GraphState)
graph.add_node('retrieve', retrieve)
graph.add_node('generate',generate)
graph.add_node('grade',grade_documents)
graph.add_node('web_search', web_search)
graph.add_node('transform_query', transform_question)

graph.add_edge(START, 'retrieve')
graph.add_edge('retrieve', 'grade')
graph.add_conditional_edges(
    'grade',
    decide_to_generate,
     {
        "transform_query": "transform_query",
        'generate' : 'generate'
    }
)
graph.add_edge('transform_query', 'web_search')
graph.add_edge('web_search', 'generate')
graph.add_edge('generate', END)

workflow = graph.compile()
workflow

In [None]:
# Run
inputs = {"question": "agent memory?"}
final_state = workflow.invoke(inputs)
print('Final State : ')
print(final_state)

In [None]:
# Run
inputs = {"question": "agent memory?"}
for ans in workflow.stream(inputs):
  for key, value in ans.items():
    print(f"Node '{key}':")
  print("\n---\n")
print('Final Answer : ')
print(ans)

In [None]:
# Run
inputs = {"question": "what is ensemble learning"}
for ans in workflow.stream(inputs):
  for key, value in ans.items():
    print(f"Node '{key}':")
  print("\n---\n")
print('Final Answer : ')
print(ans)