# Text Data Processing

This notebook demonstrates an end-to-end pipeline for processing text data using Databricks, focusing on:

1. Loading text files (PDF, HTML, DOCX) from a source volume
2. Chunking the transcribed text into manageable segments
3. Creating a vector search index for efficient retrieval

The notebook is structured in sequential steps, from data ingestion through to indexing, making it easy to understand and modify for your specific text processing needs. Each major section is clearly commented and includes relevant configuration parameters.

Key components used:
- Databricks Vector Search
- BGE embedding model for text vectorization


#### Setup and Configuration
 
This section defines key configuration parameters used throughout the notebook:
 
 - Unity Catalog settings (catalog, schema, volume names)
 - Model endpoints (Whisper AI, BGE embeddings) 
 - Delta table names
 - Vector search configuration


In [0]:
# Packages required by all code.
# Versions of Databricks code are not locked since Databricks ensures changes are backwards compatible.
# Versions of open source packages are locked since package authors often make backwards compatible changes
%pip install -qqqq -U \
  databricks-vectorsearch databricks-agents pydantic databricks-sdk mlflow mlflow-skinny `# For agent & data pipeline code` \
  pypdf==4.1.0  `# PDF parsing` \
  markdownify==0.12.1  `# HTML parsing` \
  pypandoc_binary==1.13  `# DOCX parsing` \
  transformers==4.41.1 torch==2.3.0 tiktoken==0.7.0 langchain-text-splitters==0.2.0. `# get_recursive_character_text_splitter`

# Restart to load the packages into the Python environment
dbutils.library.restartPython()

In [0]:
%run ../global_config

In [0]:
from typing import TypedDict
from datetime import datetime
import warnings
import io
import traceback
import os
from urllib.parse import urlparse

# PDF libraries
from pypdf import PdfReader

# HTML libraries
from markdownify import markdownify as md
import markdownify
import re

## DOCX libraries
import pypandoc
import tempfile

# Schema of the dict returned by `file_parser(...)`
class ParserReturnValue(TypedDict):
    # DO NOT CHANGE THESE NAMES - these are required by Evaluation & Framework
    # Parsed content of the document
    doc_content: str  # do not change this name
    # The status of whether the parser succeeds or fails, used to exclude failed files downstream
    parser_status: str  # do not change this name
    # Unique ID of the document
    doc_uri: str  # do not change this name

    # OK TO CHANGE THESE NAMES
    # Optionally, you can add additional metadata fields here
    example_metadata: str
    last_modified: datetime


# Parser function.  Replace this function to provide custom parsing logic.
def file_parser(
    raw_doc_contents_bytes: bytes,
    doc_path: str,
    modification_time: datetime,
    doc_bytes_length: int,
) -> ParserReturnValue:
    """
    Parses the content of a PDF document into a string.

    This function takes the raw bytes of a PDF document and its path, attempts to parse the document using PyPDF,
    and returns the parsed content and the status of the parsing operation.

    Parameters:
    - raw_doc_contents_bytes (bytes): The raw bytes of the document to be parsed (set by Spark when loading the file)
    - doc_path (str): The DBFS path of the document, used to verify the file extension (set by Spark when loading the file)
    - modification_time (timestamp): The last modification time of the document (set by Spark when loading the file)
    - doc_bytes_length (long): The size of the document in bytes (set by Spark when loading the file)

    Returns:
    - ParserReturnValue: A dictionary containing the parsed document content and the status of the parsing operation.
      The 'doc_content' key will contain the parsed text as a string, and the 'parser_status' key will indicate
      whether the parsing was successful or if an error occurred.
    """
    try:
        filename, file_extension = os.path.splitext(doc_path)

        if file_extension == ".pdf":
            pdf = io.BytesIO(raw_doc_contents_bytes)
            reader = PdfReader(pdf)

            parsed_content = [
                page_content.extract_text() for page_content in reader.pages
            ]

            parsed_document = {
                "doc_content": "\n".join(parsed_content),
                "parser_status": "SUCCESS",
            }
        elif file_extension == ".html":
            from markdownify import markdownify as md

            html_content = raw_doc_contents_bytes.decode("utf-8")

            markdown_contents = md(
                str(html_content).strip(), heading_style=markdownify.ATX
            )
            markdown_stripped = re.sub(r"\n{3,}", "\n\n", markdown_contents.strip())

            parsed_document = {
                "doc_content": markdown_stripped,
                "parser_status": "SUCCESS",
            }
        elif file_extension == ".docx":
            with tempfile.NamedTemporaryFile(delete=True) as temp_file:
                temp_file.write(raw_doc_contents_bytes)
                temp_file_path = temp_file.name
                md = pypandoc.convert_file(temp_file_path, "markdown", format="docx")

                parsed_document = {
                    "doc_content": md.strip(),
                    "parser_status": "SUCCESS",
                }
        else:
            raise Exception(f"No supported parser for {doc_path}")

        # Extract the required doc_uri
        # convert from `dbfs:/Volumes/catalog/schema/pdf_docs/filename.pdf` to `Volumes/catalog/schema/pdf_docs/filename.pdf`
        modified_path = urlparse(doc_path).path.lstrip('/')
        parsed_document["doc_uri"] = modified_path

        # Sample metadata extraction logic
        if "test" in parsed_document["doc_content"]:
            parsed_document["example_metadata"] = "test"
        else:
            parsed_document["example_metadata"] = "not test"

        # Add the modified time
        parsed_document["last_modified"] = modification_time

        return parsed_document

    except Exception as e:
        status = f"An error occurred: {e}\n{traceback.format_exc()}"
        warnings.warn(status)
        return {
            "doc_content": "",
            "parser_status": f"ERROR: {status}",
        }

