In [54]:
from azure.storage.filedatalake.aio import FileSystemClient
from azure.identity.aio import (
    AzureDeveloperCliCredential,
    ManagedIdentityCredential,
    get_bearer_token_provider,
)
from io import BytesIO

In [2]:
AZURE_USERSTORAGE_ACCOUNT="usersttlhc65hyz52oy"
AZURE_USERSTORAGE_CONTAINER="user-content"


In [None]:
user_oid = "63e8dafc-039c-447b-9df0-6b2c9aec1df2"
original_user_query = "Extract tables from 2013_contoso_products.pdf file"

In [16]:
azure_credential = ManagedIdentityCredential()

In [17]:
user_blob_container_client = FileSystemClient(
            f"https://{AZURE_USERSTORAGE_ACCOUNT}.dfs.core.windows.net",
            AZURE_USERSTORAGE_CONTAINER,
            credential=azure_credential,
        )

In [6]:
from fuzzywuzzy import fuzz
import os

# Custom stopwords to ignore while matching
custom_stopwords = {'merger', 'agreement', 'document', 'memo', 'file', 'note'}

def clean_text(text):
    # Lowercase, remove extension, replace _ with space, and remove stopwords
    text = os.path.splitext(text)[0].replace('_', ' ').lower()
    words = text.split()
    return ' '.join([word for word in words if word not in custom_stopwords])

def extract_matching_filename(input_string, filenames, threshold=70):
    input_clean = clean_text(input_string)
    matches = []

    for filename in filenames:
        file_clean = clean_text(filename)
        similarity = fuzz.partial_ratio(file_clean, input_clean)

        if similarity >= threshold:
            matches.append((filename, similarity))

    matches.sort(key=lambda x: x[1], reverse=True)
    return matches[0][0] if matches else None




In [51]:
def get_file_name(user_query, all_paths):
    files = []
    try:
        for path in all_paths:
            print(path)
            files.append(path.name.split('/')[-1])
    except Exception as error:
        print("Error listing uploaded files", error)
    file_mentioned_in_query = extract_matching_filename(user_query, files)
    return file_mentioned_in_query

In [None]:

from azure.identity import DefaultAzureCredential
from azure.storage.blob import BlobServiceClient

# Authenticate using DefaultAzureCredential
credential = DefaultAzureCredential()

account_url = f"https://{AZURE_USERSTORAGE_ACCOUNT}.blob.core.windows.net"
blob_service_client = BlobServiceClient(account_url=account_url, credential=credential)

# Get the container and blob clients
container_client = blob_service_client.get_container_client(AZURE_USERSTORAGE_CONTAINER)




In [52]:
# all_paths = user_blob_container_client.get_paths(path=user_oid)
all_paths = container_client.list_blobs()

file_mentioned_by_client = get_file_name(original_user_query, all_paths)
file_mentioned_by_client

