In [5]:
import openai
import os
from dotenv import load_dotenv
load_dotenv()
openai.api_key = os.getenv("OPENAI_API_KEY")

from llama_index.llms.openai import OpenAI
llm = OpenAI(model="gpt-4")

In [29]:

from llama_index.core.workflow import Event
from llama_index.core.schema import NodeWithScore


class RetrieverEvent(Event):
    """Result of running retrieval"""

    nodes: list[NodeWithScore]


class CreateCitationsEvent(Event):
    """Add citations to the nodes."""

    nodes: list[NodeWithScore]

from llama_index.core.workflow import (
    Context,
    Workflow,
    StartEvent,
    StopEvent,
    step,
)


In [50]:
from llama_index.core import (
    SimpleDirectoryReader,
    VectorStoreIndex,
    StorageContext,
    load_index_from_storage
)
from llama_index.embeddings.openai import OpenAIEmbedding

def load_or_create_index(directory_path, persist_dir):
        if os.path.exists(persist_dir):
            print("Loading existing index...")
            storage_context = StorageContext.from_defaults(persist_dir=persist_dir)
            index = load_index_from_storage(storage_context)
        else:
            print("Creating new index...")
            documents = SimpleDirectoryReader(directory_path, recursive=True).load_data()
            index = VectorStoreIndex.from_documents(
                documents=documents,
                embed_model=OpenAIEmbedding(model_name="text-embedding-3-small"),
            )
            index.storage_context.persist(persist_dir=persist_dir)
        return index

index = load_or_create_index(
            "data/get-started/",
            "storage"
        )

Creating new index...


In [51]:
print(index)

<llama_index.core.indices.vector_store.base.VectorStoreIndex object at 0x176ff31a0>


In [52]:
from llama_index.core.prompts import PromptTemplate

CITATION_QA_TEMPLATE = PromptTemplate(
    "Please provide an answer based solely on the provided sources. "
    "When referencing information from a source, "
    "cite the appropriate source(s) using their corresponding numbers. "
    "Every answer should include at least one source citation. "
    "Only cite a source when you are explicitly referencing it. "
    "If none of the sources are helpful, you should indicate that. "
    "For example:\n"
    "Source 1:\n"
    "The sky is red in the evening and blue in the morning.\n"
    "Source 2:\n"
    "Water is wet when the sky is red.\n"
    "Query: When is water wet?\n"
    "Answer: Water will be wet when the sky is red [2], "
    "which occurs in the evening [1].\n"
    "Now it's your turn. Below are several numbered sources of information:"
    "\n------\n"
    "{context_str}"
    "\n------\n"
    "Query: {query_str}\n"
    "Answer: "
)

CITATION_REFINE_TEMPLATE = PromptTemplate(
    "Please provide an answer based solely on the provided sources. "
    "When referencing information from a source, "
    "cite the appropriate source(s) using their corresponding numbers. "
    "Every answer should include at least one source citation. "
    "Only cite a source when you are explicitly referencing it. "
    "If none of the sources are helpful, you should indicate that. "
    "For example:\n"
    "Source 1:\n"
    "The sky is red in the evening and blue in the morning.\n"
    "Source 2:\n"
    "Water is wet when the sky is red.\n"
    "Query: When is water wet?\n"
    "Answer: Water will be wet when the sky is red [2], "
    "which occurs in the evening [1].\n"
    "Now it's your turn. "
    "We have provided an existing answer: {existing_answer}"
    "Below are several numbered sources of information. "
    "Use them to refine the existing answer. "
    "If the provided sources are not helpful, you will repeat the existing answer."
    "\nBegin refining!"
    "\n------\n"
    "{context_msg}"
    "\n------\n"
    "Query: {query_str}\n"
    "Answer: "
)

DEFAULT_CITATION_CHUNK_SIZE = 512
DEFAULT_CITATION_CHUNK_OVERLAP = 20

