# Implementing multimodal retrieval using Cortex Search Service

Welcome! This tutorial shows a lightweight example where a customer has 2 long pdfs and wants to search and ask natural questions on them. On a high level, this tutorial demonstrates:

- Convert long PDF files to document screenshots (images).
- (Optional but highly recommended) Run parse_document on PDFs for auxiliary text retrieval to further improve quality.
- Embed document screenshots using EMBED_IMAGE_1024 (PrPr) which runs `voyage-multimodal-3` under the hood
- Create a Cortex Search Service using multimodal embeddings and OCR text.
- Retrieve top pages using Cortex Search.
- Get natural language answer with multimodal RAG!

To start with, make sure you have PDFs stored under a stage. The two PDF files used in this demo can be found ![here](https://drive.google.com/drive/folders/1bExhPiJlF9aNushnXeLLBR4m9EMaShHw?usp=sharing).

In [None]:
-- CREATE SCHEMA IF NOT EXISTS CORTEX_SEARCH_DB.PYU;
-- CREATE OR REPLACE STAGE CORTEX_SEARCH_DB.PYU.MULTIMODAL_DEMO
-- STORAGE_INTEGRATION = ML_DEV
-- URL = 's3://ml-dev-sfc-or-dev-misc1-k8s/cortexsearch/pyu/multimodal/demo/'
-- DIRECTORY = (ENABLE = TRUE);

-- CREATE OR REPLACE STAGE CORTEX_SEARCH_DB.PYU.MULTIMODAL_DEMO_INTERNAL DIRECTORY = (ENABLE = TRUE) ENCRYPTION = (TYPE = 'SNOWFLAKE_SSE');

-- COPY FILES INTO @CORTEX_SEARCH_DB.PYU.MULTIMODAL_DEMO_INTERNAL/raw_pdf/
-- FROM @CORTEX_SEARCH_DB.PYU.MULTIMODAL_DEMO/raw_pdf/;

LS @CORTEX_SEARCH_DB.PYU.MULTIMODAL_DEMO_INTERNAL/raw_pdf/;

Now let's run some python code:

The purpose is to paginate raw pages into pages -- in image and PDF format. Images are for multimodal retrieval, while PDFs are for better OCR quality (optional). As long as you configure the config correctly, you are good to go!

```
class Config:
    input_stage: str = "@CORTEX_SEARCH_DB.PYU.MULTIMODAL_DEMO_INTERNAL/raw_pdf/"
    output_stage: str = "@CORTEX_SEARCH_DB.PYU.MULTIMODAL_DEMO_INTERNAL/"
    input_path: str = "raw_pdf"
    output_pdf_path: str = "paged_pdf"
    output_image_path: str = "paged_image"
    allowed_extensions: List[str] = None
    max_dimension: int = 1500  # Maximum dimension in pixels before scaling
    dpi: int = 300  # Default DPI for image conversion

    def __post_init__(self):
        if self.allowed_extensions is None:
            self.allowed_extensions = [".pdf"]
```

**Make sure the output_stage is an internal stage**, because `embed_image_1024` only works with internal stages at the moment.

In [None]:
# Import python packages
import os
import sys
import tempfile
from contextlib import contextmanager
from dataclasses import dataclass
from typing import List
from typing import Tuple

import pdfplumber
import PyPDF2
import snowflake.snowpark.session as session
import streamlit as st


def print_info(msg: str) -> None:
    """Print info message"""
    print(f"INFO: {msg}", file=sys.stderr)


def print_error(msg: str) -> None:
    """Print error message"""
    print(f"ERROR: {msg}", file=sys.stderr)
    if hasattr(st, "error"):
        st.error(msg)


def print_warning(msg: str) -> None:
    """Print warning message"""
    print(f"WARNING: {msg}", file=sys.stderr)


@dataclass
class Config:
    input_stage: str = "@CORTEX_SEARCH_DB.PYU.MULTIMODAL_DEMO_INTERNAL/raw_pdf/"
    output_stage: str = (
        "@CORTEX_SEARCH_DB.PYU.MULTIMODAL_DEMO_INTERNAL/"  # Base output stage without subdirectories
    )
    input_path: str = "raw_pdf"
    output_pdf_path: str = "paged_pdf"
    output_image_path: str = "paged_image"
    allowed_extensions: List[str] = None
    max_dimension: int = 1500  # Maximum dimension in pixels before scaling
    dpi: int = 300  # Default DPI for image conversion

    def __post_init__(self):
        if self.allowed_extensions is None:
            self.allowed_extensions = [".pdf"]


class PDFProcessingError(Exception):
    """Base exception for PDF processing errors"""


class FileDownloadError(PDFProcessingError):
    """Raised when file download fails"""


class PDFConversionError(PDFProcessingError):
    """Raised when PDF conversion fails"""


@contextmanager
def managed_temp_file(suffix: str = None) -> str:
    """Context manager for temporary file handling"""
    temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
    try:
        yield temp_file.name
    finally:
        # Don't delete the file immediately, let the caller handle cleanup
        pass


def cleanup_temp_file(file_path: str) -> None:
    """Clean up a temporary file"""
    try:
        if os.path.exists(file_path):
            os.unlink(file_path)
    except OSError as e:
        print_warning(f"Failed to delete temporary file {file_path}: {e}")


def list_pdf_files(session: session.Session, config: Config) -> List[dict]:
    """List all PDF files in the source stage"""
    try:
        # Use LIST command instead of DIRECTORY function
        query = f"""
        LIST {config.input_stage}
        """

        file_list = session.sql(query).collect()

        # Filter for PDF files
        pdf_files = []
        for file_info in file_list:
            full_path = file_info["name"]
            # Extract just the filename from the full path
            file_name = os.path.basename(full_path)

            if any(
                file_name.lower().endswith(ext) for ext in config.allowed_extensions
            ):
                pdf_files.append(
                    {
                        "RELATIVE_PATH": file_name,  # Use just the filename
                        "SIZE": file_info["size"] if "size" in file_info else 0,
                    }
                )

        print_info(f"Found {len(pdf_files)} PDF files in the stage")
        return pdf_files
    except Exception as e:
        print_error(f"Failed to list files: {e}")
        raise


def download_file_from_stage(
    session: session.Session, file_path: str, config: Config
) -> str:
    """Download a file from stage using session.file.get"""
    # Create a temporary directory
    temp_dir = tempfile.mkdtemp()
    try:
        # Ensure there are no double slashes in the path
        stage_path = f"{config.input_stage.rstrip('/')}/{file_path.lstrip('/')}"

        # Get the file from stage
        get_result = session.file.get(stage_path, temp_dir)
        if not get_result or get_result[0].status != "DOWNLOADED":
            raise FileDownloadError(f"Failed to download file: {file_path}")

        # Construct the local path where the file was downloaded
        local_path = os.path.join(temp_dir, os.path.basename(file_path))
        if not os.path.exists(local_path):
            raise FileDownloadError(f"Downloaded file not found at: {local_path}")

        return local_path
    except Exception as e:
        print_error(f"Error downloading {file_path}: {e}")
        # Clean up the temporary directory
        try:
            import shutil

            shutil.rmtree(temp_dir)
        except Exception as cleanup_error:
            print_warning(f"Failed to clean up temporary directory: {cleanup_error}")
        raise FileDownloadError(f"Failed to download file: {e}")


def convert_pdf_to_images(pdf_path: str, config: Config) -> List[Tuple[str, int]]:
    """Convert PDF pages to images"""
    temp_files = []  # Keep track of temporary files for cleanup
    try:
        # Open PDF with pdfplumber
        with pdfplumber.open(pdf_path) as pdf:
            print_info(f"Converting PDF to {len(pdf.pages)} images")

            segments = []
            for i, page in enumerate(pdf.pages):
                # Get page dimensions
                width = page.width
                height = page.height

                # Determine if scaling is needed
                max_dim = max(width, height)
                if max_dim > config.max_dimension:
                    # Calculate scale factor to fit within max_dimension
                    scale_factor = config.max_dimension / max_dim
                    width = int(width * scale_factor)
                    height = int(height * scale_factor)

                # Create temporary file for the image
                temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
                temp_files.append(temp_file.name)

                # Convert page to image with specified DPI
                img = page.to_image(resolution=config.dpi)
                img.save(temp_file.name)

                segments.append((temp_file.name, i + 1))

            return segments
    except Exception as e:
        print_error(f"Error converting PDF to images: {e}")
        # Clean up any temporary files created so far
        for temp_file in temp_files:
            cleanup_temp_file(temp_file)
        raise PDFConversionError(f"Failed to convert PDF to images: {e}")


def extract_pdf_pages(pdf_path: str, config: Config) -> List[Tuple[str, int]]:
    """Extract individual pages from PDF using PyPDF2"""
    temp_files = []  # Keep track of temporary files for cleanup
    try:
        # Open PDF with PyPDF2
        with open(pdf_path, "rb") as file:
            pdf_reader = PyPDF2.PdfReader(file)
            num_pages = len(pdf_reader.pages)
            print_info(f"Extracting {num_pages} pages from PDF")

            segments = []
            for i in range(num_pages):
                # Create temporary file for the page
                temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".pdf")
                temp_files.append(temp_file.name)

                # Create a new PDF with just this page
                pdf_writer = PyPDF2.PdfWriter()
                pdf_writer.add_page(pdf_reader.pages[i])

                # Save the page to the temporary file
                with open(temp_file.name, "wb") as output_file:
                    pdf_writer.write(output_file)

                segments.append((temp_file.name, i + 1))

            return segments
    except Exception as e:
        print_error(f"Error extracting PDF pages: {e}")
        # Clean up any temporary files created so far
        for temp_file in temp_files:
            cleanup_temp_file(temp_file)
        raise PDFConversionError(f"Failed to extract PDF pages: {e}")


