# Bootstrap Classification

LLM-based invoice categorization using the category hierarchy.

Creates:
- **prompts**: Stores prompt versions for classification
- **cat_bootstrap**: LLM classification results (category, confidence)

Tested on Serverless v4.

In [0]:
%pip install uv
%uv pip install .
%restart_python

In [0]:
from src.utils import get_spark
from src.config import load_config
import pandas as pd
import pyspark.sql.functions as F

spark = get_spark()
config = load_config()

In [0]:
# Widgets for SQL
dbutils.widgets.removeAll()
dbutils.widgets.text("llm_endpoint", config.large_llm_endpoint)
dbutils.widgets.text("catalog", config.catalog)
dbutils.widgets.text("schema", config.schema_name)
dbutils.widgets.text("invoices_table", config.invoices)
dbutils.widgets.text("cat_bootstrap_table", config.cat_bootstrap)
dbutils.widgets.text("prompts_table", config.prompts)
dbutils.widgets.text("categories_table", config.categories_table)
dbutils.widgets.text("categories_str_table", config.categories_str)
dbutils.widgets.text("batch_size", "1000")
dbutils.widgets.text("max_batches", "100")

# Get widget values for use in SQL
catalog = dbutils.widgets.get("catalog")
schema = dbutils.widgets.get("schema")
cat_bootstrap_table = dbutils.widgets.get("cat_bootstrap_table")
categories_str_table = dbutils.widgets.get("categories_str_table")
prompts_table = dbutils.widgets.get("prompts_table")
invoices_table = dbutils.widgets.get("invoices_table")
llm_endpoint = dbutils.widgets.get("llm_endpoint")
batch_size = int(dbutils.widgets.get("batch_size"))
max_batches = int(dbutils.widgets.get("max_batches"))

In [0]:
%sql
SELECT * FROM IDENTIFIER(:catalog || '.' || :schema || '.' || :categories_table) LIMIT 10

In [0]:
%sql
SELECT * FROM IDENTIFIER(:catalog || '.' || :schema || '.' || :invoices_table) LIMIT 10

## 1. Build Category Hierarchy Strings

Create markdown representation for the prompt. We have simple examples with a simple Level 3 + Description and a full hierarchy off the table both provided.

In [0]:
# Simple category table
cat_df = spark.table(config.full_categories_table_path).toPandas()
cat_df['cat_w_description'] = cat_df['category_level_3'] + ": " + cat_df['category_level_3_description']
categories_str = ' '.join(cat_df['cat_w_description'].astype(str))

In [0]:
def build_category_string_from_table(categories_table: str) -> str:
    """Convert categories table to markdown for prompt, including descriptions."""
    df = spark.table(categories_table)
    
    pdf = df.select(
        "category_level_1", "category_level_2", "category_level_3", "category_level_3_description"
    ).distinct().toPandas()
    
    lines = ["# Spend Categories\n"]
    for level1 in sorted(pdf['category_level_1'].dropna().unique()):
        lines.append(f"## {level1}")
        level2s = pdf[pdf['category_level_1'] == level1]['category_level_2'].dropna().unique()
        for level2 in sorted(level2s):
            lines.append(f"### {level2}")
            subdf = pdf[(pdf['category_level_1'] == level1) & (pdf['category_level_2'] == level2)]
            for _, row in subdf.iterrows():
                cat3 = row['category_level_3']
                desc = row['category_level_3_description']
                if pd.notnull(cat3) and pd.notnull(desc):
                    lines.append(f"- {cat3}: {desc}")
                elif pd.notnull(cat3):
                    lines.append(f"- {cat3}")
            lines.append("")
    return "\n".join(lines)

categories_str = build_category_string_from_table(config.full_categories_table_path)
print(categories_str[:500] + "...")

In [0]:
%sql
CREATE TABLE IF NOT EXISTS IDENTIFIER(:catalog || '.' || :schema || '.' || :categories_str_table) (
  datetime TIMESTAMP,
  file_name STRING,
  categories_str STRING
)

In [0]:
(
  spark.createDataFrame([(categories_str,)], ["categories_str"])
  .withColumn("file_name", F.lit(config.categories_file))
  .withColumn("datetime", F.current_timestamp())
  .select('datetime','file_name','categories_str')
  .write.format('delta')
  .mode('append')
  .saveAsTable(config.full_categories_str_table_path)
)

In [0]:
%sql
SELECT *
FROM IDENTIFIER(:catalog || '.' || :schema || '.' || :categories_str_table)
LIMIT 5

## 2. Build Prompt Table

In [0]:
prompt = """
Use the following categories and accompanying description and classify the invoice below.

<Rules>
- If you cannot identify the category, return "Unknown" and rank the confidence as 1
- A confidence of 2 means you can classify it, but it could easily be multiple categories
- A confidence of 3 means it is likely a single category, but the rationale isn't clear
- A confidence of 4 means you are very confident in your classification with a clear rationale
- A confidence of 5 means you have an exact example in the context for the classification
</Rules>

Return a json output with the category, confidence of classification between 1 and 5, and rationale. You must use the categories provided.
"""
print(prompt)

In [0]:
%sql
CREATE TABLE IF NOT EXISTS IDENTIFIER(:catalog || '.' || :schema || '.' || :prompts_table) (
  datetime TIMESTAMP,
  prompt STRING
)

In [0]:
(
  spark.createDataFrame([(prompt,)], ["prompt"])
  .withColumn("datetime", F.current_timestamp())
  .select('datetime','prompt')
  .write.format('delta')
  .mode('append')
  .saveAsTable(config.full_prompts_table_path)
)

