In [1]:
import os
import torch
import openai
import chromadb
import warnings
import nest_asyncio
import pandas as pd

from llama_parse import LlamaParse
from llama_index.core import Document, VectorStoreIndex, get_response_synthesizer, StorageContext, QueryBundle
from llama_index.core.retrievers import VectorIndexRetriever, BaseRetriever
from llama_index.core.node_parser import SemanticDoubleMergingSplitterNodeParser, LanguageConfig
from llama_index.core.schema import NodeWithScore
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.retrievers.bm25 import BM25Retriever
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.vector_stores.chroma import ChromaVectorStore
from typing import Literal, List, Optional
from dotenv import load_dotenv
from tqdm.auto import tqdm
from context_cite import ContextCiter

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
torch.cuda.set_device(1)

In [3]:
load_dotenv()
nest_asyncio.apply()
warnings.filterwarnings("ignore")

In [4]:
document = 'documents/attention.pdf'
MODEL_NAME = "Llama-3.2-1B-Instruct"
model_name = "meta-llama/Llama-3.2-1B-Instruct" # 3.2 1B Instruct for faster inference, 3.1 8B for better performance

In [5]:
openai.api_key = os.getenv("OPENAI_API_KEY")
    
config = LanguageConfig(language="english", spacy_model="en_core_web_md") # must download the model first
embed_model = OpenAIEmbedding()
splitter = SemanticDoubleMergingSplitterNodeParser(
    initial_threshold=0.6,
    appending_threshold=0.5,
    merging_threshold=0.6,
    language_config=config,
    max_chunk_size=1024,
)

In [6]:
parser = LlamaParse(
    api_key=os.getenv("LLAMA_CLOUD_API_TOKEN"),
    num_workers=8,
    show_progress=True,
    result_type="markdown"
)

In [7]:
file = "documents/intro_to_ml.pdf"
if not os.path.exists(file):
    raise FileNotFoundError(f"File {file} not found")

In [8]:
documents = parser.load_data(file)
nodes = splitter.get_nodes_from_documents(documents, show_progress=True)

Started parsing the file under job_id b2b81d8b-3bef-4a11-90e6-c6264e38dec0
..

Parsing nodes: 100%|██████████| 640/640 [01:46<00:00,  5.98it/s]


In [9]:
storage_context = StorageContext.from_defaults()
storage_context.docstore.add_documents(nodes)
vector_index = VectorStoreIndex(nodes=nodes, 
                         insert_batch_size=1024, 
                         storage_context=storage_context,
                         show_progress=True)                                                          

Generating embeddings: 100%|██████████| 1024/1024 [00:23<00:00, 43.57it/s]
Generating embeddings: 100%|██████████| 1024/1024 [00:22<00:00, 45.05it/s]
Generating embeddings: 100%|██████████| 919/919 [00:15<00:00, 58.72it/s]


In [10]:
dense_retriever = VectorIndexRetriever(index=vector_index, similarity_top_k=5)
sparse_retriever = BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=5)

res_synth = get_response_synthesizer()

In [11]:
class HybridRetriever(BaseRetriever):
    def __init__(self, dense_retriever: BaseRetriever = dense_retriever, 
                 sparse_retriever: BaseRetriever = sparse_retriever,
                 mode: Literal["AND", "OR"] = "OR",
                 **kwargs) -> None:
        self.dense_retriever = dense_retriever
        self.sparse_retriever = sparse_retriever
        self.mode = mode

        super().__init__(**kwargs)

    def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
        dense_res = self.dense_retriever.retrieve(query_bundle)
        sparse_res = self.sparse_retriever.retrieve(query_bundle)

        dense_ids = {n.node.node_id for n in dense_res}
        sparse_ids = {n.node.node_id for n in sparse_res}

        combined_ids = {n.node.node_id: n for n in dense_res}
        combined_ids.update({n.node.node_id: n for n in sparse_res})

        if self.mode == "AND":
            ids = dense_ids.intersection(sparse_ids)

        elif self.mode == "OR":
            ids = dense_ids.union(sparse_ids)

        else:
            raise ValueError("Invalid mode. Must be either 'AND' or 'OR'.")
        
        retrieved_nodes = [combined_ids[id] for id in ids]
        return retrieved_nodes

