# Exploring a RAG System with Sentence Transformers, ChromaDB, and Ollama

This notebook demonstrates the implementation of a **Retrieval-Augmented Generation (RAG) system** using:
- **Sentence Transformers** for embedding text,
- **ChromaDB** as the vector database, and
- **Ollama** as the LLM for generating responses.

The primary goal is to understand how these components interact to enhance information retrieval and generation. 

> **Note:** This notebook is for **educational purposes** only and is not intended for production use.

In [None]:
#Install the required packages. Given that ollama api is running on the UBELIX
#before running this, open the terminal and run this command in another tab:
#singularity exec --nv /storage/research/dsl_shared/solutions/singularity/ollama.sif ollama serve & 
!pip install pdfplumber==0.11.4
!pip install ollama==0.3.3
!pip install nltk==3.9.1
!pip install sentence-transformers==3.2.1
!pip install chromadb==0.5.18

In [None]:
import os
import pdfplumber
import nltk
import ollama
import chromadb
from tqdm import tqdm
from sentence_transformers import SentenceTransformer


nltk.download('punkt', quiet=True)
nltk.download('punkt_tab', quiet=True)

In [None]:
ollama.list()

In [None]:
ollama.pull("llama3.2")

In [None]:
#let's set some variables

model_name = "Lajavaness/bilingual-embedding-large"  #choose any embedding model you prefer

vector_db = "chromaDB" # Allowed Values ['chromaDB', 'FAISS']. Only ChromaDB works now

collection_name = "dsl_embeddings3"

raw_db = "/storage/homefs/aa22x177/dsl_data"  #root directory to where raw documents are stored

data_language = "english" #variable for the tokenizer. Supported language = ['czech', 'danish', 'dutch', 'english', 'estonian', 'finnish', 'french', 'german' ,'greek' ,'italian' ,'norwegian', 'polish' ,'portuguese', 'russian' ,'slovene','spanish', 'swedish', 'turkish']

db_directory = os.path.join(os.path.expanduser('~'), '.db')  #default. Change it to where you want to store the vector DB

chunk_size = 20 #stands for the number of sentences per chunk

llm_model = 'llama3.2:latest' # select any model available on the ollama site https://ollama.com/search


prompt = """
You are a helpful polite assistat that works at the Data Science Lab (DSL). given the following data about DSL: \n {data} \n
and the following query: \n{query}\n
generate a responde. If there are no relevant information in the data, state that you don't have an answer and advise to contact DSL through info.dsl@unibe.ch or support.dsl@unibe.ch
"""


In [None]:
# we create helper functions that helps us read into raw data [txt, pdf] in order to convert them into strings and create embeddings

def get_file_paths(root_dir: str, file_extensions: list[str]) -> list[str]:
    """
    Retrieves a list of paths to all files with specified extensions in the given root directory and its subdirectories.

    Args:
        root_dir (str): The root directory to search for files.
        file_extensions (list[str]): A list of file extensions to retrieve. For example, ["txt", "pdf"]

    Returns:
        List[str]: A list of file paths to all matching files found within the root directory and its subdirectories.
    """
    file_paths = []
    
    for dirpath, _, filenames in os.walk(root_dir):
        for filename in filenames:
            if any(filename.endswith(f".{ext}") for ext in file_extensions):
                file_paths.append(os.path.join(dirpath, filename))
    
    return file_paths



def read_text_file(file_path: str) -> str:
    """
    Reads the content of a text file and returns it as a single string.

    Args:
        file_path (str): The path to the .txt file to read.

    Returns:
        str: The content of the file as a single string.
    """
    with open(file_path, 'r', encoding='utf-8') as file:
        content = file.read()
    
    return content


def read_pdf_file(file_path: str) -> str:
    """
    Reads the content of a PDF file and returns it as a single string.
    
    Args:
        file_path (str): The path to the PDF file to read.
    
    Returns:
        str: The content of the PDF as a single string.
    """
    text_content = []
    
    with pdfplumber.open(file_path) as pdf:
        for page in pdf.pages:
            # Extract text from each page
            page_text = page.extract_text()
            if page_text:  # Ensure the page has text
                text_content.append(page_text)
    
    # Join all pages' text into a single string
    return "\n".join(text_content)


def split_text_into_sentences(text: str, language: str) -> list[str]:
    """
    Splits the given text into a list of sentences using NLTK's sentence tokenizer.

    Args:
        text (str): The input text to split into sentences.
        language (str): The language of the text for the sentence tokenizer

    Returns:
        list[str]: A list of sentences.
    """
    sentences = nltk.sent_tokenize(text, language=language)
    return sentences


