## Spend Categorization - Vector Search RAG

This notebook uses vector search for Borealis Wind Systems spend categorization. We create a vector index from labeled invoices and retrieve similar examples to provide context for LLM classification.

**Approach:**
1. Build a vector index from `invoices_enh` (enhanced with LLM descriptions)
2. For each transaction, retrieve 5 similar examples with known categories
3. Use these examples as few-shot context for classification

Tested on Serverless v4.

In [0]:
%pip install uv
%uv pip install .
%restart_python

In [0]:
from src.utils import get_spark
from src.config import load_config

spark = get_spark()
config = load_config()

In [0]:
# Widgets for SQL cells
dbutils.widgets.removeAll()
dbutils.widgets.text("catalog", config.catalog)
dbutils.widgets.text("schema", config.schema_name)
dbutils.widgets.text("invoices", config.invoices)
dbutils.widgets.text("cat_vectorsearch", config.cat_vectorsearch_table)

# Preview the enhanced invoices table
display(spark.table(config.full_invoices_table_path).limit(10))

## Create Combined Text Column

First, we create a combined text column for embedding that captures the key information about each transaction.

In [0]:
# Create a table with combined text for vector search
spark.sql(f"""
CREATE OR REPLACE TABLE {config.catalog}.{config.schema_name}.invoices_vs AS
SELECT 
    *,
    CONCAT(
        'Supplier: ', COALESCE(supplier, ''), ' | ',
        'Country: ', COALESCE(supplier_country, ''), ' | ',
        'Description: ', COALESCE(description, ''), ' | ',
        'Plant: ', COALESCE(plant, ''), ' | ',
        'Region: ', COALESCE(region, ''), ' | ',
        'Cost Centre: ', COALESCE(cost_centre, '')
    ) AS combined_text
FROM {config.full_invoices_table_path}
""")

# Enable Change Data Feed for vector search sync
spark.sql(f"""
ALTER TABLE {config.catalog}.{config.schema_name}.invoices_vs
SET TBLPROPERTIES (delta.enableChangeDataFeed = true)
""")

## Create Vector Search Index

GTE (General Text Embeddings) handles semantic variations well - matching "ACME_CORP" with "ACME Corp LLC" or interpreting different description formats. We use `databricks-gte-large-en` for embedding.

In [0]:
from databricks.vector_search.client import VectorSearchClient

client = VectorSearchClient()

# Create or update the vector search index
try:
    index = client.create_delta_sync_index(
        endpoint_name="one-env-shared-endpoint-3",  # Update with your endpoint
        source_table_name=f"{config.catalog}.{config.schema_name}.invoices_vs",
        index_name=f"{config.catalog}.{config.schema_name}.vs_index",
        pipeline_type="TRIGGERED",
        primary_key="order_id",
        embedding_source_column="combined_text",
        embedding_model_endpoint_name="databricks-gte-large-en"
    )
    print("✅ Vector search index created")
except Exception as e:
    print(f"Index exists or error: {e}")

## Test Vector Search

Query the index to find similar invoices.

In [0]:
# Test vector search with a random transaction
test_query = spark.sql(f"""
SELECT
  FIRST(combined_text) as query_row,
  STRING(
    COLLECT_LIST(
      CONCAT(
        combined_text, '\n', 
        'Level 1: ', category_level_1, '\n',
        'Level 2: ', category_level_2, '\n',
        'Level 3: ', category_level_3, '\n\n'
      )
    )
  ) AS similar_examples
FROM vector_search(
  index => '{config.catalog}.{config.schema_name}.vs_index',
  query_text => (
    SELECT combined_text 
    FROM {config.catalog}.{config.schema_name}.invoices_vs 
    ORDER BY RAND() LIMIT 1
  ),
  num_results => 5,
  query_type => 'hybrid'
)
""")
display(test_query)

In [0]:
prompt = """You are a spend categorization expert for Borealis Wind Systems, a wind turbine manufacturer.

Classify the transaction into the correct 3-level category hierarchy:
- Level 1: Direct, Indirect, or Non-Procureable
- Level 2: The category (e.g., "Bearings & Seals", "MRO", "IT & Software")
- Level 3: The specific item type (e.g., "Spherical roller bearing", "Tool rental")

Use ONLY categories from the hierarchy provided. Match the transaction to the most specific and accurate category.

Use the following similar examples for context:
"""

