# Multimodal RAG with Langchain

This cookbook shows how to perform RAG on the table and text extraction output of the nv-ingest pdf extraction pipeline

Using RAG on tables can present some challenges as raw table data doesn't always work well with semantic similarity search. To account for this, we will generate summaries of the table data to perform the similarity search on


First, let's load in the text and table json content from the extracted nv-ingest metadata:

In [1]:
import json
from pathlib import Path

text_data = json.loads(Path("./processed_docs/text/multimodal_test.pdf.metadata.json").read_text())
table_data = json.loads(Path("./processed_docs/structured/multimodal_test.pdf.metadata.json").read_text())

text_content = [doc["metadata"]["content"] for doc in text_data]
tables = [table["metadata"]["table_metadata"]["table_content"] for table in table_data]

Then, we'll split up our text content into smaller chunks while maintaining a large enough window to avoid losing context. However, we don't want to do this with our tables as doing so might break them up and corrupt them

In [2]:
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.documents import Document

text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=2000, chunk_overlap=200, add_start_index=True
)
texts = text_splitter.split_documents([Document(text) for text in text_content])

Next, we'll create a chain that uses an llm to summarize our table or text chunks

In [3]:
import os
from langchain_nvidia_ai_endpoints import ChatNVIDIA

os.environ["NVIDIA_API_KEY"] = ""

llm = ChatNVIDIA(model="meta/llama-3.1-8b-instruct")

In [4]:
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate

prompt_text = """You are an assistant tasked with summarizing tables and text. \ 
Give a concise summary of the table or text. Table or text chunk: {element} """
prompt = ChatPromptTemplate.from_template(prompt_text)

summarize_chain = {"element": lambda x: x} | prompt | llm | StrOutputParser()

And then we'll apply that chain to each of our text and table chunks

In [5]:
table_summaries = summarize_chain.batch(tables, {"max_concurrency": 5})
text_summaries = summarize_chain.batch(texts, {"max_concurrency": 5})

Next, we'll create a multi vector retriever which allows us to store vectors both for the raw text and tables as well as the summaries we generated

In [6]:
from langchain.retrievers.multi_vector import MultiVectorRetriever
from langchain.storage import InMemoryStore
from langchain_chroma import Chroma
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings

vectorstore = Chroma(collection_name="summaries", embedding_function=NVIDIAEmbeddings())

store = InMemoryStore()
id_key = "doc_id"

retriever = MultiVectorRetriever(
    vectorstore=vectorstore,
    docstore=store,
    id_key=id_key,
)

Now, we can add the text and table as well as the text and table summaries to our retriever

In [7]:
import uuid

doc_ids = [str(uuid.uuid4()) for _ in texts]
summary_texts = [
    Document(page_content=s, metadata={id_key: doc_ids[i]})
    for i, s in enumerate(text_summaries)
]
retriever.vectorstore.add_documents(summary_texts)
retriever.docstore.mset(list(zip(doc_ids, texts)))

table_ids = [str(uuid.uuid4()) for _ in tables]
summary_tables = [
    Document(page_content=s, metadata={id_key: table_ids[i]})
    for i, s in enumerate(table_summaries)
]
retriever.vectorstore.add_documents(summary_tables)
retriever.docstore.mset(list(zip(table_ids, tables)))

Finally, we'll create an RAG chain that we can use to query our pdf in natural language

In [8]:
from langchain_core.runnables import RunnablePassthrough

template = """You are an assistant for question-answering tasks. 
Answer the question based only on the following context, which can include text and tables:
{context}
Question: {question}
"""
prompt = ChatPromptTemplate.from_template(template)

model = ChatNVIDIA(model="meta/llama-3.1-8b-instruct")

chain = (
    {"context": retriever, "question": RunnablePassthrough()}
    | prompt
    | model
    | StrOutputParser()
)

In [9]:
chain.invoke("What is the dog doing and where?")

'The dog is chasing a squirrel in the front yard.'