# RAG Retrieval

In [2]:
import os
from dotenv import load_dotenv
from openai import OpenAI
import chromadb
import chromadb.utils.embedding_functions as embedding_functions

In [3]:
class SearchEngine:
    """
    A search engine that uses a language model to extract key words from a query and search a database.

    Attributes:
        extra_exembedding (bool): Flag to use extra embedding functions.
        llm (OpenAI): The language model instance.
        llm_model (str): The name of the language model.
    """
    def __init__(self, db_path='db', collection_name='demo',extra_exembedding=True,enable_logging=True):
        """
        Initializes the SearchEngine with a database path and collection name.

        Args:
            db_path (str): The path to the database.
            collection_name (str): The name of the collection in the database.
            extra_exembedding (bool): Flag to use extra embedding functions.
        """
        
        if not os.path.exists(db_path):
            raise ValueError(f"Database path '{db_path}' does not exist.")

        self.extra_exembedding = extra_exembedding
        self.enable_logging = enable_logging

        load_dotenv()
        api_key = os.environ.get("OPENAI_API_KEY")
        base_url = os.environ.get("OPENAI_BASE_URL")
        embedding_model = os.environ.get("OPENAI_EMBEDDING_NAME", "text-embedding-3-small")
        
        if not api_key:
            raise ValueError("API key must be set in environment variables.")
        
        client = chromadb.PersistentClient(path=db_path)
        
        try:
            if self.extra_exembedding:
                if base_url:
                    openai_ef = embedding_functions.OpenAIEmbeddingFunction(
                        api_key=api_key,
                        api_base=base_url,
                        model_name=embedding_model
                    )
                else:
                    openai_ef = embedding_functions.OpenAIEmbeddingFunction(
                        api_key=api_key,
                        model_name=embedding_model
                    )
            
                self.collection = client.get_collection(name=collection_name,embedding_function=openai_ef)
            else:
                self.collection = client.get_collection(name=collection_name)
            if self.enable_logging:
                print(f"\nConnected to collection {collection_name} in database {db_path} for searching.")
        except:
            raise ValueError(f"Collection '{collection_name}' does not exist in database at path '{db_path}'.")
        
        self.llm = OpenAI(api_key=api_key, base_url=base_url) if base_url else OpenAI(api_key=api_key)
        self.llm_model = os.environ.get("OPENAI_MODEL_NAME","gpt-4o-mini")

    def search(self, query, n_results=1):
        """
        Searches the database using a query.

        Args:
            query (str): The search query.
            n_results (int): The number of results to return.

        Returns:
            list: The search results.
        """

        return self.collection.query(
            query_texts=[query],
            n_results=n_results
        )
        
        
    
    def hybrid_search(self, query, text_results=3,semantic_results=3):
        """
        Performs a hybrid search using both text and semantic recall.

        Args:
            query (str): The search query.
            text_results (int): The number of text-based results to return.
            semantic_results (int): The number of semantic-based results to return.

        Returns:
            list: The combined search results.
        """
        if text_results + semantic_results < 1:
            raise ValueError("The sum of text_results and semantic_results must be greater than 0.")
        
        text_recall = None
        text_semanti_recall = None
        
        if text_results > 0:
            simple_key_word = self._extract_simple_key_word_from_query(query)
            
            if self.enable_logging:
                print(f"Simple key word for text recall: {simple_key_word}")
            
            text_recall = self.collection.query(
                    query_texts=[query],
                    n_results=text_results,
                    where_document={"$contains": simple_key_word}
                )
        
        if semantic_results > 0:
            key_words = self._extract_key_words_from_query(query)
            
            if self.enable_logging:
                print(f"Key words for semantic recall: {key_words}")
            
            text_semanti_recall = self.collection.query(
                query_texts=[query],
                n_results=semantic_results,
                where_document={"$or": [{"$contains": kw} for kw in key_words]}
                )

        simple_search_recall = self.search(query, n_results = text_results + semantic_results)
        
        results = self.combine_objects(text_recall,text_semanti_recall)
        
        results = self.combine_objects(results,simple_search_recall)
        
        return results
    
    def _extract_simple_key_word_from_query(self, query):
        """
        Extracts a simple key word from the query using the language model.
        """
        prompt = (
            f"Analyze the query below, identify a single key search term for recal.\n\n"
            f"{query}\n\n"
            f"Just return the final key word."
        )
        
        try:
            response = self.llm.chat.completions.create(
                messages=[{"role": "user","content": prompt}],
                model=self.llm_model,
                temperature=0.0,
                max_tokens=1024,
            )
            key_word = response.choices[0].message.content.strip()
            return key_word
        
        except Exception as e:
            print(f"Error extracting keywords: {e}")
            return None
    
    def _extract_key_words_from_query(self, query):
        """
        Extracts key words from the query using the language model.

        Args:
            query (str): The search query.

        Returns:
            list: A list of key words.
        """
        prompt = (
            f"Analyze the query below, identify key search terms for recall, and expand the query using synonyms, lemmatization and ensure to include both single words and phrases.\n\n"
            f"{query}\n\n"
            f"Just return the final key words in a comma-separated list like 'key_word1,key_word2,key_word3',at most five key words."
        )
        
        try:
            response = self.llm.chat.completions.create(
                messages=[{"role": "user","content": prompt}],
                model=self.llm_model,
                temperature=0.1,
                max_tokens=1024,
            )
            key_words = [kw.strip() for kw in response.choices[0].message.content.split(",")]
            
            prompt_translation = (
                f"Translate the key words below into both Chinese and English.\n\n"
                f"{key_words}\n\n"
                f"Just return the final translated key words in a comma-separated list like 'key_word1,key_word2,key_word3,关键词1,关键词2,关键词3'."
            )
            
            response = self.llm.chat.completions.create(
                messages=[{"role": "user","content": prompt_translation}],
                model=self.llm_model,
                temperature=0.1,
                max_tokens=1024,
            )
            key_words = [kw.strip() for kw in response.choices[0].message.content.split(",")]
            
            return key_words
        except Exception as e:
            print(f"Error extracting keywords: {e}")
            return []
        
    @staticmethod
    def combine_objects(obj1, obj2):
        """
        Combines two search result objects.

        Args:
            obj1 (dict): The first search result object.
            obj2 (dict): The second search result object.

        Returns:
            dict: The combined search result object.
        """
        if obj1 is None:
            return obj2
        if obj2 is None:
            return obj1
        
        obj1_ids = set(obj1['ids'][0])
        
        result = {key: obj1[key] for key in ['ids', 'distances', 'metadatas', 'documents']}
        
        for index, obj_id in enumerate(obj2['ids'][0]):
            if obj_id not in obj1_ids:
                for key in result:
                    result[key][0].append(obj2[key][0][index])
                
        return result

