In [0]:
%pip install -qqqq -U pypdf==4.1.0 databricks-vectorsearch transformers==4.41.1 torch==2.3.0 tiktoken==0.7.0 langchain-text-splitters==0.2.2 mlflow mlflow-skinny

In [0]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, expr, current_timestamp, md5, explode, udf
from pyspark.sql.types import StringType, ArrayType
import mlflow
from mlflow.utils import databricks_utils as du
from dlt import *
import yaml
from pypdf import PdfReader
import io
from langchain_text_splitters import RecursiveCharacterTextSplitter
from transformers import AutoTokenizer
import tiktoken

# Load configurations from YAML files
def load_config(config_file):
    with open(config_file, 'r') as f:
        return yaml.safe_load(f)

# Load configurations
data_pipeline_config = load_config('data_pipeline_config.yaml')
destination_tables_config = load_config('destination_tables_config.yaml')

# Extract configuration values
UC_CATALOG = "theodore_kop_personal"  # This could also be loaded from a config file
UC_SCHEMA = "bcp"
RAG_APP_NAME = "bcp_document_rag_poc"

# Source path for documents
SOURCE_PATH = f"/Volumes/theodore_kop_personal/bcp/rag_documentation"

# Define PDF parsing UDF
def parse_pdf_udf(content):
    """Parse PDF content and return the extracted text."""
    try:
        pdf = io.BytesIO(content)
        reader = PdfReader(pdf)
        parsed_content = [page.extract_text() for page in reader.pages]
        return "\n".join(parsed_content)
    except Exception as e:
        print(f"Error parsing PDF: {str(e)}")
        return None

# Register the PDF parsing UDF
parse_pdf = udf(parse_pdf_udf, StringType())

# Define text chunking UDF
def chunk_text_udf(text):
    """Chunk text using the configured chunking strategy."""
    try:
        embedding_config = data_pipeline_config["embedding_config"]
        chunker_config = data_pipeline_config["pipeline_config"]["chunker"]["config"]
        
        # Select the correct tokenizer based on the embedding model configuration
        if embedding_config["embedding_tokenizer"]["tokenizer_source"] == "hugging_faceXX":
            tokenizer = AutoTokenizer.from_pretrained(
                embedding_config["embedding_tokenizer"]["tokenizer_model_name"]
            )
            text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
                tokenizer,
                chunk_size=chunker_config["chunk_size_tokens"],
                chunk_overlap=chunker_config["chunk_overlap_tokens"],
            )
        elif embedding_config["embedding_tokenizer"]["tokenizer_source"] == "tiktoken":
            tokenizer = tiktoken.encoding_for_model(
                embedding_config["embedding_tokenizer"]["tokenizer_model_name"]
            )
            text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
                tokenizer,
                chunk_size=chunker_config["chunk_size_tokens"],
                chunk_overlap=chunker_config["chunk_overlap_tokens"],
            )
        else:
            text_splitter = RecursiveCharacterTextSplitter(
                chunk_size=chunker_config["chunk_size_tokens"],
                chunk_overlap=chunker_config["chunk_overlap_tokens"],
            )
        
        chunks = text_splitter.split_text(text)
        return chunks
    except Exception as e:
        print(f"Error chunking text: {str(e)}")
        return []

# Register the text chunking UDF
chunk_text = udf(chunk_text_udf, ArrayType(StringType()))

# Define the raw files table (Bronze)
@dlt.table(
    name="raw_files",
    comment="Raw files from UC Volume",
    table_properties={
        "quality": "bronze",
        "pipelines.autoOptimize.zOrderCols": "path"
    }
)
def raw_files():
    return (
        spark.readStream
        .format("cloudFiles")
        .option("cloudFiles.format", "binaryFile")
        .option("cloudFiles.schemaLocation", f"/tmp/{RAG_APP_NAME}/schema")
        .option("cloudFiles.inferColumnTypes", "true")
        .option("recursiveFileLookup", "true")
        .option("pathGlobFilter", "*.pdf")
        .load(SOURCE_PATH)
        .withColumn("processing_timestamp", current_timestamp())
    )

# Define the parsed documents table (Silver)
@dlt.table(
    name="parsed_docs",
    comment="Parsed document contents",
    table_properties={
        "quality": "silver",
        "pipelines.autoOptimize.zOrderCols": "path"
    }
)
@dlt.expect("valid_content", "parsed_content IS NOT NULL")
@dlt.expect("valid_path", "path IS NOT NULL")
def parsed_docs():
    return (
        dlt.read_stream("raw_files")
        .withColumn("parsed_content", parse_pdf("content"))
        .select(
            "path",
            "parsed_content",
            "processing_timestamp"
        )
    )

# Define the chunked documents table (Gold)
@dlt.table(
    name="chunked_docs",
    comment="Chunked document contents for vector search",
    table_properties={
        "quality": "gold",
        "pipelines.autoOptimize.zOrderCols": "chunk_id",
        "pipelines.autoOptimize.optimizeWrite": "true",
        "delta.enableChangeDataFeed" : "true"
    }
)
@dlt.expect("valid_chunk", "chunked_text IS NOT NULL")
@dlt.expect("valid_chunk_id", "chunk_id IS NOT NULL")
def chunked_docs():
    return (
        dlt.read_stream("parsed_docs")
        .withColumn("chunks", chunk_text("parsed_content"))
        .select(
            "path",
            explode("chunks").alias("chunked_text"),
            "processing_timestamp"
        )
        .withColumn("chunk_id", md5("chunked_text"))
    )

# Define the data quality metrics
@dlt.view(
    name="quality_metrics",
    comment="Data quality metrics for the pipeline"
)
def quality_metrics():
    return spark.sql("""
        SELECT
            'raw_files' as table_name,
            COUNT(*) as total_records,
            COUNT(CASE WHEN content IS NULL THEN 1 END) as null_content_count
        FROM raw_files
        UNION ALL
        SELECT
            'parsed_docs' as table_name,
            COUNT(*) as total_records,
            COUNT(CASE WHEN parsed_content IS NULL THEN 1 END) as null_parsed_content_count
        FROM parsed_docs
        UNION ALL
        SELECT
            'chunked_docs' as table_name,
            COUNT(*) as total_records,
            COUNT(CASE WHEN chunked_text IS NULL THEN 1 END) as null_chunked_text_count
        FROM chunked_docs
    """)
