In [1]:
!pip install llama-index==0.11.10
!pip install llama-index-llms-bedrock==0.2.1
!pip install llama-index-embeddings-bedrock==0.3.1

Collecting llama-index==0.11.10
  Using cached llama_index-0.11.10-py3-none-any.whl.metadata (11 kB)
Collecting llama-index-agent-openai<0.4.0,>=0.3.1 (from llama-index==0.11.10)
  Using cached llama_index_agent_openai-0.3.4-py3-none-any.whl.metadata (728 bytes)
Collecting llama-index-cli<0.4.0,>=0.3.1 (from llama-index==0.11.10)
  Using cached llama_index_cli-0.3.1-py3-none-any.whl.metadata (1.5 kB)
Collecting llama-index-core<0.12.0,>=0.11.10 (from llama-index==0.11.10)
  Using cached llama_index_core-0.11.14-py3-none-any.whl.metadata (2.4 kB)
Collecting llama-index-embeddings-openai<0.3.0,>=0.2.4 (from llama-index==0.11.10)
  Using cached llama_index_embeddings_openai-0.2.5-py3-none-any.whl.metadata (686 bytes)
Collecting llama-index-indices-managed-llama-cloud>=0.3.0 (from llama-index==0.11.10)
  Using cached llama_index_indices_managed_llama_cloud-0.4.0-py3-none-any.whl.metadata (3.8 kB)
Collecting llama-index-legacy<0.10.0,>=0.9.48 (from llama-index==0.11.10)
  Using cached llama

In [2]:
import boto3

from llama_index.llms.bedrock import Bedrock
from llama_index.embeddings.bedrock import BedrockEmbedding
from llama_index.core import ServiceContext, StorageContext, load_index_from_storage
from llama_index.core.node_parser.text.sentence import SentenceSplitter
from llama_index.core.prompts import PromptTemplate
from llama_index.core.response_synthesizers import ResponseMode
from llama_index.core.prompts.prompt_type import PromptType
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core import Settings

In [3]:
from llama_index.core.workflow import Event
from llama_index.core.workflow import (
    Workflow,
    StartEvent,
    StopEvent,
    Context,
    step,
)
from llama_index.core.indices.utils import (
    default_format_node_batch_fn,
)
from llama_index.core.retrievers import BaseRetriever, VectorIndexRetriever
from llama_index.core.output_parsers.pydantic import PydanticOutputParser
from llama_index.core.program import LLMTextCompletionProgram
from pydantic import BaseModel, Field
from typing import Any, List, Optional, Type, Union

SUB_QUESTION_PROMPT_STR = """
Vous êtes un assistant utile qui génère plusieurs requêtes de recherche basées sur une seule requête d'entrée. 
Générer 2 requêtes de recherche, une sur chaque ligne, liées à la requête d'entrée suivante:
Requête: {query}
Requêtes:
"""

PYDANTIC_FORMAT_TMPL = """
Voici un schéma JSON à suivre:
{schema}

Générez un objet JSON valide mais ne répétez pas le schéma.
"""

CHOICE_SELECT_PROMPT_TMPL = """
Une liste de documents est présentée ci-dessous. Chaque document est accompagné d'un numéro et d'un résumé du document. Une question est également fournie.
Donne-moi uniquement les numéros des documents pertinents à la question en utilisant le format du tableau.
Voici quelques exemples: 
Document 1:
<résumé du document 1>

Document 2:
<résumé du document 2>

...

Document 10:
<résumé du document 10>

Question: <question>
Tableau des nombres des documents pertinents: [2, 4, 5, 6]

Essayons ceci maintenant :

{context_str}
Question: {query_str}
Tableau des nombres des documents pertinents:
"""

CHOICE_REFLECTION_PROMPT_STR = """
Cela a provoqué l'erreur : {error}

Réessayez, la réponse doit contenir uniquement les numéros des documents pertinents.
Ne répétez pas le schéma.
"""

NEGATIVE_FILTER_PROMPT_TMPL = """
Voici une liste de contextes pertinents pour une requête donnée. Vous devez filtrer cette liste pour exclure les contextes qui correspondent à un contexte négatif spécifique.
Donne-moi uniquement les numéros des documents en utilisant le format du tableau.
Voici quelques exemples: 
Requête principale : <question>
Contexte négatif : <contexte_negatif>
Liste des contextes :
Document 1:
<résumé du document 1>

Document 2:
<résumé du document 2>

Document 3:
<résumé du document 10>

Document 4:
<résumé du document 10>

Tableau des nombres des documents: [1, 2, 3]

Essayons ceci maintenant :

Requête principale : {query_str}
Contexte négatif : {contexte_negatif}
Liste des contextes : {contexte}
Tableau des nombres des documents:
"""

NEGATIVE_FILTER_REFLECTION_PROMPT_STR = """
Cela a provoqué l'erreur : {error}

Réessayez, la réponse doit contenir uniquement les numéros des documents.
Ne répétez pas le schéma.
"""

class DocumentNumberInfo(BaseModel):
    """Informations concernant un tableau structuré."""
    tableau_document: list = Field(
        ..., description="le tableau des nombres des documents"
    )

class SubQuestionEvent(Event):
    query: str
    neg_context: str
    nodes: list

class RerankerValidationErrorEvent(Event):
    error: str
    query: str
    neg_context: str
    nodes: list
    rerank_nodes: list

class RerankerBatchEvent(Event):
    query: str
    neg_context: str
    nodes: list
    rerank_nodes: list

class RerankerDone(Event):
    prompt: str
    query: str
    neg_context: str
    nodes: list
    rerank_nodes: list
    batch_nodes: list

class RerankerValidationDone(Event):
    query: str
    neg_context: str
    nodes: list
    rerank_nodes: list

class FilterValidationErrorEvent(Event):
    error: str
    query: str
    neg_context: str
    nodes: list
    rerank_nodes: list

class FilterDone(Event):
    prompt: str
    query: str
    neg_context: str
    nodes: list
    rerank_nodes: list

class RerankerFlow(Workflow):
    def __init__(self, vector_retriever: VectorIndexRetriever, timeout = 120, verbose = True):
        self.llm = Bedrock(model="anthropic.claude-3-sonnet-20240229-v1:0", region_name="eu-west-3")
        self.vector_retriever = vector_retriever
        self.batch_size = 3 # the divided size to rerank all the nodes 
        self.max_retries = 3 # the max retries for the prompt
        self.min_chunk = 10 # the minimum size of the chunks after the rerank
        super().__init__(timeout=timeout, verbose=verbose)

    # the state to generate sub questions and put inside nodes
    @step()
    async def sub_question(
        self, ev: StartEvent
    ) -> SubQuestionEvent:
        query = ev.query
        nodes = ev.nodes
        neg_context = ev.neg_context
        # subquestions = []
        # try:
        #     subquestions_response = Settings.llm.complete(SUB_QUESTION_PROMPT_STR.format(query=query))
        #     subquestions = subquestions_response.text.split("\n")
        # except Exception as e:
        #     print(e)
        # if len(subquestions) > 2:
        #     nodes1 = self.vector_retriever.retrieve(subquestions[-2])
        #     nodes2 = self.vector_retriever.retrieve(subquestions[-1])
        #     one_third_index = int(len(nodes)/3)
        #     two_third_index = int(len(nodes)/3*2)
        #     nodes_ids = []
        #     for node in nodes:
        #         nodes_ids.append(node.id_)
        #     indicator1 = 0
        #     indicator2 = 0
        #     for i in range(len(nodes1)):
        #         if nodes1[i].id_ not in nodes_ids:
        #             indicator1 += 1
        #             nodes_ids.append(nodes1[i].id_)
        #             if indicator1 == 1:
        #                 nodes.insert(one_third_index, nodes1[i])
        #             elif indicator1 == 2:
        #                 nodes.insert(one_third_index, nodes1[i])
        #             elif indicator1 == 3:
        #                 nodes.insert(two_third_index, nodes1[i])
        #             elif indicator1 == 4:
        #                 nodes.insert(two_third_index, nodes1[i])
        #             elif indicator1 == 5:
        #                 nodes.append(nodes1[i])
        #             elif indicator1 == 6:
        #                 nodes.append(nodes1[i])
        #                 break
        #     for i in range(len(nodes2)):
        #         if nodes2[i].id_ not in nodes_ids:
        #             indicator2 += 1
        #             nodes_ids.append(nodes2[i].id_)
        #             if indicator2 == 1:
        #                 nodes.insert(one_third_index, nodes2[i])
        #             elif indicator2 == 2:
        #                 nodes.insert(one_third_index, nodes2[i])
        #             elif indicator2 == 3:
        #                 nodes.insert(two_third_index, nodes2[i])
        #             elif indicator2 == 4:
        #                 nodes.insert(two_third_index, nodes2[i])
        #             elif indicator2 == 5:
        #                 nodes.append(nodes2[i])
        #             elif indicator2 == 6:
        #                 nodes.append(nodes2[i])
        #                 break
        return SubQuestionEvent(query=query, neg_context=neg_context, nodes=nodes)

    # the state to generate prompt for reranking
    @step(pass_context=True)
    async def reranker_prompt(
        self, ctx: Context, ev: Union[SubQuestionEvent, RerankerValidationErrorEvent, RerankerBatchEvent]
    ) -> Union[StopEvent, RerankerDone]:
        query = ev.query
        nodes = ev.nodes
        neg_context = ev.neg_context
        window = int(len(nodes)/3)

        if isinstance(ev, SubQuestionEvent):
            rerank_nodes = []
            current_batch = ctx.data.get("batch", 0)
            ctx.data["batch"] = current_batch + 1
            batch_nodes = nodes[window*current_batch:window*(current_batch+1)]
            reflection_prompt = ""
        elif isinstance(ev, RerankerBatchEvent):
            rerank_nodes = ev.rerank_nodes
            current_batch = ctx.data.get("batch", 0)
            # iterate the nodes based on batch size
            if current_batch >= self.batch_size:
                print("Reranked nodes are extracted.")
                # if no negative context, output reranked nodes. otherwise, go the state of filter
                if neg_context == '':
                    if len(rerank_nodes) == 0:
                        return StopEvent(result=None)
                    else:
                        return StopEvent(result=rerank_nodes)
                else:
                    return RerankerValidationDone(query=query, neg_context=neg_context, nodes=nodes, rerank_nodes=rerank_nodes)
            else:
                ctx.data["batch"] = current_batch + 1
            batch_nodes = nodes[window*current_batch:window*(current_batch+1)]
            reflection_prompt = ""
        # prompt error event
        elif isinstance(ev, RerankerValidationErrorEvent):
            current_retries = ctx.data.get("prompt_retries", 0)
            rerank_nodes = ev.rerank_nodes
            if current_retries >= self.max_retries:
                if neg_context == '':
                    if len(rerank_nodes) == 0:
                        return StopEvent(result=None)
                    else:
                        return StopEvent(result=rerank_nodes)
                else:
                    return RerankerValidationDone(query=query, neg_context=neg_context, nodes=nodes, rerank_nodes=rerank_nodes)
            else:
                ctx.data["prompt_retries"] = current_retries + 1
            current_batch = ctx.data.get("batch", 0)
            batch_nodes = nodes[window*current_batch:window*(current_batch+1)]
            reflection_prompt = CHOICE_REFLECTION_PROMPT_STR.format(error=ev.error)
        
        prompt = CHOICE_SELECT_PROMPT_TMPL
        if reflection_prompt:
            prompt += reflection_prompt
        return RerankerDone(prompt=prompt, query=query, neg_context=neg_context, nodes=nodes, rerank_nodes=rerank_nodes, batch_nodes=batch_nodes)

    # call LLM to rerank nodes
    @step()
    async def rerank_validate(
        self, ev: RerankerDone
    ) -> Union[StopEvent, RerankerValidationDone, RerankerValidationErrorEvent, RerankerBatchEvent]:
        try:
            choice_output_parser = PydanticOutputParser(output_cls=DocumentNumberInfo, pydantic_format_tmpl=PYDANTIC_FORMAT_TMPL)
            choice_program = LLMTextCompletionProgram.from_defaults(
                output_parser=choice_output_parser,
                llm=self.llm,
                prompt_template_str=ev.prompt,
            )
            output = choice_program(context_str=default_format_node_batch_fn(ev.batch_nodes), query_str=ev.query)
            new_nodes = []
            for i in output.tableau_document:
                if i-1 < len(ev.batch_nodes):
                    new_nodes.append(ev.batch_nodes[i-1])
        except Exception as e:
            print("Validation failed, retrying...")
            return RerankerValidationErrorEvent(
                error=str(e), query=ev.query, neg_context=ev.neg_context, nodes=ev.nodes, rerank_nodes=ev.rerank_nodes
            )
        rerank_nodes = new_nodes + ev.rerank_nodes
        # if reranked nodes are less than min chunks, return to the state of batch event, otherwise output the reranked nodes.
        if len(rerank_nodes) > self.min_chunk:
            print("Reranked nodes are extracted.")
            if ev.neg_context == '':
                return StopEvent(result=rerank_nodes)
            else:
                return RerankerValidationDone(query=ev.query, neg_context=ev.neg_context, nodes=ev.nodes, rerank_nodes=rerank_nodes)
        else:
            return RerankerBatchEvent(query=ev.query, neg_context=ev.neg_context, nodes=ev.nodes, rerank_nodes=rerank_nodes)

    # prepare the prompt to filter the reranked nodes
    @step(pass_context=True)
    async def filter_prompt(
        self, ctx: Context, ev: Union[RerankerValidationDone, FilterValidationErrorEvent]
    ) -> Union[StopEvent, FilterDone]:
        query = ev.query
        nodes = ev.nodes
        neg_context = ev.neg_context
        rerank_nodes = ev.rerank_nodes

        if isinstance(ev, RerankerValidationDone):
            reflection_prompt = ""
        elif isinstance(ev, FilterValidationErrorEvent):
            current_retries = ctx.data.get("prompt_retries", 0)
            if current_retries >= self.max_retries:
                if len(rerank_nodes) == 0:
                    return StopEvent(result=None)
                else:
                    return StopEvent(result=rerank_nodes)
            else:
                ctx.data["prompt_retries"] = current_retries + 1
            reflection_prompt = NEGATIVE_FILTER_REFLECTION_PROMPT_STR.format(error=ev.error)
        prompt = NEGATIVE_FILTER_PROMPT_TMPL
        if reflection_prompt:
            prompt += reflection_prompt
        return FilterDone(prompt=prompt, query=query, neg_context=neg_context, nodes=nodes, rerank_nodes=rerank_nodes)

    # call LLM to filter the reranked nodes
    @step(pass_context=True)
    async def filter_validate(
        self, ctx: Context, ev: FilterDone
    ) -> Union[StopEvent, FilterValidationErrorEvent, RerankerBatchEvent]:
        try:
            choice_output_parser = PydanticOutputParser(output_cls=DocumentNumberInfo, pydantic_format_tmpl=PYDANTIC_FORMAT_TMPL)
            filter_program = LLMTextCompletionProgram.from_defaults(
                output_parser=choice_output_parser,
                llm=self.llm,
                prompt_template_str=ev.prompt,
            )
            output = filter_program(context_str=default_format_node_batch_fn(ev.rerank_nodes), query_str=ev.query, contexte_negatif=ev.neg_context)
        except Exception as e:
            print("Validation failed, retrying...")
            return FilterValidationErrorEvent(
                error=str(e), query=ev.query, neg_context=ev.neg_context, nodes=ev.nodes, rerank_nodes=ev.rerank_nodes
            )
        new_nodes = []
        for i in output.tableau_document:
            new_nodes.append(ev.rerank_nodes[i-1])
        print("Reranked nodes are filtered.")
        current_batch = ctx.data.get("batch", 0)
        if len(new_nodes) > self.min_chunk:
            return StopEvent(result=new_nodes)
        else:
            if current_batch >= self.batch_size:
                if len(new_nodes) == 0:
                    return StopEvent(result=None)
                else:
                    return StopEvent(result=new_nodes)
            else:
                return RerankerBatchEvent(query=ev.query, neg_context=ev.neg_context, nodes=ev.nodes, rerank_nodes=new_nodes)

In [4]:
from typing import Any, Dict, List, Optional

from llama_index.core.callbacks.base import CallbackManager
from llama_index.core.retrievers import (
    BaseRetriever,
)
from llama_index.core.indices.keyword_table.base import BaseKeywordTableIndex
from llama_index.core.schema import NodeWithScore, QueryBundle

class KeywordTableLexiqueRetriever(BaseRetriever):
    """
    Extracts keywords from lexique using space separator.
    """
    def __init__(
        self,
        index: BaseKeywordTableIndex,
        callback_manager: Optional[CallbackManager] = None,
        object_map: Optional[dict] = None,
        verbose: bool = False,
        **kwargs: Any,
    ) -> None:
        self.index = index
        self.lexique_keywords = self._get_lexique_keywords()
        super().__init__(
            callback_manager=callback_manager,
            object_map=object_map,
            verbose=verbose,
        )
    def _get_lexique_keywords(self):
        lexique_keywords = []
        for item in self.index.docstore.docs.values():
            if "keyword" in item.metadata:
                lexique_keywords.append(item.metadata["keyword"])
        return lexique_keywords

    def _get_keywords(self, query_str: str) -> List[str]:
        words = query_str.split()
        keywords = []

        for word in words:
            if word in self.lexique_keywords and word not in keywords:
                keywords.append(word)
            elif word[:-1] in self.lexique_keywords and word[:-1] not in keywords:
                keywords.append(word[:-1])
            elif word[1:] in self.lexique_keywords and word[1:] not in keywords:
                keywords.append(word[1:])
        return keywords

    def _retrieve(
        self,
        query_bundle: QueryBundle,
    ) -> List[NodeWithScore]:
        """Get nodes for response."""
        
        keywords = self._get_keywords(query_bundle.query_str)
        print(f"query keywords: {keywords}")

        sorted_nodes = []
        for node in self.index.docstore.docs.values():
            if "keyword" in node.metadata and node.metadata["keyword"] in keywords:
                sorted_nodes.append(node)

        return [NodeWithScore(node=node) for node in sorted_nodes]
    
    async def _aretrieve(
        self,
        query_bundle: QueryBundle,
    ) -> List[NodeWithScore]:
        """Get nodes for response."""
        
        keywords = self._get_keywords(query_bundle.query_str)
        print(f"query keywords: {keywords}")

        sorted_nodes = []
        for node in self.index.docstore.docs.values():
            if "keyword" in node.metadata and node.metadata["keyword"] in keywords:
                sorted_nodes.append(node)

        return [NodeWithScore(node=node) for node in sorted_nodes]

In [5]:
from llama_index.core.schema import QueryBundle, NodeWithScore
from llama_index.core.retrievers import BaseRetriever, VectorIndexRetriever
from llama_index.core.indices.keyword_table import KeywordTableGPTRetriever
from typing import List
import threading
import asyncio
from queue import Queue

