In [None]:
#Install these required dependencies to run this notebook
# !pip install python-dotenv==1.0.0
# !pip install requests
# !pip install sseclient-py==1.8.0
# !pip install pdf2image==1.17.0
# !pip install langchain-sambanova
# !pip install openai

In [None]:
# References

# https://github.com/Unstructured-IO/unstructured  
# https://docs.unstructured.io/open-source/core-functionality/chunking
# https://github.com/Unstructured-IO/unstructured?tab=readme-ov-file

In [None]:
# Extract Text, Tables and Images from PDF document
# Once we have the PDF downloaded, we will utilize unstructured.io library to process our document and extract the contents.

# _Extraction of Text and Image summaries_

In [None]:
from unstructured.partition.pdf import partition_pdf

In [None]:
raw_pdf_elements=partition_pdf(
    filename="../ML_PRACTICE/data/NIPS-2017-attention-is-all-you-need-Paper.pdf",     #put pdf path 
    strategy="hi_res",
    extract_images_in_pdf=True,
    extract_image_block_types=["Image", "Table"],
    extract_image_block_to_payload=False,
    extract_image_block_output_dir="image_store",
  )

In [None]:
raw_pdf_elements

### _Identify common elements_

In [None]:
# seen_classes = set()
# unique_elements = []

# for element in raw_pdf_elements:
#     # print(type(element))  # `elements` is your list of parsed elements
#     if type(element) not in seen_classes:
#         unique_elements.append(element)
#         seen_classes.add(type(element))


# seen_classes
# unique_elements

### _store the img,text and table data_

In [None]:
Header=[]
Footer=[]
Title=[]
NarrativeText=[]
Text=[]
ListItem=[]


for element in raw_pdf_elements:
  if "unstructured.documents.elements.Header" in str(type(element)):
            Header.append(str(element))
  elif "unstructured.documents.elements.Footer" in str(type(element)):
            Footer.append(str(element))
  elif "unstructured.documents.elements.Title" in str(type(element)):
            Title.append(str(element))
  elif "unstructured.documents.elements.NarrativeText" in str(type(element)):
            NarrativeText.append(str(element))
  elif "unstructured.documents.elements.Text" in str(type(element)):
            Text.append(str(element))
  elif "unstructured.documents.elements.ListItem" in str(type(element)):
            ListItem.append(str(element))




In [None]:
img=[]
for element in raw_pdf_elements:
    if "unstructured.documents.elements.Image" in str(type(element)):
        img.append(str(element))


img

In [None]:
for i in range(len(img)):
    print(i, img[i])

In [None]:
table=[]
for element in raw_pdf_elements:
  if "unstructured.documents.elements.Table" in str(type(element)):
            table.append(str(element))

len(table)
print(table)

In [None]:
for i in range(len(table)):
    print(i, table[i])

In [None]:
NarrativeText=[]
for element in raw_pdf_elements:
    if "unstructured.documents.elements.NarrativeText" in str(type(element)):
            NarrativeText.append(str(element))
# NarrativeText

# _Summary of image,text,and table_

In [None]:
import sys
sys.path.append('../MULTIMODAL_RAG/model_wrapper.py')  # Add the folder path
import model_wrapper

In [None]:
from model_wrapper import SambaNovaCloud
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate

In [None]:
import os
sambanova_api_key = "SAMBANOVA_API_KEY"
os.environ["SAMBANOVA_API_KEY"] = sambanova_api_key

In [None]:
from langchain_sambanova import ChatSambaNovaCloud

llm = ChatSambaNovaCloud(
    model="Meta-Llama-3.3-70B-Instruct",
    max_tokens=2024,
    temperature=0.7,
    top_p=0.01,
)

In [None]:
# imports ChatPromptTemplate, a LangChain class that lets you define and format prompts for chat-based LLMs


#####Chaining
from langchain_core.prompts import ChatPromptTemplate


prompt_text="""You are a helpful assistant tasked with summarizing text.Give a concise summary of the NarrativeText.NarrativeText {element}"""

prompt_table="""You are a helpful assistant tasked with summarizing tables .Give a concise summary of the table.Table {element}"""


# These lines convert the plain text strings into LangChain-compatible templates.
# You can now use .format_messages(element="...") to inject content into {element}.

prompt_text = ChatPromptTemplate.from_template(prompt_text)
prompt_table=ChatPromptTemplate.from_template(prompt_table)


In [None]:

