# Video Data Processing Job

This notebook demonstrates an end-to-end pipeline for processing video data using Databricks, focusing on:

1. Loading video files (.mp4, .mov) from a source volume
2. Transcribing the videos audio to text using Whisper AI
3. Chunking the transcribed text into manageable segments
4. Creating a vector search index for efficient retrieval

The notebook is structured in sequential steps, from data ingestion through to indexing, making it easy to understand and modify for your specific video processing needs. Each major section is clearly commented and includes relevant configuration parameters.

Key components used:
- Databricks Vector Search
- FFMpeg package
- Whisper AI for transcription
- BGE embedding model for text vectorization


#### Setup and Configuration
 
This section defines key configuration parameters used throughout the notebook:
 
 - Unity Catalog settings (catalog, schema, volume names)
 - Model endpoints (Whisper AI, BGE embeddings) 
 - Delta table names
 - Vector search configuration


In [0]:
# Packages required by all code.
# Versions of Databricks code are not locked since Databricks ensures changes are backwards compatible.
# Versions of open source packages are locked since package authors often make backwards compatible changes
%pip install -qqqq -U \
  databricks-vectorsearch databricks-agents pydantic databricks-sdk mlflow mlflow-skinny `# For agent & data pipeline code` \
  pypdf==4.1.0  `# PDF parsing` \
  markdownify==0.12.1  `# HTML parsing` \
  pypandoc_binary==1.13  `# DOCX parsing` \
  transformers==4.41.1 torch==2.3.0 tiktoken==0.7.0 langchain-text-splitters==0.2.0. `# get_recursive_character_text_splitter`

# Restart to load the packages into the Python environment
dbutils.library.restartPython()

In [0]:
%run ../global_config

# Video Processing and Transcription Pipeline Documentation

## Overview
This code implements a robust video processing and transcription pipeline that handles video files, extracts audio, processes it in chunks, and generates transcripts using Whisper. The pipeline is designed to work efficiently with large video files while managing system resources effectively.

## Key Features
- **Video Format Support**: Handles multiple video formats including MP4, AVI, MOV, MKV, WebM, FLV, WMV, MPEG, 3GP
- **Resource Management**: Monitors and manages system resources (CPU, Memory) during processing
- **Chunked Processing**: Splits large videos into manageable chunks for efficient processing
- **Parallel Processing**: Processes chunks in batches for better performance
- **Error Handling**: Comprehensive error handling and logging throughout the pipeline
- **Delta Lake Integration**: Stores results in Delta Lake tables for efficient querying and management

## Process Flow
1. **Video File Discovery**
   - Scans specified directory for supported video files
   - Validates file formats and existence

2. **Metadata Extraction**
   - Extracts video duration, format, and other metadata
   - Calculates optimal chunk size based on video duration

3. **Audio Extraction and Chunking**
   - Extracts audio from video files
   - Splits audio into optimal-sized chunks
   - Processes chunks in batches to manage memory usage

4. **Transcription Processing**
   - Processes each audio chunk through Whisper
   - Combines chunk transcripts in correct order
   - Stores results in Delta Lake table

## Technical Details

### Resource Management
```python
MAX_MEMORY_PERCENT = 80  # Maximum memory usage percentage
BATCH_SIZE = 7  # Number of chunks to process in each batch
MIN_CHUNK_DURATION = 30  # Minimum chunk duration in seconds
MAX_CHUNK_DURATION = 60  # Maximum chunk duration in seconds
```

### Audio Processing Parameters
- Sample Rate: 16000 Hz
- Channels: Mono
- Codec: libmp3lame
- Bitrate: 64k

### Data Structures
The pipeline uses several key data structures:
1. **Video Metadata**
   - Duration
   - Name
   - Path
   - Format
   - File Format

2. **Chunk Data**
   - Start Time
   - End Time
   - Size
   - Binary Content

3. **Output DataFrame Schema**
   - Modality
   - Name
   - Path
   - Length
   - Modification Time
   - Transcript Text
   - Chunk Count

### Key Functions

#### `extract_and_chunk_audio(video_path, chunk_duration_seconds=None)`
- Extracts audio from video
- Splits into optimal-sized chunks
- Returns chunk data and video metadata

#### `process_chunk_batch(video_path, start_times, durations)`
- Processes multiple chunks in parallel
- Manages system resources
- Returns processed chunk data

#### `combine_transcripts(spark, transcript_df)`
- Combines individual chunk transcripts
- Maintains proper ordering
- Creates final transcript

## Usage Notes

### Prerequisites
- FFmpeg installed and available in system path
- PySpark environment configured
- Access to Whisper endpoint
- Sufficient system resources

### Performance Considerations
- Chunk size affects memory usage and processing speed
- Batch size can be adjusted based on available resources
- System resource monitoring prevents memory overflow

### Error Handling
- Comprehensive logging throughout the pipeline
- Graceful handling of unsupported formats
- Resource limit monitoring and pausing

### Output
- Results stored in Delta Lake table
- Includes full transcript and metadata
- Maintains original video information

## Limitations
- Maximum chunk duration of 60 seconds
- Minimum chunk duration of 30 seconds
- Memory usage capped at 80%
- Batch processing limited to 7 chunks

## Future Improvements
- Dynamic batch size adjustment
- Support for additional video formats
- Enhanced error recovery mechanisms
- Parallel processing optimization

