In [None]:
!pip install llama_index openai llama-index-tools-tavily-research

In [None]:
import os

os.environ["OPENAI_API_KEY"] = ""

tavily_ai_api_key = ""
from llama_index.llms.openai import OpenAI
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.core import Settings
from pprint import pprint
# 사용 LLM 설정
Settings.llm = OpenAI(model="gpt-4o-mini", temperature=0.1)
Settings.embed_model = OpenAIEmbedding(
    model="text-embedding-3-small"
)
import nest_asyncio

nest_asyncio.apply()

In [None]:
import pandas as pd
#korean_webtext.csv 불러오기
data = pd.read_csv().iloc[:,1:]

# Document 오브젝트로 변환
from llama_index.core import Document, VectorStoreIndex
documents = []

#Iterative하게 Document 만들기
for i, row in data.iterrows():
    documents.append(Document(
        text=row['text'],
        # extra_info={'title': row['title']}
    ))

In [None]:
documents

In [None]:
"""Corrective RAG LlamaPack class."""

from typing import Any, Dict, List

from llama_index.core import VectorStoreIndex, SummaryIndex
from llama_index.core.llama_pack.base import BaseLlamaPack
from llama_index.llms.openai import OpenAI
from llama_index.core.schema import Document, NodeWithScore
from llama_index.core.query_pipeline.query import QueryPipeline
from llama_index.tools.tavily_research.base import TavilyToolSpec
from llama_index.core.prompts import PromptTemplate

DEFAULT_RELEVANCY_PROMPT_TEMPLATE = PromptTemplate(
    template="""As a grader, your task is to evaluate the relevance of a document retrieved in response to a user's question.

    Retrieved Document:
    -------------------
    {context_str}

    User Question:
    --------------
    {query_str}

    Evaluation Criteria:
    - Consider whether the document contains keywords or topics related to the user's question.
    - The evaluation should not be overly stringent; the primary objective is to identify and filter out clearly irrelevant retrievals.

    Decision:
    - Assign a binary score to indicate the document's relevance.
    - Use 'yes' if the document is relevant to the question, or 'no' if it is not.

    Please provide your binary score ('yes' or 'no') below to indicate the document's relevance to the user question."""
)

DEFAULT_TRANSFORM_QUERY_TEMPLATE = PromptTemplate(
    template="""Your task is to refine a query to ensure it is highly effective for retrieving relevant search results. \n
    Analyze the given input to grasp the core semantic intent or meaning. \n
    Original Query:
    \n ------- \n
    {query_str}
    \n ------- \n
    Your goal is to rephrase or enhance this query to improve its search performance. Ensure the revised query is concise and directly aligned with the intended search objective. \n
    Respond with the optimized query only:"""
)