# is creating a data flow pipeline, where each component transforms the input and passes it to the next.
#  1. {"element": lambda x: x}:The key "element" matches the {element} placeholder in your prompt templates.
#  2.| prompt_text or | prompt_table:This takes the formatted dictionary and applies it to the prompt using ChatPromptTemplate.
#  3. | llm:This sends the prompt to your LLM (Language Model), e.g., sambanova, etc., and gets a response.
# 4. | StrOutputParser():This parses the output and converts it into a plain string (instead of an LLM message object).


text_summarize_chain = {"element": lambda x: x} | prompt_text | llm | StrOutputParser()
table_summarize_chain= {"element": lambda x: x} | prompt_table | llm | StrOutputParser()

In [None]:
# print(NarrativeText)
# .batch() method in LangChain to summarize multiple narrative texts in parallel or controlled batches.
NarrativeText_summaries = []
if NarrativeText:
    NarrativeText_summaries = text_summarize_chain.batch(NarrativeText, {'max_concurrency': 1})  #This sets the maximum number of parallel executions to 1, meaning it processes one input at a time (sequentially).

In [None]:
table_summaries = []
table_summaries=table_summarize_chain.batch(table,{"max_concurrency": 3})

# _Image summary_

In [None]:
import openai
import base64
import os 
from langchain_core.messages import HumanMessage
from IPython.display import HTML, display

In [None]:
client = openai.OpenAI(
    base_url="https://api.sambanova.ai/v1", 
    api_key="SAMBANOVA_API_KEY"
)

In [None]:
def encode_image(image_path):
  with open(image_path, "rb") as image_file:
    return base64.b64encode(image_file.read()).decode('utf-8')

In [None]:
def image_summarizer(prompt,image_base64):
    response = client.chat.completions.create(
        model="Llama-4-Maverick-17B-128E-Instruct",
        messages=[
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": prompt},
                    {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}}
                ]
            }
        ]
    )

    return response.choices[0].message.content

In [None]:

def generate_img_summaries(path):
    """
    Generate summaries and base64 encoded strings for images
    path: Path to list of .jpg files extracted by Unstructured
    """

    # Store base64 encoded images
    img_base64_list = []

    # Store image summaries
    image_summaries = []

    # Prompt
    prompt = """You are an assistant tasked with summarizing images for retrieval. \
    These summaries will be embedded and used to retrieve the raw image. \
    Give a concise summary of the image that is well optimized for retrieval."""

    # Apply to images
    for img_file in sorted(os.listdir(path)):
        if img_file.endswith(".jpg"):
            img_path = os.path.join(path, img_file)
            base64_image = encode_image(img_path)
            img_base64_list.append(base64_image)
            generated_summary = image_summarizer(prompt,base64_image)
            print(generated_summary)
            image_summaries.append(image_summarizer(prompt,base64_image))

    return img_base64_list, image_summaries


# Image summaries
img_base64_list, image_summaries = generate_img_summaries("../ML_PRACTICE/image_store")

In [None]:
image_summaries


In [None]:
img_base64_list

# _Adding to the Vector Store_


In [None]:
import uuid

from langchain.retrievers.multi_vector import MultiVectorRetriever
from langchain.storage import InMemoryStore
from langchain_chroma import Chroma
from langchain_core.documents import Document
# from langchain_community.embeddings import SambaStudioEmbeddings
from langchain_sambanova import SambaNovaCloudEmbeddings
from dotenv import load_dotenv
load_dotenv()
import os


In [None]:
from langchain_sambanova import SambaNovaCloudEmbeddings

embeddings = SambaNovaCloudEmbeddings(
    model="E5-Mistral-7B-Instruct",sambanova_api_key="SAMBANOVA_API_KEY")

In [None]:
# This function creates a MultiVectorRetriever that indexes embeddings of summaries (text, table, and image) in a vector store, while storing the original content in an in-memory docstore. When queried, the retriever uses vector similarity on the summaries to retrieve the most relevant results and maps them back to the original full documents. This enables fast, semantically accurate retrieval while preserving access to detailed source data.



