# Spend Categorization

## Naive LLM Inference
This notebook uses a naive version of our to do categorization using only prompt engineering. It provides a baseline of using naive batch inference to solve the problem. It also provides an example of how you'd bootstrap a categorization problem if you didn't have labels.

In [0]:
df = spark.sql("SELECT * FROM shm.spend.test").toPandas()

In [0]:
df.iloc[0].combined

We are going to load the hierarchy off the config, that way if we want to a) version it, or b) change the categorization it is relatively straightforward. A Delta table is also a great idea here.

We are going to focus on level 1 and level 2 categorization first.

In [0]:
import yaml
CONFIG_PATH = "config.yaml" 
with open(CONFIG_PATH, "r", encoding="utf-8") as f:
    config = yaml.safe_load(f)
   
categories = config['company']['categories']
catalog = config['data_generation']["catalog"]
schema = config['data_generation']["schema"]
llm_endpoint = config['data_generation']["llm_endpoint"]
table = config['data_generation']["table"]

In [0]:
categories_str = "# Categories\n\n"
for agency, subagencies in categories.items():
    categories_str += f"## {agency}\n"
    for subagency in sorted(subagencies):
        categories_str += f"- {subagency}\n"
    categories_str += "\n"

categories_str

In [0]:
from pyspark.sql.functions import current_date

(
  spark.createDataFrame([(categories_str,)], ["categories_str"])
  .withColumn("date", current_date())
  .select('date','categories_str')
  .write.format('delta')
  .mode('append')
  .saveAsTable('shm.spend.categories')
)

In [0]:
%sql
SELECT * FROM shm.spend.categories

This cell creates a markdown string with all the values from the agency and its subagencies. This is convenient because it always uses the awards table so is always up to date. It could also be pointed at a category tree for hierachical spend clasification.

This could also be a function used as an agent tool.

Next, we write a short prompt for our model - this could definitely be improved, but nowhere near enough to get acceptable accuracy!

In [0]:
prompt = """Use the following hierarchy and return the level 1 and level 2 categories. Return a json output with only the agencies. You must use the categories from the hierarchy, pick the best ones.

For example:
'date: 2025-10-07\norder_id: ORD-2025-01340-QC\ncategory_level_1: Direct\ncategory_level_2: Electrical Assemblies\ncategory_level_3: Control cabinet\ncost_centre: CC-100-Production\nplant: US-East Plant\nplant_id: PLANT-US-E\nregion: North Americasupplier: Eaton Corporation\nsupplier_country: US\ndescription: Industrial control cabinet with electrical assemblies for production line\n'

{'lvl_1_cat': Direct, 'lvl_2_cat': Electrical Assemblies}
"""

In [0]:
from pyspark.sql.functions import current_date

(
  spark.createDataFrame([(prompt,)], ["prompt"])
  .withColumn("date", current_date())
  .select('date','prompt')
  .write.format('delta')
  .mode('append')
  .saveAsTable('shm.spend.prompts')
)

We setup widgets so that we can call the category tree and prompt in our batch inference as parameters

We use AI_QUERY to run batch inference - this is all done in SQL - I am repeating the `CONCAT` call twice here, just so I can inspect the combined prompt that went into the model. We also use the `responseFormat` to enforce structured outputs. This is critical for consistency and maintanability of Generative AI solutions - I wouldn't leave POC without it. It's worth pointing out that because of the optimizations done in AI_QUERY - 500 calls to the LLM only takes 20 seconds.

In [0]:
dbutils.widgets.text("llm_endpoint", llm_endpoint)
dbutils.widgets.text("catalog", catalog)
dbutils.widgets.text("schema", schema)
dbutils.widgets.text("table", table+"_naive")

In [0]:
%sql
SELECT * FROM shm.spend.transactions_comb LIMIT 5

In [0]:
%sql
WITH latest_prompt AS (
  SELECT
    prompt,
    date
  FROM shm.spend.prompts
  ORDER BY date DESC
  LIMIT 1
),
latest_category AS (
  SELECT *
  FROM shm.spend.categories
  ORDER BY date DESC
  LIMIT 1
)
SELECT
  t.id,
  CONCAT(
    lp.prompt, '\n',
    lc.categories_str, '\n',
    t.combined
  ) AS prompt
FROM $catalog.$schema.test t
LEFT JOIN latest_prompt lp ON 1=1
LEFT JOIN latest_category lc ON 1=1

In [0]:
%sql
SELECT * FROM $catalog.$schema.$table

Second step overwrite to deconstruct that JSON file and pull in the actual labels for evaluation. This could be done in the first SQL call, but it was getting long.

In [0]:
%sql
SELECT * FROM shm.spend.pred_naive LIMIT 5

In [0]:
%sql
CREATE OR REPLACE TABLE shm.spend.pred_naive_comp AS
SELECT
  p.*,
  agency,
  subagency,
  t.funding_agency_name,
  t.funding_sub_agency_name
FROM 
  shm.spend.pred_naive p
JOIN
  shm.spend.test t
ON 
  t.id = p.id
LATERAL VIEW 
  JSON_TUPLE(p.llm_output, 'agency', 'subagency') AS agency, subagency

Now let's move into sklearn to get a classification report from our LLM based analysis for comparison sakes.

In [0]:
pred_naive = spark.table('shm.spend.pred_naive_comp').dropna(
    subset=['funding_agency_name', 'agency', 'funding_sub_agency_name', 'subagency']
).toPandas()

In [0]:
from sklearn.metrics import accuracy_score, classification_report

print(f"""Agency Accuracy: {accuracy_score(
  pred_naive['funding_agency_name'], 
  pred_naive['agency']
  ):0.3f}""")

print(f"""Subagency Accuracy: {accuracy_score(
  pred_naive['funding_sub_agency_name'], 
  pred_naive['subagency']
  ):0.3f}""")

With naive inference we have relatively poor accuracy, but more than zero.

In [0]:
%sql
SELECT * FROM shm.spend.test