## 多模态-rag

Advanced RAG - 02. Multi-Modal RAG
The previous tutorial 01. RAG on Semi-structured data introduced RAG development on semi-structured data, for example texts and tables in PDF documents.

BUT, it still can't read images.

Let's learn how to enable image recognition in RAG by employing multi-modal models.

What is Multi-Modal model?

Multi-Modal model can process and analyze data from multiple modalities and provide a more complete and accurate understanding of the underlying data.

GPT-4V

GPT-4V is a multi-modal model that takes in both text and images, and responds with text output. Please refer to GPT-4 Vision for introduction and API guide.

In this tutorial, let's use GPT-4V model to implement multi-modal RAG application that can understand the images embedded in the PDF document and answer relevant questions.

The PDF document we use in this example is the JP Morgan - Weekly Market Recap. It's a small PDF file containing several tables which is a good example for quick data processing and clear demonstration.

Prepare Environment
Let's install the necessary Python packages.

In [None]:
!pip install langchain unstructured[all-docs] pydantic lxml openai chromadb tiktoken -q -U


In [ ]:
# !apt-get install poppler-utils tesseract-ocr


In [ ]:

from typing import Any

from pydantic import BaseModel
from unstructured.partition.pdf import partition_pdf

images_path = "./images"
raw_pdf_elements = partition_pdf(
    filename="weekly_market_recap.pdf",
    extract_images_in_pdf=True,
    infer_table_structure=True,
    chunking_strategy="by_title",
    max_characters=4000,
    new_after_n_chars=3800,
    combine_text_under_n_chars=2000,
    image_output_dir_path=images_path,
)

In [ ]:
from IPython.display import Image

Image('images/figure-1-1.jpg')


In [ ]:
# Image summarizer

import base64
import os

from langchain.chat_models import ChatOpenAI
from langchain.schema.messages import HumanMessage


class ImageSummarizer:

    def __init__(self, image_path) -> None:
        self.image_path = image_path
        self.prompt = """
                    You are an assistant tasked with summarizing images for retrieval.
                    These summaries will be embedded and used to retrieve the raw image.
                    Give a concise summary of the image that is well optimized for retrieval.
                    """

    def base64_encode_image(self):
        with open(self.image_path, "rb") as image_file:
            return base64.b64encode(image_file.read()).decode("utf-8")

    def summarize(self, prompt=None):
        base64_image_data = self.base64_encode_image()
        chat = ChatOpenAI(model="gpt-4-vision-preview", max_tokens=1000)

        # gpt4 vision api doc - https://platform.openai.com/docs/guides/vision
        response = chat.invoke(
            [
                HumanMessage(
                    content=[
                        {
                            "type": "text",
                            "text": prompt if prompt else self.prompt
                        },
                        {
                            "type": "image_url",
                            "image_url": {"url": f"data:image/jpeg;base64,{base64_image_data}"},
                        },
                    ]
                )
            ]
        )
        return base64_image_data, response.content

In [ ]:
image_data_list = []
image_summary_list = []

for img_file in sorted(os.listdir(images_path)):
    if img_file.endswith(".jpg"):
        summarizer = ImageSummarizer(os.path.join(images_path, img_file))
        data, summary = summarizer.summarize()
        image_data_list.append(data)
        image_summary_list.append(summary)


In [ ]:
class Element(BaseModel):
    type: str
    text: Any


table_elements = []
text_elements = []
for element in raw_pdf_elements:
    if "unstructured.documents.elements.Table" in str(type(element)):
        table_elements.append(Element(type="table", text=str(element)))
    elif "unstructured.documents.elements.CompositeElement" in str(type(element)):
        text_elements.append(Element(type="text", text=str(element)))


In [ ]:

from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser

prompt_text = """
  You are responsible for concisely summarizing table or text chunk:

  {element}
"""
prompt = ChatPromptTemplate.from_template(prompt_text)
summarize_chain = {"element": lambda x: x} | prompt | ChatOpenAI(temperature=0,
                                                                 model="gpt-3.5-turbo") | StrOutputParser()


In [ ]:

tables = [i.text for i in table_elements]
table_summaries = summarize_chain.batch(tables, {"max_concurrency": 5})

texts = [i.text for i in text_elements]
text_summaries = summarize_chain.batch(texts, {"max_concurrency": 5})


Use LangChain MultiVectorRetriever to associate summaries of tables, texts and images 
with original data chunks in parent-child relationship.


