# Bootstrap Classification

LLM-based invoice categorization using the category hierarchy.

Creates:
- **prompts**: Stores prompt versions for classification
- **cat_bootstrap**: LLM classification results (category, confidence)

Tested on Serverless v4.

In [0]:
%pip install uv
%uv pip install .
%restart_python

In [0]:
from src.utils import get_spark
from src.config import load_config

spark = get_spark()
config = load_config()

In [0]:
%sql
SELECT * FROM IDENTIFIER(:catalog || '.' || :schema || '.' || :categories) LIMIT 10

In [0]:
config

## 1. Build Category Hierarchy

Create markdown representation for the prompt.

In [0]:
def build_category_string(categories: dict) -> str:
    """Convert category hierarchy to markdown for prompt."""
    lines = ["# Spend Categories\n"]
    for level1, level2_dict in categories.items():
        lines.append(f"## {level1}")
        for level2, level3_list in level2_dict.items():
            lines.append(f"### {level2}")
            for level3 in level3_list:
                lines.append(f"- {level3}")
        lines.append("")
    return "\n".join(lines)

categories_str = build_category_string(config.categories)
print(categories_str[:500] + "...")

## 2. Define Classification Prompt

In [0]:
PROMPT = """Classify the invoice into the category hierarchy.

Return JSON with:
- category_level_1: Direct, Indirect, or Non-Procureable
- category_level_2: The level 2 category name
- confidence: 1-10 confidence score
- rationale: Brief explanation

Use only categories from the hierarchy below.
"""
print(PROMPT)

## 3. Save Prompt Version

In [0]:
from src.versioning import save_prompt_version, save_category_summary

# Save prompt version (only if changed)
save_prompt_version(spark, config, PROMPT, categories_str)

# Save category summary (only if categories changed)
save_category_summary(spark, config, config.categories, categories_str)

## 4. Preview Invoices

In [0]:
invoices = spark.table(config.full_invoices_table_path)
print(f"Invoices: {invoices.count()} rows")
invoices.select("order_id", "description", "category_level_1", "category_level_2").show(5, truncate=40)

## 5. Run LLM Classification

Use AI_QUERY for batch inference.

In [0]:
# Widgets for SQL
dbutils.widgets.removeAll()
dbutils.widgets.text("llm_endpoint", config.large_llm_endpoint)
dbutils.widgets.text("catalog", config.catalog)
dbutils.widgets.text("schema", config.schema_name)
dbutils.widgets.text("invoices", config.invoices)
dbutils.widgets.text("cat_bootstrap", config.cat_bootstrap)
dbutils.widgets.text("prompts", config.prompts)
dbutils.widgets.text("categories", config.categories_table)
dbutils.widgets.text("batch_size", "1000")
dbutils.widgets.text("max_batches", "100")

In [None]:
%sql
-- Initialize cat_bootstrap table with schema
CREATE TABLE IF NOT EXISTS IDENTIFIER(:catalog || '.' || :schema || '.' || :cat_bootstrap) (
  order_id STRING,
  description STRING,
  actual_level_1 STRING,
  actual_level_2 STRING,
  pred_level_1 STRING,
  pred_level_2 STRING,
  confidence DOUBLE,
  classified_at TIMESTAMP,
  processed_at TIMESTAMP,
  source STRING
) USING DELTA

In [None]:
# Get batch configuration
batch_size = int(dbutils.widgets.get("batch_size"))
max_batches = int(dbutils.widgets.get("max_batches"))

