In [None]:
from __future__ import annotations

import pymupdf
import pandas as pd
import torch
from langchain.text_splitter import RecursiveCharacterTextSplitter
from pyspark.sql.functions import udf, explode, col, pandas_udf
from pyspark.sql.types import (
    ArrayType,
    StructType,
    StructField,
    StringType,
    IntegerType,
    FloatType,
)

In [None]:
%%configure -f
{
  "executorCores": 1,
  "conf": {
    "spark.sql.execution.arrow.maxRecordsPerBatch": "10"
  }
}

In [None]:
def extract_text_from_parsed_pdf(pdf_bytes: bytes):
    try:
        doc = pymupdf.Document(stream=pdf_bytes, filetype="pdf")
        if len(doc) > 100:
            return None
        return [{"text": page.get_text(), "page_number": page.number} for page in doc]
    except Exception:
        return None


extract_schema = ArrayType(
    StructType(
        [
            StructField("text", StringType(), True),
            StructField("page_number", IntegerType(), True),
        ]
    )
)
extract_udf = udf(extract_text_from_parsed_pdf, extract_schema)


def chunk(text: str):
    splitter = RecursiveCharacterTextSplitter(chunk_size=2048, chunk_overlap=200)
    chunks = []
    for idx, t in enumerate(splitter.split_text(text)):
        chunks.append({"text": t, "chunk_id": idx})
    return chunks


chunk_schema = ArrayType(
    StructType(
        [
            StructField("text", StringType(), True),
            StructField("chunk_id", IntegerType(), True),
        ]
    )
)
chunk_udf = udf(chunk, chunk_schema)

_model_cache = {"model": None}


def get_model():
    import os

    os.environ["TORCH_HOME"] = "/tmp/torch"
    os.environ["XDG_CACHE_HOME"] = "/tmp"
    os.environ["HF_HOME"] = "/tmp/huggingface"
    os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"

    if _model_cache["model"] is None:
        from sentence_transformers import SentenceTransformer

        device = "cuda"
        model = SentenceTransformer(
            "sentence-transformers/all-MiniLM-L6-v2", device=device
        )
        model.compile()
        _model_cache["model"] = model
    return _model_cache["model"]


@pandas_udf(ArrayType(FloatType()))
def embed_udf(texts: pd.Series) -> pd.Series:
    model = get_model()
    if texts.empty:
        return pd.Series([[]] * len(texts))

    embeddings = model.encode(
        texts.tolist(),
        convert_to_tensor=True,
        torch_dtype=torch.bfloat16,
    )
    return pd.Series([row.tolist() for row in embeddings.cpu().numpy()])

In [None]:
paths_df = spark.read.parquet(
    "s3://daft-public-datasets/digitalcorpora_metadata"
).filter(col("file_name").endswith(".pdf"))
paths = [row.uploaded_pdf_path for row in paths_df.collect()]

df = spark.read.format("binaryFile").load(paths)

df = df.withColumnRenamed("path", "uploaded_pdf_path")

df = df.withColumn("pages", extract_udf(col("content")))
df = df.withColumn("page", explode("pages"))
df = df.withColumn("page_text", col("page.text"))
df = df.withColumn("page_number", col("page.page_number"))
df = df.filter(col("page_text").isNotNull())
df = df.withColumn("chunks", chunk_udf(col("page_text")))
df = df.withColumn("chunk", explode("chunks"))
df = df.withColumn("chunk_text", col("chunk.text"))
df = df.withColumn("chunk_id", col("chunk.chunk_id"))
df = df.filter(col("chunk_text").isNotNull())
df = df.withColumn("embedding", embed_udf(col("chunk_text")))
df = df.select(
    "uploaded_pdf_path", "page_number", "chunk_id", "chunk_text", "embedding"
)
df.write.mode("append").parquet(
    "s3://eventual-dev-benchmarking-results/ai-benchmark-results/document-embedding-results"
)