class CorrectiveRAGPack(BaseLlamaPack):
    def __init__(self, documents: List[Document], tavily_ai_apikey: str) -> None:
        """Init params."""
        llm = OpenAI(model="gpt-4") #GPT4?
        self.relevancy_pipeline = QueryPipeline(
            chain=[DEFAULT_RELEVANCY_PROMPT_TEMPLATE, llm]
        )
        self.transform_query_pipeline = QueryPipeline(
            chain=[DEFAULT_TRANSFORM_QUERY_TEMPLATE, llm]
        )

        self.llm = llm
        self.index = VectorStoreIndex.from_documents(documents)
        self.tavily_tool = TavilyToolSpec(api_key=tavily_ai_apikey)

    def get_modules(self) -> Dict[str, Any]:
        """Get modules."""
        return {"llm": self.llm, "index": self.index}

    def retrieve_nodes(self, query_str: str, **kwargs: Any) -> List[NodeWithScore]:
        """Retrieve the relevant nodes for the query."""
        retriever = self.index.as_retriever(**kwargs)
        return retriever.retrieve(query_str)

    def evaluate_relevancy(
        self, retrieved_nodes: List[Document], query_str: str
    ) -> List[str]:
        """Evaluate relevancy of retrieved documents with the query."""
        relevancy_results = []
        for node in retrieved_nodes:
            relevancy = self.relevancy_pipeline.run(
                context_str=node.text, query_str=query_str
            )
            relevancy_results.append(relevancy.message.content.lower().strip())
            print(f'Relevancy Result: {relevancy.message.content.lower().strip()}')
        return relevancy_results

    def extract_relevant_texts(
        self, retrieved_nodes: List[NodeWithScore], relevancy_results: List[str]
    ) -> str:
        """Extract relevant texts from retrieved documents."""
        relevant_texts = [
            retrieved_nodes[i].text
            for i, result in enumerate(relevancy_results)
            if result == "yes"
        ]
        return "\n".join(relevant_texts)

    def search_with_transformed_query(self, query_str: str) -> str: #인터넷에서 찾은 결과인지 retrieved context인지 확인 필요
        """Search the transformed query with Tavily API."""
        search_results = self.tavily_tool.search(query_str, max_results=5)
        return "\n".join([result.text for result in search_results])

    def get_result(self, relevant_text: str, search_text: str, query_str: str) -> Any:
        """Get result with relevant text."""
        documents = [Document(text=relevant_text + "\n" + search_text)]
        index = SummaryIndex.from_documents(documents)
        query_engine = index.as_query_engine()
        return query_engine.query(query_str)

    def run(self, query_str: str, **kwargs: Any) -> Any:
        """Run the pipeline."""
        # Retrieve nodes based on the input query string.
        retrieved_nodes = self.retrieve_nodes(query_str, **kwargs)
        # 실행 시 retreive context 같이 띄워서 가독성 높이기


        # Evaluate the relevancy of each retrieved document in relation to the query string.
        relevancy_results = self.evaluate_relevancy(retrieved_nodes, query_str)
        # Extract texts from documents that are deemed relevant based on the evaluation.
        relevant_text = self.extract_relevant_texts(retrieved_nodes, relevancy_results)

        # Initialize search_text variable to handle cases where it might not get defined.
        search_text = ""

        # If any document is found irrelevant, transform the query string for better search results.
        if "no" in relevancy_results:
            transformed_query_str = self.transform_query_pipeline.run(
                query_str=query_str
            ).message.content
            # Conduct a search with the transformed query string and collect the results.
            search_text = self.search_with_transformed_query(transformed_query_str)

        # Compile the final result. If there's additional search text from the transformed query,
        # it's included; otherwise, only the relevant text from the initial retrieval is returned.
        if search_text:
            return self.get_result(relevant_text, search_text, query_str)
        else:
            return self.get_result(relevant_text, "", query_str)

In [None]:
#corrective_rag 생성
corrective_rag_pack =

In [None]:
from IPython.display import Markdown, display
response = corrective_rag_pack.run()
display(Markdown(str(response)))

In [None]:
#소스노드 확인
pprint()

In [None]:
#소스노드 확인
pprint()

In [None]:

response = corrective_rag_pack.run()
display(Markdown(str(response)))

In [None]:
pprint(response.source_nodes[0].text)

In [None]:
pprint(response.source_nodes[1].text)

In [None]:
response = corrective_rag_pack.run()
display(Markdown(str(response)))

- 기본 Retrieval 성능 개선이 필요하다

In [None]:
from llama_index.core.indices.query.query_transform import HyDEQueryTransform
from llama_index.core.query_engine import TransformQueryEngine

In [None]:
"""Corrective RAG LlamaPack class."""

from typing import Any, Dict, List

from llama_index.core import VectorStoreIndex, SummaryIndex
from llama_index.core.llama_pack.base import BaseLlamaPack
from llama_index.llms.openai import OpenAI
from llama_index.core.schema import Document, NodeWithScore
from llama_index.core.query_pipeline.query import QueryPipeline
from llama_index.tools.tavily_research.base import TavilyToolSpec
from llama_index.core.prompts import PromptTemplate