#### `typed_dicts_to_spark_schema`

`typed_dicts_to_spark_schema` converts multiple TypedDicts into a Spark schema, allowing for the combination of multiple TypedDicts into a single Spark DataFrame schema. This function enables the resulting Delta Table to reflect the schema defined in `ParserReturnValue`.

Arguments:
- `*typed_dicts`: A variable number of TypedDict classes to be converted.

Returns:
- `StructType`: A Spark schema represented as a StructType object, which is a collection of StructField objects derived from the provided TypedDicts.


In [0]:
from pyspark.sql.types import (
    StructType,
    StructField,
    StringType,
    IntegerType,
    DoubleType,
    BooleanType,
    ArrayType,
    TimestampType,
    DateType,
)
from typing import TypedDict, get_type_hints, List
from datetime import datetime, date, time


def typed_dict_to_spark_fields(typed_dict: type[TypedDict]) -> StructType:
    """
    Converts a TypedDict into a list of Spark StructField objects.

    This function maps Python types defined in a TypedDict to their corresponding
    Spark SQL data types, facilitating the creation of a Spark DataFrame schema
    from Python type annotations.

    Parameters:
    - typed_dict (type[TypedDict]): The TypedDict class to be converted.

    Returns:
    - StructType: A list of StructField objects representing the Spark schema.

    Raises:
    - ValueError: If an unsupported type is encountered or if dictionary types are used.
    """

    # Mapping of type names to Spark type objects
    type_mapping = {
        str: StringType(),
        int: IntegerType(),
        float: DoubleType(),
        bool: BooleanType(),
        list: ArrayType(StringType()),  # Default to StringType for arrays
        datetime: TimestampType(),
        date: DateType(),
    }

    def get_spark_type(value_type):
        """
        Helper function to map a Python type to a Spark SQL data type.

        This function supports basic Python types, lists of a single type, and raises
        an error for unsupported types or dictionaries.

        Parameters:
        - value_type: The Python type to be converted.

        Returns:
        - DataType: The corresponding Spark SQL data type.

        Raises:
        - ValueError: If the type is unsupported or if dictionary types are used.
        """
        if value_type in type_mapping:
            return type_mapping[value_type]
        elif hasattr(value_type, "__origin__") and value_type.__origin__ == list:
            # Handle List[type] types
            return ArrayType(get_spark_type(value_type.__args__[0]))
        elif hasattr(value_type, "__origin__") and value_type.__origin__ == dict:
            # Handle Dict[type, type] types (not fully supported)
            raise ValueError("Dict types are not fully supported")
        else:
            raise ValueError(f"Unsupported type: {value_type}")

    # Get the type hints for the TypedDict
    type_hints = get_type_hints(typed_dict)

    # Convert the type hints into a list of StructField objects
    fields = [
        StructField(key, get_spark_type(value), True)
        for key, value in type_hints.items()
    ]

    # Create and return the StructType object
    return fields

In [0]:
def typed_dicts_to_spark_schema(*typed_dicts: type[TypedDict]) -> StructType:
    """
    Converts multiple TypedDicts into a Spark schema.

    This function allows for the combination of multiple TypedDicts into a single
    Spark DataFrame schema, enabling the creation of complex data structures.

    Parameters:
    - *typed_dicts: Variable number of TypedDict classes to be converted.

    Returns:
    - StructType: A Spark schema represented as a StructType object, which is a collection
      of StructField objects derived from the provided TypedDicts.
    """
    fields = []
    for typed_dict in typed_dicts:
        fields.extend(typed_dict_to_spark_fields(typed_dict))

    return StructType(fields)

