# Quackling — Prev Next Augmentation

In [1]:
# requirements for this example:
%pip install -qq \
    quackling \
    python-dotenv \
    llama-index-embeddings-huggingface \
    llama-index-llms-huggingface-api \
    llama-index-vector-stores-milvus

Note: you may need to restart the kernel to use updated packages.


In [2]:
FILE_PATHS = [
    # "/path/to/local/pdf",  # file path
    "https://arxiv.org/pdf/2206.01062",  # URL (DocLayNet paper)
]
TEXT_QA_TEMPLATE_STR = "Context information is below.\n---------------------\n{context_str}\n---------------------\nGiven the context information and not prior knowledge, answer the query.\nQuery: {query_str}\nAnswer:\n"
QUERY = "How many pages were human annotated?"
TOP_K = 1

In [3]:
import os
from tempfile import TemporaryDirectory

from dotenv import load_dotenv
from pydantic import TypeAdapter
from rich.pretty import pprint

load_dotenv()

# embeddings:
HF_EMBED_MODEL_ID = "BAAI/bge-small-en-v1.5"
# vector store:
MILVUS_URL = os.environ.get(
    "MILVUS_URL", f"{(tmp_dir := TemporaryDirectory()).name}/milvus_demo.db"
)
MILVUS_COLL_NAME = os.environ.get("MILVUS_COLL_NAME", "quackling_prev_next_aug")
MILVUS_KWARGS = TypeAdapter(dict).validate_json(os.environ.get("MILVUS_KWARGS", "{}"))
# LLM:
HF_API_KEY = os.environ.get("HF_API_KEY")
HF_LLM_MODEL_ID = "mistralai/Mixtral-8x7B-Instruct-v0.1"

In [4]:
import warnings

warnings.filterwarnings(action="ignore", category=UserWarning, module="pydantic|torch")
warnings.filterwarnings(action="ignore", category=FutureWarning, module="easyocr")

## Initialization

In [5]:
from llama_index.core import Settings
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.llms.huggingface_api import HuggingFaceInferenceAPI

Settings.embed_model = HuggingFaceEmbedding(model_name=HF_EMBED_MODEL_ID)
Settings.llm = HuggingFaceInferenceAPI(
    token=HF_API_KEY,
    model_name=HF_LLM_MODEL_ID,
)

embed_dim = len(Settings.embed_model.get_text_embedding("hi"))

## Ingestion

In [6]:
from quackling.llama_index.node_parsers import HierarchicalJSONNodeParser
from quackling.llama_index.readers import DoclingPDFReader

reader = DoclingPDFReader(parse_type=DoclingPDFReader.ParseType.JSON)
node_parser = HierarchicalJSONNodeParser()

In [7]:
docs = reader.load_data(file_path=FILE_PATHS)
pprint(docs, max_length=2, max_string=250, max_depth=4)

Fetching 7 files:   0%|          | 0/7 [00:00<?, ?it/s]

At this point, we define the *vector store*. Here, we showcase using Milvus, but any LlamaIndex-compatible vector store could be used just as well.

In [8]:
from llama_index.vector_stores.milvus import MilvusVectorStore

vector_store = MilvusVectorStore(
    uri=MILVUS_URL,
    collection_name=MILVUS_COLL_NAME,
    dim=embed_dim,
    overwrite=True,
    **MILVUS_KWARGS,
)

In [9]:
from llama_index.core import VectorStoreIndex

index = VectorStoreIndex.from_vector_store(vector_store)

In [10]:
from llama_index.core.ingestion import IngestionPipeline

pipeline = IngestionPipeline(
    transformations=[node_parser, Settings.embed_model],
    vector_store=vector_store,
)
nodes = pipeline.run(documents=docs)

## RAG

### Without prev-next augmentation

In [11]:
from llama_index.core import PromptTemplate
from llama_index.core.response_synthesizers.type import ResponseMode

query_engine = index.as_query_engine(
    similarity_top_k=TOP_K,
    text_qa_template=PromptTemplate(TEXT_QA_TEMPLATE_STR),
    response_mode=ResponseMode.TREE_SUMMARIZE,
)
query_res = query_engine.query(QUERY)
pprint(query_res, max_length=3, max_string=250, max_depth=4)

### With prev-next augmentation

Here we create a `PrevNextNodePostprocessor` (currently beta) for augmenting the prompt with nodes around the retrieved one:

In [12]:
from llama_index.core.postprocessor.node import PrevNextNodePostprocessor
from llama_index.core.storage.docstore import SimpleDocumentStore

docstore = SimpleDocumentStore()
for n in nodes:
    n.embedding = None
docstore.add_documents(nodes)

processor = PrevNextNodePostprocessor(
    docstore=docstore,
    mode="next",
    num_nodes=4,
)

In [13]:
query_engine = index.as_query_engine(
    similarity_top_k=TOP_K,
    text_qa_template=PromptTemplate(TEXT_QA_TEMPLATE_STR),
    node_postprocessors=[processor],  # <==
    response_mode=ResponseMode.TREE_SUMMARIZE,
)
query_res = query_engine.query(QUERY)
pprint(query_res, max_length=5, max_string=250, max_depth=5)