DEFAULT_RELEVANCY_PROMPT_TEMPLATE = PromptTemplate(
    template="""As a grader, your task is to evaluate the relevance of a document retrieved in response to a user's question.

    Retrieved Document:
    -------------------
    {context_str}

    User Question:
    --------------
    {query_str}

    Evaluation Criteria:
    - Consider whether the document contains keywords or topics related to the user's question.
    - The evaluation should not be overly stringent; the primary objective is to identify and filter out clearly irrelevant retrievals.

    Decision:
    - Assign a binary score to indicate the document's relevance.
    - Use 'yes' if the document is relevant to the question, or 'no' if it is not.

    Please provide your binary score ('yes' or 'no') below to indicate the document's relevance to the user question."""
)

DEFAULT_TRANSFORM_QUERY_TEMPLATE = PromptTemplate(
    template="""Your task is to refine a query to ensure it is highly effective for retrieving relevant search results. \n
    Analyze the given input to grasp the core semantic intent or meaning. \n
    Original Query:
    \n ------- \n
    {query_str}
    \n ------- \n
    Your goal is to rephrase or enhance this query to improve its search performance. Ensure the revised query is concise and directly aligned with the intended search objective. \n
    Respond with the optimized query only:"""
)


class CorrectiveRAGPack(BaseLlamaPack):
    def __init__(self, documents: List[Document], tavily_ai_apikey: str) -> None:
        """Init params."""
        llm = OpenAI(model="gpt-4o-mini")
        self.relevancy_pipeline = QueryPipeline(
            chain=[DEFAULT_RELEVANCY_PROMPT_TEMPLATE, llm]
        )
        self.transform_query_pipeline = QueryPipeline(
            chain=[DEFAULT_TRANSFORM_QUERY_TEMPLATE, llm]
        )

        self.llm = llm
        self.index = VectorStoreIndex.from_documents(documents)
        self.tavily_tool = TavilyToolSpec(api_key=tavily_ai_apikey)
        #hyde 생성
    def get_modules(self) -> Dict[str, Any]:
        """Get modules."""
        return {"llm": self.llm, "index": self.index}

    def retrieve_nodes(self, query_str: str, **kwargs: Any) -> List[NodeWithScore]:
        """Retrieve the relevant nodes for the query."""
        retriever = self.index.as_retriever(**kwargs)
        # hyde 생성 가상 문서를 기준 임베딩으로 retrieve하는 구조로 변경
        return retriever.retrieve(query_str) #retrieve 수정


    def evaluate_relevancy(
        self, retrieved_nodes: List[Document], query_str: str
    ) -> List[str]:
        """Evaluate relevancy of retrieved documents with the query."""
        relevancy_results = []
        for node in retrieved_nodes:
            relevancy = self.relevancy_pipeline.run(
                context_str=node.text, query_str=query_str
            )
            relevancy_results.append(relevancy.message.content.lower().strip())
            print(f'Relevancy Result: {relevancy.message.content.lower().strip()}')
        return relevancy_results

    def extract_relevant_texts(
        self, retrieved_nodes: List[NodeWithScore], relevancy_results: List[str]
    ) -> str:
        """Extract relevant texts from retrieved documents."""
        relevant_texts = [
            retrieved_nodes[i].text
            for i, result in enumerate(relevancy_results)
            if result == "yes"
        ]
        return "\n".join(relevant_texts)

    def search_with_transformed_query(self, query_str: str) -> str:
        """Search the transformed query with Tavily API."""
        print(f'Internet Search needed with query: {query_str}')
        search_results = self.tavily_tool.search(query_str, max_results=5)
        printing_result = "\n".join([result.text for result in search_results])
        print(f'인터넷에서 찾은 검색결과: {printing_result}')
        return "\n".join([result.text for result in search_results])

    def get_result(self, relevant_text: str, search_text: str, query_str: str) -> Any:
        """Get result with relevant text."""
        documents = [Document(text=relevant_text + "\n" + search_text)]
        index = SummaryIndex.from_documents(documents)
        query_engine = index.as_query_engine()
        return query_engine.query(query_str)

    def run(self, query_str: str, **kwargs: Any) -> Any:
        """Run the pipeline."""
        # Retrieve nodes based on the input query string.
        retrieved_nodes = self.retrieve_nodes(query_str, **kwargs)
        pprint(f"Retrieved Node 1: {retrieved_nodes[0].text}")
        pprint(f"Retrieved Node 2: {retrieved_nodes[1].text}")
        # Evaluate the relevancy of each retrieved document in relation to the query string.
        relevancy_results = self.evaluate_relevancy(retrieved_nodes, query_str)
        # Extract texts from documents that are deemed relevant based on the evaluation.
        relevant_text = self.extract_relevant_texts(retrieved_nodes, relevancy_results)

        # Initialize search_text variable to handle cases where it might not get defined.
        search_text = ""

        # If any document is found irrelevant, transform the query string for better search results.
        if "no" in relevancy_results:
            transformed_query_str = self.transform_query_pipeline.run(
                query_str=query_str
            ).message.content
            # Conduct a search with the transformed query string and collect the results.
            search_text = self.search_with_transformed_query(transformed_query_str)

        # Compile the final result. If there's additional search text from the transformed query,
        # it's included; otherwise, only the relevant text from the initial retrieval is returned.
        if search_text:
            return self.get_result(relevant_text, search_text, query_str)
        else:
            return self.get_result(relevant_text, "", query_str)