In [80]:
from llama_index.core.schema import (
    MetadataMode,
    NodeWithScore,
    TextNode,
)

from llama_index.core.response_synthesizers import (
    ResponseMode,
    get_response_synthesizer,
)

from typing import Union, List
from llama_index.core.node_parser import SentenceSplitter


class CitationQueryEngineWorkflow(Workflow):
    @step(pass_context=True)
    async def retrieve(
        self, ctx: Context, ev: StartEvent
    ) -> RetrieverEvent:
        "Entry point for RAG, triggered by a StartEvent with `query`."
        query = ev.get("query")
        if not query:
            return None

        print(f"Query the database with: {query}")

        # store the query in the global context
        ctx.data["query"] = query

        if ev.index is None:
            print("Index is empty, load some documents before querying!")
            return None
        else:
            print("Index is not empty, proceed with querying!")

        retriever = ev.index.as_retriever(similarity_top_k=2)
        nodes = retriever.retrieve(query)
        print(f"Retrieved {len(nodes)} nodes.")
        return RetrieverEvent(nodes=nodes)
    
    @step(pass_context=True)
    async def create_citation_nodes(
        self, ctx: Context, ev: RetrieverEvent
    ) -> CreateCitationsEvent:
        """
        Modify retrieved nodes to create granular sources for citations.

        Takes a list of NodeWithScore objects and splits their content
        into smaller chunks, creating new NodeWithScore objects for each chunk.
        Each new node is labeled as a numbered source, allowing for more precise
        citation in query results.

        Args:
            nodes (List[NodeWithScore]): A list of NodeWithScore objects to be processed.

        Returns:
            List[NodeWithScore]: A new list of NodeWithScore objects, where each object
            represents a smaller chunk of the original nodes, labeled as a source.
        """
        nodes = ev.nodes

        new_nodes: List[NodeWithScore] = []

        text_splitter = SentenceSplitter(
            chunk_size=DEFAULT_CITATION_CHUNK_SIZE,
            chunk_overlap=DEFAULT_CITATION_CHUNK_OVERLAP,
        )

        for node in nodes:
            text_chunks = text_splitter.split_text(
                node.node.get_content(metadata_mode=MetadataMode.NONE)
            )

            for text_chunk in text_chunks:
                text = f"Source {len(new_nodes)+1}:\n{text_chunk}\n"

                new_node = NodeWithScore(
                    node=TextNode.parse_obj(node.node), score=node.score
                )
                new_node.node.text = text
                new_nodes.append(new_node)
        return CreateCitationsEvent(nodes=new_nodes)

    @step(pass_context=True)
    async def synthesize(
        self, ctx: Context, ev: CreateCitationsEvent
    ) -> StopEvent:
        """Return a streaming response using the retrieved nodes."""
        llm = OpenAI(model="gpt-4")
        query = ctx.data.get("query")
        print(f"Synthesizing response for query: {query}")

        synthesizer = get_response_synthesizer(
            llm=llm,
            text_qa_template=CITATION_QA_TEMPLATE,
            refine_template=CITATION_REFINE_TEMPLATE,
            response_mode=ResponseMode.COMPACT,
            use_async=True,
        )

        response = await synthesizer.asynthesize(query, nodes=ev.nodes)
        print(f"Response: {response}")
        return StopEvent(result=response)

In [81]:
from llama_index.utils.workflow import draw_all_possible_flows

draw_all_possible_flows(CitationQueryEngineWorkflow,filename="CitationQueryEngineWorkflow.html")

CitationQueryEngineWorkflow.html


In [82]:
w = CitationQueryEngineWorkflow()

In [83]:
result = await w.run(query="What is streamlit?", index=index)

Query the database with: What is streamlit?
Index is not empty, proceed with querying!
Retrieved 2 nodes.
Synthesizing response for query: What is streamlit?
Response: Streamlit is a tool for creating interactive apps using Python. It allows you to write apps in the same way you write plain Python scripts. You can add Streamlit commands into a normal Python script and run it with `streamlit run`. This will spin up a local Streamlit server and open your app in a new tab in your default web browser. The app can include charts, text, widgets, tables, and more. Streamlit's architecture has a unique data flow: any time something must be updated on the screen, Streamlit reruns your entire Python script from top to bottom [1][2].


