In [66]:
import os
import sys
from dotenv import load_dotenv
from langchain.prompts import PromptTemplate
from langchain.docstore.document import Document
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_ollama import OllamaEmbeddings, OllamaLLM
from langchain.vectorstores import FAISS

from pydantic import BaseModel, Field


In [67]:
path = "../data/Understanding_Climate_Change.pdf"

In [68]:
def replace_t_with_space(list_of_documents):
    """
    Replaces all tab characters ('\t') with spaces in the page content of each document

    Args:
        list_of_documents: A list of document objects, each with a 'page_content' attribute.

    Returns:
        The modified list of documents with tab characters replaced by spaces.
    """

    for doc in list_of_documents:
        doc.page_content = doc.page_content.replace('\t', ' ')  # Replace tabs with spaces
    return list_of_documents

def encode_pdf(path, chunk_size=1000, chunk_overlap=200):
    """
    Encodes a PDF book into a vector store using OpenAI embeddings.

    Args:
        path: The path to the PDF file.
        chunk_size: The desired size of each text chunk.
        chunk_overlap: The amount of overlap between consecutive chunks.

    Returns:
        A FAISS vector store containing the encoded book content.
    """

    # Load PDF documents
    loader = PyPDFLoader(path)
    documents = loader.load()

    # Split documents into chunks
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=chunk_size, chunk_overlap=chunk_overlap, length_function=len
    )
    
    texts = text_splitter.split_documents(documents)
    cleaned_texts = replace_t_with_space(texts)

    # Create embeddings and vector store
    embedding_model = OllamaEmbeddings(model="nomic-embed-text")
    vectorstore = FAISS.from_documents(cleaned_texts, embedding_model)

    return vectorstore

vectorstore = encode_pdf(path)

In [69]:
from langchain.output_parsers import PydanticOutputParser

class RetrievalResponse(BaseModel):
    response: str = Field(..., description="Output only 'Yes' or 'No'.")

# Setup parser
parser = PydanticOutputParser(pydantic_object=RetrievalResponse)

# Prompt with format instructions
retrieval_prompt = PromptTemplate(
                                    input_variables=["query"],
                                    template=(
                                        "Given the query '{query}', determine if retrieval is necessary.\n"
                                        "Output only 'Yes' or 'No'.\n{format_instructions}"
                                    ),
                                    partial_variables={"format_instructions": parser.get_format_instructions()}
                                )

llm = OllamaLLM(model="llama3")

def run_retrieval_chain(query: str) -> RetrievalResponse:
    prompt = retrieval_prompt.format(query=query)
    output = llm.invoke(prompt)
    return parser.parse(output)

In [70]:

class GenerationResponse(BaseModel):
    response: str = Field(..., description="The generated response.")

parser = PydanticOutputParser(pydantic_object=GenerationResponse)

generation_prompt = PromptTemplate(
    input_variables=["query", "context"],
    template=(
        "Given the query:\n{query}\n\n"
        "and the context:\n{context}\n\n"
        "Generate a helpful response.\n"
        "{format_instructions}"
    ),
    partial_variables={"format_instructions": parser.get_format_instructions()}
)

llm = OllamaLLM(model="llama3")

def run_generation_chain(query: str, context: str) -> GenerationResponse:
    prompt = generation_prompt.format(query=query, context=context)
    output = llm.invoke(prompt)
    return parser.parse(output)

In [84]:
class RelevanceResponse(BaseModel):
    response: str = Field(..., description="Output only 'Relevant' or 'Irrelevant'.")

parser = PydanticOutputParser(pydantic_object=RelevanceResponse)

relevance_prompt = PromptTemplate(
                        input_variables=["query", "context"],
                        template="Given the query '{query}' and the context '{context}', determine if the context is relevant. Output only 'relevant' or 'irrelevant' only."
                        )


llm = OllamaLLM(model="llama3")

def run_relevance_chain(query: str, context: str) -> RelevanceResponse:
    """Run relevance checking chain."""
    prompt = relevance_prompt.format(query=query, context=context)
    output = llm.invoke(prompt)
    return output


In [87]:

class SupportResponse(BaseModel):
    response: str = Field(..., description="Output 'Fully supported', 'Partially supported', or 'No support'.")

support_parser = PydanticOutputParser(pydantic_object=SupportResponse)

support_prompt = PromptTemplate(
    input_variables=["response", "context"],
    template="""
You are a strict classifier. 
Given the response '{response}' and the context '{context}', determine if the response is supported. 

Return ONLY a valid JSON object, nothing else, in this exact format:

{{"response": "Fully supported"}}
OR
{{"response": "Partially supported"}}
OR
{{"response": "No support"}}
"""
)

class UtilityResponse(BaseModel):
    response: int = Field(..., description="Rate the utility of the response from 1 to 5.")

