In [1]:
import os
os.makedirs("modules", exist_ok=True)

In [2]:
%%writefile modules/utils.py
import os
import warnings
from dotenv import load_dotenv
from langchain_openai import ChatOpenAI
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler

# Suppress TensorFlow warnings
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
warnings.filterwarnings("ignore")

# Load environment variables
load_dotenv()

def get_api_key(key_name="OPENROUTER_API_KEY"):
    """
    Get API key from environment variables

    """
    api_key = os.getenv(key_name)
    
    if not api_key:
        raise ValueError(f"Invalid API key: {key_name} not found in environment variables")
    
    return api_key

def initialize_llm(model_name="meta-llama/llama-3.3-70b-instruct",
                  temperature=0.4,
                  use_streaming=True):
    """
    Initialize LLM

    """
    api_key = get_api_key()
    callbacks = [StreamingStdOutCallbackHandler()]
    
    llm = ChatOpenAI(
        model_name=model_name,
        temperature=temperature,
        streaming=use_streaming,
        callbacks=callbacks,
        openai_api_key=api_key,
        openai_api_base="https://openrouter.ai/api/v1"
    )
    
    return llm


Writing modules/utils.py


In [9]:
%%writefile modules/retriever.py
import os
from typing import List, Optional
from langchain_community.document_loaders import DirectoryLoader, TextLoader
from langchain_community.vectorstores import FAISS, Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.retrievers.document_compressors import (
    EmbeddingsFilter, 
    LLMChainFilter, 
    LLMChainExtractor, 
    DocumentCompressorPipeline
)
from langchain.retrievers import ContextualCompressionRetriever
from langchain_community.embeddings.fastembed import FastEmbedEmbeddings
from langchain_huggingface import HuggingFaceEmbeddings

from utils import initialize_llm

def load_documents(docs_dir: str = "books") -> List:
    """
    Load documents from directory
    
    """
    if not os.path.exists(docs_dir):
        raise ValueError(f"The specified directory {docs_dir} does not exist. Please enter a valid directory")

    loader = DirectoryLoader(
        docs_dir,
        glob="**/*.txt",
        loader_cls=TextLoader
    )
    documents = loader.load()
    print(f"Loaded {len(documents)} documents")
    return documents

def create_vectorstore(documents, embeddings=None, store_type: str = "faiss", persist_directory: Optional[str] = None):
    """
    Create vector store from documents
    
    """
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
    chunks = text_splitter.split_documents(documents)
    print(f"Split the documents into {len(chunks)} chunks")
    
    if embeddings is None:
        embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-base-en-v1.5")
        
    # Create Vector Store
    if store_type.lower() == "faiss":
        vector_store = FAISS.from_documents(chunks, embeddings)
        if persist_directory:
            vector_store.save_local(persist_directory)
    elif store_type.lower() == "chroma":
        if persist_directory:            
            vector_store = Chroma.from_documents(
                documents=chunks,
                embedding=embeddings,
                persist_directory=persist_directory
            )
            vector_store.persist()
        else:
            vector_store = Chroma.from_documents(
                documents=chunks,
                embedding=embeddings
            )
    else:
        raise ValueError(f"Unknown vector store type {store_type}")

    return vector_store

