Need a short video of Bugs Bunny the cartoon character saying “That’s All Folks”

In [21]:
%load_ext autoreload
%autoreload 2

- Load Vector Database
- Run Multi-Model Chain
- Display Anser and Relevant Docs

In [22]:
import os
import fitz 
import io
import requests
import io
import re
from langchain.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from IPython.display import HTML, display
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from PIL import Image
import json
import base64
from langchain_core.messages import HumanMessage
from PIL import Image,ImageFile
from DocumentLoader import DocumentChunk,DocumentExtract
ImageFile.LOAD_TRUNCATED_IMAGES = True
from langchain.vectorstores import FAISS
from langchain.chat_models import AzureChatOpenAI
from langchain.vectorstores import Chroma
import uuid
from langchain.embeddings import AzureOpenAIEmbeddings
from langchain.retrievers.multi_vector import MultiVectorRetriever
from langchain.storage import InMemoryStore
from langchain_community.docstore.in_memory import InMemoryDocstore
from langchain_core.documents import Document

In [23]:
os.environ["AZURE_OPENAI_ENDPOINT"] = """
os.environ["AZURE_OPENAI_API_KEY"] = ""

In [24]:
embed_model = AzureOpenAIEmbeddings(
    azure_deployment="text-embedding-ada-002",
    openai_api_version="2023-05-15",
)

In [25]:
def plt_img_base64(img_base64):
    """Disply base64 encoded string as image"""
    # Create an HTML img tag with the base64 string as the source
    image_html = f'<img src="data:image/jpeg;base64,{img_base64}" />'
    # Display the image by rendering the HTML
    display(HTML(image_html))


def looks_like_base64(sb):
    """Check if the string looks like base64"""
    return re.match("^[A-Za-z0-9+/]+[=]{0,2}$", sb) is not None


def is_image_data(b64data):
    """
    Check if the base64 data is an image by looking at the start of the data
    """
    image_signatures = {
        b"\xFF\xD8\xFF": "jpg",
        b"\x89\x50\x4E\x47\x0D\x0A\x1A\x0A": "png",
        b"\x47\x49\x46\x38": "gif",
        b"\x52\x49\x46\x46": "webp",
    }
    try:
        header = base64.b64decode(b64data)[:8]  # Decode and get the first 8 bytes
        for sig, format in image_signatures.items():
            if header.startswith(sig):
                return True
        return False
    except Exception:
        return False


def resize_base64_image(base64_string, size=(128, 128)):
    """
    Resize an image encoded as a Base64 string
    """
    # Decode the Base64 string
    img_data = base64.b64decode(base64_string)
    img = Image.open(io.BytesIO(img_data))

    # Resize the image
    resized_img = img.resize(size, Image.LANCZOS)

    # Save the resized image to a bytes buffer
    buffered = io.BytesIO()
    resized_img.save(buffered, format=img.format)

    # Encode the resized image to Base64
    return base64.b64encode(buffered.getvalue()).decode("utf-8")


def split_image_text_types(docs):
    """
    Split base64-encoded images and texts
    """
    b64_images = []
    texts = []
    for doc in docs:
        # Check if the document is of type Document and extract page_content if so
        if isinstance(doc, Document):
            doc = doc.page_content
        if looks_like_base64(doc) and is_image_data(doc):
            doc = resize_base64_image(doc, size=(1300, 600))
            b64_images.append(doc)
        else:
            texts.append(doc)
    return {"images": b64_images, "texts": texts}


def img_prompt_func(data_dict):
    """
    Join the context into a single string
    """
    formatted_texts = "\n".join(data_dict["context"]["texts"])
    messages = []

    # Adding image(s) to the messages if present
    if data_dict["context"]["images"]:
        for image in data_dict["context"]["images"]:
            image_message = {
                "type": "image_url",
                "image_url": {"url": f"data:image/jpeg;base64,{image}"},
            }
            messages.append(image_message)

    # Adding the text for analysis
    text_message = {
        "type": "text",
        "text": (
            "Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.\n"
            "You will be given a mixed of text, tables, and image(s) usually of charts or graphs.\n"
            "Make sure to be concise and to the point in your answer\n"
            f"User-provided question: {data_dict['question']}\n\n"
            "Text and / or tables:\n"
            f"{formatted_texts}"
        ),
    }
    messages.append(text_message)
    return [HumanMessage(content=messages)]


def multi_modal_rag_chain(retriever):
    """
    Multi-modal RAG chain
    """

    # Multi-modal LLM
    model = AzureChatOpenAI(temperature=0, model="gpt-4-vision-preview", max_tokens=1024,openai_api_version="2023-07-01-preview",
    azure_deployment="imageanalysis")

    # RAG pipeline
    chain = (
        {
            "context": retriever | RunnableLambda(split_image_text_types),
            "question": RunnablePassthrough(),
        }
        | RunnableLambda(img_prompt_func)
        | model
        | StrOutputParser()
    )

    return chain


## Load Retrieval 

In [26]:
collection_name = ''
vectorstore = Chroma(collection_name=collection_name,persist_directory="./chroma_db", embedding_function=embed_model)
store = InMemoryStore()
id_key = "doc_id"

# Create the multi-vector retriever
retriever = MultiVectorRetriever(
    vectorstore=vectorstore,
    docstore=store,
    id_key=id_key,
)
with open(f'chroma_db/{collection_name}.json', "r") as json_file:
    loaded_data = json.load(json_file)

retriever.docstore.store = loaded_data

## Query Retrieval

In [27]:
query = ""

In [28]:
chain_multi_model = multi_modal_rag_chain(retriever)

In [29]:
docs = retriever.get_relevant_documents(query)

In [None]:
chain_multi_model.invoke(query)

In [32]:
len(docs)

4