## Library Imports

In [None]:
%load_ext autoreload 
%autoreload 2
import os
import nest_asyncio
from IPython.display import display, Markdown, clear_output

nest_asyncio.apply()
clear_output()

In [None]:
QDRANT_HOST = os.getenv("QDRANT_HOST", "localhost")
QDRANT_PORT = int(os.getenv("QDRANT_PORT", 6333))
OLLAMA_BASE_URL = os.getenv("OLLAMA_HOST", "http://localhost:11434")

In [None]:
def pretty_print(data):
    if isinstance(data, str):
        display(Markdown(data))
    elif isinstance(data, dict):
        for key, value in data.items():
            display(Markdown(f"**{key}:** {value}"))
    else:
        display(data)

## Traditional RAG

### Dataset Preparation

In [None]:
from datasets import load_dataset

## Step 1: Load the SQuAD dataset
dataset = load_dataset("squad")

## Step 2 : Extract unique contexts from the dataset
data = [item["context"] for item in dataset["train"]]
texts = list(set(data))

### Embed Dataset

In [None]:
from llama_index.embeddings.ollama import OllamaEmbedding
from tqdm import tqdm


def batch_iterate(lst, batch_size):
    """Yield successive n-sized chunks from lst."""
    for i in range(0, len(lst), batch_size):
        yield lst[i : i + batch_size]


class EmbedData:
    def __init__(
        self,
        embed_model_name="hf.co/Qwen/Qwen3-Embedding-0.6B-GGUF:Q8_0",
        batch_size=32,
    ):
        self.embed_model_name = embed_model_name
        self.batch_size = batch_size
        self.embed_model = self._load_embed_model()
        self.embeddings = []

    def _load_embed_model(self):
        embed_model = OllamaEmbedding(
            model_name=self.embed_model_name,
            base_url=OLLAMA_BASE_URL,
        )
        return embed_model

    def generate_embeddings(self, text):
        self.embeddings = self.embed_model.get_text_embedding_batch(texts=text)
        return self.embeddings

    def embed(self, contexts):
        self.contexts = contexts
        for batch in tqdm(batch_iterate(contexts, self.batch_size), desc="Embedding"):
            embeddings = self.generate_embeddings(batch)
            self.embeddings.extend(embeddings)

In [None]:
batch_size = 32
embeddata = EmbedData(batch_size=batch_size)
embeddata.embed(texts)

In [None]:
pretty_print(embeddata.contexts[0]), pretty_print(embeddata.embeddings[0])

### Vector DB

In [None]:
from qdrant_client import QdrantClient
from qdrant_client.http import models


class QdrantVDB:
    def __init__(self, collection_name, vector_dim=768, batch_size=512):
        self.collection_name = collection_name
        self.vector_dim = vector_dim
        self.batch_size = batch_size

    def define_client(self):
        self.client = QdrantClient(
            url=f"http://{QDRANT_HOST}:{QDRANT_PORT}",
            prefer_grpc=False,
        )

    def create_collection(self):
        if not self.client.collection_exists(self.collection_name):
            self.client.create_collection(
                collection_name=self.collection_name,
                # NOTE: We use similarity search with dot product, and store the vectors
                # on disk instead of memory to optimize memory usage for large datasets
                vectors_config=models.VectorParams(
                    size=self.vector_dim, distance=models.Distance.DOT, on_disk=True
                ),
                # NOTE: Optimizer config is necessary to optimize storage
                # and indexing performance
                optimizers_config=models.OptimizersConfigDiff(
                    default_segment_number=9, indexing_threshold=0
                ),
            )

    def ingest_data(self, embeddata):
        # Zip the contexts and embeddings into pairs (eagerly convert to list for len())
        paired_data = list(zip(embeddata.contexts, embeddata.embeddings))
        # Iterate over zipped batches of (context, embedding) pairs
        for batch in tqdm(
            batch_iterate(paired_data, self.batch_size),
            total=len(paired_data) // self.batch_size,
            desc="Ingesting in batches",
        ):
            # Unzip the batch into separate lists of contexts and embeddings
            batch_contexts, batch_embeddings = zip(*batch)

            # Upload the batch to the collection
            # For each batch, we invoke the .client.upload_collection to store the embeddings and their associated metadata (payload). Payload stores metadata such as the original context for each vector.

            self.client.upload_collection(
                collection_name=self.collection_name,
                vectors=batch_embeddings,  # List of embedding vectors
                payload=[
                    {"context": context} for context in batch_contexts
                ],  # Associated metadata
            )

        # Configuration to update the collection only if the total data ingested in the latest run exceeds a certain threshold.
        # We specify the threshold, so that we are not updating the vector db as soon as a new entry is added, but rather after a certain number of entries have been added.
        self.client.update_collection(
            collection_name=self.collection_name,
            optimizer_config=models.OptimizersConfigDiff(indexing_threshold=20000),
        )

