# Les Misér-AI-bles: Batching Through the Barricades
## Spark Profiling

This notebook takes our previously ingested table and uses Spark UDFs to run our prompt against three OpenAI endpoints

In [0]:
%pip install databricks-agents --quiet
%restart_python

In [0]:
from databricks.sdk import WorkspaceClient

w = WorkspaceClient()

workspace_client = WorkspaceClient()
workspace_url = workspace_client.config.host

# Check if running in Databricks
import os

if "DATABRICKS_RUNTIME_VERSION" in os.environ:
    token = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()
else:
    token = workspace_client.config.token

In [0]:
from pyspark.sql.functions import col
from pyspark.sql.types import StructType, StructField, StringType, FloatType
import time
import json

def extract_data_from_passage(header2, header3, page_content):
    start_time = time.time()
    
    client = OpenAI(
        api_key=token,
        base_url=f"{workspace_url}/serving-endpoints",
    )
    
    response = client.chat.completions.create(
        model='azure-o1',
        messages=[
            {"role": "user", "content": f"""
             Take this passage from Les Miserables and do structured data extraction in JSON. I want you to provide the title of the chapter, a list of characters, a synopsis of the chapter, and the overall sentiment of the chapter - positive, neutral, or negative. Do not make up anything if the passage isn't part of the novel. Also include 'experiment: o1-udf'

             Output Format:
                title: 
                characters: []
                synopsis:
                sentiment:
                experiment:

             {header2}
             {header3}
             {page_content}
             """}
        ],
    )
    
    end_time = time.time()
    elapsed_time = end_time - start_time
    
    return response.choices[0].message.content.replace("json\n","").replace("","") + f"time: {elapsed_time:.2f}"

extract_data_udf = udf(extract_data_from_passage, StringType())

In [0]:
output = les_mis_df.repartition(24).withColumn(
    "extracted_data", 
    extract_data_udf(les_mis_df.header_2, les_mis_df.header_3, les_mis_df.page_content)
    )

In [0]:
display(output)

In [0]:
output.write.mode('overwrite').saveAsTable('shm.default.`azure-o1_profiling`')