# Audio Data Processing

This notebook demonstrates an end-to-end pipeline for processing audio data using Databricks, focusing on:

1. Loading audio files (.wav) from a source volume
2. Transcribing audio to text using Whisper AI
3. Chunking the transcribed text into manageable segments
4. Creating a vector search index for efficient retrieval

The notebook is structured in sequential steps, from data ingestion through to indexing, making it easy to understand and modify for your specific audio processing needs. Each major section is clearly commented and includes relevant configuration parameters.

Key components used:
- Databricks Vector Search
- Whisper AI for transcription
- BGE embedding model for text vectorization


#### Setup and Configuration
 
This section defines key configuration parameters used throughout the notebook:
 
 - Unity Catalog settings (catalog, schema, volume names)
 - Model endpoints (Whisper AI, BGE embeddings) 
 - Delta table names
 - Vector search configuration


In [0]:
# Packages required by all code.
# Versions of Databricks code are not locked since Databricks ensures changes are backwards compatible.
# Versions of open source packages are locked since package authors often make backwards compatible changes
%pip install -qqqq -U \
  databricks-vectorsearch databricks-agents pydantic databricks-sdk mlflow mlflow-skinny `# For agent & data pipeline code` \
  transformers==4.41.1 torch==2.3.0 tiktoken==0.7.0 langchain-text-splitters==0.2.0. `# get_recursive_character_text_splitter`
%pip install pyannote.audio torch torchaudio soundfile numpy pydub
# Restart to load the packages into the Python environment
dbutils.library.restartPython()

In [0]:
%run ../global_config

#### Loading Audio Data
This cell reads audio files (.wav) from the specified volume path using Spark's binaryFile format. The audio files are loaded as binary data, preserving their original format for processing.

"""
## Audio Processing Pipeline 

### Overview
This code implements an audio processing pipeline that processes audio files using PySpark and the Whisper AI model. The pipeline is designed to handle large audio files by breaking them into manageable chunks, process them through a Whisper endpoint, and store the results in a Delta table.

### Key Features
- Binary file reading and processing
- Automatic chunking of large audio files
- Integration with Whisper AI for transcription
- Delta table storage of results
- Comprehensive logging and performance monitoring

### Process Flow
1. **File Reading**: Reads audio files as binary files from a specified source path
2. **Chunking**: Splits large audio files into optimal-sized chunks (2MB-10MB)
3. **Processing**: Sends each chunk to a Whisper endpoint for transcription
4. **Storage**: Saves results to a Delta table with the following columns:
   - modality: Type of content (audio)
   - path: Original file path
   - modificationTime: File modification timestamp
   - length: File size
   - binary_content: Base64 encoded audio content
   - transcript_text: Whisper-generated transcription

### Technical Details

#### Chunking Strategy
- Minimum chunk size: 2MB
- Maximum chunk size: 10MB
- Target chunks per file: 30
- Dynamic chunk size calculation based on file size

#### Performance Considerations
- Uses PySpark for distributed processing
- Implements efficient binary data handling
- Includes performance logging and timing metrics

#### Error Handling
- Comprehensive error logging
- Graceful failure handling for chunking process
- Configurable error handling for Whisper endpoint calls

### Usage Example
```python
process_audio_data(
    spark=spark,
    source_path='/Volumes/catalog/schema/volume/audio_folder/',
    output_table="catalog.schema.audio_data_table",
    whisper_endpoint="whisper-endpoint-name"
)
```

### Notes
- Ensure sufficient memory allocation for processing large audio files
- Monitor Whisper endpoint capacity and performance
- Consider implementing retry logic for failed transcriptions
"""


In [0]:
from typing import Dict, Any, List
import logging
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, lit, expr, udf, unbase64, element_at
from pyspark.sql.types import DoubleType, ArrayType, StructType, StructField, BinaryType, IntegerType, StringType
import os
import psutil
import time
import math
from datetime import datetime
import io
import base64

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S'
)

# Create logger
logger = logging.getLogger('audio_processor')

# Suppress noisy loggers
logging.getLogger('pyspark').setLevel(logging.WARNING)
logging.getLogger('py4j').setLevel(logging.WARNING)
logging.getLogger('matplotlib').setLevel(logging.WARNING)
logging.getLogger('PIL').setLevel(logging.WARNING)

