## Setup

Before we begin, please make sure you have setup the `.env` file in the project 
directory as described in [`README.md`](README.md).

Next, we will load in the necessary environment variables (e.g., API keys) for this notebook:

In [10]:
import os
from dotenv import load_dotenv

_ = load_dotenv()

assert os.environ.get("GOOGLE_API_KEY")

## Data Ingestion

### Partioning and Chunking

In [4]:
from unstructured.ingest.v2.pipeline.pipeline import Pipeline
from unstructured.ingest.v2.interfaces import ProcessorConfig
from unstructured.ingest.v2.processes.connectors.local import (
    LocalIndexerConfig,
    LocalDownloaderConfig,
    LocalConnectionConfig,
    LocalUploaderConfig
)
from unstructured.ingest.v2.processes.partitioner import PartitionerConfig
from unstructured.ingest.v2.processes.chunker import ChunkerConfig


ingest_pipeline = Pipeline.from_configs(
    context=ProcessorConfig(
        verbose=True,
        tqdm=True,
    ),
    indexer_config=LocalIndexerConfig(
        input_path="data/input-docs/LLaVA-subset.pdf"
    ),
    downloader_config=LocalDownloaderConfig(),
    source_connection_config=LocalConnectionConfig(),
    partitioner_config=PartitionerConfig(
        partition_by_api=False,
        strategy="hi_res",        
        additional_partition_args={
            "languages": ["eng"],            
            "extract_images_in_pdf": True,
            "extract_image_block_types": ["Image", "Table"],
            "extract_image_block_output_dir": "data/ingest-output/images",
        }
    ),
    chunker_config=ChunkerConfig(
        chunking_strategy="by_title",
        # Chunking params to aggregate text blocks
        # Attempt to create a new chunk 3800 chars
        # Attempt to keep chunks > 2000 chars
        chunk_max_characters=4000,
        chunk_new_after_n_chars=3800,
        chunk_combine_text_under_n_chars=2000
    ),
    uploader_config=LocalUploaderConfig(
        output_dir="data/ingest-output/"
    ),
)

