In [1]:
# %% Imports
import os, json, uuid, hashlib, base64, pickle
from base64 import b64decode
from pathlib import Path
from typing import List, Dict
from tqdm import tqdm

from unstructured.partition.pdf import partition_pdf

from langchain.chat_models import ChatOpenAI
from langchain_openai import ChatOpenAI as VisionModel
from langchain.vectorstores import Chroma
from langchain.embeddings import OpenAIEmbeddings
from langchain.schema.document import Document
from langchain.retrievers.multi_vector import MultiVectorRetriever
from langchain.storage import InMemoryStore
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_core.messages import HumanMessage

In [2]:
# --- CONFIG ---
SOURCE_DIR = Path("source_docs")
HASH_FILE = Path("output/hashes.json")
VSTORE_DIR = Path("output/vectorstore")
DOCSTORE_PATH = Path("output/docstore/docstore.pkl")

# Create necessary directories
VSTORE_DIR.mkdir(parents=True, exist_ok=True)
DOCSTORE_PATH.parent.mkdir(parents=True, exist_ok=True)
HASH_FILE.parent.mkdir(parents=True, exist_ok=True)
EMBEDDINGS = OpenAIEmbeddings()

  EMBEDDINGS = OpenAIEmbeddings()


In [3]:
def get_file_hash(filepath: Path) -> str:
    hasher = hashlib.sha256()
    with open(filepath, "rb") as f:
        hasher.update(f.read())
    return hasher.hexdigest()

def load_hashes(json_path: Path) -> dict:
    if json_path.exists():
        with open(json_path, "r") as f:
            return json.load(f)
    return {}

def save_hashes(hashes: dict, json_path: Path):
    with open(json_path, "w") as f:
        json.dump(hashes, f, indent=2)


In [4]:
def parse_pdf_elements(filepath: str):
    chunks = partition_pdf(
        filename=filepath,
        infer_table_structure=True,
        strategy="hi_res",
        extract_image_block_types=["Image"],
        extract_image_block_to_payload=True,
        chunking_strategy="by_title",
        max_characters=10000,
        combine_text_under_n_chars=2000,
        new_after_n_chars=6000,
    )
    tables, texts, images = [], [], []
    for chunk in chunks:
        if "CompositeElement" in str(type(chunk)):
            for el in chunk.metadata.orig_elements:
                if "Table" in str(type(el)):
                    tables.append(el)
                elif "Image" in str(type(el)):
                    images.append(el.metadata.image_base64)
            texts.append(chunk)
    return texts, tables, images

In [5]:
def get_text_table_chain():
    prompt = ChatPromptTemplate.from_template("""
    You are an assistant tasked with summarizing tables and text.
    Give a concise summary of the table or text.
    Respond only with the summary, no additional comment.
    Table or text chunk: {element}
    """)
    model = ChatOpenAI(temperature=0.5, model="gpt-4.1-mini")
    return {"element": lambda x: x} | prompt | model | StrOutputParser()

In [6]:
def get_image_chain():
    prompt = ChatPromptTemplate.from_messages([
        ("user", [
            {"type": "text", "text": "Describe the image in detail. For context, it's from a trust fund report. Be specific."},
            {"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,{image}"}}
        ])
    ])
    model = VisionModel(model="gpt-4.1-mini")
    return prompt | model | StrOutputParser()

In [7]:
def load_vectorstore(vstore_dir):
    if not any(Path(vstore_dir).glob("*")):
        return Chroma(collection_name="multi_modal_rag", persist_directory=str(vstore_dir), embedding_function=EMBEDDINGS)
    try:
        return Chroma(collection_name="multi_modal_rag", persist_directory=str(vstore_dir), embedding_function=EMBEDDINGS)
    except Exception:
        import shutil
        shutil.rmtree(vstore_dir, ignore_errors=True)
        return Chroma(collection_name="multi_modal_rag", persist_directory=str(vstore_dir), embedding_function=EMBEDDINGS)

def load_docstore(docstore_path):
    if docstore_path.exists():
        with open(docstore_path, "rb") as f:
            return pickle.load(f)
    return InMemoryStore()

def save_docstore(docstore, docstore_path):
    with open(docstore_path, "wb") as f:
        pickle.dump(docstore, f)

In [8]:
def add_documents_to_retriever(retriever, elements, summaries, filename, id_key="doc_id"):
    if not elements or not summaries or len(elements) == 0 or len(summaries) == 0:
        print(f"⚠️ Skipping empty documents for: {filename} | {id_key}")
        return

    doc_ids = [str(uuid.uuid4()) for _ in elements]
    docs = [
        Document(
            page_content=summaries[i],
            metadata={id_key: doc_ids[i], "source_file": filename},
        )
        for i in range(len(elements))
    ]

    if len(docs) == 0:
        print(f"⚠️ Skipped adding empty document list to retriever for {filename}")
        return

    retriever.vectorstore.add_documents(docs)
    retriever.docstore.mset(list(zip(doc_ids, elements)))


In [9]:
# Load stores
vectorstore = load_vectorstore(VSTORE_DIR)
docstore = load_docstore(DOCSTORE_PATH)
retriever = MultiVectorRetriever(
    vectorstore=vectorstore,
    docstore=docstore,
    id_key="doc_id"
)

  return Chroma(collection_name="multi_modal_rag", persist_directory=str(vstore_dir), embedding_function=EMBEDDINGS)


