# Multimodal RAG using Bedrock Titan and Claude3

## Install dependencies

In [None]:
!sudo apt-get update -y
!sudo apt -y install poppler-utils tesseract-ocr
!sudo apt install ffmpeg libsm6 libxext6  -y

In [None]:
!pip install pdf2image
!pip install pytesseract
!pip install -U langchain langchain-experimental langchain-aws
!pip install "unstructured[all-docs]==0.10.19" pillow pydantic lxml pillow matplotlib tiktoken open_clip_torch torch
!pip install -U faiss-cpu tiktoken

## Download and process dataset

In [None]:
import os
import shutil

!wget "https://www.getty.edu/publications/resources/virtuallibrary/0892360224.pdf" --no-check-certificate
shutil.move("0892360224.pdf","Data")

In [None]:
path = "/home/sagemaker-user/Bedrock-Claude-Deep-Dive-Workshop/03_Multimodal_RAG/"
file_name = os.listdir(path)

In [None]:
file_name

In [None]:
# Extract images, tables, and chunk text
from unstructured.partition.pdf import partition_pdf

raw_pdf_elements = partition_pdf(
    filename=path + 'Data',
    extract_images_in_pdf=True,
    infer_table_structure=True,
    chunking_strategy="by_title",
    max_characters=4000,
    new_after_n_chars=3800,
    combine_text_under_n_chars=2000,
    image_output_dir_path=path,)

In [None]:
tables = []
texts = []
for element in raw_pdf_elements:
    if "unstructured.documents.elements.Table" in str(type(element)):
        tables.append(str(element))
    elif "unstructured.documents.elements.CompositeElement" in str(type(element)):
        texts.append(str(element))
#
print(len(tables))
print(len(texts))

In [None]:
from PIL import Image
Image.open("./figure-26-1.jpg")

## Import texts and images embedding to Faiss vector database

In [None]:
# Ensure the quality of texts
texts = [text for text in texts if len(text) > 20]
print(len(texts))

In [None]:
# Get image URIs with .jpg extension only
image_uris = sorted(
    [
        os.path.join(path, image_name)
        for image_name in os.listdir(path)
        if image_name.endswith(".jpg")
    ]
)

In [None]:
import boto3

# Bedrock runtime
REGION_NAME = "us-west-2"

client = boto3.client(
    "bedrock-runtime",
    region_name=REGION_NAME,
)

In [None]:
import json
import base64

def base64_encode_image(image_path):
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode('utf8')

def generate_embeddings(image_base64=None, text_description=None):
    input_data = {}

    if image_base64 is not None:
        input_data["inputImage"] = image_base64
    if text_description is not None:
        input_data["inputText"] = text_description

    if not input_data:
        raise ValueError("At least one of image_base64 or text_description must be provided")

    body = json.dumps(input_data)

    response = client.invoke_model(
        body=body,
        modelId="amazon.titan-embed-image-v1",
        accept="application/json",
        contentType="application/json"
    )

    response_body = json.loads(response.get("body").read())

    finish_reason = response_body.get("message")

    if finish_reason is not None:
        raise EmbedError(f"Embeddings generation error: {finish_reason}")

    return response_body.get("embedding")

In [None]:
# Generate images embedding using Titan multimodal embedding
images_embeddings = []
for image in image_uris:
    embedding = (image, generate_embeddings(image_base64=base64_encode_image(image)))
    images_embeddings.append(embedding)

In [None]:
import os
import uuid

import numpy as np
from langchain_community.document_loaders import TextLoader
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import BedrockEmbeddings
from PIL import Image as _PILImage


embeddings = BedrockEmbeddings(client=client, region_name=REGION_NAME, model_id='amazon.titan-embed-image-v1')

# Create Faiss vector store
vectorstore = FAISS.from_texts(texts, embeddings)

In [None]:
# print vectors dimension and vector counts before add images embedding
dimension = vectorstore.index.d
print(f"Dimension of vectors in the index: {dimension}")
print("Vector counts:", vectorstore.index.ntotal)

In [None]:
# Add images embedding
vectorstore.add_embeddings(images_embeddings)

print("Vector counts:", vectorstore.index.ntotal)

retriever = vectorstore.as_retriever()