In [None]:
corrective_rag_pack = CorrectiveRAGPack(documents, tavily_ai_apikey=tavily_ai_api_key)

In [None]:
#질문: "캐나다가 주목받는 국가인 이유가 뭘까?"
response = corrective_rag_pack.run()
display(Markdown(str(response)))

In [None]:
#질문: "분당과 같은 도시가 외국에서도 통하려면 어떻게해야되?"
response = corrective_rag_pack.run()
display(Markdown(str(response)))

In [None]:
#질문: "일본과 러시아 전쟁 중 유명한게 뭐가 있지?"
response = corrective_rag_pack.run()
display(Markdown(str(response)))

- HyDE보다 retrieval 성능 개선해보자
- BM25+Dense Hybrid Search(RRF) + query decomposition

In [None]:
from typing import List
import asyncio
from llama_index.core import QueryBundle
from llama_index.core.retrievers import BaseRetriever
from llama_index.core.schema import NodeWithScore
from tqdm.asyncio import tqdm

query_gen_str = """
너는 사용자가 대충 쓴 질문에 대해서, 최대한 답변하기 위한 근거를 찾기 위한 다수의 서치 쿼리를 생성해 내야해.
생성하는 쿼리 중에 딱 하나는 역으로 된 질문으로 추가해서, 상대적 근거로 활용될 수 있도록 해야해.
{num_queries}개의 서치 쿼리를 만들어 내고, 하나당 한줄씩 사용해.
Query: {query}
Queries:
"""
query_gen_prompt = PromptTemplate(query_gen_str)

llm = OpenAI(model="gpt-4o-mini")

def generate_queries(llm, query: str, num_queries: int = 4):
    response = llm.predict(
        query_gen_prompt, num_queries=num_queries, query=query
    )
    # assume LLM proper put each query on a newline
    queries = response.split("\n")
    queries_str = "\n".join(queries)
    print(f"Generated queries:\n{queries_str}")
    return queries

async def run_queries(queries, retrievers):
    """Run queries against retrievers."""
    tasks = []
    for query in queries:
        for i, retriever in enumerate(retrievers):
            tasks.append(retriever.aretrieve(query))

    task_results = await tqdm.gather(*tasks)

    results_dict = {}
    for i, (query, query_result) in enumerate(zip(queries, task_results)):
        results_dict[(query, i)] = query_result

    return results_dict

