In [None]:
# for mac
!brew install poppler tesseract libmagic
%pip install -Uq "unstructured[pdf]" pillow lxml pillow
%pip install -Uq chromadb tiktoken
%pip install -Uq langchain langchain-community langchain-openai
%pip install -Uq python_dotenv
%pip install -Uq  pdf2image pytesseract

## Extract the data


### Partition PDF tables, text, and images

In [2]:
from unstructured.partition.pdf import partition_pdf
from dotenv import load_dotenv
import os

# Load environment variables from .env file
load_dotenv()
# Get the file path from environment variables
file_path = os.getenv("PDF_FILE_PATH", "./content/attention.pdf")
if not file_path:
    raise ValueError("PDF_FILE_PATH not found in environment variables. Please check your .env file.")

# file_path = "./content/attention.pdf"

# Reference: https://docs.unstructured.io/open-source/core-functionality/chunking
chunks = partition_pdf(
    filename=file_path,
    infer_table_structure=True,            # extract tables
    strategy="hi_res",                     # mandatory to infer tables
    extract_image_block_types=["Image", "Table"],   # Add 'Table' to list to extract image of tables
    extract_image_block_to_payload=True,   # if true, will extract base64 for API usage
    chunking_strategy="by_title",          # or 'basic'
    max_characters=10000,                  # defaults to 500
    combine_text_under_n_chars=2000,       # defaults to 0
    new_after_n_chars=6000,  
)

  from .autonotebook import tqdm as notebook_tqdm


### Storing Extracted Data

In [3]:
import os
import base64
from PIL import Image
import io
import json 

# After your chunks extraction
tables = []

for chunk in chunks:
    if "CompositeElement" in str(type(chunk)):
        # Get the original elements from the composite
        chunk_els = chunk.metadata.orig_elements
        for el in chunk_els:
            # Check if the element is a Table
            if "Table" in str(type(el)):
                tables.append(el)
    # Also check if the chunk itself is a Table
    elif "Table" in str(type(chunk)):
        tables.append(chunk)

# Now create directories and save the tables
output_base = "extracted_content"
dirs = {
    "images": os.path.join(output_base, "images"),
    "tables": os.path.join(output_base, "tables"),
    "table_images": os.path.join(output_base, "table_images"),
    "text": os.path.join(output_base, "text")
}

# Create directories if they don't exist
for dir_path in dirs.values():
    os.makedirs(dir_path, exist_ok=True)

# Function to save base64 image
def save_base64_image(base64_str, output_path):
    if base64_str:
        img_data = base64.b64decode(base64_str)
        img = Image.open(io.BytesIO(img_data))
        img.save(output_path)

# Function to extract table metadata
def extract_table_metadata(chunk):
    metadata = {}
    # Extract all available metadata attributes
    for attr in dir(chunk.metadata):
        if not attr.startswith('_'):  # Skip private attributes
            value = getattr(chunk.metadata, attr)
            # Convert non-serializable objects to string representation
            if not isinstance(value, (str, int, float, bool, list, dict, type(None))):
                value = str(value)
            metadata[attr] = value
    return metadata

# Save tables
for i, table in enumerate(tables):
    table_dir = os.path.join(dirs["tables"], f"table_{i}")
    os.makedirs(table_dir, exist_ok=True)
    
    # Save table as text
    with open(os.path.join(table_dir, "content.txt"), "w") as f:
        f.write(table.text)
    
    # Save table HTML
    with open(os.path.join(table_dir, "content.html"), "w") as f:
        f.write(table.metadata.text_as_html)
    
    # Save table metadata
    metadata = extract_table_metadata(table)
    with open(os.path.join(table_dir, "metadata.json"), "w") as f:
        json.dump(metadata, f, indent=2)
    
    # Save table image if available
    if hasattr(table.metadata, "image_base64"):
        save_base64_image(
            table.metadata.image_base64,
            os.path.join(dirs["table_images"], f"table_{i}.png")
        )
        
        