In [0]:
from langchain_text_splitters import RecursiveCharacterTextSplitter
from transformers import AutoTokenizer
import tiktoken
from typing import Callable, Tuple, Optional
import os
import re
from databricks.sdk import WorkspaceClient

# Constants
HF_CACHE_DIR = "/tmp/hf_cache/"

# Embedding Models Configuration
EMBEDDING_MODELS = {
    "gte-large-en-v1.5": {
        "tokenizer": lambda: AutoTokenizer.from_pretrained(
            "Alibaba-NLP/gte-large-en-v1.5", cache_dir=HF_CACHE_DIR
        ),
        "context_window": 8192,
        "type": "SENTENCE_TRANSFORMER",
    },
    "bge-large-en-v1.5": {
        "tokenizer": lambda: AutoTokenizer.from_pretrained(
            "BAAI/bge-large-en-v1.5", cache_dir=HF_CACHE_DIR
        ),
        "context_window": 512,
        "type": "SENTENCE_TRANSFORMER",
    },
    "bge_large_en_v1_5": {
        "tokenizer": lambda: AutoTokenizer.from_pretrained(
            "BAAI/bge-large-en-v1.5", cache_dir=HF_CACHE_DIR
        ),
        "context_window": 512,
        "type": "SENTENCE_TRANSFORMER",
    },
    "text-embedding-ada-002": {
        "context_window": 8192,
        "tokenizer": lambda: tiktoken.encoding_for_model("text-embedding-ada-002"),
        "type": "OPENAI",
    },
    "text-embedding-3-small": {
        "context_window": 8192,
        "tokenizer": lambda: tiktoken.encoding_for_model("text-embedding-3-small"),
        "type": "OPENAI",
    },
    "text-embedding-3-large": {
        "context_window": 8192,
        "tokenizer": lambda: tiktoken.encoding_for_model("text-embedding-3-large"),
        "type": "OPENAI",
    },
}


def get_workspace_client() -> WorkspaceClient:
    """Returns a WorkspaceClient instance."""
    return WorkspaceClient()


def get_embedding_model_config(endpoint_type: str) -> Optional[dict]:
    """
    Retrieve embedding model configuration by endpoint type.
    """
    return EMBEDDING_MODELS.get(endpoint_type)


def extract_endpoint_type(llm_endpoint) -> Optional[str]:
    """
    Extract the endpoint type from the given llm_endpoint object.
    """
    try:
        return llm_endpoint.config.served_entities[0].external_model.name
    except AttributeError:
        try:
            return llm_endpoint.config.served_entities[0].foundation_model.name
        except AttributeError:
            return None


def detect_fmapi_embedding_model_type(
    model_serving_endpoint: str,
) -> Tuple[Optional[str], Optional[dict]]:
    """
    Detects the embedding model type and configuration for the given endpoint.
    Returns a tuple of (endpoint_type, embedding_config) or (None, None) if not found.
    """
    client = get_workspace_client()

    try:
        llm_endpoint = client.serving_endpoints.get(name=model_serving_endpoint)
        endpoint_type = extract_endpoint_type(llm_endpoint)
    except Exception as e:
        endpoint_type = None

    embedding_config = (
        get_embedding_model_config(endpoint_type) if endpoint_type else None
    )
    return (endpoint_type, embedding_config)


def validate_chunk_size(chunk_spec: dict):
    """
    Validate the chunk size and overlap settings in chunk_spec.
    Raises ValueError if any condition is violated.
    """
    if (
        chunk_spec["chunk_overlap_tokens"] + chunk_spec["chunk_size_tokens"]
    ) > chunk_spec["context_window"]:
        raise ValueError(
            f'Proposed chunk_size of {chunk_spec["chunk_size_tokens"]} + overlap of {chunk_spec["chunk_overlap_tokens"]} '
            f'is {chunk_spec["chunk_overlap_tokens"] + chunk_spec["chunk_size_tokens"]} which is greater than context '
            f'window of {chunk_spec["context_window"]} tokens.'
        )

    if chunk_spec["chunk_overlap_tokens"] > chunk_spec["chunk_size_tokens"]:
        raise ValueError(
            f'Proposed `chunk_overlap_tokens` of {chunk_spec["chunk_overlap_tokens"]} is greater than the '
            f'`chunk_size_tokens` of {chunk_spec["chunk_size_tokens"]}. Reduce the size of `chunk_size_tokens`.'
        )


