# Simple _**agentic**_ RAG demo

Or more precisely: **CRAG** (_Corrective RAG_) - a technique for a little bit smarter RAG

<a target="_blank" href="https://githubtocolab.com/IT-HUSET/ai-agenter-2025/blob/main/exercises/langgraph/extras/simple-rag-agent-demo.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a><br/>

![Corrective RAG](https://github.com/IT-HUSET/ai-workshop-250121/blob/main/images/crag-flow.png?raw=true)

## Setup

### Install dependencies

In [None]:
%pip install openai~=2.0 httpx~=0.28.1 --upgrade --quiet
%pip install python-dotenv~=1.0 --upgrade --quiet
%pip install python-dotenv~=1.0 docarray~=0.41.0 pypdf~=6.1 --upgrade --quiet
%pip install chromadb~=1.1.1 lark~=1.3 --upgrade --quiet
%pip install langchain~=0.3 langchain_openai~=0.3 langchain_community~=0.3.31 langchain-chroma~=0.2.6 --upgrade --quiet
%pip install langgraph~=0.6 --upgrade --quiet

# If running locally, you can do this instead:
#%uv sync

### Load environment variables

In [None]:
import os

# Check if running in Google Colab
try:
    from google.colab import userdata
    IN_COLAB = True
    # Get API key from Colab secrets
    os.environ["OPENAI_API_KEY"] = userdata.get('OPENAI_API_KEY')
    print("✅ Running in Google Colab - API key loaded from secrets")
except ImportError:
    IN_COLAB = False
    # Load from .env file for local development
    try:
        from dotenv import load_dotenv, find_dotenv
        load_dotenv(find_dotenv())
        print("✅ Running locally - API key loaded from .env file")
    except ImportError:
        print("⚠️ python-dotenv not installed. Install with: pip install python-dotenv")

# Verify API key is set
if not os.environ.get("OPENAI_API_KEY"):
    print("❌ OPENAI_API_KEY not found!")
    if IN_COLAB:
        print("   → Click the key icon (🔑) in the left sidebar")
        print("   → Add a secret named 'OPENAI_API_KEY'")
        print("   → Toggle 'Notebook access' to enable it")
    else:
        print("   → Create a .env file with: OPENAI_API_KEY=your-key-here")
else:
    print("✅ API key configured!")

### Setup Chat Model

In [None]:
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0.0)
embedding_model = OpenAIEmbeddings(model="text-embedding-3-large")

## Setup ingestion / retrieval pipeline

### Setup vector DB (Chroma)

In [None]:
from langchain_chroma import Chroma

persist_directory = './db/simle_rag_agent_demo/'

# Optionally remove the directory and all files in it recursively if it exists
# import shutil
# import os
# if os.path.exists(persist_directory):
#     shutil.rmtree(persist_directory)

vectordb: Chroma = Chroma(
    collection_name="simle_rag_agent_demo",
    embedding_function=embedding_model,
    persist_directory=persist_directory # Optionally persist the database
)

retriever = vectordb.as_retriever()

### Setup a text splitter

In [None]:
from langchain.text_splitter import RecursiveCharacterTextSplitter
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size = 1000,
    chunk_overlap = 80
)

### Setup documents to load

In [None]:
# Documents to load (tuple of document_id and document_url)
from dataclasses import dataclass

@dataclass
class DocInfo:
    id: str
    url: str

documents_to_load: list[DocInfo] = [
    DocInfo("1", "https://data.riksdagen.se/fil/B9E2F955-31EA-4E9E-91EB-9AE0A3A8FFA7"), # Förordning om artificiell intelligens, 2020/21:FPM109
    DocInfo("2", "https://data.riksdagen.se/fil/C40BB689-7E23-4593-BDC6-DBEE327C00C6"), # Risker och möjligheter med artificiell intelligens, 2022/23:374
    DocInfo("3", "https://data.riksdagen.se/fil/4C47740C-D13E-4E22-80CB-43DD1E101080"), # Direktiv om skadeståndsansvar gällande artificiell intelligens,
    #DocInfo("10", "https://data.riksdagen.se/fil/BECC9F0F-3DA1-4F44-9417-02DF027DA29C"), # VITBOK Om artificiell intelligens - en EU-strategi för spetskompetens och förtroende# 2022/23:FPM8
]

### Ingest - split and add to vector index

In [None]:
from langchain_community.document_loaders import PyPDFLoader
import time

def ingest_documents(doc_info: DocInfo):
    '''Helper function to ingest a document into the vector database'''

    # Check if document already exists
    existing = vectordb.get(where={"doc_id": doc_info.id})
    if existing["documents"]:
        print(f"Document {doc_info.id} already exists in index")
        return

    # Load
    print(f"Loading document {doc_info.id} ({doc_info.url})...")
    loader = PyPDFLoader(doc_info.url)
    pages = loader.load()
    for page in pages:
        page.metadata["doc_id"] = doc_info.id

    # Split
    doc_splits = text_splitter.split_documents(pages)

    # Add to index
    print(f"Adding document {doc_info.id} ({doc_info.url}) to index...")

    # Add in batches, with delay, to avoid rate limiting
    batch_size = 10
    for i in range(0, len(doc_splits), batch_size):
        batch = doc_splits[i:i + batch_size]
        vectordb.add_documents(documents=batch)
        print(f"Added splits {i} to {i + batch_size}")
        time.sleep(0.1)

    print(f"Added document {doc_info.id} ({doc_info.url}) ({len(pages)} pages) - {len(doc_splits)} splits")


for doc_info in documents_to_load:
    ingest_documents(doc_info)

## Setup query graph / pipeline

### Graph state

In [None]:
from typing import  List

from langchain_core.documents import Document
from langgraph.graph import MessagesState