In [10]:
# Load chains
text_table_chain = get_text_table_chain()
image_chain = get_image_chain()
file_hashes = load_hashes(HASH_FILE)

  model = ChatOpenAI(temperature=0.5, model="gpt-4.1-mini")


In [11]:
for filepath in tqdm(list(SOURCE_DIR.glob("*.pdf")), desc="📄 Processing PDFs"):
    file_hash = get_file_hash(filepath)
    if file_hashes.get(filepath.name) == file_hash:
        tqdm.write(f"✅ Skipping (already processed): {filepath.name}")
        continue

    tqdm.write(f"🧩 Processing new or updated: {filepath.name}")
    try:
        texts, tables, images = parse_pdf_elements(str(filepath))
    except Exception as e:
        tqdm.write(f"❌ Failed to parse {filepath.name}: {str(e)}")
        continue

    text_summaries = text_table_chain.batch(texts, {"max_concurrency": 3})
    table_summaries = text_table_chain.batch([t.metadata.text_as_html for t in tables], {"max_concurrency": 3})
    image_summaries = image_chain.batch(images)

    add_documents_to_retriever(retriever, texts, text_summaries, filepath.name)
    add_documents_to_retriever(retriever, tables, table_summaries, filepath.name)
    add_documents_to_retriever(retriever, images, image_summaries, filepath.name)

    file_hashes[filepath.name] = file_hash

📄 Processing PDFs:   0%|          | 0/5 [00:00<?, ?it/s]

📄 Processing PDFs:  40%|████      | 2/5 [00:00<00:00,  6.60it/s]

✅ Skipping (already processed): 2020TrustFundAnnualReports.pdf
✅ Skipping (already processed): 2021TrustFundAnnualReports.pdf


📄 Processing PDFs:  60%|██████    | 3/5 [00:00<00:00,  7.12it/s]

✅ Skipping (already processed): 2022TrustFundAnnualReports.pdf


📄 Processing PDFs: 100%|██████████| 5/5 [00:00<00:00,  6.50it/s]

✅ Skipping (already processed): 2023TrustFundAnnualReports.pdf
✅ Skipping (already processed): 2024TrustFundAnnualReports.pdf





In [12]:
vectorstore.persist()
save_docstore(docstore, DOCSTORE_PATH)
save_hashes(file_hashes, HASH_FILE)
print("✅ Done: Vectorstore and docstore saved.")

  vectorstore.persist()


✅ Done: Vectorstore and docstore saved.


In [13]:
def parse_docs(docs):
    b64, text = [], []
    for doc in docs:
        try:
            b64decode(doc)
            b64.append(doc)
        except Exception:
            text.append(doc)
    return {"images": b64, "texts": text}

def build_prompt(kwargs):
    docs_by_type = kwargs["context"]
    user_question = kwargs["question"]

    context_text = "".join([t.text for t in docs_by_type["texts"]])
    prompt_template = f"""
    Answer the question based only on the following context, which can include text, tables, and the below image.
    Context: {context_text}
    Question: {user_question}
    """

    prompt_content = [{"type": "text", "text": prompt_template.strip()}]
    for image in docs_by_type["images"]:
        prompt_content.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image}"}})

    return ChatPromptTemplate.from_messages([HumanMessage(content=prompt_content)])

In [14]:
def get_mm_rag_chain(retriever):
    return (
        {
            "context": retriever | RunnableLambda(parse_docs),
            "question": RunnablePassthrough(),
        }
        | RunnableLambda(build_prompt)
        | ChatOpenAI(model="gpt-4o-mini")
        | StrOutputParser()
    )

def get_mm_rag_chain_with_sources(retriever):
    return {
        "context": retriever | RunnableLambda(parse_docs),
        "question": RunnablePassthrough(),
    } | RunnablePassthrough().assign(
        response=(
            RunnableLambda(build_prompt)
            | ChatOpenAI(model="gpt-4o-mini")
            | StrOutputParser()
        )
    )


In [15]:
chain = get_mm_rag_chain(retriever)
chain_with_sources = get_mm_rag_chain_with_sources(retriever)

In [16]:
question = "tell me about major replenishments in FIFs"
response = chain.invoke(question)
print(f"\n🧠 Answer:\n{response}")


🧠 Answer:
Major replenishments in Financial Intermediary Funds (FIFs) have occurred in chronological order, with specific pledging sessions and amounts associated with each fund. Here’s a summary of the key replenishments:

1. **Global Fund**
   - **Pledging Session:** October 2019
   - **Replenishment Cycle Period:** FY2020–22
   - **Amount:** $14.0 billion
   - **Previous Replenishment:** FY2017–19

2. **Green Climate Fund (GCF)**
   - **Pledging Session:** October 2019
   - **Replenishment Cycle Period:** FY2020–23
   - **Amount:** $10.0 billion
   - **Previous Replenishment:** FY2015–19

3. **Global Agriculture and Food Security Program (GAFSP)**
   - **Pledging Session:** October 2020
   - **Replenishment Cycle Period:** FY2020–25
   - **Amount:** $1.5 billion
   - **Previous Replenishment:** FY2010–20

4. **Global Partnership for Education (GPE)**
   - **Pledging Session:** July 2021
   - **Replenishment Cycle Period:** CY2021–25
   - **Amount:** $4.0 billion
   - **Previous Rep