def fuse_results(results_dict, similarity_top_k: int = 2):
    """Fuse results."""
    k = 60.0  # `k` is a parameter used to control the impact of outlier rankings.
    fused_scores = {}
    text_to_node = {}

    # compute reciprocal rank scores
    for nodes_with_scores in results_dict.values():
        for rank, node_with_score in enumerate(
            sorted(
                nodes_with_scores, key=lambda x: x.score or 0.0, reverse=True
            )
        ):
            text = node_with_score.node.get_content()
            text_to_node[text] = node_with_score
            if text not in fused_scores:
                fused_scores[text] = 0.0
            fused_scores[text] += 1.0 / (rank + k)

    # sort results
    reranked_results = dict(
        sorted(fused_scores.items(), key=lambda x: x[1], reverse=True)
    )

    # adjust node scores
    reranked_nodes: List[NodeWithScore] = []
    for text, score in reranked_results.items():
        reranked_nodes.append(text_to_node[text])
        reranked_nodes[-1].score = score

    return reranked_nodes[:similarity_top_k]

class FusionRetriever(BaseRetriever):
    """Ensemble retriever with fusion."""

    def __init__(
        self,
        llm,
        retrievers: List[BaseRetriever],
        similarity_top_k: int = 2,
    ) -> None:
        """Init params."""
        self._retrievers = retrievers
        self._similarity_top_k = similarity_top_k
        self._llm = llm
        super().__init__()

    def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
        """Retrieve."""
        queries = generate_queries(
            self._llm, query_bundle.query_str, num_queries=4
        )
        results = asyncio.run(run_queries(queries, self._retrievers))
        final_results = fuse_results(
            results, similarity_top_k=self._similarity_top_k
        )

        return final_results

In [None]:
!pip install llama-index-retrievers-bm25

In [None]:
"""Corrective RAG LlamaPack class."""

from typing import Any, Dict, List

from llama_index.core import VectorStoreIndex, SummaryIndex
from llama_index.core.llama_pack.base import BaseLlamaPack
from llama_index.llms.openai import OpenAI
from llama_index.core.schema import Document, NodeWithScore
from llama_index.core.query_pipeline.query import QueryPipeline
from llama_index.tools.tavily_research.base import TavilyToolSpec
from llama_index.core.prompts import PromptTemplate
from llama_index.retrievers.bm25 import BM25Retriever
DEFAULT_RELEVANCY_PROMPT_TEMPLATE = PromptTemplate(
    template="""As a grader, your task is to evaluate the relevance of a document retrieved in response to a user's question.

    Retrieved Document:
    -------------------
    {context_str}

    User Question:
    --------------
    {query_str}

    Evaluation Criteria:
    - Consider whether the document contains keywords or topics related to the user's question.
    - The evaluation should not be overly stringent; the primary objective is to identify and filter out clearly irrelevant retrievals.

    Decision:
    - Assign a binary score to indicate the document's relevance.
    - Use 'yes' if the document is relevant to the question, or 'no' if it is not.

    Please provide your binary score ('yes' or 'no') below to indicate the document's relevance to the user question."""
)

DEFAULT_TRANSFORM_QUERY_TEMPLATE = PromptTemplate(
    template="""Your task is to refine a query to ensure it is highly effective for retrieving relevant search results. \n
    Analyze the given input to grasp the core semantic intent or meaning. \n
    Original Query:
    \n ------- \n
    {query_str}
    \n ------- \n
    Your goal is to rephrase or enhance this query to improve its search performance. Ensure the revised query is concise and directly aligned with the intended search objective. \n
    Respond with the optimized query only:"""
)