In [0]:
import os
import subprocess
import psutil
import time
from pyspark.sql import SparkSession
from pyspark.sql.types import StringType, IntegerType, StructType, StructField, DoubleType, BinaryType
from pyspark.sql.functions import collect_list, struct, col, concat_ws, count, lit
from pyspark.sql import functions as F
import math
from datetime import datetime
import uuid
from pyspark.sql.window import Window
import tempfile
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Constants for resource monitoring
MAX_MEMORY_PERCENT = 80  # Maximum memory usage percentage
BATCH_SIZE = 7  # Number of chunks to process in each batch
MIN_CHUNK_DURATION = 30  # Minimum chunk duration in seconds
MAX_CHUNK_DURATION = 60  # Maximum chunk duration in seconds

# Supported video formats
SUPPORTED_VIDEO_FORMATS = {
    '.mp4': 'MP4',
    '.avi': 'AVI',
    '.mov': 'MOV',
    '.mkv': 'MKV',
    '.webm': 'WebM',
    '.flv': 'FLV',
    '.wmv': 'WMV',
    '.mpg': 'MPEG',
    '.mpeg': 'MPEG',
    '.3gp': '3GP'
}

def get_resource_usage():
    """Get current resource usage"""
    process = psutil.Process()
    memory_info = process.memory_info()
    return {
        'memory_percent': process.memory_percent(),
        'memory_used': memory_info.rss / (1024 * 1024),  # MB
        'cpu_percent': process.cpu_percent(interval=1)
    }

def ensure_directory_exists(directory):
    """Ensure the directory exists, create if it doesn't"""
    if not os.path.exists(directory):
        os.makedirs(directory)

def get_video_format(video_path):
    """
    Get the format of the video file
    Args:
        video_path: Path to the video file
    Returns:
        str: Video format or None if not supported
    """
    file_ext = os.path.splitext(video_path)[1].lower()
    return SUPPORTED_VIDEO_FORMATS.get(file_ext)

def get_video_metadata(video_path):
    """
    Get metadata about the video file
    Args:
        video_path: Path to the video file
    Returns:
        dict: Video metadata including duration, name, path, and format
    """
    try:
        # Get format information
        format_cmd = [
            'ffprobe',
            '-i', video_path,
            '-show_entries', 'format=format_name',
            '-v', 'quiet',
            '-of', 'csv=p=0'
        ]
        format_name = subprocess.check_output(format_cmd).decode().strip()
        
        # Get duration
        duration_cmd = [
            'ffprobe',
            '-i', video_path,
            '-show_entries', 'format=duration',
            '-v', 'quiet',
            '-of', 'csv=p=0'
        ]
        duration = float(subprocess.check_output(duration_cmd).decode().strip())
        
        return {
            'duration': duration,
            'name': os.path.basename(video_path),
            'path': video_path,
            'format': format_name,
            'file_format': get_video_format(video_path)
        }
    except Exception as e:
        logger.error(f"Error getting video metadata: {str(e)}")
        return None

def process_chunk(video_path, start_time, duration):
    """
    Process a single chunk and return its binary content
    Args:
        video_path: Path to the video file
        start_time: Start time in seconds
        duration: Duration in seconds
    Returns:
        bytes: Binary content of the audio chunk
    """
    try:
        # Create a temporary file to store the chunk
        with tempfile.NamedTemporaryFile(suffix='.mp3', delete=True) as temp_file:
            cmd = [
                'ffmpeg',
                '-y',  # Overwrite output files
                '-i', video_path,
                '-ss', str(start_time),
                '-t', str(duration),
                '-vn',                 # No video
                '-acodec', 'libmp3lame',  # MP3 codec
                '-ar', '16000',       # Sample rate
                '-ac', '1',           # Mono channel
                '-b:a', '64k',        # Lower bitrate
                temp_file.name
            ]
            
            logger.info(f"Processing chunk at {start_time}s")
            subprocess.run(cmd, check=True)
            
            # Read the binary content
            with open(temp_file.name, 'rb') as f:
                binary_content = f.read()
            
            return binary_content
            
    except subprocess.CalledProcessError as e:
        logger.error(f"Error processing chunk: {str(e)}")
        return None
    except Exception as e:
        logger.error(f"Unexpected error: {str(e)}")
        return None

def check_resources():
    """Check if system resources are within limits"""
    usage = get_resource_usage()
    if usage['memory_percent'] > MAX_MEMORY_PERCENT:
        logger.warning(f"High memory usage: {usage['memory_percent']}%")
        return False
    return True

def calculate_optimal_chunk_duration(total_duration):
    """
    Calculate optimal chunk duration based on video length
    Args:
        total_duration: Total video duration in seconds
    Returns:
        int: Optimal chunk duration in seconds
    """
    if total_duration <= 300:  # 5 minutes
        return MIN_CHUNK_DURATION
    elif total_duration <= 1800:  # 30 minutes
        return 60
    else:  # > 30 minutes
        return min(MAX_CHUNK_DURATION, total_duration / 20)  # Aim for ~20 chunks

def process_chunk_batch(video_path, start_times, durations):
    """
    Process a batch of chunks in parallel
    Args:
        video_path: Path to the video file
        start_times: List of start times for chunks
        durations: List of durations for chunks
    Returns:
        list: List of binary contents for each chunk
    """
    results = []
    for start_time, duration in zip(start_times, durations):
        if not check_resources():
            logger.warning("Resource limits exceeded, pausing processing")
            time.sleep(5)  # Wait for resources to free up
            if not check_resources():
                raise Exception("Resource limits exceeded after waiting")
        
        binary_content = process_chunk(video_path, start_time, duration)
        if binary_content:
            results.append({
                'start_time': start_time,
                'end_time': start_time + duration,
                'size': len(binary_content),
                'binary_content': binary_content
            })
        else:
            logger.error(f"Failed to process chunk starting at {start_time}s")
    
    return results