class CustomRetriever(BaseRetriever):
    """Custom retriever that performs both semantic search and hybrid search."""

    def __init__(
        self,
        vector_retriever: VectorIndexRetriever,
        keyword_retriever: KeywordTableGPTRetriever,
        embed_model: BedrockEmbedding,
        old_queries: List = [],
        mode: str = "OR",
        
    ) -> None:
        """Init params."""

        self._vector_retriever = vector_retriever
        self._keyword_retriever = keyword_retriever
        self.embed_model = embed_model
        self.reranker = RerankerFlow(vector_retriever=vector_retriever)
        self.old_queries = old_queries
        index_feedback = load_index_from_storage(storage_context=StorageContext.from_defaults(persist_dir="finetune/vector_persist_800_feedback"))
        self.feedback_retriever = VectorIndexRetriever(index=index_feedback, similarity_top_k=1)
        if mode not in ("AND", "OR"):
            raise ValueError("Invalid mode.")
        self._mode = mode
        super().__init__()

    def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
        """Retrieve nodes given query."""
        print("async not activated.")
        keyword_nodes = self._keyword_retriever.retrieve(query_bundle)
        new_query_str = self.replace_abbreviations(query_bundle.query_str, keyword_nodes)
        query_bundle.embedding = self._weighted_aggre_embedding(new_query_str)
        vector_nodes = self._vector_retriever.retrieve(query_bundle)

        vector_ids = {n.node.node_id for n in vector_nodes}
        keyword_ids = {n.node.node_id for n in keyword_nodes}

        combined_dict = {n.node.node_id: n for n in vector_nodes}
        combined_dict.update({n.node.node_id: n for n in keyword_nodes})

        if self._mode == "AND":
            retrieve_ids = vector_ids.intersection(keyword_ids)
        else:
            retrieve_ids = vector_ids.union(keyword_ids)

        retrieve_nodes = [combined_dict[rid] for rid in retrieve_ids]
        return retrieve_nodes

    async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
        """Retrieve nodes given query."""
        print("async activated.")
        keyword_nodes = self._keyword_retriever.retrieve(query_bundle)
        new_query_str = self.replace_abbreviations(query_bundle.query_str, keyword_nodes)
        query_bundle.embedding = await self._aweighted_aggre_embedding(new_query_str)
        vector_nodes_full = self._vector_retriever.retrieve(query_bundle)

        feedback_nodes = self.feedback_retriever.retrieve(query_bundle)
        if feedback_nodes[0].score > 0.8 : 
            neg_context = feedback_nodes[0].metadata["message"]
        else:
            neg_context = ''
        vector_nodes = await self.reranker.run(query=query_bundle.query_str, nodes=vector_nodes_full, neg_context=neg_context)
        if vector_nodes == None:
            return keyword_nodes

        vector_ids = {n.node.node_id for n in vector_nodes}
        keyword_ids = {n.node.node_id for n in keyword_nodes}

        combined_dict = {n.node.node_id: n for n in vector_nodes}
        combined_dict.update({n.node.node_id: n for n in keyword_nodes})

        if self._mode == "AND":
            retrieve_ids = vector_ids.intersection(keyword_ids)
        else:
            retrieve_ids = vector_ids.union(keyword_ids)

        retrieve_nodes = [combined_dict[rid] for rid in retrieve_ids]
        return retrieve_nodes

    def replace_abbreviations(self, question, keyword_nodes):
        """
        Replaces abbreviations in the lexique with their complete names.
        Assumes abbreviations are separated by either two spaces or one space + one punctuation mark.
        """
        lexique = {}
        for node in keyword_nodes:
            keyword = node.node.metadata['keyword']
            keyword_complete = node.text.split(': ')[1]
            lexique[keyword] = keyword_complete

        words = question.split()
        result = []

        for word in words:
            if word in lexique:
                result.append(lexique[word])
                result.append('('+word+')')
            elif word[:-1] in lexique:
                result.append(lexique[word[:-1]])
                result.append('('+word[:-1]+')')
                result.append(word[-1])
            elif word[1:] in lexique:
                result.append(word[0])
                result.append(lexique[word[1:]])
                result.append('('+word[1:]+')')
            else:
                result.append(word)

        return ' '.join(result)
    
    def _get_query_embedding_threaded(self, query):
        """
        Get the query embedding using a separate thread.
        """
        result_queue = Queue()  # Create a queue to store the result
        thread = threading.Thread(target=self._get_query_embedding_async_thread, args=(query, result_queue))
        thread.start()
        thread.join()  # Wait for the thread to finish
        return result_queue.get()  # Retrieve the result from the queue

    def _get_query_embedding_async_thread(self, query, result_queue):
        """
        Helper method to run the asynchronous call in a separate thread.
        """
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
        result = loop.run_until_complete(self.embed_model.aget_query_embedding(query))
        loop.close()
        result_queue.put(result)  # Put the result in the queue
    
    def _weighted_aggre_embedding(self, new_query_str):
        """
        A weighted aggregation embedding method which can integrate the preivous queries in the current query.
        """
        current_embedding = self._get_query_embedding_threaded(new_query_str)
        print(current_embedding)
        if len(self.old_queries) == 0:
            return current_embedding
        elif len(self.old_queries) == 1:
            first_embedding = self._get_query_embedding_threaded(self.old_queries[0])
            new_embedding = []
            for i in range(len(current_embedding)):
                new_embedding.append(0.8 * current_embedding[i] + 0.2 * first_embedding[i])
            return new_embedding
        else:
            first_embedding = self._get_query_embedding_threaded(self.old_queries[0])
            other_embeddings = []
            for query in self.old_queries[1:]:
                other_embeddings.append(self._get_query_embedding_threaded(query))
            other_embedding = list(np.array(other_embeddings).mean(axis=0))
            new_embedding = []
            for i in range(len(current_embedding)):
                new_embedding.append(0.8 * current_embedding[i] + 0.2 * first_embedding[i] + 0.1 * other_embedding[i])
            return new_embedding
    async def _aweighted_aggre_embedding(self, new_query_str):
        """
        A weighted aggregation embedding method which can integrate the preivous queries in the current query.
        """
        current_embedding = await self.embed_model.aget_query_embedding(new_query_str)
        if len(self.old_queries) == 0:
            return current_embedding
        elif len(self.old_queries) == 1:
            first_embedding = await self.embed_model.aget_query_embedding(self.old_queries[0])
            new_embedding = []
            for i in range(len(current_embedding)):
                new_embedding.append(0.8 * current_embedding[i] + 0.2 * first_embedding[i])
            return new_embedding
        else:
            first_embedding = await self.embed_model.aget_query_embedding(self.old_queries[0])
            other_embeddings = []
            for query in self.old_queries[1:]:
                other_embeddings.append(await self.embed_model.aget_query_embedding(query))
            other_embedding = list(np.array(other_embeddings).mean(axis=0))
            new_embedding = []
            for i in range(len(current_embedding)):
                new_embedding.append(0.8 * current_embedding[i] + 0.2 * first_embedding[i] + 0.1 * other_embedding[i])
            return new_embedding

In [6]:
from IPython.display import Markdown, display
def display_nodes(nodes):
    nodes_output = ''
    for node in nodes:
        nodes_output += 'ID: ' + node.id_ + '\n'
        if 'page_label' in node.metadata:
            nodes_output += 'Pagelabel: ' + node.metadata['page_label'] + '\n'
        else:
            pass
        if 'file_name' in node.metadata:
            nodes_output += 'File name: ' + node.metadata['file_name'] + '\n'
        else:
            pass
        if 'file_path' in node.metadata:
            nodes_output += 'File path: ' + node.metadata['file_path'] + '\n'
        else:
            pass
        if 'metier' in node.metadata:
            nodes_output += 'Metier: ' + str(node.metadata['metier']) + '\n'
        else:
            pass
        nodes_output += 'Score: ' + str(node.score) + '\n'
        nodes_output += 'Text: ' + str(node.text) + '\n'
        nodes_output += '################\n'
    return nodes_output

In [7]:
DEFAULT_QUERY_PROMPT_TMPL = (
    " Les informations contextuelles sont"
    " ci-dessous.\n---------------------\n{context_str}\n---------------------\n"
    " Compte tenu des informations contextuelles, répondez à la requête sans reproduire la requète.\n"
    " Baser la réponse sur le contexte fourni, sans mentionner les noms des documents sources. Structurer la réponse comme suit :\n"
    " - Formater le texte avec une ponctuation appropriés entre les phrases.\n"
    " - Lorsqu'il y a une enumeration d'actions ou d'étapes dans la réponse, les précéder d'un tiret et les numéroter (1-, 2-, etc.)\n"
    " - Commencer chaque nouvelle phrase ou action par une majuscule.\n"
    " - Relire la réponse pour s'assurer qu'elle est claire, cohérente et bien formatée avant de la soumettre."
    " Requête: {query_str}\n"
    " Réponse:\n"
)

In [8]:
KEYWORD_EXTRACT_TEMPLATE_TMPL = (
    "Un texte est fourni ci-dessous. Étant donné le texte, extrayez jusqu'à {max_keywords} "
    "keywords du texte. Évitez les stopwords."
    "---------------------\n"
    "{text}\n"
    "---------------------\n"
    "Fournissez des keywords au format suivant: 'KEYWORDS: <keywords>'\n"
)
KEYWORD_EXTRACT_TEMPLATE = PromptTemplate(
    KEYWORD_EXTRACT_TEMPLATE_TMPL, prompt_type=PromptType.KEYWORD_EXTRACT
)

In [9]:
boto3_session = boto3.session.Session()
region_name = "eu-central-1"
bedrock_agent_client = boto3_session.client('bedrock-runtime', region_name=region_name)
Settings.embed_model = BedrockEmbedding(client=bedrock_agent_client, model_name="amazon.titan-embed-text-v1")
Settings.llm = Bedrock(model="anthropic.claude-3-haiku-20240307-v1:0", region_name="eu-west-3")
chunk_size = 800
Settings.node_parser = SentenceSplitter.from_defaults(chunk_size=chunk_size, chunk_overlap=50)
Settings.num_output = 2024

In [10]:
index_doc = load_index_from_storage(storage_context=StorageContext.from_defaults(persist_dir="vector_persist_800_main"))

In [11]:
nodes = index_doc.docstore.docs.values()
print(len(nodes))

26482


In [11]:
index_keyword = load_index_from_storage(storage_context=StorageContext.from_defaults(persist_dir="vector_persist_800_lexique"))

In [12]:
vector_retriever = VectorIndexRetriever(index=index_doc, similarity_top_k=30)
keyword_retriever = KeywordTableGPTRetriever(index=index_keyword, max_keywords_per_query = 20, keyword_extract_template=KEYWORD_EXTRACT_TEMPLATE)
custom_retriever = CustomRetriever(vector_retriever, keyword_retriever, embed_model=Settings.embed_model, old_queries=[])

In [70]:
import time
start_time = time.time()
nodes = await custom_retriever.aretrieve("Que veut dire CICM?")
for node in nodes:
    print(node.id_)
    print(node.score)
    print("#######")
end_time = time.time()
execution_time = end_time - start_time
print(f"Execution Time: {execution_time} seconds")

async activated.
Running step sub_question
Step sub_question produced event SubQuestionEvent
Running step reranker_prompt
Step reranker_prompt produced event RerankerDone
Running step rerank_validate
Step rerank_validate produced event RerankerBatchEvent
Running step reranker_prompt
Step reranker_prompt produced event RerankerDone
Running step rerank_validate
Step rerank_validate produced event RerankerBatchEvent
Running step reranker_prompt
Step reranker_prompt produced event RerankerDone
Running step rerank_validate
Reranked nodes are extracted.
Step rerank_validate produced event StopEvent
14416d8b-d04e-463c-bde6-a4074dac106c
0.6292323751243815
#######
8cb8e165-c7f5-474b-ab8f-533ce6ee563e
0.6209355792525079
#######
ab038de4-5ce8-414b-b927-c5576cc0a8e4
0.6143688385279331
#######
4a955058-7ddc-4e34-9017-1e89994f927e
0.6304759656740551
#######
28e895b0-4337-45f5-b1ef-283df9c63e0c
0.6583704601825731
#######
756a78d7-ab8a-4964-9b1d-977e926a3ddc
0.6118743166167164
#######
29dc8db9-0d16-42

## old solution  -- 29.55 seconds / 4 questions

In [48]:
!pip install torch sentence-transformers

Collecting sentence-transformers
  Using cached sentence_transformers-3.1.1-py3-none-any.whl.metadata (10 kB)
Using cached sentence_transformers-3.1.1-py3-none-any.whl (245 kB)
Installing collected packages: sentence-transformers
Successfully installed sentence-transformers-3.1.1


In [49]:
from sentence_transformers import CrossEncoder
class CustomRetriever(BaseRetriever):
    """Custom retriever that performs both semantic search and hybrid search."""

    def __init__(
        self,
        vector_retriever: VectorIndexRetriever,
        keyword_retriever: KeywordTableGPTRetriever,
        embed_model: BedrockEmbedding,
        old_queries: List = [],
        mode: str = "OR",
        
    ) -> None:
        """Init params."""

        self._vector_retriever = vector_retriever
        self._keyword_retriever = keyword_retriever
        self.embed_model = embed_model
        self.old_queries = old_queries
        self.crossencoder_model = CrossEncoder('finetune/crossencoder-camembert-base-mmarcoFR')
        if mode not in ("AND", "OR"):
            raise ValueError("Invalid mode.")
        self._mode = mode
        super().__init__()

    def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
        """Retrieve nodes given query."""
        print("async not activated.")
        keyword_nodes = self._keyword_retriever.retrieve(query_bundle)
        new_query_str = self.replace_abbreviations(query_bundle.query_str, keyword_nodes)
        query_bundle.embedding = self._weighted_aggre_embedding(new_query_str)
        vector_nodes_full = self._vector_retriever.retrieve(query_bundle)
        vector_nodes = self.rerank_retrieve(query_bundle,vector_nodes_full)

        vector_ids = {n.node.node_id for n in vector_nodes}
        keyword_ids = {n.node.node_id for n in keyword_nodes}

        combined_dict = {n.node.node_id: n for n in vector_nodes}
        combined_dict.update({n.node.node_id: n for n in keyword_nodes})

        if self._mode == "AND":
            retrieve_ids = vector_ids.intersection(keyword_ids)
        else:
            retrieve_ids = vector_ids.union(keyword_ids)

        retrieve_nodes = [combined_dict[rid] for rid in retrieve_ids]
        return retrieve_nodes

    async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
        """Retrieve nodes given query."""
        print("async activated.")
        keyword_nodes = self._keyword_retriever.retrieve(query_bundle)
        for node in keyword_nodes:
            print(node.id_)
            print(node.score)
            print(node.text)
            print("#######")
        new_query_str = self.replace_abbreviations(query_bundle.query_str, keyword_nodes)
        query_bundle.embedding = await self._aweighted_aggre_embedding(new_query_str)
        vector_nodes_full = self._vector_retriever.retrieve(query_bundle)
        vector_nodes = self.rerank_retrieve(query_bundle,vector_nodes_full)

        vector_ids = {n.node.node_id for n in vector_nodes}
        keyword_ids = {n.node.node_id for n in keyword_nodes}

        combined_dict = {n.node.node_id: n for n in vector_nodes}
        combined_dict.update({n.node.node_id: n for n in keyword_nodes})

        if self._mode == "AND":
            retrieve_ids = vector_ids.intersection(keyword_ids)
        else:
            retrieve_ids = vector_ids.union(keyword_ids)

        retrieve_nodes = [combined_dict[rid] for rid in retrieve_ids]
        return retrieve_nodes

    def rerank_retrieve(self, query_bundle: QueryBundle, vector_nodes: List[NodeWithScore]):
        if len(vector_nodes) < 2:
            return vector_nodes
        pairs = []
        for node in vector_nodes:
            pairs.append((query_bundle.query_str, node.text))
        scores = self.crossencoder_model.predict(pairs)

        zipped = zip(scores, vector_nodes)
        zipped = sorted(zipped, key=lambda x: x[0], reverse=True)
        numbers, nodes = zip(*zipped)
        return nodes[:int(0.7 * len(nodes))]

    def replace_abbreviations(self, question, keyword_nodes):
        """
        Replaces abbreviations in the lexique with their complete names.
        Assumes abbreviations are separated by either two spaces or one space + one punctuation mark.
        """
        lexique = {}
        for node in keyword_nodes:
            keyword = node.node.metadata['keyword']
            keyword_complete = node.text.split(': ')[1]
            lexique[keyword] = keyword_complete

        words = question.split()
        result = []

        for word in words:
            if word in lexique:
                result.append(lexique[word])
                result.append('('+word+')')
            elif word[:-1] in lexique:
                result.append(lexique[word[:-1]])
                result.append('('+word[:-1]+')')
                result.append(word[-1])
            elif word[1:] in lexique:
                result.append(word[0])
                result.append(lexique[word[1:]])
                result.append('('+word[1:]+')')
            else:
                result.append(word)

        return ' '.join(result)
    
    def _get_query_embedding_threaded(self, query):
        """
        Get the query embedding using a separate thread.
        """
        result_queue = Queue()  # Create a queue to store the result
        thread = threading.Thread(target=self._get_query_embedding_async_thread, args=(query, result_queue))
        thread.start()
        thread.join()  # Wait for the thread to finish
        return result_queue.get()  # Retrieve the result from the queue

    def _get_query_embedding_async_thread(self, query, result_queue):
        """
        Helper method to run the asynchronous call in a separate thread.
        """
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
        result = loop.run_until_complete(self.embed_model.aget_query_embedding(query))
        loop.close()
        result_queue.put(result)  # Put the result in the queue
    
    def _weighted_aggre_embedding(self, new_query_str):
        """
        A weighted aggregation embedding method which can integrate the preivous queries in the current query.
        """
        current_embedding = self._get_query_embedding_threaded(new_query_str)
        print(current_embedding)
        if len(self.old_queries) == 0:
            return current_embedding
        elif len(self.old_queries) == 1:
            first_embedding = self._get_query_embedding_threaded(self.old_queries[0])
            new_embedding = []
            for i in range(len(current_embedding)):
                new_embedding.append(0.8 * current_embedding[i] + 0.2 * first_embedding[i])
            return new_embedding
        else:
            first_embedding = self._get_query_embedding_threaded(self.old_queries[0])
            other_embeddings = []
            for query in self.old_queries[1:]:
                other_embeddings.append(self._get_query_embedding_threaded(query))
            other_embedding = list(np.array(other_embeddings).mean(axis=0))
            new_embedding = []
            for i in range(len(current_embedding)):
                new_embedding.append(0.8 * current_embedding[i] + 0.2 * first_embedding[i] + 0.1 * other_embedding[i])
            return new_embedding
    async def _aweighted_aggre_embedding(self, new_query_str):
        """
        A weighted aggregation embedding method which can integrate the preivous queries in the current query.
        """
        print("async activated")
        current_embedding = await self.embed_model.aget_query_embedding(new_query_str)
        if len(self.old_queries) == 0:
            return current_embedding
        elif len(self.old_queries) == 1:
            first_embedding = await self.embed_model.aget_query_embedding(self.old_queries[0])
            new_embedding = []
            for i in range(len(current_embedding)):
                new_embedding.append(0.8 * current_embedding[i] + 0.2 * first_embedding[i])
            return new_embedding
        else:
            first_embedding = await self.embed_model.aget_query_embedding(self.old_queries[0])
            other_embeddings = []
            for query in self.old_queries[1:]:
                other_embeddings.append(await self.embed_model.aget_query_embedding(query))
            other_embedding = list(np.array(other_embeddings).mean(axis=0))
            new_embedding = []
            for i in range(len(current_embedding)):
                new_embedding.append(0.8 * current_embedding[i] + 0.2 * first_embedding[i] + 0.1 * other_embedding[i])
            return new_embedding

  from tqdm.autonotebook import tqdm, trange
2024-09-27 11:41:14.926133: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [50]:
vector_retriever = VectorIndexRetriever(index=index_doc, similarity_top_k=15)
keyword_retriever = KeywordTableGPTRetriever(index=index_keyword, max_keywords_per_query = 20, keyword_extract_template=KEYWORD_EXTRACT_TEMPLATE)
custom_retriever = CustomRetriever(vector_retriever, keyword_retriever, embed_model=Settings.embed_model, old_queries=[])