utility_parser = PydanticOutputParser(pydantic_object=UtilityResponse)

utility_prompt = PromptTemplate(
    input_variables=["query", "response"],
    template="""
You are a strict evaluator. 
Given the query '{query}' and the response '{response}', rate the utility of the response. 

Return ONLY a valid JSON object, nothing else, in this exact format:

{{"response": 1}} OR {{"response": 2}} OR {{"response": 3}} OR {{"response": 4}} OR {{"response": 5}}
"""
)
llm = OllamaLLM(model="llama3")

def run_support(response: str, context: str):
    prompt = support_prompt.format(response=response, context=context)
    raw_output = llm.invoke(prompt)
    try:
        return support_parser.parse(raw_output)
    except Exception as e:
        print("⚠️ Parse error, raw output:", raw_output)
        return {"response": "No support"}  # fallback

def run_utility(query: str, response: str):
    prompt = utility_prompt.format(query=query, response=response)
    raw_output = llm.invoke(prompt)
    try:
        return utility_parser.parse(raw_output)
    except Exception as e:
        print("⚠️ Parse error, raw output:", raw_output)
        return {"response": 3}  # fallback


In [92]:
def self_rag(query, vectorstore, top_k=3):
    print(f"\nProcessing query: {query}")
    
    # Step 1: Determine if retrieval is necessary
    print("Step 1: Determining if retrieval is necessary...")
    retrieval_decision = run_retrieval_chain(query).response.strip().lower()
    print(f"Retrieval decision: {retrieval_decision}")

    if retrieval_decision == 'yes':
        # Step 2: Retrieve relevant documents
        print("Step 2: Retrieving relevant documents...")
        docs = vectorstore.similarity_search(query, k=top_k)
        contexts = [doc.page_content for doc in docs]
        print(f"Retrieved {len(contexts)} documents")

        # Step 3: Evaluate relevance of retrieved documents
        print("Step 3: Evaluating relevance of retrieved documents...")
        relevant_contexts = []
        for i, context in enumerate(contexts):
            relevance = run_relevance_chain(query, context)
            print(f"Document {i+1} relevance: {relevance}")
            if relevance.lower() == 'relevant':
                relevant_contexts.append(context)

        print(f"Number of relevant contexts: {len(relevant_contexts)}")
        # If no relevant contexts found, generate without retrieval
        if not relevant_contexts:
            print("No relevant contexts found. Generating without retrieval...")
            input_data = {"query": query, "context": "No relevant context found."}
            return generation_chain.invoke(input_data).response

        # Step 4: Generate response using relevant contexts
        print("Step 4: Generating responses using relevant contexts...")
        responses = []
        for i, context in enumerate(relevant_contexts):
            print(f"Generating response for context {i+1}...")
            input_data = {"query": query, "context": context}
            response = run_generation_chain(query, context)
            
            # Step 5: Assess support
            print(f"Step 5: Assessing support for response {i+1}...")
            input_data = {"response": response, "context": context}
            support = run_support(response, context)
            print(f"Support assessment: {support}")
            
            # Step 6: Evaluate utility
            print(f"Step 6: Evaluating utility for response {i+1}...")
            utility = run_utility(query, response)
            print(f"Utility score: {utility}")
            
            responses.append((response, support, utility))

        # Select the best response based on support and utility
        print("Selecting the best response...")
        best_response = max(responses, key=lambda x: (x[1] == 'fully supported', x[2]))
        print(f"Best response support: {best_response[1]}, utility: {best_response[2]}")
        return best_response[0]
            
    else:
        # Generate without retrieval
        print("Generating without retrieval...")
        return run_generation_chain(query, "No retrieval necessary.")

In [93]:
query = "What is the impact of climate change on the environment?"
response = self_rag(query, vectorstore)

print("\nFinal response:")
print(response)


Processing query: What is the impact of climate change on the environment?
Step 1: Determining if retrieval is necessary...
Retrieval decision: yes
Step 2: Retrieving relevant documents...
Retrieved 3 documents
Step 3: Evaluating relevance of retrieved documents...
Document 1 relevance: Relevant
Document 2 relevance: Relevant
Document 3 relevance: Relevant
Number of relevant contexts: 3
Step 4: Generating responses using relevant contexts...
Generating response for context 1...
Step 5: Assessing support for response 1...
Support assessment: response='Fully supported'
Step 6: Evaluating utility for response 1...
Utility score: response=5
Generating response for context 2...
Step 5: Assessing support for response 2...
Support assessment: response='Fully supported'
Step 6: Evaluating utility for response 2...
Utility score: response=5
Generating response for context 3...
Step 5: Assessing support for response 3...
Support assessment: response='Fully supported'
Step 6: Evaluating utility 