def extract_and_chunk_audio(video_path, chunk_duration_seconds=None):
    """
    Extract audio from video and split into chunks in memory
    Args:
        video_path: Path to the video file in Databricks volume
        chunk_duration_seconds: Duration of each chunk in seconds (optional)
    Returns:
        tuple: (success_status, list of chunk data or error_message, video_metadata)
    """
    try:
        # Get video metadata
        video_metadata = get_video_metadata(video_path)
        if not video_metadata:
            return False, "Failed to get video metadata", None
            
        # Verify video file exists
        if not os.path.exists(video_path):
            return False, f"Video file not found at: {video_path}", None
        
        # Calculate optimal chunk duration if not provided
        if chunk_duration_seconds is None:
            chunk_duration_seconds = calculate_optimal_chunk_duration(video_metadata['duration'])
        
        # Calculate number of chunks
        total_duration = video_metadata['duration']
        num_chunks = math.ceil(total_duration / chunk_duration_seconds)
        logger.info(f"Will create {num_chunks} chunks of {chunk_duration_seconds} seconds each")
        
        chunk_data = []
        failed_chunks = []
        
        # Process chunks in batches
        for batch_start in range(0, num_chunks, BATCH_SIZE):
            batch_end = min(batch_start + BATCH_SIZE, num_chunks)
            logger.info(f"Processing batch {batch_start//BATCH_SIZE + 1} of {math.ceil(num_chunks/BATCH_SIZE)}")
            
            start_times = [i * chunk_duration_seconds for i in range(batch_start, batch_end)]
            durations = [chunk_duration_seconds] * (batch_end - batch_start)
            
            try:
                batch_results = process_chunk_batch(video_path, start_times, durations)
                chunk_data.extend(batch_results)
            except Exception as e:
                logger.error(f"Error processing batch: {str(e)}")
                failed_chunks.extend(range(batch_start, batch_end))
                continue
        
        # Report results
        logger.info("\nProcessing Summary:")
        logger.info(f"Total chunks attempted: {num_chunks}")
        logger.info(f"Successfully created: {len(chunk_data)}")
        logger.info(f"Failed chunks: {failed_chunks}")
        
        if not chunk_data:
            return False, "Failed to create any chunks", None
            
        return True, chunk_data, video_metadata
        
    except Exception as e:
        logger.error(f"Error processing video: {str(e)}")
        return False, f"Error processing video: {str(e)}", None

def create_audio_dataframe(spark, video_metadata, chunk_data):
    """
    Create a DataFrame with the audio chunks in memory
    Args:
        spark: SparkSession
        video_metadata: Dictionary containing video metadata
        chunk_data: List of dictionaries containing chunk information
    Returns:
        DataFrame: Spark DataFrame with audio chunks in memory
    """
    # Define schema
    schema = StructType([
        StructField("name", StringType(), True),
        StructField("path", StringType(), True),
        StructField("length", DoubleType(), True),
        StructField("modificationTime", StringType(), True),
        StructField("chunk_index", IntegerType(), True),
        StructField("chunk_start_time", DoubleType(), True),
        StructField("chunk_end_time", DoubleType(), True),
        StructField("chunk_size_bytes", IntegerType(), True),
        StructField("audio_binary", BinaryType(), True)
    ])
    
    # Create DataFrame with chunks
    from pyspark.sql import Row
    timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    rows = [
        Row(
            name=video_metadata['name'],
            path=video_metadata['path'],
            length=float(video_metadata['duration']),
            modificationTime=timestamp,
            chunk_index=i,
            chunk_start_time=float(chunk['start_time']),
            chunk_end_time=float(chunk['end_time']),
            chunk_size_bytes=chunk['size'],
            audio_binary=chunk['binary_content']
        )
        for i, chunk in enumerate(chunk_data)
    ]
    
    return spark.createDataFrame(rows, schema=schema)

def combine_transcripts(spark, transcript_df):
    """
    Combine transcript chunks into a single transcript per video
    Args:
        spark: SparkSession
        transcript_df: DataFrame containing transcript chunks
    Returns:
        DataFrame: Combined transcripts with original video metadata
    """
    # First, sort the chunks by their index
    sorted_df = transcript_df.orderBy("chunk_index")
    
    # Create a window specification for ordering
    window_spec = Window.partitionBy("name").orderBy("chunk_index")
    
    # Add a row number to ensure proper ordering
    numbered_df = sorted_df.withColumn("row_num", F.row_number().over(window_spec))
    
    # Group by video and collect transcripts in order
    combined_df = numbered_df.groupBy(
        "name",
        "path",
        "length",
        "modificationTime"
    ).agg(
        F.collect_list(
            F.struct(
                F.col("chunk_index"),
                F.col("transcript_text")
            )
        ).alias("chunks"),
        F.count("*").alias("chunk_count")
    ).select(
        # Add modality as first column
        F.lit("video").alias("modality"),
        "name",
        "path",
        "length",
        "modificationTime",
        # Sort chunks by index and combine text
        F.concat_ws(
            " ",
            F.expr("transform(array_sort(chunks, (left, right) -> case when left.chunk_index < right.chunk_index then -1 else 1 end), x -> x.transcript_text)")
        ).alias("transcript_text"),
        "chunk_count"
    )
    
    return combined_df