def get_recursive_character_text_splitter(
    model_serving_endpoint: str,
    embedding_model_name: str = None,
    chunk_size_tokens: int = None,
    chunk_overlap_tokens: int = 0,
) -> Callable[[str], list[str]]:
    try:
        # Detect the embedding model and its configuration
        embedding_model_name, chunk_spec = detect_fmapi_embedding_model_type(
            model_serving_endpoint
        )

        if chunk_spec is None or embedding_model_name is None:
            # Fall back to using provided embedding_model_name
            chunk_spec = EMBEDDING_MODELS.get(embedding_model_name)
            if chunk_spec is None:
                raise KeyError

        # Update chunk specification based on provided parameters
        chunk_spec["chunk_size_tokens"] = (
            chunk_size_tokens or chunk_spec["context_window"]
        )
        chunk_spec["chunk_overlap_tokens"] = chunk_overlap_tokens

        # Validate chunk size and overlap
        validate_chunk_size(chunk_spec)

        print(f'Chunk size in tokens: {chunk_spec["chunk_size_tokens"]}')
        print(f'Chunk overlap in tokens: {chunk_spec["chunk_overlap_tokens"]}')
        context_usage = (
            round(
                (chunk_spec["chunk_size_tokens"] + chunk_spec["chunk_overlap_tokens"])
                / chunk_spec["context_window"],
                2,
            )
            * 100
        )
        print(
            f'Using {context_usage}% of the {chunk_spec["context_window"]} token context window.'
        )

    except KeyError:
        raise ValueError(
            f"Embedding model `{embedding_model_name}` not found. Available models: {EMBEDDING_MODELS.keys()}"
        )

    def _recursive_character_text_splitter(text: str) -> list[str]:
        tokenizer = chunk_spec["tokenizer"]()
        if chunk_spec["type"] == "SENTENCE_TRANSFORMER":
            splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
                tokenizer,
                chunk_size=chunk_spec["chunk_size_tokens"],
                chunk_overlap=chunk_spec["chunk_overlap_tokens"],
            )
        elif chunk_spec["type"] == "OPENAI":
            splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
                tokenizer.name,
                chunk_size=chunk_spec["chunk_size_tokens"],
                chunk_overlap=chunk_spec["chunk_overlap_tokens"],
            )
        else:
            raise ValueError(f"Unsupported model type: {chunk_spec['type']}")
        return splitter.split_text(text)

    return _recursive_character_text_splitter

In [0]:
import traceback
from datetime import datetime
from typing import Any, Callable, TypedDict, Dict
import os
from IPython.display import display_markdown
import warnings
import pyspark.sql.functions as func
from pyspark.sql.types import StructType
from pyspark.sql import DataFrame, SparkSession


def _parse_and_extract(
    raw_doc_contents_bytes: bytes,
    modification_time: datetime,
    doc_bytes_length: int,
    doc_path: str,
    parse_file_udf: Callable[[[dict, Any]], str],
) -> Dict[str, Any]:
    """Parses raw bytes & extract metadata."""

    try:
        # Run the parser
        parser_output_dict = parse_file_udf(
            raw_doc_contents_bytes=raw_doc_contents_bytes,
            doc_path=doc_path,
            modification_time=modification_time,
            doc_bytes_length=doc_bytes_length,
        )

        if parser_output_dict.get("parser_status") == "SUCCESS":
            return parser_output_dict
        else:
            raise Exception(parser_output_dict.get("parser_status"))

    except Exception as e:
        status = f"An error occurred: {e}\n{traceback.format_exc()}"
        warnings.warn(status)
        return {
            "doc_content": "",
            "doc_uri": "",
            "parser_status": status,
        }


def _get_parser_udf(
    # extract_metadata_udf: Callable[[[dict, Any]], str],
    parse_file_udf: Callable[[[dict, Any]], str],
    spark_dataframe_schema: StructType,
):
    """Gets the Spark UDF which will parse the files in parallel.

    Arguments:
      - extract_metadata_udf: A function that takes parsed content and extracts the metadata
      - parse_file_udf: A function that takes the raw file and returns the parsed text.
      - spark_dataframe_schema: The resulting schema of the document delta table
    """
    # This UDF will load each file, parse the doc, and extract metadata.
    parser_udf = func.udf(
        lambda raw_doc_contents_bytes, modification_time, doc_bytes_length, doc_path: _parse_and_extract(
            raw_doc_contents_bytes,
            modification_time,
            doc_bytes_length,
            doc_path,
            parse_file_udf,
        ),
        returnType=spark_dataframe_schema,
    )
    return parser_udf

def load_files_to_df(
    spark: SparkSession,
    source_path: str) -> DataFrame:
    """
    Load files from a directory into a Spark DataFrame.
    Each row in the DataFrame will contain the path, length, and content of the file; for more
    details, see https://spark.apache.org/docs/latest/sql-data-sources-binaryFile.html
    """

    if not os.path.exists(source_path):
        raise ValueError(
            f"{source_path} passed to `load_uc_volume_files` does not exist."
        )

    # Load the raw riles
    raw_files_df = (
        spark.read.format("binaryFile").option("recursiveFileLookup", "true")
        .load(source_path)
    )

    # Check that files were present and loaded
    if raw_files_df.count() == 0:
        raise Exception(f"`{source_path}` does not contain any files.")

    print(f"Found {raw_files_df.count()} files in {source_path}.")
    raw_files_df.show()
    return raw_files_df


