# Spend Categorization

## Data Generation
This notebook generates the data included in the `assets` folder.

Tested on severless v4.

In [2]:
%pip install uv
%uv pip install .
%restart_python

/Users/scott.mckean/Repos/spend_categorization/.venv/bin/python: No module named pip
Note: you may need to restart the kernel to use updated packages.
/Users/scott.mckean/Repos/spend_categorization/.venv/bin/python: No module named uv
Note: you may need to restart the kernel to use updated packages.


UsageError: Line magic function `%restart_python` not found.


In [None]:
from src.utils import get_spark
from src.generate import load_generate_config

spark = get_spark()
config = load_generate_config()

In [None]:
import random
from datetime import timedelta, datetime

random.seed(42)

In [None]:
# Use the loaded config - all settings are flattened
name = config.company_name
industry = config.industry
plants = config.plants
cost_centres = config.cost_centres         
categories = config.categories
cat_cc_map = config.category_cost_centre_mapping
level_1_dist = config.distribution
n_rows = config.rows
start_date = datetime.strptime(config.start_date, "%Y-%m-%d")
end_date = datetime.strptime(config.end_date, "%Y-%m-%d")
fieldnames = config.python_columns
catalog = config.catalog
schema = config.schema_name
llm_endpoint = config.llm_endpoint
table = config.table

In [7]:
def random_date(start_date, end_date):
    delta = end_date - start_date
    return (start_date + timedelta(days=random.randint(0, delta.days)))
  
def sample_level1(level_1_dist):
    r = random.random()
    d_direct = level_1_dist["Direct"]
    d_indirect = level_1_dist["Indirect"]
    if r < d_direct:
        return "Direct"
    elif r < d_direct + d_indirect:
        return "Indirect"
    else:
        return "Non-Procureable"

def sample_category_triplet(categories, level_1_dist):
    l1 = sample_level1(level_1_dist)
    l2 = random.choice(list(categories[l1].keys()))
    l3 = random.choice(categories[l1][l2])
    return l1, l2, l3

def sample_plant(plants):
    return random.choice(plants)

def sample_cost_centre(l2, cat_cc_map):
    return cat_cc_map[l2]

def generate_order_id(i, n_rows, date):
    # More random PO pattern
    year = date.year
    seq = random.randint(1, n_rows)
    suffix = ''.join(random.choices('ABCDEFGHIJKLMNOPQRSTUVWXYZ', k=2))
    return f"ORD-{year}-{str(seq).zfill(5)}-{suffix}"

def generate_amount_and_price(l1):
    if l1 == "Non-Procureable":
        amount = 1
        unit_price = round(random.uniform(1000, 500000), 2)
        unit = "ls"
    else:
        amount = random.randint(1, 100)
        unit_price = round(random.uniform(5, 5000), 2)
    total = round(amount * unit_price, 2)
    return amount, unit_price, total

In [8]:
import pandas as pd

rows = []
for i in range(1, n_rows + 1):
    date = random_date(start_date, end_date)
    l1, l2, l3 = sample_category_triplet(categories, level_1_dist)
    plant = sample_plant(plants)
    cc = sample_cost_centre(l2, cat_cc_map)
    order_id = generate_order_id(i, n_rows, date)
    amount, unit_price, total = generate_amount_and_price(l1)

    row = {
        "date": date.strftime("%Y-%m-%d"),
        "order_id": order_id,
        "category_level_1": l1,
        "category_level_2": l2,
        "category_level_3": l3,
        "cost_centre": cc,
        "plant": plant["name"],
        "plant_id": plant["id"],
        "region": plant["region"],
        "amount": amount,
        "unit_price": unit_price,
        "total": total
    }
    rows.append(row)

invoice_data = pd.DataFrame(rows, columns=fieldnames)

In [9]:
invoice_data