In [51]:
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.response_synthesizers import ResponseMode
qa_template = PromptTemplate(DEFAULT_QUERY_PROMPT_TMPL)
query_engine = RetrieverQueryEngine.from_args(retriever=custom_retriever,
                                              text_qa_template=qa_template,
                                              use_async=True,
                                              response_mode = ResponseMode.SIMPLE_SUMMARIZE,
                                              streaming=False
                                             )

In [52]:
import time
start_time = time.time()

response = await query_engine.aquery("Quelle norme de tube acier utiliser pour faire une liaison détente / comptage pour alimenter une chaufferie ?")
print("========================================")
print(response.response)

end_time = time.time()
execution_time = end_time - start_time
print("========================================")
print(f"Execution Time: {execution_time} seconds")

async activated.
a6d34e81-2270-41a0-bd53-29e9a97ff6d2
None
Ac : Acier joint Coulé
#######
1de44c9a-4dd1-4599-9903-764c2c365b2d
None
APE : Acier revetu Polyéthylène
#######
ebdb32c9-3c6f-4014-b615-c8a22564631b
None
A / Ac : Acier
#######
2442f36a-1370-4c42-b40b-a69da0d2eb48
None
Ast : Acier Joint Standard
#######
7ced680c-d34c-4f2f-8047-0446ff4493f1
None
DPBA : Dispositif Protection Branchement Acier
#######
a72f7bb9-fe08-44a9-95e4-77a962738124
None
Ac : Acier
#######
2ad7b4b6-788f-4894-aa1f-79796c324780
None
MBDI : Manchette de branchement à déclencheur intégré (réseau Acier)
#######
d78a7b5e-686e-4ed8-9f40-a22705584925
None
A : Acier joint soudé
#######
0cf43a38-1334-4ae0-9aa6-1eeaeb7deef7
None
Av : Acier joint vissé
#######
092d3436-9a5c-47d9-9f93-23d2aa74620c
None
AR : Acier revetu Brai
#######
async activated
Selon les informations contextuelles fournies, les tubes en acier doivent être assemblés soit par brasage capillaire ("fort" ou "tendre") pour les tubes de diamètre extérieur 

In [53]:
import time
start_time = time.time()

response = await query_engine.aquery("Quels sont les achats immobilisés pour le biométhane ?")
print("========================================")
print(response.response)

end_time = time.time()
execution_time = end_time - start_time
print("========================================")
print(f"Execution Time: {execution_time} seconds")

async activated.
7cf600ba-afae-43f8-9801-b7fc04ebd0e8
None
OMA : Logiciel de Gestion des Marchés d' Achats
#######
704986ef-6ef3-4b95-95b0-e277d30e9d9c
None
E-consult : Portail Internet de mise à disposition des appels d’offre pour les fournisseurs (Métier Achats)
#######
cd932db3-47d6-4d7c-ab1c-5d0c920277a3
None
GGO : Gestionnaire des certificats de Garantie d'Origine biométhane
#######
0a6d9496-7a81-45f2-8179-cea51bf3117d
None
GO : (biométhane) Garantie d'Origine Le registre des GO enregistre les acteurs, les sites et les mouvements de GO
#######
ec88550b-3525-4ab4-8882-fd9939c55bff
None
PTF Bio : Plateforme Digital Biométhane
#######
8e9c6c1c-a8a7-4e62-b623-77419fcb0e82
None
CRAB : Compte-Rendu Annuel Biométhane
#######
async activated
Il n'y a pas d'informations spécifiques sur les achats immobilisés pour le biométhane dans le contexte fourni. Les documents traitent plutôt des principes généraux de fonctionnement et d'exploitation des postes d'injection de biométhane, ainsi que des

In [54]:
import time
start_time = time.time()

response = await query_engine.aquery("Que veut dire CICM ?")
print("========================================")
print(response.response)

end_time = time.time()
execution_time = end_time - start_time
print("========================================")
print(f"Execution Time: {execution_time} seconds")

async activated.
8471c5fa-7fb0-4b30-93ff-71297c693945
None
CICM : Conduite d'Immeuble, Conduite Montante
#######
async activated
CICM signifie "Conduite d'Immeuble, Conduite Montante". C'est un terme utilisé pour désigner les tuyauteries verticales (conduites montantes) et horizontales (conduites d'immeuble) qui alimentent en gaz les différents niveaux d'un bâtiment d'habitation collectif.
Execution Time: 27.767505645751953 seconds


In [55]:
import time
start_time = time.time()

response = await query_engine.aquery("Dans le cadre de la surveillance des réseaux, doit-on surveiller les branchements ?")
print("========================================")
print(response.response)

end_time = time.time()
execution_time = end_time - start_time
print("========================================")
print(f"Execution Time: {execution_time} seconds")

async activated.
829d4a92-9fae-49f7-b58b-c4071df163bd
None
VSR : Véhicule de Surveillance des Réseaux
#######
828b100a-66f1-4187-b294-2799918bfceb
None
TARG : Téléchargement des Applications Réseaux Gaz.
#######
9b7ab968-6d93-4ffb-8e28-7d9219fb066a
None
ATLAS : Logiciel de Cartographie des Réseaux à Grande Echelle
#######
0b941bdb-afdd-4a9c-ba35-f7771ff06570
None
AIPR : Autorisation d’intervention à proximité des réseaux
#######
31de9c4b-5387-4da6-b9f2-d63c7db3ab17
None
GRHYD : Gestion des Réseaux par l'injection d'Hydrogène pour Décarboner les énergies. (ADEME+Engie)
#######
9da32904-5699-4675-9be4-c51e6e190faf
None
VGD : Voirie et Réseaux Divers
#######
19a8994b-bbfd-4799-be18-9e3332113ff4
None
PHARE : Portail Habilitations et Accès  aux Réseaux informatiques d'Entreprises
#######
b513bcf5-06b2-43ee-8c1d-51f4e79dfa1e
None
ROR : Relations Opérateurs de Réseaux
#######
1b1406e5-6383-496c-8b8e-fcc28ccbf019
None
ATRD : Ou "Tarifs ATRD" : tarifs d’Accès des Tiers aux Réseaux de Distributi

## Claude Sonnet -- avec sousquestion -- 22 seconds / 4 questions

In [66]:
from llama_index.core.workflow import Event
from llama_index.core.workflow import (
    Workflow,
    StartEvent,
    StopEvent,
    Context,
    step,
)
from llama_index.core.indices.utils import (
    default_format_node_batch_fn,
)
from llama_index.core.retrievers import BaseRetriever, VectorIndexRetriever
from llama_index.core.output_parsers.pydantic import PydanticOutputParser
from llama_index.core.program import LLMTextCompletionProgram
from pydantic import BaseModel, Field
from typing import Any, List, Optional, Type, Union

SUB_QUESTION_PROMPT_STR = """
Vous êtes un assistant utile qui génère plusieurs requêtes de recherche basées sur une seule requête d'entrée. 
Générer 2 requêtes de recherche, une sur chaque ligne, liées à la requête d'entrée suivante:
Requête: {query}
Requêtes:
"""

PYDANTIC_FORMAT_TMPL = """
Voici un schéma JSON à suivre:
{schema}

Générez un objet JSON valide mais ne répétez pas le schéma.
"""

CHOICE_SELECT_PROMPT_TMPL = """
Une liste de documents est présentée ci-dessous. Chaque document est accompagné d'un numéro et d'un résumé du document. Une question est également fournie.
Donne-moi uniquement les numéros des documents pertinents à la question en utilisant le format du tableau.
Voici quelques exemples: 
Document 1:
<résumé du document 1>

Document 2:
<résumé du document 2>

...

Document 10:
<résumé du document 10>

Question: <question>
Tableau des nombres des documents pertinents: [2, 4, 5, 6]

Essayons ceci maintenant :

{context_str}
Question: {query_str}
Tableau des nombres des documents pertinents:
"""

CHOICE_REFLECTION_PROMPT_STR = """
Cela a provoqué l'erreur : {error}

Réessayez, la réponse doit contenir uniquement les numéros des documents pertinents.
Ne répétez pas le schéma.
"""

NEGATIVE_FILTER_PROMPT_TMPL = """
Voici une liste de contextes pertinents pour une requête donnée. Vous devez filtrer cette liste pour exclure les contextes qui correspondent à un contexte négatif spécifique.
Donne-moi uniquement les numéros des documents en utilisant le format du tableau.
Voici quelques exemples: 
Requête principale : <question>
Contexte négatif : <contexte_negatif>
Liste des contextes :
Document 1:
<résumé du document 1>

Document 2:
<résumé du document 2>

Document 3:
<résumé du document 10>

Document 4:
<résumé du document 10>

Tableau des nombres des documents: [1, 2, 3]

Essayons ceci maintenant :

Requête principale : {query_str}
Contexte négatif : {contexte_negatif}
Liste des contextes : {contexte}
Tableau des nombres des documents:
"""

NEGATIVE_FILTER_REFLECTION_PROMPT_STR = """
Cela a provoqué l'erreur : {error}

Réessayez, la réponse doit contenir uniquement les numéros des documents.
Ne répétez pas le schéma.
"""

class DocumentNumberInfo(BaseModel):
    """Informations concernant un tableau structuré."""
    tableau_document: list = Field(
        ..., description="le tableau des nombres des documents"
    )

class SubQuestionEvent(Event):
    query: str
    neg_context: str
    nodes: list

class RerankerValidationErrorEvent(Event):
    error: str
    query: str
    neg_context: str
    nodes: list
    rerank_nodes: list

class RerankerBatchEvent(Event):
    query: str
    neg_context: str
    nodes: list
    rerank_nodes: list

class RerankerDone(Event):
    prompt: str
    query: str
    neg_context: str
    nodes: list
    rerank_nodes: list
    batch_nodes: list

class RerankerValidationDone(Event):
    query: str
    neg_context: str
    nodes: list
    rerank_nodes: list

class FilterValidationErrorEvent(Event):
    error: str
    query: str
    neg_context: str
    nodes: list
    rerank_nodes: list

class FilterDone(Event):
    prompt: str
    query: str
    neg_context: str
    nodes: list
    rerank_nodes: list

class RerankerFlow(Workflow):
    def __init__(self, vector_retriever: VectorIndexRetriever, timeout = 120, verbose = True):
        self.llm = Bedrock(model="anthropic.claude-3-sonnet-20240229-v1:0", region_name="eu-west-3")
        self.vector_retriever = vector_retriever
        self.batch_size = 3 # the divided size to rerank all the nodes 
        self.max_retries = 3 # the max retries for the prompt
        self.min_chunk = 10 # the minimum size of the chunks after the rerank
        super().__init__(timeout=timeout, verbose=verbose)

    # the state to generate sub questions and put inside nodes
    @step()
    async def sub_question(
        self, ev: StartEvent
    ) -> SubQuestionEvent:
        query = ev.query
        nodes = ev.nodes
        neg_context = ev.neg_context
        subquestions = []
        try:
            subquestions_response = Settings.llm.complete(SUB_QUESTION_PROMPT_STR.format(query=query))
            subquestions = subquestions_response.text.split("\n")
        except Exception as e:
            print(e)
        if len(subquestions) > 2:
            nodes1 = self.vector_retriever.retrieve(subquestions[-2])
            nodes2 = self.vector_retriever.retrieve(subquestions[-1])
            one_third_index = int(len(nodes)/3)
            two_third_index = int(len(nodes)/3*2)
            nodes_ids = []
            for node in nodes:
                nodes_ids.append(node.id_)
            indicator1 = 0
            indicator2 = 0
            for i in range(len(nodes1)):
                if nodes1[i].id_ not in nodes_ids:
                    indicator1 += 1
                    nodes_ids.append(nodes1[i].id_)
                    if indicator1 == 1:
                        nodes.insert(one_third_index, nodes1[i])
                    elif indicator1 == 2:
                        nodes.insert(one_third_index, nodes1[i])
                    elif indicator1 == 3:
                        nodes.insert(two_third_index, nodes1[i])
                    elif indicator1 == 4:
                        nodes.insert(two_third_index, nodes1[i])
                    elif indicator1 == 5:
                        nodes.append(nodes1[i])
                    elif indicator1 == 6:
                        nodes.append(nodes1[i])
                        break
            for i in range(len(nodes2)):
                if nodes2[i].id_ not in nodes_ids:
                    indicator2 += 1
                    nodes_ids.append(nodes2[i].id_)
                    if indicator2 == 1:
                        nodes.insert(one_third_index, nodes2[i])
                    elif indicator2 == 2:
                        nodes.insert(one_third_index, nodes2[i])
                    elif indicator2 == 3:
                        nodes.insert(two_third_index, nodes2[i])
                    elif indicator2 == 4:
                        nodes.insert(two_third_index, nodes2[i])
                    elif indicator2 == 5:
                        nodes.append(nodes2[i])
                    elif indicator2 == 6:
                        nodes.append(nodes2[i])
                        break
        return SubQuestionEvent(query=query, neg_context=neg_context, nodes=nodes)

    # the state to generate prompt for reranking
    @step(pass_context=True)
    async def reranker_prompt(
        self, ctx: Context, ev: Union[SubQuestionEvent, RerankerValidationErrorEvent, RerankerBatchEvent]
    ) -> Union[StopEvent, RerankerDone]:
        query = ev.query
        nodes = ev.nodes
        neg_context = ev.neg_context
        window = int(len(nodes)/3)

        if isinstance(ev, SubQuestionEvent):
            rerank_nodes = []
            current_batch = ctx.data.get("batch", 0)
            ctx.data["batch"] = current_batch + 1
            batch_nodes = nodes[window*current_batch:window*(current_batch+1)]
            reflection_prompt = ""
        elif isinstance(ev, RerankerBatchEvent):
            rerank_nodes = ev.rerank_nodes
            current_batch = ctx.data.get("batch", 0)
            # iterate the nodes based on batch size
            if current_batch >= self.batch_size:
                print("Reranked nodes are extracted.")
                # if no negative context, output reranked nodes. otherwise, go the state of filter
                if neg_context == '':
                    if len(rerank_nodes) == 0:
                        return StopEvent(result=None)
                    else:
                        return StopEvent(result=rerank_nodes)
                else:
                    return RerankerValidationDone(query=query, neg_context=neg_context, nodes=nodes, rerank_nodes=rerank_nodes)
            else:
                ctx.data["batch"] = current_batch + 1
            batch_nodes = nodes[window*current_batch:window*(current_batch+1)]
            reflection_prompt = ""
        # prompt error event
        elif isinstance(ev, RerankerValidationErrorEvent):
            current_retries = ctx.data.get("prompt_retries", 0)
            rerank_nodes = ev.rerank_nodes
            if current_retries >= self.max_retries:
                if neg_context == '':
                    if len(rerank_nodes) == 0:
                        return StopEvent(result=None)
                    else:
                        return StopEvent(result=rerank_nodes)
                else:
                    return RerankerValidationDone(query=query, neg_context=neg_context, nodes=nodes, rerank_nodes=rerank_nodes)
            else:
                ctx.data["prompt_retries"] = current_retries + 1
            current_batch = ctx.data.get("batch", 0)
            batch_nodes = nodes[window*current_batch:window*(current_batch+1)]
            reflection_prompt = CHOICE_REFLECTION_PROMPT_STR.format(error=ev.error)
        
        prompt = CHOICE_SELECT_PROMPT_TMPL
        if reflection_prompt:
            prompt += reflection_prompt
        return RerankerDone(prompt=prompt, query=query, neg_context=neg_context, nodes=nodes, rerank_nodes=rerank_nodes, batch_nodes=batch_nodes)

    # call LLM to rerank nodes
    @step()
    async def rerank_validate(
        self, ev: RerankerDone
    ) -> Union[StopEvent, RerankerValidationDone, RerankerValidationErrorEvent, RerankerBatchEvent]:
        try:
            choice_output_parser = PydanticOutputParser(output_cls=DocumentNumberInfo, pydantic_format_tmpl=PYDANTIC_FORMAT_TMPL)
            choice_program = LLMTextCompletionProgram.from_defaults(
                output_parser=choice_output_parser,
                llm=self.llm,
                prompt_template_str=ev.prompt,
            )
            output = choice_program(context_str=default_format_node_batch_fn(ev.batch_nodes), query_str=ev.query)
            new_nodes = []
            for i in output.tableau_document:
                if i-1 < len(ev.batch_nodes):
                    new_nodes.append(ev.batch_nodes[i-1])
        except Exception as e:
            print("Validation failed, retrying...")
            return RerankerValidationErrorEvent(
                error=str(e), query=ev.query, neg_context=ev.neg_context, nodes=ev.nodes, rerank_nodes=ev.rerank_nodes
            )
        rerank_nodes = new_nodes + ev.rerank_nodes
        # if reranked nodes are less than min chunks, return to the state of batch event, otherwise output the reranked nodes.
        if len(rerank_nodes) > self.min_chunk:
            print("Reranked nodes are extracted.")
            if ev.neg_context == '':
                return StopEvent(result=rerank_nodes)
            else:
                return RerankerValidationDone(query=ev.query, neg_context=ev.neg_context, nodes=ev.nodes, rerank_nodes=rerank_nodes)
        else:
            return RerankerBatchEvent(query=ev.query, neg_context=ev.neg_context, nodes=ev.nodes, rerank_nodes=rerank_nodes)

    # prepare the prompt to filter the reranked nodes
    @step(pass_context=True)
    async def filter_prompt(
        self, ctx: Context, ev: Union[RerankerValidationDone, FilterValidationErrorEvent]
    ) -> Union[StopEvent, FilterDone]:
        query = ev.query
        nodes = ev.nodes
        neg_context = ev.neg_context
        rerank_nodes = ev.rerank_nodes

        if isinstance(ev, RerankerValidationDone):
            reflection_prompt = ""
        elif isinstance(ev, FilterValidationErrorEvent):
            current_retries = ctx.data.get("prompt_retries", 0)
            if current_retries >= self.max_retries:
                if len(rerank_nodes) == 0:
                    return StopEvent(result=None)
                else:
                    return StopEvent(result=rerank_nodes)
            else:
                ctx.data["prompt_retries"] = current_retries + 1
            reflection_prompt = NEGATIVE_FILTER_REFLECTION_PROMPT_STR.format(error=ev.error)
        prompt = NEGATIVE_FILTER_PROMPT_TMPL
        if reflection_prompt:
            prompt += reflection_prompt
        return FilterDone(prompt=prompt, query=query, neg_context=neg_context, nodes=nodes, rerank_nodes=rerank_nodes)

    # call LLM to filter the reranked nodes
    @step(pass_context=True)
    async def filter_validate(
        self, ctx: Context, ev: FilterDone
    ) -> Union[StopEvent, FilterValidationErrorEvent, RerankerBatchEvent]:
        try:
            choice_output_parser = PydanticOutputParser(output_cls=DocumentNumberInfo, pydantic_format_tmpl=PYDANTIC_FORMAT_TMPL)
            filter_program = LLMTextCompletionProgram.from_defaults(
                output_parser=choice_output_parser,
                llm=self.llm,
                prompt_template_str=ev.prompt,
            )
            output = filter_program(context_str=default_format_node_batch_fn(ev.rerank_nodes), query_str=ev.query, contexte_negatif=ev.neg_context)
        except Exception as e:
            print("Validation failed, retrying...")
            return FilterValidationErrorEvent(
                error=str(e), query=ev.query, neg_context=ev.neg_context, nodes=ev.nodes, rerank_nodes=ev.rerank_nodes
            )
        new_nodes = []
        for i in output.tableau_document:
            new_nodes.append(ev.rerank_nodes[i-1])
        print("Reranked nodes are filtered.")
        current_batch = ctx.data.get("batch", 0)
        if len(new_nodes) > self.min_chunk:
            return StopEvent(result=new_nodes)
        else:
            if current_batch >= self.batch_size:
                if len(new_nodes) == 0:
                    return StopEvent(result=None)
                else:
                    return StopEvent(result=new_nodes)
            else:
                return RerankerBatchEvent(query=ev.query, neg_context=ev.neg_context, nodes=ev.nodes, rerank_nodes=new_nodes)