def apply_parsing_udf(raw_files_df: DataFrame, parse_file_udf: Callable[[[dict, Any]], str], parsed_df_schema: StructType) -> DataFrame:
    """
    Apply a file-parsing UDF to a DataFrame whose rows correspond to file content/metadata loaded via
    https://spark.apache.org/docs/latest/sql-data-sources-binaryFile.html
    Returns a DataFrame with the parsed content and metadata.
    """
    print("Running parsing & metadata extraction UDF in spark...")

    parser_udf = _get_parser_udf(parse_file_udf, parsed_df_schema)

    # Run the parsing
    parsed_files_staging_df = raw_files_df.withColumn(
        "parsing", parser_udf("content", "modificationTime", "length", "path")
    ).drop("content")

    # Check and warn on any errors
    errors_df = parsed_files_staging_df.filter(
        func.col(f"parsing.parser_status") != "SUCCESS"
    )

    num_errors = errors_df.count()
    if num_errors > 0:
        display_markdown(
            f"### {num_errors} documents had parse errors. Please review.", raw=True
        )
        errors_df.show()

        if errors_df.count() == parsed_files_staging_df.count():
            raise ValueError(
                "All documents produced an error during parsing. Please review."
            )

    num_empty_content = errors_df.filter(func.col("parsing.doc_content") == "").count()
    if num_empty_content > 0:
        display_markdown(
            f"### {num_errors} documents have no content. Please review.", raw=True
        )
        errors_df.show()

        if num_empty_content == parsed_files_staging_df.count():
            raise ValueError("All documents are empty. Please review.")

    # Filter for successfully parsed files
    # Change the schema to the resulting schema
    resulting_fields = [field.name for field in parsed_df_schema.fields]

    parsed_files_df = parsed_files_staging_df.filter(
        parsed_files_staging_df.parsing.parser_status == "SUCCESS"
    )

    parsed_files_df.show()
    parsed_files_df = parsed_files_df.select(
        *[func.col(f"parsing.{field}").alias(field) for field in resulting_fields]
    )
    return parsed_files_df


In [0]:
from typing import Dict, Any
import logging
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, lit, size, split, cast
from pyspark.sql.types import StructType, DoubleType

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def process_text_data(
    spark: SparkSession,
    source_path: str,
    parse_file_udf: Any,
    parsed_df_schema: StructType,
    output_table: str
) -> None:
    """
    Process text files and create a standardized table with renamed columns.
    
    Args:
        spark: SparkSession
        source_path: Path to the source text files
        parse_file_udf: UDF for parsing files
        parsed_df_schema: Schema for the parsed DataFrame
        output_table: Full name of the output table (catalog.schema.table)
    """
    logger.info(f"Processing text files from {source_path}")
    
    # Load raw files
    raw_files_df = load_files_to_df(
        spark=spark,
        source_path=source_path,
    )
    
    # Apply parsing UDF
    parsed_files_df = apply_parsing_udf(
        raw_files_df=raw_files_df,
        parse_file_udf=parse_file_udf,
        parsed_df_schema=parsed_df_schema
    )
    
    # First, create a DataFrame with the word count calculation
    word_count_df = parsed_files_df.withColumn(
        "word_count",
        size(split(col("doc_content"), "\\s+"))
    )
    
    # Now create the final DataFrame with all columns
    processed_df = word_count_df.select(
        col("doc_content"),  # Keep as is
        col("parser_status"),  # Keep as is
        col("doc_uri").alias("path"),  # Rename doc_uri to path
        col("last_modified").alias("modificationTime"),  # Rename last_modified to modificationTime
        lit("docs").alias("modality"),  # Add modality column with value "docs"
        col("word_count").cast("double").alias("length")  # Cast word count to double and rename to length
    )
    
    # Write to Delta table
    processed_df.write.mode("overwrite").option(
        "overwriteSchema", "true"
    ).saveAsTable(output_table)
    
    logger.info(f"Processed {processed_df.count()} documents and saved to {output_table}")
    
    # Display for debugging
    processed_df.display()