# Save text and images from chunks
for i, chunk in enumerate(chunks):
    if "CompositeElement" in str(type(chunk)):
        # Save text content
        with open(os.path.join(dirs["text"], f"text_{i}.txt"), "w") as f:
            f.write(chunk.text)
        
        # Check for images in composite elements
        chunk_els = chunk.metadata.orig_elements
        for j, el in enumerate(chunk_els):
            if "Image" in str(type(el)):
                save_base64_image(
                    el.metadata.image_base64,
                    os.path.join(dirs["images"], f"image_{i}_{j}.png")
                )

In [None]:
tables

### Extracting tables separately

In [5]:
# # Extracting Tables
# elements = partition_pdf(filename=file_path,
#                          infer_table_structure=True,
#                          strategy='hi_res',
#            )

# tables = [el for el in elements if el.category == "Table"]

# print(tables[0].text)
# print(tables[0].metadata.text_as_html)
# print(tables)

### Storing Extracted Data (extra)

In [6]:
# import os
# import base64
# from PIL import Image
# import io
# import json

# # Create output directories
# output_base = "extracted_content"
# dirs = {
#     "images": os.path.join(output_base, "images"),
#     "tables": os.path.join(output_base, "tables"),
#     "table_images": os.path.join(output_base, "table_images"),
#     "text": os.path.join(output_base, "text")
# }

# # Create directories if they don't exist
# for dir_path in dirs.values():
#     os.makedirs(dir_path, exist_ok=True)

# # Function to save base64 image
# def save_base64_image(base64_str, output_path):
#     if base64_str:
#         img_data = base64.b64decode(base64_str)
#         img = Image.open(io.BytesIO(img_data))
#         img.save(output_path)

# # Function to extract table metadata
# def extract_table_metadata(chunk):
#     metadata = {}
#     # Extract all available metadata attributes
#     for attr in dir(chunk.metadata):
#         if not attr.startswith('_'):  # Skip private attributes
#             value = getattr(chunk.metadata, attr)
#             # Convert non-serializable objects to string representation
#             if not isinstance(value, (str, int, float, bool, list, dict, type(None))):
#                 value = str(value)
#             metadata[attr] = value
#     return metadata

# # Save tables
# for i, table in enumerate(tables):
#     table_dir = os.path.join(dirs["tables"], f"table_{i}")
#     os.makedirs(table_dir, exist_ok=True)
    
#     # Save table as text
#     with open(os.path.join(table_dir, "content.txt"), "w") as f:
#         f.write(table.text)
    
#     # Save table HTML
#     with open(os.path.join(table_dir, "content.html"), "w") as f:
#         f.write(table.metadata.text_as_html)
    
#     # Save table metadata
#     metadata = extract_table_metadata(table)
#     with open(os.path.join(table_dir, "metadata.json"), "w") as f:
#         json.dump(metadata, f, indent=2)
    
#     # Save table image if available
#     if hasattr(table.metadata, "image_base64"):
#         save_base64_image(
#             table.metadata.image_base64,
#             os.path.join(dirs["table_images"], f"table_{i}.png")
#         )

# # Save text and images from chunks
# for i, chunk in enumerate(chunks):
#     if "CompositeElement" in str(type(chunk)):
#         # Save text content
#         with open(os.path.join(dirs["text"], f"text_{i}.txt"), "w") as f:
#             f.write(chunk.text)
#         # Check for images in composite elements
#         chunk_els = chunk.metadata.orig_elements
#         for j, el in enumerate(chunk_els):
#             if "Image" in str(type(el)):
#                 save_base64_image(
#                     el.metadata.image_base64,
#                     os.path.join(dirs["images"], f"image_{i}_{j}.png")
#                 )

### Separate extracted elements into tables, text, and images

In [7]:
# separate texts 
texts = []

for chunk in chunks:
    # if "Table" in str(type(chunk)):
    #     tables.append(chunk)

    if "CompositeElement" in str(type((chunk))):
        texts.append(chunk)

In [None]:
texts