class CorrectiveRAGPack(BaseLlamaPack):
    def __init__(self, documents: List[Document], tavily_ai_apikey: str) -> None:
        """Init params."""
        llm = OpenAI(model="gpt-4o-mini")
        self.relevancy_pipeline = QueryPipeline(
            chain=[DEFAULT_RELEVANCY_PROMPT_TEMPLATE, llm]
        )
        self.transform_query_pipeline = QueryPipeline(
            chain=[DEFAULT_TRANSFORM_QUERY_TEMPLATE, llm]
        )

        self.llm = llm
        self.index = VectorStoreIndex.from_documents(documents)
        self.tavily_tool = TavilyToolSpec(api_key=tavily_ai_apikey)

    def get_modules(self) -> Dict[str, Any]:
        """Get modules."""
        return {"llm": self.llm, "index": self.index}

    def retrieve_nodes(self, query_str: str, **kwargs: Any) -> List[NodeWithScore]:
        """Retrieve the relevant nodes for the query."""
        retriever = self.index.as_retriever(**kwargs)
        ## vector retriever 추가, 위에꺼 삭제


        ## bm25 retriever 추가

        ## 결과 합치는 fusion retriever 추가
        return retriever.retrieve(query_str) #retriever 퓨전 리트리버로 교체



    def evaluate_relevancy(
        self, retrieved_nodes: List[Document], query_str: str
    ) -> List[str]:
        """Evaluate relevancy of retrieved documents with the query."""
        relevancy_results = []
        for node in retrieved_nodes:
            relevancy = self.relevancy_pipeline.run(
                context_str=node.text, query_str=query_str
            )
            relevancy_results.append(relevancy.message.content.lower().strip())
            print(f'Relevancy Result: {relevancy.message.content.lower().strip()}')
        return relevancy_results

    def extract_relevant_texts(
        self, retrieved_nodes: List[NodeWithScore], relevancy_results: List[str]
    ) -> str:
        """Extract relevant texts from retrieved documents."""
        relevant_texts = [
            retrieved_nodes[i].text
            for i, result in enumerate(relevancy_results)
            if result == "yes"
        ]
        return "\n".join(relevant_texts)

    def search_with_transformed_query(self, query_str: str) -> str:
        """Search the transformed query with Tavily API."""
        print(f'Internet Search needed with query: {query_str}')
        search_results = self.tavily_tool.search(query_str, max_results=5)
        printing_result = "\n".join([result.text for result in search_results])
        print(f'인터넷에서 찾은 검색결과: {printing_result}')
        return "\n".join([result.text for result in search_results])

    def get_result(self, relevant_text: str, search_text: str, query_str: str) -> Any:
        """Get result with relevant text."""
        documents = [Document(text=relevant_text + "\n" + search_text)]
        index = SummaryIndex.from_documents(documents)
        query_engine = index.as_query_engine()
        return query_engine.query(query_str)

    def run(self, query_str: str, **kwargs: Any) -> Any:
        """Run the pipeline."""
        # Retrieve nodes based on the input query string.
        retrieved_nodes = self.retrieve_nodes(query_str, **kwargs)
        pprint(f"Retrieved Node 1: {retrieved_nodes[0].text}")
        pprint(f"Retrieved Node 2: {retrieved_nodes[1].text}")
        # Evaluate the relevancy of each retrieved document in relation to the query string.
        relevancy_results = self.evaluate_relevancy(retrieved_nodes, query_str)
        # Extract texts from documents that are deemed relevant based on the evaluation.
        relevant_text = self.extract_relevant_texts(retrieved_nodes, relevancy_results)

        # Initialize search_text variable to handle cases where it might not get defined.
        search_text = ""

        # If any document is found irrelevant, transform the query string for better search results.
        if "no" in relevancy_results:
            transformed_query_str = self.transform_query_pipeline.run(
                query_str=query_str
            ).message.content
            # Conduct a search with the transformed query string and collect the results.
            search_text = self.search_with_transformed_query(transformed_query_str)

        # Compile the final result. If there's additional search text from the transformed query,
        # it's included; otherwise, only the relevant text from the initial retrieval is returned.
        if search_text:
            return self.get_result(relevant_text, search_text, query_str)
        else:
            return self.get_result(relevant_text, "", query_str)