# Example usage:
"""
from your_module import load_files_to_df, apply_parsing_udf, typed_dicts_to_spark_schema, ParserReturnValue

# Configure the processor
process_text_data(
    spark=spark,
    source_path=f'/Volumes/{UC_CATALOG_NAME}/{UC_SCHEMA_NAME}/{UC_VOLUME_NAME}/text_data/',
    parse_file_udf=file_parser,
    parsed_df_schema=typed_dicts_to_spark_schema(ParserReturnValue),
    output_table=f"{UC_CATALOG_NAME}.{UC_SCHEMA_NAME}.{DOCS_DATA_TABLE_NAME}"
)
""" 

In [0]:
process_text_data(
    spark=spark,
    source_path=f'/Volumes/{UC_CATALOG_NAME}/{UC_SCHEMA_NAME}/{UC_VOLUME_NAME}/text_data/',
    parse_file_udf=file_parser,
    parsed_df_schema=typed_dicts_to_spark_schema(ParserReturnValue),
    output_table=f"{UC_CATALOG_NAME}.{UC_SCHEMA_NAME}.{DOCS_DATA_TABLE_NAME}"
)

In [0]:
from typing import Literal, Optional, Any, Callable
from databricks.vector_search.client import VectorSearchClient
from pyspark.sql.functions import explode
import pyspark.sql.functions as func
from typing import Callable
from langchain_text_splitters import RecursiveCharacterTextSplitter
from transformers import AutoTokenizer
import tiktoken
from pyspark.sql.types import StructType, StringType, StructField, MapType, ArrayType
from pyspark.sql import SparkSession
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def standardize_timestamp_format(df):
    """
    Standardize timestamp format in a DataFrame.
    Converts any timestamp column to a consistent format.
    
    Args:
        df: Input DataFrame
    Returns:
        DataFrame with standardized timestamps
    """
    if "modificationTime" in df.columns:
        return df.withColumn(
            "modificationTime",
            func.to_timestamp(func.col("modificationTime"))
        )
    return df

def compute_chunks(
    docs_table: str,
    doc_column: str,
    chunk_fn: Callable[[str], list[str]],
    propagate_columns: list[str],
    chunked_docs_table: str,
    modality: str,
) -> str:
    """
    Compute chunks from a document table and append them to an existing chunked table.
    
    Args:
        docs_table: Source table containing documents
        doc_column: Column name containing the text to chunk
        chunk_fn: Function to split text into chunks
        propagate_columns: List of columns to propagate from the docs table to chunks table
        chunked_docs_table: Target table for storing chunks
        modality: Type of content (e.g., 'video', 'audio', 'pdf')
    Returns:
        str: Name of the chunked table
    """
    logger.info(f"Computing chunks for `{docs_table}`...")
    
    # Initialize Spark session if not already available
    spark = SparkSession.builder.getOrCreate()
    
    # Read source documents
    raw_docs = spark.read.table(docs_table)
    
    # Check if modality column exists in source table
    source_has_modality = "modality" in raw_docs.columns
    
    # Create UDF for chunking
    parser_udf = func.udf(
        chunk_fn,
        returnType=ArrayType(StringType()),
    )
    
    # Process documents into chunks
    chunked_array_docs = raw_docs.withColumn(
        "content_chunked", parser_udf(doc_column)
    ).drop(doc_column)
    
    # Select columns to propagate, excluding modality if it exists
    columns_to_propagate = [col for col in propagate_columns if col != "modality"]
    
    chunked_docs = chunked_array_docs.select(
        *columns_to_propagate, explode("content_chunked").alias("content_chunked")
    )
    
    # Add chunk_id
    chunks_with_ids = chunked_docs.withColumn(
        "chunk_id", func.md5(func.col("content_chunked"))
    )
    
    # Add modality column if it doesn't exist in source
    if not source_has_modality:
        chunks_with_ids = chunks_with_ids.withColumn("modality", func.lit(modality))
    
    # Check if target table exists and get its schema
    try:
        target_schema = spark.read.table(chunked_docs_table).schema
        table_exists = True
    except Exception:
        table_exists = False
        target_schema = None
    
    if table_exists:
        target_has_modality = "modality" in [field.name for field in target_schema]
        
        # If target has modality but source doesn't, add it
        if target_has_modality and not source_has_modality:
            chunks_with_ids = chunks_with_ids.withColumn("modality", func.lit(modality))
    
    # Standardize timestamp format
    chunks_with_ids = standardize_timestamp_format(chunks_with_ids)
    
    # Reorder columns for better display
    final_columns = ["chunk_id", "content_chunked"]
    if "modality" in chunks_with_ids.columns:
        final_columns.append("modality")
    final_columns.extend(columns_to_propagate)
    
    chunks_with_ids = chunks_with_ids.select(*final_columns)
    
    if table_exists:
        # Read existing chunks
        existing_chunks = spark.read.table(chunked_docs_table)
        
        # Get existing chunk IDs
        existing_ids = existing_chunks.select("chunk_id").distinct()
        
        # Filter out chunks that already exist
        new_chunks = chunks_with_ids.join(
            existing_ids,
            chunks_with_ids.chunk_id == existing_ids.chunk_id,
            "left_anti"
        )
        
        logger.info(f"Found {chunks_with_ids.count()} total chunks, {new_chunks.count()} new chunks")
        
        # Append only new chunks
        if new_chunks.count() > 0:
            new_chunks.write.mode("append").saveAsTable(chunked_docs_table)
            logger.info(f"Appended {new_chunks.count()} new chunks to {chunked_docs_table}")
        else:
            logger.info("No new chunks to append")
    else:
        # Create new table if it doesn't exist
        chunks_with_ids.write.mode("overwrite").option(
            "overwriteSchema", "true"
        ).saveAsTable(chunked_docs_table)
        logger.info(f"Created new table {chunked_docs_table} with {chunks_with_ids.count()} chunks")
    
    return chunked_docs_table