In [None]:
"""
# Delete vectors
print("count before:", vectorstore.index.ntotal)

for i in range(34):
    vectorstore.delete([vectorstore.index_to_docstore_id[i]])
"""

In [None]:
# Try similartiy search between text and images
docs_and_scores = vectorstore.similarity_search_with_score("Moses and the Messengers from Canaan")
docs_and_scores

## Build Multimodal RAG

In [None]:
import base64
import io
from io import BytesIO

import numpy as np
from PIL import Image


def resize_base64_image(image_path, size=(128, 128)):
    """
    Resize an image encoded as a Base64 string.

    Args:
    base64_string (str): Base64 string of the original image.
    size (tuple): Desired size of the image as (width, height).

    Returns:
    str: Base64 string of the resized image.
    """
    # Decode the Base64 string
    img = Image.open(image_path)

    # Resize the image
    resized_img = img.resize(size, Image.LANCZOS)

    # Save the resized image to a bytes buffer
    buffered = io.BytesIO()
    resized_img.save(buffered, format=img.format)

    # Encode the resized image to Base64
    return base64.b64encode(buffered.getvalue()).decode("utf-8")


def is_image(s):
    """Check if a string is Base64 encoded"""
    try:
        return s.endswith(".jpg") 
    except Exception:
        return False


def split_image_text_types(docs):
    """Split numpy array images and texts"""
    images = []
    text = []
    for doc in docs:
        doc = doc.page_content  # Extract Document contents
        if is_image(doc):
            # Resize image to avoid OAI server error
            images.append(
                resize_base64_image(doc, size=(250, 250))
            )  # base64 encoded str
        else:
            text.append(doc)
    return {"images": images, "texts": text}

In [None]:
from operator import itemgetter

from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableLambda, RunnablePassthrough,RunnableParallel


def prompt_func(data_dict):
    # Joining the context texts into a single string
    formatted_texts = "\n".join(data_dict["context"]["texts"])
    messages = []

    # Adding image(s) to the messages if present
    if data_dict["context"]["images"]:
        image_message = {
            "type": "image_url",
            "image_url": {
                "url": f"data:image/jpeg;base64,{data_dict['context']['images'][0]}"
            },
        }
        messages.append(image_message)

    # Adding the text message for analysis
    text_message = {
        "type": "text",
        "text": (
            "As an expert art critic and historian, your task is to analyze and interpret images, "
            "considering their historical and cultural significance. Alongside the images, you will be "
            "provided with related text to offer context. Both will be retrieved from a vectorstore based "
            "on user-input keywords. Please use your extensive knowledge and analytical skills to provide a "
            "comprehensive summary that includes:\n"
            "- A detailed description of the visual elements in the image.\n"
            "- The historical and cultural context of the image.\n"
            "- An interpretation of the image's symbolism and meaning.\n"
            "- Connections between the image and the related text.\n\n"
            f"User-provided keywords: {data_dict['question']}\n\n"
            "Text and / or tables:\n"
            f"{formatted_texts}"
        ),
    }
    messages.append(text_message)

    return [HumanMessage(content=messages)]

In [None]:
from IPython.display import HTML, display


def plt_img_base64(img_base64):
    # Create an HTML img tag with the base64 string as the source
    image_html = f'<img src="data:image/jpeg;base64,{img_base64}" />'

    # Display the image by rendering the HTML
    display(HTML(image_html))

In [None]:
from langchain_aws import ChatBedrock

# Using Bedrock Claude3
model = ChatBedrock(
    client=client,
    model_id="anthropic.claude-3-sonnet-20240229-v1:0",
    region_name=REGION_NAME,
    model_kwargs={"temperature": 0.1, "max_tokens": 1024},
)

# RAG pipeline
chain = (
    {
        "context": retriever | RunnableLambda(split_image_text_types),
        "question": RunnablePassthrough(),
    }
    | RunnableParallel({"response":prompt_func| model| StrOutputParser(),
                      "context": itemgetter("context"),})
)

In [None]:
# Retrieve related images and texts then invoke Claude3 to generate answer
response = chain.invoke("Madonna and Child with Two Saints and a Donor")
print(response['response'])
plt_img_base64(response['context']['images'][0])

In [None]:
# check context
print(response['context'])