# Constants for chunking
MIN_CHUNK_SIZE = 2 * 1024 * 1024  # 2MB minimum chunk size
MAX_CHUNK_SIZE = 10 * 1024 * 1024  # 10MB maximum chunk size
TARGET_CHUNKS = 30  # Target number of chunks per file

def calculate_optimal_chunk_size(file_size):
    """
    Calculate optimal chunk size based on file size
    Args:
        file_size: Size of the file in bytes
    Returns:
        int: Optimal chunk size in bytes
    """
    # Calculate chunk size to get approximately TARGET_CHUNKS chunks
    chunk_size = file_size / TARGET_CHUNKS
    
    # Ensure chunk size is within bounds
    return max(MIN_CHUNK_SIZE, min(chunk_size, MAX_CHUNK_SIZE))

def chunk_binary_data(binary_content):
    """
    Split binary data into chunks
    Args:
        binary_content: Binary content of the file
    Returns:
        list: List of chunk dictionaries
    """
    try:
        file_size = len(binary_content)
        chunk_size = calculate_optimal_chunk_size(file_size)
        num_chunks = math.ceil(file_size / chunk_size)
        
        chunks = []
        for i in range(num_chunks):
            start = i * chunk_size
            end = min(start + chunk_size, file_size)
            chunk = binary_content[start:end]
            
            # Encode the chunk as base64
            encoded_chunk = base64.b64encode(chunk).decode('utf-8')
            
            chunks.append({
                'chunk_index': i,
                'size': len(chunk),
                'binary_content': encoded_chunk  # Store as base64 encoded string
            })
        
        return chunks
    except Exception as e:
        logger.error(f"Error chunking binary data: {str(e)}")
        return []

# Define the schema for the chunked data
chunk_schema = ArrayType(StructType([
    StructField("chunk_index", IntegerType(), True),
    StructField("size", IntegerType(), True),
    StructField("binary_content", StringType(), True)  # Changed to StringType for base64
]))

# Register the UDF
chunk_udf = udf(chunk_binary_data, chunk_schema)

def process_audio_data(
    spark: SparkSession,
    source_path: str,
    output_table: str,
    whisper_endpoint: str
) -> None:
    """
    Process audio files and create a standardized table with transcriptions.
    
    Args:
        spark: SparkSession
        source_path: Path to the source audio files
        output_table: Full name of the output table (catalog.schema.table)
        whisper_endpoint: Name of the Whisper endpoint
    """
    start_time = time.time()
    logger.info(f"Starting audio processing from {source_path}")
    
    # Read the audio files as binary files
    audio_df = spark.read.format("binaryFile").load(source_path)
    file_count = audio_df.count()
    logger.info(f"Found {file_count} audio files to process")
    
    # Add a column with the binary content
    audio_df = audio_df.withColumn("binary_content", col("content"))
    
    # Change the data type for the length column to DoubleType
    audio_df = audio_df.withColumn("length", col("length").cast(DoubleType()))
    
    # Select only the necessary columns
    audio_df = audio_df.select("path", "modificationTime", "length", "binary_content")
    
    # Chunk the binary content
    chunked_df = audio_df.withColumn("chunks", chunk_udf(col("binary_content")))
    logger.info("Successfully chunked binary content")
    
    # Explode the chunks array to process each chunk separately
    exploded_df = chunked_df.select(
        col("path"),
        col("modificationTime"),
        col("length"),
        col("chunks.chunk_index").alias("chunk_index"),
        col("chunks.size").alias("chunk_size"),
        col("chunks.binary_content").alias("binary_content")  # This is now base64 encoded
    )
    chunk_count = exploded_df.count()
    logger.info(f"Created {chunk_count} chunks for processing")
    
    # Add modality column and get transcriptions for each chunk
    processed_df = exploded_df.select(
        lit("audio").alias("modality"),
        col("path"),
        col("modificationTime"),
        col("length"),
        # col("chunk_index"),
        # col("chunk_size"),
        col("binary_content"),
        # Convert the base64 string back to binary before sending to whisper endpoint
        expr(f"ai_query('{whisper_endpoint}', unbase64(element_at(binary_content, 1)), failOnError => True)").alias("transcript_text")
    )
    
    # Write to Delta table
    processed_df.write.mode("overwrite").option(
        "overwriteSchema", "true"
    ).saveAsTable(output_table)
    
    end_time = time.time()
    duration = end_time - start_time
    logger.info(f"Successfully processed and saved {chunk_count} audio chunks to {output_table}")
    logger.info(f"Processing completed in {duration:.2f} seconds")
    logger.info(f"Average processing time per file: {duration/file_count:.2f} seconds")