# Example usage code
def process_video_chunks():
    """
    Example function to process video transcripts into chunks
    """
    logger.info("Starting video chunk processing...")
    
    # Configure the chunker
    chunk_fn = get_recursive_character_text_splitter(
        model_serving_endpoint=EMBEDDING_MODEL_ENDPOINT,
        chunk_size_tokens=384,
        chunk_overlap_tokens=128,
    )
    
    # Get source table schema
    source_table = f"{UC_CATALOG_NAME}.{UC_SCHEMA_NAME}.{AUDIO_DATA_TABLE_NAME}"
    source_schema = spark.table(source_table).schema
    
    # Log source table columns
    logger.info(f"Source table columns: {[field.name for field in source_schema]}")
    
    # Get the columns to propagate
    # Exclude only the columns we definitely don't want
    propagate_columns = [
        field.name
        for field in source_schema
        if field.name not in ["transcript_text", "chunk_count"]  # Keep name and modality
    ]
    
    logger.info(f"Propagating columns: {propagate_columns}")
    
    # Process chunks
    chunked_docs_table = compute_chunks(
        docs_table=source_table,
        doc_column="transcript_text",
        chunk_fn=chunk_fn,
        propagate_columns=propagate_columns,
        chunked_docs_table=CHUNKED_DOCS_DELTA_TABLE,
        modality="video"
    )
    
    # Display results
    result_df = spark.read.table(chunked_docs_table)
    logger.info(f"Chunked table schema: {result_df.schema}")
    logger.info(f"Number of chunks created: {result_df.count()}")
    
    return result_df

# To run the processing:
# result_df = process_video_chunks()
# display(result_df)

# Example usage:
"""
# Define chunking function
def chunk_text(text: str) -> list[str]:
    # Your chunking logic here
    pass

# Compute chunks for video transcripts
compute_chunks(
    docs_table="ankit_yadav.fluke_schema.video_data_text",
    doc_column="transcript_text",
    chunk_fn=chunk_text,
    propagate_columns=["name", "path", "length"],
    chunked_docs_table="ankit_yadav.fluke_schema.content_chunks",
    modality="video"
)
""" 

In [0]:
# Configure the chunker
chunk_fn = get_recursive_character_text_splitter(
    model_serving_endpoint=EMBEDDING_MODEL_ENDPOINT,
    chunk_size_tokens=384,
    chunk_overlap_tokens=128,
)

# Get the columns from the parser except for the doc_content
# You can modify this to adjust which fields are propagated from the docs table to the chunks table.
propagate_columns = [
    field.name
    for field in spark.table(f"{UC_CATALOG_NAME}.{UC_SCHEMA_NAME}.{DOCS_DATA_TABLE_NAME}").schema.fields
    if field.name not in ["doc_content","parser_status"]
]

chunked_docs_table = compute_chunks(
    # The source documents table.
    docs_table=f"{UC_CATALOG_NAME}.{UC_SCHEMA_NAME}.{DOCS_DATA_TABLE_NAME}",
    # The column containing the documents to be chunked.
    doc_column="doc_content",
    # The chunking function that takes a string (document) and returns a list of strings (chunks).
    chunk_fn=chunk_fn,
    # Choose which columns to propagate from the docs table to chunks table. `doc_uri` column is required we can propagate the original document URL to the Agent's web app.
    propagate_columns=propagate_columns,
    # By default, the chunked_docs_table will be written to `{docs_table}_chunked`.
    chunked_docs_table=f"{CHUNKED_DOCS_DELTA_TABLE}",
    modality="docs"
)

display(spark.read.table(chunked_docs_table))

