# **1. Importing Libraries**

In [1]:
import io
import re
import uuid
import base64
import chromadb
import numpy as np
from PIL import Image
from io import BytesIO
from operator import itemgetter
from IPython.display import HTML, display

from langchain.vectorstores import Chroma
from langchain.storage import InMemoryStore
from langchain.schema.document import Document
from langchain.embeddings import OpenAIEmbeddings
from langchain.retrievers.multi_vector import MultiVectorRetriever

from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableLambda, RunnablePassthrough

# **2. Multimodal RAG**

**Load the vector store from local path**

In [2]:
# Load the vector store and retriever
vectorstore = Chroma(collection_name="multi_modal_rag",
                     embedding_function=OpenAIEmbeddings(),
                     persist_directory="chroma_langchain_db")

id_key = "doc_id"
store = InMemoryStore()
retriever = MultiVectorRetriever(
    vectorstore=vectorstore,
    docstore=store,
    id_key=id_key,
)
retriever = vectorstore.as_retriever()

  embedding_function=OpenAIEmbeddings(),
  vectorstore = Chroma(collection_name="multi_modal_rag",


**RAG utils**

In [3]:
def plt_img_base64(img_base64):
    """Disply base64 encoded string as image"""
    # Create an HTML img tag with the base64 string as the source
    image_html = f'<img src="data:image/jpeg;base64,{img_base64}" />'
    # Display the image by rendering the HTML
    display(HTML(image_html))


def looks_like_base64(sb):
    """Check if the string looks like base64"""
    return re.match("^[A-Za-z0-9+/]+[=]{0,2}$", sb) is not None


def is_image_data(b64data):
    """ Check if the base64 data is an image by looking at the start of the data """
    image_signatures = {
        b"\xff\xd8\xff": "jpg",
        b"\x89\x50\x4e\x47\x0d\x0a\x1a\x0a": "png",
        b"\x47\x49\x46\x38": "gif",
        b"\x52\x49\x46\x46": "webp",
    }
    try:
        header = base64.b64decode(b64data)[:8]  # Decode and get the first 8 bytes
        for sig, format in image_signatures.items():
            if header.startswith(sig):
                return True
        return False
    except Exception:
        return False


def resize_base64_image(base64_string, size=(128, 128)):
    """ Resize an image encoded as a Base64 string """
    # Decode the Base64 string
    img_data = base64.b64decode(base64_string)
    img = Image.open(io.BytesIO(img_data))

    # Resize the image
    resized_img = img.resize(size, Image.LANCZOS)

    # Save the resized image to a bytes buffer
    buffered = io.BytesIO()
    resized_img.save(buffered, format=img.format)

    # Encode the resized image to Base64
    return base64.b64encode(buffered.getvalue()).decode("utf-8")


def split_image_text_types(docs):
    """ Split base64-encoded images and texts """
    b64_images = []
    texts = []
    for doc in docs:
        if isinstance(doc, Document):
            doc = doc.page_content
        if looks_like_base64(doc) and is_image_data(doc):
            doc = resize_base64_image(doc, size=(1300, 600))
            b64_images.append(doc)
        else:
            texts.append(doc)
    return {"images": b64_images, "texts": texts}


def img_prompt_func(data_dict):
    """ Join the context into a single string """
    formatted_texts = "\n".join(data_dict["context"]["texts"])
    messages = []

    # Adding image(s) to the messages if present
    if data_dict["context"]["images"]:
        for image in data_dict["context"]["images"]:
            image_message = {
                "type": "image_url",
                "image_url": {"url": f"data:image/jpeg;base64,{image}"},
            }
            messages.append(image_message)

    # Adding the text for analysis
    text_message = {
        "type": "text",
        "text": (
            "Answer the question based on the following context, which can include text, tables, and images."
            f"User-provided question: {data_dict['question']}\n\n"
            "Text and / or tables:\n"
            f"{formatted_texts}"
        ),
    }
    messages.append(text_message)
    return [HumanMessage(content=messages)]


def multi_modal_rag_chain(retriever):
    """ Multi-modal RAG chain """
    model = ChatOpenAI(temperature=0,
                      model="gpt-4o-mini",
                      max_tokens=1024,
                      streaming=True)
    # RAG pipeline
    chain = (
        {
            "context": retriever | RunnableLambda(split_image_text_types),
            "question": RunnablePassthrough(),
        }
        | RunnableLambda(img_prompt_func)
        | model
        | StrOutputParser()
    )

    return chain

In [4]:
# Check retrieval
query = "What is dialated attention?"
docs = retriever.invoke(query, k=10)
print("Documents Retreived: ", len(docs))

# Generate a response 
chain_multimodal_rag = multi_modal_rag_chain(retriever)
for chunk in chain_multimodal_rag.invoke(query):
    print(chunk, end="", flush=True)

Documents Retreived:  10
Dilated attention is a mechanism used in processing input data by dividing it into segments and applying a sparsification technique. It involves three main components: Query (Q), Key (K), and Value (V). The input is segmented into N parts, each of a specified length, and then sparsified by selecting rows at regular intervals.

Key features of dilated attention include:

- **Transformation**: It can be converted into dense attention through specific operations on the input and output.
- **Optimization**: This transformation allows for leveraging existing optimizations from traditional attention mechanisms, enhancing efficiency.
- **Efficiency**: Dilated attention significantly reduces computational costs, achieving a decrease by a factor of \( N r^2 \) compared to vanilla attention.

Additionally, dilated attention demonstrates almost constant latency when scaling up sequence lengths, making it capable of handling very large sequences efficiently, unlike vanilla