def upload_file_to_stage(
    session: session.Session, file_path: str, output_path: str, config: Config
) -> str:
    """Upload file to the output stage"""
    try:
        # Get the directory and filename from the output path
        output_dir = os.path.dirname(output_path)
        base_name = os.path.basename(output_path)

        # Create the full stage path with subdirectory
        stage_path = f"{config.output_stage.rstrip('/')}/{output_dir.lstrip('/')}"

        # Read the content of the original file
        with open(file_path, "rb") as f:
            file_content = f.read()

        # Create a new file with the correct name
        temp_dir = tempfile.gettempdir()
        temp_file_path = os.path.join(temp_dir, base_name)

        # Write the content to the new file
        with open(temp_file_path, "wb") as f:
            f.write(file_content)

        # Upload the file using session.file.put with compression disabled
        put_result = session.file.put(
            temp_file_path, stage_path, auto_compress=False, overwrite=True
        )

        # Check upload status
        if not put_result or len(put_result) == 0:
            raise Exception(f"Failed to upload file: {base_name}")

        if put_result[0].status not in ["UPLOADED", "SKIPPED"]:
            raise Exception(f"Upload failed with status: {put_result[0].status}")

        # Clean up the temporary file
        if os.path.exists(temp_file_path):
            os.remove(temp_file_path)

        return f"Successfully uploaded {base_name} to {stage_path}"
    except Exception as e:
        print_error(f"Error uploading file: {e}")
        raise