In [0]:

process_audio_data(
    spark=spark,
    source_path=f'/Volumes/{UC_CATALOG_NAME}/{UC_SCHEMA_NAME}/{UC_VOLUME_NAME}/{AUDIO_DATA_VOLUME_FOLDER}/',
    output_table=f"{UC_CATALOG_NAME}.{UC_SCHEMA_NAME}.{AUDIO_DATA_TABLE_NAME}",
    whisper_endpoint=WHISPER_ENDPOINT_NAME
)

##### `get_recursive_character_text_splitter`

`get_recursive_character_text_splitter` creates a new function that, given an embedding endpoint, returns a callable that can chunk text documents. This utility allows you to write the core business logic of the chunker, without dealing with the details of text splitting. You can decide to write your own, or edit this code if it does not fit your use case.

**Arguments:**

- `model_serving_endpoint`: The name of the Model Serving endpoint with the embedding model.
- `embedding_model_name`: The name of the embedding model e.g., `gte-large-en-v1.5`, etc.   If `model_serving_endpoint` is an OpenAI External Model or FMAPI model and set to `None`, this will be automatically detected. 
- `chunk_size_tokens`: An optional size for each chunk in tokens. Defaults to `None`, which uses the model's entire context window.
- `chunk_overlap_tokens`: Tokens that should overlap between chunks. Defaults to `0`.

**Returns:** A callable that takes a document (`str`) and produces a list of chunks (`list[str]`).

In [0]:
from langchain_text_splitters import RecursiveCharacterTextSplitter
from transformers import AutoTokenizer
import tiktoken
from typing import Callable, Tuple, Optional
import os
import re
from databricks.sdk import WorkspaceClient

# Constants
HF_CACHE_DIR = "/tmp/hf_cache/"

# Embedding Models Configuration
EMBEDDING_MODELS = {
    "gte-large-en-v1.5": {
        "tokenizer": lambda: AutoTokenizer.from_pretrained(
            "Alibaba-NLP/gte-large-en-v1.5", cache_dir=HF_CACHE_DIR
        ),
        "context_window": 8192,
        "type": "SENTENCE_TRANSFORMER",
    },
    "bge-large-en-v1.5": {
        "tokenizer": lambda: AutoTokenizer.from_pretrained(
            "BAAI/bge-large-en-v1.5", cache_dir=HF_CACHE_DIR
        ),
        "context_window": 512,
        "type": "SENTENCE_TRANSFORMER",
    },
    "bge_large_en_v1_5": {
        "tokenizer": lambda: AutoTokenizer.from_pretrained(
            "BAAI/bge-large-en-v1.5", cache_dir=HF_CACHE_DIR
        ),
        "context_window": 512,
        "type": "SENTENCE_TRANSFORMER",
    },
    "text-embedding-ada-002": {
        "context_window": 8192,
        "tokenizer": lambda: tiktoken.encoding_for_model("text-embedding-ada-002"),
        "type": "OPENAI",
    },
    "text-embedding-3-small": {
        "context_window": 8192,
        "tokenizer": lambda: tiktoken.encoding_for_model("text-embedding-3-small"),
        "type": "OPENAI",
    },
    "text-embedding-3-large": {
        "context_window": 8192,
        "tokenizer": lambda: tiktoken.encoding_for_model("text-embedding-3-large"),
        "type": "OPENAI",
    },
}


def get_workspace_client() -> WorkspaceClient:
    """Returns a WorkspaceClient instance."""
    return WorkspaceClient()


