In [None]:
###################################
############# LOADER ##############
###################################

In [None]:
!wget https://sgp.fas.org/crs/misc/IF10244.pdf

In [None]:
from langchain_community.document_loaders import UnstructuredPDFLoader
from unstructured.partition.utils.constants import PartitionStrategy


loader = UnstructuredPDFLoader(
    file_path="./IF10244.pdf",
    strategy=PartitionStrategy.HI_RES,
    infer_table_structure=True,
    extract_images_in_pdf=True,
    chunking_strategy="by_title",
    new_after_n_chars=4000,  # Soft-max
    max_characters=4000,  # Hard-max
    combine_text_under_n_chars=2000,  # Combine chunks of < 200 chars
    mode='elements',  # Split the documents into elements such as Title and NarrativeText.
)
data = loader.load()

In [None]:
len(data)

In [None]:
data[0]

In [None]:
data[0].page_content

In [None]:
data[0].metadata["category"]

In [None]:
[doc.metadata['category'] for doc in data]

In [None]:
data[2]  # Table

In [None]:
data[2].page_content

In [None]:
data[2].metadata['text_as_html']

In [None]:
from IPython.display import display, Markdown

display(Markdown(data[2].metadata['text_as_html']))

In [None]:
### Split data into text and tables list
from htmltabletomd import convert_table

text, tables = [], []

for doc in data:
    if doc.metadata['category'] == 'CompositeElement':
        text.append(doc)
    elif doc.metadata['category'] == 'Table':
        doc.page_content = convert_table(doc.metadata['text_as_html'])
        tables.append(doc)

print(f"Total number of text records: {len(text)}")
print(f"Total number of table records: {len(tables)}")        

In [None]:
from pprint import pprint

pprint(tables[0].page_content)

In [None]:
!ls -ltrh

In [None]:
!ls -ltrh ./figures

In [None]:
from IPython.display import Image

display(Image('./figures/figure-1-2.jpg'))

In [None]:
###################################
############ SUMMARY ##############
###################################

In [None]:
from getpass import getpass

OPENAI_API_KEY = getpass('Enter OpenAI Key: ')

In [None]:
from langchain_openai import ChatOpenAI

CHAT_MODEL = ChatOpenAI(model_name='gpt-4o-mini', api_key=OPENAI_API_KEY, temperature=0)

In [None]:
# Summarize TEXT and TABLES

In [None]:
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from langchain_core.runnables import RunnablePassthrough

summarize_prompt_text = """
You are an assistant summarizing tables or text for retrieval purposes. Create a concise, detailed summary 
optimized for retrieval. If it's a table, include a brief description of its content along with the summary. 
Do not add extra labels like "Summary."

Content to summarize:
{context}
"""
summarize_prompt = ChatPromptTemplate.from_template(summarize_prompt_text)

summarize_chain = {"context": RunnablePassthrough()} | summarize_prompt | CHAT_MODEL | StrOutputParser()

text_docs = [txt.page_content for txt in text]
table_docs = [table.page_content for table in tables]

text_summaries = summarize_chain.batch(text_docs, {"max_concurrency": 5})
table_summaries = summarize_chain.batch(table_docs, {"max_concurrency": 5})

In [None]:
text_summaries

In [None]:
table_summaries

In [None]:
# Summarize Image

In [None]:
import base64
import os
from langchain_core.messages import HumanMessage

def image_to_base64(image_path):
    with open(image_path, "rb") as img_file:
        return base64.b64encode(img_file.read()).decode("utf-8")

def summarize_image(base64_image, prompt_text):
    response = CHAT_MODEL.invoke(
        [
            HumanMessage(
                content=[
                    {"type": "text", "text": prompt_text},
                    {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}},
                ]
            )
        ]
    )
    return response.content

def generate_image_summary(image_folder):
    base64_images = []
    summaries = []

    summary_prompt = """
    You are tasked with summarizing images for retrieval. 
    The images may contain charts, tables, or graphs. 
    Create a detailed summary optimized for retrieval without extra labels like 'Summary:'
    """
    
    for image in sorted(os.listdir(image_folder)):
        if image.endswith(".jpg"):
            full_image_path = os.path.join(image_folder, image)
            encoded_image = image_to_base64(full_image_path)
            base64_images.append(encoded_image)
            summaries.append(summarize_image(encoded_image, summary_prompt))
    
    return base64_images, summaries

encoded_images, img_summaries = generate_image_summary('./figures')

In [None]:
img_summaries[1]

In [None]:
display(Image('./figures/figure-1-2.jpg'))

In [None]:
###################################
########### RETRIEVER #############
###################################

In [None]:
from langchain_openai import OpenAIEmbeddings


OPENAI_EMBEDDING_MODEL = OpenAIEmbeddings(model='text-embedding-ada-002', api_key=OPENAI_API_KEY)

In [None]:
from langchain_chroma import Chroma
from langchain_community.storage import RedisStore
from langchain_community.utilities.redis import get_client

vectorstore = Chroma(
    collection_name='OSM-21-Oct-2024',
    embedding_function=OPENAI_EMBEDDING_MODEL,
    collection_metadata={"hnsw:space": "cosine"},
)

docstore = RedisStore(client=get_client('redis://localhost:6379'))

In [None]:
from langchain.retrievers.multi_vector import MultiVectorRetriever

retriever = MultiVectorRetriever(
    vectorstore=vectorstore,
    docstore=docstore,
)

