In [0]:
# %skip
# spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch",10)
# %pip install chonkie --break-system-packages

In [0]:
# dbutils.library.restartPython()

In [0]:
%sql
USE CATALOG dev_appian_poc;
USE SCHEMA 02_gold;

In [0]:
df = spark.table("document_text_contents")
# bronze_df = spark.table("dev_appian_poc.`00_bronze`.appian_raw_documents_ingest")

In [0]:
test_df = df.limit(1)
test_df.display()

In [0]:
from chonkie import RecursiveChunker
import pandas as pd
from typing import Iterator
from pyspark.sql.functions import pandas_udf, explode



@pandas_udf("array<string>")
def read_as_chunk(batch_iter : Iterator[pd.Series]) -> Iterator[pd.Series]:
    chunker = RecursiveChunker()
    for batch in batch_iter:
        result_chunk = []
        for text in batch:
            if text and isinstance(text, str):
                chunks = chunker(text)
                chunk_text = [str(chunk) for chunk in chunks]
                result_chunk.append(chunk_text)
            else:
                result_chunk.append([])
        yield pd.Series(result_chunk)


In [0]:
df_chunks = test_df.withColumn("full_text", explode(read_as_chunk("full_text")))

In [0]:
df_chunks.display()

In [0]:
from mlflow.deployments import get_deploy_client
import pandas as pd
from pyspark.sql.functions import pandas_udf


@pandas_udf("array<float>")
def get_embedding(contents: pd.Series, ) -> pd.Series:
    deploy_client = get_deploy_client("databricks")

    def get_embeddings(batch):
        response = deploy_client.predict(
            endpoint="databricks-bge-large-en", inputs={"input": batch}
        )
        return [e["embedding"] for e in response.data]

    batch_size = 150
    batches = [
        contents.iloc[i : i + batch_size]
        for i in range(0, len(contents), batch_size)
    ]

    all_embeddings = []
    for batch in batches:
        all_embeddings += get_embeddings(batch.tolist())

    return pd.Series(all_embeddings)


In [0]:
test_embedded = df_chunks.withColumn("embedding", get_embedding("full_text")).selectExpr(
    "row_num", "path", "full_text", "embedding"
)

In [0]:
test_embedded.display()