def process_video_file(spark, video_path):
    """
    Process a single video file
    Args:
        spark: SparkSession
        video_path: Path to the video file
    Returns:
        DataFrame: Combined transcripts for the video
    """
    logger.info(f"\nProcessing video: {video_path}")
    
    # Extract and chunk audio directly in memory
    success, result, video_metadata = extract_and_chunk_audio(video_path)
    
    if success:
        # Create DataFrame with audio chunks in memory
        audio_df = create_audio_dataframe(spark, video_metadata, result)
        
        # Process audio chunks with Whisper
        logger.info("\nProcessing audio chunks with Whisper...")
        transcript_df = audio_df.withColumn("modality", lit("video")) \
            .withColumn("transcript_text", F.expr(f"ai_query('{WHISPER_ENDPOINT_NAME}', audio_binary, failOnError => True)"))
        
        # Combine transcripts
        logger.info("\nCombining transcripts...")
        combined_transcripts = combine_transcripts(spark, transcript_df)
        
        # Save combined transcripts
        logger.info("\nSaving transcripts to video_data_text table...")
        combined_transcripts.write.format("delta").mode("append").saveAsTable(f"{UC_CATALOG_NAME}.{UC_SCHEMA_NAME}.{VIDEO_DATA_TABLE_NAME}")
        
        # Show the combined transcripts
        logger.info("\nCombined Transcripts:")
        combined_transcripts.show(truncate=False)
        
        return combined_transcripts
    else:
        logger.error(f"Error processing video {video_path}: {result}")
        return None

def main():
    # Initialize Spark session
    spark = SparkSession.builder.getOrCreate()
    
    # Video files directory
    video_dir = f"/Volumes/{UC_CATALOG_NAME}/{UC_SCHEMA_NAME}/{UC_VOLUME_NAME}/{VIDEO_DATA_VOLUME_FOLDER}/"
    
    # Get all supported video files in the directory
    video_files = []
    for f in os.listdir(video_dir):
        if any(f.lower().endswith(ext) for ext in SUPPORTED_VIDEO_FORMATS.keys()):
            video_files.append(os.path.join(video_dir, f))
    
    if not video_files:
        logger.warning(f"No supported video files found in {video_dir}")
        logger.info(f"Supported formats: {', '.join(SUPPORTED_VIDEO_FORMATS.values())}")
        return None
    
    logger.info(f"Found {len(video_files)} video files to process")
    
    # Process each video file
    all_transcripts = []
    for video_path in video_files:
        try:
            # Check if video format is supported
            video_format = get_video_format(video_path)
            if not video_format:
                logger.warning(f"Skipping unsupported video format: {video_path}")
                continue
                
            logger.info(f"\nProcessing {video_format} video: {video_path}")
            transcripts = process_video_file(spark, video_path)
            if transcripts:
                all_transcripts.append(transcripts)
        except Exception as e:
            logger.error(f"Error processing {video_path}: {str(e)}")
            continue
    
    if all_transcripts:
        # Combine all transcripts if needed
        final_df = all_transcripts[0]
        for df in all_transcripts[1:]:
            final_df = final_df.union(df)
        
        logger.info("\nFinal combined transcripts for all videos:")
        final_df.show(truncate=False)
        return final_df
    else:
        logger.warning("No videos were successfully processed")
        return None

if __name__ == "__main__":
    main() 

##### `get_recursive_character_text_splitter`

`get_recursive_character_text_splitter` creates a new function that, given an embedding endpoint, returns a callable that can chunk text documents. This utility allows you to write the core business logic of the chunker, without dealing with the details of text splitting. You can decide to write your own, or edit this code if it does not fit your use case.

**Arguments:**

- `model_serving_endpoint`: The name of the Model Serving endpoint with the embedding model.
- `embedding_model_name`: The name of the embedding model e.g., `gte-large-en-v1.5`, etc.   If `model_serving_endpoint` is an OpenAI External Model or FMAPI model and set to `None`, this will be automatically detected. 
- `chunk_size_tokens`: An optional size for each chunk in tokens. Defaults to `None`, which uses the model's entire context window.
- `chunk_overlap_tokens`: Tokens that should overlap between chunks. Defaults to `0`.

**Returns:** A callable that takes a document (`str`) and produces a list of chunks (`list[str]`).

In [0]:
from langchain_text_splitters import RecursiveCharacterTextSplitter
from transformers import AutoTokenizer
import tiktoken
from typing import Callable, Tuple, Optional
import os
import re
from databricks.sdk import WorkspaceClient

# Constants
HF_CACHE_DIR = "/tmp/hf_cache/"