def create_multi_vector_retriever(vectorstore,text_summaries, texts, table_summaries, table, image_summaries,img
):
    """
    Create a retriever that indexes the summary but returns the original image or text.
    """


  # Initialize the storage tier
    store = InMemoryStore()
    id_key = "doc_id"

        # vectorstore = Chroma(
    #     # collection_name='summaries',embedding_function=embeddings )
    # print(vectorstore)
    # print(table_summaries)
    # print(texts)

    # The retriever (empty to start)
    retriever = MultiVectorRetriever(vectorstore=vectorstore, 
                                        docstore=store, 
                                        id_key=id_key)
                                        # search_kwargs={'k': 2})  

        
    print(retriever)

    def add_documents(retriever, doc_summaries, doc_contents):
        print(doc_contents)
        doc_ids = [str(uuid.uuid4()) for _ in doc_contents]
        print(retriever)
        summary_docs = [
                Document(page_content=s, metadata={id_key: doc_ids[i]})
                for i, s in enumerate(doc_summaries)
            ]
        print(summary_docs)

        retriever.vectorstore.add_documents(summary_docs)
        retriever.docstore.mset(list(zip(doc_ids, doc_contents)))

        # Add texts, tables, and images
        # Check that text_summaries is not empty before adding
    if text_summaries:
        # print(text_summaries)
        add_documents(retriever, text_summaries, texts)
       
    # Check that table_summaries is not empty before adding
    if table_summaries:
        # print(table_summaries)
        add_documents(retriever, table_summaries, table)
    # Check that image_summaries is not empty before adding
    if image_summaries:
        add_documents(retriever, image_summaries, img)
      
    return retriever






In [None]:
# The vectorstore to use to index the child chunks
vectorstore = Chroma(
    collection_name='summaries',embedding_function=embeddings,persist_directory="sample-rag-multi-modal" )
print(vectorstore)

# def create_multi_vector_retriever(vectorstore,text_summaries, texts, table_summaries, table, image_summaries, img
# ):
# Create retriever
retriever_multi_vector_img = create_multi_vector_retriever(
    vectorstore,
    NarrativeText_summaries,
    NarrativeText,
    table_summaries,
    table,
    image_summaries,
    img_base64_list,
)

retriever_multi_vector_img

# _RAG_Build_retriever_

In [None]:
import io
import re

from IPython.display import HTML, display
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from PIL import Image

In [None]:
# Function	               Purpose in Pipeline
# split_image_text_types:	   Separate raw data into images vs text
# resize_base64_image	:       Optimize image for LLM input
# looks_like_base64	:       Ensure clean, valid input
# img_prompt_func	   :        Construct a prompt that blends both formats
# plt_img_base64	  :          Optional — visualize/debug images

In [None]:
def display_base64_images(image_list):
    for i, img in enumerate(image_list):
        html = f"<h4>🖼️ Image {i+1}</h4><img src='data:image/jpeg;base64,{img}' width='500'/>"
        display(HTML(html))


In [None]:
def plt_img_base64(img_base64):
    """Disply base64 encoded string as image"""
    # Create an HTML img tag with the base64 string as the source
    image_html = f'<img src="data:image/jpeg;base64,{img_base64}" />'
    # Display the image by rendering the HTML
    display(HTML(image_html))

In [None]:
def looks_like_base64(sb):
    """Check if the string looks like base64"""
    return re.match("^[A-Za-z0-9+/]+[=]{0,2}$", sb) is not None

In [None]:
def is_image_data(b64data):
    """
    Check if the base64 data is an image by looking at the start of the data
    """
    image_signatures = {
        b"\xff\xd8\xff": "jpg",
        b"\x89\x50\x4e\x47\x0d\x0a\x1a\x0a": "png",
        b"\x47\x49\x46\x38": "gif",
        b"\x52\x49\x46\x46": "webp",
    }
    try:
        header = base64.b64decode(b64data)[:8]  # Decode and get the first 8 bytes
        for sig, format in image_signatures.items():
            if header.startswith(sig):
                return True
        return False
    except Exception:
        return False

In [None]:
def resize_base64_image(base64_string, size=(128, 128)):
    """
    Resize an image encoded as a Base64 string
    """
    # Decode the Base64 string
    img_data = base64.b64decode(base64_string)
    img = Image.open(io.BytesIO(img_data))

    # Resize the image
    resized_img = img.resize(size, Image.LANCZOS)

    # Save the resized image to a bytes buffer
    buffered = io.BytesIO()
    resized_img.save(buffered, format=img.format)

    # Encode the resized image to Base64
    return base64.b64encode(buffered.getvalue()).decode("utf-8")

In [None]:

def split_image_text_types(docs):
    """
    Split base64-encoded images and texts
    """
    b64_images = []
    texts = []
    for doc in docs:
        # Check if the document is of type Document and extract page_content if so
        if isinstance(doc, Document):
            doc = doc.page_content
        if looks_like_base64(doc) and is_image_data(doc):
            doc = resize_base64_image(doc, size=(400, 300))
            b64_images.append(doc)
        else:
            texts.append(doc)
    return {"images": b64_images, "texts": texts}