def process_pdf_files(config: Config) -> None:
    """Main process to orchestrate the PDF splitting"""
    try:
        session = get_active_session()
        pdf_files = list_pdf_files(session, config)

        for file_info in pdf_files:
            file_path = file_info["RELATIVE_PATH"]
            print_info(f"Processing: {file_path}")

            try:
                # Download the PDF file
                local_pdf_path = download_file_from_stage(session, file_path, config)

                # Get base filename without extension
                base_name = os.path.splitext(os.path.basename(file_path))[0]

                # Extract individual PDF pages
                pdf_segments = extract_pdf_pages(local_pdf_path, config)

                # Convert PDF to images
                image_segments = convert_pdf_to_images(local_pdf_path, config)

                # Process each page
                for (pdf_segment, page_num), (image_segment, _) in zip(
                    pdf_segments, image_segments
                ):
                    try:
                        # Create proper output paths with correct naming
                        pdf_output_path = (
                            f"{config.output_pdf_path}/{base_name}_page_{page_num}.pdf"
                        )
                        image_output_path = f"{config.output_image_path}/{base_name}_page_{page_num}.png"

                        # Upload PDF page directly from the temporary file
                        upload_file_to_stage(
                            session, pdf_segment, pdf_output_path, config
                        )

                        # Upload image page directly from the temporary file
                        upload_file_to_stage(
                            session, image_segment, image_output_path, config
                        )
                    except Exception as e:
                        print_error(
                            f"Error processing page {page_num} of {file_path}: {e}"
                        )
                    finally:
                        # Clean up temporary files
                        cleanup_temp_file(pdf_segment)
                        cleanup_temp_file(image_segment)

                # Clean up the original downloaded file
                cleanup_temp_file(local_pdf_path)

            except Exception as e:
                print_error(f"Error processing {file_path}: {e}")
                continue

    except Exception as e:
        print_error(f"Fatal error in process_pdf_files: {e}")
        raise