def initialize_retriever(docs_dir: str = "books", 
                         store_type: str = "faiss", 
                         persist_directory: Optional[str] = "vector", 
                         similarity_threshold=0.4):
    """
    Initialize retriever

    """
    # Check if vector store exists
    vector_store = None
    if persist_directory and os.path.exists(persist_directory):
        # Load existing vector store
        print(f"Loading vector store from {persist_directory}")
        embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-base-en-v1.5")
        if store_type.lower() == "faiss":
            vector_store = FAISS.load_local(persist_directory, embeddings, allow_dangerous_deserialization=True)
        elif store_type.lower() == "chroma":
            vector_store = Chroma(persist_directory=persist_directory, embedding_function=embeddings)
            
    # If vector store doesn't exist, create it
    if vector_store is None:
        documents = load_documents(docs_dir)
        if not documents:
            print("No documents in the directory")
            return None
        
        # Create Vector Store
        vector_store = create_vectorstore(
            documents,
            store_type=store_type,
            persist_directory=persist_directory
        )

    # Base Retriever
    base_retriever = vector_store.as_retriever(
        search_type="similarity",
        search_kwargs={"k": 5}
    )
    
    # Create the embeddings and the embeddings filter
    embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-base-en-v1.5")
    embeddings_filter = EmbeddingsFilter(
        embeddings=embeddings,
        similarity_threshold=similarity_threshold
    )

    # Create LLMChain Extractor to extract the relevant documents
    llm = initialize_llm()
    llm_extractor = LLMChainExtractor.from_llm(llm=llm)

    # Create a pipeline of compressors
    compression_pipeline = DocumentCompressorPipeline(
        transformers=[embeddings_filter, llm_extractor]
    )

    # Create the retriever
    retriever = ContextualCompressionRetriever(
        base_compressor=compression_pipeline,
        base_retriever=base_retriever
    )

    return retriever

Writing modules/retriever.py


In [10]:
%%writefile modules/prompts.py
from langchain.prompts import PromptTemplate

# Router prompt
ROUTER_PROMPT = PromptTemplate(
    template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are a specialized router that determines the appropriate data source for user queries.

# Available Data Sources:
1. Vectorstore - Contains only:
   * Your personal biography including any question about Chris Olande
   * The complete text of "Frankenstein" by Mary Shelley
   * The complete text of "Romeo and Juliet" by William Shakespeare

2. Web Search - For all other information needs

# Routing Rules:
- Use 'vectorstore' ONLY for questions specifically about:
  * Your personal biographical information
  * Details, quotes, characters, themes, or analysis of "Frankenstein"
  * Details, quotes, characters, themes, or analysis of "Romeo and Juliet"

- Use 'web_search' for:
  * All other questions
  * Current events and news
  * General knowledge questions
  * Any topic not directly related to your biography or the two literary works

# Output Format:
Return ONLY a JSON object with the key 'datasource' and value of either 'vectorstore' or 'web_search'.
Do not include any explanations, preambles, or additional text.

Question to route: {question}
<|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
    input_variables=["question"],
)

# Generation prompt
GENERATION_PROMPT = PromptTemplate(
    template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are a retrieval-augmented AI assistant that provides precise answers using only the supplied context.

## Response Guidelines:
- Answer using ONLY information from the provided context
- Keep responses to three sentences maximum
- Format important points in **bold** when appropriate
- Provide direct, factual answers without speculation
- If the context doesn't contain the answer, respond only with "I don't know"
- Do not reference the context or your instructions in your answer

## Remember:
- Never invent information or draw conclusions beyond what's explicitly stated
- Prioritize accuracy over completeness
- Use simple, clear language
<|eot_id|><|start_header_id|>user<|end_header_id|>
Question: {question}

Context:
{context}
<|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
    input_variables=["question", "context"],
)

# Retrieval grader prompt
RETRIEVAL_GRADER_PROMPT = PromptTemplate(
    template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are a precision document relevance evaluator. Your task is to determine if a retrieved document contains information relevant to answering a user's question.

## Evaluation Criteria:
- A document is "relevant" if it contains:
  * Direct answers to the question
  * Key concepts, terminology, or facts related to the question
  * Contextual information that would help form a complete answer

- A document is "not relevant" if it:
  * Contains no information related to the question
  * Only mentions keywords in an unrelated context
  * Addresses a completely different topic

## Output Requirements:
- Return ONLY a JSON object with the key 'score' and a value of either 'yes' or 'no'
- Do not include any explanations, reasoning, or additional text
- Be generous in assessing relevance - when in doubt, mark as relevant ('yes')
<|eot_id|><|start_header_id|>user<|end_header_id|>
USER QUESTION: {question}

RETRIEVED DOCUMENT:
{document}
<|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
    input_variables=["question", "document"],
)

