## Spend Categorization - RAG
This notebook uses vector search to retrieve similar examples with spend categories. This is straight retrieval augmented generation context stuffing. We create a vector index with 7000 entries and search it to find the closest 5 neighbors, giving that to our model to make better decisions

In [0]:
%pip install databricks-vectorsearch
%restart_python

In [0]:
%sql
SELECT * FROM shm.spend.train
LIMIT 10

Our first 'algorithm' will be a simple vector search to provide similar examples from the categorized spend. This can be thought of as a semantic k nearest neighbour lookup. But it means we are always looking up our live, approved spend table to avoid mapping new categories

In [0]:
%sql
ALTER TABLE shm.spend.train
SET TBLPROPERTIES (delta.enableChangeDataFeed = true)

GTE (General Text Embeddings) is particularly well-suited for invoice text matching due to its ability to handle semantic variations and unstructured data. GTE's multi-stage contrastive learning enables robust interpretation of abbreviated vendor codes (e.g., "ACME_CORP" vs "ACME Corp LLC"), mismatched product descriptions (e.g., "WIDGET-A2" vs "A2 Widget"), and localized number formats (1,234.56 vs 1.234,56). Unlike models requiring structured input, GTE processes text sequences without relying on positional encoding. Additionally, GTE's optional sparse embeddings enable exact matching of critical identifiers.

Because of the size of the table, the sync will take around 4 hours

In [0]:
from databricks.vector_search.client import VectorSearchClient

client = VectorSearchClient()

try:
  index = client.create_delta_sync_index(
    endpoint_name="one-env-shared-endpoint-3",
    source_table_name="shm.spend.train",
    index_name="shm.spend.vs_index",
    pipeline_type="TRIGGERED",
    primary_key="id",
    embedding_source_column="combined_award_text",
    embedding_model_endpoint_name="databricks-gte-large-en"
  )
except:
  print('index exists')

Let's test out Vector Search

In [0]:
%sql
-- Vector search
SELECT
  FIRST(combined_award_text) as query_row,
  STRING(
    COLLECT_LIST(
      CONCAT(
        combined_award_text, '\n', 
        'agency: ', funding_agency_name, '\n',
        'subagency: ', funding_sub_agency_name, '\n\n'
      )
    )
  ) AS similarity_search
FROM vector_search(
  index => 'shm.spend.vs_index',
  query_text => (
    SELECT combined_award_text 
    FROM shm.spend.test 
    ORDER BY RAND() LIMIT 1
  ),
  num_results => 5,
  query_type => 'hybrid'
)

In [0]:
prompt = """Use the following agency hierarchy and return the agency and subagency for the award information below. Return a json output with only the agencies. You must use the agencies and subagencies from the hierarchy, and pick the best ones. Never predict a subagency without the correct parent agency.

Output format: 
{'agency': AGENCY_NAME, 'subagency': SUBAGENCY_NAME}

Use the following examples for context:
"""

This is our main query - we wrap in SQL to leverage our batch inference optimizations. We also enforce ontology by passing in an `enum` for the levels. This is a bit of an naive way to do this, but works well enough. I would quickly transition this to a tool call + agent framework to add some reflection and hierarchy testing.

In [0]:
with open("agency_hierarchy.md", "r") as f:
    hierarchy = f.read()

In [0]:
%sql
-- Optimized Vector search with preretrival
SELECT
  FIRST(combined_award_text) as query_row,
  STRING(
    COLLECT_LIST(
      CONCAT(
        combined_award_text, '\n', 
        'agency: ', funding_agency_name, '\n',
        'subagency: ', funding_sub_agency_name, '\n\n'
      )
    )
  ) AS similarity_search
FROM vector_search(
  index => 'shm.spend.vs_index',
  query_text => (
    SELECT combined_award_text 
    FROM shm.spend.test 
    ORDER BY RAND() LIMIT 1
  ),
  num_results => 5,
  query_type => 'hybrid'
)

In [0]:
run_generation = False

In [0]:
from pyspark.sql import functions as F