def get_embedding_model_config(endpoint_type: str) -> Optional[dict]:
    """
    Retrieve embedding model configuration by endpoint type.
    """
    return EMBEDDING_MODELS.get(endpoint_type)


def extract_endpoint_type(llm_endpoint) -> Optional[str]:
    """
    Extract the endpoint type from the given llm_endpoint object.
    """
    try:
        return llm_endpoint.config.served_entities[0].external_model.name
    except AttributeError:
        try:
            return llm_endpoint.config.served_entities[0].foundation_model.name
        except AttributeError:
            return None


def detect_fmapi_embedding_model_type(
    model_serving_endpoint: str,
) -> Tuple[Optional[str], Optional[dict]]:
    """
    Detects the embedding model type and configuration for the given endpoint.
    Returns a tuple of (endpoint_type, embedding_config) or (None, None) if not found.
    """
    client = get_workspace_client()

    try:
        llm_endpoint = client.serving_endpoints.get(name=model_serving_endpoint)
        endpoint_type = extract_endpoint_type(llm_endpoint)
    except Exception as e:
        endpoint_type = None

    embedding_config = (
        get_embedding_model_config(endpoint_type) if endpoint_type else None
    )
    return (endpoint_type, embedding_config)


def validate_chunk_size(chunk_spec: dict):
    """
    Validate the chunk size and overlap settings in chunk_spec.
    Raises ValueError if any condition is violated.
    """
    if (
        chunk_spec["chunk_overlap_tokens"] + chunk_spec["chunk_size_tokens"]
    ) > chunk_spec["context_window"]:
        raise ValueError(
            f'Proposed chunk_size of {chunk_spec["chunk_size_tokens"]} + overlap of {chunk_spec["chunk_overlap_tokens"]} '
            f'is {chunk_spec["chunk_overlap_tokens"] + chunk_spec["chunk_size_tokens"]} which is greater than context '
            f'window of {chunk_spec["context_window"]} tokens.'
        )

    if chunk_spec["chunk_overlap_tokens"] > chunk_spec["chunk_size_tokens"]:
        raise ValueError(
            f'Proposed `chunk_overlap_tokens` of {chunk_spec["chunk_overlap_tokens"]} is greater than the '
            f'`chunk_size_tokens` of {chunk_spec["chunk_size_tokens"]}. Reduce the size of `chunk_size_tokens`.'
        )


def get_recursive_character_text_splitter(
    model_serving_endpoint: str,
    embedding_model_name: str = None,
    chunk_size_tokens: int = None,
    chunk_overlap_tokens: int = 0,
) -> Callable[[str], list[str]]:
    try:
        # Detect the embedding model and its configuration
        embedding_model_name, chunk_spec = detect_fmapi_embedding_model_type(
            model_serving_endpoint
        )

        if chunk_spec is None or embedding_model_name is None:
            # Fall back to using provided embedding_model_name
            chunk_spec = EMBEDDING_MODELS.get(embedding_model_name)
            if chunk_spec is None:
                raise KeyError

        # Update chunk specification based on provided parameters
        chunk_spec["chunk_size_tokens"] = (
            chunk_size_tokens or chunk_spec["context_window"]
        )
        chunk_spec["chunk_overlap_tokens"] = chunk_overlap_tokens

        # Validate chunk size and overlap
        validate_chunk_size(chunk_spec)

        print(f'Chunk size in tokens: {chunk_spec["chunk_size_tokens"]}')
        print(f'Chunk overlap in tokens: {chunk_spec["chunk_overlap_tokens"]}')
        context_usage = (
            round(
                (chunk_spec["chunk_size_tokens"] + chunk_spec["chunk_overlap_tokens"])
                / chunk_spec["context_window"],
                2,
            )
            * 100
        )
        print(
            f'Using {context_usage}% of the {chunk_spec["context_window"]} token context window.'
        )

    except KeyError:
        raise ValueError(
            f"Embedding model `{embedding_model_name}` not found. Available models: {EMBEDDING_MODELS.keys()}"
        )

    def _recursive_character_text_splitter(text: str) -> list[str]:
        tokenizer = chunk_spec["tokenizer"]()
        if chunk_spec["type"] == "SENTENCE_TRANSFORMER":
            splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
                tokenizer,
                chunk_size=chunk_spec["chunk_size_tokens"],
                chunk_overlap=chunk_spec["chunk_overlap_tokens"],
            )
        elif chunk_spec["type"] == "OPENAI":
            splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
                tokenizer.name,
                chunk_size=chunk_spec["chunk_size_tokens"],
                chunk_overlap=chunk_spec["chunk_overlap_tokens"],
            )
        else:
            raise ValueError(f"Unsupported model type: {chunk_spec['type']}")
        return splitter.split_text(text)

    return _recursive_character_text_splitter