In [None]:
def img_prompt_func(data_dict):
    """
    Join the context into a single string
    """
    formatted_texts = "\n".join(data_dict["context"]["texts"])
    messages = []

    # Adding image(s) to the messages if present
    if data_dict["context"]["images"]:
        for image in data_dict["context"]["images"]:
            image_message = {
                "type": "image_url",
                "image_url": {"url": f"data:image/jpeg;base64,{image}"},
            }
            messages.append(image_message)

    # Adding the text for analysis
    text_message = {
        "type": "text",
        "text": (
            "You are analyst and advice.\n"
            "You will be given a mixed of text, tables, and image(s) usually of charts or graphs.\n"
            "Use this information to provide answer related to the user question. \n"
            f"User-provided question: {data_dict['question']}\n\n"
            "Text and / or tables:\n"
            f"{formatted_texts}"
        ),
    }
    messages.append(text_message)
    return [HumanMessage(content=messages)]

In [None]:
from langchain_sambanova import ChatSambaNovaCloud

llm_1= ChatSambaNovaCloud(
    model="Llama-4-Maverick-17B-128E-Instruct",
    # max_tokens=2024,
    temperature=0.7,
    top_p=0.01,
)

##### ***_multi_modal_rag_chain(...) builds the entire chain:_

#### _Retrieves relevant text and images_

#### _Splits them using split_image_text_types_

#### _Formats a multimodal prompt with img_prompt_func_

#### _Passes it to the LLM (Llama-4-Maverick-17B-128E-Instruct)_

#### _Parses the result into a plain string_

In [None]:
def multi_modal_rag_chain(retriever):
    """
    Multi-modal RAG chain
    """

    # Multi-modal LLM
    model = ChatSambaNovaCloud(model="Llama-4-Maverick-17B-128E-Instruct",temperature=0.7,top_p=0.01)
    # RAG pipeline
    chain = (
        {
            "context": retriever | RunnableLambda(split_image_text_types),
            "question": RunnablePassthrough(),
        }
        | RunnableLambda(img_prompt_func)
        | model
        | StrOutputParser()
    )

    return chain

In [None]:
def query_multimodal_rag(query_text):
    # Run the multimodal RAG chain
    response = chain_multimodal_rag.invoke(query_text)
    
    # Print response (LLM's answer)
    print("Answer:\n", response)
    
    # Get docs separately (already retrieved)
    docs = retriever_multi_vector_img.invoke(query_text, limit= 6)
    
    # Display all images retrieved and detected as base64 images
    for img in split_image_text_types(docs)["images"]:
        plt_img_base64(img)

In [None]:
# User can enter the query
query_text = input("Enter your query: ")
query_multimodal_rag(query_text)

In [None]:
chain_multimodal_rag = multi_modal_rag_chain(retriever_multi_vector_img)
query_text = "What is transformer?"

# Run the multimodal RAG chain
response = chain_multimodal_rag.invoke(query_text)

# Print response (LLM's answer)
print("Answer:\n", response)

# Get docs separately (already retrieved)
docs = retriever_multi_vector_img.invoke(query_text, limit=6)

# Display all images retrieved and detected as base64 images
for img in split_image_text_types(docs)["images"]:
    plt_img_base64(img)


In [None]:
# --- Query by Image ---



import base64
from io import BytesIO
from IPython.display import display, Image as IPyImage
import re

chain_multimodal_rag = multi_modal_rag_chain(retriever_multi_vector_img)
query_text = "What is mult-head-attention?"

docs = retriever_multi_vector_img.invoke(query_text, limit=6)

def is_base64_image(text):
    return isinstance(text, str) and re.match(r'^/9j|^iVB', text.strip())  # JPEG or PNG signatures

for i, doc in enumerate(docs):
    print(f"{i}:", end=" ")

    if is_base64_image(doc):
        try:
            # Display inline image
            image_data = base64.b64decode(doc)
            display(IPyImage(data=image_data))
        except Exception as e:
            print(f"[Error displaying image]: {e}")
    else:
        print(doc)
print(docs)

In [None]:
# --- Query by TEXT ---

query_text = "What is mult-head-attention?"

response_text = chain_multimodal_rag.invoke(query_text)
print("Response to text query:")
print(response_text)