In [3]:
import uuid

from langchain.retrievers.multi_vector import MultiVectorRetriever
from langchain.storage import InMemoryStore
from langchain_community.vectorstores import Chroma
from langchain_core.documents import Document
from langchain_openai import OpenAIEmbeddings

In [None]:
def create_multi_vector_retriever(vectorstore,text_summaries,texts,tables_summaries,tables,image_summaries,images):
    
    """
    create retriever that indexes summarieses , but returns raw image or texts
    """
    
    # initialization
    
    store = InMemoryStore()
    id_key='doc_id'
    
    
    #create the mmutli-vector retriever
    
    retriever = MultiVectorRetriever(
        vectorstore=vectorstore,
        docstore=store,
        id_key=id_key
    )
    
    
    def add_documents(retriever,doc_summaries,doc_contents):
        
        doc_ids = [str(uuid.uuid4()) for _ in doc_contents ]
        
        summary_docs = [
            Document(page_contents = s ,metadata={id_key:doc_ids[i]})
            for i,s in enumerate(doc_summaries)
        ]
        
        retriever.vectorstore.add_documents(summary_docs)
        retriever.docstore.mset(list(zip(doc_ids,doc_contents)))
        
        
        #add text tables images/
        #check that text summaries is not emty before adding 
        
        if text_summaries:
            add_documents(retriever,text_summaries,texts)
            
        # check for table too
        
        if tables_summaries:
            add_documents(retriever,tables_summaries,tables)
            
        if image_summaries:
            add_documents(retriever,image_summaries,images)
            
            
        return retriever
    
vectorstore = Chroma(
    collection_name = "mm_rag",embedding_function=OpenAIEmbeddings()
)

## create retriever

retriever_multi_vector_img = create_multi_vector_retriever(
    vectorstore,
    text_summaries,
    Text,
    tables_summaries,
    Table,
    image_summaries,
    img_base64_list,
    
)

In [None]:
retriever_multi_vector_img

In [None]:
import io,re
from IPython.display import HTML,display
from PIL import Image



In [None]:
def plt_img_base64(img_base64):
    
    image_html = f'<img src="data:image/jpeg;base64",{img_base64} />'
    
    display(HTML(image_html))

In [None]:
def plt_img_base64(img_base64):
    
    image_html = f'<img src="data:image/jpeg;base64",{img_base64} />'
    
    display(HTML(image_html))
    
plt_img_base64(img_base64_list[1])


In [None]:
def looks_like_base64(sb):
    """
    check if the string looks like base64
    """
    
    return re.match("^[A-Za-z0-9+/]+[=]{0,2}$",sb) is not None

In [None]:
def is_image_data(b64data):
    """
    Check if the base64 data is an image by looking at the start of the data
    """
    image_signatures = {
        b"\xFF\xD8\xFF": "jpg",
        b"\x89\x50\x4E\x47\x0D\x0A\x1A\x0A": "png",
        b"\x47\x49\x46\x38": "gif",
        b"\x52\x49\x46\x46": "webp",
    }
    try:
        header = base64.b64decode(b64data)[:8]  # Decode and get the first 8 bytes
        for sig, format in image_signatures.items():
            if header.startswith(sig):
                return True
        return False
    except Exception:
        return False


In [None]:
def resize_base64_image(base64_string,size=(128,128)):
    """
    Resize an image encoded as base64 string
    """
    
    # decode the base64 string
    
    img_data = base64.b64decode(base64_string)
    img= Image.open(io.BytesIO(img_data))
    
    #resize images
    resized_img = img.resize(size,Image.LANCZOS)
    
    # save the resized image to a bytes buffer
    buffered = io.BytesIO()
    resized_img.save(buffered,format=img.format)
    
    
    # encode the resized image to base64
    return base64.b64encode(buffered.getvalue()).decode("utf-8")

In [None]:
def split_image_text_types(docs):
    """
    Split base64-encoded images and texts
    """
    
    b64_images = []
    texts = []
    
    for doc in docs:
        
        #check if the document(doc) is of the type document and extract page_content if sorted
        if isinstance(doc,Document):
            doc = doc.page_content
            
        if looks_like_base64(doc) and is_image_data(doc):
            
            doc = resize_base64_image(doc,size=(1300,600))
            b64_images.append(doc)
            
        else:
            texts.append(doc)
            
    print(b64_images)
    print(texts)
        
    return {"images":b64_images ,
                "texts":texts
                }

In [None]:
def img_prompt_func(data_dict):
    """
    Join the context into a single string
    """
    # print(data_dict)
    formatted_texts = "\n".join(data_dict["context"]["texts"])
    messages = []

    # Adding image(s) to the messages if present
    if data_dict["context"]["images"]:
        for image in data_dict["context"]["images"]:
            image_message = {
                "type": "image_url",
                "image_url": {"url": f"data:image/jpeg;base64,{image}"},
            }
            messages.append(image_message)

    # Adding the text for analysis
    text_message = {
        "type": "text",
        "text": (
            "You are a helpful assistant.\n"
            "You will be given a mixed info(s).\n"
            "Use this information to provide relevant information to the user question.\n"
            f"User-provided question: {data_dict['question']}\n\n"
            "Text and/or tables:\n"
            f"{formatted_texts}"
        ),
    }
    messages.append(text_message)

    return [HumanMessage(content=messages)]

In [None]:
from langchain_core.runnables import RunnableLambda,RunnablePassthrough
from langchain.chat_models import ChatOpenAI
from langchain.schema import StrOutputParser


def multi_model_rag_chain(retriever):
    """
    Multi model RAG chain
    """
    
    model = ChatOpenAI(temperature=0,model="gpt-5-vision-preview",max_tokens=1024)
    
    
    # RAG pipeline
    
    chain = (
        {
            "context": retriever | RunnableLambda(split_image_text_types),
            "question": RunnablePassthrough(),
            
        }
        | RunnableLambda(img_prompt_func)
        | model
        | StrOutputParser()
        
    )
    
    return chain

In [None]:
# create RAG chain

chain_multi_model_rag = multi_model_rag_chain(retriever_multi_vector_img)


In [None]:
chain_multi_model_rag

In [None]:
query = "whats the paper about ?"

In [None]:
chain_multi_model_rag.invoke(query)