In [None]:
import os
import openai
import chromadb
import warnings
import nest_asyncio
import re

from llama_parse import LlamaParse
from llama_index.core import Document, VectorStoreIndex, get_response_synthesizer, StorageContext, QueryBundle
from llama_index.core.retrievers import VectorIndexRetriever, BaseRetriever
from llama_index.core.node_parser import SemanticDoubleMergingSplitterNodeParser, LanguageConfig
from llama_index.core.schema import NodeWithScore
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.retrievers.bm25 import BM25Retriever
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.vector_stores.chroma import ChromaVectorStore
from typing import Literal, List, Optional
from dotenv import load_dotenv
from tqdm.auto import tqdm
from context_cite import ContextCiter

resource module not available on Windows


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
load_dotenv()
nest_asyncio.apply()
warnings.filterwarnings("ignore")

In [3]:
openai.api_key = os.getenv("OPENAI_API_KEY")
    
config = LanguageConfig(language="english", spacy_model="en_core_web_md") # must download the model first
# chroma_client = chromadb.PersistentClient()

embed_model = OpenAIEmbedding()


splitter = SemanticDoubleMergingSplitterNodeParser(
    initial_threshold=0.7,
    appending_threshold=0.9,
    merging_threshold=0.9,
    language_config=config,
    max_chunk_size=1024,
)

In [4]:
parser = LlamaParse(
    api_key=os.getenv("LLAMACLOUD_API_KEY"),
    num_workers=8,
    show_progress=True,
    result_type="markdown"
)

In [5]:
document = 'documents/attention.pdf'
MODEL_NAME = "Llama-3.2-1B-Instruct"
model_name = "meta-llama/Llama-3.2-1B-Instruct" # 3.2 1B Instruct for faster inference, 3.1 8B for better performance

In [6]:
file = "documents/sample.pdf"
if not os.path.exists(file):
    raise FileNotFoundError(f"File {file} not found")

In [7]:
documents = parser.load_data(file)

Started parsing the file under job_id 7d8f1997-7ffc-4e1f-8a29-105941ae2698


In [8]:
nodes = splitter.get_nodes_from_documents(documents, show_progress=True)

Parsing nodes: 100%|██████████| 6/6 [00:04<00:00,  1.20it/s]


In [9]:
print(f"Splitted {len(nodes)} nodes from {len(documents)} documents")

Splitted 156 nodes from 6 documents


In [10]:
storage_context = StorageContext.from_defaults()
storage_context.docstore.add_documents(nodes)
vector_index = VectorStoreIndex(nodes=nodes, 
                         insert_batch_size=1024, 
                         storage_context=storage_context,
                         show_progress=True)                                                          

Generating embeddings: 100%|██████████| 156/156 [00:04<00:00, 31.97it/s]


In [11]:
dense_retriever = VectorIndexRetriever(index=vector_index, similarity_top_k=5)
sparse_retriever = BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=5)

res_synth = get_response_synthesizer()

In [12]:
class HybridRetriever(BaseRetriever):
    def __init__(self, dense_retriever: BaseRetriever = dense_retriever, 
                 sparse_retriever: BaseRetriever = sparse_retriever,
                 mode: Literal["AND", "OR"] = "OR",
                 **kwargs) -> None:
        self.dense_retriever = dense_retriever
        self.sparse_retriever = sparse_retriever
        self.mode = mode

        super().__init__(**kwargs)

    def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
        dense_res = self.dense_retriever.retrieve(query_bundle)
        sparse_res = self.sparse_retriever.retrieve(query_bundle)

        dense_ids = {n.node.node_id for n in dense_res}
        sparse_ids = {n.node.node_id for n in sparse_res}

        combined_ids = {n.node.node_id: n for n in dense_res}
        combined_ids.update({n.node.node_id: n for n in sparse_res})

        if self.mode == "AND":
            ids = dense_ids.intersection(sparse_ids)

        elif self.mode == "OR":
            ids = dense_ids.union(sparse_ids)

        else:
            raise ValueError("Invalid mode. Must be either 'AND' or 'OR'.")
        
        retrieved_nodes = [combined_ids[id] for id in ids]
        return retrieved_nodes

In [13]:
hybrid_retriever = HybridRetriever(dense_retriever=dense_retriever, sparse_retriever=sparse_retriever)

In [14]:
query_engine = RetrieverQueryEngine(retriever=hybrid_retriever, response_synthesizer=res_synth)

In [15]:
res = query_engine.query("What does printf mean in C?")
res

Response(response='printf in C stands for "print formatted".', source_nodes=[NodeWithScore(node=TextNode(id_='50e7c3ce-a0c0-48ca-9767-ddf41e899358', embedding=None, metadata={}, excluded_embed_metadata_keys=[], excluded_llm_metadata_keys=[], relationships={<NodeRelationship.SOURCE: '1'>: RelatedNodeInfo(node_id='5a26bf52-9137-434e-9a39-c42c62ac00fd', node_type='4', metadata={}, hash='0e21741e5655b2140ca5424d788011fde501d9208075dbc9f3fbbe54b3e7d3dc'), <NodeRelationship.PREVIOUS: '2'>: RelatedNodeInfo(node_id='5169b0a8-8c22-4b01-9767-dd8f51517729', node_type='1', metadata={}, hash='7f5e520fd1f480c7fc6a6e39a67447d4d64372b4352c6d41aa202d14f38eb4eb'), <NodeRelationship.NEXT: '3'>: RelatedNodeInfo(node_id='5f06cff0-f8f0-4d84-8a73-8fa02fce6b24', node_type='1', metadata={}, hash='2036882499f36a3d94a801aa6617ac45bfbc45247f59f32b1b5f6415b435e302')}, metadata_template='{key}: {value}', metadata_separator='\n', text='The function body consists of a single statement, a call to the printf() function

In [16]:
def rerank(query, nodes,
           model=model_name,
           top_k=5):
    context = "\n\n".join([node.node.get_content() for node in nodes])
    
    cc = ContextCiter.from_pretrained(
        model_name,
        context=context,
        query=query,
        device="cuda"
    )

    attributions = cc.get_attributions(as_dataframe=False, top_k=len(nodes))
    segments = cc.sources

    node_scores = {}
    for node in nodes:
        node_text = node.node.get_content()
        cumulative_score = 0.0
        # For each source segment and its associated attribution score,
        # if the segment appears in the node's text, add its score.
        for seg, score in zip(segments, attributions):
            if seg.strip() and seg.strip() in node_text:
                cumulative_score += score
        node_scores[node.node.node_id] = cumulative_score

    # Re-rank the nodes by descending cumulative score.
    reranked_nodes = sorted(nodes, key=lambda x: node_scores.get(x.node.node_id, 0.0), reverse=True)
    
    return reranked_nodes, cc.response, node_scores

In [17]:
def query_pdf(query):
    global query_engine
    if query_engine is None:
        return "Please upload a PDF first."

    nodes = query_engine.retriever.retrieve(QueryBundle(query))
    reranked_nodes, cc_response, node_scores = rerank(query, nodes)
    
    top_nodes = reranked_nodes[:5]  # select the top 5 nodes
    final_context = "\n\n".join([node.node.get_content() for node in top_nodes])
    
    final_response = (
        f"{cc_response}\n\nTop Ranked Context:\n{final_context}\n\nNode Scores:\n{node_scores}"
    )
    
    return final_response

In [None]:
query_pdf("What is the purpose of the printf function in C?")