def chunk_sentences(sentences: list[str], chunk_size: int) -> list[str]:
    """
    Groups a list of sentences into chunks, each containing up to `chunk_size` sentences.

    Args:
        sentences (list[str]): A list of sentences.
        chunk_size (int): The number of sentences per chunk.

    Returns:
        list[str]: A list of text chunks, each containing up to `chunk_size` sentences.
    """
    chunks = []
    for i in range(0, len(sentences), chunk_size):
        chunk = " ".join(sentences[i:i + chunk_size])
        chunks.append(chunk)
    return chunks

In [None]:
# Initialize ChromaDB client
client = chromadb.PersistentClient(path=db_directory)


print("\n--- Embedding and Storing Documents in ChromaDB ---")
print(f"Embedding Model: {model_name}")
print(f"Chunk Size (sentences per chunk): {chunk_size}")
print(f"Raw Data Directory: {raw_db}")
print(f"Vector Database Directory: {db_directory}\n")
print(f"Vector Database is: {vector_db}\n")

# Step 1: Load documents (txt and pdf)
file_paths = get_file_paths(raw_db, ["txt", "pdf"])
print(f"Found {len(file_paths)} files to process.\n")

# Initialize embedding model
embedding_model = SentenceTransformer(model_name, trust_remote_code=True)
max_seq_length = embedding_model.max_seq_length  # Typically 512 for older models. Newer ones have larger input size

# Create or retrieve the collection in ChromaDB
collection = client.get_or_create_collection(collection_name)

for file_path in tqdm(file_paths, desc="Processing documents"):
    # Step 2: Read content based on file type
    if file_path.endswith('.txt'):
        text = read_text_file(file_path)
    elif file_path.endswith('.pdf'):
        text = read_pdf_file(file_path)
    else:
        print(f"Unsupported file type: {file_path}")
        continue

    # Step 3: Split text into sentences
    sentences = split_text_into_sentences(text, data_language)

    # Step 4: Chunk sentences into groups
    chunks = chunk_sentences(sentences, chunk_size)

    # Use file name as the document ID and create metadata with chunk index
    file_name = os.path.basename(file_path)
    for i, chunk_text in enumerate(chunks):
        # Step 5: Embed each chunk
        embedding = embedding_model.encode(
                chunk_text,
                truncation=True,
                max_length=max_seq_length
        )

        # Create a unique ID for each chunk
        chunk_id = f"{file_name}_chunk_{i}"
        collection.add(
            documents=[chunk_text],
            embeddings=[embedding],
            metadatas=[{"file_name": file_name, "chunk_id": i}],
            ids=[chunk_id]
            )

print("\n--- Embedding and Storage Complete ---")
print(f"Stored {len(file_paths)} documents in ChromaDB.\n")

In [None]:
class ChromaRetriever:
    """
    A class for retrieving documents from a ChromaDB collection based on semantic similarity using embeddings.
    """
    def __init__(self, embedding_model: str, db_path: str, db_collection: str, n_results: int) -> None:
        self.embedding_model = embedding_model
        self.db_path = db_path
        self.db_collection = db_collection
        self.n_results = n_results
        self.model = SentenceTransformer(self.embedding_model, trust_remote_code=True)
        self.client = chromadb.PersistentClient(path=self.db_path)
        self.collection = self.client.get_collection(name=self.db_collection)

    def retrieve(self, query: str):
        """Embeds the query and retrieves relevant documents from the collection."""
        try:
            embedded_query = self.model.encode(query)
            results = self.collection.query(
                query_embeddings=[embedded_query],
                n_results=self.n_results
            )
            return results
        except Exception as e:
            print(f"An error occurred during retrieval: {e}")
            return None
        

    def format_results_for_prompt(self, results):
        """
        Formats the retrieval results into a string suitable for the Responder's prompt.

        Args:
            results: The dictionary returned by the retrieve method.

        Returns:
            A formatted string containing the retrieved data.
        """
        if not results:
            return "No relevant data found."

        formatted_data = ""
        for idx, (doc, metadata) in enumerate(zip(results['documents'][0], results['metadatas'][0])):
            chunk_id = metadata.get('chunk_id', 'N/A')
            file_name = metadata.get('file_name', 'N/A')
            formatted_data += f"Document {idx + 1}:\n"
            formatted_data += f"Document ID: {chunk_id}\n"
            formatted_data += f"File Name: {file_name}\n"
            formatted_data += f"Content:\n{doc}\n"
            formatted_data += "-" * 80 + "\n"

        return formatted_data