Unnamed: 0,date,order_id,category_level_1,category_level_2,category_level_3,cost_centre,plant,plant_id,region,amount,unit_price,total
0,2025-10-16,ORD-2025-02287-TR,Direct,Blades & Hub Parts,Blade shear web,CC-300-Engineering,US-West Plant,PLANT-US-W,North America,70,439.26,30748.20
1,2025-03-08,ORD-2025-08280-PO,Direct,Components,Control PCB,CC-100-Production,US-West Plant,PLANT-US-W,North America,92,3251.17,299107.64
2,2025-07-12,ORD-2025-00107-TE,Direct,Bearings & Seals,Oil seal,CC-100-Production,Germany-North,PLANT-DE-N,Europe,55,1704.55,93750.25
3,2024-06-08,ORD-2024-06225-CW,Direct,Electrical Assemblies,Control cabinet,CC-100-Production,US-East Plant,PLANT-US-E,North America,78,1326.28,103449.84
4,2024-02-14,ORD-2024-01292-OV,Indirect,Events & Conferences,Conference registration,CC-500-Sales,Germany-South,PLANT-DE-S,Europe,80,4427.83,354226.40
...,...,...,...,...,...,...,...,...,...,...,...,...
9995,2024-10-01,ORD-2024-07585-LY,Direct,Blades & Hub Parts,Pitch bearing,CC-300-Engineering,Brazil Plant,PLANT-BR,South America,69,663.55,45784.95
9996,2024-12-17,ORD-2024-01192-HL,Direct,Packaging Materials,Export crate,CC-700-Logistics,US-West Plant,PLANT-US-W,North America,62,2379.93,147555.66
9997,2024-03-18,ORD-2024-03705-QJ,Indirect,Safety & PPE,Hearing protection,CC-200-Maintenance,Vietnam Plant,PLANT-VN,Asia,83,2206.96,183177.68
9998,2025-07-14,ORD-2025-05719-AH,Indirect,Temporary Labor / Contracting,Crane rental crew,CC-200-Maintenance,Germany-South,PLANT-DE-S,Europe,49,2481.28,121582.72


In [11]:
(
  spark.createDataFrame(invoice_data)
  .write.format("delta")
  .option("overwriteSchema", "true")
  .mode("overwrite")
  .saveAsTable(f"{catalog}.{schema}.{table}")
)

In [0]:
dbutils.widgets.text("llm_endpoint", llm_endpoint)
dbutils.widgets.text("catalog", catalog)
dbutils.widgets.text("schema", schema)
dbutils.widgets.text("table_gen", table)
dbutils.widgets.text("table_gen", table+"_gen")
dbutils.widgets.text("prompt", prompt)

In [0]:
%sql
CREATE OR REPLACE TABLE $catalog.$schema.$table_gen AS
SELECT
  date,
  order_id,
  CONCAT(
    :prompt, '\n',
    'date: ', date, '\n',
    'order_id: ', order_id, '\n',
    'category_level_1: ', category_level_1, '\n',
    'category_level_2: ', category_level_2, '\n',
    'category_level_3: ', category_level_3, '\n',
    'cost_centre: ', cost_centre, '\n',
    'plant: ', plant, '\n',
    'plant_id: ', plant_id, '\n',
    'region: ', region
  ) AS input,
  AI_QUERY(
    :llm_endpoint,
    CONCAT(
      :prompt, '\n',
      'date: ', date, '\n',
      'order_id: ', order_id, '\n',
      'category_level_1: ', category_level_1, '\n',
      'category_level_2: ', category_level_2, '\n',
      'category_level_3: ', category_level_3, '\n',
      'cost_centre: ', cost_centre, '\n',
      'plant: ', plant, '\n',
      'plant_id: ', plant_id, '\n',
      'region: ', region, '\n',
      'Output: '
    ),
    responseFormat => '{
      "type": "json_schema",
      "json_schema": {
        "name": "categorization",
        "schema": {
          "type": "object",
          "properties": {
            "description": {"type": "string"},
            "supplier": {"type": "string"},
            "supplier_country": {"type": "string"}
          }
        }
      }
    }'
  ) AS output
FROM $catalog.$schema.$table

In [13]:
spark.table(f"{catalog}.{schema}.{table}_gen").limit(1).toPandas()

Unnamed: 0,date,order_id,input,output
0,2024-04-14,ORD-2024-06913-AC,Generate these fields based on the provided in...,"{""description"":""Annual sales conference regist..."


In [0]:
%sql
CREATE OR REPLACE TABLE $catalog.$schema.$table_enh AS
SELECT
  t.*,
  o.supplier,
  o.supplier_country,
  o.description
FROM $catalog.$schema.$table t
JOIN (
  SELECT
    date,
    order_id,
    output.supplier AS supplier,
    output.supplier_country AS supplier_country,
    output.description AS description
  FROM $catalog.$schema.$table_gen
  LATERAL VIEW json_tuple(output, 'description', 'supplier', 'supplier_country')
    output AS description, supplier, supplier_country
) o
ON t.date = o.date AND t.order_id = o.order_id

In [0]:
%sql
SELECT * FROM $catalog.$schema.$table_enh