In [ ]:
import uuid

from langchain.embeddings import OpenAIEmbeddings
from langchain.retrievers.multi_vector import MultiVectorRetriever
from langchain.schema.document import Document
from langchain.storage import InMemoryStore
from langchain.vectorstores import Chroma

id_key = "doc_id"

# The retriever (empty to start)
retriever = MultiVectorRetriever(
    vectorstore=Chroma(collection_name="summaries", embedding_function=OpenAIEmbeddings()),
    docstore=InMemoryStore(),
    id_key=id_key,
)

# Add texts
doc_ids = [str(uuid.uuid4()) for _ in texts]
summary_texts = [
    Document(page_content=s, metadata={id_key: doc_ids[i]})
    for i, s in enumerate(text_summaries)
]
retriever.vectorstore.add_documents(summary_texts)
retriever.docstore.mset(list(zip(doc_ids, texts)))

# Add tables
table_ids = [str(uuid.uuid4()) for _ in tables]
summary_tables = [
    Document(page_content=s, metadata={id_key: table_ids[i]})
    for i, s in enumerate(table_summaries)
]
retriever.vectorstore.add_documents(summary_tables)
retriever.docstore.mset(list(zip(table_ids, tables)))

# Add images
# image_data_list = []
# image_summary_list = []
doc_ids = [str(uuid.uuid4()) for _ in image_data_list]
summary_images = [
    Document(page_content=s, metadata={id_key: doc_ids[i]})
    for i, s in enumerate(image_summary_list)
]
retriever.vectorstore.add_documents(summary_images)
retriever.docstore.mset(list(zip(doc_ids, image_data_list)))


Image helper functions


In [ ]:
from PIL import Image
from IPython.display import HTML, display
import io
import re


def plt_img_base64(img_base64):
    display(HTML(f''))


def is_image_data(b64data):
    """
    Check if the base64 data is an image by looking at the start of the data
    """
    image_signatures = {
        b"\xFF\xD8\xFF": "jpg",
        b"\x89\x50\x4E\x47\x0D\x0A\x1A\x0A": "png",
        b"\x47\x49\x46\x38": "gif",
        b"\x52\x49\x46\x46": "webp",
    }
    try:
        header = base64.b64decode(b64data)[:8]  # Decode and get the first 8 bytes
        for sig, format in image_signatures.items():
            if header.startswith(sig):
                return True
        return False
    except Exception:
        return False


def split_image_text_types(docs):
    """
    Split base64-encoded images and texts
    """
    b64_images = []
    texts = []
    for doc in docs:
        # Check if the document is of type Document and extract page_content if so
        if isinstance(doc, Document):
            doc = doc.page_content

        if is_image_data(doc):
            b64_images.append(doc)
        else:
            texts.append(doc)
    return {"images": b64_images, "texts": texts}


def img_prompt_func(data_dict):
    messages = []

    # Adding image(s) to the messages if present
    if data_dict["context"]["images"]:
        for image in data_dict["context"]["images"]:
            image_message = {
                "type": "image_url",
                "image_url": {"url": f"data:image/jpeg;base64,{image}"},
            }
            messages.append(image_message)

    # Adding texts to the messages
    formatted_texts = "\n".join(data_dict["context"]["texts"])
    text_message = {
        "type": "text",
        "text": (
            "You are financial analyst.\n"
            "You will be given a mixed of text, tables, and image(s) usually of charts or graphs.\n"
            "Use this information to answer the user question in the finance. \n"
            f"Question: {data_dict['question']}\n\n"
            "Text and / or tables:\n"
            f"{formatted_texts}"
        ),
    }
    messages.append(text_message)
    return [HumanMessage(content=messages)]

In [ ]:

from langchain.schema.runnable import RunnableLambda, RunnablePassthrough

model = ChatOpenAI(temperature=0, model="gpt-4-vision-preview", max_tokens=1024)

# RAG pipeline
chain = (
        {
            "context": retriever | RunnableLambda(split_image_text_types),
            "question": RunnablePassthrough(),
        }
        | RunnableLambda(img_prompt_func)
        | model
        | StrOutputParser()
)

In [ ]:
query = "Which year had the highest holiday sales growth?"
chain.invoke(query)


In [ ]:
docs = retriever.get_relevant_documents(query)

len(docs)

In [ ]:
is_image_data(docs[1])


In [ ]:
plt_img_base64(docs[1])