In [None]:
class Responder:
    """
    A class to generate responses using the Ollama LLM within a RAG framework.
    """

    def __init__(self, data: str, model: str, prompt_template: str, query: str) -> None:
        """
        Initialize the Responder instance.

        Args:
            data: The output from the retriever to be added to the prompt
            model: The name of the LLM model to use.
            prompt_template: The template string for the prompt.
            query: The user's query.
        """
        self.data = data 
        self.model = model 
        self.prompt_template = prompt_template
        self.query = query

        self.prompt = prompt_template.format(data=self.data, query=self.query)

    
    def generate_response(self) -> str:
        """
        Generate a response based on the query and data.

        Returns:
            The response generated by the LLM.
        """
        self._check_model()
        try:
            model_output = ollama.generate(model=self.model, prompt=self.prompt)
            return model_output['response']
        except KeyError as e:
            raise ValueError(f"Response does not contain expected key: {e}")
        except Exception as e:
            raise RuntimeError(f"An error occurred during response generation: {e}")
        

    def stream_response(self):
        """
        Stream a response based on the query and data for a chatbot environment.
        """
        self._check_model()
        try:
            response_generator = ollama.generate(model=self.model, prompt=self.prompt, stream=True)
            
            for chunk in response_generator:
                print(chunk['response'], end='', flush=True)
            print("\n")
            return ""

        except KeyError as e:
            raise ValueError(f"Response does not contain expected key: {e}")
        except Exception as e:
            raise RuntimeError(f"An error occurred during response generation: {e}")
        
    
    def stream_response_chunks(self):
        """
        Returns a generator that yields chunks of the response text.
        """
        self._check_model()
        try:
            response_generator = ollama.generate(model=self.model, prompt=self.prompt, stream=True)
            for chunk in response_generator:
                yield chunk['response']
        except KeyError as e:
            raise ValueError(f"Response does not contain expected key: {e}")
        except Exception as e:
            raise RuntimeError(f"An error occurred during response generation: {e}")

        
    def _check_model(self):
        """
        Herper function to check if the specified model is available. If not, attempt to download it.
        """
        try:
            models = ollama.list()['models']
            model_names = [model['name'] for model in models]
        except Exception as e:
            raise RuntimeError(f"Failed to retrieve the list of models: {e}")

        if self.model not in model_names:
            print(f"Model '{self.model}' is not downloaded. Attempting to download...")
            try:
                ollama.pull(self.model)
                print(f"Successfully downloaded model '{self.model}'.")
            except ollama.ResponseError as e:
                raise ValueError(f"Model '{self.model}' does not exist in the Ollama repository. Please check the model name.")
            except Exception as e:
                raise RuntimeError(f"An error occurred while downloading the model '{self.model}': {e}")

In [None]:
#let's try the retriever by itself first

results_numbers = 5

retriever = ChromaRetriever(embedding_model=model_name, 
                                db_path=db_directory, 
                                db_collection=collection_name, 
                                n_results=results_numbers)

while True:
    query = str(input("Type a query to search the DB. Type 'quit' to exit:  "))

    if query.lower() == 'quit':
        break
    else:
        results = retriever.retrieve(query)


            # Print out the results
        print("\n--- Query Results ---\n")
        for idx, (doc, metadata, distance) in enumerate(zip(results['documents'][0], results['metadatas'][0], results['distances'][0])):
            print(f"Result {idx + 1}:")
            print(f"Document ID: {metadata.get('chunk_id', 'N/A')}")
            print(f"File Name: {metadata.get('file_name', 'N/A')}")
            print(f"Distance: {distance}")
            print(f"Content:\n{doc}\n")
            print("-" * 80)


In [None]:
#Let's try the whole
while True:
    retriever = ChromaRetriever(embedding_model=model_name, 
                                db_path=db_directory, 
                                db_collection=collection_name, 
                                n_results=5)
        
    user_query = str(input("Ask a question. Type quit to exit:  "))
    if user_query.lower() == "quit":
        break
    else:
        print("Looking the DB for relevant information .......")
        # get the data for the RAG and put it in str format
        search_results = retriever.retrieve(user_query)
        formated_result = retriever.format_results_for_prompt(search_results)

        responder = Responder(data=formated_result, model=llm_model, prompt_template=prompt, query=user_query)
        responder.stream_response()