2024-07-15 18:09:44,323 MainProcess INFO     Created index with configs: {"input_path": "data/input-docs/LLaVA-subset.pdf", "recursive": false, "file_glob": null}, connection configs: {"access_config": {}}
2024-07-15 18:09:44,324 MainProcess INFO     Created download with configs: {"download_dir": null}, connection configs: {"access_config": {}}
2024-07-15 18:09:44,325 MainProcess INFO     Created partition with configs: {"strategy": "hi_res", "ocr_languages": null, "encoding": null, "additional_partition_args": {"languages": ["eng"], "extract_images_in_pdf": true, "extract_image_block_types": ["Image", "Table"], "extract_image_block_output_dir": "data/ingest-output/images"}, "skip_infer_table_types": null, "fields_include": ["element_id", "text", "type", "metadata", "embeddings"], "flatten_metadata": false, "metadata_exclude": [], "metadata_include": [], "partition_endpoint": "https://api.unstructured.io/general/v0/general", "partition_by_api": false, "api_key": null, "hi_res_model_na

In [5]:
print(ingest_pipeline)

index (LocalIndexer) -> download (LocalDownloader) -> partition (hi_res) -> chunk (by_title) -> upload (LocalUploader)


In [6]:
ingest_pipeline.run()

2024-07-15 18:09:47,048 MainProcess INFO     Running local pipline: index (LocalIndexer) -> download (LocalDownloader) -> partition (hi_res) -> chunk (by_title) -> upload (LocalUploader) with configs: {"reprocess": false, "verbose": true, "tqdm": true, "work_dir": "/Users/tclee/.cache/unstructured/ingest/pipeline", "num_processes": 2, "max_connections": null, "raise_on_error": false, "disable_parallelism": false, "preserve_downloads": false, "download_only": false, "max_docs": null, "re_download": false, "uncompress": false, "status": {}, "semaphore": null}
2024-07-15 18:09:47,153 MainProcess DEBUG    Generated file data: FileData(identifier='/Users/tclee/notebooks/multi_modal_rag/data/input-docs/LLaVA-subset.pdf', connector_type='local', source_identifiers=SourceIdentifiers(filename='LLaVA-subset.pdf', fullpath='/Users/tclee/notebooks/multi_modal_rag/data/input-docs/LLaVA-subset.pdf', rel_path='LLaVA-subset.pdf'), doc_type=<IndexDocType.FILE: 'file'>, metadata=DataSourceMetadata(url=N

In [1]:
from unstructured.staging.base import elements_from_json

elements = elements_from_json(
    filename="data/ingest-output/LLaVA-subset.pdf.json"
)

In [2]:
from collections import Counter

display(
    Counter(
        type(element) 
        for element 
        in elements
    )
)

Counter({unstructured.documents.elements.CompositeElement: 17,
         unstructured.documents.elements.Table: 4})

In [3]:
# Categorize elements by type
def categorize_elements(raw_pdf_elements):
    """
    Categorize extracted elements from a PDF into tables and texts.
    
    raw_pdf_elements: List of unstructured.documents.elements
    """
    tables, texts = ([], [])
    
    for element in raw_pdf_elements:
        if element.category == "Table":
            tables.append(element.text)
        elif element.category == "CompositeElement":
            texts.append(element.text)
            
    return texts, tables

In [4]:
# Get text, tables
texts, tables = categorize_elements(elements)

In [5]:
len(tables)

4

In [6]:
len(texts)

17

## Multi-vector retriever

Use [multi-vector-retriever](https://python.langchain.com/docs/modules/data_connection/retrievers/multi_vector#summary) to index image (and / or text, table) summaries, but retrieve raw images (along with raw texts or tables).

### Text and Table summaries

We'll use **Gemini 1.5 Flash** to produce table and text summaries.

Summaries are used to retrieve raw tables and / or raw chunks of text.

In [9]:
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_google_genai import ChatGoogleGenerativeAI


# Generate summaries of text and table elements
def generate_summaries(texts, tables):
    """
    Summarize text elements
    texts: List of str
    tables: List of str    
    """

    # Prompt
    prompt_text = (
        "You are an assistant tasked with summarizing tables "
        "and text for retrieval. These summaries will be embedded "
        "and used to retrieve the raw text or table elements. "
        "Give a concise summary of the table or text that is "
        "well optimized for retrieval.\n\n"
        "Table or text:\n"
        "{element}"
    )
    prompt = ChatPromptTemplate.from_template(
        prompt_text
    )

    # Text summary chain
    model = ChatGoogleGenerativeAI(
        model="gemini-1.5-flash", 
        temperature=0
    )    
    summarize_chain = (
        {"element": lambda x: x} 
        | prompt 
        | model 
        | StrOutputParser()
    )

    # Initialize empty summaries
    text_summaries, table_summaries = ([], [])

    text_summaries = summarize_chain.batch(
        inputs=texts, 
        config={
            "max_concurrency": 5
        }
    )

    table_summaries = summarize_chain.batch(
        inputs=tables, 
        config={
            "max_concurrency": 5
        }
    )    

    return (text_summaries, table_summaries)

In [17]:
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_google_genai import ChatGoogleGenerativeAI


# Generate summaries of text or table elements
def generate_summaries(
    texts: list[str]
) -> list[str]:
    """
    Summarize text elements
    texts: List of str
    """

    # Prompt
    prompt_text = (
        "You are an assistant tasked with summarizing tables "
        "and text for retrieval. These summaries will be embedded "
        "and used to retrieve the raw text or table elements. "
        "Give a concise summary of the table or text that is "
        "well optimized for retrieval.\n\n"
        "Table or text:\n"
        "{element}"
    )
    prompt = ChatPromptTemplate.from_template(
        template=prompt_text
    )

    # Text summary chain
    model = ChatGoogleGenerativeAI(
        model="gemini-1.5-flash", 
        temperature=0
    )    
    summarize_chain = (
        {"element": lambda x: x} 
        | prompt 
        | model 
        | StrOutputParser()
    )

    # Initialize empty summaries
    summaries = []

    summaries = summarize_chain.batch(
        inputs=texts, 
        config={
            "max_concurrency": 5
        }
    )

    return summaries

In [18]:
# Gemini API Rate Limits for Free tier:
# 15 RPM (requests per minute)
text_summaries_1 = generate_summaries(
    texts[:15]
)

In [24]:
text_summaries_2 = generate_summaries(
    texts[15:]
)

In [29]:
table_summaries = generate_summaries(
    tables    
)

In [37]:
text_summaries = text_summaries_1 + text_summaries_2

### Save summaries to JSON

Save the texts, tables and images summaries to JSON files. Next time we can just load the summaries from the JSON files. Calling the LLM API to summarize each time is expensive.

In [43]:
import json

def write_to_json(summaries: list[str], json_path: str):
    """
    Saves list of summaries to a JSON file.
    """
    with open(
        file=json_path, 
        mode='w', 
        encoding='utf-8') as f:
        json.dump(
            summaries, 
            f,        
            ensure_ascii=False,             
            indent=4
        )


def read_from_json(json_path: str) -> list[str]:
    """
    Returns a list of summaries from a JSON file.
    """
    with open(
        file=json_path, 
        mode='r', 
        encoding='utf-8') as f:
        return json.load(f)

In [41]:
# write_to_json(
#     summaries=text_summaries,
#     json_path='data/summaries/text_summaries.json'
# )

In [34]:
# write_to_json(
#     summaries=table_summaries,
#     json_path='data/summaries/table_summaries.json'
# )

In [48]:
text_summaries = read_from_json('data/summaries/text_summaries.json')

In [49]:
table_summaries = read_from_json('data/summaries/table_summaries.json')