# RAPTOR: Recursive Abstractive Processing for Tree-Organized Retrieval

In [None]:
# NOTE: An OpenAI API key must be set here for application initialization, even if not in use.
# If you're not utilizing OpenAI models, assign a placeholder string (e.g., "not_used").
import os
#os.environ["OPENAI_API_KEY"] = "your-openai-key"

1) **Building**: RAPTOR recursively embeds, clusters, and summarizes chunks of text to construct a tree with varying levels of summarization from the bottom up. You can create a tree from the text in 'sample.txt' using `RA.add_documents(text)`.

2) **Querying**: At inference time, the RAPTOR model retrieves information from this tree, integrating data across lengthy documents at different abstraction levels. You can perform queries on the tree with `RA.answer_question`.

### Building the tree

In [None]:
from raptor import RetrievalAugmentation 

## Using other Open Source Models for Summarization/QA/Embeddings

If you want to use other models such as Llama or Mistral, you can very easily define your own models and use them with RAPTOR. 

In [None]:
import torch
from raptor import BaseSummarizationModel, BaseQAModel, BaseEmbeddingModel, RetrievalAugmentationConfig
from transformers import AutoTokenizer, pipeline

In [None]:
from huggingface_hub import login
add_to_git_credential=True
login("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")

## Building Summarization Model

In [None]:
import requests

class SummarizationModel(BaseSummarizationModel):
    def __init__(self, url="http://a0221.nhr.fau.de:5000/v1/chat/completions"):
        super().__init__()  # Initialize from BaseSummarizationModel if needed
        self.url = url
        self.headers = {
            "Content-Type": "application/json"
        }
        self.history = []

    def summarize(self, context, max_tokens=150):
        # Clear history for each new summarization request
        self.history = []

        # Construct the user message for summarization
        user_message = f"{context}"
        self.history.append({"role": "user", "content": f"Write a summary of the following, including as many key details as possible: {context}:"})

        # Prepare the data payload
        data = {
            "mode": "instruct",
            "temperature": 0.0,
            "messages": self.history
        }

        # Make the POST request to the specified URL
        try:
            response = requests.post(self.url, headers=self.headers, json=data, verify=False)

            # Check if the response is successful
            if response.status_code == 200:
                print(response.json())
                assistant_message = response.json()['choices'][0]['message']['content']
                print(assistant_message)
                return assistant_message.strip()
            else:
                return f"Error: {response.status_code} {response.text}"
        except requests.exceptions.RequestException as e:
            return f"Request error: {e}"


## Building the QA Model

In [None]:
import requests

class QAModel(BaseQAModel):
    def __init__(self, url="http://a0221.nhr.fau.de:5000/v1/chat/completions"):
        super().__init__()  # Initialize from BaseSummarizationModel if needed
        self.url = url
        self.headers = {
            "Content-Type": "application/json"
        }
        self.history = []

    def answer_question(self, context, question):
        # Clear history for each new summarization request
        self.history = []
       
        
        summarized_content=SummarizationModel().summarize(context)
        self.history.append({"role": "user", "content": f"context: {summarized_content}\n\nQuestion: {question}\n\nPlease provide a detailed and informative answer based on context above."})

        # Prepare the data payload
        data = {
            "mode": "instruct",
            "temperature": 0.7,
            "messages": self.history
        }
        
        # Make the POST request to the specified URL
        try:
            response = requests.post(self.url, headers=self.headers, json=data, verify=False)

            # Check if the response is successful
            if response.status_code == 200:
                print(response.json())
                assistant_message = response.json()['choices'][0]['message']['content']
                print(assistant_message)
                return assistant_message.strip()
            else:
                return f"Error: {response.status_code} {response.text}"
        except requests.exceptions.RequestException as e:
            return f"Request error: {e}"


## Building the embedding model by sentense transformers

In [None]:
from sentence_transformers import SentenceTransformer
class SBertEmbeddingModel(BaseEmbeddingModel):
    def __init__(self, model_name="sentence-transformers/multi-qa-mpnet-base-cos-v1"):
        self.model = SentenceTransformer(model_name)

    def create_embedding(self, text):
        return self.model.encode(text,show_progress_bar=False)


In [None]:
RAC = RetrievalAugmentationConfig(summarization_model=SummarizationModel(), qa_model=QAModel(), embedding_model=SBertEmbeddingModel())

In [None]:
RA = RetrievalAugmentation(config=RAC)

## Building the tree by loading knowledge base

In [None]:
with open('demo/sample.txt', 'r', encoding='utf-8') as file:
    text = file.read()
print(text[:100])    
RA.add_documents(text)

SAVE_PATH = "demo/tech_txt_tree_structure"
RA.save(SAVE_PATH)

## Testing the response by asking questions

In [None]:
question = "what is time"

answer = RA.answer_question(question=question)

if answer is not None:
 print("Answer: ", answer)
else:
    print("No answer found")

## Visualizing Tree Structure

In [None]:
tree = RA.tree
tree.root_nodes
def print_tree_layers(root_nodes):
    """
    Iterates over the tree from the root nodes and prints node index and text layer by layer.

    Args:
      root_nodes: A dictionary mapping node index to Node objects.
    """
        
    all_nodes = tree.all_nodes
    current_layer = list(root_nodes.values())  # Convert root_nodes to a list for iteration
    level = 0
    while current_layer:
        print(f"================= Level {level} ================= ")
        next_layer = []
        for node in current_layer:
            print(f"Index: {node.index}, Text: {node.text}\n")
            next_layer.extend(all_nodes.get(child_index) for child_index in node.children)
        
        current_layer = next_layer
        level += 1

print_tree_layers(tree.root_nodes)


## Loading the saved tree structure

In [None]:
SAVE_PATH = "demo/sample_txt_tree_structure"
RA = RetrievalAugmentation(config=RAC,tree=SAVE_PATH)
question = "how is the time cindrella living"
answer = RA.answer_question(question=question)