# Embedding Models Configuration
EMBEDDING_MODELS = {
    "gte-large-en-v1.5": {
        "tokenizer": lambda: AutoTokenizer.from_pretrained(
            "Alibaba-NLP/gte-large-en-v1.5", cache_dir=HF_CACHE_DIR
        ),
        "context_window": 8192,
        "type": "SENTENCE_TRANSFORMER",
    },
    "bge-large-en-v1.5": {
        "tokenizer": lambda: AutoTokenizer.from_pretrained(
            "BAAI/bge-large-en-v1.5", cache_dir=HF_CACHE_DIR
        ),
        "context_window": 512,
        "type": "SENTENCE_TRANSFORMER",
    },
    "bge_large_en_v1_5": {
        "tokenizer": lambda: AutoTokenizer.from_pretrained(
            "BAAI/bge-large-en-v1.5", cache_dir=HF_CACHE_DIR
        ),
        "context_window": 512,
        "type": "SENTENCE_TRANSFORMER",
    },
    "text-embedding-ada-002": {
        "context_window": 8192,
        "tokenizer": lambda: tiktoken.encoding_for_model("text-embedding-ada-002"),
        "type": "OPENAI",
    },
    "text-embedding-3-small": {
        "context_window": 8192,
        "tokenizer": lambda: tiktoken.encoding_for_model("text-embedding-3-small"),
        "type": "OPENAI",
    },
    "text-embedding-3-large": {
        "context_window": 8192,
        "tokenizer": lambda: tiktoken.encoding_for_model("text-embedding-3-large"),
        "type": "OPENAI",
    },
}


def get_workspace_client() -> WorkspaceClient:
    """Returns a WorkspaceClient instance."""
    return WorkspaceClient()


def get_embedding_model_config(endpoint_type: str) -> Optional[dict]:
    """
    Retrieve embedding model configuration by endpoint type.
    """
    return EMBEDDING_MODELS.get(endpoint_type)


def extract_endpoint_type(llm_endpoint) -> Optional[str]:
    """
    Extract the endpoint type from the given llm_endpoint object.
    """
    try:
        return llm_endpoint.config.served_entities[0].external_model.name
    except AttributeError:
        try:
            return llm_endpoint.config.served_entities[0].foundation_model.name
        except AttributeError:
            return None


def detect_fmapi_embedding_model_type(
    model_serving_endpoint: str,
) -> Tuple[Optional[str], Optional[dict]]:
    """
    Detects the embedding model type and configuration for the given endpoint.
    Returns a tuple of (endpoint_type, embedding_config) or (None, None) if not found.
    """
    client = get_workspace_client()

    try:
        llm_endpoint = client.serving_endpoints.get(name=model_serving_endpoint)
        endpoint_type = extract_endpoint_type(llm_endpoint)
    except Exception as e:
        endpoint_type = None

    embedding_config = (
        get_embedding_model_config(endpoint_type) if endpoint_type else None
    )
    return (endpoint_type, embedding_config)


def validate_chunk_size(chunk_spec: dict):
    """
    Validate the chunk size and overlap settings in chunk_spec.
    Raises ValueError if any condition is violated.
    """
    if (
        chunk_spec["chunk_overlap_tokens"] + chunk_spec["chunk_size_tokens"]
    ) > chunk_spec["context_window"]:
        raise ValueError(
            f'Proposed chunk_size of {chunk_spec["chunk_size_tokens"]} + overlap of {chunk_spec["chunk_overlap_tokens"]} '
            f'is {chunk_spec["chunk_overlap_tokens"] + chunk_spec["chunk_size_tokens"]} which is greater than context '
            f'window of {chunk_spec["context_window"]} tokens.'
        )

    if chunk_spec["chunk_overlap_tokens"] > chunk_spec["chunk_size_tokens"]:
        raise ValueError(
            f'Proposed `chunk_overlap_tokens` of {chunk_spec["chunk_overlap_tokens"]} is greater than the '
            f'`chunk_size_tokens` of {chunk_spec["chunk_size_tokens"]}. Reduce the size of `chunk_size_tokens`.'
        )


def get_recursive_character_text_splitter(
    model_serving_endpoint: str,
    embedding_model_name: str = None,
    chunk_size_tokens: int = None,
    chunk_overlap_tokens: int = 0,
) -> Callable[[str], list[str]]:
    try:
        # Detect the embedding model and its configuration
        embedding_model_name, chunk_spec = detect_fmapi_embedding_model_type(
            model_serving_endpoint
        )

        if chunk_spec is None or embedding_model_name is None:
            # Fall back to using provided embedding_model_name
            chunk_spec = EMBEDDING_MODELS.get(embedding_model_name)
            if chunk_spec is None:
                raise KeyError

        # Update chunk specification based on provided parameters
        chunk_spec["chunk_size_tokens"] = (
            chunk_size_tokens or chunk_spec["context_window"]
        )
        chunk_spec["chunk_overlap_tokens"] = chunk_overlap_tokens

        # Validate chunk size and overlap
        validate_chunk_size(chunk_spec)

        print(f'Chunk size in tokens: {chunk_spec["chunk_size_tokens"]}')
        print(f'Chunk overlap in tokens: {chunk_spec["chunk_overlap_tokens"]}')
        context_usage = (
            round(
                (chunk_spec["chunk_size_tokens"] + chunk_spec["chunk_overlap_tokens"])
                / chunk_spec["context_window"],
                2,
            )
            * 100
        )
        print(
            f'Using {context_usage}% of the {chunk_spec["context_window"]} token context window.'
        )

    except KeyError:
        raise ValueError(
            f"Embedding model `{embedding_model_name}` not found. Available models: {EMBEDDING_MODELS.keys()}"
        )

    def _recursive_character_text_splitter(text: str) -> list[str]:
        tokenizer = chunk_spec["tokenizer"]()
        if chunk_spec["type"] == "SENTENCE_TRANSFORMER":
            splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
                tokenizer,
                chunk_size=chunk_spec["chunk_size_tokens"],
                chunk_overlap=chunk_spec["chunk_overlap_tokens"],
            )
        elif chunk_spec["type"] == "OPENAI":
            splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
                tokenizer.name,
                chunk_size=chunk_spec["chunk_size_tokens"],
                chunk_overlap=chunk_spec["chunk_overlap_tokens"],
            )
        else:
            raise ValueError(f"Unsupported model type: {chunk_spec['type']}")
        return splitter.split_text(text)

    return _recursive_character_text_splitter