In [67]:
from llama_index.core.schema import QueryBundle, NodeWithScore
from llama_index.core.retrievers import BaseRetriever, VectorIndexRetriever
from llama_index.core.indices.keyword_table import KeywordTableGPTRetriever
from typing import List
import threading
import asyncio
from queue import Queue

class CustomRetriever(BaseRetriever):
    """Custom retriever that performs both semantic search and hybrid search."""

    def __init__(
        self,
        vector_retriever: VectorIndexRetriever,
        keyword_retriever: KeywordTableGPTRetriever,
        embed_model: BedrockEmbedding,
        old_queries: List = [],
        mode: str = "OR",
        
    ) -> None:
        """Init params."""

        self._vector_retriever = vector_retriever
        self._keyword_retriever = keyword_retriever
        self.embed_model = embed_model
        self.reranker = RerankerFlow(vector_retriever=vector_retriever)
        self.old_queries = old_queries
        index_feedback = load_index_from_storage(storage_context=StorageContext.from_defaults(persist_dir="finetune/vector_persist_800_feedback"))
        self.feedback_retriever = VectorIndexRetriever(index=index_feedback, similarity_top_k=1)
        if mode not in ("AND", "OR"):
            raise ValueError("Invalid mode.")
        self._mode = mode
        super().__init__()

    def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
        """Retrieve nodes given query."""
        print("async not activated.")
        keyword_nodes = self._keyword_retriever.retrieve(query_bundle)
        new_query_str = self.replace_abbreviations(query_bundle.query_str, keyword_nodes)
        query_bundle.embedding = self._weighted_aggre_embedding(new_query_str)
        vector_nodes = self._vector_retriever.retrieve(query_bundle)

        vector_ids = {n.node.node_id for n in vector_nodes}
        keyword_ids = {n.node.node_id for n in keyword_nodes}

        combined_dict = {n.node.node_id: n for n in vector_nodes}
        combined_dict.update({n.node.node_id: n for n in keyword_nodes})

        if self._mode == "AND":
            retrieve_ids = vector_ids.intersection(keyword_ids)
        else:
            retrieve_ids = vector_ids.union(keyword_ids)

        retrieve_nodes = [combined_dict[rid] for rid in retrieve_ids]
        return retrieve_nodes

    async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
        """Retrieve nodes given query."""
        print("async activated.")
        keyword_nodes = self._keyword_retriever.retrieve(query_bundle)
        new_query_str = self.replace_abbreviations(query_bundle.query_str, keyword_nodes)
        query_bundle.embedding = await self._aweighted_aggre_embedding(new_query_str)
        vector_nodes_full = self._vector_retriever.retrieve(query_bundle)

        feedback_nodes = self.feedback_retriever.retrieve(query_bundle)
        if feedback_nodes[0].score > 0.8 : 
            neg_context = feedback_nodes[0].metadata["message"]
        else:
            neg_context = ''
        vector_nodes = await self.reranker.run(query=query_bundle.query_str, nodes=vector_nodes_full, neg_context=neg_context)
        if vector_nodes == None:
            return keyword_nodes

        vector_ids = {n.node.node_id for n in vector_nodes}
        keyword_ids = {n.node.node_id for n in keyword_nodes}

        combined_dict = {n.node.node_id: n for n in vector_nodes}
        combined_dict.update({n.node.node_id: n for n in keyword_nodes})

        if self._mode == "AND":
            retrieve_ids = vector_ids.intersection(keyword_ids)
        else:
            retrieve_ids = vector_ids.union(keyword_ids)

        retrieve_nodes = [combined_dict[rid] for rid in retrieve_ids]
        return retrieve_nodes

    def replace_abbreviations(self, question, keyword_nodes):
        """
        Replaces abbreviations in the lexique with their complete names.
        Assumes abbreviations are separated by either two spaces or one space + one punctuation mark.
        """
        lexique = {}
        for node in keyword_nodes:
            keyword = node.node.metadata['keyword']
            keyword_complete = node.text.split(': ')[1]
            lexique[keyword] = keyword_complete

        words = question.split()
        result = []

        for word in words:
            if word in lexique:
                result.append(lexique[word])
                result.append('('+word+')')
            elif word[:-1] in lexique:
                result.append(lexique[word[:-1]])
                result.append('('+word[:-1]+')')
                result.append(word[-1])
            elif word[1:] in lexique:
                result.append(word[0])
                result.append(lexique[word[1:]])
                result.append('('+word[1:]+')')
            else:
                result.append(word)

        return ' '.join(result)
    
    def _get_query_embedding_threaded(self, query):
        """
        Get the query embedding using a separate thread.
        """
        result_queue = Queue()  # Create a queue to store the result
        thread = threading.Thread(target=self._get_query_embedding_async_thread, args=(query, result_queue))
        thread.start()
        thread.join()  # Wait for the thread to finish
        return result_queue.get()  # Retrieve the result from the queue

    def _get_query_embedding_async_thread(self, query, result_queue):
        """
        Helper method to run the asynchronous call in a separate thread.
        """
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
        result = loop.run_until_complete(self.embed_model.aget_query_embedding(query))
        loop.close()
        result_queue.put(result)  # Put the result in the queue
    
    def _weighted_aggre_embedding(self, new_query_str):
        """
        A weighted aggregation embedding method which can integrate the preivous queries in the current query.
        """
        current_embedding = self._get_query_embedding_threaded(new_query_str)
        print(current_embedding)
        if len(self.old_queries) == 0:
            return current_embedding
        elif len(self.old_queries) == 1:
            first_embedding = self._get_query_embedding_threaded(self.old_queries[0])
            new_embedding = []
            for i in range(len(current_embedding)):
                new_embedding.append(0.8 * current_embedding[i] + 0.2 * first_embedding[i])
            return new_embedding
        else:
            first_embedding = self._get_query_embedding_threaded(self.old_queries[0])
            other_embeddings = []
            for query in self.old_queries[1:]:
                other_embeddings.append(self._get_query_embedding_threaded(query))
            other_embedding = list(np.array(other_embeddings).mean(axis=0))
            new_embedding = []
            for i in range(len(current_embedding)):
                new_embedding.append(0.8 * current_embedding[i] + 0.2 * first_embedding[i] + 0.1 * other_embedding[i])
            return new_embedding
    async def _aweighted_aggre_embedding(self, new_query_str):
        """
        A weighted aggregation embedding method which can integrate the preivous queries in the current query.
        """
        current_embedding = await self.embed_model.aget_query_embedding(new_query_str)
        if len(self.old_queries) == 0:
            return current_embedding
        elif len(self.old_queries) == 1:
            first_embedding = await self.embed_model.aget_query_embedding(self.old_queries[0])
            new_embedding = []
            for i in range(len(current_embedding)):
                new_embedding.append(0.8 * current_embedding[i] + 0.2 * first_embedding[i])
            return new_embedding
        else:
            first_embedding = await self.embed_model.aget_query_embedding(self.old_queries[0])
            other_embeddings = []
            for query in self.old_queries[1:]:
                other_embeddings.append(await self.embed_model.aget_query_embedding(query))
            other_embedding = list(np.array(other_embeddings).mean(axis=0))
            new_embedding = []
            for i in range(len(current_embedding)):
                new_embedding.append(0.8 * current_embedding[i] + 0.2 * first_embedding[i] + 0.1 * other_embedding[i])
            return new_embedding

In [68]:
vector_retriever = VectorIndexRetriever(index=index_doc, similarity_top_k=30)
keyword_retriever = KeywordTableGPTRetriever(index=index_keyword, max_keywords_per_query = 20, keyword_extract_template=KEYWORD_EXTRACT_TEMPLATE)
custom_retriever = CustomRetriever(vector_retriever, keyword_retriever, embed_model=Settings.embed_model, old_queries=[])

In [69]:
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.response_synthesizers import ResponseMode
qa_template = PromptTemplate(DEFAULT_QUERY_PROMPT_TMPL)
query_engine = RetrieverQueryEngine.from_args(retriever=custom_retriever,
                                              text_qa_template=qa_template,
                                              use_async=True,
                                              response_mode = ResponseMode.SIMPLE_SUMMARIZE,
                                              streaming=False
                                             )

In [70]:
import time
start_time = time.time()

response = await query_engine.aquery("Quelle norme de tube acier utiliser pour faire une liaison détente / comptage pour alimenter une chaufferie ?")
print("========================================")
print(response.response)

end_time = time.time()
execution_time = end_time - start_time
print("========================================")
print(f"Execution Time: {execution_time} seconds")

async activated.
Running step sub_question
Step sub_question produced event SubQuestionEvent
Running step reranker_prompt
Step reranker_prompt produced event RerankerDone
Running step rerank_validate
Step rerank_validate produced event RerankerBatchEvent
Running step reranker_prompt
Step reranker_prompt produced event RerankerDone
Running step rerank_validate
Step rerank_validate produced event RerankerBatchEvent
Running step reranker_prompt
Step reranker_prompt produced event RerankerDone
Running step rerank_validate
Reranked nodes are extracted.
Step rerank_validate produced event RerankerValidationDone
Running step filter_prompt
Step filter_prompt produced event FilterDone
Running step filter_validate
Reranked nodes are filtered.
Step filter_validate produced event StopEvent
Selon les informations fournies, pour faire une liaison détente/comptage pour alimenter une chaufferie, il faut utiliser des tubes en acier conformes aux normes suivantes :

1- Les tubes noirs doivent avoir un r

In [71]:
import time
start_time = time.time()

response = await query_engine.aquery("Quels sont les achats immobilisés pour le biométhane ?")
print("========================================")
print(response.response)

end_time = time.time()
execution_time = end_time - start_time
print("========================================")
print(f"Execution Time: {execution_time} seconds")

async activated.
Running step sub_question
Step sub_question produced event SubQuestionEvent
Running step reranker_prompt
Step reranker_prompt produced event RerankerDone
Running step rerank_validate
Step rerank_validate produced event RerankerBatchEvent
Running step reranker_prompt
Step reranker_prompt produced event RerankerDone
Running step rerank_validate
Reranked nodes are extracted.
Step rerank_validate produced event StopEvent
Les informations contextuelles ne mentionnent pas d'achats immobilisés spécifiques au biométhane. Cependant, elles indiquent que le poste réseau d'injection de biométhane comprend différents équipements tels que :

1- Des organes de coupure accessibles depuis le domaine public.
2- Un limiteur de débit dynamique pour moduler l'injection en fonction de la pression amont.
3- Des capteurs de pression et de débit pour l'instrumentation du réseau.
4- Des équipements pour le contrôle de la qualité du gaz, son odorisation, sa pression et la régulation de son débit

In [72]:
import time
start_time = time.time()

response = await query_engine.aquery("Que veut dire CICM ?")
print("========================================")
print(response.response)

end_time = time.time()
execution_time = end_time - start_time
print("========================================")
print(f"Execution Time: {execution_time} seconds")

async activated.
Running step sub_question
Step sub_question produced event SubQuestionEvent
Running step reranker_prompt
Step reranker_prompt produced event RerankerDone
Running step rerank_validate
Step rerank_validate produced event RerankerBatchEvent
Running step reranker_prompt
Step reranker_prompt produced event RerankerDone
Running step rerank_validate
Reranked nodes are extracted.
Step rerank_validate produced event StopEvent
CICM signifie "Conduite d'Immeuble, Conduite Montante". C'est un terme utilisé pour désigner les tuyauteries verticales qui alimentent différents niveaux d'un bâtiment d'habitation collectif, ainsi que les conduites reliant ces tuyauteries verticales à la conduite d'immeuble.
Execution Time: 18.33278775215149 seconds


In [73]:
import time
start_time = time.time()

response = await query_engine.aquery("Dans le cadre de la surveillance des réseaux, doit-on surveiller les branchements ?")
print("========================================")
print(response.response)

end_time = time.time()
execution_time = end_time - start_time
print("========================================")
print(f"Execution Time: {execution_time} seconds")

async activated.
Running step sub_question
Step sub_question produced event SubQuestionEvent
Running step reranker_prompt
Step reranker_prompt produced event RerankerDone
Running step rerank_validate
Step rerank_validate produced event RerankerBatchEvent
Running step reranker_prompt
Step reranker_prompt produced event RerankerDone
Running step rerank_validate
Step rerank_validate produced event RerankerBatchEvent
Running step reranker_prompt
Step reranker_prompt produced event RerankerDone
Running step rerank_validate
Reranked nodes are extracted.
Step rerank_validate produced event RerankerValidationDone
Running step filter_prompt
Step filter_prompt produced event FilterDone
Running step filter_validate
Reranked nodes are filtered.
Step filter_validate produced event StopEvent
Oui, la surveillance des réseaux doit inclure le suivi des branchements. Voici les principales actions à entreprendre :

1- Effectuer un contrôle de cohérence entre les données figurant sur les plans et les élém

## Claude Sonnet -- sans sousquestion  -- 15.9 seconds / 4 questions

In [None]:
from llama_index.core.workflow import Event
from llama_index.core.workflow import (
    Workflow,
    StartEvent,
    StopEvent,
    Context,
    step,
)
from llama_index.core.indices.utils import (
    default_format_node_batch_fn,
)
from llama_index.core.retrievers import BaseRetriever, VectorIndexRetriever
from llama_index.core.output_parsers.pydantic import PydanticOutputParser
from llama_index.core.program import LLMTextCompletionProgram
from pydantic import BaseModel, Field
from typing import Any, List, Optional, Type, Union

SUB_QUESTION_PROMPT_STR = """
Vous êtes un assistant utile qui génère plusieurs requêtes de recherche basées sur une seule requête d'entrée. 
Générer 2 requêtes de recherche, une sur chaque ligne, liées à la requête d'entrée suivante:
Requête: {query}
Requêtes:
"""

PYDANTIC_FORMAT_TMPL = """
Voici un schéma JSON à suivre:
{schema}

Générez un objet JSON valide mais ne répétez pas le schéma.
"""

CHOICE_SELECT_PROMPT_TMPL = """
Une liste de documents est présentée ci-dessous. Chaque document est accompagné d'un numéro et d'un résumé du document. Une question est également fournie.
Donne-moi uniquement les numéros des documents pertinents à la question en utilisant le format du tableau.
Voici quelques exemples: 
Document 1:
<résumé du document 1>

Document 2:
<résumé du document 2>

...

Document 10:
<résumé du document 10>

Question: <question>
Tableau des nombres des documents pertinents: [2, 4, 5, 6]

Essayons ceci maintenant :

{context_str}
Question: {query_str}
Tableau des nombres des documents pertinents:
"""

CHOICE_REFLECTION_PROMPT_STR = """
Cela a provoqué l'erreur : {error}

Réessayez, la réponse doit contenir uniquement les numéros des documents pertinents.
Ne répétez pas le schéma.
"""

NEGATIVE_FILTER_PROMPT_TMPL = """
Voici une liste de contextes pertinents pour une requête donnée. Vous devez filtrer cette liste pour exclure les contextes qui correspondent à un contexte négatif spécifique.
Donne-moi uniquement les numéros des documents en utilisant le format du tableau.
Voici quelques exemples: 
Requête principale : <question>
Contexte négatif : <contexte_negatif>
Liste des contextes :
Document 1:
<résumé du document 1>

Document 2:
<résumé du document 2>

Document 3:
<résumé du document 10>

Document 4:
<résumé du document 10>

Tableau des nombres des documents: [1, 2, 3]

Essayons ceci maintenant :

Requête principale : {query_str}
Contexte négatif : {contexte_negatif}
Liste des contextes : {contexte}
Tableau des nombres des documents:
"""

NEGATIVE_FILTER_REFLECTION_PROMPT_STR = """
Cela a provoqué l'erreur : {error}

Réessayez, la réponse doit contenir uniquement les numéros des documents.
Ne répétez pas le schéma.
"""

class DocumentNumberInfo(BaseModel):
    """Informations concernant un tableau structuré."""
    tableau_document: list = Field(
        ..., description="le tableau des nombres des documents"
    )

class SubQuestionEvent(Event):
    query: str
    neg_context: str
    nodes: list

class RerankerValidationErrorEvent(Event):
    error: str
    query: str
    neg_context: str
    nodes: list
    rerank_nodes: list

class RerankerBatchEvent(Event):
    query: str
    neg_context: str
    nodes: list
    rerank_nodes: list

class RerankerDone(Event):
    prompt: str
    query: str
    neg_context: str
    nodes: list
    rerank_nodes: list
    batch_nodes: list

class RerankerValidationDone(Event):
    query: str
    neg_context: str
    nodes: list
    rerank_nodes: list

class FilterValidationErrorEvent(Event):
    error: str
    query: str
    neg_context: str
    nodes: list
    rerank_nodes: list

class FilterDone(Event):
    prompt: str
    query: str
    neg_context: str
    nodes: list
    rerank_nodes: list