# Hallucination grader prompt
HALLUCINATION_GRADER_PROMPT = PromptTemplate(
    template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are a factual accuracy validator that determines if a generated answer is fully supported by the provided reference documents.

## Evaluation Guidelines:
- Score 'yes' if ALL claims and statements in the answer are explicitly supported by information in the reference documents
- Score 'no' if ANY part of the answer:
  * Contains information not present in the documents
  * Makes assertions beyond what can be directly verified from the documents
  * Contradicts information in the documents
  * Presents speculative or uncertain information as definitive
  * Extends or extrapolates from the documents without clear support

## Key Assessment Principles:
- Be strict and precise - every claim must have direct evidence
- Focus on factual statements rather than phrasing or organization
- Consider implicit facts that are logically derivable from the documents as supported
- If uncertainty exists and the answer presents information with appropriate qualifiers, this is acceptable

## Output Format:
- Return ONLY a JSON object with the key 'score' and value of either 'yes' or 'no'
- Do not include explanations or reasoning in your output
<|eot_id|><|start_header_id|>user<|end_header_id|>
REFERENCE DOCUMENTS:
---
{documents}
---

GENERATED ANSWER:
{generation}
<|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
    input_variables=["generation", "documents"]
)

# Answer grader prompt
ANSWER_GRADER_PROMPT = PromptTemplate(
    template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are a specialized answer quality evaluator that assesses whether a response effectively addresses a user's question.

## Evaluation Criteria:
A "useful" answer (score: 'yes') must:
- Directly address the core intent of the question
- Provide substantive, relevant information
- Be clear and comprehensible
- Contain sufficient detail to satisfy the basic information need

An answer is "not useful" (score: 'no') if it:
- Is off-topic or addresses a different question
- Contains only vague, general statements without specific information
- Is factually incorrect (based on common knowledge)
- Is too incomplete to provide value
- Is unintelligible or incoherent
- Merely restates the question without providing new information

## Context Considerations:
- Consider both explicit and implicit information needs
- A partial answer that addresses the main point can still be "useful"
- The length of the answer is less important than its relevance and substance
- Technical accuracy matters more for technical questions

## Output Format:
- Return ONLY a JSON object with the key 'score' and value of either 'yes' or 'no'
- Do not include explanations or reasoning in your output
<|eot_id|><|start_header_id|>user<|end_header_id|>
USER QUESTION:
{question}

GENERATED ANSWER:
---
{generation}
---
<|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
    input_variables=["generation", "question"]
)

Writing modules/prompts.py


In [11]:
%%writefile modules/chains.py
from langchain_core.output_parsers import JsonOutputParser, StrOutputParser
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain.schema import Document
import os

from utils import initialize_llm, get_api_key
from prompts import (
    ROUTER_PROMPT, 
    GENERATION_PROMPT, 
    RETRIEVAL_GRADER_PROMPT,
    HALLUCINATION_GRADER_PROMPT,
    ANSWER_GRADER_PROMPT
)

def setup_chains():
    """
    Set up chains for question routing, RAG generation, and grading

    """
    # Initialize LLM
    llm = initialize_llm()
    
    # Question router chain
    question_router = ROUTER_PROMPT | llm | JsonOutputParser()
    
    # RAG chain
    rag_chain = GENERATION_PROMPT | llm | StrOutputParser()
    
    # Retrieval grader chain
    retrieval_grader = RETRIEVAL_GRADER_PROMPT | llm | JsonOutputParser()
    
    # Hallucination grader chain
    hallucination_grader = HALLUCINATION_GRADER_PROMPT | llm | JsonOutputParser()
    
    # Answer grader chain
    answer_grader = ANSWER_GRADER_PROMPT | llm | JsonOutputParser()
    
    return question_router, rag_chain, retrieval_grader, hallucination_grader, answer_grader

