# Multimodal RAG for PDF files

This example references [sudarshan koirala's work](https://github.com/sudarshan-koirala/youtube-stuffs/blob/main/langchain/LangChain_Multi_modal_RAG.ipynb) to build a multimodal-rag for pdf that contains tables,images and text paragraphs.  
The vector database used here is Amazon Opensearch Serverless (aoss).  
Refer to the [public documentation](https://docs.aws.amazon.com/opensearch-service/latest/developerguide/serverless-vector-search.html) to set this up.  
Below is the data ingestion pipeline.  

![data ingestion](./diagrams/multimodal-rag.drawio.png)

Below is the inference pipeline

![inference](./diagrams/multimodal-rag-inference.drawio.png)

Install prerequisites

In [None]:
%pip install unstructured[all-docs] transformers opensearch-py boto3


load sample PDF files

In [1]:
import os
data_dir = './data'
target_files = [os.path.join(data_dir,file_name) for file_name in os.listdir(data_dir)]
image_output_dir = 'data-output'

Initialize bedrock runtime client and prompt used to summarize tables and text

In [2]:
import boto3
import json

bedrock_runtime_client = boto3.client('bedrock-runtime',region_name='us-west-2')

summary_prompt = """You are an assistant tasked with summarizing tables and text. \
Give a concise summary of the table or text. Table or text chunk: {element} """

Define methods to invoke bedrock FMs.  
invoke_model uses titan embeddings to convert text to vector embeddings for search  
invoke_llm_model uses claude LLM to summarise the context and produce the final output returned to the user

In [3]:
def invoke_model(input):
    response = bedrock_runtime_client.invoke_model(
        body=json.dumps({
            'inputText': input
        }),
        modelId="amazon.titan-embed-text-v1",
        accept="application/json",
        contentType="application/json",
    )
    response_body = json.loads(response.get("body").read())
    return response_body.get("embedding")

def invoke_llm_model(input):
    response = bedrock_runtime_client.invoke_model(
        body=json.dumps({
            "prompt": "\n\nHuman: {input}\n\nAssistant:".format(input=input),
            "max_tokens_to_sample": 300,
            "temperature": 0.5,
            "top_k": 250,
            "top_p": 1,
            "stop_sequences": [
                "\n\nHuman:"
            ],
            # "anthropic_version": "bedrock-2023-05-31"
        }),
        modelId="anthropic.claude-v2:1",
        accept="application/json",
        contentType="application/json",
    )
    response_body = json.loads(response.get("body").read())
    return response_body.get("completion")

The Blip2 model is used to do image captioning

In [None]:
from PIL import Image
from transformers import Blip2Processor, Blip2ForConditionalGeneration
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"

processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
caption_model = Blip2ForConditionalGeneration.from_pretrained(
    "Salesforce/blip2-opt-2.7b", load_in_8bit=True, device_map={"": 0}, torch_dtype=torch.float16
) 
prompt_caption = "Question: What is in the image? Be specific about graphs, such as bar plots. Answer:"

def generate_image_captions(image_path,prompt):
    image = Image.open(open(image_path,'rb'))
    inputs = processor(images=image, text=prompt, return_tensors="pt").to(device="cuda", dtype=torch.float16)
    generated_ids = caption_model.generate(**inputs,max_new_tokens=50)
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
    return generated_text




The unstructured python package is used to segregate and retrieve the tables, images and text paragraphs from a PDF file.  
The output directory of image files is specified in the image_output_dir_path argument in partition_pdf.  
Tables are summarized to text using the Claude bedrock endpoint. Both the raw table elements and summarized text are stored.  
Images are summarized using the image caption model.  
Text paragraphs are stored as they are. They can be chunked before storing if they are too long.

In [None]:
from unstructured.partition.pdf import partition_pdf
extracted_elements_list = []
for target_file in target_files:
    image_output_dir_path = os.path.join(image_output_dir,target_file.split('/')[1].split('.')[0])
    table_and_text_elements = partition_pdf(
        filename=target_file,
        extract_images_in_pdf=True,
        infer_table_structure=True,
        chunking_strategy="by_title", #Uses title elements to identify sections within the document for chunking
        max_characters=4000,
        new_after_n_chars=3800,
        combine_text_under_n_chars=2000,
        image_output_dir_path=image_output_dir_path,
    )
    tables = []
    texts = []
    for element in table_and_text_elements:
        if "unstructured.documents.elements.Table" in str(type(element)):
            tables.append({'raw':str(element),'summary':invoke_llm_model(summary_prompt.format(element=str(element)))})
        elif "unstructured.documents.elements.CompositeElement" in str(type(element)):
            texts.append(str(element))
    image_captions = []
    for image_file in os.listdir(image_output_dir_path):
        image_captions.append(generate_image_captions(os.path.join(image_output_dir_path,image_file),prompt_caption))
    
    extracted_elements_list.append({
        'source': target_file,
        'tables': tables,
        'texts': texts,
        'images': image_captions
    })

The opensearchpy python package is used to interact with the aoss database

In [10]:
from opensearchpy import OpenSearch, RequestsHttpConnection, AWSV4SignerAuth
host = 'your collection id.region.aoss.amazonaws.com'
region = 'us-west-2'
service = 'aoss'
index = 'your collection index'
credentials = boto3.Session().get_credentials()
auth = AWSV4SignerAuth(credentials, region, service)

ospy_client = OpenSearch(
    hosts = [{'host': host, 'port': 443}],
    http_auth = auth,
    use_ssl = True,
    verify_certs = True,
    connection_class = RequestsHttpConnection,
    pool_maxsize = 20
)

Function to prepare documents for insertion into the aoss database.  
src_doc refers to the original document (e.g. wildfires.pdf)  
raw_element refers to tables,images or text paragraphs extracted from the unstructured python package  
raw_element_type is one of table,image or text  
processed_element refers to post processed elements from raw_element; this can be table summary, text chunks, image captions  
processed_element_embedding refers to vector embeddings generated from processed_element using the embedding model

In [11]:
def prep_document(embedding,raw_element,processed_element,doc_type,src_doc):
    document = { 
        # "_id": str(hash(raw_element)),
        "processed_element_embedding": embedding,
        "processed_element": processed_element,
        "raw_element_type": doc_type,
        "raw_element": raw_element,
        "src_doc": src_doc
    }
    return document

Generate embeddings for text, table summary and image captions and define documents for each of the elements

In [None]:
documents = []
for extracted_element in extracted_elements_list:
    texts = extracted_element['texts']
    tables = extracted_element['tables']
    image_captions = extracted_element['images']
    src_doc = extracted_element['source']
    for text in texts:
        embedding = invoke_model(text)
        document = prep_document(embedding,text,text,'text',src_doc)
        documents.append(document)
    for table in tables:
        table_raw = table['raw']
        table_summary = table['summary']
        embeddings = invoke_model(table_summary)
        document = prep_document(embedding,table_raw,table_summary,'table',src_doc)
        documents.append(document)
    for image_caption in image_captions:
        embedding = invoke_model(image_caption)
        document = prep_document(embedding,image_caption,image_caption,'image',src_doc)
        documents.append(document)

Insert each of the documents into aoss. Bulk indexing is not done at this point in time due to a [bug](https://repost.aws/zh-Hant/questions/QUxXol2_SQRb-7iYoouyjl8A/questions/QUxXol2_SQRb-7iYoouyjl8A/aws-opensearch-serverless-bulk-api-document-id-is-not-supported-in-create-index-operation-request?)

In [None]:
for doc in documents:
    response = ospy_client.index(
        index = index,
        body = doc,
    )

Ask a question.    
Generate embeddings for the question.  
Perform a search against aoss using the k nearest neighbours algorithm to retrieve the most relevant documents.  
Here, the processed_element is returned as context but it can be the raw_element as well. E.g. table instead of table summary.  
Pass these documents as context to the Claude LLM and obtain the result.  

In [12]:
question = 'your query'
embedding = invoke_model(question)
k = 4 # number of neighbours, size and k are the same to return k results in total. If size is not specified, k results will be returned per shard.
query = {
    "size": k,
    "query": {
        "knn": {
            "processed_element_embedding": {
                "vector": embedding, 
                "k": k}
            },
    }
}

response = ospy_client.search(
    body = query,
    index = index
)

hits = response['hits']['hits']
prompt_template = """
    The following is a friendly conversation between a human and an AI. 
    The AI is talkative and provides lots of specific details from its context.
    If the AI does not know the answer to a question, it truthfully says it 
    does not know.
    {context}
    Instruction: Based on the above documents, provide a detailed answer for, {question} Answer "don't know" 
    if not present in the document. 
    Solution:"""
context = []
for hit in hits:
    context.append(hit['_source']['processed_element'])


llm_prompt = prompt_template.format(context='\n'.join(context),question=question)
output = invoke_llm_model(llm_prompt)

In [None]:
print(output)