In [84]:
from IPython.display import Markdown, display

display(Markdown(f"{result}"))

Streamlit is a tool for creating interactive apps using Python. It allows you to write apps in the same way you write plain Python scripts. You can add Streamlit commands into a normal Python script and run it with `streamlit run`. This will spin up a local Streamlit server and open your app in a new tab in your default web browser. The app can include charts, text, widgets, tables, and more. Streamlit's architecture has a unique data flow: any time something must be updated on the screen, Streamlit reruns your entire Python script from top to bottom [1][2].

In [86]:
print(result.source_nodes[1].node.get_text())

Source 2:
Data flow

Streamlit's architecture allows you to write apps the same way you write plain
Python scripts. To unlock this, Streamlit apps have a unique data flow: any
time something must be updated on the screen, Streamlit reruns your entire
Python script from top to bottom.

This can happen in two situations:

- Whenever you modify your app's source code.

- Whenever a user interacts with widgets in the app. For example, when dragging
  a slider, entering text in an input box, or clicking a button.

Whenever a callback is passed to a widget via the `on_change` (or `on_click`) parameter, the callback will always run before the rest of your script. For details on the Callbacks API, please refer to our Session State API Reference Guide.

And to make all of this fast and seamless, Streamlit does some heavy lifting
for you behind the scenes. A big player in this story is the
`@st.cache_data` decorator, which allows developers to skip certain
costly computations when their apps rer

In [72]:
retriever = index.as_retriever()
retrieved_nodes = retriever.retrieve("What is llama_index?")

In [73]:
print(retrieved_nodes)

[NodeWithScore(node=TextNode(id_='6da6d750-6260-4433-83b5-dabf42afc7d4', embedding=None, metadata={'file_path': '/Users/boringtao/Projects/AutoRAG/notebooks/data/get-started/fundamentals/_index.md', 'file_name': '_index.md', 'file_size': 1254, 'creation_date': '2024-08-21', 'last_modified_date': '2024-08-21'}, excluded_embed_metadata_keys=['file_name', 'file_type', 'file_size', 'creation_date', 'last_modified_date', 'last_accessed_date'], excluded_llm_metadata_keys=['file_name', 'file_type', 'file_size', 'creation_date', 'last_modified_date', 'last_accessed_date'], relationships={<NodeRelationship.SOURCE: '1'>: RelatedNodeInfo(node_id='8225bafb-71a4-4a8e-94bf-5e67c9b5eeb0', node_type=<ObjectType.DOCUMENT: '4'>, metadata={'file_path': '/Users/boringtao/Projects/AutoRAG/notebooks/data/get-started/fundamentals/_index.md', 'file_name': '_index.md', 'file_size': 1254, 'creation_date': '2024-08-21', 'last_modified_date': '2024-08-21'}, hash='869fdb122bfd92504d4cf9adb5958b115175df93719bd96aaf

In [76]:
from llama_index.core.data_structs import Node
from llama_index.core.response_synthesizers import ResponseMode
from llama_index.core import get_response_synthesizer

response_synthesizer = get_response_synthesizer(
    response_mode=ResponseMode.COMPACT
)

query_engine = index.as_query_engine(response_synthesizer=response_synthesizer)
response = query_engine.query("What is streamlit")

In [77]:
print(response)

Streamlit is a tool that allows users to easily create interactive web applications using Python scripts. By adding Streamlit commands to a Python script and running it with `streamlit run`, a local Streamlit server is started, and the app can be viewed in a web browser. Streamlit provides various commands like `st.text` for adding text and `st.line_chart` for drawing line charts, enabling users to create a wide range of visualizations and interactive elements in their apps.