def setup_web_search():
    """
    Set up web search tool

    """
    # Get Tavily API key from environment variables
    os.environ['TAVILY_API_KEY'] = get_api_key("TAVILY_API_KEY")
    
    # Initialize web search tool
    web_search_tool = TavilySearchResults(k=3)
    
    return web_search_tool

def process_web_search_results(docs):
    """
    Process web search results into a Document

    """
    web_results = "\n".join([d["content"] for d in docs])
    return Document(page_content=web_results)

Writing modules/chains.py


In [13]:
%%writefile modules/graph.py
from typing_extensions import TypedDict
from typing import List
from langgraph.graph import END, StateGraph
from pprint import pprint

from retriever import initialize_retriever
from chains import setup_chains, setup_web_search, process_web_search_results

# Define graph state
class GraphState(TypedDict):
    question: str
    generation: str
    web_search: str
    documents: List[str]

def retrieve(state):
    """
    Retrieve documents from vectorstore

    """
    print("---RETRIEVE---")
    question = state["question"]

    # Retrieval
    documents = retriever.invoke(question)
    return {"documents": documents, "question": question}

def generate(state):
    """
    Generate answer using RAG on retrieved documents

    """
    print("---GENERATE---")
    question = state["question"]
    documents = state["documents"]

    # RAG generation
    generation = rag_chain.invoke({"context": documents, "question": question})
    return {"documents": documents, "question": question, "generation": generation}

def grade_documents(state):
    """
    Determines whether the retrieved documents are relevant to the question
    If any document is not relevant, we will set a flag to run web search

    """
    print("---CHECK DOCUMENT RELEVANCE TO QUESTION---")
    question = state["question"]
    documents = state["documents"]

    # Score each doc
    filtered_docs = []
    web_search = "No"
    for d in documents:
        score = retrieval_grader.invoke({"question": question, "document": d.page_content})
        grade = score['score']
        # Document relevant
        if grade.lower() == "yes":
            print("---GRADE: DOCUMENT RELEVANT---")
            filtered_docs.append(d)
        # Document not relevant
        else:
            print("---GRADE: DOCUMENT NOT RELEVANT---")
            # We do not include the document in filtered_docs
            # We set a flag to indicate that we want to run web search
            web_search = "Yes"
            continue
    return {"documents": filtered_docs, "question": question, "web_search": web_search}

def web_search(state):
    """
    Web search based on the question

    """
    print("---WEB SEARCH---")
    question = state["question"]
    documents = state.get("documents", [])  # SAFELY get documents or default to empty list

    # Web search
    docs = web_search_tool.invoke({"query": question})
    web_results = process_web_search_results(docs)

    documents.append(web_results)
    return {"documents": documents, "question": question}

def route_question(state):
    """
    Route question to web search or RAG.

    """
    print("---ROUTE QUESTION---")
    question = state["question"]
    print(question)
    source = question_router.invoke({"question": question})
    print(source)
    print(source['datasource'])
    if source['datasource'] == 'web_search':
        print("---ROUTE QUESTION TO WEB SEARCH---")
        return "websearch"
    elif source['datasource'] == 'vectorstore':
        print("---ROUTE QUESTION TO RAG---")
        return "vectorstore"

def decide_to_generate(state):
    """
    Determines whether to generate an answer, or add web search

    """
    print("---ASSESS GRADED DOCUMENTS---")
    question = state["question"]
    web_search = state["web_search"]
    filtered_documents = state["documents"]

    if web_search == "Yes":
        # All documents have been filtered check_relevance
        # We will re-generate a new query
        print("---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, INCLUDE WEB SEARCH---")
        return "websearch"
    else:
        # We have relevant documents, so generate answer
        print("---DECISION: GENERATE---")
        return "generate"

