# Bootstrap Classification

LLM-based invoice categorization using the category hierarchy.

Creates:
- **prompts**: Stores prompt versions for classification
- **cat_bootstrap**: LLM classification results (category, confidence)

In [None]:
from src.utils import get_spark, load_categorize_config
from src.generate import load_generate_config

spark = get_spark()
gen_config = load_generate_config()
cat_config = load_categorize_config()

## 1. Build Category Hierarchy

Create markdown representation for the prompt.

In [None]:
def build_category_string(categories: dict) -> str:
    """Convert category hierarchy to markdown for prompt."""
    lines = ["# Spend Categories\n"]
    for level1, level2_dict in categories.items():
        lines.append(f"## {level1}")
        for level2, level3_list in level2_dict.items():
            lines.append(f"### {level2}")
            for level3 in level3_list:
                lines.append(f"- {level3}")
        lines.append("")
    return "\n".join(lines)

categories_str = build_category_string(gen_config.categories)
print(categories_str[:500] + "...")

## 2. Define Classification Prompt

In [None]:
PROMPT = """Classify the invoice into the category hierarchy.

Return JSON with:
- category_level_1: Direct, Indirect, or Non-Procureable
- category_level_2: The level 2 category name
- confidence: 1-10 confidence score
- rationale: Brief explanation

Use only categories from the hierarchy below.
"""
print(PROMPT)

## 3. Save Prompt Version

In [None]:
from pyspark.sql.functions import current_timestamp

prompt_df = spark.createDataFrame(
    [(PROMPT, categories_str)],
    ["prompt", "categories"]
).withColumn("created_at", current_timestamp())

prompt_df.write.format("delta").mode("append").saveAsTable(cat_config.full_prompts)
print(f"Saved prompt to {cat_config.full_prompts}")

## 4. Preview Invoices

In [None]:
invoices = spark.table(gen_config.full_invoices)
print(f"Invoices: {invoices.count()} rows")
invoices.select("order_id", "description", "category_level_1", "category_level_2").show(5, truncate=40)

## 5. Run LLM Classification

Use AI_QUERY for batch inference.

In [None]:
# Widgets for SQL
dbutils.widgets.text("llm_endpoint", gen_config.llm_endpoint)
dbutils.widgets.text("catalog", cat_config.catalog)
dbutils.widgets.text("schema", cat_config.schema_name)
dbutils.widgets.text("invoices_table", gen_config.invoices_table)
dbutils.widgets.text("cat_bootstrap_table", cat_config.cat_bootstrap_table)
dbutils.widgets.text("prompts_table", cat_config.prompts_table)

In [None]:
%sql
-- LLM classification using latest prompt
CREATE OR REPLACE TABLE $catalog.$schema.$cat_bootstrap_table AS
WITH latest_prompt AS (
  SELECT prompt, categories
  FROM $catalog.$schema.$prompts_table
  ORDER BY created_at DESC
  LIMIT 1
)
SELECT
  i.order_id,
  i.description,
  i.category_level_1 AS actual_level_1,
  i.category_level_2 AS actual_level_2,
  AI_QUERY(
    :llm_endpoint,
    CONCAT(p.prompt, '\n', p.categories, '\n\nInvoice: ', i.description),
    responseFormat => '{"type":"json_schema","json_schema":{"name":"classification","schema":{"type":"object","properties":{"category_level_1":{"type":"string"},"category_level_2":{"type":"string"},"confidence":{"type":"number"},"rationale":{"type":"string"}}}}}'
  ):category_level_1 AS pred_level_1,
  AI_QUERY(
    :llm_endpoint,
    CONCAT(p.prompt, '\n', p.categories, '\n\nInvoice: ', i.description),
    responseFormat => '{"type":"json_schema","json_schema":{"name":"classification","schema":{"type":"object","properties":{"category_level_1":{"type":"string"},"category_level_2":{"type":"string"},"confidence":{"type":"number"},"rationale":{"type":"string"}}}}}'
  ):category_level_2 AS pred_level_2,
  AI_QUERY(
    :llm_endpoint,
    CONCAT(p.prompt, '\n', p.categories, '\n\nInvoice: ', i.description),
    responseFormat => '{"type":"json_schema","json_schema":{"name":"classification","schema":{"type":"object","properties":{"category_level_1":{"type":"string"},"category_level_2":{"type":"string"},"confidence":{"type":"number"},"rationale":{"type":"string"}}}}}'
  ):confidence AS confidence,
  current_timestamp() AS classified_at
FROM $catalog.$schema.$invoices_table i
CROSS JOIN latest_prompt p

## 6. Evaluate Accuracy

In [None]:
from sklearn.metrics import accuracy_score, classification_report

results = spark.table(cat_config.full_cat_bootstrap).toPandas()

level1_acc = accuracy_score(results["actual_level_1"], results["pred_level_1"])
print(f"Level 1 Accuracy: {level1_acc:.3f}")

level2_acc = accuracy_score(results["actual_level_2"], results["pred_level_2"])
print(f"Level 2 Accuracy: {level2_acc:.3f}")

In [None]:
print("Level 1 Classification Report:")
print(classification_report(results["actual_level_1"], results["pred_level_1"]))

## 7. Summary

In [None]:
print("Tables created:")
print(f"  {cat_config.full_prompts}")
print(f"  {cat_config.full_cat_bootstrap}")
print(f"\nResults: {len(results)} invoices classified")
print(f"Level 1 Accuracy: {level1_acc:.1%}")
print(f"Level 2 Accuracy: {level2_acc:.1%}")