# Spend Categorization

## Data Generation
This notebook generates the data included in the `assets` folder.

Tested on severless v4.

In [0]:
%uv pip install .
%restart_python

In [0]:
import random
import csv
from datetime import datetime, timedelta
import yaml

random.seed(42)

In [0]:
CONFIG_PATH = "config.yaml" 
with open(CONFIG_PATH, "r", encoding="utf-8") as f:
    config = yaml.safe_load(f)

name = config['company']['name']
industry = config['company']['industry']
plants = config['company']['plants']
cost_centres = config['company']['cost_centres']         
categories = config['company']['categories']
cat_cc_map = config['company']['category_cost_centre_mapping']
level_1_dist = config['data_generation']['distribution']
n_rows = config['data_generation']['rows']
start_date = config['data_generation']["start"]
end_date = config['data_generation']["end"]
fieldnames = config['data_generation']["python_columns"]
catalog = config['data_generation']["catalog"]
schema = config['data_generation']["schema"]
llm_endpoint = config['data_generation']["llm_endpoint"]
table = config['data_generation']["table"]

In [0]:
def random_date(start_date, end_date):
    delta = end_date - start_date
    return (start_date + timedelta(days=random.randint(0, delta.days)))
  
def sample_level1(level_1_dist):
    r = random.random()
    d_direct = level_1_dist["Direct"]
    d_indirect = level_1_dist["Indirect"]
    if r < d_direct:
        return "Direct"
    elif r < d_direct + d_indirect:
        return "Indirect"
    else:
        return "Non-Procureable"

def sample_category_triplet(categories, level_1_dist):
    l1 = sample_level1(level_1_dist)
    l2 = random.choice(list(categories[l1].keys()))
    l3 = random.choice(categories[l1][l2])
    return l1, l2, l3

def sample_plant(plants):
    return random.choice(plants)

def sample_cost_centre(l2, cat_cc_map):
    return cat_cc_map[l2]

def generate_order_id(i, n_rows, date):
    # More random PO pattern
    year = date.year
    seq = random.randint(1, n_rows)
    suffix = ''.join(random.choices('ABCDEFGHIJKLMNOPQRSTUVWXYZ', k=2))
    return f"ORD-{year}-{str(seq).zfill(5)}-{suffix}"

def generate_amount_and_price(l1):
    if l1 == "Non-Procureable":
        amount = 1
        unit_price = round(random.uniform(1000, 500000), 2)
        unit = "ls"
    else:
        amount = random.randint(1, 100)
        unit_price = round(random.uniform(5, 5000), 2)
    total = round(amount * unit_price, 2)
    return amount, unit_price, total

In [0]:
import pandas as pd

rows = []
for i in range(1, n_rows + 1):
    date = random_date(start_date, end_date)
    l1, l2, l3 = sample_category_triplet(categories, level_1_dist)
    plant = sample_plant(plants)
    cc = sample_cost_centre(l2, cat_cc_map)
    order_id = generate_order_id(i, n_rows, date)
    amount, unit_price, total = generate_amount_and_price(l1)

    row = {
        "date": date.strftime("%Y-%m-%d"),
        "order_id": order_id,
        "category_level_1": l1,
        "category_level_2": l2,
        "category_level_3": l3,
        "cost_centre": cc,
        "plant": plant["name"],
        "plant_id": plant["id"],
        "region": plant["region"],
        "amount": amount,
        "unit_price": unit_price,
        "total": total
    }
    rows.append(row)

invoice_data = pd.DataFrame(rows, columns=fieldnames)

In [0]:
invoice_data

In [0]:
(
  spark.createDataFrame(invoice_data)
  .write.format("delta")
  .option("overwriteSchema", "true")
  .mode("overwrite")
  .saveAsTable(f"{catalog}.{schema}.{table}")
)

In [0]:
with open("generation_prompt.md", "r", encoding="utf-8") as f:
    prompt = f.read()

In [0]:
dbutils.widgets.text("llm_endpoint", llm_endpoint)
dbutils.widgets.text("catalog", catalog)
dbutils.widgets.text("schema", schema)
dbutils.widgets.text("table_gen", table)
dbutils.widgets.text("table_gen", table+"_gen")
dbutils.widgets.text("prompt", prompt)

In [0]:
%sql
CREATE OR REPLACE TABLE $catalog.$schema.$table_gen AS
SELECT
  date,
  order_id,
  CONCAT(
    :prompt, '\n',
    'date: ', date, '\n',
    'order_id: ', order_id, '\n',
    'category_level_1: ', category_level_1, '\n',
    'category_level_2: ', category_level_2, '\n',
    'category_level_3: ', category_level_3, '\n',
    'cost_centre: ', cost_centre, '\n',
    'plant: ', plant, '\n',
    'plant_id: ', plant_id, '\n',
    'region: ', region
  ) AS input,
  AI_QUERY(
    :llm_endpoint,
    CONCAT(
      :prompt, '\n',
      'date: ', date, '\n',
      'order_id: ', order_id, '\n',
      'category_level_1: ', category_level_1, '\n',
      'category_level_2: ', category_level_2, '\n',
      'category_level_3: ', category_level_3, '\n',
      'cost_centre: ', cost_centre, '\n',
      'plant: ', plant, '\n',
      'plant_id: ', plant_id, '\n',
      'region: ', region, '\n',
      'Output: '
    ),
    responseFormat => '{
      "type": "json_schema",
      "json_schema": {
        "name": "categorization",
        "schema": {
          "type": "object",
          "properties": {
            "description": {"type": "string"},
            "supplier": {"type": "string"},
            "supplier_country": {"type": "string"}
          }
        }
      }
    }'
  ) AS output
FROM $catalog.$schema.$table

In [0]:
spark.table(f"{catalog}.{schema}.{table}_gen").display()

In [0]:
%sql
CREATE OR REPLACE TABLE $catalog.$schema.$table_enh AS
SELECT
  t.*,
  o.supplier,
  o.supplier_country,
  o.description
FROM $catalog.$schema.$table t
JOIN (
  SELECT
    date,
    order_id,
    output.supplier AS supplier,
    output.supplier_country AS supplier_country,
    output.description AS description
  FROM $catalog.$schema.$table_gen
  LATERAL VIEW json_tuple(output, 'description', 'supplier', 'supplier_country')
    output AS description, supplier, supplier_country
) o
ON t.date = o.date AND t.order_id = o.order_id

In [0]:
%sql
SELECT * FROM $catalog.$schema.$table_enh