#### `chunk_docs`

`chunk_docs` creates a new delta table, given a table of documents, computing the chunk function over each document to produce a chunked documents table. This utility will let you write the core business logic of the chunker, without dealing with the spark details. You can decide to write your own, or edit this code if it does not fit your use case.

Arguments:
- `docs_table`: The fully qualified delta table name. For example: `my_catalog.my_schema.my_docs`
- `doc_column`: The name of the column where the documents can be found from `docs_table`. For example: `doc`.
- `chunk_fn`: A function that takes a document (str) and produces a list of chunks (list[str]).
- `propagate_columns`: Columns that should be propagated to the chunk table. For example: `url` to propagate the source URL.
- `chunked_docs_table`: An optional output table name for chunks. Defaults to `{docs_table}_chunked`.

Returns:
The name of the chunked docs table.

##### Examples of creating a `chunk_fn`

###### Option 1: Use a recursive character text splitter.

We provide a `get_recursive_character_text_splitter` util in this cookbook which will determine
the best chunk window given the embedding endpoint that we decide to use for indexing.

```py
chunk_fn = get_recursive_character_text_splitter('databricks-bge-large-en')
```

###### Option 2: Use a custom splitter (e.g. LLamaIndex splitters)

> An example `chunk_fn` using the markdown-aware node parser:

```py
from llama_index.core.node_parser import MarkdownNodeParser, TokenTextSplitter
from llama_index.core import Document
parser = MarkdownNodeParser()

def chunk_fn(doc: str) -> list[str]:
  documents = [Document(text=doc)]
  nodes = parser.get_nodes_from_documents(documents)
  return [node.get_content() for node in nodes]
```

In [0]:
from typing import Literal, Optional, Any, Callable
from databricks.vector_search.client import VectorSearchClient
from pyspark.sql.functions import explode
import pyspark.sql.functions as func
from typing import Callable
from langchain_text_splitters import RecursiveCharacterTextSplitter
from transformers import AutoTokenizer
import tiktoken
from pyspark.sql.types import StructType, StringType, StructField, MapType, ArrayType
from pyspark.sql import SparkSession
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def standardize_timestamp_format(df):
    """
    Standardize timestamp format in a DataFrame.
    Converts any timestamp column to a consistent format.
    
    Args:
        df: Input DataFrame
    Returns:
        DataFrame with standardized timestamps
    """
    if "modificationTime" in df.columns:
        return df.withColumn(
            "modificationTime",
            func.to_timestamp(func.col("modificationTime"))
        )
    return df