In [0]:
%sql
SELECT *
FROM IDENTIFIER(:catalog || '.' || :schema || '.' || :prompts_table)
LIMIT 5

## 5. Run LLM Classification

Use AI_QUERY for batch inference.

In [0]:
%sql
-- Initialize cat_bootstrap table with schema
CREATE TABLE IF NOT EXISTS IDENTIFIER(:catalog || '.' || :schema || '.' || :cat_bootstrap_table) (
  order_id STRING,
  description STRING,
  actual_level_1 STRING,
  actual_level_2 STRING,
  pred_level_1 STRING,
  pred_level_2 STRING,
  confidence STRING,
  rationale STRING,
  classified_at TIMESTAMP,
  source STRING
) USING DELTA

In [0]:
# Get batch configuration
total_rows = spark.table(config.full_invoices_table_path).count()
print(f"Total rows: {total_rows}")
print(f"Batch size: {batch_size}")
print(f"Max batches: {max_batches}")
print(f"Target table: {config.full_cat_bootstrap_table_path}")

In [0]:
# Register the combined prompt for optimization
combined_prompt = spark.sql(f"""
    SELECT
        CONCAT(prompt, '\\nCategories:\\n', categories_str) AS combined_prompt
    FROM (
        SELECT prompt
        FROM {catalog}.{schema}.{prompts_table}
        ORDER BY datetime DESC
        LIMIT 1
    )
    CROSS JOIN (
        SELECT categories_str
        FROM {catalog}.{schema}.{categories_str_table}
        ORDER BY datetime DESC
        LIMIT 1
    )
""").first()['combined_prompt']

combined_prompt


In [0]:

import mlflow

# Register the combined prompt in the MLflow Prompt Registry
prompt_registry_name = f"{catalog}.{schema}.combined_invoice_prompt"
registered_prompt = mlflow.genai.register_prompt(
    name=prompt_registry_name,
    template=combined_prompt,
    commit_message="Register combined invoice classification prompt",
    tags={
        "source": "bootstrap_classification",
        "type": "combined_prompt",
        "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
    }
)
print(f"Registered prompt: {registered_prompt.name} (version {registered_prompt.version})")

In [0]:
batch_num = 0
while batch_num < int(max_batches):
    processed_count = spark.table(config.full_cat_bootstrap_table_path).count()
    total_count = spark.table(config.full_invoices_table_path).count()
    
    if processed_count >= total_count:
        print(f"ðŸŽ‰ All {total_count} rows already processed!")
        break

    spark.sql(f"""
        MERGE INTO {catalog}.{schema}.{cat_bootstrap_table} AS target
        USING (
            WITH latest_prompt AS (
                SELECT prompt
                FROM {catalog}.{schema}.{prompts_table}
                ORDER BY datetime DESC
                LIMIT 1
            ),
            latest_category AS (
                SELECT categories_str
                FROM {catalog}.{schema}.{categories_str_table}
                ORDER BY datetime DESC
                LIMIT 1
            )
            SELECT
                order_id,
                description,
                category_level_1 AS actual_level_1,
                category_level_2 AS actual_level_2,
                ai_result.category_level_1 AS pred_level_1,
                ai_result.category_level_2 AS pred_level_2,
                ai_result.confidence AS confidence,
                ai_result.rationale AS rationale,
                current_timestamp() AS classified_at,
                'bootstrap' AS source
            FROM (
                SELECT
                    i.*,
                    from_json(AI_QUERY(
                        '{llm_endpoint}',
                        CONCAT(
                            p.prompt, 
                            '\\nCategories:\\n', 
                            c.categories_str, 
                            '\\n\\nInvoice: ', 
                            i.description
                        ),
                        responseFormat => 'STRUCT<result:STRUCT<category_level_1:STRING, category_level_2:STRING, confidence:DOUBLE, rationale:STRING>>'
                    ), 'STRUCT<category_level_1:STRING, category_level_2:STRING, confidence:DOUBLE, rationale:STRING>'
                    ) AS ai_result
                FROM {catalog}.{schema}.{invoices_table} i
                CROSS JOIN latest_prompt p
                CROSS JOIN latest_category c
                LEFT ANTI JOIN {catalog}.{schema}.{cat_bootstrap_table} cat
                    ON i.order_id = cat.order_id
                LIMIT {batch_size}
            )
        ) AS source
        ON target.order_id = source.order_id
        WHEN NOT MATCHED THEN INSERT *
    """)
    
    processed_count = spark.table(config.full_cat_bootstrap_table_path).count()
    print(f"âœ… Batch {batch_num + 1} complete. Processed: {processed_count} / {total_count} rows ({processed_count/total_count*100:.1f}%)")
    
    batch_num += 1
    
    if processed_count >= total_count:
        print(f"\nðŸŽ‰ All rows processed!")
        break

print(f"\nFinal: {processed_count} / {total_count} rows processed")

In [0]:
%sql
SELECT * FROM IDENTIFIER(:catalog || '.' || :schema || '.' || :cat_bootstrap_table)

## 6. Evaluate Accuracy
Ultimately this is a classification problem. Let's look at the accuracy.

In [0]:
from sklearn.metrics import accuracy_score, classification_report

results = spark.table(config.full_cat_bootstrap_table_path).where("description IS NOT Null").toPandas()

level1_acc = accuracy_score(results["actual_level_1"], results["pred_level_1"])
print(f"Level 1 Accuracy: {level1_acc:.3f}")

level2_acc = accuracy_score(results["actual_level_2"], results["pred_level_2"])
print(f"Level 2 Accuracy: {level2_acc:.3f}")

print("Level 1 Classification Report:")
print(classification_report(results["actual_level_1"], results["pred_level_1"]))