In [15]:
import uuid
import base64
from IPython.display import Image, display
from unstructured.partition.pdf import partition_pdf
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.schema import StrOutputParser, Document
from langchain_milvus import Milvus
from langchain.storage import InMemoryStore
from langchain.embeddings import OpenAIEmbeddings
from langchain.retrievers.multi_vector import MultiVectorRetriever
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_core.messages import HumanMessage
from langgraph.graph import StateGraph

In [16]:
def extract_pdf_content(file_path):
    chunks = partition_pdf(
        filename=file_path,
        infer_table_structure=True,
        strategy="hi_res",
        extract_image_block_types=["Image", "Table"],
        extract_image_block_to_payload=True,
        chunking_strategy="by_title",
        max_characters=10000,
        combine_text_under_n_chars=2000,
        new_after_n_chars=6000,
    )
    texts, tables, images = [], [], []
    for chunk in chunks:
        if "CompositeElement" in str(type(chunk)):
            texts.append(chunk.text)
            for el in chunk.metadata.orig_elements:
                if "Table" in str(type(el)):
                    tables.append(el.metadata.text_as_html)
                elif "Image" in str(type(el)):
                    images.append(el.metadata.image_base64)
    return texts, tables, images

In [17]:
def summarize_content(texts, tables, images):
    model = ChatOpenAI(model_name="gpt-4-turbo", temperature=0.5)
    
    prompt_text = """You are an assistant summarizing text and tables. Provide a concise summary of the given content:
    {element}"""
    
    prompt = ChatPromptTemplate.from_template(prompt_text)
    summarize_chain = prompt | model | StrOutputParser()

    text_summaries = summarize_chain.batch(texts, {"max_concurrency": 3})
    table_summaries = summarize_chain.batch(tables, {"max_concurrency": 3})
    
    image_prompt = ChatPromptTemplate.from_messages([
        ("user", [
            {"type": "text", "text": "Describe the image in detail."},
            {"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,{image}"}},
        ])
    ])
    chain = image_prompt | ChatOpenAI(model="gpt-4o-mini") | StrOutputParser()
    image_summaries = chain.batch(images)
    
    return text_summaries, table_summaries, image_summaries

In [18]:
def store_vectors(texts, text_summaries, tables, table_summaries, images, image_summaries):
    # Initialize Milvus and InMemoryStore
    vectorstore = Milvus(
        collection_name="multi_modal_rag",
        embedding_function=OpenAIEmbeddings(),
        connection_args={"uri": "http://localhost:19530"},
    )
    store = InMemoryStore()
    retriever = MultiVectorRetriever(vectorstore=vectorstore, docstore=store, id_key="doc_id")

    # Helper function to add data to the store
    def add_to_store(data, summaries):
        ids = [str(uuid.uuid4()) for _ in data]
        documents = [Document(page_content=summaries[i], metadata={"doc_id": ids[i], "id": ids[i]}) for i in range(len(data))]
        retriever.vectorstore.add_documents(documents, ids=ids)
        retriever.docstore.mset(list(zip(ids, data)))

    # Store text, tables, and images
    add_to_store(texts, text_summaries)
    add_to_store(tables, table_summaries)
    add_to_store(images, image_summaries)

    return retriever

In [19]:
def parse_docs(docs):
    b64, text = [], []
    for doc in docs:
        if isinstance(doc, Document):
            if "image_base64" in doc.metadata:  # Check if it's an image
                b64.append(doc.metadata["image_base64"])
            else:
                text.append(doc)
    return {"images": b64, "texts": text}

In [20]:
def display_base64_image(base64_code):
    if base64_code:
        image_data = base64.b64decode(base64_code)
        display(Image(data=image_data))

In [21]:
def build_prompt(kwargs):
    docs_by_type = kwargs["context"]
    user_question = kwargs["question"]
    context_text = "".join([text_element.page_content for text_element in docs_by_type["texts"]])

    prompt_content = [{"type": "text", "text": f"""
        Answer the question based only on the following context:
        Context: {context_text}
        Question: {user_question}
    """}]

    # Add images to the prompt
    if docs_by_type["images"]:
        for image in docs_by_type["images"]:
            prompt_content.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image}"}})
    
    return ChatPromptTemplate.from_messages([HumanMessage(content=prompt_content)])

In [30]:
class ChatbotState:
    def __init__(self, context=None, question=None, response=None, generated_prompt=None):
        self.context = context
        self.question = question
        self.response = response
        self.generated_prompt = generated_prompt


graph = StateGraph(ChatbotState)

# Node 1: Retrieve context from retriever
graph.add_node("retrieve", lambda state: ChatbotState(
    context={"docs": retriever.invoke(state.question)}, 
    question=state.question
))

# Node 2: Parse retrieved documents
graph.add_node("parse", lambda state: ChatbotState(
    context=state.context, 
    question=state.question,
    response=state.response,
    generated_prompt=parse_docs(state.context["docs"])["parsed_context"]
))

# Node 3: Build the prompt for LLM
graph.add_node("prompt", lambda state: ChatbotState(
    context=state.context, 
    question=state.question,
    response=state.response,
    generated_prompt=build_prompt({"context": state.prompt, "question": state.question})
))

# Node 4: Generate response using GPT-4o-mini
graph.add_node("generate", lambda state: ChatbotState(
    context=state.context, 
    question=state.question, 
    response=ChatOpenAI(model="gpt-4o-mini").invoke(state.prompt).text  # Fixed to use 'state.prompt'
))

# Define graph flow
graph.add_edge("retrieve", "parse")
graph.add_edge("parse", "prompt")
graph.add_edge("prompt", "generate")
graph.set_entry_point("retrieve")

# Compile chatbot
chatbot = graph.compile()

In [25]:
file_path = "Attention-is-all-you-need.pdf"
texts, tables, images = extract_pdf_content(file_path)
text_summaries, table_summaries, image_summaries = summarize_content(texts, tables, images)
retriever = store_vectors(texts, text_summaries, tables, table_summaries, images, image_summaries)

In [31]:
response = chatbot.invoke(ChatbotState(question="Explain the multihead attention?"))

# Print text response
print("Response:", response.response)

# Display images if available
if response.context.get("images"):
    for image in response.context["images"]:
        display_base64_image(image)

InvalidUpdateError: Must write to at least one of []