In [None]:
corrective_rag_pack = CorrectiveRAGPack(documents, tavily_ai_apikey=tavily_ai_api_key)


In [None]:
#질문:
response = corrective_rag_pack.run("캐나다가 주목받는 국가인 이유가 뭘까?")
display(Markdown(str(response)))

In [None]:
response = corrective_rag_pack.run("분당과 같은 도시가 외국에서도 통하려면 어떻게해야되?")
display(Markdown(str(response)))

In [None]:
response = corrective_rag_pack.run("일본과 러시아 전쟁 중 유명한게 뭐가 있지?")
display(Markdown(str(response)))

In [None]:
#이제 Retrieval은 강화했는데 너무 기준이 빡세니까 관련성 크라이테리아를 좀 완화해보자
"""Corrective RAG LlamaPack class."""

from typing import Any, Dict, List

from llama_index.core import VectorStoreIndex, SummaryIndex
from llama_index.core.llama_pack.base import BaseLlamaPack
from llama_index.llms.openai import OpenAI
from llama_index.core.schema import Document, NodeWithScore
from llama_index.core.query_pipeline.query import QueryPipeline
from llama_index.tools.tavily_research.base import TavilyToolSpec
from llama_index.core.prompts import PromptTemplate
from llama_index.retrievers.bm25 import BM25Retriever
DEFAULT_RELEVANCY_PROMPT_TEMPLATE = PromptTemplate(
    template="""As a grader, your task is to evaluate the relevance of a document retrieved in response to a user's question.

    Retrieved Document:
    -------------------
    {context_str}

    User Question:
    --------------
    {query_str}

    Evaluation Criteria: <- 덜빡세게 완화
    - Consider whether the document contains keywords or topics related to the user's question.
    - The evaluation should not be overly stringent; the primary objective is to identify and filter out clearly irrelevant retrievals.

    Decision: <- 덜빡세게 완화
    - Assign a binary score to indicate the document's relevance.
    - Use 'yes' if the document is relevant to the question, or 'no' if it is not.

    Please provide your binary score ('yes' or 'no') below to indicate the document's relevance to the user question."""
)

DEFAULT_TRANSFORM_QUERY_TEMPLATE = PromptTemplate(
    template="""Your task is to refine a query to ensure it is highly effective for retrieving relevant search results. \n
    Analyze the given input to grasp the core semantic intent or meaning. \n
    Original Query:
    \n ------- \n
    {query_str}
    \n ------- \n
    Your goal is to rephrase or enhance this query to improve its search performance. Ensure the revised query is concise and directly aligned with the intended search objective. \n
    Respond with the optimized query only:"""
)