In [12]:
hybrid_retriever = HybridRetriever(dense_retriever=dense_retriever, sparse_retriever=sparse_retriever)
query_engine = RetrieverQueryEngine(retriever=hybrid_retriever, response_synthesizer=res_synth)

In [19]:
def rerank(query, nodes,
           model=model_name,
           long=False,
           top_k=None):
    context = "\n\n".join([node.node.get_content() for node in nodes])
    
    cc = ContextCiter.from_pretrained(
        model_name,
        context=context,
        query=query,
        device="cuda",
        solver="elastic_net",
        num_ablations=128
    )

    attributions = cc.get_attributions(as_dataframe=True, top_k=len(nodes) if top_k is None else top_k)
    
    if hasattr(attributions, "data"):
        attributions_df = attributions.data
    else:
        attributions_df = attributions

    segments = cc.sources
    score_list = attributions_df["Score"].tolist()

    node_scores = {}
    for node in nodes:
        node_text = node.node.get_content()
        cumulative_score = 0.0
        for seg, score in zip(segments, score_list):
            if seg.strip() and seg.strip() in node_text:
                cumulative_score += score
        node_scores[node.node.node_id] = cumulative_score

    reranked_nodes = sorted(nodes, key=lambda x: node_scores.get(x.node.node_id, 0.0), reverse=True)
    return reranked_nodes, cc.response, node_scores, attributions

In [23]:
def query_pdf(query: str, top_k: int = 5) -> list:
    """ 
    Query the PDF document using the query engine."
    """
    global query_engine, nodes
    if query_engine is None:
        return "Please upload a PDF first."

    long = True if len(nodes) > 100 else False
    rnodes = query_engine.retriever.retrieve(QueryBundle(query))
    reranked_nodes, cc_response, node_scores, attrs = rerank(query, rnodes)
    
    top_nodes = reranked_nodes[:top_k]  
    final_context = "\n\n".join([node.node.get_content() for node in top_nodes])
    
    final_response = (
        f"{cc_response}\n\nTop Ranked Context:\n{final_context}\n\nNode Scores:\n{node_scores}"
    )

    if hasattr(attrs, "data"):
        attrs = attrs.data
    
    return final_response, attrs

In [25]:
res, attributions = query_pdf("What is machine learning?")

Attributed: Machine learning is a branch of artificial intelligence that enables computers to learn and improve their performance on a task without being explicitly programmed. It involves training algorithms to recognize patterns, make predictions, or take actions based on data, allowing machines to adapt to new situations and improve their performance over time.


100%|██████████| 128/128 [00:13<00:00,  9.22it/s]


In [24]:
attributions

Unnamed: 0,Score,Source
0,15.341433,Machine learning also helps us find solutions ...
1,6.912308,Contents Preface xvii 1.1 What Is Machine Lear...
2,2.784766,"To be intelligent, a system that is in a chang..."
3,1.166246,"At the same time, we know that a face image is..."
4,0.977079,If the system can learn and adapt to such chan...
5,0.64802,"In science, large amounts of data in physics, ..."
6,0.607022,But machine learning is not just a database pr...
7,0.594878,Because we are not able to explain our experti...


In [17]:
attributions.iloc[0].Source

'Machine learning also helps us find solutions to many problems in vision, speech recognition, and robotics.'

In [18]:
attributions.iloc[1].Source

'Contents Preface xvii 1.1 What Is Machine Learning? 1 1.2 Examples of Machine Learning Applications 4 1.2.1 Learning Associations 4 1.2.4 Unsupervised Learning 11 2.1 Learning a Class from Examples 21 2.7 Model Selection and Generalization 37 2.8 Dimensions of a Supervised Machine Learning Algorithm 41 “Learning Logical Definitions from Relations.” Machine Learning 5:239–266. “An Introduction to MCMC for Machine Learning.” Machine Learning 50:5–43. “Q-learning.” Machine Learning 8:279–292. 1.1 What Is Machine Learning? abundant: In addition to retail, in finance banks analyze their past data to build models to use in credit applications, fraud detection, and the stock market.'