def grade_generation_v_documents_and_question(state):
    """
    Determines whether the generation is grounded in the document and answers question.

    """
    print("---CHECK HALLUCINATIONS---")
    question = state["question"]
    documents = state["documents"]
    generation = state["generation"]

    score = hallucination_grader.invoke({"documents": documents, "generation": generation})
    grade = score['score']

    # Check hallucination
    if grade == "yes":
        print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
        # Check question-answering
        print("---GRADE GENERATION vs QUESTION---")
        score = answer_grader.invoke({"question": question, "generation": generation})
        grade = score['score']
        if grade == "yes":
            print("---DECISION: GENERATION ADDRESSES QUESTION---")
            return "useful"
        else:
            print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
            return "not useful"
    else:
        pprint("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
        return "not supported"

def create_graph():
    """
    Create and configure the LangGraph workflow

    """
    # Create the StateGraph
    workflow = StateGraph(GraphState)
    
    # Define the nodes
    workflow.add_node("websearch", web_search)  # web search
    workflow.add_node("retrieve", retrieve)  # retrieve
    workflow.add_node("grade_documents", grade_documents)  # grade documents
    workflow.add_node("generate", generate)  # generate
    
    # Define entry point
    workflow.set_conditional_entry_point(
        route_question,
        {
            "websearch": "websearch",
            "vectorstore": "retrieve",
        },
    )
    
    # Define edges
    workflow.add_edge("retrieve", "grade_documents")
    workflow.add_conditional_edges(
        "grade_documents",
        decide_to_generate,
        {
            "websearch": "websearch",
            "generate": "generate",
        },
    )
    workflow.add_edge("websearch", "generate")
    workflow.add_conditional_edges(
        "generate",
        grade_generation_v_documents_and_question,
        {
            "not supported": "generate",
            "useful": END,
            "not useful": "websearch",
        },
    )
    
    # Compile the graph
    return workflow.compile()

# Initialize globals (to be used by graph components)
retriever = None
question_router = None
rag_chain = None
retrieval_grader = None
hallucination_grader = None
answer_grader = None
web_search_tool = None

def init_globals():
    """
    Initialize global variables needed for the graph
    """
    global retriever, question_router, rag_chain, retrieval_grader, hallucination_grader, answer_grader, web_search_tool
    
    # Initialize retriever
    retriever = initialize_retriever(persist_directory="vector")
    
    # Set up chains
    question_router, rag_chain, retrieval_grader, hallucination_grader, answer_grader = setup_chains()
    
    # Set up web search tool
    web_search_tool = setup_web_search()
    

Overwriting modules/graph.py


In [None]:
%%writefile modules/main.py
import os
from pprint import pprint

from graph import create_graph, init_globals

def setup_environment():
    
    # Create vector directory if it doesn't exist
    os.makedirs("vector", exist_ok=True)

def process_query(query):
    """
    Process a query through the RAG workflow

    """
    # Create input for the workflow
    inputs = {"question": query}
    
    # Create and run the graph
    app = create_graph()
    
    # Stream outputs for debugging
    final_output = None
    for output in app.stream(inputs):
        for key, value in output.items():
            pprint(f"Finished running: {key}:")
        final_output = value
    
    # Return the final generation
    return final_output["generation"]

def main():
    """
    Main function
    """
    # Set up environment
    setup_environment()
    
    # Initialize global components
    init_globals()
    
    # Continuously prompt the user until they type 'exit' or 'quit'
    while True:
        query = input("\nEnter your question (type 'exit' or 'quit' to end): ")
        
        # Check if user wants to exit
        if query.lower() in ['exit', 'quit']:
            print("Exiting program. Goodbye!")
            break
        
        # Process the query and display response
        response = process_query(query)
        
        # Print the response
        print("\nFinal response:")
        print(response)

if __name__ == "__main__":
    main()