In [0]:
from chonkie import RecursiveChunker
import pandas as pd
from typing import Iterator
from pyspark.sql.functions import pandas_udf, explode,col, expr, current_timestamp
from mlflow.deployments import get_deploy_client

@pandas_udf("array<string>")
def read_as_chunk(batch_iter : Iterator[pd.Series]) -> Iterator[pd.Series]:
    chunker = RecursiveChunker()
    for batch in batch_iter:
        result_chunk = []
        for text in batch:
            if text and isinstance(text, str):
                chunks = chunker(text)
                chunk_text = [str(chunk) for chunk in chunks]
                result_chunk.append(chunk_text)
            else:
                result_chunk.append([])
        yield pd.Series(result_chunk)


@pandas_udf("array<float>")
def get_embedding(contents: pd.Series, ) -> pd.Series:
    deploy_client = get_deploy_client("databricks")

    def get_embeddings(batch):
        response = deploy_client.predict(
            endpoint="databricks-bge-large-en", inputs={"input": batch}
        )
        return [e["embedding"] for e in response.data]

    batch_size = 150
    batches = [
        contents.iloc[i : i + batch_size]
        for i in range(0, len(contents), batch_size)
    ]

    all_embeddings = []
    for batch in batches:
        all_embeddings += get_embeddings(batch.tolist())

    return pd.Series(all_embeddings)


In [None]:
%sql
CREATE TABLE IF NOT EXISTS dev_appian_poc.`02_gold`.ingestion_text_embeddings (
  id BIGINT GENERATED BY DEFAULT AS IDENTITY,
  doc_id INT,
  path STRING,
  content STRING,
  embedding ARRAY<FLOAT>,
  processed_at TIMESTAMP
)
TBLPROPERTIES (delta.enableChangeDataFeed = true);

In [None]:
df = (
    spark.readStream.format("delta")
    .option("ignoreDeletes", "true")
    .table("dev_appian_poc.02_gold.document_text_contents")
    .withColumn("chunked", read_as_chunk(col("full_text")))
    .withColumn("chunked", expr("explode(chunked)"))
    .withColumn("embedded", get_embedding(col("chunked")))
    .withColumn("processed_at", current_timestamp())
    .select(
        col("doc_id").alias("doc_id"),
        col("path"),
        col("chunked").alias("content"),
        col("embedded").alias("embedding"),
    )
)


query = (
    df.writeStream.format("delta")
    .option(
        "checkpointLocation",
        "/Volumes/dev_appian_poc/00_bronze/checkpoints/embedded_document",
    )
    .option("mergeSchema", "true")
    .outputMode("append")
    .trigger(availableNow=True)
    .toTable("dev_appian_poc.02_gold.ingestion_text_embeddings")
)

query.awaitTermination()