class RerankerFlow(Workflow):
    def __init__(self, vector_retriever: VectorIndexRetriever, timeout = 120, verbose = True):
        self.llm = Bedrock(model="anthropic.claude-3-sonnet-20240229-v1:0", region_name="eu-west-3")
        self.vector_retriever = vector_retriever
        self.batch_size = 3 # the divided size to rerank all the nodes 
        self.max_retries = 3 # the max retries for the prompt
        self.min_chunk = 10 # the minimum size of the chunks after the rerank
        super().__init__(timeout=timeout, verbose=verbose)

    # the state to generate sub questions and put inside nodes
    @step()
    async def sub_question(
        self, ev: StartEvent
    ) -> SubQuestionEvent:
        query = ev.query
        nodes = ev.nodes
        neg_context = ev.neg_context
        # subquestions = []
        # try:
        #     subquestions_response = Settings.llm.complete(SUB_QUESTION_PROMPT_STR.format(query=query))
        #     subquestions = subquestions_response.text.split("\n")
        # except Exception as e:
        #     print(e)
        # if len(subquestions) > 2:
        #     nodes1 = self.vector_retriever.retrieve(subquestions[-2])
        #     nodes2 = self.vector_retriever.retrieve(subquestions[-1])
        #     one_third_index = int(len(nodes)/3)
        #     two_third_index = int(len(nodes)/3*2)
        #     nodes_ids = []
        #     for node in nodes:
        #         nodes_ids.append(node.id_)
        #     indicator1 = 0
        #     indicator2 = 0
        #     for i in range(len(nodes1)):
        #         if nodes1[i].id_ not in nodes_ids:
        #             indicator1 += 1
        #             nodes_ids.append(nodes1[i].id_)
        #             if indicator1 == 1:
        #                 nodes.insert(one_third_index, nodes1[i])
        #             elif indicator1 == 2:
        #                 nodes.insert(one_third_index, nodes1[i])
        #             elif indicator1 == 3:
        #                 nodes.insert(two_third_index, nodes1[i])
        #             elif indicator1 == 4:
        #                 nodes.insert(two_third_index, nodes1[i])
        #             elif indicator1 == 5:
        #                 nodes.append(nodes1[i])
        #             elif indicator1 == 6:
        #                 nodes.append(nodes1[i])
        #                 break
        #     for i in range(len(nodes2)):
        #         if nodes2[i].id_ not in nodes_ids:
        #             indicator2 += 1
        #             nodes_ids.append(nodes2[i].id_)
        #             if indicator2 == 1:
        #                 nodes.insert(one_third_index, nodes2[i])
        #             elif indicator2 == 2:
        #                 nodes.insert(one_third_index, nodes2[i])
        #             elif indicator2 == 3:
        #                 nodes.insert(two_third_index, nodes2[i])
        #             elif indicator2 == 4:
        #                 nodes.insert(two_third_index, nodes2[i])
        #             elif indicator2 == 5:
        #                 nodes.append(nodes2[i])
        #             elif indicator2 == 6:
        #                 nodes.append(nodes2[i])
        #                 break
        return SubQuestionEvent(query=query, neg_context=neg_context, nodes=nodes)

    # the state to generate prompt for reranking
    @step(pass_context=True)
    async def reranker_prompt(
        self, ctx: Context, ev: Union[SubQuestionEvent, RerankerValidationErrorEvent, RerankerBatchEvent]
    ) -> Union[StopEvent, RerankerDone]:
        query = ev.query
        nodes = ev.nodes
        neg_context = ev.neg_context
        window = int(len(nodes)/3)

        if isinstance(ev, SubQuestionEvent):
            rerank_nodes = []
            current_batch = ctx.data.get("batch", 0)
            ctx.data["batch"] = current_batch + 1
            batch_nodes = nodes[window*current_batch:window*(current_batch+1)]
            reflection_prompt = ""
        elif isinstance(ev, RerankerBatchEvent):
            rerank_nodes = ev.rerank_nodes
            current_batch = ctx.data.get("batch", 0)
            # iterate the nodes based on batch size
            if current_batch >= self.batch_size:
                print("Reranked nodes are extracted.")
                # if no negative context, output reranked nodes. otherwise, go the state of filter
                if neg_context == '':
                    if len(rerank_nodes) == 0:
                        return StopEvent(result=None)
                    else:
                        return StopEvent(result=rerank_nodes)
                else:
                    return RerankerValidationDone(query=query, neg_context=neg_context, nodes=nodes, rerank_nodes=rerank_nodes)
            else:
                ctx.data["batch"] = current_batch + 1
            batch_nodes = nodes[window*current_batch:window*(current_batch+1)]
            reflection_prompt = ""
        # prompt error event
        elif isinstance(ev, RerankerValidationErrorEvent):
            current_retries = ctx.data.get("prompt_retries", 0)
            rerank_nodes = ev.rerank_nodes
            if current_retries >= self.max_retries:
                if neg_context == '':
                    if len(rerank_nodes) == 0:
                        return StopEvent(result=None)
                    else:
                        return StopEvent(result=rerank_nodes)
                else:
                    return RerankerValidationDone(query=query, neg_context=neg_context, nodes=nodes, rerank_nodes=rerank_nodes)
            else:
                ctx.data["prompt_retries"] = current_retries + 1
            current_batch = ctx.data.get("batch", 0)
            batch_nodes = nodes[window*current_batch:window*(current_batch+1)]
            reflection_prompt = CHOICE_REFLECTION_PROMPT_STR.format(error=ev.error)
        
        prompt = CHOICE_SELECT_PROMPT_TMPL
        if reflection_prompt:
            prompt += reflection_prompt
        return RerankerDone(prompt=prompt, query=query, neg_context=neg_context, nodes=nodes, rerank_nodes=rerank_nodes, batch_nodes=batch_nodes)

    # call LLM to rerank nodes
    @step()
    async def rerank_validate(
        self, ev: RerankerDone
    ) -> Union[StopEvent, RerankerValidationDone, RerankerValidationErrorEvent, RerankerBatchEvent]:
        try:
            choice_output_parser = PydanticOutputParser(output_cls=DocumentNumberInfo, pydantic_format_tmpl=PYDANTIC_FORMAT_TMPL)
            choice_program = LLMTextCompletionProgram.from_defaults(
                output_parser=choice_output_parser,
                llm=self.llm,
                prompt_template_str=ev.prompt,
            )
            output = choice_program(context_str=default_format_node_batch_fn(ev.batch_nodes), query_str=ev.query)
            new_nodes = []
            for i in output.tableau_document:
                if i-1 < len(ev.batch_nodes):
                    new_nodes.append(ev.batch_nodes[i-1])
        except Exception as e:
            print("Validation failed, retrying...")
            return RerankerValidationErrorEvent(
                error=str(e), query=ev.query, neg_context=ev.neg_context, nodes=ev.nodes, rerank_nodes=ev.rerank_nodes
            )
        rerank_nodes = new_nodes + ev.rerank_nodes
        # if reranked nodes are less than min chunks, return to the state of batch event, otherwise output the reranked nodes.
        if len(rerank_nodes) > self.min_chunk:
            print("Reranked nodes are extracted.")
            if ev.neg_context == '':
                return StopEvent(result=rerank_nodes)
            else:
                return RerankerValidationDone(query=ev.query, neg_context=ev.neg_context, nodes=ev.nodes, rerank_nodes=rerank_nodes)
        else:
            return RerankerBatchEvent(query=ev.query, neg_context=ev.neg_context, nodes=ev.nodes, rerank_nodes=rerank_nodes)

    # prepare the prompt to filter the reranked nodes
    @step(pass_context=True)
    async def filter_prompt(
        self, ctx: Context, ev: Union[RerankerValidationDone, FilterValidationErrorEvent]
    ) -> Union[StopEvent, FilterDone]:
        query = ev.query
        nodes = ev.nodes
        neg_context = ev.neg_context
        rerank_nodes = ev.rerank_nodes

        if isinstance(ev, RerankerValidationDone):
            reflection_prompt = ""
        elif isinstance(ev, FilterValidationErrorEvent):
            current_retries = ctx.data.get("prompt_retries", 0)
            if current_retries >= self.max_retries:
                if len(rerank_nodes) == 0:
                    return StopEvent(result=None)
                else:
                    return StopEvent(result=rerank_nodes)
            else:
                ctx.data["prompt_retries"] = current_retries + 1
            reflection_prompt = NEGATIVE_FILTER_REFLECTION_PROMPT_STR.format(error=ev.error)
        prompt = NEGATIVE_FILTER_PROMPT_TMPL
        if reflection_prompt:
            prompt += reflection_prompt
        return FilterDone(prompt=prompt, query=query, neg_context=neg_context, nodes=nodes, rerank_nodes=rerank_nodes)

    # call LLM to filter the reranked nodes
    @step(pass_context=True)
    async def filter_validate(
        self, ctx: Context, ev: FilterDone
    ) -> Union[StopEvent, FilterValidationErrorEvent, RerankerBatchEvent]:
        try:
            choice_output_parser = PydanticOutputParser(output_cls=DocumentNumberInfo, pydantic_format_tmpl=PYDANTIC_FORMAT_TMPL)
            filter_program = LLMTextCompletionProgram.from_defaults(
                output_parser=choice_output_parser,
                llm=self.llm,
                prompt_template_str=ev.prompt,
            )
            output = filter_program(context_str=default_format_node_batch_fn(ev.rerank_nodes), query_str=ev.query, contexte_negatif=ev.neg_context)
        except Exception as e:
            print("Validation failed, retrying...")
            return FilterValidationErrorEvent(
                error=str(e), query=ev.query, neg_context=ev.neg_context, nodes=ev.nodes, rerank_nodes=ev.rerank_nodes
            )
        new_nodes = []
        for i in output.tableau_document:
            new_nodes.append(ev.rerank_nodes[i-1])
        print("Reranked nodes are filtered.")
        current_batch = ctx.data.get("batch", 0)
        if len(new_nodes) > self.min_chunk:
            return StopEvent(result=new_nodes)
        else:
            if current_batch >= self.batch_size:
                if len(new_nodes) == 0:
                    return StopEvent(result=None)
                else:
                    return StopEvent(result=new_nodes)
            else:
                return RerankerBatchEvent(query=ev.query, neg_context=ev.neg_context, nodes=ev.nodes, rerank_nodes=new_nodes)

In [None]:
from llama_index.core.schema import QueryBundle, NodeWithScore
from llama_index.core.retrievers import BaseRetriever, VectorIndexRetriever
from llama_index.core.indices.keyword_table import KeywordTableGPTRetriever
from typing import List
import threading
import asyncio
from queue import Queue

class CustomRetriever(BaseRetriever):
    """Custom retriever that performs both semantic search and hybrid search."""

    def __init__(
        self,
        vector_retriever: VectorIndexRetriever,
        keyword_retriever: KeywordTableGPTRetriever,
        embed_model: BedrockEmbedding,
        old_queries: List = [],
        mode: str = "OR",
        
    ) -> None:
        """Init params."""

        self._vector_retriever = vector_retriever
        self._keyword_retriever = keyword_retriever
        self.embed_model = embed_model
        self.reranker = RerankerFlow(vector_retriever=vector_retriever)
        self.old_queries = old_queries
        index_feedback = load_index_from_storage(storage_context=StorageContext.from_defaults(persist_dir="finetune/vector_persist_800_feedback"))
        self.feedback_retriever = VectorIndexRetriever(index=index_feedback, similarity_top_k=1)
        if mode not in ("AND", "OR"):
            raise ValueError("Invalid mode.")
        self._mode = mode
        super().__init__()

    def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
        """Retrieve nodes given query."""
        print("async not activated.")
        keyword_nodes = self._keyword_retriever.retrieve(query_bundle)
        new_query_str = self.replace_abbreviations(query_bundle.query_str, keyword_nodes)
        query_bundle.embedding = self._weighted_aggre_embedding(new_query_str)
        vector_nodes = self._vector_retriever.retrieve(query_bundle)

        vector_ids = {n.node.node_id for n in vector_nodes}
        keyword_ids = {n.node.node_id for n in keyword_nodes}

        combined_dict = {n.node.node_id: n for n in vector_nodes}
        combined_dict.update({n.node.node_id: n for n in keyword_nodes})

        if self._mode == "AND":
            retrieve_ids = vector_ids.intersection(keyword_ids)
        else:
            retrieve_ids = vector_ids.union(keyword_ids)

        retrieve_nodes = [combined_dict[rid] for rid in retrieve_ids]
        return retrieve_nodes

    async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
        """Retrieve nodes given query."""
        print("async activated.")
        keyword_nodes = self._keyword_retriever.retrieve(query_bundle)
        new_query_str = self.replace_abbreviations(query_bundle.query_str, keyword_nodes)
        query_bundle.embedding = await self._aweighted_aggre_embedding(new_query_str)
        vector_nodes_full = self._vector_retriever.retrieve(query_bundle)

        feedback_nodes = self.feedback_retriever.retrieve(query_bundle)
        if feedback_nodes[0].score > 0.8 : 
            neg_context = feedback_nodes[0].metadata["message"]
        else:
            neg_context = ''
        vector_nodes = await self.reranker.run(query=query_bundle.query_str, nodes=vector_nodes_full, neg_context=neg_context)
        if vector_nodes == None:
            return keyword_nodes

        vector_ids = {n.node.node_id for n in vector_nodes}
        keyword_ids = {n.node.node_id for n in keyword_nodes}

        combined_dict = {n.node.node_id: n for n in vector_nodes}
        combined_dict.update({n.node.node_id: n for n in keyword_nodes})

        if self._mode == "AND":
            retrieve_ids = vector_ids.intersection(keyword_ids)
        else:
            retrieve_ids = vector_ids.union(keyword_ids)

        retrieve_nodes = [combined_dict[rid] for rid in retrieve_ids]
        return retrieve_nodes

    def replace_abbreviations(self, question, keyword_nodes):
        """
        Replaces abbreviations in the lexique with their complete names.
        Assumes abbreviations are separated by either two spaces or one space + one punctuation mark.
        """
        lexique = {}
        for node in keyword_nodes:
            keyword = node.node.metadata['keyword']
            keyword_complete = node.text.split(': ')[1]
            lexique[keyword] = keyword_complete

        words = question.split()
        result = []

        for word in words:
            if word in lexique:
                result.append(lexique[word])
                result.append('('+word+')')
            elif word[:-1] in lexique:
                result.append(lexique[word[:-1]])
                result.append('('+word[:-1]+')')
                result.append(word[-1])
            elif word[1:] in lexique:
                result.append(word[0])
                result.append(lexique[word[1:]])
                result.append('('+word[1:]+')')
            else:
                result.append(word)

        return ' '.join(result)
    
    def _get_query_embedding_threaded(self, query):
        """
        Get the query embedding using a separate thread.
        """
        result_queue = Queue()  # Create a queue to store the result
        thread = threading.Thread(target=self._get_query_embedding_async_thread, args=(query, result_queue))
        thread.start()
        thread.join()  # Wait for the thread to finish
        return result_queue.get()  # Retrieve the result from the queue

    def _get_query_embedding_async_thread(self, query, result_queue):
        """
        Helper method to run the asynchronous call in a separate thread.
        """
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
        result = loop.run_until_complete(self.embed_model.aget_query_embedding(query))
        loop.close()
        result_queue.put(result)  # Put the result in the queue
    
    def _weighted_aggre_embedding(self, new_query_str):
        """
        A weighted aggregation embedding method which can integrate the preivous queries in the current query.
        """
        current_embedding = self._get_query_embedding_threaded(new_query_str)
        print(current_embedding)
        if len(self.old_queries) == 0:
            return current_embedding
        elif len(self.old_queries) == 1:
            first_embedding = self._get_query_embedding_threaded(self.old_queries[0])
            new_embedding = []
            for i in range(len(current_embedding)):
                new_embedding.append(0.8 * current_embedding[i] + 0.2 * first_embedding[i])
            return new_embedding
        else:
            first_embedding = self._get_query_embedding_threaded(self.old_queries[0])
            other_embeddings = []
            for query in self.old_queries[1:]:
                other_embeddings.append(self._get_query_embedding_threaded(query))
            other_embedding = list(np.array(other_embeddings).mean(axis=0))
            new_embedding = []
            for i in range(len(current_embedding)):
                new_embedding.append(0.8 * current_embedding[i] + 0.2 * first_embedding[i] + 0.1 * other_embedding[i])
            return new_embedding
    async def _aweighted_aggre_embedding(self, new_query_str):
        """
        A weighted aggregation embedding method which can integrate the preivous queries in the current query.
        """
        current_embedding = await self.embed_model.aget_query_embedding(new_query_str)
        if len(self.old_queries) == 0:
            return current_embedding
        elif len(self.old_queries) == 1:
            first_embedding = await self.embed_model.aget_query_embedding(self.old_queries[0])
            new_embedding = []
            for i in range(len(current_embedding)):
                new_embedding.append(0.8 * current_embedding[i] + 0.2 * first_embedding[i])
            return new_embedding
        else:
            first_embedding = await self.embed_model.aget_query_embedding(self.old_queries[0])
            other_embeddings = []
            for query in self.old_queries[1:]:
                other_embeddings.append(await self.embed_model.aget_query_embedding(query))
            other_embedding = list(np.array(other_embeddings).mean(axis=0))
            new_embedding = []
            for i in range(len(current_embedding)):
                new_embedding.append(0.8 * current_embedding[i] + 0.2 * first_embedding[i] + 0.1 * other_embedding[i])
            return new_embedding

In [14]:
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.response_synthesizers import ResponseMode
qa_template = PromptTemplate(DEFAULT_QUERY_PROMPT_TMPL)
query_engine = RetrieverQueryEngine.from_args(retriever=custom_retriever,
                                              text_qa_template=qa_template,
                                              use_async=True,
                                              response_mode = ResponseMode.SIMPLE_SUMMARIZE,
                                              streaming=False
                                             )

In [124]:
# with sub_question
import time
start_time = time.time()

response = await query_engine.aquery("Quelle norme de tube acier utiliser pour faire une liaison détente / comptage pour alimenter une chaufferie ?")
print("========================================")
print(response.response)

end_time = time.time()
execution_time = end_time - start_time
print(f"Execution Time: {execution_time} seconds")

async activated.
Running step sub_question
Step sub_question produced event SubQuestionEvent
Running step reranker_prompt
Step reranker_prompt produced event RerankerDone
Running step rerank_validate
Step rerank_validate produced event RerankerBatchEvent
Running step reranker_prompt
Step reranker_prompt produced event RerankerDone
Running step rerank_validate
Step rerank_validate produced event RerankerBatchEvent
Running step reranker_prompt
Step reranker_prompt produced event RerankerDone
Running step rerank_validate
Reranked nodes are extracted.
Step rerank_validate produced event RerankerValidationDone
Running step filter_prompt
Step filter_prompt produced event FilterDone
Running step filter_validate
Reranked nodes are filtered.
Step filter_validate produced event StopEvent
Selon les informations fournies, pour réaliser une liaison détente/comptage pour alimenter une chaufferie, les normes de tubes acier à utiliser sont les suivantes :

1- Les tubes acier doivent être conformes aux

In [15]:
# without sub_question
import time
start_time = time.time()

response = await query_engine.aquery("Quelle norme de tube acier utiliser pour faire une liaison détente / comptage pour alimenter une chaufferie ?")
print("========================================")
print(response.response)

end_time = time.time()
execution_time = end_time - start_time
print("========================================")
print(f"Execution Time: {execution_time} seconds")

async activated.
Running step sub_question
Step sub_question produced event SubQuestionEvent
Running step reranker_prompt
Step reranker_prompt produced event RerankerDone
Running step rerank_validate
Step rerank_validate produced event RerankerBatchEvent
Running step reranker_prompt
Step reranker_prompt produced event RerankerDone
Running step rerank_validate
Step rerank_validate produced event RerankerBatchEvent
Running step reranker_prompt
Step reranker_prompt produced event RerankerDone
Running step rerank_validate
Step rerank_validate produced event RerankerBatchEvent
Running step reranker_prompt
Reranked nodes are extracted.
Step reranker_prompt produced event RerankerValidationDone
Running step filter_prompt
Step filter_prompt produced event FilterDone
Running step filter_validate
Reranked nodes are filtered.
Step filter_validate produced event StopEvent
Selon les informations fournies, pour les tubes en acier, les raccords doivent être conformes aux spécifications ATG B 521. De 

In [153]:
import time
start_time = time.time()

response = await query_engine.aquery("Quels sont les achats immobilisés pour le biométhane ?")
print("========================================")
print(response.response)

end_time = time.time()
execution_time = end_time - start_time
print("========================================")
print(f"Execution Time: {execution_time} seconds")

async activated.
Running step sub_question
Step sub_question produced event SubQuestionEvent
Running step reranker_prompt
Step reranker_prompt produced event RerankerDone
Running step rerank_validate
Step rerank_validate produced event RerankerBatchEvent
Running step reranker_prompt
Step reranker_prompt produced event RerankerDone
Running step rerank_validate
Step rerank_validate produced event RerankerBatchEvent
Running step reranker_prompt
Step reranker_prompt produced event RerankerDone
Running step rerank_validate
Reranked nodes are extracted.
Step rerank_validate produced event StopEvent
Il n'y a pas d'informations spécifiques sur les achats immobilisés pour le biométhane dans le contexte fourni. Les documents traitent principalement des aspects techniques et réglementaires liés à l'injection de biométhane dans le réseau de distribution de gaz naturel. Ils abordent des sujets tels que :

- Les caractéristiques techniques des postes d'injection de biométhane, notamment les systèmes

In [154]:
import time
start_time = time.time()

response = await query_engine.aquery("Que veut dire CICM ?")
print("========================================")
print(response.response)

end_time = time.time()
execution_time = end_time - start_time
print("========================================")
print(f"Execution Time: {execution_time} seconds")

async activated.
Running step sub_question
Step sub_question produced event SubQuestionEvent
Running step reranker_prompt
Step reranker_prompt produced event RerankerDone
Running step rerank_validate
Step rerank_validate produced event RerankerBatchEvent
Running step reranker_prompt
Step reranker_prompt produced event RerankerDone
Running step rerank_validate
Step rerank_validate produced event RerankerBatchEvent
Running step reranker_prompt
Step reranker_prompt produced event RerankerDone
Running step rerank_validate
Reranked nodes are extracted.
Step rerank_validate produced event StopEvent
CICM signifie "Conduite d'Immeuble, Conduite Montante". C'est un terme utilisé pour désigner les tuyauteries verticales qui alimentent différents niveaux d'un bâtiment d'habitation collectif, ainsi que les conduites d'immeuble qui raccordent ces tuyauteries verticales à la conduite principale.
Execution Time: 12.03987431526184 seconds