# Get total rows to process
total_rows = spark.table(config.full_invoices_table_path).count()
num_batches = min(max_batches, (total_rows + batch_size - 1) // batch_size)

print(f"Processing {total_rows} rows in {num_batches} batches of {batch_size} rows each")
print(f"Target table: {config.full_cat_bootstrap_table_path}")

In [None]:
# Process in batches with MERGE INTO
for batch_num in range(num_batches):
    offset = batch_num * batch_size
    
    print(f"\n{'='*60}")
    print(f"Processing batch {batch_num + 1}/{num_batches} (rows {offset} to {offset + batch_size})")
    print(f"{'='*60}")
    
    # Create temporary view with row numbers
    spark.sql(f"""
        CREATE OR REPLACE TEMP VIEW batch_invoices_{batch_num} AS
        SELECT 
            ROW_NUMBER() OVER (ORDER BY order_id) - 1 as row_num,
            *
        FROM IDENTIFIER(:catalog || '.' || :schema || '.' || :invoices)
    """)
    
    # Process batch with LLM classification
    spark.sql(f"""
        MERGE INTO IDENTIFIER(:catalog || '.' || :schema || '.' || :cat_bootstrap) AS target
        USING (
            WITH latest_prompt AS (
                SELECT prompt, categories
                FROM IDENTIFIER(:catalog || '.' || :schema || '.' || :prompts)
                ORDER BY created_at DESC
                LIMIT 1
            )
            SELECT
                i.order_id,
                i.description,
                i.category_level_1 AS actual_level_1,
                i.category_level_2 AS actual_level_2,
                AI_QUERY(
                    :llm_endpoint,
                    CONCAT(p.prompt, '\\n', p.categories, '\\n\\nInvoice: ', i.description),
                    responseFormat => '{{"type":"json_schema","json_schema":{{"name":"classification","schema":{{"type":"object","properties":{{"category_level_1":{{"type":"string"}},"category_level_2":{{"type":"string"}},"confidence":{{"type":"number"}},"rationale":{{"type":"string"}}}}}}}}}}'
                ):category_level_1 AS pred_level_1,
                AI_QUERY(
                    :llm_endpoint,
                    CONCAT(p.prompt, '\\n', p.categories, '\\n\\nInvoice: ', i.description),
                    responseFormat => '{{"type":"json_schema","json_schema":{{"name":"classification","schema":{{"type":"object","properties":{{"category_level_1":{{"type":"string"}},"category_level_2":{{"type":"string"}},"confidence":{{"type":"number"}},"rationale":{{"type":"string"}}}}}}}}}}'
                ):category_level_2 AS pred_level_2,
                AI_QUERY(
                    :llm_endpoint,
                    CONCAT(p.prompt, '\\n', p.categories, '\\n\\nInvoice: ', i.description),
                    responseFormat => '{{"type":"json_schema","json_schema":{{"name":"classification","schema":{{"type":"object","properties":{{"category_level_1":{{"type":"string"}},"category_level_2":{{"type":"string"}},"confidence":{{"type":"number"}},"rationale":{{"type":"string"}}}}}}}}}}'
                ):confidence AS confidence,
                current_timestamp() AS classified_at,
                current_timestamp() AS processed_at
            FROM batch_invoices_{batch_num} i
            CROSS JOIN latest_prompt p
            WHERE i.row_num >= {offset} AND i.row_num < {offset + batch_size}
        ) AS source
        ON target.order_id = source.order_id
        WHEN MATCHED THEN UPDATE SET *
        WHEN NOT MATCHED THEN INSERT *
    """)
    
    processed_count = spark.table(config.full_cat_bootstrap_table_path).count()
    print(f"âœ… Batch {batch_num + 1} complete. Total processed: {processed_count} rows")

print(f"\n{'='*60}")
print(f"ðŸŽ‰ All batches complete! Total rows: {processed_count}")
print(f"{'='*60}")

In [0]:
%sql
-- LLM classification using latest prompt
CREATE OR REPLACE TABLE $catalog.$schema.$cat_bootstrap_table AS
WITH latest_prompt AS (
  SELECT prompt, categories
  FROM $catalog.$schema.$prompts_table
  ORDER BY created_at DESC
  LIMIT 1
)
SELECT
  i.order_id,
  i.description,
  i.category_level_1 AS actual_level_1,
  i.category_level_2 AS actual_level_2,
  AI_QUERY(
    :llm_endpoint,
    CONCAT(p.prompt, '\n', p.categories, '\n\nInvoice: ', i.description),
    responseFormat => '{"type":"json_schema","json_schema":{"name":"classification","schema":{"type":"object","properties":{"category_level_1":{"type":"string"},"category_level_2":{"type":"string"},"confidence":{"type":"number"},"rationale":{"type":"string"}}}}}'
  ):category_level_1 AS pred_level_1,
  AI_QUERY(
    :llm_endpoint,
    CONCAT(p.prompt, '\n', p.categories, '\n\nInvoice: ', i.description),
    responseFormat => '{"type":"json_schema","json_schema":{"name":"classification","schema":{"type":"object","properties":{"category_level_1":{"type":"string"},"category_level_2":{"type":"string"},"confidence":{"type":"number"},"rationale":{"type":"string"}}}}}'
  ):category_level_2 AS pred_level_2,
  AI_QUERY(
    :llm_endpoint,
    CONCAT(p.prompt, '\n', p.categories, '\n\nInvoice: ', i.description),
    responseFormat => '{"type":"json_schema","json_schema":{"name":"classification","schema":{"type":"object","properties":{"category_level_1":{"type":"string"},"category_level_2":{"type":"string"},"confidence":{"type":"number"},"rationale":{"type":"string"}}}}}'
  ):confidence AS confidence,
  current_timestamp() AS classified_at
FROM $catalog.$schema.$invoices_table i
CROSS JOIN latest_prompt p

## 6. Evaluate Accuracy

In [0]:
from sklearn.metrics import accuracy_score, classification_report

results = spark.table(config.full_cat_bootstrap_table_path).toPandas()

level1_acc = accuracy_score(results["actual_level_1"], results["pred_level_1"])
print(f"Level 1 Accuracy: {level1_acc:.3f}")

level2_acc = accuracy_score(results["actual_level_2"], results["pred_level_2"])
print(f"Level 2 Accuracy: {level2_acc:.3f}")

In [0]:
print("Level 1 Classification Report:")
print(classification_report(results["actual_level_1"], results["pred_level_1"]))

## 7. Summary

In [0]:
print("Tables created:")
print(f"  {config.full_prompts_table_path}")
print(f"  {config.full_cat_bootstrap_table_path}")
print(f"\nResults: {len(results)} invoices classified")
print(f"Level 1 Accuracy: {level1_acc:.1%}")
print(f"Level 2 Accuracy: {level2_acc:.1%}")