class CorrectiveRAGPack(BaseLlamaPack):
    def __init__(self, documents: List[Document], tavily_ai_apikey: str) -> None:
        """Init params."""
        llm = OpenAI(model="gpt-4o-mini")
        self.relevancy_pipeline = QueryPipeline(
            chain=[DEFAULT_RELEVANCY_PROMPT_TEMPLATE, llm]
        )
        self.transform_query_pipeline = QueryPipeline(
            chain=[DEFAULT_TRANSFORM_QUERY_TEMPLATE, llm]
        )

        self.llm = llm
        self.index = VectorStoreIndex.from_documents(documents)
        self.tavily_tool = TavilyToolSpec(api_key=tavily_ai_apikey)

    def get_modules(self) -> Dict[str, Any]:
        """Get modules."""
        return {"llm": self.llm, "index": self.index}

    def retrieve_nodes(self, query_str: str, **kwargs: Any) -> List[NodeWithScore]:
        """Retrieve the relevant nodes for the query."""
        #retriever = self.index.as_retriever(**kwargs)
        ## vector retriever
        vector_retriever = self.index.as_retriever(similarity_top_k=2)

        ## bm25 retriever
        bm25_retriever = BM25Retriever.from_defaults(
                        docstore=self.index.docstore, similarity_top_k=2
                        )
        fusion_retriever = FusionRetriever(
            self.llm, [vector_retriever, bm25_retriever], similarity_top_k=2
            )
        #return retriever.retrieve(query_str)
        return fusion_retriever.retrieve(query_str)


    def evaluate_relevancy(
        self, retrieved_nodes: List[Document], query_str: str
    ) -> List[str]:
        """Evaluate relevancy of retrieved documents with the query."""
        relevancy_results = []
        for node in retrieved_nodes:
            relevancy = self.relevancy_pipeline.run(
                context_str=node.text, query_str=query_str
            )
            relevancy_results.append(relevancy.message.content.lower().strip())
            print(f'Relevancy Result: {relevancy.message.content.lower().strip()}')
        return relevancy_results

    def extract_relevant_texts(
        self, retrieved_nodes: List[NodeWithScore], relevancy_results: List[str]
    ) -> str:
        """Extract relevant texts from retrieved documents."""
        relevant_texts = [
            retrieved_nodes[i].text
            for i, result in enumerate(relevancy_results)
            if result == "yes"
        ]
        return "\n".join(relevant_texts)

    def search_with_transformed_query(self, query_str: str) -> str:
        """Search the transformed query with Tavily API."""
        print(f'Internet Search needed with query: {query_str}')
        search_results = self.tavily_tool.search(query_str, max_results=5)
        printing_result = "\n".join([result.text for result in search_results])
        print(f'인터넷에서 찾은 검색결과: {printing_result}')
        return "\n".join([result.text for result in search_results])

    def get_result(self, relevant_text: str, search_text: str, query_str: str) -> Any:
        """Get result with relevant text."""
        documents = [Document(text=relevant_text + "\n" + search_text)]
        index = SummaryIndex.from_documents(documents)
        query_engine = index.as_query_engine()
        return query_engine.query(query_str)

    def run(self, query_str: str, **kwargs: Any) -> Any:
        """Run the pipeline."""
        # Retrieve nodes based on the input query string.
        retrieved_nodes = self.retrieve_nodes(query_str, **kwargs)
        pprint(f"Retrieved Node 1: {retrieved_nodes[0].text}")
        pprint(f"Retrieved Node 2: {retrieved_nodes[1].text}")
        # Evaluate the relevancy of each retrieved document in relation to the query string.
        relevancy_results = self.evaluate_relevancy(retrieved_nodes, query_str)
        # Extract texts from documents that are deemed relevant based on the evaluation.
        relevant_text = self.extract_relevant_texts(retrieved_nodes, relevancy_results)

        # Initialize search_text variable to handle cases where it might not get defined.
        search_text = ""

        # If any document is found irrelevant, transform the query string for better search results.
        if "no" in relevancy_results:
            transformed_query_str = self.transform_query_pipeline.run(
                query_str=query_str
            ).message.content
            # Conduct a search with the transformed query string and collect the results.
            search_text = self.search_with_transformed_query(transformed_query_str)

        # Compile the final result. If there's additional search text from the transformed query,
        # it's included; otherwise, only the relevant text from the initial retrieval is returned.
        if search_text:
            return self.get_result(relevant_text, search_text, query_str)
        else:
            return self.get_result(relevant_text, "", query_str)




In [None]:
corrective_rag_pack = CorrectiveRAGPack(documents, tavily_ai_apikey=tavily_ai_api_key)

In [None]:
#질문: "캐나다가 주목받는 국가인 이유가 뭘까?"
response = corrective_rag_pack.run()
display(Markdown(str(response)))

In [None]:
#질문: "분당과 같은 도시가 외국에서도 통하려면 어떻게해야되?"
response = corrective_rag_pack.run()
display(Markdown(str(response)))

In [None]:
pprint(response.source_nodes[0].text)

In [None]:
#질문: "손흥민이 프리미어리그 득점왕때 몇골을 넣었지?"
response = corrective_rag_pack.run()
display(Markdown(str(response)))