In [1]:
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langgraph.graph import END, StateGraph
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.tools import tool
# from langchain_experimental.utilities import PythonREPL
from langchain_core.messages import AIMessage
from langgraph.prebuilt import ToolNode

from langchain_core.messages import (
    BaseMessage,
    HumanMessage,
    ToolMessage,
)

from dotenv import load_dotenv
load_dotenv()

import nest_asyncio
nest_asyncio.apply()

import os
from langchain.tools import Tool
import tiktoken
from datetime import datetime
from multiprocessing import Process, Queue
from difflib import unified_diff
from IPython.display import display, HTML

from llama_cloud_services import LlamaParse

from llama_index.vector_stores.qdrant import QdrantVectorStore
from llama_index.core import Settings
import qdrant_client
from IPython.display import Markdown, display
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader
from llama_index.core import StorageContext
from langchain_nomic.embeddings import NomicEmbeddings


from langchain_ollama import ChatOllama
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler

In [2]:
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
llm = ChatOllama(
    model="llama 3.1",
    callback_manager=callback_manager,
    temperature=1
)

In [3]:
LLAMA_CLOUD_API_KEY="llx-nxrF5SQ5MnbMrKhrb1HqfdFA3YajxTDjtjfctsyebDSdBa7W"

parser = LlamaParse(
    api_key=LLAMA_CLOUD_API_KEY,
    result_type="markdown"  # "markdown" and "text" are available
)

# use SimpleDirectoryReader to parse our file
file_extractor = {".pdf": parser}
documents = SimpleDirectoryReader(input_files=['/mnt/MainDrive/Codes/Heckeer/electronics-13-03233.pdf'], file_extractor=file_extractor).load_data()
print(documents)