In [155]:
import time
start_time = time.time()

response = await query_engine.aquery("Dans le cadre de la surveillance des réseaux, doit-on surveiller les branchements ?")
print("========================================")
print(response.response)

end_time = time.time()
execution_time = end_time - start_time
print("========================================")
print(f"Execution Time: {execution_time} seconds")

async activated.
Running step sub_question
Step sub_question produced event SubQuestionEvent
Running step reranker_prompt
Step reranker_prompt produced event RerankerDone
Running step rerank_validate
Step rerank_validate produced event RerankerBatchEvent
Running step reranker_prompt
Step reranker_prompt produced event RerankerDone
Running step rerank_validate
Reranked nodes are extracted.
Step rerank_validate produced event RerankerValidationDone
Running step filter_prompt
Step filter_prompt produced event FilterDone
Running step filter_validate
Reranked nodes are filtered.
Step filter_validate produced event RerankerBatchEvent
Running step reranker_prompt
Step reranker_prompt produced event RerankerDone
Running step rerank_validate
Step rerank_validate produced event RerankerBatchEvent
Running step reranker_prompt
Reranked nodes are extracted.
Step reranker_prompt produced event RerankerValidationDone
Running step filter_prompt
Step filter_prompt produced event FilterDone
Running step

## Claude haiku -- avec sousquestion  -- 22.3 seconds / 4 questions

In [39]:
llm_sonnet = Bedrock(model="anthropic.claude-3-sonnet-20240229-v1:0", region_name="eu-west-3")
Response_haiku = Settings.llm.complete(CHOICE_SELECT_PROMPT_TMPL.format(context_str=default_format_node_batch_fn(nodes), query_str="Dans le cadre de la surveillance des réseaux, doit-on surveiller les branchements ?"))
print(Response_haiku.text)
print("======================")
Response_sonnet = llm_sonnet.complete(CHOICE_SELECT_PROMPT_TMPL.format(context_str=default_format_node_batch_fn(nodes), query_str="Dans le cadre de la surveillance des réseaux, doit-on surveiller les branchements ?"))
print(Response_sonnet.text)

Selon les informations fournies, les documents pertinents à la question "Dans le cadre de la surveillance des réseaux, doit-on surveiller les branchements ?" sont :

[6, 7]

Le document 6 traite du branchement particulier et indique que chaque branchement particulier doit être muni d'un organe de coupure individuelle (OCI) qui doit être accessible, bien signalé, facilement manœuvrable et identifié. Cela suggère que la surveillance des branchements fait partie de la surveillance des réseaux.

Le document 7 décrit la conduite d'immeuble, qui fait suite au branchement d'immeuble collectif. Cela indique également que les branchements font partie intégrante du réseau à surveiller.
D'après les résumés des documents fournis, les documents pertinents pour répondre à la question "Dans le cadre de la surveillance des réseaux, doit-on surveiller les branchements ?" sont les suivants :

[6, 17]

Le document 6 mentionne que "Le branchement particulier est la partie d'ouvrage située immédiatement en

In [40]:
from llama_index.core.workflow import Event
from llama_index.core.workflow import (
    Workflow,
    StartEvent,
    StopEvent,
    Context,
    step,
)
import re
from llama_index.core.indices.utils import (
    default_format_node_batch_fn,
)
from llama_index.core.retrievers import BaseRetriever, VectorIndexRetriever
from llama_index.core.output_parsers.pydantic import PydanticOutputParser
from llama_index.core.program import LLMTextCompletionProgram
from pydantic import BaseModel, Field
from typing import Any, List, Optional, Type, Union

SUB_QUESTION_PROMPT_STR = """
Vous êtes un assistant utile qui génère plusieurs requêtes de recherche basées sur une seule requête d'entrée. 
Générer 3 requêtes de recherche, une sur chaque ligne, liées à la requête d'entrée suivante:
Requête: {query}
Requêtes:
"""

CHOICE_SELECT_PROMPT_TMPL = """
Une liste de documents est présentée ci-dessous. Chaque document est accompagné d'un numéro et d'un résumé du document. Une question est également fournie.
Donne-moi uniquement les numéros des documents pertinents à la question en utilisant le format du tableau.
Voici quelques exemples: 
Document 1:
<résumé du document 1>

Document 2:
<résumé du document 2>

...

Document 10:
<résumé du document 10>

Question: <question>
Tableau des nombres des documents pertinents: [2, 4, 5, 6]

Essayons ceci maintenant :

{context_str}
Question: {query_str}
Tableau des nombres des documents pertinents:
"""

CHOICE_REFLECTION_PROMPT_STR = """
Cela a provoqué l'erreur : {error}

Réessayez, la réponse doit contenir uniquement les numéros des documents pertinents.
Ne répétez pas le schéma.
"""

NEGATIVE_FILTER_PROMPT_TMPL = """
Voici une liste de contextes pertinents pour une requête donnée. Vous devez filtrer cette liste pour exclure les contextes qui correspondent à un contexte négatif spécifique.
Donne-moi uniquement les numéros des documents en utilisant le format du tableau.
Voici quelques exemples: 
Requête principale : <question>
Contexte négatif : <contexte_negatif>
Liste des contextes :
Document 1:
<résumé du document 1>

Document 2:
<résumé du document 2>

Document 3:
<résumé du document 10>

Document 4:
<résumé du document 10>

Tableau des nombres des documents: [1, 2, 3]

Essayons ceci maintenant :

Requête principale : {query_str}
Contexte négatif : {contexte_negatif}
Liste des contextes : {contexte}
Tableau des nombres des documents:
"""

NEGATIVE_FILTER_REFLECTION_PROMPT_STR = """
Cela a provoqué l'erreur : {error}

Réessayez, la réponse doit contenir uniquement les numéros des documents.
Ne répétez pas le schéma.
"""

class SubQuestionEvent(Event):
    query: str
    neg_context: str
    nodes: list

class RerankerValidationErrorEvent(Event):
    error: str
    query: str
    neg_context: str
    nodes: list
    rerank_nodes: list

class RerankerBatchEvent(Event):
    query: str
    neg_context: str
    nodes: list
    rerank_nodes: list

class RerankerDone(Event):
    prompt: str
    query: str
    neg_context: str
    nodes: list
    rerank_nodes: list
    batch_nodes: list

class RerankerValidationDone(Event):
    query: str
    neg_context: str
    nodes: list
    rerank_nodes: list

class FilterValidationErrorEvent(Event):
    error: str
    query: str
    neg_context: str
    nodes: list
    rerank_nodes: list

class FilterDone(Event):
    prompt: str
    query: str
    neg_context: str
    nodes: list
    rerank_nodes: list

class RerankerFlow(Workflow):
    def __init__(self, vector_retriever: VectorIndexRetriever, timeout = 120, verbose = True):
        self.vector_retriever = vector_retriever
        self.batch_size = 3 # the divided size to rerank all the nodes 
        self.max_retries = 3 # the max retries for the prompt
        self.min_chunk = 10 # the minimum size of the chunks after the rerank
        super().__init__(timeout=timeout, verbose=verbose)

    # the state to generate sub questions and put inside nodes
    @step()
    async def sub_question(
        self, ev: StartEvent
    ) -> SubQuestionEvent:
        query = ev.query
        nodes = ev.nodes
        neg_context = ev.neg_context
        subquestions = []
        try:
            subquestions_response = Settings.llm.complete(SUB_QUESTION_PROMPT_STR.format(query=query))
            subquestions = subquestions_response.text.split("\n")
        except Exception as e:
            print(e)
        if len(subquestions) > 3:
            nodes1 = self.vector_retriever.retrieve(subquestions[-3])
            nodes2 = self.vector_retriever.retrieve(subquestions[-2])
            nodes3 = self.vector_retriever.retrieve(subquestions[-1])
            one_third_index = int(len(nodes)/3)
            two_third_index = int(len(nodes)/3*2)
            nodes_ids = []
            for node in nodes:
                nodes_ids.append(node.id_)
            indicator1 = 0
            indicator2 = 0
            indicator3 = 0
            for i in range(len(nodes1)):
                if nodes1[i].id_ not in nodes_ids:
                    indicator1 += 1
                    nodes_ids.append(nodes1[i].id_)
                    if indicator1 == 1:
                        nodes.insert(one_third_index, nodes1[i])
                    elif indicator1 == 2:
                        nodes.insert(one_third_index, nodes1[i])
                    elif indicator1 == 3:
                        nodes.insert(two_third_index, nodes1[i])
                    elif indicator1 == 4:
                        nodes.insert(two_third_index, nodes1[i])
                    elif indicator1 == 5:
                        nodes.append(nodes1[i])
                    elif indicator1 == 6:
                        nodes.append(nodes1[i])
                        break
            for i in range(len(nodes2)):
                if nodes2[i].id_ not in nodes_ids:
                    indicator2 += 1
                    nodes_ids.append(nodes2[i].id_)
                    if indicator2 == 1:
                        nodes.insert(one_third_index, nodes2[i])
                    elif indicator2 == 2:
                        nodes.insert(one_third_index, nodes2[i])
                    elif indicator2 == 3:
                        nodes.insert(two_third_index, nodes2[i])
                    elif indicator2 == 4:
                        nodes.insert(two_third_index, nodes2[i])
                    elif indicator2 == 5:
                        nodes.append(nodes2[i])
                    elif indicator2 == 6:
                        nodes.append(nodes2[i])
                        break
            for i in range(len(nodes3)):
                if nodes3[i].id_ not in nodes_ids:
                    indicator3 += 1
                    nodes_ids.append(nodes3[i].id_)
                    if indicator3 == 1:
                        nodes.insert(one_third_index, nodes3[i])
                    elif indicator3 == 2:
                        nodes.insert(one_third_index, nodes3[i])
                    elif indicator3 == 3:
                        nodes.insert(two_third_index, nodes3[i])
                    elif indicator3 == 4:
                        nodes.insert(two_third_index, nodes3[i])
                    elif indicator3 == 5:
                        nodes.append(nodes3[i])
                    elif indicator3 == 6:
                        nodes.append(nodes3[i])
                        break
        return SubQuestionEvent(query=query, neg_context=neg_context, nodes=nodes)

    # the state to generate prompt for reranking
    @step(pass_context=True)
    async def reranker_prompt(
        self, ctx: Context, ev: Union[SubQuestionEvent, RerankerValidationErrorEvent, RerankerBatchEvent]
    ) -> Union[StopEvent, RerankerDone]:
        query = ev.query
        nodes = ev.nodes
        neg_context = ev.neg_context
        window = int(len(nodes)/3)

        if isinstance(ev, SubQuestionEvent):
            rerank_nodes = []
            current_batch = ctx.data.get("batch", 0)
            ctx.data["batch"] = current_batch + 1
            batch_nodes = nodes[window*current_batch:window*(current_batch+1)]
            reflection_prompt = ""
        elif isinstance(ev, RerankerBatchEvent):
            rerank_nodes = ev.rerank_nodes
            current_batch = ctx.data.get("batch", 0)
            # iterate the nodes based on batch size
            if current_batch >= self.batch_size:
                print("Reranked nodes are extracted.")
                # if no negative context, output reranked nodes. otherwise, go the state of filter
                if neg_context == '':
                    if len(rerank_nodes) == 0:
                        return StopEvent(result=None)
                    else:
                        return StopEvent(result=rerank_nodes)
                else:
                    return RerankerValidationDone(query=query, neg_context=neg_context, nodes=nodes, rerank_nodes=rerank_nodes)
            else:
                ctx.data["batch"] = current_batch + 1
            batch_nodes = nodes[window*current_batch:window*(current_batch+1)]
            reflection_prompt = ""
        # prompt error event
        elif isinstance(ev, RerankerValidationErrorEvent):
            current_retries = ctx.data.get("prompt_retries", 0)
            rerank_nodes = ev.rerank_nodes
            if current_retries >= self.max_retries:
                if neg_context == '':
                    if len(rerank_nodes) == 0:
                        return StopEvent(result=None)
                    else:
                        return StopEvent(result=rerank_nodes)
                else:
                    return RerankerValidationDone(query=query, neg_context=neg_context, nodes=nodes, rerank_nodes=rerank_nodes)
            else:
                ctx.data["prompt_retries"] = current_retries + 1
            current_batch = ctx.data.get("batch", 0)
            batch_nodes = nodes[window*current_batch:window*(current_batch+1)]
            reflection_prompt = CHOICE_REFLECTION_PROMPT_STR.format(error=ev.error)
        
        prompt = CHOICE_SELECT_PROMPT_TMPL
        if reflection_prompt:
            prompt += reflection_prompt
        return RerankerDone(prompt=prompt, query=query, neg_context=neg_context, nodes=nodes, rerank_nodes=rerank_nodes, batch_nodes=batch_nodes)

    # call LLM to rerank nodes
    @step()
    async def rerank_validate(
        self, ev: RerankerDone
    ) -> Union[StopEvent, RerankerValidationDone, RerankerValidationErrorEvent, RerankerBatchEvent]:
        try:
            Response_haiku = Settings.llm.complete(ev.prompt.format(context_str=default_format_node_batch_fn(ev.batch_nodes), query_str=ev.query))
            match = re.search(r'\[(.*?)\]', Response_haiku.text)
            if match:
                output = list(map(int, match.group(1).split(',')))
            else:
                output = []
            new_nodes = []
            for i in output:
                if i-1 < len(ev.batch_nodes):
                    new_nodes.append(ev.batch_nodes[i-1])
        except Exception as e:
            print("Validation failed, retrying...")
            return RerankerValidationErrorEvent(
                error=str(e), query=ev.query, neg_context=ev.neg_context, nodes=ev.nodes, rerank_nodes=ev.rerank_nodes
            )
        rerank_nodes = new_nodes + ev.rerank_nodes
        # if reranked nodes are less than min chunks, return to the state of batch event, otherwise output the reranked nodes.
        if len(rerank_nodes) > self.min_chunk:
            print("Reranked nodes are extracted.")
            if ev.neg_context == '':
                return StopEvent(result=rerank_nodes)
            else:
                return RerankerValidationDone(query=ev.query, neg_context=ev.neg_context, nodes=ev.nodes, rerank_nodes=rerank_nodes)
        else:
            return RerankerBatchEvent(query=ev.query, neg_context=ev.neg_context, nodes=ev.nodes, rerank_nodes=rerank_nodes)

    # prepare the prompt to filter the reranked nodes
    @step(pass_context=True)
    async def filter_prompt(
        self, ctx: Context, ev: Union[RerankerValidationDone, FilterValidationErrorEvent]
    ) -> Union[StopEvent, FilterDone]:
        query = ev.query
        nodes = ev.nodes
        neg_context = ev.neg_context
        rerank_nodes = ev.rerank_nodes

        if isinstance(ev, RerankerValidationDone):
            reflection_prompt = ""
        elif isinstance(ev, FilterValidationErrorEvent):
            current_retries = ctx.data.get("prompt_retries", 0)
            if current_retries >= self.max_retries:
                if len(rerank_nodes) == 0:
                    return StopEvent(result=None)
                else:
                    return StopEvent(result=rerank_nodes)
            else:
                ctx.data["prompt_retries"] = current_retries + 1
            reflection_prompt = NEGATIVE_FILTER_REFLECTION_PROMPT_STR.format(error=ev.error)
        prompt = NEGATIVE_FILTER_PROMPT_TMPL
        if reflection_prompt:
            prompt += reflection_prompt
        return FilterDone(prompt=prompt, query=query, neg_context=neg_context, nodes=nodes, rerank_nodes=rerank_nodes)

    # call LLM to filter the reranked nodes
    @step(pass_context=True)
    async def filter_validate(
        self, ctx: Context, ev: FilterDone
    ) -> Union[StopEvent, FilterValidationErrorEvent, RerankerBatchEvent]:
        try:
            Response_haiku = Settings.llm.complete(ev.prompt.format(context_str=default_format_node_batch_fn(ev.rerank_nodes), query_str=ev.query, contexte_negatif=ev.neg_context))
            match = re.search(r'\[(.*?)\]', Response_haiku.text)
            if match:
                output = list(map(int, match.group(1).split(',')))
            else:
                output = []
        except Exception as e:
            print("Validation failed, retrying...")
            return FilterValidationErrorEvent(
                error=str(e), query=ev.query, neg_context=ev.neg_context, nodes=ev.nodes, rerank_nodes=ev.rerank_nodes
            )
        new_nodes = []
        for i in output:
            new_nodes.append(ev.rerank_nodes[i-1])
        print("Reranked nodes are filtered.")
        current_batch = ctx.data.get("batch", 0)
        if len(new_nodes) > self.min_chunk:
            return StopEvent(result=new_nodes)
        else:
            if current_batch >= self.batch_size:
                if len(new_nodes) == 0:
                    return StopEvent(result=None)
                else:
                    return StopEvent(result=new_nodes)
            else:
                return RerankerBatchEvent(query=ev.query, neg_context=ev.neg_context, nodes=ev.nodes, rerank_nodes=new_nodes)

In [41]:
from llama_index.core.schema import QueryBundle, NodeWithScore
from llama_index.core.retrievers import BaseRetriever, VectorIndexRetriever
from llama_index.core.indices.keyword_table import KeywordTableGPTRetriever
from typing import List
import threading
import asyncio
from queue import Queue