def compute_chunks(
    docs_table: str,
    doc_column: str,
    chunk_fn: Callable[[str], list[str]],
    propagate_columns: list[str],
    chunked_docs_table: str,
    modality: str,
) -> str:
    """
    Compute chunks from a document table and append them to an existing chunked table.
    
    Args:
        docs_table: Source table containing documents
        doc_column: Column name containing the text to chunk
        chunk_fn: Function to split text into chunks
        propagate_columns: List of columns to propagate from the docs table to chunks table
        chunked_docs_table: Target table for storing chunks
        modality: Type of content (e.g., 'video', 'audio', 'pdf')
    Returns:
        str: Name of the chunked table
    """
    logger.info(f"Computing chunks for `{docs_table}`...")
    
    # Initialize Spark session if not already available
    spark = SparkSession.builder.getOrCreate()
    
    # Read source documents
    raw_docs = spark.read.table(docs_table)
    
    # Check if modality column exists in source table
    source_has_modality = "modality" in raw_docs.columns
    
    # Create UDF for chunking
    parser_udf = func.udf(
        chunk_fn,
        returnType=ArrayType(StringType()),
    )
    
    # Process documents into chunks
    chunked_array_docs = raw_docs.withColumn(
        "content_chunked", parser_udf(doc_column)
    ).drop(doc_column)
    
    # Select columns to propagate, excluding modality if it exists
    columns_to_propagate = [col for col in propagate_columns if col != "modality"]
    
    chunked_docs = chunked_array_docs.select(
        *columns_to_propagate, explode("content_chunked").alias("content_chunked")
    )
    
    # Add chunk_id
    chunks_with_ids = chunked_docs.withColumn(
        "chunk_id", func.md5(func.col("content_chunked"))
    )
    
    # Add modality column if it doesn't exist in source
    if not source_has_modality:
        chunks_with_ids = chunks_with_ids.withColumn("modality", func.lit(modality))
    
    # Check if target table exists and get its schema
    table_exists = spark.catalog._jcatalog.tableExists(chunked_docs_table)
    if table_exists:
        target_schema = spark.read.table(chunked_docs_table).schema
        target_has_modality = "modality" in [field.name for field in target_schema]
        
        # If target has modality but source doesn't, add it
        if target_has_modality and not source_has_modality:
            chunks_with_ids = chunks_with_ids.withColumn("modality", func.lit(modality))
    
    # Standardize timestamp format
    chunks_with_ids = standardize_timestamp_format(chunks_with_ids)
    
    # Reorder columns for better display
    final_columns = ["chunk_id", "content_chunked"]
    if "modality" in chunks_with_ids.columns:
        final_columns.append("modality")
    final_columns.extend(columns_to_propagate)
    
    chunks_with_ids = chunks_with_ids.select(*final_columns)
    
    if table_exists:
        # Read existing chunks
        existing_chunks = spark.read.table(chunked_docs_table)
        
        # Get existing chunk IDs
        existing_ids = existing_chunks.select("chunk_id").distinct()
        
        # Filter out chunks that already exist
        new_chunks = chunks_with_ids.join(
            existing_ids,
            chunks_with_ids.chunk_id == existing_ids.chunk_id,
            "left_anti"
        )
        
        logger.info(f"Found {chunks_with_ids.count()} total chunks, {new_chunks.count()} new chunks")
        
        # Append only new chunks
        if new_chunks.count() > 0:
            new_chunks.write.mode("append").saveAsTable(chunked_docs_table)
            logger.info(f"Appended {new_chunks.count()} new chunks to {chunked_docs_table}")
        else:
            logger.info("No new chunks to append")
    else:
        # Create new table if it doesn't exist
        chunks_with_ids.write.mode("overwrite").option(
            "overwriteSchema", "true"
        ).saveAsTable(chunked_docs_table)
        logger.info(f"Created new table {chunked_docs_table} with {chunks_with_ids.count()} chunks")
    
    return chunked_docs_table

# Example usage code
def process_video_chunks():
    """
    Example function to process video transcripts into chunks
    """
    logger.info("Starting video chunk processing...")
    
    # Configure the chunker
    chunk_fn = get_recursive_character_text_splitter(
        model_serving_endpoint=EMBEDDING_MODEL_ENDPOINT,
        chunk_size_tokens=384,
        chunk_overlap_tokens=128,
    )
    
    # Get source table schema
    source_table = f"{UC_CATALOG_NAME}.{UC_SCHEMA_NAME}.{AUDIO_DATA_TABLE_NAME}"
    source_schema = spark.table(source_table).schema
    
    # Log source table columns
    logger.info(f"Source table columns: {[field.name for field in source_schema]}")
    
    # Get the columns to propagate
    # Exclude only the columns we definitely don't want
    propagate_columns = [
        field.name
        for field in source_schema
        if field.name not in ["transcript_text", "chunk_count"]  # Keep name and modality
    ]
    
    logger.info(f"Propagating columns: {propagate_columns}")
    
    # Process chunks
    chunked_docs_table = compute_chunks(
        docs_table=source_table,
        doc_column="transcript_text",
        chunk_fn=chunk_fn,
        propagate_columns=propagate_columns,
        chunked_docs_table=CHUNKED_DOCS_DELTA_TABLE,
        modality="video"
    )
    
    # Display results
    result_df = spark.read.table(chunked_docs_table)
    logger.info(f"Chunked table schema: {result_df.schema}")
    logger.info(f"Number of chunks created: {result_df.count()}")
    
    return result_df