Started parsing the file under job_id a7197595-f651-43fb-bbc7-159837a937e3
[Document(id_='295d76f6-2cd2-406e-bbec-4e5d4353b185', embedding=None, metadata={'file_path': '/mnt/MainDrive/Codes/Heckeer/electronics-13-03233.pdf', 'file_name': 'electronics-13-03233.pdf', 'file_type': 'application/pdf', 'file_size': 957053, 'creation_date': '2025-02-19', 'last_modified_date': '2025-02-19'}, excluded_embed_metadata_keys=['file_name', 'file_type', 'file_size', 'creation_date', 'last_modified_date', 'last_accessed_date'], excluded_llm_metadata_keys=['file_name', 'file_type', 'file_size', 'creation_date', 'last_modified_date', 'last_accessed_date'], relationships={}, metadata_template='{key}: {value}', metadata_separator='\n', text_resource=MediaResource(embeddings=None, data=None, text='# Next-Gen Dynamic Hand Gesture Recognition: MediaPipe, Inception-v3 and LSTM-Based Enhanced Deep Learning Model\n\nYaseen1, Oh-Jin Kwon1,*, Jaeho Kim2, Sonain Jamil3, Jinhee Lee1 and Faiz Ullah1\n\n1Department o

In [4]:
client = qdrant_client.QdrantClient(
    # you can use :memory: mode for fast and light-weight experiments,
    # it does not require to have Qdrant deployed anywhere
    # but requires qdrant-client >= 1.1.1
    # location=":memory:"
    # otherwise set Qdrant instance address with:
    # url="http://:"
    # otherwise set Qdrant instance with host and port:
    host="localhost",
    port=6333
    # set API KEY for Qdrant Cloud
    # api_key="",
)

In [5]:
Settings.embed_model = NomicEmbeddings(model="nomic-embed-text-v1.5")
vector_store = QdrantVectorStore(client=client, collection_name="Hecker")

storage_context = StorageContext.from_defaults(vector_store=vector_store)
index = VectorStoreIndex.from_documents(
    documents,
    storage_context=storage_context,
    # embedding=NomicEmbeddings(model="nomic-embed-text-v1.5", inference_mode="local"),
)


In [6]:
local_llm = "llama3.1"

In [7]:
def num_tokens_from_string(string: str, encoding_name: str = "cl100k_base") -> int:
    """Returns the number of tokens in a text string."""
    encoding = tiktoken.get_encoding(encoding_name)
    num_tokens = len(encoding.encode(string))
    return num_tokens

def chunk_text_by_sentence(text, chunk_size=2048):
    """Chunk the $text into sentences with less than 2k tokens."""
    sentences = text.split('. ')
    chunked_text = []
    curr_chunk = []
    for sentence in sentences:
        if num_tokens_from_string(". ".join(curr_chunk)) + num_tokens_from_string(sentence) + 2 <= chunk_size:
            curr_chunk.append(sentence)
        else:
            chunked_text.append(". ".join(curr_chunk))
            curr_chunk = [sentence]
    if curr_chunk:
        chunked_text.append(". ".join(curr_chunk))
    return chunked_text[0]

def chunk_text_front(text, chunk_size = 2048):
    '''
    get the first `trunk_size` token of text
    '''
    chunked_text = ""
    tokens = num_tokens_from_string(text)
    if tokens < chunk_size:
        return text
    else:
        ratio = float(chunk_size) / tokens
        char_num = int(len(text) * ratio)
        return text[:char_num]

def chunk_texts(text, chunk_size = 2048):
    '''
    trunk the text into n parts, return a list of text
    [text, text, text]
    '''
    tokens = num_tokens_from_string(text)
    if tokens < chunk_size:
        return [text]
    else:
        texts = []
        n = int(tokens/chunk_size) + 1
        
        part_length = len(text) // n
        
        extra = len(text) % n
        parts = []
        start = 0

        for i in range(n):

            end = start + part_length + (1 if i < extra else 0)
            parts.append(text[start:end])
            start = end
        return parts

In [8]:
from datetime import datetime
import os

def get_draft(question):
    # Getting the draft answer
    draft_prompt = '''
IMPORTANT:
Try to answer this question/instruction with step-by-step thoughts and make the answer more structural.
Use `\n\n` to split the answer into several paragraphs.
Just respond to the instruction directly. DO NOT add additional explanations or introducement in the answer unless you are asked to.
'''
    # openai_client = OpenAI(api_key=openai.api_key)
    draft_agent = ChatOllama(model=local_llm, callback_manager=callback_manager, format="json", temperature=1)
    response = draft_agent.invoke({
        "messages": [
            HumanMessage(content=f"{question}\n{draft_prompt}")
        ]
    })
    
    return response.content

In [9]:
from llama_cloud_services import LlamaParse
from llama_index.core import SimpleDirectoryReader
import tiktoken
from typing import List, Optional

class DocumentParser:
    def __init__(self, result_type="markdown"):
        self.parser = LlamaParse(result_type=result_type)
        self.file_extractor = {".pdf": self.parser}
        
    def load_document(self, file_path: str) -> List[str]:
        """Load and parse document"""
        documents = SimpleDirectoryReader(
            input_files=[file_path], 
            file_extractor=self.file_extractor
        ).load_data()
        return [doc.text for doc in documents]
    
    def chunk_document(self, content: str, chunk_size: int = 2048) -> List[str]:
        """Chunk document content"""
        encoding = tiktoken.get_encoding("cl100k_base")
        tokens = encoding.encode(content)
        chunks = []
        
        current_chunk = []
        current_size = 0
        
        for token in tokens:
            if current_size + 1 <= chunk_size:
                current_chunk.append(token)
                current_size += 1
            else:
                chunks.append(encoding.decode(current_chunk))
                current_chunk = [token]
                current_size = 1
                
        if current_chunk:
            chunks.append(encoding.decode(current_chunk))
            
        return chunks
    
    def parse_and_chunk(self, file_path: str) -> Optional[List[str]]:
        """Parse document and return chunks"""
        try:
            content = self.load_document(file_path)
            if not content:
                return None
            
            all_chunks = []
            for text in content:
                chunks = self.chunk_document(text)
                all_chunks.extend(chunks)
            
            return all_chunks
        except Exception as e:
            print(f"Error parsing document: {e}")
            return None

In [10]:
def split_draft(draft, split_char = '\n\n'):
    paragraphs = draft.split(split_char)
    draft_paragraphs = [para for para in paragraphs if len(para)>5]
    # print(f"The draft answer has {len(draft_paragraphs)}")
    return draft_paragraphs

def split_draft_openai(question, answer, NUM_PARAGRAPHS = 4):
    split_prompt = f'''
Split the answer of the question into multiple paragraphs with each paragraph containing a complete thought.
The answer should be splited into less than {NUM_PARAGRAPHS} paragraphs.
Use ## as splitting char to seperate the paragraphs.
So you should output the answer with ## to split the paragraphs.
**IMPORTANT**
Just output the query directly. DO NOT add additional explanations or introducement in the answer unless you are asked to.
'''
    Ollama_client = ChatOllama(model=local_llm, callback_manager=callback_manager, format="json", temperature=1)
    splited_answer = Ollama_client.invoke({
        "messages": [
            HumanMessage(content=f"##Question: {question}\n\n##Response: {answer}\n\n##Instruction: {split_prompt}")
        ]
    })
    split_draft_paragraphs = split_draft(splited_answer, split_char = '##')
    return split_draft_paragraphs.content

In [11]:
def get_query(question, answer):
    query_prompt = '''
I want to verify the content correctness of the given question, especially the last sentences.
Please summarize the content with the corresponding question.
This summarization will be used as a query to search with Bing search engine.
The query should be short but need to be specific to promise Bing can find related knowledge or pages.
You can also use search syntax to make the query short and clear enough for the search engine to find relevant language data.
Try to make the query as relevant as possible to the last few sentences in the content.
**IMPORTANT**
Just output the query directly. DO NOT add additional explanations or introducement in the answer unless you are asked to.
'''
    # openai_client = OpenAI(api_key = openai.api_key)
    Ollama_client = ChatOllama(model=local_llm, callback_manager=callback_manager, format="json", temperature=1)
    query = Ollama_client.invoke({
        "messages": [
            HumanMessage(content=f"##Question: {question}\n\n##Response: {answer}\n\n##Instruction: {query_prompt}")
        ]
    })
    return query.content

In [12]:
def get_revise_answer(question, answer, content):
    revise_prompt = '''
I want to revise the answer according to retrieved related text of the question in WIKI pages.
You need to check whether the answer is correct.
If you find some errors in the answer, revise the answer to make it better.
If you find some necessary details are ignored, add it to make the answer more plausible according to the related text.
If you find the answer is right and do not need to add more details, just output the original answer directly.
**IMPORTANT**
Try to keep the structure (multiple paragraphs with its subtitles) in the revised answer and make it more structual for understanding.
Add more details from retrieved text to the answer.
Split the paragraphs with \n\n characters.
Just output the revised answer directly. DO NOT add additional explanations or annoucement in the revised answer unless you are asked to.
'''
    # openai_client = OpenAI(api_key = openai.api_key)
    Ollama_client = ChatOllama(model=local_llm, callback_manager=callback_manager, format="json", temperature=1)
    revised_answer = Ollama_client.invoke({
        "messages": [
            HumanMessage(content=f"##Existing Text in Wiki Web: {content}\n\n##Question: {question}\n\n##Answer: {answer}\n\n##Instruction: {revise_prompt}")
        ]
    })
    return revised_answer.content

def get_reflect_answer(question, answer):
    reflect_prompt = '''
Give a title for the answer of the question.
And add a subtitle to each paragraph in the answer and output the final answer using markdown format. 
This will make the answer to this question look more structured for better understanding.
**IMPORTANT**
Try to keep the structure (multiple paragraphs with its subtitles) in the response and make it more structual for understanding.
Split the paragraphs with \n\n characters.
Just output the revised answer directly. DO NOT add additional explanations or annoucement in the revised answer unless you are asked to.
'''

    Ollama_client = ChatOllama(model=local_llm, callback_manager=callback_manager, format="json", temperature=1)
    reflected_answer = Ollama_client.invoke({
        "messages": [
            HumanMessage(content=f"##Question:\n{question}\n\n##Answer:\n{answer}\n\n##Instruction:\n{reflect_prompt}")
        ]
    })
    return reflected_answer.content


In [13]:

def get_query_wrapper(q, question, answer):
    result = get_query(question, answer)
    q.put(result)

# def get_content_wrapper(q, query):
#     result = get_content(query)
#     q.put(result)

def get_revise_answer_wrapper(q, question, answer, content):
    result = get_revise_answer(question, answer, content)
    q.put(result)

def get_reflect_answer_wrapper(q, question, answer):
    result = get_reflect_answer(question, answer)
    q.put(result)

from multiprocessing import Process, Queue
def run_with_timeout(func, timeout, *args, **kwargs):
    q = Queue()  
    p = Process(target=func, args=(q, *args), kwargs=kwargs)
    p.start()
    p.join(timeout)
    if p.is_alive():
        print(f"{datetime.now()} [INFO] Function {str(func)} running timeout ({timeout}s), terminating...")
        p.terminate()
        p.join() 
        result = None
    else:
        print(f"{datetime.now()} [INFO] Function {str(func)} executed successfully.")
        result = q.get()  # 从队列中获取结果
    return result

In [14]:
from difflib import unified_diff
from IPython.display import display, HTML
import gradio as gr


def generate_diff_html(text1, text2):
    diff = unified_diff(text1.splitlines(keepends=True),
                        text2.splitlines(keepends=True),
                        fromfile='text1', tofile='text2')

    diff_html = ""
    for line in diff:
        if line.startswith('+'):
            diff_html += f"<div style='color:green;'>{line.rstrip()}</div>"
        elif line.startswith('-'):
            diff_html += f"<div style='color:red;'>{line.rstrip()}</div>"
        elif line.startswith('@'):
            diff_html += f"<div style='color:blue;'>{line.rstrip()}</div>"
        else:
            diff_html += f"{line.rstrip()}<br>"
    return diff_html

newline_char = '\n'

def rat(question: str, document_path: str):
    # Initialize document parser
    doc_parser = DocumentParser()
    
    print(f"{datetime.now()} [INFO] Parsing document...")
    chunks = doc_parser.parse_and_chunk(document_path)
    if not chunks:
        return "Error: Could not parse document", ""
    
    
    print(f"{datetime.now()} [INFO] Generating draft...")
    draft = get_draft(question)
    print(f"{datetime.now()} [INFO] Return draft.")
    # print(f"##################### DRAFT #######################")
    # print(draft)
    # print(f"#####################  END  #######################")

    print(f"{datetime.now()} [INFO] Processing draft ...")
    # draft_paragraphs = split_draft(draft)
    draft_paragraphs = split_draft_openai(question, draft)
    
    
    print(f"{datetime.now()} [INFO] Draft is splitted into {len(draft_paragraphs)} sections.")
    
    answer = ""
    for i, p in enumerate(draft_paragraphs):
        # print(str(i)*80)
        print(f"{datetime.now()} [INFO] Revising {i+1}/{len(draft_paragraphs)} sections ...")
        answer = answer + '\n\n' + p
        # print(f"[{i}/{len(draft_paragraphs)}] Original Answer:\n{answer.replace(newline_char, ' ')}")

        # query = get_query(question, answer)
        print(f"{datetime.now()} [INFO] Generating query ...")
        res = run_with_timeout(get_query_wrapper, 30, question, answer)
        if not res:
            print(f"{datetime.now()} [INFO] Generating query timeout, skipping...")
            continue
        else:
            query = res
        print(f">>> {i}/{len(draft_paragraphs)} Query: {query.replace(newline_char, ' ')}")

        print(f"{datetime.now()} [INFO] Crawling network pages ...")
        # content = get_content(query)
        # res = run_with_timeout(get_content_wrapper, 30, query)
        # if not res:
        #     print(f"{datetime.now()} [INFO] Parsing network pages timeout, skipping ...")
        #     continue
        # else:
        #     content = res

        LIMIT = 2
        for j, c in enumerate(draft_paragraphs):
            if  j >= LIMIT: # limit rge number of network pages
                break
            print(f"{datetime.now()} [INFO] Revising answers with retrieved network pages...[{j}/{min(len(draft_paragraphs),LIMIT)}]")
            # answer = get_revise_answer(question, answer, c)
            res = run_with_timeout(get_revise_answer_wrapper, 30, question, answer, c)
            if not res:
                print(f"{datetime.now()} [INFO] Revising answers timeout, skipping ...")
                continue
            else:
                diff_html = generate_diff_html(answer, res)
                display(HTML(diff_html))
                answer = res
            print(f"{datetime.now()} [INFO] Answer revised [{j}/{min(len(draft_paragraphs),3)}]")
        # print(f"[{i}/{len(draft_paragraphs)}] REVISED ANSWER:\n {answer.replace(newline_char, ' ')}")
        # print()
    res = run_with_timeout(get_reflect_answer_wrapper, 30, question, answer)
    if not res:
        print(f"{datetime.now()} [INFO] Reflecting answers timeout, skipping next steps...")
    else:
        answer = res
    return draft, answer

page_title = "RAT: Retrieval Augmented Thoughts Elicit Context-Aware Reasoning in Long-Horizon Generation"
page_md = """
# RAT: Retrieval Augmented Thoughts Elicit Context-Aware Reasoning in Long-Horizon Generation

We explore how iterative revising a chain of thoughts with the help of information retrieval significantly improves large language models' reasoning and generation ability in long-horizon generation tasks, while hugely mitigating hallucination. In particular, the proposed method — retrieval-augmented thoughts (RAT) — revises each thought step one by one with retrieved information relevant to the task query, the current and the past thought steps, after the initial zero-shot CoT is generated.

Applying RAT to various base models substantially improves their performances on various long-horizon generation tasks; on average of relatively increasing rating scores by 13.63% on code generation, 16.96% on mathematical reasoning, 19.2% on creative writing, and 42.78% on embodied task planning.

Feel free to try our demo!

"""

def clear_func():
    return "", "", ""


with gr.Blocks(title = page_title) as demo:
   
    gr.Markdown(page_md)
    
    with gr.Row():
        file_input = gr.File(label="Upload Document")

    with gr.Row():
        chatgpt_box = gr.Textbox(
            label = "ChatGPT",
            placeholder = "Response from ChatGPT with zero-shot chain-of-thought.",
            elem_id = "chatgpt"
        )

    with gr.Row():
        stream_box = gr.Textbox(
            label = "Streaming",
            placeholder = "Interactive response with RAT...",
            elem_id = "stream",
            lines = 10,
            visible = False
        )
    
    with gr.Row():
        rat_box = gr.Textbox(
            label = "RAT",
            placeholder = "Final response with RAT ...",
            elem_id = "rat",
            lines = 6
        )

    with gr.Column(elem_id="instruction_row"):
        with gr.Row():
            instruction_box = gr.Textbox(
                label = "instruction",
                placeholder = "Enter your instruction here",
                lines = 2,
                elem_id="instruction",
                interactive=True,
                visible=True
            )
        # with gr.Row():
        #     model_radio = gr.Radio(["gpt-3.5-turbo", "gpt-4", "GPT-4-turbo"], elem_id="model_radio", value="gpt-3.5-turbo", 
        #                         label='GPT model', 
        #                         show_label=True,
        #                         interactive=True, 
        #                         visible=True) 
        #     openai_api_key_textbox = gr.Textbox(
        #         label='OpenAI API key',
        #         placeholder="Paste your OpenAI API key (sk-...) and hit Enter", 
        #         show_label=True, 
        #         lines=1, 
        #         type='password')
            
    # openai_api_key_textbox.change(set_openai_api_key,
    #     inputs=[openai_api_key_textbox],
    #     outputs=[])

    with gr.Row():
        submit_btn = gr.Button(
            value="submit", visible=True, interactive=True
        )
        clear_btn = gr.Button(
            value="clear", visible=True, interactive=True
        )
        regenerate_btn = gr.Button(
            value="regenerate", visible=True, interactive=True
        )

    submit_btn.click(
        fn = rat,
        inputs = [instruction_box, file_input],
        outputs = [chatgpt_box, rat_box]
    )

    clear_btn.click(
        fn = clear_func,
        inputs = [],
        outputs = [instruction_box, chatgpt_box, rat_box]
    )

    regenerate_btn.click(
        fn = rat,
        inputs = [instruction_box],
        outputs = [chatgpt_box, rat_box]
    )

    examples = gr.Examples(
        examples=[
            # "I went to the supermarket yesterday.", 
            # "Helen is a good swimmer."
            "Write a survey of retrieval-augmented generation in Large Language Models.",
            "Introduce Jin-Yong's life and his works.",
            "Summarize the American Civil War according to the timeline.",
            "Describe the life and achievements of Marie Curie"
            ],
        inputs=[instruction_box]
        )

demo.launch(server_name="0.0.0.0", debug=True)

  from .autonotebook import tqdm as notebook_tqdm
Exception in thread Thread-7 (run):
Traceback (most recent call last):
  File "/home/neutrino/miniconda3/envs/Ml/lib/python3.11/threading.py", line 1045, in _bootstrap_inner
    self.run()
  File "/home/neutrino/miniconda3/envs/Ml/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 766, in run_closure
    _threading_Thread_run(self)
  File "/home/neutrino/miniconda3/envs/Ml/lib/python3.11/threading.py", line 982, in run
    self._target(*self._args, **self._kwargs)
  File "/home/neutrino/miniconda3/envs/Ml/lib/python3.11/site-packages/uvicorn/server.py", line 65, in run
    return asyncio.run(self.serve(sockets=sockets))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/neutrino/miniconda3/envs/Ml/lib/python3.11/site-packages/nest_asyncio.py", line 26, in run
    loop = asyncio.get_event_loop()
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/neutrino/miniconda3/envs/Ml/lib/python3.11/site-packages/nest_asyncio.py

KeyboardInterrupt: 