{'name': '63e8dafc-039c-447b-9df0-6b2c9aec1df2', 'container': 'user-content', 'snapshot': None, 'version_id': None, 'is_current_version': None, 'blob_type': <BlobType.BLOCKBLOB: 'BlockBlob'>, 'metadata': {}, 'encrypted_metadata': None, 'last_modified': datetime.datetime(2025, 4, 29, 12, 3, 17, tzinfo=datetime.timezone.utc), 'etag': '0x8DD8715D4056180', 'size': 0, 'content_range': None, 'append_blob_committed_block_count': None, 'is_append_blob_sealed': None, 'page_blob_sequence_number': None, 'server_encrypted': True, 'copy': {'id': None, 'source': None, 'status': None, 'progress': None, 'completion_time': None, 'status_description': None, 'incremental_copy': None, 'destination_snapshot': None}, 'content_settings': {'content_type': 'application/octet-stream', 'content_encoding': None, 'content_language': None, 'content_md5': None, 'content_disposition': None, 'cache_control': None}, 'lease': {'status': 'unlocked', 'state': 'available', 'duration': None}, 'blob_tier': 'Hot', 'rehydrate_

'2013_contoso_products.pdf'

In [53]:
blob_client = container_client.get_blob_client('63e8dafc-039c-447b-9df0-6b2c9aec1df2/2013_contoso_products.pdf')

In [55]:

# Download the blob content to a buffer
stream = BytesIO()
blob_client.download_blob().readinto(stream)

# Reset the buffer's position to the beginning
stream.seek(0)

0

In [56]:
import base64
import hashlib
import logging
import os
import re
import tempfile
from abc import ABC
from collections.abc import AsyncGenerator
from glob import glob
from typing import IO, Optional, Union

class File:
    """
    Represents a file stored either locally or in a data lake storage account
    This file might contain access control information about which users or groups can access it
    """

    def __init__(self, content: IO, acls: Optional[dict[str, list]] = None, url: Optional[str] = None):
        self.content = content
        self.acls = acls or {}
        self.url = url

    def filename(self):
        return os.path.basename(self.content.name)

    def file_extension(self):
        return os.path.splitext(self.content.name)[1]

    def filename_to_id(self):
        filename_ascii = re.sub("[^0-9a-zA-Z_-]", "_", self.filename())
        filename_hash = base64.b16encode(self.filename().encode("utf-8")).decode("ascii")
        acls_hash = ""
        if self.acls:
            acls_hash = base64.b16encode(str(self.acls).encode("utf-8")).decode("ascii")
        return f"file-{filename_ascii}-{filename_hash}{acls_hash}"

    def close(self):
        if self.content:
            self.content.close()


In [None]:
File(content=stream, acls={"oids": [user_oid]}, url=blob_client.url)

'https://usersttlhc65hyz52oy.blob.core.windows.net/user-content/63e8dafc-039c-447b-9df0-6b2c9aec1df2/2013_contoso_products.pdf'

In [68]:
from abc import ABC
from collections.abc import AsyncGenerator
from typing import IO

from collections.abc import Generator
from dataclasses import dataclass

from azure.ai.documentintelligence.aio import DocumentIntelligenceClient



In [64]:
class Page:
    """
    A single page from a document

    Attributes:
        page_num (int): Page number (0-indexed)
        offset (int): If the text of the entire Document was concatenated into a single string, the index of the first character on the page. For example, if page 1 had the text "hello" and page 2 had the text "world", the offset of page 2 is 5 ("hellow")
        text (str): The text of the page
    """

    def __init__(self, page_num: int, offset: int, text: str):
        self.page_num = page_num
        self.offset = offset
        self.text = text

class Parser(ABC):
    """
    Abstract parser that parses content into Page objects
    """

    async def parse(self, content: IO) -> AsyncGenerator[Page, None]:
        if False:
            yield  # pragma: no cover - this is necessary for mypy to type check


class SplitPage:
    """
    A section of a page that has been split into a smaller chunk.

    Attributes:
        page_num (int): Page number (0-indexed)
        text (str): The text of the section
    """

    def __init__(self, page_num: int, text: str):
        self.page_num = page_num
        self.text = text


class TextSplitter(ABC):
    """
    Splits a list of pages into smaller chunks
    :param pages: The pages to split
    :return: A generator of SplitPage
    """

    def split_pages(self, pages: list[Page]) -> Generator[SplitPage, None, None]:
        if False:
            yield  # pragma: no cover - this is necessary for mypy to type check

@dataclass(frozen=True)
class FileProcessor:
    parser: Parser
    splitter: TextSplitter

class Section:
    """
    A section of a page that is stored in a search service. These sections are used as context by Azure OpenAI service
    """

    def __init__(self, split_page: SplitPage, content: File, category: Optional[str] = None):
        self.split_page = split_page
        self.content = content
        self.category = category

In [63]:
import logging
from abc import ABC
from collections.abc import Awaitable
from typing import Callable, Optional, Union
from urllib.parse import urljoin

import aiohttp
import tiktoken
from azure.core.credentials import AzureKeyCredential
from azure.core.credentials_async import AsyncTokenCredential
from azure.identity.aio import get_bearer_token_provider
from openai import AsyncAzureOpenAI, AsyncOpenAI, RateLimitError
from tenacity import (
    AsyncRetrying,
    retry_if_exception_type,
    stop_after_attempt,
    wait_random_exponential,
)
from typing_extensions import TypedDict

logger = logging.getLogger("scripts")


class EmbeddingBatch:
    """
    Represents a batch of text that is going to be embedded
    """

    def __init__(self, texts: list[str], token_length: int):
        self.texts = texts
        self.token_length = token_length


class ExtraArgs(TypedDict, total=False):
    dimensions: int


class OpenAIEmbeddings(ABC):
    """
    Contains common logic across both OpenAI and Azure OpenAI embedding services
    Can split source text into batches for more efficient embedding calls
    """

    SUPPORTED_BATCH_AOAI_MODEL = {
        "text-embedding-ada-002": {"token_limit": 8100, "max_batch_size": 16},
        "text-embedding-3-small": {"token_limit": 8100, "max_batch_size": 16},
        "text-embedding-3-large": {"token_limit": 8100, "max_batch_size": 16},
    }
    SUPPORTED_DIMENSIONS_MODEL = {
        "text-embedding-ada-002": False,
        "text-embedding-3-small": True,
        "text-embedding-3-large": True,
    }

    def __init__(self, open_ai_model_name: str, open_ai_dimensions: int, disable_batch: bool = False):
        self.open_ai_model_name = open_ai_model_name
        self.open_ai_dimensions = open_ai_dimensions
        self.disable_batch = disable_batch

    async def create_client(self) -> AsyncOpenAI:
        raise NotImplementedError

    def before_retry_sleep(self, retry_state):
        logger.info("Rate limited on the OpenAI embeddings API, sleeping before retrying...")

    def calculate_token_length(self, text: str):
        encoding = tiktoken.encoding_for_model(self.open_ai_model_name)
        return len(encoding.encode(text))

    def split_text_into_batches(self, texts: list[str]) -> list[EmbeddingBatch]:
        batch_info = OpenAIEmbeddings.SUPPORTED_BATCH_AOAI_MODEL.get(self.open_ai_model_name)
        if not batch_info:
            raise NotImplementedError(
                f"Model {self.open_ai_model_name} is not supported with batch embedding operations"
            )

        batch_token_limit = batch_info["token_limit"]
        batch_max_size = batch_info["max_batch_size"]
        batches: list[EmbeddingBatch] = []
        batch: list[str] = []
        batch_token_length = 0
        for text in texts:
            text_token_length = self.calculate_token_length(text)
            if batch_token_length + text_token_length >= batch_token_limit and len(batch) > 0:
                batches.append(EmbeddingBatch(batch, batch_token_length))
                batch = []
                batch_token_length = 0

            batch.append(text)
            batch_token_length = batch_token_length + text_token_length
            if len(batch) == batch_max_size:
                batches.append(EmbeddingBatch(batch, batch_token_length))
                batch = []
                batch_token_length = 0

        if len(batch) > 0:
            batches.append(EmbeddingBatch(batch, batch_token_length))

        return batches

    async def create_embedding_batch(self, texts: list[str], dimensions_args: ExtraArgs) -> list[list[float]]:
        batches = self.split_text_into_batches(texts)
        embeddings = []
        client = await self.create_client()
        for batch in batches:
            async for attempt in AsyncRetrying(
                retry=retry_if_exception_type(RateLimitError),
                wait=wait_random_exponential(min=15, max=60),
                stop=stop_after_attempt(15),
                before_sleep=self.before_retry_sleep,
            ):
                with attempt:
                    emb_response = await client.embeddings.create(
                        model=self.open_ai_model_name, input=batch.texts, **dimensions_args
                    )
                    embeddings.extend([data.embedding for data in emb_response.data])
                    logger.info(
                        "Computed embeddings in batch. Batch size: %d, Token count: %d",
                        len(batch.texts),
                        batch.token_length,
                    )

        return embeddings

    async def create_embedding_single(self, text: str, dimensions_args: ExtraArgs) -> list[float]:
        client = await self.create_client()
        async for attempt in AsyncRetrying(
            retry=retry_if_exception_type(RateLimitError),
            wait=wait_random_exponential(min=15, max=60),
            stop=stop_after_attempt(15),
            before_sleep=self.before_retry_sleep,
        ):
            with attempt:
                emb_response = await client.embeddings.create(
                    model=self.open_ai_model_name, input=text, **dimensions_args
                )
                logger.info("Computed embedding for text section. Character count: %d", len(text))

        return emb_response.data[0].embedding

    async def create_embeddings(self, texts: list[str]) -> list[list[float]]:

        dimensions_args: ExtraArgs = (
            {"dimensions": self.open_ai_dimensions}
            if OpenAIEmbeddings.SUPPORTED_DIMENSIONS_MODEL.get(self.open_ai_model_name)
            else {}
        )

        if not self.disable_batch and self.open_ai_model_name in OpenAIEmbeddings.SUPPORTED_BATCH_AOAI_MODEL:
            return await self.create_embedding_batch(texts, dimensions_args)

        return [await self.create_embedding_single(text, dimensions_args) for text in texts]


class AzureOpenAIEmbeddingService(OpenAIEmbeddings):
    """
    Class for using Azure OpenAI embeddings
    To learn more please visit https://learn.microsoft.com/azure/ai-services/openai/concepts/understand-embeddings
    """

    def __init__(
        self,
        open_ai_service: Union[str, None],
        open_ai_deployment: Union[str, None],
        open_ai_model_name: str,
        open_ai_dimensions: int,
        open_ai_api_version: str,
        credential: Union[AsyncTokenCredential, AzureKeyCredential],
        open_ai_custom_url: Union[str, None] = None,
        disable_batch: bool = False,
    ):
        super().__init__(open_ai_model_name, open_ai_dimensions, disable_batch)
        self.open_ai_service = open_ai_service
        if open_ai_service:
            self.open_ai_endpoint = f"https://{open_ai_service}.openai.azure.com"
        elif open_ai_custom_url:
            self.open_ai_endpoint = open_ai_custom_url
        else:
            raise ValueError("Either open_ai_service or open_ai_custom_url must be provided")
        self.open_ai_deployment = open_ai_deployment
        self.open_ai_api_version = open_ai_api_version
        self.credential = credential

    async def create_client(self) -> AsyncOpenAI:
        class AuthArgs(TypedDict, total=False):
            api_key: str
            azure_ad_token_provider: Callable[[], Union[str, Awaitable[str]]]

        auth_args = AuthArgs()
        if isinstance(self.credential, AzureKeyCredential):
            auth_args["api_key"] = self.credential.key
        elif isinstance(self.credential, AsyncTokenCredential):
            auth_args["azure_ad_token_provider"] = get_bearer_token_provider(
                self.credential, "https://cognitiveservices.azure.com/.default"
            )
        else:
            raise TypeError("Invalid credential type")

        return AsyncAzureOpenAI(
            azure_endpoint=self.open_ai_endpoint,
            azure_deployment=self.open_ai_deployment,
            api_version=self.open_ai_api_version,
            **auth_args,
        )


class OpenAIEmbeddingService(OpenAIEmbeddings):
    """
    Class for using OpenAI embeddings
    To learn more please visit https://platform.openai.com/docs/guides/embeddings
    """

    def __init__(
        self,
        open_ai_model_name: str,
        open_ai_dimensions: int,
        credential: str,
        organization: Optional[str] = None,
        disable_batch: bool = False,
    ):
        super().__init__(open_ai_model_name, open_ai_dimensions, disable_batch)
        self.credential = credential
        self.organization = organization

    async def create_client(self) -> AsyncOpenAI:
        return AsyncOpenAI(api_key=self.credential, organization=self.organization)


class ImageEmbeddings:
    """
    Class for using image embeddings from Azure AI Vision
    To learn more, please visit https://learn.microsoft.com/azure/ai-services/computer-vision/how-to/image-retrieval#call-the-vectorize-image-api
    """

    def __init__(self, endpoint: str, token_provider: Callable[[], Awaitable[str]]):
        self.token_provider = token_provider
        self.endpoint = endpoint

    async def create_embeddings(self, blob_urls: list[str]) -> list[list[float]]:
        endpoint = urljoin(self.endpoint, "computervision/retrieval:vectorizeImage")
        headers = {"Content-Type": "application/json"}
        params = {"api-version": "2023-02-01-preview", "modelVersion": "latest"}
        headers["Authorization"] = "Bearer " + await self.token_provider()

        embeddings: list[list[float]] = []
        async with aiohttp.ClientSession(headers=headers) as session:
            for blob_url in blob_urls:
                async for attempt in AsyncRetrying(
                    retry=retry_if_exception_type(Exception),
                    wait=wait_random_exponential(min=15, max=60),
                    stop=stop_after_attempt(15),
                    before_sleep=self.before_retry_sleep,
                ):
                    with attempt:
                        body = {"url": blob_url}
                        async with session.post(url=endpoint, params=params, json=body) as resp:
                            resp_json = await resp.json()
                            embeddings.append(resp_json["vector"])

        return embeddings

    def before_retry_sleep(self, retry_state):
        logger.info("Rate limited on the Vision embeddings API, sleeping before retrying...")


  code = compile(arg_to_compile, '<string>', 'eval')


In [65]:
async def parse_file(
    file: File,
    file_processors: dict[str, FileProcessor],
    category: Optional[str] = None,
    image_embeddings: Optional[ImageEmbeddings] = None,
) -> list[Section]:
    key = file.file_extension().lower()
    processor = file_processors.get(key)
    if processor is None:
        logger.info("Skipping '%s', no parser found.", file.filename())
        return []
    logger.info("Ingesting '%s'", file.filename())
    pages = [page async for page in processor.parser.parse(content=file.content)]
    logger.info("Splitting '%s' into sections", file.filename())
    if image_embeddings:
        logger.warning("Each page will be split into smaller chunks of text, but images will be of the entire page.")
    sections = [
        Section(split_page, content=file, category=category) for split_page in processor.splitter.split_pages(pages)
    ]
    return sections


In [66]:
class FetchUserFileStrategy:
    """
    Strategy for parsing a file that has already been uploaded to a ADLS2 storage account
    """

    def __init__(
        self,
        file_processors: dict[str, FileProcessor],
    ):
        self.file_processors = file_processors

    async def fetch_file(self, file: File):
        if self.image_embeddings:
            logging.warning("Image embeddings are not currently supported for the user upload feature")
        sections = await parse_file(file, self.file_processors)
        return sections

In [67]:
import logging
from abc import ABC
from collections.abc import Generator

import tiktoken

logger = logging.getLogger("scripts")


class TextSplitter(ABC):
    """
    Splits a list of pages into smaller chunks
    :param pages: The pages to split
    :return: A generator of SplitPage
    """

    def split_pages(self, pages: list[Page]) -> Generator[SplitPage, None, None]:
        if False:
            yield  # pragma: no cover - this is necessary for mypy to type check


ENCODING_MODEL = "text-embedding-ada-002"

STANDARD_WORD_BREAKS = [",", ";", ":", " ", "(", ")", "[", "]", "{", "}", "\t", "\n"]

# See W3C document https://www.w3.org/TR/jlreq/#cl-01
CJK_WORD_BREAKS = [
    "、",
    "，",
    "；",
    "：",
    "（",
    "）",
    "【",
    "】",
    "「",
    "」",
    "『",
    "』",
    "〔",
    "〕",
    "〈",
    "〉",
    "《",
    "》",
    "〖",
    "〗",
    "〘",
    "〙",
    "〚",
    "〛",
    "〝",
    "〞",
    "〟",
    "〰",
    "–",
    "—",
    "‘",
    "’",
    "‚",
    "‛",
    "“",
    "”",
    "„",
    "‟",
    "‹",
    "›",
]

STANDARD_SENTENCE_ENDINGS = [".", "!", "?"]

# See CL05 and CL06, based on JIS X 4051:2004
# https://www.w3.org/TR/jlreq/#cl-04
CJK_SENTENCE_ENDINGS = ["。", "！", "？", "‼", "⁇", "⁈", "⁉"]

# NB: text-embedding-3-XX is the same BPE as text-embedding-ada-002
bpe = tiktoken.encoding_for_model(ENCODING_MODEL)

DEFAULT_OVERLAP_PERCENT = 10  # See semantic search article for 10% overlap performance
DEFAULT_SECTION_LENGTH = 1000  # Roughly 400-500 tokens for English


class SentenceTextSplitter(TextSplitter):
    """
    Class that splits pages into smaller chunks. This is required because embedding models may not be able to analyze an entire page at once
    """

    def __init__(self, max_tokens_per_section: int = 500):
        self.sentence_endings = STANDARD_SENTENCE_ENDINGS + CJK_SENTENCE_ENDINGS
        self.word_breaks = STANDARD_WORD_BREAKS + CJK_WORD_BREAKS
        self.max_section_length = DEFAULT_SECTION_LENGTH
        self.sentence_search_limit = 100
        self.max_tokens_per_section = max_tokens_per_section
        self.section_overlap = int(self.max_section_length * DEFAULT_OVERLAP_PERCENT / 100)

    def split_page_by_max_tokens(self, page_num: int, text: str) -> Generator[SplitPage, None, None]:
        """
        Recursively splits page by maximum number of tokens to better handle languages with higher token/word ratios.
        """
        tokens = bpe.encode(text)
        if len(tokens) <= self.max_tokens_per_section:
            # Section is already within max tokens, return
            yield SplitPage(page_num=page_num, text=text)
        else:
            # Start from the center and try and find the closest sentence ending by spiralling outward.
            # IF we get to the outer thirds, then just split in half with a 5% overlap
            start = int(len(text) // 2)
            pos = 0
            boundary = int(len(text) // 3)
            split_position = -1
            while start - pos > boundary:
                if text[start - pos] in self.sentence_endings:
                    split_position = start - pos
                    break
                elif text[start + pos] in self.sentence_endings:
                    split_position = start + pos
                    break
                else:
                    pos += 1

            if split_position > 0:
                first_half = text[: split_position + 1]
                second_half = text[split_position + 1 :]
            else:
                # Split page in half and call function again
                # Overlap first and second halves by DEFAULT_OVERLAP_PERCENT%
                middle = int(len(text) // 2)
                overlap = int(len(text) * (DEFAULT_OVERLAP_PERCENT / 100))
                first_half = text[: middle + overlap]
                second_half = text[middle - overlap :]
            yield from self.split_page_by_max_tokens(page_num, first_half)
            yield from self.split_page_by_max_tokens(page_num, second_half)

    def split_pages(self, pages: list[Page]) -> Generator[SplitPage, None, None]:
        def find_page(offset):
            num_pages = len(pages)
            for i in range(num_pages - 1):
                if offset >= pages[i].offset and offset < pages[i + 1].offset:
                    return pages[i].page_num
            return pages[num_pages - 1].page_num

        all_text = "".join(page.text for page in pages)
        if len(all_text.strip()) == 0:
            return

        length = len(all_text)
        if length <= self.max_section_length:
            yield from self.split_page_by_max_tokens(page_num=find_page(0), text=all_text)
            return

        start = 0
        end = length
        while start + self.section_overlap < length:
            last_word = -1
            end = start + self.max_section_length

            if end > length:
                end = length
            else:
                # Try to find the end of the sentence
                while (
                    end < length
                    and (end - start - self.max_section_length) < self.sentence_search_limit
                    and all_text[end] not in self.sentence_endings
                ):
                    if all_text[end] in self.word_breaks:
                        last_word = end
                    end += 1
                if end < length and all_text[end] not in self.sentence_endings and last_word > 0:
                    end = last_word  # Fall back to at least keeping a whole word
            if end < length:
                end += 1

            # Try to find the start of the sentence or at least a whole word boundary
            last_word = -1
            while (
                start > 0
                and start > end - self.max_section_length - 2 * self.sentence_search_limit
                and all_text[start] not in self.sentence_endings
            ):
                if all_text[start] in self.word_breaks:
                    last_word = start
                start -= 1
            if all_text[start] not in self.sentence_endings and last_word > 0:
                start = last_word
            if start > 0:
                start += 1

            section_text = all_text[start:end]
            yield from self.split_page_by_max_tokens(page_num=find_page(start), text=section_text)

            last_figure_start = section_text.rfind("<figure")
            if last_figure_start > 2 * self.sentence_search_limit and last_figure_start > section_text.rfind(
                "</figure"
            ):
                # If the section ends with an unclosed figure, we need to start the next section with the figure.
                start = min(end - self.section_overlap, start + last_figure_start)
                logger.info(
                    f"Section ends with unclosed figure, starting next section with the figure at page {find_page(start)} offset {start} figure start {last_figure_start}"
                )
            else:
                start = end - self.section_overlap

        if start + self.section_overlap < end:
            yield from self.split_page_by_max_tokens(page_num=find_page(start), text=all_text[start:end])


class SimpleTextSplitter(TextSplitter):
    """
    Class that splits pages into smaller chunks based on a max object length. It is not aware of the content of the page.
    This is required because embedding models may not be able to analyze an entire page at once
    """

    def __init__(self, max_object_length: int = 1000):
        self.max_object_length = max_object_length

    def split_pages(self, pages: list[Page]) -> Generator[SplitPage, None, None]:
        all_text = "".join(page.text for page in pages)
        if len(all_text.strip()) == 0:
            return

        length = len(all_text)
        if length <= self.max_object_length:
            yield SplitPage(page_num=0, text=all_text)
            return

        # its too big, so we need to split it
        for i in range(0, length, self.max_object_length):
            yield SplitPage(page_num=i // self.max_object_length, text=all_text[i : i + self.max_object_length])
        return


In [70]:
import html
import io
import logging
from collections.abc import AsyncGenerator
from enum import Enum
from typing import IO, Union

import pymupdf
from azure.ai.documentintelligence.aio import DocumentIntelligenceClient
from azure.ai.documentintelligence.models import (
    AnalyzeDocumentRequest,
    AnalyzeResult,
    DocumentFigure,
    DocumentTable,
)
from azure.core.credentials import AzureKeyCredential
from azure.core.credentials_async import AsyncTokenCredential
from azure.core.exceptions import HttpResponseError
from PIL import Image
from pypdf import PdfReader

In [71]:

class LocalPdfParser(Parser):
    """
    Concrete parser backed by PyPDF that can parse PDFs into pages
    To learn more, please visit https://pypi.org/project/pypdf/
    """

    async def parse(self, content: IO) -> AsyncGenerator[Page, None]:
        logger.info("Extracting text from '%s' using local PDF parser (pypdf)", content.name)

        reader = PdfReader(content)
        pages = reader.pages
        offset = 0
        for page_num, p in enumerate(pages):
            page_text = p.extract_text()
            yield Page(page_num=page_num, offset=offset, text=page_text)
            offset += len(page_text)



In [72]:
import logging
from abc import ABC

import aiohttp
from azure.core.credentials_async import AsyncTokenCredential
from azure.identity.aio import get_bearer_token_provider
from rich.progress import Progress
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed

logger = logging.getLogger("scripts")


class MediaDescriber(ABC):

    async def describe_image(self, image_bytes) -> str:
        raise NotImplementedError  # pragma: no cover


class ContentUnderstandingDescriber:
    CU_API_VERSION = "2024-12-01-preview"

    analyzer_schema = {
        "analyzerId": "image_analyzer",
        "name": "Image understanding",
        "description": "Extract detailed structured information from images extracted from documents.",
        "baseAnalyzerId": "prebuilt-image",
        "scenario": "image",
        "config": {"returnDetails": False},
        "fieldSchema": {
            "name": "ImageInformation",
            "descriptions": "Description of image.",
            "fields": {
                "Description": {
                    "type": "string",
                    "description": "Description of the image. If the image has a title, start with the title. Include a 2-sentence summary. If the image is a chart, diagram, or table, include the underlying data in an HTML table tag, with accurate numbers. If the image is a chart, describe any axis or legends. The only allowed HTML tags are the table/thead/tr/td/tbody tags.",
                },
            },
        },
    }

    def __init__(self, endpoint: str, credential: AsyncTokenCredential):
        self.endpoint = endpoint
        self.credential = credential

    async def poll_api(self, session, poll_url, headers):

        @retry(stop=stop_after_attempt(60), wait=wait_fixed(2), retry=retry_if_exception_type(ValueError))
        async def poll():
            async with session.get(poll_url, headers=headers) as response:
                response.raise_for_status()
                response_json = await response.json()
                if response_json["status"] == "Failed":
                    raise Exception("Failed")
                if response_json["status"] == "Running":
                    raise ValueError("Running")
                return response_json

        return await poll()

    async def create_analyzer(self):
        logger.info("Creating analyzer '%s'...", self.analyzer_schema["analyzerId"])

        token_provider = get_bearer_token_provider(self.credential, "https://cognitiveservices.azure.com/.default")
        token = await token_provider()
        headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
        params = {"api-version": self.CU_API_VERSION}
        analyzer_id = self.analyzer_schema["analyzerId"]
        cu_endpoint = f"{self.endpoint}/contentunderstanding/analyzers/{analyzer_id}"
        async with aiohttp.ClientSession() as session:
            async with session.put(
                url=cu_endpoint, params=params, headers=headers, json=self.analyzer_schema
            ) as response:
                if response.status == 409:
                    logger.info("Analyzer '%s' already exists.", analyzer_id)
                    return
                elif response.status != 201:
                    data = await response.text()
                    raise Exception("Error creating analyzer", data)
                else:
                    poll_url = response.headers.get("Operation-Location")

            with Progress() as progress:
                progress.add_task("Creating analyzer...", total=None, start=False)
                await self.poll_api(session, poll_url, headers)

    async def describe_image(self, image_bytes: bytes) -> str:
        logger.info("Sending image to Azure Content Understanding service...")
        async with aiohttp.ClientSession() as session:
            token = await self.credential.get_token("https://cognitiveservices.azure.com/.default")
            headers = {"Authorization": "Bearer " + token.token}
            params = {"api-version": self.CU_API_VERSION}
            analyzer_name = self.analyzer_schema["analyzerId"]
            async with session.post(
                url=f"{self.endpoint}/contentunderstanding/analyzers/{analyzer_name}:analyze",
                params=params,
                headers=headers,
                data=image_bytes,
            ) as response:
                response.raise_for_status()
                poll_url = response.headers["Operation-Location"]

                with Progress() as progress:
                    progress.add_task("Processing...", total=None, start=False)
                    results = await self.poll_api(session, poll_url, headers)

                fields = results["result"]["contents"][0]["fields"]
                return fields["Description"]["valueString"]


In [None]:
class DocumentAnalysisParser(Parser):
    """
    Concrete parser backed by Azure AI Document Intelligence that can parse many document formats into pages
    To learn more, please visit https://learn.microsoft.com/azure/ai-services/document-intelligence/overview
    """

    def __init__(
        self,
        endpoint: str,
        credential: Union[AsyncTokenCredential, AzureKeyCredential],
        model_id="prebuilt-layout",
        use_content_understanding=True,
        content_understanding_endpoint: Union[str, None] = None,
    ):
        self.model_id = model_id
        self.endpoint = endpoint
        self.credential = credential
        self.use_content_understanding = use_content_understanding
        self.content_understanding_endpoint = content_understanding_endpoint

    async def parse(self, content: IO) -> AsyncGenerator[Page, None]:
        logger.info("Extracting text from '%s' using Azure Document Intelligence", content.name)

        async with DocumentIntelligenceClient(
            endpoint=self.endpoint, credential=self.credential
        ) as document_intelligence_client:
            file_analyzed = False
            if self.use_content_understanding:
                if self.content_understanding_endpoint is None:
                    raise ValueError("Content Understanding is enabled but no endpoint was provided")
                if isinstance(self.credential, AzureKeyCredential):
                    raise ValueError(
                        "AzureKeyCredential is not supported for Content Understanding, use keyless auth instead"
                    )
                cu_describer = ContentUnderstandingDescriber(self.content_understanding_endpoint, self.credential)
                content_bytes = content.read()
                try:
                    poller = await document_intelligence_client.begin_analyze_document(
                        model_id="prebuilt-layout",
                        analyze_request=AnalyzeDocumentRequest(bytes_source=content_bytes),
                        output=["figures"],
                        features=["ocrHighResolution"],
                        output_content_format="markdown",
                    )
                    doc_for_pymupdf = pymupdf.open(stream=io.BytesIO(content_bytes))
                    file_analyzed = True
                except HttpResponseError as e:
                    content.seek(0)
                    if e.error and e.error.code == "InvalidArgument":
                        logger.error(
                            "This document type does not support media description. Proceeding with standard analysis."
                        )
                    else:
                        logger.error(
                            "Unexpected error analyzing document for media description: %s. Proceeding with standard analysis.",
                            e,
                        )

            if file_analyzed is False:
                poller = await document_intelligence_client.begin_analyze_document(
                    model_id=self.model_id, analyze_request=content, content_type="application/octet-stream"
                )
            analyze_result: AnalyzeResult = await poller.result()

            offset = 0
            for page in analyze_result.pages:
                tables_on_page = [
                    table
                    for table in (analyze_result.tables or [])
                    if table.bounding_regions and table.bounding_regions[0].page_number == page.page_number
                ]
                figures_on_page = []
                if self.use_content_understanding:
                    figures_on_page = [
                        figure
                        for figure in (analyze_result.figures or [])
                        if figure.bounding_regions and figure.bounding_regions[0].page_number == page.page_number
                    ]

                class ObjectType(Enum):
                    NONE = -1
                    TABLE = 0
                    FIGURE = 1

                page_offset = page.spans[0].offset
                page_length = page.spans[0].length
                mask_chars: list[tuple[ObjectType, Union[int, None]]] = [(ObjectType.NONE, None)] * page_length
                # mark all positions of the table spans in the page
                for table_idx, table in enumerate(tables_on_page):
                    for span in table.spans:
                        # replace all table spans with "table_id" in table_chars array
                        for i in range(span.length):
                            idx = span.offset - page_offset + i
                            if idx >= 0 and idx < page_length:
                                mask_chars[idx] = (ObjectType.TABLE, table_idx)
                # mark all positions of the figure spans in the page
                for figure_idx, figure in enumerate(figures_on_page):
                    for span in figure.spans:
                        # replace all figure spans with "figure_id" in figure_chars array
                        for i in range(span.length):
                            idx = span.offset - page_offset + i
                            if idx >= 0 and idx < page_length:
                                mask_chars[idx] = (ObjectType.FIGURE, figure_idx)

                # build page text by replacing characters in table spans with table html
                page_text = ""
                added_objects = set()  # set of object types todo mypy
                for idx, mask_char in enumerate(mask_chars):
                    object_type, object_idx = mask_char
                    if object_type == ObjectType.NONE:
                        page_text += analyze_result.content[page_offset + idx]
                    elif object_type == ObjectType.TABLE:
                        if object_idx is None:
                            raise ValueError("Expected object_idx to be set")
                        if mask_char not in added_objects:
                            page_text += DocumentAnalysisParser.table_to_html(tables_on_page[object_idx])
                            added_objects.add(mask_char)
                    elif object_type == ObjectType.FIGURE:
                        if cu_describer is None:
                            raise ValueError("cu_describer should not be None, unable to describe figure")
                        if object_idx is None:
                            raise ValueError("Expected object_idx to be set")
                        if mask_char not in added_objects:
                            figure_html = await DocumentAnalysisParser.figure_to_html(
                                doc_for_pymupdf, figures_on_page[object_idx], cu_describer
                            )
                            page_text += figure_html
                            added_objects.add(mask_char)
                # We remove these comments since they are not needed and skew the page numbers
                page_text = page_text.replace("<!-- PageBreak -->", "")
                # We remove excess newlines at the beginning and end of the page
                page_text = page_text.strip()
                yield Page(page_num=page.page_number - 1, offset=offset, text=page_text)
                offset += len(page_text)

    @staticmethod
    async def figure_to_html(
        doc: pymupdf.Document, figure: DocumentFigure, cu_describer: ContentUnderstandingDescriber
    ) -> str:
        figure_title = (figure.caption and figure.caption.content) or ""
        logger.info("Describing figure %s with title '%s'", figure.id, figure_title)
        if not figure.bounding_regions:
            return f"<figure><figcaption>{figure_title}</figcaption></figure>"
        if len(figure.bounding_regions) > 1:
            logger.warning("Figure %s has more than one bounding region, using the first one", figure.id)
        first_region = figure.bounding_regions[0]
        # To learn more about bounding regions, see https://aka.ms/bounding-region
        bounding_box = (
            first_region.polygon[0],  # x0 (left)
            first_region.polygon[1],  # y0 (top
            first_region.polygon[4],  # x1 (right)
            first_region.polygon[5],  # y1 (bottom)
        )
        page_number = first_region["pageNumber"]  # 1-indexed
        cropped_img = DocumentAnalysisParser.crop_image_from_pdf_page(doc, page_number - 1, bounding_box)
        figure_description = await cu_describer.describe_image(cropped_img)
        return f"<figure><figcaption>{figure_title}<br>{figure_description}</figcaption></figure>"

    @staticmethod
    def table_to_html(table: DocumentTable):
        table_html = "<figure><table>"
        rows = [
            sorted([cell for cell in table.cells if cell.row_index == i], key=lambda cell: cell.column_index)
            for i in range(table.row_count)
        ]
        for row_cells in rows:
            table_html += "<tr>"
            for cell in row_cells:
                tag = "th" if (cell.kind == "columnHeader" or cell.kind == "rowHeader") else "td"
                cell_spans = ""
                if cell.column_span is not None and cell.column_span > 1:
                    cell_spans += f" colSpan={cell.column_span}"
                if cell.row_span is not None and cell.row_span > 1:
                    cell_spans += f" rowSpan={cell.row_span}"
                table_html += f"<{tag}{cell_spans}>{html.escape(cell.content)}</{tag}>"
            table_html += "</tr>"
        table_html += "</table></figure>"
        return table_html

    @staticmethod
    def crop_image_from_pdf_page(
        doc: pymupdf.Document, page_number: int, bbox_inches: tuple[float, float, float, float]
    ) -> bytes:
        """
        Crops a region from a given page in a PDF and returns it as an image.

        :param pdf_path: Path to the PDF file.
        :param page_number: The page number to crop from (0-indexed).
        :param bbox_inches: A tuple of (x0, y0, x1, y1) coordinates for the bounding box, in inches.
        :return: A PIL Image of the cropped area.
        """
        # Scale the bounding box to 72 DPI
        bbox_dpi = 72
        bbox_pixels = [x * bbox_dpi for x in bbox_inches]
        rect = pymupdf.Rect(bbox_pixels)
        # Assume that the PDF has 300 DPI,
        # and use the matrix to convert between the 2 DPIs
        page_dpi = 300
        page = doc.load_page(page_number)
        pix = page.get_pixmap(matrix=pymupdf.Matrix(page_dpi / bbox_dpi, page_dpi / bbox_dpi), clip=rect)

        img = Image.frombytes("RGB", (pix.width, pix.height), pix.samples)
        bytes_io = io.BytesIO()
        img.save(bytes_io, format="PNG")
        return bytes_io.getvalue()


In [None]:
def setup_file_processors(
    azure_credential: AsyncTokenCredential,
    document_intelligence_service: Union[str, None],
    document_intelligence_key: Union[str, None] = None,
    local_pdf_parser: bool = False,
    local_html_parser: bool = False,
    search_images: bool = False,
    use_content_understanding: bool = False,
    content_understanding_endpoint: Union[str, None] = None,
):
    sentence_text_splitter = SentenceTextSplitter()

    doc_int_parser: Optional[DocumentAnalysisParser] = None
    # check if Azure Document Intelligence credentials are provided
    if document_intelligence_service is not None:
        documentintelligence_creds: Union[AsyncTokenCredential, AzureKeyCredential] = (
            azure_credential if document_intelligence_key is None else AzureKeyCredential(document_intelligence_key)
        )
        doc_int_parser = DocumentAnalysisParser(
            endpoint=f"https://{document_intelligence_service}.cognitiveservices.azure.com/",
            credential=documentintelligence_creds,
            use_content_understanding=use_content_understanding,
            content_understanding_endpoint=content_understanding_endpoint,
        )

    pdf_parser: Optional[Parser] = None
    if local_pdf_parser or document_intelligence_service is None:
        pdf_parser = LocalPdfParser()
    elif document_intelligence_service is not None:
        pdf_parser = doc_int_parser
    else:
        logger.warning("No PDF parser available")

    html_parser: Optional[Parser] = None
    if local_html_parser or document_intelligence_service is None:
        html_parser = LocalHTMLParser()
    elif document_intelligence_service is not None:
        html_parser = doc_int_parser
    else:
        logger.warning("No HTML parser available")

    # These file formats can always be parsed:
    file_processors = {
        ".json": FileProcessor(JsonParser(), SimpleTextSplitter()),
        ".md": FileProcessor(TextParser(), sentence_text_splitter),
        ".txt": FileProcessor(TextParser(), sentence_text_splitter),
        ".csv": FileProcessor(CsvParser(), sentence_text_splitter),
    }
    # These require either a Python package or Document Intelligence
    if pdf_parser is not None:
        file_processors.update({".pdf": FileProcessor(pdf_parser, sentence_text_splitter)})
    if html_parser is not None:
        file_processors.update({".html": FileProcessor(html_parser, sentence_text_splitter)})
    # These file formats require Document Intelligence
    if doc_int_parser is not None:
        file_processors.update(
            {
                ".docx": FileProcessor(doc_int_parser, sentence_text_splitter),
                ".pptx": FileProcessor(doc_int_parser, sentence_text_splitter),
                ".xlsx": FileProcessor(doc_int_parser, sentence_text_splitter),
                ".png": FileProcessor(doc_int_parser, sentence_text_splitter),
                ".jpg": FileProcessor(doc_int_parser, sentence_text_splitter),
                ".jpeg": FileProcessor(doc_int_parser, sentence_text_splitter),
                ".tiff": FileProcessor(doc_int_parser, sentence_text_splitter),
                ".bmp": FileProcessor(doc_int_parser, sentence_text_splitter),
                ".heic": FileProcessor(doc_int_parser, sentence_text_splitter),
            }
        )
    return file_processors


In [None]:
# Set up ingester
file_processors = setup_file_processors(
    azure_credential=azure_credential,
    document_intelligence_service=os.getenv("AZURE_DOCUMENTINTELLIGENCE_SERVICE"),
    local_pdf_parser=os.getenv("USE_LOCAL_PDF_PARSER", "").lower() == "true",
    local_html_parser=os.getenv("USE_LOCAL_HTML_PARSER", "").lower() == "true",
    search_images=USE_GPT4V,
)