if run_generation:
    pred_vs_df = (
        spark.table("shm.spend.test")
        .filter(F.col("combined_award_text").isNotNull())  # Filter out nulls upfront
        .filter(F.length(F.trim(F.col("combined_award_text"))) > 0)  # Filter out empty strings
        .withColumn(
            "vs_results",
            F.expr("""
                (
                    SELECT
                    STRING(
                        COLLECT_LIST(
                        CONCAT(
                            combined_award_text, '\n',
                            'agency: ', funding_agency_name, '\n',
                            'subagency: ', funding_sub_agency_name, '\n\n'
                        )
                        )
                    ) AS similarity_search
                    FROM VECTOR_SEARCH(
                    index => 'shm.spend.vs_index',
                    query_text => combined_award_text,
                    num_results => 5,
                    query_type => 'hybrid'
                    )
                )
            """)
        )
        .withColumn(
            "full_prompt",
            F.concat(
                F.lit(prompt), 
                F.lit('\n'),
                F.lit(hierarchy), 
                F.lit('\n'),
                F.col("vs_results"),
                F.lit("\n Award to Classify: \n"),
                F.col("combined_award_text")
            )
        )
        .select(
            F.col("id"),
            F.col("full_prompt"),
            F.col("vs_results"),
            F.expr("""
                AI_QUERY(
                    'databricks-meta-llama-3-3-70b-instruct',
                    full_prompt,
                    responseFormat => '{
                        "type": "json_schema",
                        "json_schema": {
                            "name": "categorization",
                            "schema": {
                                "type": "object",
                                "properties": {
                                    "agency": {"type": "string"},
                                    "subagency": {"type": "string"}
                                }
                            }
                        }
                    }'
                )
            """).alias("llm_output")
        )
    )

    pred_vs_df.write.mode("overwrite").saveAsTable("shm.spend.pred_vectorsearch")

In [0]:
pred_vs_df = spark.sql("SELECT * FROM shm.spend.pred_vectorsearch").display()

Add our spend and actuals back in

In [0]:
%sql
CREATE OR REPLACE TABLE shm.spend.pred_vectorsearch_comp AS
SELECT
  p.*,
  agency,
  subagency,
  t.funding_agency_name,
  t.funding_sub_agency_name,
  t.federal_action_obligation as spend
FROM 
  shm.spend.pred_vectorsearch p
JOIN
  shm.spend.test t
ON 
  t.id = p.id
LATERAL VIEW 
  JSON_TUPLE(p.llm_output, 'agency', 'subagency') AS agency, subagency

## Evaluation
The cells below evaluate the results, weighted by spend

In [0]:
import pandas as pd
pred_vs_comp = spark.table('shm.spend.pred_vectorsearch_comp').dropna(
    subset=['funding_agency_name', 'agency', 'funding_sub_agency_name', 'subagency']
).toPandas()

Even with only 7,000 rows (negligble in most spend categorization cases), we've improved level 1 accuracy from 53% to 96% and level 2 accuracy from 26% to 92%.

In [0]:
from sklearn.metrics import accuracy_score, classification_report

print(f"""Agency Accuracy: {accuracy_score(
  pred_vs_comp['funding_agency_name'], 
  pred_vs_comp['agency']
  ):0.3f}""")

print(f"""Subagency Accuracy: {accuracy_score(
  pred_vs_comp['funding_sub_agency_name'], 
  pred_vs_comp['subagency']
  ):0.3f}""")

In [0]:
import pandas as pd

class_dict = classification_report(
  y_true = pred_vs_comp['funding_agency_name'], 
  y_pred = pred_vs_comp['agency'],
  sample_weight=pred_vs_comp['spend'].abs(),
  output_dict=True
  )

In [0]:
class_dict

In [0]:
display(
    pd.DataFrame(class_dict)
      .transpose()
      .reset_index()
      .drop(columns=['f1-score'], errors='ignore')
      .rename(columns={'support': 'spend'})
      .sort_values('spend', ascending=False)
      .query('index != ["accuracy", "weighted avg", "macro avg"]')
      .assign(spend=lambda df: (df['spend'] / 1000).round().astype(int))
      .round(2)
      .head(12)
)

In [0]:
import seaborn as sns
import matplotlib.pyplot as plt

# Prepare the data
df = pd.DataFrame(class_dict).transpose().reset_index()
df = df.query('index != ["accuracy", "weighted avg", "macro avg"]')
df = df.sort_values('support', ascending=False)

# Add total spend as a label on each bar
df['total_spend'] = df['index'].map(pred_vs_comp.groupby('funding_agency_name')['spend'].sum().round(0))

# Truncate labels to max 20 characters
df['short_label'] = df['index'].str.slice(0, 20)

# Sort by total spend descending
df = df.sort_values('total_spend', ascending=False).reset_index(drop=True).head(9)

# Set the theme to minimal
sns.set_theme(style="whitegrid")

# Create the bar plot for precision with total spend labels inside bars
plt.figure(figsize=(12, 8))  # Increased height for taller bars
barplot = sns.barplot(x='short_label', y='precision', data=df, color='#1B3139')
plt.xlabel('', fontsize=14)
plt.ylabel('Accuracy', fontsize=14)
plt.ylim(0.5,1.0)

for spine in barplot.spines.values():
    spine.set_visible(False)

# Rotate x labels for better readability
plt.xticks(rotation=-60, ha='left')

barplot.grid(False)

# Add rotated spend labels inside each bar
for index, row in df.iterrows():
    barplot.text(index, row.precision*0.8, f"${row.total_spend:,.0f}", color='white', ha="center", va='baseline', rotation=-90, fontsize=12, fontweight='bold')

plt.tight_layout()

plt.savefig('vector_search.png', dpi=600)