class CustomRetriever(BaseRetriever):
    """Custom retriever that performs both semantic search and hybrid search."""

    def __init__(
        self,
        vector_retriever: VectorIndexRetriever,
        keyword_retriever: KeywordTableGPTRetriever,
        embed_model: BedrockEmbedding,
        old_queries: List = [],
        mode: str = "OR",
        
    ) -> None:
        """Init params."""

        self._vector_retriever = vector_retriever
        self._keyword_retriever = keyword_retriever
        self.embed_model = embed_model
        self.reranker = RerankerFlow(vector_retriever=vector_retriever)
        self.old_queries = old_queries
        index_feedback = load_index_from_storage(storage_context=StorageContext.from_defaults(persist_dir="finetune/vector_persist_800_feedback"))
        self.feedback_retriever = VectorIndexRetriever(index=index_feedback, similarity_top_k=1)
        if mode not in ("AND", "OR"):
            raise ValueError("Invalid mode.")
        self._mode = mode
        super().__init__()

    def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
        """Retrieve nodes given query."""
        print("async not activated.")
        keyword_nodes = self._keyword_retriever.retrieve(query_bundle)
        new_query_str = self.replace_abbreviations(query_bundle.query_str, keyword_nodes)
        query_bundle.embedding = self._weighted_aggre_embedding(new_query_str)
        vector_nodes = self._vector_retriever.retrieve(query_bundle)

        vector_ids = {n.node.node_id for n in vector_nodes}
        keyword_ids = {n.node.node_id for n in keyword_nodes}

        combined_dict = {n.node.node_id: n for n in vector_nodes}
        combined_dict.update({n.node.node_id: n for n in keyword_nodes})

        if self._mode == "AND":
            retrieve_ids = vector_ids.intersection(keyword_ids)
        else:
            retrieve_ids = vector_ids.union(keyword_ids)

        retrieve_nodes = [combined_dict[rid] for rid in retrieve_ids]
        return retrieve_nodes

    async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
        """Retrieve nodes given query."""
        print("async activated.")
        keyword_nodes = self._keyword_retriever.retrieve(query_bundle)
        new_query_str = self.replace_abbreviations(query_bundle.query_str, keyword_nodes)
        query_bundle.embedding = await self._aweighted_aggre_embedding(new_query_str)
        vector_nodes_full = self._vector_retriever.retrieve(query_bundle)

        feedback_nodes = self.feedback_retriever.retrieve(query_bundle)
        if feedback_nodes[0].score > 0.8 : 
            neg_context = feedback_nodes[0].metadata["message"]
        else:
            neg_context = ''
        vector_nodes = await self.reranker.run(query=query_bundle.query_str, nodes=vector_nodes_full, neg_context=neg_context)
        if vector_nodes == None:
            return keyword_nodes

        vector_ids = {n.node.node_id for n in vector_nodes}
        keyword_ids = {n.node.node_id for n in keyword_nodes}

        combined_dict = {n.node.node_id: n for n in vector_nodes}
        combined_dict.update({n.node.node_id: n for n in keyword_nodes})

        if self._mode == "AND":
            retrieve_ids = vector_ids.intersection(keyword_ids)
        else:
            retrieve_ids = vector_ids.union(keyword_ids)

        retrieve_nodes = [combined_dict[rid] for rid in retrieve_ids]
        return retrieve_nodes

    def replace_abbreviations(self, question, keyword_nodes):
        """
        Replaces abbreviations in the lexique with their complete names.
        Assumes abbreviations are separated by either two spaces or one space + one punctuation mark.
        """
        lexique = {}
        for node in keyword_nodes:
            keyword = node.node.metadata['keyword']
            keyword_complete = node.text.split(': ')[1]
            lexique[keyword] = keyword_complete

        words = question.split()
        result = []

        for word in words:
            if word in lexique:
                result.append(lexique[word])
                result.append('('+word+')')
            elif word[:-1] in lexique:
                result.append(lexique[word[:-1]])
                result.append('('+word[:-1]+')')
                result.append(word[-1])
            elif word[1:] in lexique:
                result.append(word[0])
                result.append(lexique[word[1:]])
                result.append('('+word[1:]+')')
            else:
                result.append(word)

        return ' '.join(result)
    
    def _get_query_embedding_threaded(self, query):
        """
        Get the query embedding using a separate thread.
        """
        result_queue = Queue()  # Create a queue to store the result
        thread = threading.Thread(target=self._get_query_embedding_async_thread, args=(query, result_queue))
        thread.start()
        thread.join()  # Wait for the thread to finish
        return result_queue.get()  # Retrieve the result from the queue

    def _get_query_embedding_async_thread(self, query, result_queue):
        """
        Helper method to run the asynchronous call in a separate thread.
        """
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
        result = loop.run_until_complete(self.embed_model.aget_query_embedding(query))
        loop.close()
        result_queue.put(result)  # Put the result in the queue
    
    def _weighted_aggre_embedding(self, new_query_str):
        """
        A weighted aggregation embedding method which can integrate the preivous queries in the current query.
        """
        current_embedding = self._get_query_embedding_threaded(new_query_str)
        print(current_embedding)
        if len(self.old_queries) == 0:
            return current_embedding
        elif len(self.old_queries) == 1:
            first_embedding = self._get_query_embedding_threaded(self.old_queries[0])
            new_embedding = []
            for i in range(len(current_embedding)):
                new_embedding.append(0.8 * current_embedding[i] + 0.2 * first_embedding[i])
            return new_embedding
        else:
            first_embedding = self._get_query_embedding_threaded(self.old_queries[0])
            other_embeddings = []
            for query in self.old_queries[1:]:
                other_embeddings.append(self._get_query_embedding_threaded(query))
            other_embedding = list(np.array(other_embeddings).mean(axis=0))
            new_embedding = []
            for i in range(len(current_embedding)):
                new_embedding.append(0.8 * current_embedding[i] + 0.2 * first_embedding[i] + 0.1 * other_embedding[i])
            return new_embedding
    async def _aweighted_aggre_embedding(self, new_query_str):
        """
        A weighted aggregation embedding method which can integrate the preivous queries in the current query.
        """
        current_embedding = await self.embed_model.aget_query_embedding(new_query_str)
        if len(self.old_queries) == 0:
            return current_embedding
        elif len(self.old_queries) == 1:
            first_embedding = await self.embed_model.aget_query_embedding(self.old_queries[0])
            new_embedding = []
            for i in range(len(current_embedding)):
                new_embedding.append(0.8 * current_embedding[i] + 0.2 * first_embedding[i])
            return new_embedding
        else:
            first_embedding = await self.embed_model.aget_query_embedding(self.old_queries[0])
            other_embeddings = []
            for query in self.old_queries[1:]:
                other_embeddings.append(await self.embed_model.aget_query_embedding(query))
            other_embedding = list(np.array(other_embeddings).mean(axis=0))
            new_embedding = []
            for i in range(len(current_embedding)):
                new_embedding.append(0.8 * current_embedding[i] + 0.2 * first_embedding[i] + 0.1 * other_embedding[i])
            return new_embedding

In [42]:
vector_retriever = VectorIndexRetriever(index=index_doc, similarity_top_k=30)
keyword_retriever = KeywordTableGPTRetriever(index=index_keyword, max_keywords_per_query = 20, keyword_extract_template=KEYWORD_EXTRACT_TEMPLATE)
custom_retriever = CustomRetriever(vector_retriever, keyword_retriever, embed_model=Settings.embed_model, old_queries=[])

In [43]:
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.response_synthesizers import ResponseMode
qa_template = PromptTemplate(DEFAULT_QUERY_PROMPT_TMPL)
query_engine = RetrieverQueryEngine.from_args(retriever=custom_retriever,
                                              text_qa_template=qa_template,
                                              use_async=True,
                                              response_mode = ResponseMode.SIMPLE_SUMMARIZE,
                                              streaming=False
                                             )

In [44]:
import time
start_time = time.time()

response = await query_engine.aquery("Quelle norme de tube acier utiliser pour faire une liaison détente / comptage pour alimenter une chaufferie ?")
print("========================================")
print(response.response)

end_time = time.time()
execution_time = end_time - start_time
print("========================================")
print(f"Execution Time: {execution_time} seconds")

async activated.
Running step sub_question
Step sub_question produced event SubQuestionEvent
Running step reranker_prompt
Step reranker_prompt produced event RerankerDone
Running step rerank_validate
Step rerank_validate produced event RerankerBatchEvent
Running step reranker_prompt
Step reranker_prompt produced event RerankerDone
Running step rerank_validate
Step rerank_validate produced event RerankerBatchEvent
Running step reranker_prompt
Step reranker_prompt produced event RerankerDone
Running step rerank_validate
Step rerank_validate produced event RerankerBatchEvent
Running step reranker_prompt
Reranked nodes are extracted.
Step reranker_prompt produced event RerankerValidationDone
Running step filter_prompt
Step filter_prompt produced event FilterDone
Running step filter_validate
Validation failed, retrying...
Step filter_validate produced event FilterValidationErrorEvent
Running step filter_prompt
Step filter_prompt produced event FilterDone
Running step filter_validate
Validat

In [45]:
import time
start_time = time.time()

response = await query_engine.aquery("Quels sont les achats immobilisés pour le biométhane ?")
print("========================================")
print(response.response)

end_time = time.time()
execution_time = end_time - start_time
print("========================================")
print(f"Execution Time: {execution_time} seconds")

async activated.
Running step sub_question
Step sub_question produced event SubQuestionEvent
Running step reranker_prompt
Step reranker_prompt produced event RerankerDone
Running step rerank_validate
Step rerank_validate produced event RerankerBatchEvent
Running step reranker_prompt
Step reranker_prompt produced event RerankerDone
Running step rerank_validate
Step rerank_validate produced event RerankerBatchEvent
Running step reranker_prompt
Step reranker_prompt produced event RerankerDone
Running step rerank_validate
Reranked nodes are extracted.
Step rerank_validate produced event StopEvent
Le contexte fourni ne mentionne pas d'achats immobilisés pour le biométhane. Les informations se concentrent plutôt sur les spécifications techniques et le fonctionnement des postes d'injection de biométhane sur le réseau de distribution de gaz naturel. Le texte décrit en détail les différents équipements et systèmes qui composent ces postes, tels que :

- Le système d'odorisation du biométhane, a

In [46]:
import time
start_time = time.time()

response = await query_engine.aquery("Que veut dire CICM ?")
print("========================================")
print(response.response)

end_time = time.time()
execution_time = end_time - start_time
print("========================================")
print(f"Execution Time: {execution_time} seconds")

async activated.
Running step sub_question
Step sub_question produced event SubQuestionEvent
Running step reranker_prompt
Step reranker_prompt produced event RerankerDone
Running step rerank_validate
Step rerank_validate produced event RerankerBatchEvent
Running step reranker_prompt
Step reranker_prompt produced event RerankerDone
Running step rerank_validate
Reranked nodes are extracted.
Step rerank_validate produced event StopEvent
CICM signifie "Conduite d'Immeuble, Conduite Montante". C'est un terme utilisé pour désigner les canalisations de gaz qui alimentent les immeubles, notamment la conduite d'immeuble située à l'intérieur du bâtiment et la conduite montante qui traverse les différents étages.
Execution Time: 16.471271991729736 seconds


In [47]:
import time
start_time = time.time()

response = await query_engine.aquery("Dans le cadre de la surveillance des réseaux, doit-on surveiller les branchements ?")
print("========================================")
print(response.response)

end_time = time.time()
execution_time = end_time - start_time
print("========================================")
print(f"Execution Time: {execution_time} seconds")

async activated.
Running step sub_question
Step sub_question produced event SubQuestionEvent
Running step reranker_prompt
Step reranker_prompt produced event RerankerDone
Running step rerank_validate
Step rerank_validate produced event RerankerBatchEvent
Running step reranker_prompt
Step reranker_prompt produced event RerankerDone
Running step rerank_validate
Reranked nodes are extracted.
Step rerank_validate produced event RerankerValidationDone
Running step filter_prompt
Step filter_prompt produced event FilterDone
Running step filter_validate
Validation failed, retrying...
Step filter_validate produced event FilterValidationErrorEvent
Running step filter_prompt
Step filter_prompt produced event FilterDone
Running step filter_validate
Validation failed, retrying...
Step filter_validate produced event FilterValidationErrorEvent
Running step filter_prompt
Step filter_prompt produced event FilterDone
Running step filter_validate
Validation failed, retrying...
Step filter_validate produc

## Claude haiku -- sans sousquestion  -- 11.35 seconds / 4 questions

In [13]:
from llama_index.core.workflow import Event
from llama_index.core.workflow import (
    Workflow,
    StartEvent,
    StopEvent,
    Context,
    step,
)
import re
from llama_index.core.indices.utils import (
    default_format_node_batch_fn,
)
from llama_index.core.retrievers import BaseRetriever, VectorIndexRetriever
from llama_index.core.output_parsers.pydantic import PydanticOutputParser
from llama_index.core.program import LLMTextCompletionProgram
from pydantic import BaseModel, Field
from typing import Any, List, Optional, Type, Union

SUB_QUESTION_PROMPT_STR = """
Vous êtes un assistant utile qui génère plusieurs requêtes de recherche basées sur une seule requête d'entrée. 
Générer 3 requêtes de recherche, une sur chaque ligne, liées à la requête d'entrée suivante:
Requête: {query}
Requêtes:
"""

CHOICE_SELECT_PROMPT_TMPL = """
Une liste de documents est présentée ci-dessous. Chaque document est accompagné d'un numéro et d'un résumé du document. Une question est également fournie.
Donne-moi uniquement les numéros des documents pertinents à la question en utilisant le format du tableau.
Voici quelques exemples: 
Document 1:
<résumé du document 1>

Document 2:
<résumé du document 2>

...

Document 10:
<résumé du document 10>

Question: <question>
Tableau des nombres des documents pertinents: [2, 4, 5, 6]

Essayons ceci maintenant :

{context_str}
Question: {query_str}
Tableau des nombres des documents pertinents:
"""

CHOICE_REFLECTION_PROMPT_STR = """
Cela a provoqué l'erreur : {error}

Réessayez, la réponse doit contenir uniquement les numéros des documents pertinents.
Ne répétez pas le schéma.
"""

NEGATIVE_FILTER_PROMPT_TMPL = """
Voici une liste de contextes pertinents pour une requête donnée. Vous devez filtrer cette liste pour exclure les contextes qui correspondent à un contexte négatif spécifique.
Donne-moi uniquement les numéros des documents en utilisant le format du tableau.
Voici quelques exemples: 
Requête principale : <question>
Contexte négatif : <contexte_negatif>
Liste des contextes :
Document 1:
<résumé du document 1>

Document 2:
<résumé du document 2>

Document 3:
<résumé du document 10>

Document 4:
<résumé du document 10>

Tableau des nombres des documents: [1, 2, 3]

Essayons ceci maintenant :

Requête principale : {query_str}
Contexte négatif : {contexte_negatif}
Liste des contextes : {contexte}
Tableau des nombres des documents:
"""

NEGATIVE_FILTER_REFLECTION_PROMPT_STR = """
Cela a provoqué l'erreur : {error}

Réessayez, la réponse doit contenir uniquement les numéros des documents.
Ne répétez pas le schéma.
"""

class SubQuestionEvent(Event):
    query: str
    neg_context: str
    nodes: list

class RerankerValidationErrorEvent(Event):
    error: str
    query: str
    neg_context: str
    nodes: list
    rerank_nodes: list

class RerankerBatchEvent(Event):
    query: str
    neg_context: str
    nodes: list
    rerank_nodes: list

class RerankerDone(Event):
    prompt: str
    query: str
    neg_context: str
    nodes: list
    rerank_nodes: list
    batch_nodes: list

class RerankerValidationDone(Event):
    query: str
    neg_context: str
    nodes: list
    rerank_nodes: list

class FilterValidationErrorEvent(Event):
    error: str
    query: str
    neg_context: str
    nodes: list
    rerank_nodes: list

class FilterDone(Event):
    prompt: str
    query: str
    neg_context: str
    nodes: list
    rerank_nodes: list

class RerankerFlow(Workflow):
    def __init__(self, vector_retriever: VectorIndexRetriever, timeout = 120, verbose = True):
        self.vector_retriever = vector_retriever
        self.batch_size = 3 # the divided size to rerank all the nodes 
        self.max_retries = 3 # the max retries for the prompt
        self.min_chunk = 10 # the minimum size of the chunks after the rerank
        super().__init__(timeout=timeout, verbose=verbose)

    # the state to generate sub questions and put inside nodes
    @step()
    async def sub_question(
        self, ev: StartEvent
    ) -> SubQuestionEvent:
        query = ev.query
        nodes = ev.nodes
        neg_context = ev.neg_context
        # subquestions = []
        # try:
        #     subquestions_response = Settings.llm.complete(SUB_QUESTION_PROMPT_STR.format(query=query))
        #     subquestions = subquestions_response.text.split("\n")
        # except Exception as e:
        #     print(e)
        # if len(subquestions) > 3:
        #     nodes1 = self.vector_retriever.retrieve(subquestions[-3])
        #     nodes2 = self.vector_retriever.retrieve(subquestions[-2])
        #     nodes3 = self.vector_retriever.retrieve(subquestions[-1])
        #     one_third_index = int(len(nodes)/3)
        #     two_third_index = int(len(nodes)/3*2)
        #     nodes_ids = []
        #     for node in nodes:
        #         nodes_ids.append(node.id_)
        #     indicator1 = 0
        #     indicator2 = 0
        #     indicator3 = 0
        #     for i in range(len(nodes1)):
        #         if nodes1[i].id_ not in nodes_ids:
        #             indicator1 += 1
        #             nodes_ids.append(nodes1[i].id_)
        #             if indicator1 == 1:
        #                 nodes.insert(one_third_index, nodes1[i])
        #             elif indicator1 == 2:
        #                 nodes.insert(one_third_index, nodes1[i])
        #             elif indicator1 == 3:
        #                 nodes.insert(two_third_index, nodes1[i])
        #             elif indicator1 == 4:
        #                 nodes.insert(two_third_index, nodes1[i])
        #             elif indicator1 == 5:
        #                 nodes.append(nodes1[i])
        #             elif indicator1 == 6:
        #                 nodes.append(nodes1[i])
        #                 break
        #     for i in range(len(nodes2)):
        #         if nodes2[i].id_ not in nodes_ids:
        #             indicator2 += 1
        #             nodes_ids.append(nodes2[i].id_)
        #             if indicator2 == 1:
        #                 nodes.insert(one_third_index, nodes2[i])
        #             elif indicator2 == 2:
        #                 nodes.insert(one_third_index, nodes2[i])
        #             elif indicator2 == 3:
        #                 nodes.insert(two_third_index, nodes2[i])
        #             elif indicator2 == 4:
        #                 nodes.insert(two_third_index, nodes2[i])
        #             elif indicator2 == 5:
        #                 nodes.append(nodes2[i])
        #             elif indicator2 == 6:
        #                 nodes.append(nodes2[i])
        #                 break
        #     for i in range(len(nodes3)):
        #         if nodes3[i].id_ not in nodes_ids:
        #             indicator3 += 1
        #             nodes_ids.append(nodes3[i].id_)
        #             if indicator3 == 1:
        #                 nodes.insert(one_third_index, nodes3[i])
        #             elif indicator3 == 2:
        #                 nodes.insert(one_third_index, nodes3[i])
        #             elif indicator3 == 3:
        #                 nodes.insert(two_third_index, nodes3[i])
        #             elif indicator3 == 4:
        #                 nodes.insert(two_third_index, nodes3[i])
        #             elif indicator3 == 5:
        #                 nodes.append(nodes3[i])
        #             elif indicator3 == 6:
        #                 nodes.append(nodes3[i])
        #                 break
        return SubQuestionEvent(query=query, neg_context=neg_context, nodes=nodes)

    # the state to generate prompt for reranking
    @step(pass_context=True)
    async def reranker_prompt(
        self, ctx: Context, ev: Union[SubQuestionEvent, RerankerValidationErrorEvent, RerankerBatchEvent]
    ) -> Union[StopEvent, RerankerDone]:
        query = ev.query
        nodes = ev.nodes
        neg_context = ev.neg_context
        window = int(len(nodes)/3)

        if isinstance(ev, SubQuestionEvent):
            rerank_nodes = []
            current_batch = ctx.data.get("batch", 0)
            ctx.data["batch"] = current_batch + 1
            batch_nodes = nodes[window*current_batch:window*(current_batch+1)]
            reflection_prompt = ""
        elif isinstance(ev, RerankerBatchEvent):
            rerank_nodes = ev.rerank_nodes
            current_batch = ctx.data.get("batch", 0)
            # iterate the nodes based on batch size
            if current_batch >= self.batch_size:
                print("Reranked nodes are extracted.")
                # if no negative context, output reranked nodes. otherwise, go the state of filter
                if neg_context == '':
                    if len(rerank_nodes) == 0:
                        return StopEvent(result=None)
                    else:
                        return StopEvent(result=rerank_nodes)
                else:
                    return RerankerValidationDone(query=query, neg_context=neg_context, nodes=nodes, rerank_nodes=rerank_nodes)
            else:
                ctx.data["batch"] = current_batch + 1
            batch_nodes = nodes[window*current_batch:window*(current_batch+1)]
            reflection_prompt = ""
        # prompt error event
        elif isinstance(ev, RerankerValidationErrorEvent):
            current_retries = ctx.data.get("prompt_retries", 0)
            rerank_nodes = ev.rerank_nodes
            if current_retries >= self.max_retries:
                if neg_context == '':
                    if len(rerank_nodes) == 0:
                        return StopEvent(result=None)
                    else:
                        return StopEvent(result=rerank_nodes)
                else:
                    return RerankerValidationDone(query=query, neg_context=neg_context, nodes=nodes, rerank_nodes=rerank_nodes)
            else:
                ctx.data["prompt_retries"] = current_retries + 1
            current_batch = ctx.data.get("batch", 0)
            batch_nodes = nodes[window*current_batch:window*(current_batch+1)]
            reflection_prompt = CHOICE_REFLECTION_PROMPT_STR.format(error=ev.error)
        
        prompt = CHOICE_SELECT_PROMPT_TMPL
        if reflection_prompt:
            prompt += reflection_prompt
        return RerankerDone(prompt=prompt, query=query, neg_context=neg_context, nodes=nodes, rerank_nodes=rerank_nodes, batch_nodes=batch_nodes)

    # call LLM to rerank nodes
    @step()
    async def rerank_validate(
        self, ev: RerankerDone
    ) -> Union[StopEvent, RerankerValidationDone, RerankerValidationErrorEvent, RerankerBatchEvent]:
        try:
            Response_haiku = Settings.llm.complete(ev.prompt.format(context_str=default_format_node_batch_fn(ev.batch_nodes), query_str=ev.query))
            match = re.search(r'\[(.*?)\]', Response_haiku.text)
            if match:
                output = list(map(int, match.group(1).split(',')))
            else:
                output = []
            new_nodes = []
            for i in output:
                if i-1 < len(ev.batch_nodes):
                    new_nodes.append(ev.batch_nodes[i-1])
        except Exception as e:
            print("Validation failed, retrying...")
            return RerankerValidationErrorEvent(
                error=str(e), query=ev.query, neg_context=ev.neg_context, nodes=ev.nodes, rerank_nodes=ev.rerank_nodes
            )
        rerank_nodes = new_nodes + ev.rerank_nodes
        # if reranked nodes are less than min chunks, return to the state of batch event, otherwise output the reranked nodes.
        if len(rerank_nodes) > self.min_chunk:
            print("Reranked nodes are extracted.")
            if ev.neg_context == '':
                return StopEvent(result=rerank_nodes)
            else:
                return RerankerValidationDone(query=ev.query, neg_context=ev.neg_context, nodes=ev.nodes, rerank_nodes=rerank_nodes)
        else:
            return RerankerBatchEvent(query=ev.query, neg_context=ev.neg_context, nodes=ev.nodes, rerank_nodes=rerank_nodes)

    # prepare the prompt to filter the reranked nodes
    @step(pass_context=True)
    async def filter_prompt(
        self, ctx: Context, ev: Union[RerankerValidationDone, FilterValidationErrorEvent]
    ) -> Union[StopEvent, FilterDone]:
        query = ev.query
        nodes = ev.nodes
        neg_context = ev.neg_context
        rerank_nodes = ev.rerank_nodes

        if isinstance(ev, RerankerValidationDone):
            reflection_prompt = ""
        elif isinstance(ev, FilterValidationErrorEvent):
            current_retries = ctx.data.get("prompt_retries", 0)
            if current_retries >= self.max_retries:
                if len(rerank_nodes) == 0:
                    return StopEvent(result=None)
                else:
                    return StopEvent(result=rerank_nodes)
            else:
                ctx.data["prompt_retries"] = current_retries + 1
            reflection_prompt = NEGATIVE_FILTER_REFLECTION_PROMPT_STR.format(error=ev.error)
        prompt = NEGATIVE_FILTER_PROMPT_TMPL
        if reflection_prompt:
            prompt += reflection_prompt
        return FilterDone(prompt=prompt, query=query, neg_context=neg_context, nodes=nodes, rerank_nodes=rerank_nodes)

    # call LLM to filter the reranked nodes
    @step(pass_context=True)
    async def filter_validate(
        self, ctx: Context, ev: FilterDone
    ) -> Union[StopEvent, FilterValidationErrorEvent, RerankerBatchEvent]:
        try:
            Response_haiku = Settings.llm.complete(ev.prompt.format(context_str=default_format_node_batch_fn(ev.rerank_nodes), query_str=ev.query, contexte_negatif=ev.neg_context))
            match = re.search(r'\[(.*?)\]', Response_haiku.text)
            if match:
                output = list(map(int, match.group(1).split(',')))
            else:
                output = []
        except Exception as e:
            print("Validation failed, retrying...")
            return FilterValidationErrorEvent(
                error=str(e), query=ev.query, neg_context=ev.neg_context, nodes=ev.nodes, rerank_nodes=ev.rerank_nodes
            )
        new_nodes = []
        for i in output:
            new_nodes.append(ev.rerank_nodes[i-1])
        print("Reranked nodes are filtered.")
        current_batch = ctx.data.get("batch", 0)
        if len(new_nodes) > self.min_chunk:
            return StopEvent(result=new_nodes)
        else:
            if current_batch >= self.batch_size:
                if len(new_nodes) == 0:
                    return StopEvent(result=None)
                else:
                    return StopEvent(result=new_nodes)
            else:
                return RerankerBatchEvent(query=ev.query, neg_context=ev.neg_context, nodes=ev.nodes, rerank_nodes=new_nodes)