In [None]:
database = QdrantVDB(collection_name="squad_collection")
database.define_client()
database.create_collection()
database.ingest_data(embeddata)
clear_output()

### Retrieval

- Encapsulate the logic for searching the vector db using a query (of string type)
- Using the embedding model and vector db, we can retriee the most relevant contexts based on the query.

In [None]:
import time


class Retriever:
    def __init__(self, vector_db, embeddata):
        self.vector_db = vector_db
        self.embeddata = embeddata

    def search(self, query):
        # Use hf function to get the query embedding
        query_embedding = self.embeddata.embed_model.get_query_embedding(query)

        #
        start_time = time.time()
        result = self.vector_db.client.query_points(
            collection_name=self.vector_db.collection_name,
            query=query_embedding,
            search_params=models.SearchParams(
                quantization=models.QuantizationSearchParams(
                    ignore=True, rescore=True, oversampling=2.0
                )
                # Ignore quantization during search for high precision
                # Rescore the results after the initial quantized search for better accuracy
                # Oversampling to fetch additional candidates to improve result quality
            ),
            timeout=1000,
        )
        end_time = time.time()
        elapsed_time = end_time - start_time

        print(f"Execution time for search: {elapsed_time:.2f} seconds")

        return result

In [None]:
result = Retriever(database, embeddata).search("Sample Query")

In [None]:
for data in result.points:
    pretty_print(dict(data)["payload"]["context"])

### RAG

In [None]:
from llama_index.llms.ollama import Ollama


class RAG:
    def __init__(self, retriever, llm_name="phi3:3.8b"):
        self.llm_name = llm_name
        self.llm = self._setup_llm()
        self.retriever = retriever
        self.qa_prompt_tmpl_str = """Context information is below.
                                     ---------------------
                                     {context}
                                     ---------------------
                                     
                                     Given the context information above I want you
                                     to think step by step to answer the query in a
                                     crisp manner, incase case you don't know the
                                     answer say 'I don't know!'
                                     
                                     ---------------------
                                     Query: {query}
                                     ---------------------
                                     Answer: """

    def _setup_llm(self):
        return Ollama(model=self.llm_name, base_url=OLLAMA_BASE_URL)

    # Retrieve relevant results from the vector database
    def generate_context(self, query):
        # Use the retriever to get relevant context
        search_result = self.retriever.search(query)
        if not search_result.points:
            return "No relevant context found."

        # Iterate through the search results and extract the context field from
        # each points payload and append each context to a list called combined_prompt
        context = [dict(point) for point in search_result.points]
        combined_prompt = []
        for entry in context:
            context = entry["payload"]["context"]
            combined_prompt.append(context)

        return "\n\n --- \n\n".join(combined_prompt)

    # Collating everything together into a query method, which will accept the user query,
    # generate a context for it, format the prompt template, to create a prompt, send it to the LLM, and return the generated response.

    def query(self, query):
        context = self.generate_context(query)

        prompt = self.qa_prompt_tmpl_str.format(context=context, query=query)

        response = self.llm.complete(prompt)

        return context, dict(response)["text"]

### Using RAG

In [None]:
retriever = Retriever(database, embeddata)
rag = RAG(retriever, llm_name="phi3:3.8b")


# Taking a look at dummy data, and forming a query based on it
pretty_print(embeddata.contexts[15])

In [None]:
query = "The premium and VIP services in Airports are reserved for which type of passengers?"
context, response = rag.query(query)

pretty_print(context), pretty_print(response)