In [9]:
# Get the images from the CompositeElement objects
def get_images_base64(chunks):
    images_b64 = []
    for chunk in chunks:
        if "CompositeElement" in str(type(chunk)):
            chunk_els = chunk.metadata.orig_elements
            for el in chunk_els:
                if "Image" in str(type(el)):
                    images_b64.append(el.metadata.image_base64)
    return images_b64

images = get_images_base64(chunks)

In [None]:
images

## Summarize the data

### Table summaries

In [11]:

from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from dotenv import load_dotenv

# Load environment variables from .env file
load_dotenv()

# Access the OpenAI API key from the environment variable
openai.api_key = os.getenv("OPENAI_API_KEY")

# Table summary prompt
prompt_table = """
You are an assistant tasked with summarizing tables.
Give a concise summary of the table.

Respond only with the summary, no additional comment.
Do not start your message by saying "Here is a summary" or anything like that.
Just give the summary as it is.

Table: {element}
"""
table_prompt = ChatPromptTemplate.from_template(prompt_table)

# Image summary prompt
prompt_image = """
You are an assistant tasked with describing images.
Give a concise description of what you see in the image.

Respond only with the description, no additional comment.
Do not start your message by saying "Here is a description" or anything like that.
Just give the description as it is.

For context, the image is part of a research paper explaining the transformers architecture. 
Be specific about graphs, such as bar plots.
"""

# Create the model and chains
model = ChatOpenAI(temperature=0.5, model="gpt-4o-mini")

# Table summarization chain
table_chain = {"element": lambda x: x} | table_prompt | model | StrOutputParser()

# Summarize tables
tables_html = [table.metadata.text_as_html for table in tables]
table_summaries = table_chain.batch(tables_html, {"max_concurrency": 3})

# Image summarization
messages = [
    (
        "user",
        [
            {"type": "text", "text": prompt_image},
            {
                "type": "image_url",
                "image_url": {"url": "data:image/jpeg;base64,{image}"},
            },
        ],
    )
]
image_prompt = ChatPromptTemplate.from_messages(messages)
image_chain = image_prompt | ChatOpenAI(model="gpt-4o-mini") | StrOutputParser()
image_summaries = image_chain.batch(images)

In [None]:
table_summaries

In [None]:
image_summaries

In [None]:
print(image_summaries[0])

## Load data and summaries to vectorstore

### Create the vectorstore

In [15]:
from langchain_openai import OpenAIEmbeddings
from langchain.vectorstores import Chroma
from langchain.storage import InMemoryStore
from langchain.schema.document import Document
from langchain.retrievers.multi_vector import MultiVectorRetriever
import uuid

# Create vector_db directory if it doesn't exist
os.makedirs("vector_db", exist_ok=True)

# The vectorstore to use to index the child chunks
vectorstore = Chroma(
    collection_name="multi_modal_rag",
    embedding_function=OpenAIEmbeddings(),
    persist_directory="./vector_db"
)

# The storage layer for the parent documents
store = InMemoryStore()
id_key = "doc_id"

# The retriever with increased search results
retriever = MultiVectorRetriever(
    vectorstore=vectorstore,
    docstore=store,
    id_key=id_key,
    search_kwargs={"k": 10}
)

# Add texts (as is, without summarization)
doc_ids = [str(uuid.uuid4()) for _ in texts]
text_docs = [
    Document(page_content=text.text, metadata={id_key: doc_ids[i]}) 
    for i, text in enumerate(texts)
]
if text_docs:
    retriever.vectorstore.add_documents(text_docs)
    retriever.docstore.mset(list(zip(doc_ids, texts)))
    retriever.vectorstore.persist()

# Add table summaries
table_ids = [str(uuid.uuid4()) for _ in tables]
summary_tables = [
    Document(page_content=summary, metadata={id_key: table_ids[i]}) 
    for i, summary in enumerate(table_summaries)
    if summary
]
if summary_tables:
    retriever.vectorstore.add_documents(summary_tables)
    retriever.docstore.mset(list(zip(table_ids, tables)))
    retriever.vectorstore.persist()