class GraphState(MessagesState):
    question: str
    documents: List[Document]
    irrelevant_docs: bool
    answer: str


### Nodes

#### Retrieval (Vector Store similarity search)

In [None]:
class RetrievalNode:
    def __call__(self, state: GraphState):
        print("---RETRIEVE---")
        question = state["question"]

        # Retrieval
        documents = retriever.invoke(question)

        print(f"---RETRIEVED {len(documents)} DOCS---")
        #print(f"{documents}")

        return {"documents": documents}

#### Retrieval Grader

In [None]:
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import Runnable
from langchain_core.output_parsers import StrOutputParser

class RetrievalGraderNode:
    system_template = """You are a grader assessing relevance of a retrieved document to a user question.
    If the document contains keyword(s) or semantic meaning related to the question, grade it as relevant.
    Give a binary score '1' or '0' score to indicate whether the document is relevant to the question.

    **Retrieved document:** \n\n {document}
    """

    prompt = ChatPromptTemplate.from_messages(
        [
            ("system", system_template),
            ("human", "{question}"),
        ]
    )

    chain: Runnable

    def __init__(self):
        self.chain = self.prompt | llm | StrOutputParser()

    def __call__(self, state: GraphState):
        print("---CHECK DOCUMENT RELEVANCE TO QUESTION---")
        question = state["question"]
        documents = state["documents"]

        # Score each doc
        filtered_docs = []
        irrelevant_docs = True

        for (i, d) in enumerate(documents):
            grade = self.chain.invoke(
                {"question": question, "document": d.page_content}
            )
            if "1" in grade:
                print(f"---GRADE: DOCUMENT {i} RELEVANT---")
                filtered_docs.append(d)
                irrelevant_docs = False
            else:
                print(f"---GRADE: DOCUMENT {i} NOT RELEVANT---")
                continue

        return {"documents": filtered_docs, "irrelevant_docs": irrelevant_docs}

#### Web Search (in case irrelevnat relevant docs were found)

As a fallback when there are no/few relevant docs, we can use a web search tool to find more information. In this case, we'll use a fake web search node (LLM call).

In [None]:
class FakeWebSearchNode:
    system_template = """You are a helpful and cheerful assistant."""

    prompt = ChatPromptTemplate.from_messages(
        [
            ("system", system_template),
            ("human", "{question}"),
        ]
    )

    chain: Runnable

    def __init__(self):
        self.chain = self.prompt | llm.bind(temperature=1.0) | StrOutputParser()

    def __call__(self, state: GraphState):
        print("---FAKE WEB SEARCH---")
        question = state["question"]

        web_results = self.chain.invoke({"question": question})

        print(f"---FAKE WEB SEARCH RESULT: \n{web_results}")

        web_results = [Document(page_content=web_results)]

        return {"documents": web_results, "question": question}

#### RAG Generation (LLM call with factual/grounded context)

In [None]:
class RAGNode:
    system_template = """You are an helpful assistant, expert in answering questions based on provided sources (snippets from documents) and citing the sources used to generate the answer. If you don't know the answer, just say that you don't know, don't try to make up an answer. Use three sentences maximum. Keep the answer as concise as possible.
    ALWAYS respond in the SAME language as the original question.

    ** Context (snippets from documents): **

    {context}
    """

    prompt = ChatPromptTemplate.from_messages(
        [
            ("system", system_template),
            ("human", "{question}"),
        ]
    )

    chain: Runnable

    def __init__(self):
        self.chain = self.prompt | llm | StrOutputParser()

    def __call__(self, state: GraphState):
        print("---GENERATE---")
        question = state["question"]
        documents = state["documents"]

        # RAG generation - setup context (i.e. relevant documents snippets)
        context = "\n\n".join(doc.page_content for doc in documents)

        # RAG generation - generate answer
        answer = self.chain.invoke({"question": question, "context": context})
        #print(f"---GENERATE - ANSWER: \n{answer}")

        return {"documents": documents, "answer": answer}

### Conditional edges

In [None]:
def decide_to_generate(state):
    print("---ASSESS GRADED DOCUMENTS---")
    irrelevant_docs: bool = state["irrelevant_docs"]

    if irrelevant_docs:
        print(
            "---DECISION: USE WEB SEARCH---"
        )
        return "fallback"
    else:
        print("---DECISION: GENERATE---")
        return "generate"

### Build Graph

In [None]:
#### Graph ####
from langgraph.graph import END, StateGraph, START
from IPython.display import Image, display

workflow = StateGraph(GraphState)

# Define the nodes
workflow.add_node("retrieve", RetrievalNode())  # retrieve
workflow.add_node("grade_documents", RetrievalGraderNode())  # grade documents
workflow.add_node("web_search", FakeWebSearchNode())  # failed to find matches
workflow.add_node("generate", RAGNode())  # generate

workflow.add_edge(START, "retrieve")
workflow.add_edge("retrieve", "grade_documents")

workflow.add_conditional_edges(
    "grade_documents",
    decide_to_generate,
    {
        "generate": "generate",
        "fallback": "web_search",
    },
)

workflow.add_edge("web_search", "generate")
workflow.add_edge("generate", END)

# Compile
graph = workflow.compile()

# View
display(Image(graph.get_graph().draw_mermaid_png()))


## Use Graph

In [None]:
# Run
inputs = {
    "question": "Har det i riksdagen diskuterats något om risker kring användningen av artificiell intelligens (AI)?"
    #"question": "Vilka var nobelpristagarna 2023?" # Should result in web search
    #"question": "Vad innebär vitboken om artificiell intelligens?" # Should NOT result in web search
}

# Execute graph
result = graph.invoke(inputs)

print(f"--- ANSWER: ---\n{result['answer']}")