# To run the processing:
# result_df = process_video_chunks()
# display(result_df)

# Example usage:
"""
# Define chunking function
def chunk_text(text: str) -> list[str]:
    # Your chunking logic here
    pass

# Compute chunks for video transcripts
compute_chunks(
    docs_table="ankit_yadav.fluke_schema.video_data_text",
    doc_column="transcript_text",
    chunk_fn=chunk_text,
    propagate_columns=["name", "path", "length"],
    chunked_docs_table="ankit_yadav.fluke_schema.content_chunks",
    modality="video"
)
""" 

In [0]:
# Configure the chunker
chunk_fn = get_recursive_character_text_splitter(
    model_serving_endpoint=EMBEDDING_MODEL_ENDPOINT,
    chunk_size_tokens=384,
    chunk_overlap_tokens=128,
)

# Get the columns from the parser except for the doc_content
# You can modify this to adjust which fields are propagated from the docs table to the chunks table.
propagate_columns = [
    field.name
    for field in spark.table(f"{UC_CATALOG_NAME}.{UC_SCHEMA_NAME}.{AUDIO_DATA_TABLE_NAME}").schema.fields
    if field.name not in ["transcript_text", "binary_content"]
]

chunked_docs_table = compute_chunks(
    # The source documents table.
    docs_table=f"{UC_CATALOG_NAME}.{UC_SCHEMA_NAME}.{AUDIO_DATA_TABLE_NAME}",
    # The column containing the documents to be chunked.
    doc_column="transcript_text",
    # The chunking function that takes a string (document) and returns a list of strings (chunks).
    chunk_fn=chunk_fn,
    # Choose which columns to propagate from the docs table to chunks table. `doc_uri` column is required we can propagate the original document URL to the Agent's web app.
    propagate_columns=propagate_columns,
    # By default, the chunked_docs_table will be written to `{docs_table}_chunked`.
    chunked_docs_table=f"{CHUNKED_DOCS_DELTA_TABLE}",
    modality="video"
)

display(spark.read.table(chunked_docs_table))

##### `build_retriever_index`

`build_retriever_index` will build the vector search index which is used by our RAG to retrieve relevant documents.

Arguments:
- `chunked_docs_table`: The chunked documents table. There is expected to be a `chunked_text` column, a `chunk_id` column, and a `url` column.
-  `primary_key`: The column to use for the vector index primary key.
- `embedding_source_column`: The column to compute embeddings for in the vector index.
- `vector_search_endpoint`: An optional vector search endpoint name. It not defined, defaults to the `{table_id}_vector_search`.
- `vector_search_index_name`: An optional index name. If not defined, defaults to `{chunked_docs_table}_index`.
- `embedding_endpoint_name`: An embedding endpoint name.
- `force_delete_vector_search_endpoint`: Setting this to true will rebuild the vector search endpoint.

In [0]:
# from typing import TypedDict, Dict
# import io
# from typing import List, Dict, Any, Tuple, Optional, TypedDict
# import warnings
# import pyspark.sql.functions as func
# from pyspark.sql.types import StructType, StringType, StructField, MapType, ArrayType
# from mlflow.utils import databricks_utils as du
# from functools import partial
# import tiktoken
# from transformers import AutoTokenizer
# from langchain_text_splitters import RecursiveCharacterTextSplitter
# from databricks.vector_search.client import VectorSearchClient
# import mlflow


# def _build_index(
#     primary_key: str,
#     embedding_source_column: str,
#     vector_search_endpoint: str,
#     chunked_docs_table_name: str,
#     vectorsearch_index_name: str,
#     embedding_endpoint_name: str,
#     force_delete=False,
# ):

#     # Get the vector search index
#     vsc = VectorSearchClient(disable_notice=True)