config = Config(dpi=200)
process_pdf_files(config)

Now let's start the multimodal embedding part! We first create an intermediate table that holds relative file names of images, and then call `SNOWFLAKE.CORTEX.embed_image_1024` to turn them into vectors!

In [None]:
CREATE OR REPLACE TABLE CORTEX_SEARCH_DB.PYU.DEMO_SEC_IMAGE_CORPUS AS
SELECT
    CONCAT('paged_image/', split_part(metadata$filename, '/', -1)) AS FILE_NAME,
    '@CORTEX_SEARCH_DB.PYU.MULTIMODAL_DEMO_INTERNAL' AS STAGE_PREFIX
FROM
    @CORTEX_SEARCH_DB.PYU.MULTIMODAL_DEMO_INTERNAL/paged_image/
GROUP BY FILE_NAME, STAGE_PREFIX
;

SELECT * FROM CORTEX_SEARCH_DB.PYU.DEMO_SEC_IMAGE_CORPUS LIMIT 5;

In [None]:
CREATE OR REPLACE TABLE CORTEX_SEARCH_DB.PYU.DEMO_SEC_VM3_VECTORS AS
SELECT FILE_NAME, STAGE_PREFIX, SNOWFLAKE.CORTEX.embed_image_1024('voyage-multimodal-3', '@CORTEX_SEARCH_DB.PYU.MULTIMODAL_DEMO_INTERNAL', FILE_NAME) AS IMAGE_VECTOR
FROM CORTEX_SEARCH_DB.PYU.DEMO_SEC_IMAGE_CORPUS;


SELECT * FROM CORTEX_SEARCH_DB.PYU.DEMO_SEC_VM3_VECTORS LIMIT 5;

Similarly, we call `SNOWFLAKE.CORTEX.PARSE_DOCUMENT` to extract text from PDF pages. We discover that, although multimodal retrieval is powerful, augmenting it with text retrieval for keyword matching can bring quality improvement on certain types of search tasks/queries.

In [None]:
CREATE OR REPLACE TABLE CORTEX_SEARCH_DB.PYU.DEMO_SEC_PDF_CORPUS AS
SELECT
    CONCAT('paged_pdf/', split_part(metadata$filename, '/', -1)) AS FILE_NAME,
    '@CORTEX_SEARCH_DB.PYU.MULTIMODAL_DEMO_INTERNAL' AS STAGE_PREFIX
FROM
    @CORTEX_SEARCH_DB.PYU.MULTIMODAL_DEMO_INTERNAL/paged_pdf/
GROUP BY FILE_NAME, STAGE_PREFIX
;

CREATE OR REPLACE TABLE CORTEX_SEARCH_DB.PYU.DEMO_SEC_PARSE_DOC AS
    SELECT
        FILE_NAME,
        STAGE_PREFIX,
        PARSE_JSON(TO_VARCHAR(SNOWFLAKE.CORTEX.PARSE_DOCUMENT(
            '@CORTEX_SEARCH_DB.PYU.MULTIMODAL_DEMO_INTERNAL',
            FILE_NAME,
            {'mode': 'LAYOUT'}
        ))):content AS PARSE_DOC_OUTPUT
    FROM CORTEX_SEARCH_DB.PYU.DEMO_SEC_PDF_CORPUS
;