In [4]:
se = SearchEngine(db_path='db', collection_name='demo')


Connected to collection demo in database db for searching.


In [5]:
se.search("Show me how to meditate" , n_results=3)

{'ids': [['5f49075d270bc6b893f27710f247fb81d8ac82b49832a48e9ddd4c1d9c0a4fbf',
   '7cb5944269b1fa2985ea8e460141469f6ec31ffd19ea428062964fd2344dbfc0',
   'de8051914982b15f0a2449e2c14b1b3f957f41b594b2bb9c070b646d32dbecd7']],
 'distances': [[0.8542461668006612, 0.8923088599274618, 0.9893264975740468]],
 'metadatas': [[{'Header 1': 'The Transformative Power of Meditation: A Path to Inner Peace',
    'Header 2': 'Getting Started with Meditation',
    'Header 3': 'Basic Steps to Meditate',
    'file': 'doc/example_doc.md'},
   {'Header 1': 'The Transformative Power of Meditation: A Path to Inner Peace',
    'Header 2': 'A Simple Meditation Practice: 10-Minute Guided Session',
    'file': 'doc/example_doc.md'},
   {'Header 1': 'The Transformative Power of Meditation: A Path to Inner Peace',
    'Header 2': 'Getting Started with Meditation',
    'Header 3': 'Creating a Meditation Space',
    'file': 'doc/example_doc.md'}]],
 'embeddings': None,
 'documents': [['### Basic Steps to Meditate  \n1.

In [6]:
se.search("what is RAG", n_results=3)

{'ids': [['0917d4f70bf429c14ef0bd82c019fa91b3c118a34d62a92870fac05f14ecbba0',
   '6380585168007facd2291bdfe045e9e8fb693ab78a90d4b625d4a6549764c3c9',
   'ba934fae265c2f969b9ac92793af578b8be398b6552c03c0187c615366d5d73b']],
 'distances': [[1.0776456932135008, 1.3660479009299706, 1.6847147858983502]],
 'metadatas': [[{'file': 'doc/example_doc_pdf_converted.md'},
   {'file': 'doc/example_doc_pdf_converted.md'},
   {'file': 'doc/example_doc_pdf_converted.md'}]],
 'embeddings': None,
 'documents': [['**Retrieval-Augmented Generation (RAG) with Large Language Models (LLMs)**  \n**1. Introduction**  \nRetrieval-Augmented Generation (RAG) is a powerful approach that combines the strengths of\ninformation retrieval and generative models. By integrating a retrieval mechanism with a large  \nlanguage model (LLM), RAG can provide more accurate and contextually relevant responses, especially\nin knowledge-intensive tasks.  \n**2. Objectives**  \nEnhance the accuracy of responses generated by LLMs.',

In [7]:
se.hybrid_search("简要介绍数学", text_results=2,semantic_results=2)

Simple key word for text recall: 数学
Key words for semantic recall: ['mathematics', 'mathematical concepts', 'introduction to mathematics', 'mathematical foundation', 'mathematical knowledge', '数学', '数学概念', '数学简介', '数学基础', '数学知识']


{'ids': [['bdfd10149846d1e9176a4d84d075cf956291853d0d414d7dc16288a96a39758a',
   'd8c909c465e32586550e0da5b7d21d43ad1b8bb1d80e5d62af19c58d6dfd4614',
   '1282b6325307bf497b5578ecd28429c54ee1a1976fe4dd40a5ac5e376dd327d2',
   '4de0943f2fb85e89a4adde6cdd96806b44a81b1fa5845a3a1295983f96cc9a2c']],
 'distances': [[1.0612587436526446,
   1.0882368905727917,
   1.0962425463988101,
   1.0971993417315247]],
 'metadatas': [[{'file': 'doc/subdir/subdir_example_doc_pdf_converted.md'},
   {'file': 'doc/subdir/subdir_example_doc_pdf_converted.md'},
   {'file': 'doc/subdir/subdir_example_doc_pdf_converted.md'},
   {'file': 'doc/subdir/subdir_example_doc_pdf_converted.md'}]],
 'documents': [['**The Beauty of Mathematics**  \nMathematics is often described as the language of the universe, a discipline that transcends cultural  \nand linguistic boundaries. It is not only a tool for solving problems but also a profound way of\nunderstanding the world around us. This article explores the essence of mathemat

In [8]:
se._extract_key_words_from_query("Show me how to meditate")

['meditate',
 'meditation',
 'mindfulness',
 'practice',
 'techniques',
 '冥想',
 '静坐',
 '正念',
 '练习',
 '技巧']

In [9]:
se._extract_simple_key_word_from_query("Show me how to meditate")

'meditate'

6