In [14]:
from llama_index.core.schema import QueryBundle, NodeWithScore
from llama_index.core.retrievers import BaseRetriever, VectorIndexRetriever
from llama_index.core.indices.keyword_table import KeywordTableGPTRetriever
from typing import List
import threading
import asyncio
from queue import Queue

class CustomRetriever(BaseRetriever):
    """Custom retriever that performs both semantic search and hybrid search."""

    def __init__(
        self,
        vector_retriever: VectorIndexRetriever,
        keyword_retriever: KeywordTableGPTRetriever,
        embed_model: BedrockEmbedding,
        old_queries: List = [],
        mode: str = "OR",
        
    ) -> None:
        """Init params."""

        self._vector_retriever = vector_retriever
        self._keyword_retriever = keyword_retriever
        self.embed_model = embed_model
        self.reranker = RerankerFlow(vector_retriever=vector_retriever)
        self.old_queries = old_queries
        index_feedback = load_index_from_storage(storage_context=StorageContext.from_defaults(persist_dir="finetune/vector_persist_800_feedback"))
        self.feedback_retriever = VectorIndexRetriever(index=index_feedback, similarity_top_k=1)
        if mode not in ("AND", "OR"):
            raise ValueError("Invalid mode.")
        self._mode = mode
        super().__init__()

    def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
        """Retrieve nodes given query."""
        print("async not activated.")
        keyword_nodes = self._keyword_retriever.retrieve(query_bundle)
        new_query_str = self.replace_abbreviations(query_bundle.query_str, keyword_nodes)
        query_bundle.embedding = self._weighted_aggre_embedding(new_query_str)
        vector_nodes = self._vector_retriever.retrieve(query_bundle)

        vector_ids = {n.node.node_id for n in vector_nodes}
        keyword_ids = {n.node.node_id for n in keyword_nodes}

        combined_dict = {n.node.node_id: n for n in vector_nodes}
        combined_dict.update({n.node.node_id: n for n in keyword_nodes})

        if self._mode == "AND":
            retrieve_ids = vector_ids.intersection(keyword_ids)
        else:
            retrieve_ids = vector_ids.union(keyword_ids)

        retrieve_nodes = [combined_dict[rid] for rid in retrieve_ids]
        return retrieve_nodes

    async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
        """Retrieve nodes given query."""
        print("async activated.")
        keyword_nodes = self._keyword_retriever.retrieve(query_bundle)
        new_query_str = self.replace_abbreviations(query_bundle.query_str, keyword_nodes)
        query_bundle.embedding = await self._aweighted_aggre_embedding(new_query_str)
        vector_nodes_full = self._vector_retriever.retrieve(query_bundle)

        feedback_nodes = self.feedback_retriever.retrieve(query_bundle)
        if feedback_nodes[0].score > 0.8 : 
            neg_context = feedback_nodes[0].metadata["message"]
        else:
            neg_context = ''
        vector_nodes = await self.reranker.run(query=query_bundle.query_str, nodes=vector_nodes_full, neg_context=neg_context)
        if vector_nodes == None:
            return keyword_nodes

        vector_ids = {n.node.node_id for n in vector_nodes}
        keyword_ids = {n.node.node_id for n in keyword_nodes}

        combined_dict = {n.node.node_id: n for n in vector_nodes}
        combined_dict.update({n.node.node_id: n for n in keyword_nodes})

        if self._mode == "AND":
            retrieve_ids = vector_ids.intersection(keyword_ids)
        else:
            retrieve_ids = vector_ids.union(keyword_ids)

        retrieve_nodes = [combined_dict[rid] for rid in retrieve_ids]
        return retrieve_nodes

    def replace_abbreviations(self, question, keyword_nodes):
        """
        Replaces abbreviations in the lexique with their complete names.
        Assumes abbreviations are separated by either two spaces or one space + one punctuation mark.
        """
        lexique = {}
        for node in keyword_nodes:
            keyword = node.node.metadata['keyword']
            keyword_complete = node.text.split(': ')[1]
            lexique[keyword] = keyword_complete

        words = question.split()
        result = []

        for word in words:
            if word in lexique:
                result.append(lexique[word])
                result.append('('+word+')')
            elif word[:-1] in lexique:
                result.append(lexique[word[:-1]])
                result.append('('+word[:-1]+')')
                result.append(word[-1])
            elif word[1:] in lexique:
                result.append(word[0])
                result.append(lexique[word[1:]])
                result.append('('+word[1:]+')')
            else:
                result.append(word)

        return ' '.join(result)
    
    def _get_query_embedding_threaded(self, query):
        """
        Get the query embedding using a separate thread.
        """
        result_queue = Queue()  # Create a queue to store the result
        thread = threading.Thread(target=self._get_query_embedding_async_thread, args=(query, result_queue))
        thread.start()
        thread.join()  # Wait for the thread to finish
        return result_queue.get()  # Retrieve the result from the queue

    def _get_query_embedding_async_thread(self, query, result_queue):
        """
        Helper method to run the asynchronous call in a separate thread.
        """
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
        result = loop.run_until_complete(self.embed_model.aget_query_embedding(query))
        loop.close()
        result_queue.put(result)  # Put the result in the queue
    
    def _weighted_aggre_embedding(self, new_query_str):
        """
        A weighted aggregation embedding method which can integrate the preivous queries in the current query.
        """
        current_embedding = self._get_query_embedding_threaded(new_query_str)
        print(current_embedding)
        if len(self.old_queries) == 0:
            return current_embedding
        elif len(self.old_queries) == 1:
            first_embedding = self._get_query_embedding_threaded(self.old_queries[0])
            new_embedding = []
            for i in range(len(current_embedding)):
                new_embedding.append(0.8 * current_embedding[i] + 0.2 * first_embedding[i])
            return new_embedding
        else:
            first_embedding = self._get_query_embedding_threaded(self.old_queries[0])
            other_embeddings = []
            for query in self.old_queries[1:]:
                other_embeddings.append(self._get_query_embedding_threaded(query))
            other_embedding = list(np.array(other_embeddings).mean(axis=0))
            new_embedding = []
            for i in range(len(current_embedding)):
                new_embedding.append(0.8 * current_embedding[i] + 0.2 * first_embedding[i] + 0.1 * other_embedding[i])
            return new_embedding
    async def _aweighted_aggre_embedding(self, new_query_str):
        """
        A weighted aggregation embedding method which can integrate the preivous queries in the current query.
        """
        current_embedding = await self.embed_model.aget_query_embedding(new_query_str)
        if len(self.old_queries) == 0:
            return current_embedding
        elif len(self.old_queries) == 1:
            first_embedding = await self.embed_model.aget_query_embedding(self.old_queries[0])
            new_embedding = []
            for i in range(len(current_embedding)):
                new_embedding.append(0.8 * current_embedding[i] + 0.2 * first_embedding[i])
            return new_embedding
        else:
            first_embedding = await self.embed_model.aget_query_embedding(self.old_queries[0])
            other_embeddings = []
            for query in self.old_queries[1:]:
                other_embeddings.append(await self.embed_model.aget_query_embedding(query))
            other_embedding = list(np.array(other_embeddings).mean(axis=0))
            new_embedding = []
            for i in range(len(current_embedding)):
                new_embedding.append(0.8 * current_embedding[i] + 0.2 * first_embedding[i] + 0.1 * other_embedding[i])
            return new_embedding

In [15]:
vector_retriever = VectorIndexRetriever(index=index_doc, similarity_top_k=30)
keyword_retriever = KeywordTableGPTRetriever(index=index_keyword, max_keywords_per_query = 20, keyword_extract_template=KEYWORD_EXTRACT_TEMPLATE)
custom_retriever = CustomRetriever(vector_retriever, keyword_retriever, embed_model=Settings.embed_model, old_queries=[])

In [22]:
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.response_synthesizers import ResponseMode
qa_template = PromptTemplate(DEFAULT_QUERY_PROMPT_TMPL)
query_engine = RetrieverQueryEngine.from_args(retriever=custom_retriever,
                                              text_qa_template=qa_template,
                                              use_async=True,
                                              response_mode = ResponseMode.SIMPLE_SUMMARIZE,
                                              streaming=True
                                             )

In [24]:
import time
start_time = time.time()

response = query_engine.aquery("Quelle norme de tube acier utiliser pour faire une liaison détente / comptage pour alimenter une chaufferie ?")
print("========================================")
print(response.response)

end_time = time.time()
execution_time = end_time - start_time
print("========================================")
print(f"Execution Time: {execution_time} seconds")

async not activated.
[0.10205078125, -0.279296875, 0.00970458984375, -0.005462646484375, -0.302734375, 0.40234375, 0.07763671875, -0.00020503997802734375, -0.369140625, 0.29296875, -0.484375, 0.107421875, -0.1337890625, 0.058349609375, 0.408203125, 0.3515625, 0.38671875, -0.1083984375, 0.5234375, 0.00384521484375, 0.09716796875, -0.91796875, 0.28515625, 0.2001953125, -0.2021484375, 0.39453125, 0.1767578125, -0.057373046875, 0.054443359375, -0.4921875, -0.369140625, -0.2255859375, -0.119140625, -1.171875, 0.7109375, 0.1689453125, -0.05224609375, -0.2490234375, -0.37109375, 0.0693359375, 0.0927734375, 0.419921875, 0.154296875, -0.57421875, 0.2314453125, -0.02734375, -0.006500244140625, 0.279296875, 0.3671875, 0.14453125, -0.27734375, 0.36328125, 0.19921875, 0.2060546875, -0.298828125, -0.06982421875, 0.26953125, -0.18359375, 0.61328125, 0.2294921875, -0.61328125, 0.2294921875, 0.55859375, -0.330078125, 0.318359375, 0.10009765625, -0.1787109375, -0.427734375, -0.39453125, 0.287109375, -0.

AttributeError: 'StreamingResponse' object has no attribute 'response'

In [21]:
response.metadata

{'88d669e7-1330-4eef-a1c1-edf9f994d522': {'file_path': 'Lexique.csv',
  'file_name': 'Lexique.csv',
  'file_type': '.csv',
  'file_size': '180605',
  'creation_date': '2024-03-06',
  'last_modified_date': '2024-03-06 10:14:03',
  'last_accessed_date': '2024-03-07 13:37:50',
  'catégorie': 'vocabulaire',
  'keyword': 'DDC'},
 'fee16e37-172c-4353-aab3-2210390056af': {'file_path': 'Lexique.csv',
  'file_name': 'Lexique.csv',
  'file_type': '.csv',
  'file_size': '180605',
  'creation_date': '2024-03-06',
  'last_modified_date': '2024-03-06 10:14:03',
  'last_accessed_date': '2024-03-07 13:37:50',
  'catégorie': 'vocabulaire',
  'keyword': '6M'},
 '2ba547ac-8e18-4a65-87c6-e243d5968e8d': {'file_path': 'Lexique.csv',
  'file_name': 'Lexique.csv',
  'file_type': '.csv',
  'file_size': '180605',
  'creation_date': '2024-03-06',
  'last_modified_date': '2024-03-06 10:14:03',
  'last_accessed_date': '2024-03-07 13:37:50',
  'catégorie': 'vocabulaire',
  'keyword': 'JM'},
 '7c6a7785-2cbb-4cd9-8bd

In [61]:
import time
start_time = time.time()

response = await query_engine.aquery("Quels sont les achats immobilisés pour le biométhane ?")
print("========================================")
print(response.response)

end_time = time.time()
execution_time = end_time - start_time
print("========================================")
print(f"Execution Time: {execution_time} seconds")

async activated.
Running step sub_question
Step sub_question produced event SubQuestionEvent
Running step reranker_prompt
Step reranker_prompt produced event RerankerDone
Running step rerank_validate
Step rerank_validate produced event RerankerBatchEvent
Running step reranker_prompt
Step reranker_prompt produced event RerankerDone
Running step rerank_validate
Step rerank_validate produced event RerankerBatchEvent
Running step reranker_prompt
Step reranker_prompt produced event RerankerDone
Running step rerank_validate
Step rerank_validate produced event RerankerBatchEvent
Running step reranker_prompt
Reranked nodes are extracted.
Step reranker_prompt produced event StopEvent
Les informations contextuelles ne mentionnent pas d'achats immobilisés pour le biométhane. Cependant, elles fournissent des détails sur les éléments suivants :

1- Le périmètre des prestations de maintenance des postes d'injection de biométhane, qui ne comprend pas l'exploitation des postes mais permet à GRDF de re

In [62]:
import time
start_time = time.time()

response = await query_engine.aquery("Que veut dire CICM ?")
print("========================================")
print(response.response)

end_time = time.time()
execution_time = end_time - start_time
print("========================================")
print(f"Execution Time: {execution_time} seconds")

async activated.
Running step sub_question
Step sub_question produced event SubQuestionEvent
Running step reranker_prompt
Step reranker_prompt produced event RerankerDone
Running step rerank_validate
Step rerank_validate produced event RerankerBatchEvent
Running step reranker_prompt
Step reranker_prompt produced event RerankerDone
Running step rerank_validate
Step rerank_validate produced event RerankerBatchEvent
Running step reranker_prompt
Step reranker_prompt produced event RerankerDone
Running step rerank_validate
Step rerank_validate produced event RerankerBatchEvent
Running step reranker_prompt
Reranked nodes are extracted.
Step reranker_prompt produced event StopEvent
CICM signifie "Conduite d'Immeuble, Conduite Montante". C'est un terme utilisé pour désigner les canalisations de gaz qui alimentent les immeubles collectifs, notamment la conduite d'immeuble et la conduite montante.
Execution Time: 8.808188676834106 seconds


In [63]:
import time
start_time = time.time()

response = await query_engine.aquery("Dans le cadre de la surveillance des réseaux, doit-on surveiller les branchements ?")
print("========================================")
print(response.response)

end_time = time.time()
execution_time = end_time - start_time
print("========================================")
print(f"Execution Time: {execution_time} seconds")

async activated.
Running step sub_question
Step sub_question produced event SubQuestionEvent
Running step reranker_prompt
Step reranker_prompt produced event RerankerDone
Running step rerank_validate
Step rerank_validate produced event RerankerBatchEvent
Running step reranker_prompt
Step reranker_prompt produced event RerankerDone
Running step rerank_validate
Step rerank_validate produced event RerankerBatchEvent
Running step reranker_prompt
Step reranker_prompt produced event RerankerDone
Running step rerank_validate
Reranked nodes are extracted.
Step rerank_validate produced event RerankerValidationDone
Running step filter_prompt
Step filter_prompt produced event FilterDone
Running step filter_validate
Validation failed, retrying...
Step filter_validate produced event FilterValidationErrorEvent
Running step filter_prompt
Step filter_prompt produced event FilterDone
Running step filter_validate
Validation failed, retrying...
Step filter_validate produced event FilterValidationErrorEve