In [None]:
class RAGToolInput(BaseModel):
  # Input for RAG tool.

  query: str = Field(description='Search query in vector database.')

In [None]:
# Criando a ferramenta.

class RAGTool(BaseTool):

  name : str  = 'RAG Tool'
  description : str = '''
                      Tool for retrieving and generating answers using
                      RAG from a vector store.
                      '''
  files_path: List[str] = []

  args_schema: Type[BaseModel] = RAGToolInput

  def _init__(self, model = model) -> None:

    self.model = model
    self.texts = self._load_and_split_files()
    self.vector_db = self._create_vector_db()
    self.cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')

  def _load_files(self) -> list:

    self.files_path = ['/content/psiAI/BroadViewofEffectsOfIntroducingGenerativeAIonPsychotherpy.pdf',
                       '/content/psiAI/ConversationalBotsForPsychotherapy.pdf',
                       '/content/psiAI/TheEvaluationOfGenerativeAIinPyschotherapy.pdf']

    '''
    Método que carrega os arquivos PDF e os divide em lotes,
    para que possam ser armazenados de forma efetiva no banco
    de dados vetorial.
    '''

    # Carrega e lê os PDF's.

    all_text = []
    for file_path in self.files_path:
      reader = PdfReader(file_path)
      text_extract = [p.extract_text() for p in reader.pages]
      text = [text for text in text_extract if text]
      all_text.extend(text)

    return all_text

  def _split_text(self, all_text : list) -> list:

    recursive_splitter = RecursiveCharacterTextSplitter(
            chunk_size=2000,
            chunk_overlap=150,
            length_function=len,
            separators=['\n\n', '\n', '.', ' ', '']
        )
    text_splitted_rc = recursive_splitter.split_text('\n\n'.join(all_text))

    token_splitter = SentenceTransformersTokenTextSplitter(
            tokens_per_chunk=256,
            chunk_overlap=0
        )
    token_split_text = []
    for text in text_splitted_rc:
      token_split_text += token_splitter.split_text(text)

    return token_split_text

  def _vector_db(self):

      '''
      Cria o banco de dados vetorial que irá armazenar os
      documentos que poderão ser consultados.
      '''
      embedding_function = SentenceTransformerEmbeddingFunction()
      chroma_client = chromadb.Client()
      chroma_db = chroma_client.create_collection(
          name='psychotherapy_and_AI',
          embedding_function=embedding_function
      )

      ids = [str(i) for i in range(len(self.texts))]
      chroma_db.add(ids=ids, documents=self.texts)

      return chroma_db

  def _argumented_multiple_query(self, query : str) -> str:

    '''
    Gera uma query expandida com base na original.
    '''
    template = PromptTemplate(
            input_variables=['query'],
            template='''
            You are an experienced researcher on topics related to psychology and
            psychotherapy with solid knowledge in generative AI and is studying the
            existing and possible relationship of using generative AI in conducting
            psychotherapy.

            Your task is to suggest up to seven additional related questions
            to help them find the information they need for the provided question.
            Suggest only short questions without compound sentences.
            Output one question per line.

            Question: {query}

            Helpful Answer:
            '''
        )

    chain = LLMChain(llm=self.model, prompt=template)

    return chain.invoke(input=query)

  def _retrieve_documents(self, query : str) -> list:

    '''
    Realiza a pesquisa no banco de dados vetorial.
    '''

    expanded_query = f'{query}  {self.argumented_multiple_query(query)}'

    results = self._vector_db().query(query_texts=expanded_query,
                                   n_results=10,
                                   include=['documents', 'embeddings'])

    if not results['documents']:
      return

    else:
      retrieved_documents = results['documents'][0]
      pairs = [[query, doc] for doc in retrieved_documents]
      scores = self.cross_encoder.predict(pairs)

      documents_with_scores = list(zip(retrieved_documents, scores))
      ranked_documents = sorted(documents_with_scores, key=lambda x: x[1],
                                reverse=True)
      top5_re_ranked_documents = [doc for doc, score in ranked_documents[:5]]

      return top5_re_ranked_documents

  def rag_response(self, query : str) -> str:

    '''
    Gera a resposta baseada nos documentos. Caso a pergunta do usuário
    não apresente itens armazenados relacionados que consiga lhe prover
    uma resposta, o modelo será instruído a não responder, instruindo
    a utilizar outra ferramenta.
    '''

    retrieved_documents = self._retrieve_documents(query)

    if not retrieved_documents:

      suggestion = '''
      Não há informações armazenadas no banco de dados vetorial para responder
      a essa query. UTILIZE a ferramenta {tavily_search_tool} para responder ao
      usuário de forma acurada.
      '''
      return suggestion

    else:

      information = '\n\n'.join(retrieved_documents)

      prompt_template = PromptTemplate(
          input_variables=['query', 'information'],
          template='''
          You are an experienced researcher on topics related to psychology and
          psychotherapy with solid knowledge in generative AI and is studying the
          existing and possible relationship of using generative AI in conducting
          psychotherapy.

          Your task is to respond clearly, didactically, and in detail to the user's query.
          Answer the user's question using only the provided information and in Portuguese.

          Question: {query}
          Information: {information}

          Helpful Answer:
          '''
          )

      chain = LLMChain(llm=self.model, prompt=prompt_template)

      response = chain.invoke(input={'query':query, 'information':information})

      return response

  def _run(self, query: str) -> str:

    '''
    Executa a ferramenta e retorna uma resposta baseada na pesquisa.
    '''

    return self.rag_response(query)

