In [1]:
import os
from dotenv import load_dotenv
from mistralai import Mistral
load_dotenv()
mistral_api_key = os.environ["MISTRAL_API_KEY"]

In [2]:
# Specify model
model = "pixtral-12b-2409"

# Initialize the Mistral client
mistral_model = Mistral(api_key=mistral_api_key)

Multimodal RAG APP

In [3]:
from typing import Any
from unstructured.partition.pdf import partition_pdf
import pytesseract

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
pytesseract.pytesseract.tesseract_cmd = r"C:/Program Files/Tesseract-OCR/tesseract.exe"
input_path = os.getcwd()
output_path = os.path.join(
    os.getcwd(),
    "figures"
)
raw_pdf_elements = partition_pdf(
    filename=os.path.join(
        input_path,
        "startupai-financial-report-v2.pdf"
    ),
    extract_images_in_pdf=True,
    infer_table_structure=True,
    chunking_strategy="by_title",
    max_characters=4000,
    new_after_n_chars=3800,
    combine_text_under_n_chars=2000,
    image_output_dir_path=output_path
)

In [5]:
raw_pdf_elements

[<unstructured.documents.elements.CompositeElement at 0x1b1b9d71c30>]

Extracting the relevant info

Store the text, images and tables in 3 lists
we need to send images in binary format using base64


In [6]:
import base64
text_elements = []
table_elements = []
image_elements = []

def encode_image(image_path):
    with open(image_path, "rb") as image_file:
        return base64.b64encode(
            image_file.read(),
        ).decode('utf-8')
    
for element in raw_pdf_elements:
    if "CompositeElement" in str(type(element)):
        text_elements.append(element)
    elif "Table" in str(type(element)):
        table_elements.append(element)

table_elements = [i.text for i in table_elements]
text_elements = [i.text for i in text_elements]

print(f"Number of table elements in the pdf file: {len(table_elements)}")
print(f"Number of text elements in the pdf file: {len(text_elements)}")

Number of table elements in the pdf file: 0
Number of text elements in the pdf file: 1


Images are currently stored in figured folder
we will encode it using base64

In [7]:
for image_file in os.listdir(output_path):
    if image_file.endswith((".png",".jpg",".jped")):
        image_path = os.path.join(output_path, image_file)
        encoded_image = encode_image(image_path)
        image_elements.append(encoded_image)
print(f"Number of image elements in the pdf file: {len(image_elements)}")


Number of image elements in the pdf file: 6


Create 3 functions to summarize images, tables and texts

In [8]:
from langchain_groq import ChatGroq
load_dotenv()
groq_key = os.getenv("GROQ")
groq_model = ChatGroq(
    model="llama3-70b-8192",
    groq_api_key=groq_key)

In [19]:
from langchain.schema.messages import HumanMessage, AIMessage
def summarize_text(text_element):
    prompt=f"Summarize the following text:\n\n{text_element}\n\nSummary:"
    response = groq_model.invoke(
        [
            HumanMessage(
                content=prompt
            )
        ]
    )
    return response.content

def summarize_table(table_element):
    prompt=f"Summarize the following table:\n\n{table_element}\n\nSummary:"
    response = groq_model.invoke(
        [
            HumanMessage(
                content=prompt
            )
        ]
    )
    return response.content

def summarize_image(encoded_image):
    prompt = [
    {
        "role": "user",
        "content": [
                {
                    "type": "text",
                    "text": "Describe the contents of this image"
                },
                {
                    "type": "image_url",
                    "image_url": {
                        "url": f"data:image/jpeg;base64,{encoded_image}"
                    }
                }
        ]
    }        
]
    try: 
        # Get the chat response
        chat_response = mistral_model.chat.complete(
            model=model,
            messages=prompt
        )
    except Exception as e:
        import time
        time.sleep(6)

    # Print the content of the response
    return chat_response.choices[0].message.content


In [20]:
text_summaries = []
for i, te in enumerate(text_elements[0:2]):
    summary = summarize_text(te)
    text_summaries.append(summary)
    print(f"{i+1}th element of texts processed")

tabe_summaries = []
for i, te in enumerate(table_elements[0:2]):
    summary = summarize_table(te)
    tabe_summaries.append(summary)
    print(f"{i+1}th element of table processed")

image_summaries = []
for i, te in enumerate(image_elements[0:10]):
    summary = summarize_image(te)
    image_summaries.append(summary)
    print(f"{i+1}th element of image processed")


1th element of texts processed
1th element of image processed
2th element of image processed
3th element of image processed
4th element of image processed
5th element of image processed
6th element of image processed


Now we will proceed with the RAG technique

In [37]:
import uuid
from langchain.storage import InMemoryStore
from langchain.schema.document import Document
from langchain.retrievers.multi_vector import MultiVectorRetriever
from langchain_chroma import Chroma
from langchain.embeddings import HuggingFaceEmbeddings

vectorstorev2 = Chroma(collection_name="summaries",
                       embedding_function=HuggingFaceEmbeddings())
storev2 = InMemoryStore()
id_key = "doc_id"

retrieverv2 = MultiVectorRetriever(
    vectorstore=vectorstorev2,
    docstore=storev2, id_key=id_key,
    search_kwargs={"k": 3}
)

def add_documents_to_retriever(summaries, original_contents):
    doc_ids = [str(uuid.uuid4()) for _ in summaries]
    summary_docs = [
        Document(
            page_content=s, metadata={
                id_key: doc_ids[i]
            }
        )
        for i, s in enumerate(summaries)
    ]
    retrieverv2.vectorstore.add_documents(summary_docs)
    retrieverv2.docstore.mset(
        list(
            zip(
                doc_ids,
                original_contents
            )
        )
    )

  embedding_function=HuggingFaceEmbeddings())


In [38]:
add_documents_to_retriever(
    text_summaries, text_elements
)

# add_documents_to_retriever(
#     tabe_summaries, table_elements
# )

add_documents_to_retriever(
    image_summaries, image_elements
)

Now we will retrieve

In [54]:
from langchain.schema.runnable import RunnablePassthrough
from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser

template = """
Answer the question based only on the following context, 
which can include text, images and tables: {context}
Question: {question}
"""
groq_model = ChatGroq(
    model="llama3-70b-8192",
    groq_api_key=groq_key)
prompt = ChatPromptTemplate.from_template(template)

def retrieve_context(query):
    docs = retrieverv2.get_relevant_documents(query)  # Use correct method
    return docs

chain = (
    {
        "context": RunnablePassthrough() | retrieve_context,
        "question": RunnablePassthrough()
    }
    | prompt
    | groq_model
    | StrOutputParser()
)



In [56]:
chain.invoke(
    "What is the name of the company?"
)

'The name of the company is StartupAI.'