## Load Category Hierarchy

Load the category hierarchy from config for use in the prompt.

In [0]:
# Load hierarchy from config.yaml
category_descriptions = config.get_category_descriptions()

# Format hierarchy as text for the prompt
def format_hierarchy(categories):
    lines = []
    for level1, level2_dict in categories.items():
        lines.append(f"\n## {level1}")
        for level2, level3_list in level2_dict.items():
            desc = category_descriptions.get(level2, "")
            lines.append(f"  - {level2}: {desc}")
            for level3 in level3_list:
                lines.append(f"      - {level3}")
    return "\n".join(lines)

hierarchy = format_hierarchy(config.categories)
print(hierarchy[:1000] + "...")

In [0]:
# Create train/test split for evaluation
from sklearn.model_selection import train_test_split

invoices_df = spark.table(f'{config.catalog}.{config.schema_name}.invoices_vs').toPandas()
train_df, test_df = train_test_split(invoices_df, test_size=0.2, random_state=42)

print(f"Train: {len(train_df)}, Test: {len(test_df)}")

# Save test set for evaluation
spark.createDataFrame(test_df).write.mode("overwrite").saveAsTable(f"{config.catalog}.{config.schema_name}.test_vs")

In [0]:
run_generation = False

In [0]:
from pyspark.sql import functions as F

if run_generation:
    pred_vs_df = (
        spark.table(f"{config.catalog}.{config.schema_name}.test_vs")
        .filter(F.col("combined_text").isNotNull())
        .filter(F.length(F.trim(F.col("combined_text"))) > 0)
        .withColumn(
            "vs_results",
            F.expr(f"""
                (
                    SELECT
                    STRING(
                        COLLECT_LIST(
                        CONCAT(
                            combined_text, '\n',
                            'Level 1: ', category_level_1, '\n',
                            'Level 2: ', category_level_2, '\n',
                            'Level 3: ', category_level_3, '\n\n'
                        )
                        )
                    ) AS similar_examples
                    FROM VECTOR_SEARCH(
                    index => '{config.catalog}.{config.schema_name}.vs_index',
                    query_text => combined_text,
                    num_results => 5,
                    query_type => 'hybrid'
                    )
                )
            """)
        )
        .withColumn(
            "full_prompt",
            F.concat(
                F.lit(prompt), 
                F.lit('\n\nCategory Hierarchy:\n'),
                F.lit(hierarchy), 
                F.lit('\n\nSimilar Examples:\n'),
                F.col("vs_results"),
                F.lit("\n\nTransaction to Classify:\n"),
                F.col("combined_text")
            )
        )
        .select(
            F.col("order_id"),
            F.col("category_level_1").alias("actual_level_1"),
            F.col("category_level_2").alias("actual_level_2"),
            F.col("category_level_3").alias("actual_level_3"),
            F.col("total"),
            F.col("full_prompt"),
            F.col("vs_results"),
            F.expr(f"""
                AI_QUERY(
                    '{config.large_llm_endpoint}',
                    full_prompt,
                    responseFormat => '{{
                        "type": "json_schema",
                        "json_schema": {{
                            "name": "categorization",
                            "schema": {{
                                "type": "object",
                                "properties": {{
                                    "level_1": {{"type": "string"}},
                                    "level_2": {{"type": "string"}},
                                    "level_3": {{"type": "string"}}
                                }}
                            }}
                        }}
                    }}'
                )
            """).alias("llm_output")
        )
    )

    pred_vs_df.write.mode("overwrite").saveAsTable(config.full_cat_vectorsearch_table_path)
    print(f"✅ Predictions saved to {config.full_cat_vectorsearch_table_path}")

In [0]:
pred_vs_df = spark.sql(f"SELECT * FROM {config.full_cat_vectorsearch_table_path}").display()

## Parse LLM Output and Prepare for Evaluation

