## Multi-Modal Retrieval-Augmented Generation (RAG) System

This project provides a multi-modal Retrieval-Augmented Generation (RAG) system that processes and analyzes text, tables, and media (images) for generating summaries and insights.

### Generate API Key

To use the ChatGoogleGenerativeAI model, you need an API key. You can generate your API key by following these steps:

1. Go to [Google AI Studio API Key Generation](https://aistudio.google.com/app/apikey).
2. Follow the instructions to generate your API key.

### Install Poppler and Tesseract

For handling PDFs and unstructured data, you will need to install Poppler and Tesseract. Follow the installation instructions below:

- **Poppler:** [Installation Instructions](https://pdf2image.readthedocs.io/en/latest/installation.html)
- **Tesseract:** [Installation Instructions](https://tesseract-ocr.github.io/tessdoc/Installation.html)

#### For Colab or Ubuntu

Run the following commands in your terminal:

```bash
!sudo apt-get install poppler-utils
!sudo apt install tesseract-ocr
!sudo apt install libtesseract-dev
```

[![Open in Google Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/SS-Keval/Multimodal-RAG-Meetup/blob/main/multimodel_rag.ipynb)

#### Install following dependencies

In [None]:
!pip install -U pdfminer.six unstructured pillow-heif pi_heif unstructured_inference pytesseract unstructured.pytesseract "unstructured[all-docs]"
!pip install -U chromadb langchain langchain_huggingface langchain-google-genai langchain-chroma nltk
!python -m nltk.downloader punkt

In [None]:
import base64
import io
import os
import re
import uuid
from typing import Dict, List, Union

from langchain.retrievers.multi_vector import MultiVectorRetriever
from langchain.storage import InMemoryStore
from langchain_chroma import Chroma
from langchain_core.documents import Document
from langchain_core.messages import HumanMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain_google_genai import ChatGoogleGenerativeAI
from PIL import Image
from unstructured.partition.pdf import partition_pdf

In [None]:
gemini_api_key = "[GEMINI_API_KEY]"
document_name = "[DOCUMENT_NAME_OR_PATH]"

In [None]:
def extract_pdf_content_with_metadata(file_path: str):
    """
    Extract content, including images, tables, and text chunks from a PDF file.

    Args:
        file_path (str): The path to the PDF file to be processed.

    Returns:
        list[Element]: A list of Elements containing extracted elements like images, tables, and
                    text, each represented in a structured format.
    """
    return partition_pdf(
        filename=file_path,
        extract_images_in_pdf=True,
        chunking_strategy="by_title",
        max_characters=4000,
        combine_text_under_n_chars=2000,
        extract_image_block_types=["Image", "Table"]
    )

In [None]:
def categorize_pdf_elements_by_type(raw_pdf_elements: list):
    """
    Categorize extracted PDF elements into text content.

    Args:
        raw_pdf_elements (list): List of unstructured documents elements extracted from the PDF.

    Returns:
        List[str]: A list of text elements extracted from the PDF.
    """
    texts = []
    for element in raw_pdf_elements:
        if "unstructured.documents.elements.CompositeElement" in str(type(element)):
            texts.append(str(element))
    return texts

In [None]:
# Extract elements from the PDF
raw_pdf_elements = extract_pdf_content_with_metadata(document_name)

# Categorize elements into text content
texts = categorize_pdf_elements_by_type(raw_pdf_elements)

In [None]:
texts

In [None]:
# Initialize the Google Generative AI model
llm = ChatGoogleGenerativeAI(
    model="gemini-1.5-flash",     # Specify the model version
    temperature=0,                # Set temperature for deterministic responses
    max_output_tokens=2048,              # Set max tokens for each response
    timeout=None,                 # No timeout for API calls
    max_retries=2,                # Retry up to 2 times in case of failure
    api_key=gemini_api_key        # Provide your API key
)

#### Summary Generation Part

In [None]:
# Generate concise summaries for text elements
def summarize_text_elements(text_elements: list):
    """
    Generate concise summaries of text elements for efficient retrieval.

    Args:
        text_elements (list): List of strings representing the text elements to be summarized.

    Returns:
        List[str]: A list of concise summaries optimized for retrieval systems.
    """

    # Updated prompt for summarization
    summary_prompt = """You are an assistant specializing in generating concise and accurate
    summaries for retrieval purposes. The summary should capture the essential details of the
    provided text, making it easily searchable and optimized for information retrieval.
    Please provide a concise summary for the following element: {element}"""

    prompt_template = ChatPromptTemplate.from_template(summary_prompt)

    # Chain for processing and summarizing text elements
    summary_chain = {"element": lambda x: x} | prompt_template | llm | StrOutputParser()

    # Generate summaries with concurrency handling
    summaries = summary_chain.batch(text_elements, {"max_concurrency": 5})

    return summaries


# Generate summaries for the provided text elements
text_element_summaries = summarize_text_elements(text_elements=texts)

In [None]:
text_element_summaries

In [None]:
def encode_media_to_base64(media_path: str) -> str:
    """
    Encode media (image/table) to a base64 string.

    Args:
        media_path (str): The file path to the media.

    Returns:
        str: The base64-encoded string of the media.
    """
    with open(media_path, "rb") as media_file:
        return base64.b64encode(media_file.read()).decode("utf-8")


def summarize_media(base64_media: str, prompt: str) -> str:
    """
    Generate a summary for media (image/table) using the Google Generative AI model.

    Args:
        base64_media (str): Base64-encoded media string.
        prompt (str): The prompt text to provide context for summarization.

    Returns:
        str: The summary of the media content.
    """

    msg = llm.invoke(
        [
            HumanMessage(
                content=[
                    {"type": "text", "text": prompt},
                    {
                        "type": "image_url",
                        "image_url": {"url": f"data:image/jpeg;base64,{base64_media}"},
                    },
                ]
            )
        ]
    )
    return msg.content # type: ignore


def generate_media_summaries(media_dir: str):
    """
    Generate summaries and base64 encoded strings for media (images/tables) in the given directory.

    Args:
        media_dir (str): Directory path containing .jpg or other media files.

    Returns:
        Tuple[List[str], List[str]]: A tuple containing a list of base64-encoded media
        and a list of their respective summaries.
    """

    # Store base64-encoded media and summaries
    base64_media_list = []
    media_summaries = []

    # Prompt for summarizing media (images/tables)
    summary_prompt = """You are an assistant tasked with summarizing images and tables for retrieval.
    These summaries will be embedded and used to retrieve the raw media.
    Provide concise summaries optimized for retrieval."""

    # Process each .jpg or media file in the directory
    for media_file in sorted(os.listdir(media_dir)):
        if media_file.endswith(".jpg"):  # You can extend this to include other media types
            media_path = os.path.join(media_dir, media_file)
            base64_media = encode_media_to_base64(media_path)
            base64_media_list.append(base64_media)
            media_summaries.append(summarize_media(base64_media, summary_prompt))

    return base64_media_list, media_summaries


# Generate media summaries and base64-encoded strings
base64_media_list, media_summaries = generate_media_summaries(media_dir="figures")

In [None]:
media_summaries

In [None]:
base64_media_list

##### Retriever

In [None]:
from langchain_huggingface import HuggingFaceEmbeddings

# Initialize HuggingFace embeddings using the specified model
embedding_model_name = "BAAI/bge-m3"
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name)

# Now you can use `embeddings` to generate vector representations for your text data

In [None]:
def create_multi_vector_retriever(
    vectorstore: Chroma,
    text_summaries: List[str],
    raw_texts: List[str],
    image_summaries: List[str],
    raw_images: List[str]
) -> MultiVectorRetriever:
    """
    Create a retriever that indexes text and image summaries and returns raw text or image content.

    Args:
        vectorstore (Chroma): The vectorstore used to index document summaries.
        text_summaries (List[str]): Summaries of text content to be indexed.
        raw_texts (List[str]): Raw text documents corresponding to the text summaries.
        image_summaries (List[str]): Summaries of image content to be indexed.
        raw_images (List[str]): Base64-encoded image data corresponding to the image summaries.

    Returns:
        MultiVectorRetriever: A retriever that can fetch raw texts or images based on their summaries.
    """
    # Initialize the storage layer
    store = InMemoryStore()
    id_key = "doc_id"

    # Create the multi-vector retriever
    retriever = MultiVectorRetriever(
        vectorstore=vectorstore,
        docstore=store,
        id_key=id_key,
    )

    def add_documents(
        retriever: MultiVectorRetriever,
        summaries: List[str],
        contents: List[Union[str, bytes]]
    ):
        """
        Add documents to the vectorstore and docstore.

        Args:
            retriever (MultiVectorRetriever): The retriever instance.
            summaries (List[str]): Document summaries to be added to the vectorstore.
            contents (List[Union[str, bytes]]): Corresponding raw content (text or image) to be added to the docstore.
        """
        doc_ids = [str(uuid.uuid4()) for _ in contents]
        summary_docs = [
            Document(page_content=summary, metadata={id_key: doc_ids[i]})
            for i, summary in enumerate(summaries)
        ]
        retriever.vectorstore.add_documents(summary_docs)
        retriever.docstore.mset(list(zip(doc_ids, contents)))

    # Add text and image summaries if available
    if text_summaries and raw_texts:
        add_documents(retriever, text_summaries, raw_texts)

    if image_summaries and raw_images:
        add_documents(retriever, image_summaries, raw_images)

    return retriever

In [None]:
# Initialize the vectorstore with embeddings
vectorstore = Chroma(
    collection_name="mm_rag",
    embedding_function=embeddings
)

# Create the multi-vector retriever
retriever_multi_vector = create_multi_vector_retriever(
    vectorstore=vectorstore,
    text_summaries=text_element_summaries,
    raw_texts=texts,
    image_summaries=media_summaries,
    raw_images=base64_media_list,
)

#### Final Answer Generation

In [None]:
def is_base64_encoded(data: str) -> bool:
    """
    Check if the string is a valid base64 encoded data.

    Args:
        data (str): String to be checked.

    Returns:
        bool: True if the string looks like base64 encoded data, otherwise False.
    """
    return re.match(r"^[A-Za-z0-9+/]+[=]{0,2}$", data) is not None


def is_media_base64(data: str) -> bool:
    """
    Check if the base64 data represents a media item.

    Args:
        data (str): Base64-encoded media data.

    Returns:
        bool: True if the data represents a media item, otherwise False.
    """
    media_signatures = {
        b"\xff\xd8\xff": "jpg",
        b"\x89\x50\x4e\x47\x0d\x0a\x1a\x0a": "png",
        b"\x47\x49\x46\x38": "gif",
        b"\x52\x49\x46\x46": "webp",
    }
    try:
        header = base64.b64decode(data)[:8]  # Decode and get the first 8 bytes
        return any(header.startswith(sig) for sig in media_signatures)
    except Exception:
        return False


def resize_base64_media(base64_string: str, size: tuple = (128, 128)) -> str:
    """
    Resize a base64-encoded media item.

    Args:
        base64_string (str): Base64-encoded media string.
        size (tuple): New size for the media item.

    Returns:
        str: Base64-encoded string of the resized media item.
    """
    media_data = base64.b64decode(base64_string)
    media = Image.open(io.BytesIO(media_data))
    resized_media = media.resize(size, Image.LANCZOS)

    buffered = io.BytesIO()
    resized_media.save(buffered, format=media.format)

    return base64.b64encode(buffered.getvalue()).decode("utf-8")


def split_media_and_texts(docs: List[str]) -> Dict[str, List[str]]:
    """
    Split base64-encoded media and texts from a list of documents.

    Args:
        docs (List[str]): List of documents which might be base64-encoded media or texts.

    Returns:
        Dict[str, List[str]]: Dictionary with keys "media" and "texts" containing lists of base64 media and texts respectively.
    """
    media_list = []
    texts = []

    for doc in docs:
        if is_base64_encoded(doc) and is_media_base64(doc):
            media_list.append(resize_base64_media(doc, size=(1300, 600)))
        else:
            texts.append(doc)

    return {"media": media_list, "texts": texts}


def format_prompt(data: Dict[str, Union[List[str], str]]) -> List[HumanMessage]:
    """
    Format data into a prompt for multi-modal analysis.

    Args:
        data (Dict[str, Union[List[str], str]]): Dictionary with "context" containing "texts" and "media", and a "question".

    Returns:
        List[HumanMessage]: List of messages formatted for the LLM.
    """
    formatted_texts = "\n".join(data["context"]["texts"])
    messages = []

    if data["context"]["media"]:
        for media in data["context"]["media"]:
            media_message = {
                "type": "image_url",
                "image_url": {"url": f"data:image/jpeg;base64,{media}"},
            }
            messages.append(media_message)

    text_message = {
        "type": "text",
        "text": (
            "You are an expert tasked with providing analysis and insights based on mixed media inputs.\n"
            "You will receive a combination of text, tables, and media items, including charts and graphs.\n"
            "Your goal is to provide a detailed analysis or answer based on the provided data.\n"
            f"User question: {data['question']}\n\n"
            "Text and/or tables:\n"
            f"{formatted_texts}"
        ),
    }
    messages.append(text_message)

    return [HumanMessage(content=messages)]

In [None]:
def create_multi_modal_rag_chain(retriever) -> RunnableLambda:
    """
    Create a multi-modal RAG chain for processing.

    Args:
        retriever (MultiVectorRetriever): The retriever instance.

    Returns:
        RunnableLambda: The RAG chain for processing.
    """
    chain = (
        {
            "context": retriever | RunnableLambda(split_media_and_texts),
            "question": RunnablePassthrough(),
        }
        | RunnableLambda(format_prompt)
        | llm
        | StrOutputParser()
    )

    return chain

#### Create RAG Chain for Q & A

In [None]:
# Create RAG chain
chain_multimodal_rag = create_multi_modal_rag_chain(retriever_multi_vector)

In [None]:
# Define the query for the retriever
query = "[QUERY]"

# Retrieve documents based on the query using a multi-vector retriever
docs = retriever_multi_vector.invoke(query, limit=5)

In [None]:
docs

In [None]:
# Define the query for the final answer generation or Q&A
query ="[QUERY]"

# Invoke the multimodal RAG chain with the query to generate a final answer
print(chain_multimodal_rag.invoke(query))