In [0]:
from typing import TypedDict, Dict
import io
from typing import List, Dict, Any, Tuple, Optional, TypedDict
import warnings
import pyspark.sql.functions as func
from pyspark.sql.types import StructType, StringType, StructField, MapType, ArrayType
from mlflow.utils import databricks_utils as du
from functools import partial
import tiktoken
from transformers import AutoTokenizer
from langchain_text_splitters import RecursiveCharacterTextSplitter
from databricks.vector_search.client import VectorSearchClient
import mlflow


def _build_index(
    primary_key: str,
    embedding_source_column: str,
    vector_search_endpoint: str,
    chunked_docs_table_name: str,
    vectorsearch_index_name: str,
    embedding_endpoint_name: str,
    force_delete=False,
):

    # Get the vector search index
    vsc = VectorSearchClient(disable_notice=True)

    def find_index(endpoint_name, index_name):
        all_indexes = vsc.list_indexes(name=vector_search_endpoint).get(
            "vector_indexes", []
        )
        return vectorsearch_index_name in map(lambda i: i.get("name"), all_indexes)

    if find_index(
        endpoint_name=vector_search_endpoint, index_name=vectorsearch_index_name
    ):
        if force_delete:
            vsc.delete_index(
                endpoint_name=vector_search_endpoint, index_name=vectorsearch_index_name
            )
            create_index = True
        else:
            create_index = False
            print(
                f"Syncing index {vectorsearch_index_name}, this can take 15 minutes or much longer if you have a larger number of documents..."
            )

            sync_result = vsc.get_index(index_name=vectorsearch_index_name).sync()

    else:
        print(
            f'Creating non-existent vector search index for endpoint "{vector_search_endpoint}" and index "{vectorsearch_index_name}"'
        )
        create_index = True

    if create_index:
        print(
            f"Computing document embeddings and Vector Search Index. This can take 15 minutes or much longer if you have a larger number of documents."
        )

        vsc.create_delta_sync_index_and_wait(
            endpoint_name=vector_search_endpoint,
            index_name=vectorsearch_index_name,
            primary_key=primary_key,
            source_table_name=chunked_docs_table_name,
            pipeline_type="TRIGGERED",
            embedding_source_column=embedding_source_column,
            embedding_model_endpoint_name=embedding_endpoint_name,
        )

In [0]:
from pydantic import BaseModel


class RetrieverIndexResult(BaseModel):
    vector_search_endpoint: str
    vector_search_index_name: str
    embedding_endpoint_name: str
    chunked_docs_table: str


def build_retriever_index(
    chunked_docs_table: str,
    primary_key: str,
    embedding_source_column: str,
    embedding_endpoint_name: str,
    vector_search_endpoint: str,
    vector_search_index_name: str,
    force_delete_vector_search_endpoint=False,
) -> RetrieverIndexResult:

    retriever_index_result = RetrieverIndexResult(
        vector_search_endpoint=vector_search_endpoint,
        vector_search_index_name=vector_search_index_name,
        embedding_endpoint_name=embedding_endpoint_name,
        chunked_docs_table=chunked_docs_table,
    )

    # Enable CDC for Vector Search Delta Sync
    spark.sql(
        f"ALTER TABLE {chunked_docs_table} SET TBLPROPERTIES (delta.enableChangeDataFeed = true)"
    )

    print("Building embedding index...")
    # Building the index.
    _build_index(
        primary_key=primary_key,
        embedding_source_column=embedding_source_column,
        vector_search_endpoint=vector_search_endpoint,
        chunked_docs_table_name=chunked_docs_table,
        vectorsearch_index_name=vector_search_index_name,
        embedding_endpoint_name=embedding_endpoint_name,
        force_delete=force_delete_vector_search_endpoint,
    )

    return retriever_index_result

In [0]:
retriever_index_result = build_retriever_index(
    # Spark requires `` to escape names with special chars, VS client does not.
    chunked_docs_table=CHUNKED_DOCS_DELTA_TABLE.replace("`", ""),
    primary_key="chunk_id",
    embedding_source_column="content_chunked",
    vector_search_endpoint=VECTOR_SEARCH_ENDPOINT,
    vector_search_index_name=VECTOR_INDEX_NAME,
    # Must match the embedding endpoint you used to chunk your documents
    embedding_endpoint_name=EMBEDDING_MODEL_ENDPOINT,
    # Set to true to re-create the vector search endpoint when re-running.
    force_delete_vector_search_endpoint=False,
)

print(retriever_index_result)

print()
print("Vector search index created! This will be used in the next notebook.")
print(f"Vector search endpoint: {retriever_index_result.vector_search_endpoint}")
print(f"Vector search index: {retriever_index_result.vector_search_index_name}")
print(f"Embedding used: {retriever_index_result.embedding_endpoint_name}")
print(f"Chunked docs table: {retriever_index_result.chunked_docs_table}")