#### `chunk_docs`

`chunk_docs` creates a new delta table, given a table of documents, computing the chunk function over each document to produce a chunked documents table. This utility will let you write the core business logic of the chunker, without dealing with the spark details. You can decide to write your own, or edit this code if it does not fit your use case.

Arguments:
- `docs_table`: The fully qualified delta table name. For example: `my_catalog.my_schema.my_docs`
- `doc_column`: The name of the column where the documents can be found from `docs_table`. For example: `doc`.
- `chunk_fn`: A function that takes a document (str) and produces a list of chunks (list[str]).
- `propagate_columns`: Columns that should be propagated to the chunk table. For example: `url` to propagate the source URL.
- `chunked_docs_table`: An optional output table name for chunks. Defaults to `{docs_table}_chunked`.

Returns:
The name of the chunked docs table.

##### Examples of creating a `chunk_fn`

###### Option 1: Use a recursive character text splitter.

We provide a `get_recursive_character_text_splitter` util in this cookbook which will determine
the best chunk window given the embedding endpoint that we decide to use for indexing.

```py
chunk_fn = get_recursive_character_text_splitter('databricks-bge-large-en')
```

###### Option 2: Use a custom splitter (e.g. LLamaIndex splitters)

> An example `chunk_fn` using the markdown-aware node parser:

```py
from llama_index.core.node_parser import MarkdownNodeParser, TokenTextSplitter
from llama_index.core import Document
parser = MarkdownNodeParser()

def chunk_fn(doc: str) -> list[str]:
  documents = [Document(text=doc)]
  nodes = parser.get_nodes_from_documents(documents)
  return [node.get_content() for node in nodes]
```

In [0]:
from typing import Literal, Optional, Any, Callable
from databricks.vector_search.client import VectorSearchClient
from pyspark.sql.functions import explode
import pyspark.sql.functions as func
from typing import Callable
from langchain_text_splitters import RecursiveCharacterTextSplitter
from transformers import AutoTokenizer
import tiktoken
from pyspark.sql.types import StructType, StringType, StructField, MapType, ArrayType
from pyspark.sql import SparkSession
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def standardize_timestamp_format(df):
    """
    Standardize timestamp format in a DataFrame.
    Converts any timestamp column to a consistent format.
    
    Args:
        df: Input DataFrame
    Returns:
        DataFrame with standardized timestamps
    """
    if "modificationTime" in df.columns:
        return df.withColumn(
            "modificationTime",
            func.to_timestamp(func.col("modificationTime"))
        )
    return df

def compute_chunks(
    docs_table: str,
    doc_column: str,
    chunk_fn: Callable[[str], list[str]],
    propagate_columns: list[str],
    chunked_docs_table: str,
    modality: str,
) -> str:
    """
    Compute chunks from a document table and append them to an existing chunked table.
    
    Args:
        docs_table: Source table containing documents
        doc_column: Column name containing the text to chunk
        chunk_fn: Function to split text into chunks
        propagate_columns: List of columns to propagate from the docs table to chunks table
        chunked_docs_table: Target table for storing chunks
        modality: Type of content (e.g., 'video', 'audio', 'pdf')
    Returns:
        str: Name of the chunked table
    """
    logger.info(f"Computing chunks for `{docs_table}`...")
    
    # Initialize Spark session if not already available
    spark = SparkSession.builder.getOrCreate()
    
    # Read source documents
    raw_docs = spark.read.table(docs_table)
    
    # Check if modality column exists in source table
    source_has_modality = "modality" in raw_docs.columns
    
    # Create UDF for chunking
    parser_udf = func.udf(
        chunk_fn,
        returnType=ArrayType(StringType()),
    )
    
    # Process documents into chunks
    chunked_array_docs = raw_docs.withColumn(
        "content_chunked", parser_udf(doc_column)
    ).drop(doc_column)
    
    # Select columns to propagate, excluding modality if it exists
    columns_to_propagate = [col for col in propagate_columns if col != "modality"]
    
    chunked_docs = chunked_array_docs.select(
        *columns_to_propagate, explode("content_chunked").alias("content_chunked")
    )
    
    # Add chunk_id
    chunks_with_ids = chunked_docs.withColumn(
        "chunk_id", func.md5(func.col("content_chunked"))
    )
    
    # Add modality column if it doesn't exist in source
    if not source_has_modality:
        chunks_with_ids = chunks_with_ids.withColumn("modality", func.lit(modality))
    
    # Check if target table exists and get its schema
    table_exists = spark.catalog._jcatalog.tableExists(chunked_docs_table)
    if table_exists:
        target_schema = spark.read.table(chunked_docs_table).schema
        target_has_modality = "modality" in [field.name for field in target_schema]
        
        # If target has modality but source doesn't, add it
        if target_has_modality and not source_has_modality:
            chunks_with_ids = chunks_with_ids.withColumn("modality", func.lit(modality))
    
    # Standardize timestamp format
    chunks_with_ids = standardize_timestamp_format(chunks_with_ids)
    
    # Reorder columns for better display
    final_columns = ["chunk_id", "content_chunked"]
    if "modality" in chunks_with_ids.columns:
        final_columns.append("modality")
    final_columns.extend(columns_to_propagate)
    
    chunks_with_ids = chunks_with_ids.select(*final_columns)
    
    if table_exists:
        # Read existing chunks
        existing_chunks = spark.read.table(chunked_docs_table)
        
        # Get existing chunk IDs
        existing_ids = existing_chunks.select("chunk_id").distinct()
        
        # Filter out chunks that already exist
        new_chunks = chunks_with_ids.join(
            existing_ids,
            chunks_with_ids.chunk_id == existing_ids.chunk_id,
            "left_anti"
        )
        
        logger.info(f"Found {chunks_with_ids.count()} total chunks, {new_chunks.count()} new chunks")
        
        # Append only new chunks
        if new_chunks.count() > 0:
            new_chunks.write.mode("append").saveAsTable(chunked_docs_table)
            logger.info(f"Appended {new_chunks.count()} new chunks to {chunked_docs_table}")
        else:
            logger.info("No new chunks to append")
    else:
        # Create new table if it doesn't exist
        chunks_with_ids.write.mode("overwrite").option(
            "overwriteSchema", "true"
        ).saveAsTable(chunked_docs_table)
        logger.info(f"Created new table {chunked_docs_table} with {chunks_with_ids.count()} chunks")
    
    return chunked_docs_table

