In [1]:
import os
from datasets import load_dataset
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
import uuid

In [2]:
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_huggingface import HuggingFacePipeline
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from transformers import pipeline

In [3]:
def build_rag_store(
        n: int = 1000,
        output_folder: str = "rag_triviaqa_store",
        embeddings_model_name: str = "mistralai/Mistral-Embed"
        ) -> None:

    if os.path.exists(output_folder):
        print(f"Vector store already exists at {output_folder}. Skipping build.")
        return

    # -----------------------
    # 1. Load TriviaQA (rc)
    # -----------------------
    print("Loading TriviaQA rc...")
    try:
        dataset = load_dataset("trivia_qa", "rc", split=f"validation[:{n}]", trust_remote_code=True)
    except Exception as e:
        print(f"Error loading dataset: {e}")
        return

    rag_docs = []

    # -----------------------
    # 2. Initialize text splitter
    # -----------------------
    splitter = RecursiveCharacterTextSplitter(
        chunk_size=512,
        chunk_overlap=64,
        length_function=len
    )

    # -----------------------
    # 3. Extract RAG documents
    # -----------------------
    print("Extracting RAG documents...")

    for q_idx, row in enumerate(dataset):
        question_id = q_idx

        entity_pages = row.get("entity_pages", {})
        search_results = row.get("search_results", {})

        # --- Wikipedia pages ---
        ep_contexts = entity_pages.get("wiki_context", [])
        ep_titles = entity_pages.get("title", [])
        ep_urls = entity_pages.get("url", [])

        for i, text in enumerate(ep_contexts):
            if text and text.strip():
                doc = Document(
                    page_content=text,
                    metadata={
                        "id": str(uuid.uuid4()),
                        "question_id": question_id,
                        "title": ep_titles[i] if i < len(ep_titles) else "Unknown",
                        "url": ep_urls[i] if i < len(ep_urls) else "Unknown",
                        "source": "wiki"
                    }
                )
                # Split long pages into chunks
                chunks = splitter.split_documents([doc])
                for idx, chunk in enumerate(chunks):
                    chunk.metadata["chunk_index"] = idx
                    rag_docs.append(chunk)

        # --- Web search result snippets ---
        sr_contexts = search_results.get("search_context", [])
        sr_titles = search_results.get("title", [])
        sr_urls = search_results.get("url", [])

        for i, text in enumerate(sr_contexts):
            if text and text.strip():
                doc = Document(
                    page_content=text,
                    metadata={
                        "id": str(uuid.uuid4()),
                        "question_id": question_id,
                        "title": sr_titles[i] if i < len(sr_titles) else "Unknown",
                        "url": sr_urls[i] if i < len(sr_urls) else "Unknown",
                        "source": "web"
                    }
                )
                chunks = splitter.split_documents([doc])
                for idx, chunk in enumerate(chunks):
                    chunk.metadata["chunk_index"] = idx
                    rag_docs.append(chunk)

    print(f"Created {len(rag_docs)} RAG documents.")

    # -----------------------
    # 4. Build Vector Store
    # -----------------------
    print("Embedding and building FAISS index...")
    embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name)

    if rag_docs:
        vectorstore = FAISS.from_documents(rag_docs, embeddings)
        vectorstore.save_local(output_folder)
        print(f"Saved RAG datastore to '{output_folder}'.")
    else:
        print("No documents found to index.")

In [9]:
class RAGPipeline:
    def __init__(
        self,
        model,
        tokenizer,
        embeddings_model_name,
        vectorstore_folder="rag_triviaqa_store",
        top_k=5,
        max_new_tokens=1024,
        temperature=0.1,
        do_sample=True,
        top_p=0.95
    ):
        self.top_k = top_k

        # Embeddings + FAISS
        self.embeddings = HuggingFaceEmbeddings(
            model_name=embeddings_model_name
        )

        self.vectorstore = FAISS.load_local(
            vectorstore_folder,
            self.embeddings,
            allow_dangerous_deserialization=True
        )

        self.retriever = self.vectorstore.as_retriever(
            search_kwargs={"k": self.top_k}
        )

        # HF generation pipeline
        pipe = pipeline(
            "text-generation",
            model=model,
            tokenizer=tokenizer,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            do_sample=do_sample,
            top_p=top_p,
            return_full_text=False
        )

        self.llm = HuggingFacePipeline(pipeline=pipe)

        # Prompt
        self.prompt = ChatPromptTemplate.from_template("context: {context}\nquestion: {input}\nanswer:")

        # LCEL RAG chain
        self.chain = (
            {
                "context": self.retriever,
                "input": RunnablePassthrough()
            }
            | self.prompt
            | self.llm
        )

    def query(self, text: str):
        answer = self.chain.invoke(text)

        docs = self.retriever.invoke(text)

        return {
            "answer": str(answer),
            "sources": docs
        }

In [10]:
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch

# Define the model name
model_name = "mistralai/Mistral-7B-Instruct-v0.2"

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_name)

# For faster inference and reduced memory usage, we can quantize the model.
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=False,
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map="auto"
)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [6]:
# Build the RAG store
# This will download the TriviaQA dataset and create a FAISS index.
# It might take a while the first time you run it.
# The 'n' parameter limits the number of validation samples to process.
build_rag_store(n=1000, embeddings_model_name="sentence-transformers/all-MiniLM-L6-v2")

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'trivia_qa' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'trivia_qa' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


Loading TriviaQA rc...


Resolving data files:   0%|          | 0/26 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/26 [00:00<?, ?it/s]

Extracting RAG documents...
Created 284908 RAG documents.
Embedding and building FAISS index...


  embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name)


Saved RAG datastore to 'rag_triviaqa_store'.


In [11]:
# Initialize the RAG pipeline
rag_pipeline = RAGPipeline(
    model=model,
    tokenizer=tokenizer,
    embeddings_model_name="sentence-transformers/all-MiniLM-L6-v2",
    vectorstore_folder="rag_triviaqa_store",
    max_new_tokens=5, # Limit tokens to 5
    temperature=0.0, # For greedy decoding
    do_sample=False, # For greedy decoding
    top_p=1.0 # For greedy decoding
)

Device set to use cuda:0
The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


In [12]:
# Example query
query = "Who was the man behind The Chipmunks?"
result = rag_pipeline.query(query)

print("Answer:", result["answer"])
print("\nSources:")
for doc in result["sources"]:
    print(f"  - Title: {doc.metadata['title']}, URL: {doc.metadata['url']}")

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


Answer:  Dave Seville (J

Sources:
  - Title: Alvin and the Chipmunks (2007) - IMDb, URL: http://www.imdb.com/title/tt0952640/
  - Title: Alvin and the Chipmunks (2007) - IMDb, URL: http://www.imdb.com/title/tt0952640/
  - Title: Alvin and the Chipmunks (2007) - IMDb, URL: http://www.imdb.com/title/tt0952640/
  - Title: Alvin and the Chipmunks (2007) - IMDb, URL: http://www.imdb.com/title/tt0952640/
  - Title: Field Guide/Mammals/United States/Minnesota - Wikibooks ..., URL: https://en.wikibooks.org/wiki/Field_Guide/Mammals/United_States/Minnesota