SELECT * FROM CORTEX_SEARCH_DB.PYU.DEMO_SEC_PARSE_DOC LIMIT 5;

Now we join image vectors and texts into a single table, and create a Cortex Search service!

In [None]:
CREATE OR REPLACE TABLE CORTEX_SEARCH_DB.PYU.DEMO_SEC_JOINED_DATA AS
SELECT
    REGEXP_SUBSTR(v.FILE_NAME, 'paged_image/(.*)\\.png$', 1, 1, 'e', 1) AS INCLUDE_PAGE_ID,
    v.IMAGE_VECTOR AS VECTOR_MAIN,
    p.PARSE_DOC_OUTPUT AS TEXT
FROM
    CORTEX_SEARCH_DB.PYU.DEMO_SEC_VM3_VECTORS v
JOIN
    CORTEX_SEARCH_DB.PYU.DEMO_SEC_PARSE_DOC p
ON
    REGEXP_SUBSTR(v.FILE_NAME, 'paged_image/(.*)\\.png$', 1, 1, 'e', 1) = REGEXP_SUBSTR(p.FILE_NAME, 'paged_pdf/(.*)\\.pdf$', 1, 1, 'e', 1);


CREATE OR REPLACE CORTEX SEARCH SERVICE CORTEX_SEARCH_DB.PYU.DEMO_SEC_CORTEX_SEARCH_SERVICE
  TEXT INDEXES TEXT
  VECTOR INDEXES VECTOR_MAIN
  WAREHOUSE='SEARCH_L'
  TARGET_LAG='1 day'
AS (
    SELECT 
        TO_VARCHAR(TEXT) AS TEXT, 
        INCLUDE_PAGE_ID, 
        VECTOR_MAIN
    FROM CORTEX_SEARCH_DB.PYU.DEMO_SEC_JOINED_DATA
);

Note that multimodal retrieval is not GA in cortex search, thus we cannot specify a multimodal embedding model when creating the service. Instead, we will embed queries directly with `SNOWFLAKE.CORTEX.EMBED_TEXT_1024` and call cortex search service with `experimental={'queryEmbedding': query_vector}`

In [None]:

demo_query_text = "What was the overall operational cost incurred by Abbott Laboratories in 2023, and how much of this amount was allocated to research and development?"
sql_output = session.sql(f"""SELECT SNOWFLAKE.CORTEX.EMBED_TEXT_1024('voyage-multimodal-3', 'Represent the query for retrieving supporting documents:  {demo_query_text}')""").collect()
query_vector = list(sql_output[0].asDict().values())[0]
print(query_vector)

In [None]:
from snowflake.core import Root


root = Root(session)
# fetch service
my_service = (root
  .databases["CORTEX_SEARCH_DB"]
  .schemas["PYU"]
  .cortex_search_services["DEMO_SEC_CORTEX_SEARCH_SERVICE"]
)

# query service
resp = my_service.search(
  query=demo_query_text,
  columns=["TEXT", "INCLUDE_PAGE_ID"],
  limit=5,
  experimental={'queryEmbedding': query_vector}
)

for i in range(5):
    print(f"rank {i + 1}: {resp.to_dict()['results'][i]['INCLUDE_PAGE_ID']}")

top_page_id = resp.to_dict()['results'][0]['INCLUDE_PAGE_ID']

Let's see the top ranked page we found!

In [None]:
session = get_active_session()
image=session.file.get_stream(
    f"@CORTEX_SEARCH_DB.PYU.MULTIMODAL_DEMO_INTERNAL/paged_image/{top_page_id}.png",
    decompress=False).read()
st.image(image)

Finally, we can also perform multimodal retrieval augmented generation (mRAG) by sending the query and the top page image to a multimodal LLM served on snowflake cortex and get a natural language answer to our question.

In [None]:
SELECT SNOWFLAKE.CORTEX.COMPLETE('pixtral-large',
    '@CORTEX_SEARCH_DB.PYU.MULTIMODAL_DEMO_INTERNAL', 'paged_image/abbott-laboratories-10-q-2024-10-31_page_4.png',
    'Answer the following question by referencing the document image: What was the overall operational cost incurred by Abbott Laboratories in 2023, and how much of this amount was allocated to research and development?');