# Example usage code
def process_video_chunks():
    """
    Example function to process video transcripts into chunks
    """
    logger.info("Starting video chunk processing...")
    
    # Configure the chunker
    chunk_fn = get_recursive_character_text_splitter(
        model_serving_endpoint=EMBEDDING_MODEL_ENDPOINT,
        chunk_size_tokens=384,
        chunk_overlap_tokens=128,
    )
    
    # Get source table schema
    source_table = f"{UC_CATALOG_NAME}.{UC_SCHEMA_NAME}.{VIDEO_DATA_TABLE_NAME}"
    source_schema = spark.table(source_table).schema
    
    # Log source table columns
    logger.info(f"Source table columns: {[field.name for field in source_schema]}")
    
    # Get the columns to propagate
    # Exclude only the columns we definitely don't want
    propagate_columns = [
        field.name
        for field in source_schema
        if field.name not in ["transcript_text", "chunk_count"]  # Keep name and modality
    ]
    
    logger.info(f"Propagating columns: {propagate_columns}")
    
    # Process chunks
    chunked_docs_table = compute_chunks(
        docs_table=source_table,
        doc_column="transcript_text",
        chunk_fn=chunk_fn,
        propagate_columns=propagate_columns,
        chunked_docs_table=CHUNKED_DOCS_DELTA_TABLE,
        modality="video"
    )
    
    # Display results
    result_df = spark.read.table(chunked_docs_table)
    logger.info(f"Chunked table schema: {result_df.schema}")
    logger.info(f"Number of chunks created: {result_df.count()}")
    
    return result_df

# To run the processing:
# result_df = process_video_chunks()
# display(result_df)

# Example usage:
"""
# Define chunking function
def chunk_text(text: str) -> list[str]:
    # Your chunking logic here
    pass

# Compute chunks for video transcripts
compute_chunks(
    docs_table="ankit_yadav.fluke_schema.video_data_text",
    doc_column="transcript_text",
    chunk_fn=chunk_text,
    propagate_columns=["name", "path", "length"],
    chunked_docs_table="ankit_yadav.fluke_schema.content_chunks",
    modality="video"
)
""" 

In [0]:
# Configure the chunker
chunk_fn = get_recursive_character_text_splitter(
    model_serving_endpoint=EMBEDDING_MODEL_ENDPOINT,
    chunk_size_tokens=384,
    chunk_overlap_tokens=128,
)

# Get the columns from the parser except for the doc_content
# You can modify this to adjust which fields are propagated from the docs table to the chunks table.
propagate_columns = [
    field.name
    for field in spark.table(f"{UC_CATALOG_NAME}.{UC_SCHEMA_NAME}.{VIDEO_DATA_TABLE_NAME}").schema.fields
    if field.name not in ["transcript_text", "chunk_count", "name"] #TODO Pass Name to the chunk table
]

chunked_docs_table = compute_chunks(
    # The source documents table.
    docs_table=f"{UC_CATALOG_NAME}.{UC_SCHEMA_NAME}.{VIDEO_DATA_TABLE_NAME}",
    # The column containing the documents to be chunked.
    doc_column="transcript_text",
    # The chunking function that takes a string (document) and returns a list of strings (chunks).
    chunk_fn=chunk_fn,
    # Choose which columns to propagate from the docs table to chunks table. `doc_uri` column is required we can propagate the original document URL to the Agent's web app.
    propagate_columns=propagate_columns,
    # By default, the chunked_docs_table will be written to `CHUNKED_DOCS_DELTA_TABLE`.
    chunked_docs_table=CHUNKED_DOCS_DELTA_TABLE,
    modality="video"
)

display(spark.read.table(chunked_docs_table))

##### `build_retriever_index`

`build_retriever_index` will build the vector search index which is used by our RAG to retrieve relevant documents.

Arguments:
- `chunked_docs_table`: The chunked documents table. There is expected to be a `chunked_text` column, a `chunk_id` column, and a `url` column.
-  `primary_key`: The column to use for the vector index primary key.
- `embedding_source_column`: The column to compute embeddings for in the vector index.
- `vector_search_endpoint`: An optional vector search endpoint name. It not defined, defaults to the `{table_id}_vector_search`.
- `vector_search_index_name`: An optional index name. If not defined, defaults to `{chunked_docs_table}_index`.
- `embedding_endpoint_name`: An embedding endpoint name.
- `force_delete_vector_search_endpoint`: Setting this to true will rebuild the vector search endpoint.