#     def find_index(endpoint_name, index_name):
#         all_indexes = vsc.list_indexes(name=vector_search_endpoint).get(
#             "vector_indexes", []
#         )
#         return vectorsearch_index_name in map(lambda i: i.get("name"), all_indexes)

#     if find_index(
#         endpoint_name=vector_search_endpoint, index_name=vectorsearch_index_name
#     ):
#         if force_delete:
#             vsc.delete_index(
#                 endpoint_name=vector_search_endpoint, index_name=vectorsearch_index_name
#             )
#             create_index = True
#         else:
#             create_index = False
#             print(
#                 f"Syncing index {vectorsearch_index_name}, this can take 15 minutes or much longer if you have a larger number of documents..."
#             )

#             sync_result = vsc.get_index(index_name=vectorsearch_index_name).sync()

#     else:
#         print(
#             f'Creating non-existent vector search index for endpoint "{vector_search_endpoint}" and index "{vectorsearch_index_name}"'
#         )
#         create_index = True

#     if create_index:
#         print(
#             f"Computing document embeddings and Vector Search Index. This can take 15 minutes or much longer if you have a larger number of documents."
#         )

#         vsc.create_delta_sync_index_and_wait(
#             endpoint_name=vector_search_endpoint,
#             index_name=vectorsearch_index_name,
#             primary_key=primary_key,
#             source_table_name=chunked_docs_table_name,
#             pipeline_type="TRIGGERED",
#             embedding_source_column=embedding_source_column,
#             embedding_model_endpoint_name=embedding_endpoint_name,
#         )

In [0]:
# from pydantic import BaseModel


# class RetrieverIndexResult(BaseModel):
#     vector_search_endpoint: str
#     vector_search_index_name: str
#     embedding_endpoint_name: str
#     chunked_docs_table: str


# def build_retriever_index(
#     chunked_docs_table: str,
#     primary_key: str,
#     embedding_source_column: str,
#     embedding_endpoint_name: str,
#     vector_search_endpoint: str,
#     vector_search_index_name: str,
#     force_delete_vector_search_endpoint=False,
# ) -> RetrieverIndexResult:

#     retriever_index_result = RetrieverIndexResult(
#         vector_search_endpoint=vector_search_endpoint,
#         vector_search_index_name=vector_search_index_name,
#         embedding_endpoint_name=embedding_endpoint_name,
#         chunked_docs_table=chunked_docs_table,
#     )

#     # Enable CDC for Vector Search Delta Sync
#     spark.sql(
#         f"ALTER TABLE {chunked_docs_table} SET TBLPROPERTIES (delta.enableChangeDataFeed = true)"
#     )

#     print("Building embedding index...")
#     # Building the index.
#     _build_index(
#         primary_key=primary_key,
#         embedding_source_column=embedding_source_column,
#         vector_search_endpoint=vector_search_endpoint,
#         chunked_docs_table_name=chunked_docs_table,
#         vectorsearch_index_name=vector_search_index_name,
#         embedding_endpoint_name=embedding_endpoint_name,
#         force_delete=force_delete_vector_search_endpoint,
#     )

#     return retriever_index_result

In [0]:
# retriever_index_result = build_retriever_index(
#     # Spark requires `` to escape names with special chars, VS client does not.
#     chunked_docs_table=CHUNKED_DOCS_DELTA_TABLE.replace("`", ""),
#     primary_key="chunk_id",
#     embedding_source_column="content_chunked",
#     vector_search_endpoint=VECTOR_SEARCH_ENDPOINT,
#     vector_search_index_name=VECTOR_INDEX_NAME,
#     # Must match the embedding endpoint you used to chunk your documents
#     embedding_endpoint_name=EMBEDDING_MODEL_ENDPOINT,
#     # Set to true to re-create the vector search endpoint when re-running.
#     force_delete_vector_search_endpoint=False,
# )

# print(retriever_index_result)

# print()
# print("Vector search index created! This will be used in the next notebook.")
# print(f"Vector search endpoint: {retriever_index_result.vector_search_endpoint}")
# print(f"Vector search index: {retriever_index_result.vector_search_index_name}")
# print(f"Embedding used: {retriever_index_result.embedding_endpoint_name}")
# print(f"Chunked docs table: {retriever_index_result.chunked_docs_table}")