In [None]:
import uuid

from langchain_core.documents import Document

def add_documents(retriever, summaries, contents):
    # Generate unique IDs for each document
    document_ids = [str(uuid.uuid4()) for _ in contents]
    
    # Create documents using summaries and their associated IDs
    summarized_docs = [
        Document(page_content=summary, metadata={retriever.id_key: document_ids[idx]})
        for idx, summary in enumerate(summaries)
    ]
    
    # Add summarized documents to the vector store
    retriever.vectorstore.add_documents(summarized_docs)
    
    # Map document IDs to their full contents in the docstore
    retriever.docstore.mset(list(zip(document_ids, contents)))

In [None]:
add_documents(retriever, text_summaries, text_docs)

In [None]:
add_documents(retriever, table_summaries, table_docs)

In [None]:
add_documents(retriever, img_summaries, encoded_images)

In [None]:
vectorstore.get(include=["metadatas", "documents", "embeddings"])

In [None]:
vectorstore._collection.count()

In [None]:
##### Test Multi-Vector Retriever #####

In [None]:
query = "Which year has the highest acres burned?"
docs = retriever.invoke(query, limit=5)

In [None]:
len(docs), docs[0]

In [None]:
# Display Image

from IPython.display import HTML, display
from PIL import Image
import base64
from io import BytesIO

def display_base64_image(img_base64):
    # Decode the base64 string
    img_data = base64.b64decode(img_base64)
    # Create a BytesIO object
    img_buffer = BytesIO(img_data)
    # Open the image using PIL
    img = Image.open(img_buffer)
    display(img)

In [None]:
display_base64_image(docs[0])

In [None]:
###################################
########### Synthesis #############
###################################

In [None]:
import re
import base64

def is_base64_encoded(string):
    """Determine if the input string resembles base64 encoding."""
    return re.match("^[A-Za-z0-9+/]+[=]{0,2}$", string) is not None

def is_valid_image_data(encoded_data):
    """
    Verify if the base64 data corresponds to an image by checking its initial bytes.
    """
    image_headers = {
        b"\xff\xd8\xff": "jpg",
        b"\x89\x50\x4e\x47\x0d\x0a\x1a\x0a": "png",
        b"\x47\x49\x46\x38": "gif",
        b"\x52\x49\x46\x46": "webp",
    }
    try:
        decoded_bytes = base64.b64decode(encoded_data)[:8]  # Decode and check the first 8 bytes
        return any(decoded_bytes.startswith(header) for header in image_headers)
    except Exception:
        return False

def categorize_images_and_texts(documents):
    """
    Separate base64-encoded images and text elements (including tables) from documents.
    """
    images = []
    text_elements = []
    
    for doc in documents:
        # Check if it's a Document object and extract page content
        content = doc.page_content.decode('utf-8') if isinstance(doc, Document) else doc.decode('utf-8')
        
        # Classify as image or text
        if is_base64_encoded(content) and is_valid_image_data(content):
            images.append(content)
        else:
            text_elements.append(content)
    
    return {"images": images, "texts": text_elements}

In [None]:
from operator import itemgetter
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain_core.messages import HumanMessage

def generate_multimodal_prompt(data):
    # Format text data
    combined_texts = "\n".join(data["context"]["texts"])
    message_list = []
    
    # If images are present, add them to the messages
    if data["context"]["images"]:
        for img in data["context"]["images"]:
            message_list.append({
                "type": "image_url",
                "image_url": {"url": f"data:image/jpeg;base64,{img}"}
            })
    
    # Add the text content for analysis
    message_list.append({
        "type": "text",
        "text": (
            f"""You are an analyst tasked with interpreting detailed information 
                and identifying trends from text documents, tables, and charts or graphs in images. 
                Below is the context, which includes a combination of text, tables, and images, often in the form of charts or graphs.
                Use this data to answer the user’s question without fabricating information. 
                Rely on the given context to respond accurately.
                
                User question:
                {data['question']}
                
                Context documents:
                {combined_texts}
                
                Answer:
            """
        )
    })
    
    return [HumanMessage(content=message_list)]

In [None]:
multimodal_rag_chain = (
    {"context": itemgetter('context'), "question": itemgetter('input')}
    | RunnableLambda(generate_multimodal_prompt)
    | CHAT_MODEL
    | StrOutputParser()
)

document_retriever = itemgetter('input') | retriever | RunnableLambda(categorize_images_and_texts)


mm_rag = (
    RunnablePassthrough.assign(context=document_retriever).assign(answer=multimodal_rag_chain)
)

In [None]:
query = "Which year has the highest acres burned?"
response = mm_rag.invoke({"input": query})

In [None]:
def format_response(resp):
    display(Markdown("#### Input"))
    print(resp['input'])
    display(Markdown("#### Output"))
    print(resp['answer'])
    display(Markdown("#### Source Documents"))
    display(Markdown("##### Text & Tables"))
    print('\n'.join(resp['context'].get('texts', [])))
    display(Markdown("##### Images"))
    for img in resp['context'].get('images', []):
        display_base64_image(img)

In [None]:
format_response(response)

In [None]:
query = "Give me the number of fires and the acres burned for the department of the interior in the year 2020"
response = mm_rag.invoke({"input": query})
format_response(response)

In [None]:
format_response(response)

In [None]:
query = "Tell me about the number of acres burned by wildfires for the forest service in 2021"
response = mm_rag.invoke({"input": query})
format_response(response)