In [0]:
# from typing import TypedDict, Dict
# import io
# from typing import List, Dict, Any, Tuple, Optional, TypedDict
# import warnings
# import pyspark.sql.functions as func
# from pyspark.sql.types import StructType, StringType, StructField, MapType, ArrayType
# from mlflow.utils import databricks_utils as du
# from functools import partial
# import tiktoken
# from transformers import AutoTokenizer
# from langchain_text_splitters import RecursiveCharacterTextSplitter
# from databricks.vector_search.client import VectorSearchClient
# import mlflow


# def _build_index(
#     primary_key: str,
#     embedding_source_column: str,
#     vector_search_endpoint: str,
#     chunked_docs_table_name: str,
#     vectorsearch_index_name: str,
#     embedding_endpoint_name: str,
#     force_delete=False,
# ):

#     # Get the vector search index
#     vsc = VectorSearchClient(disable_notice=True)

#     def find_index(endpoint_name, index_name):
#         all_indexes = vsc.list_indexes(name=vector_search_endpoint).get(
#             "vector_indexes", []
#         )
#         return vectorsearch_index_name in map(lambda i: i.get("name"), all_indexes)

#     if find_index(
#         endpoint_name=vector_search_endpoint, index_name=vectorsearch_index_name
#     ):
#         if force_delete:
#             vsc.delete_index(
#                 endpoint_name=vector_search_endpoint, index_name=vectorsearch_index_name
#             )
#             create_index = True
#         else:
#             create_index = False
#             print(
#                 f"Syncing index {vectorsearch_index_name}, this can take 15 minutes or much longer if you have a larger number of documents..."
#             )

#             sync_result = vsc.get_index(index_name=vectorsearch_index_name).sync()

#     else:
#         print(
#             f'Creating non-existent vector search index for endpoint "{vector_search_endpoint}" and index "{vectorsearch_index_name}"'
#         )
#         create_index = True

#     if create_index:
#         print(
#             f"Computing document embeddings and Vector Search Index. This can take 15 minutes or much longer if you have a larger number of documents."
#         )

#         vsc.create_delta_sync_index_and_wait(
#             endpoint_name=vector_search_endpoint,
#             index_name=vectorsearch_index_name,
#             primary_key=primary_key,
#             source_table_name=chunked_docs_table_name,
#             pipeline_type="TRIGGERED",
#             embedding_source_column=embedding_source_column,
#             embedding_model_endpoint_name=embedding_endpoint_name,
#         )

In [0]:
# from pydantic import BaseModel


# class RetrieverIndexResult(BaseModel):
#     vector_search_endpoint: str
#     vector_search_index_name: str
#     embedding_endpoint_name: str
#     chunked_docs_table: str


# def build_retriever_index(
#     chunked_docs_table: str,
#     primary_key: str,
#     embedding_source_column: str,
#     embedding_endpoint_name: str,
#     vector_search_endpoint: str,
#     vector_search_index_name: str,
#     force_delete_vector_search_endpoint=False,
# ) -> RetrieverIndexResult:

#     retriever_index_result = RetrieverIndexResult(
#         vector_search_endpoint=vector_search_endpoint,
#         vector_search_index_name=vector_search_index_name,
#         embedding_endpoint_name=embedding_endpoint_name,
#         chunked_docs_table=chunked_docs_table,
#     )

#     # Enable CDC for Vector Search Delta Sync
#     spark.sql(
#         f"ALTER TABLE {chunked_docs_table} SET TBLPROPERTIES (delta.enableChangeDataFeed = true)"
#     )

#     print("Building embedding index...")
#     # Building the index.
#     _build_index(
#         primary_key=primary_key,
#         embedding_source_column=embedding_source_column,
#         vector_search_endpoint=vector_search_endpoint,
#         chunked_docs_table_name=chunked_docs_table,
#         vectorsearch_index_name=vector_search_index_name,
#         embedding_endpoint_name=embedding_endpoint_name,
#         force_delete=force_delete_vector_search_endpoint,
#     )

#     return retriever_index_result

In [0]:
# retriever_index_result = build_retriever_index(
#     # Spark requires `` to escape names with special chars, VS client does not.
#     chunked_docs_table=CHUNKED_DOCS_DELTA_TABLE.replace("`", ""),
#     primary_key="chunk_id",
#     embedding_source_column="content_chunked",
#     vector_search_endpoint=VECTOR_SEARCH_ENDPOINT,
#     vector_search_index_name=VECTOR_INDEX_NAME,
#     # Must match the embedding endpoint you used to chunk your documents
#     embedding_endpoint_name=EMBEDDING_MODEL_ENDPOINT,
#     # Set to true to re-create the vector search endpoint when re-running.
#     force_delete_vector_search_endpoint=False,
# )

# print(retriever_index_result)

# print()
# print("Vector search index created! This will be used in the next notebook.")
# print(f"Vector search endpoint: {retriever_index_result.vector_search_endpoint}")
# print(f"Vector search index: {retriever_index_result.vector_search_index_name}")
# print(f"Embedding used: {retriever_index_result.embedding_endpoint_name}")
# print(f"Chunked docs table: {retriever_index_result.chunked_docs_table}")