# RAG

Implement a base RAG module in DSPy. 
Given a question, retrieve the top-k documents in a list of HTML documents, then pass them as context to an LLM.

Refer to https://dspy.ai/tutorials/rag/. 


In [1]:
import dspy
from sentence_transformers import SentenceTransformer

# Load an extremely efficient local model for retrieval
model = SentenceTransformer("sentence-transformers/static-retrieval-mrl-en-v1", device="cpu")

# Create an embedder using the model's encode method
embedder = dspy.Embedder(model.encode)

# Traverse a directory and read html files - extract text from the html files
import os
from bs4 import BeautifulSoup
def read_html_files(directory):
    texts = []
    for filename in os.listdir(directory):
        if filename.endswith(".html"):
            with open(os.path.join(directory, filename), 'r', encoding='utf-8') as file:
                soup = BeautifulSoup(file, 'html.parser')
                texts.append(soup.get_text())
    return texts

In [18]:
corpus = read_html_files("PragmatiCQA-sources/The Legend of Zelda")
print(f"Loaded {len(corpus)} documents. Will encode them below.")

Loaded 406 documents. Will encode them below.


In [19]:
# Parameters for the retriever
max_characters = 10000  # for truncating >99th percentile of documents
topk_docs_to_retrieve = 5  # number of documents to retrieve per search query

search = dspy.retrievers.Embeddings(embedder=embedder, corpus=corpus, k=topk_docs_to_retrieve)



In [38]:
# lm = dspy.LM('ollama_chat/devstral', api_base='http://localhost:11434', api_key='')
with open("xai_key.txt") as f:
    api_key = f.read().strip()

lm = dspy.LM('xai/grok-3-mini', api_key=api_key)
dspy.configure(lm=lm)

In [39]:
class RAG(dspy.Module):
    def __init__(self):
        self.respond = dspy.ChainOfThought('context, question -> response')

    def forward(self, question):
        context = search(question).passages
        return self.respond(context=context, question=question)
    
rag = RAG()

In [40]:
answer = rag(question="What is the main plot of The Legend of Zelda?")  # Example query

print(answer.response)  # Print the response from the RAG model

The main plot of *The Legend of Zelda* revolves around a young hero named Link who must save Princess Zelda from the evil Ganon, the Prince of Darkness. Ganon has stolen the Triforce of Power and seeks the Triforce of Wisdom to conquer the kingdom of Hyrule. Zelda hides the eight fragments of the Triforce of Wisdom to prevent Ganon from obtaining it, and she sends her nursemaid Impa to find a brave warrior. Link embarks on a quest to collect these fragments, reassemble the Triforce, navigate treacherous dungeons, and ultimately defeat Ganon to restore peace to Hyrule.


In [41]:
q = 'What year did the Legend of Zelda come out?' 

print(rag(question=q).response)



1986


In [43]:
import os
from bs4 import BeautifulSoup

def load_topic_corpus(topic_name, sources_dir="PragmatiCQA-sources"):
    topic_dir = os.path.join(sources_dir, topic_name)
    texts = []
    for filename in os.listdir(topic_dir):
        if filename.endswith(".html"):
            with open(os.path.join(topic_dir, filename), 'r', encoding='utf-8') as file:
                soup = BeautifulSoup(file, 'html.parser')
                texts.append(soup.get_text())
    return texts

def retrieve_for_question(question, topic, top_k=5, sources_dir="PragmatiCQA-sources"):
    corpus = load_topic_corpus(topic, sources_dir=sources_dir)
    search = dspy.retrievers.Embeddings(embedder=embedder, corpus=corpus, k=top_k)
    results = search(question).passages
    return results