# Add image summaries
img_ids = [str(uuid.uuid4()) for _ in images]
summary_img = [
    Document(page_content=summary, metadata={id_key: img_ids[i]}) 
    for i, summary in enumerate(image_summaries)
    if summary
]
if summary_img:
    retriever.vectorstore.add_documents(summary_img)
    valid_images = [(id_, img) for id_, img in zip(img_ids, images) if img]
    if valid_images:
        retriever.docstore.mset(valid_images)
    retriever.vectorstore.persist()

  vectorstore = Chroma(
  retriever.vectorstore.persist()


## RAG pipeline

In [None]:
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_openai import ChatOpenAI
from base64 import b64decode

def parse_docs(docs):
    """Split retrieved documents into images and text documents"""
    images = []
    texts = []
    for doc in docs:
        # Check if the document is an image by looking for base64 content
        if isinstance(doc, str) and doc.startswith('/9j/'):
            images.append(doc)
        else:
            texts.append(doc)
    
    print(f"Retrieved {len(images)} images and {len(texts)} text documents")
    return {"images": images, "texts": texts}

def build_prompt(kwargs):
    docs_by_type = kwargs["context"]
    user_question = kwargs["question"]

    # Build context sections
    text_context = ""
    if docs_by_type["texts"]:
        text_context = "\n\nText Content:\n" + "\n".join(
            f"- {text.text}" for text in docs_by_type["texts"]
        )

    # Construct the main prompt
    prompt_template = f"""
    Please answer the question based on the following context, which includes text{' and images' if docs_by_type['images'] else ''}.
    
    {text_context}
    
    Question: {user_question}
    
    Please provide a detailed answer, referencing specific details from the {'text and images' if docs_by_type['images'] else 'text'} where relevant.
    """

    prompt_content = [{"type": "text", "text": prompt_template}]

    # Add images with high detail setting
    for image in docs_by_type["images"]:
        prompt_content.append({
            "type": "image_url",
            "image_url": {
                "url": f"data:image/jpeg;base64,{image}",
                "detail": "high"
            }
        })

    return ChatPromptTemplate.from_messages([
        SystemMessage(content="You are a helpful assistant that can analyze both text and images. When referencing images, be specific about visual details and explain their relevance to the question."),
        HumanMessage(content=prompt_content)
    ])

chain = (
    {
        "context": retriever | RunnableLambda(parse_docs),
        "question": RunnablePassthrough(),
    }
    | RunnableLambda(build_prompt)
    | ChatOpenAI(model="gpt-4o-mini")
    | StrOutputParser()
)

# Example usage
question = "Arts and culture of berlin"
response = chain.invoke(question)
print(response)

### Formatting the Answer

In [17]:
# from IPython.display import Image, display
# import base64

# def display_base64_image(base64_code):
#     """Display a base64 encoded image in Jupyter notebook"""
#     image_data = base64.b64decode(base64_code)
#     # Create IPython Image object and display it
#     display(Image(data=image_data))
    
    
# def display_rag_response(question):
#     """
#     Display RAG response in a formatted way with question, answer, and context
    
#     Args:
#         question (str): The question to ask the RAG system
#     """
#     # Execute chain and get response
#     response = chain_with_sources.invoke(question)
    
#     # Print formatted output
#     print("="*50)
#     print("🤔 Question:", question)
#     print("="*50)
#     print("\n📝 Response:", response['response'])
#     print("\n" + "="*50)
    
#     # Display text context
#     print("\n📚 Context:")
#     print("-"*50)
#     for i, text in enumerate(response['context']['texts'], 1):
#         print(f"\nText Source {i}:")
#         print("-"*20)
#         print(text.text)
#         print("\nPage:", text.metadata.page_number)
#         print("-"*50)
    
#     # Display images
#     if response['context']['images']:
#         print("\n🖼️ Image Sources:")
#         print("-"*50)
#         for i, image in enumerate(response['context']['images'], 1):
#             print(f"\nImage {i}:")
#             display_base64_image(image)
#             print("-"*50)