In [0]:
# Parse LLM output and create comparison table
spark.sql(f"""
CREATE OR REPLACE TABLE {config.catalog}.{config.schema_name}.cat_vectorsearch_comp AS
SELECT
  p.order_id,
  p.actual_level_1,
  p.actual_level_2,
  p.actual_level_3,
  p.total,
  pred_level_1,
  pred_level_2,
  pred_level_3
FROM 
  {config.full_cat_vectorsearch_table_path} p
LATERAL VIEW 
  JSON_TUPLE(p.llm_output, 'level_1', 'level_2', 'level_3') AS pred_level_1, pred_level_2, pred_level_3
""")
print("✅ Comparison table created")

## Evaluation
The cells below evaluate the results, weighted by spend

In [0]:
import pandas as pd

pred_vs_comp = spark.table(f'{config.catalog}.{config.schema_name}.cat_vectorsearch_comp').dropna(
    subset=['actual_level_1', 'pred_level_1', 'actual_level_2', 'pred_level_2']
).toPandas()

print(f"Evaluation dataset: {len(pred_vs_comp)} rows")

## Accuracy Results

Compare vector search RAG predictions against actuals, weighted by spend.

In [0]:
from sklearn.metrics import accuracy_score, classification_report

print(f"Level 1 Accuracy (Direct/Indirect/Non-Procureable): {accuracy_score(pred_vs_comp['actual_level_1'], pred_vs_comp['pred_level_1']):0.3f}")
print(f"Level 2 Accuracy (Category): {accuracy_score(pred_vs_comp['actual_level_2'], pred_vs_comp['pred_level_2']):0.3f}")

In [0]:
import pandas as pd

# Classification report for Level 2 categories, weighted by spend
class_dict = classification_report(
    y_true=pred_vs_comp['actual_level_2'], 
    y_pred=pred_vs_comp['pred_level_2'],
    sample_weight=pred_vs_comp['total'].abs(),
    output_dict=True,
    zero_division=0
)

In [0]:
class_dict

In [0]:
# Display classification metrics by category, sorted by spend
metrics_df = (
    pd.DataFrame(class_dict)
    .transpose()
    .reset_index()
    .rename(columns={'index': 'category', 'support': 'spend'})
    .query('category not in ["accuracy", "weighted avg", "macro avg"]')
    .sort_values('spend', ascending=False)
    .assign(spend=lambda df: (df['spend'] / 1000).round().astype(int))
    .round(2)
    .head(15)
)
metrics_df

In [0]:
import seaborn as sns
import matplotlib.pyplot as plt

# Prepare visualization data
df = pd.DataFrame(class_dict).transpose().reset_index()
df = df.query('index not in ["accuracy", "weighted avg", "macro avg"]')
df = df.sort_values('support', ascending=False)

# Add total spend by category
spend_by_cat = pred_vs_comp.groupby('actual_level_2')['total'].sum()
df['total_spend'] = df['index'].map(spend_by_cat).fillna(0)

# Truncate labels to max 20 characters
df['short_label'] = df['index'].str.slice(0, 20)

# Sort by total spend descending
df = df.sort_values('total_spend', ascending=False).reset_index(drop=True).head(10)

# Create visualization
sns.set_theme(style="whitegrid")
plt.figure(figsize=(12, 8))
barplot = sns.barplot(x='short_label', y='precision', data=df, color='#2E86AB')
plt.xlabel('', fontsize=14)
plt.ylabel('Precision', fontsize=14)
plt.title('Vector Search RAG Precision by Category (Borealis Wind)', fontsize=16)
plt.ylim(0.5, 1.0)

for spine in barplot.spines.values():
    spine.set_visible(False)

plt.xticks(rotation=-45, ha='left')
barplot.grid(False)

# Add spend labels inside bars
for i, row in df.reset_index().iterrows():
    if row.total_spend > 0:
        barplot.text(i, row.precision * 0.85, f"${row.total_spend/1e6:.1f}M", 
                     color='white', ha="center", va='baseline', rotation=-90, 
                     fontsize=10, fontweight='bold')

plt.tight_layout()
plt.savefig('vector_search.png', dpi=300)
plt.show()