In [0]:
from pyspark.sql.types import StructType, StructField, LongType, StringType, DoubleType, BooleanType
from pyspark.sql.functions import (
    col, sum as _sum, year, month, from_unixtime, count, avg as _avg,
    udf, rand, when, lit, expr, row_number, concat, monotonically_increasing_id
)
from pyspark.sql.window import Window
from datetime import datetime
import random

# -------------------- KEY VARIABLES -------------------- #
catalog = 'main'
directory = f"/Volumes/{catalog}/finance_lakehouse/data_gen_outputs"
output_path = f"{directory}/spend_invoices"

In [0]:
"""
Generate synthetic data for the spend_invoices table.
This script uses PySpark with distributed compute for high-volume data generation.
Maintains 1:1 relationship with purchase orders.
Invoice dates are 15-90 days after PO date.
"""

# Current timestamp for status calculations
CURRENT_TS = int(datetime.now().timestamp())

print("\n=== READING PURCHASE ORDERS DATA ===")
po_df = spark.read.option("header", "true").csv(f"{directory}/purchase_orders/*/*.csv")

# Convert columns to proper types
po_df = po_df.withColumn("purchase_order_id", col("purchase_order_id").cast("long")) \
    .withColumn("purchase_order_date", col("purchase_order_date").cast("long")) \
    .withColumn("purchase_order_amount", col("purchase_order_amount").cast("double"))

po_count = po_df.count()
print(f"Total purchase orders to process: {po_count:,}")

print("\n=== GENERATING INVOICE DATA USING DISTRIBUTED COMPUTE ===")

# Add unique row ID for invoice_id generation
window = Window.orderBy(monotonically_increasing_id())
po_df = po_df.withColumn("row_id", row_number().over(window))

# Generate invoice_id starting from 60000000
po_df = po_df.withColumn("invoice_id", col("row_id") + 60000000 - 1)

# Generate invoice_number
po_df = po_df.withColumn("invoice_number", concat(lit("INV"), col("invoice_id").cast("string")))

# Invoice date: 15-90 days after PO date (using random)
# Use hash of PO id to generate deterministic but varied random values
po_df = po_df.withColumn("days_after_po", (expr("abs(hash(purchase_order_id))") % 76) + 15)
po_df = po_df.withColumn("invoice_date", col("purchase_order_date") + col("days_after_po") * 86400)

# Payment due date: 15-90 days after invoice date
po_df = po_df.withColumn("days_until_due", (expr("abs(hash(purchase_order_id + 1))") % 76) + 15)
po_df = po_df.withColumn("payment_due_date", col("invoice_date") + col("days_until_due") * 86400)

# Invoice amount: 90% full amount, 10% partial
# Use hash to determine if partial
po_df = po_df.withColumn("is_partial", (expr("abs(hash(purchase_order_id + 2))") % 100) < 10)
po_df = po_df.withColumn("partial_pct", 0.1 + ((expr("abs(hash(purchase_order_id + 3))") % 85) / 100.0))
po_df = po_df.withColumn(
    "invoice_amount",
    when(col("is_partial"), col("purchase_order_amount") * col("partial_pct"))
    .otherwise(col("purchase_order_amount"))
)

# Round to 2 decimal places
po_df = po_df.withColumn("invoice_amount", expr("round(invoice_amount, 2)"))

# Tax rate: 0%, 5%, 7%, 8.5%, 10%
po_df = po_df.withColumn("tax_idx", expr("abs(hash(purchase_order_id + 4))") % 5)
po_df = po_df.withColumn(
    "tax_rate",
    when(col("tax_idx") == 0, lit(0.0))
    .when(col("tax_idx") == 1, lit(0.05))
    .when(col("tax_idx") == 2, lit(0.07))
    .when(col("tax_idx") == 3, lit(0.085))
    .otherwise(lit(0.10))
)
po_df = po_df.withColumn("tax_amount", expr("round(invoice_amount * tax_rate, 2)"))

# Discount: 85% no discount, 15% with discount (1%, 2%, 2.5%, 5%)
po_df = po_df.withColumn("has_discount", (expr("abs(hash(purchase_order_id + 5))") % 100) < 15)
po_df = po_df.withColumn("discount_idx", expr("abs(hash(purchase_order_id + 6))") % 4)
po_df = po_df.withColumn(
    "discount_rate",
    when(~col("has_discount"), lit(0.0))
    .when(col("discount_idx") == 0, lit(0.01))
    .when(col("discount_idx") == 1, lit(0.02))
    .when(col("discount_idx") == 2, lit(0.025))
    .otherwise(lit(0.05))
)
po_df = po_df.withColumn("discount_amount", expr("round(invoice_amount * discount_rate, 2)"))

# Total invoice amount
po_df = po_df.withColumn("total_invoice_amount", expr("round(invoice_amount + tax_amount - discount_amount, 2)"))

# Days since invoice (for status calculation)
po_df = po_df.withColumn("days_since_invoice", (lit(CURRENT_TS) - col("invoice_date")) / 86400)
po_df = po_df.withColumn("is_past_due", col("invoice_date") > lit(CURRENT_TS))

# Invoice status based on days since invoice
# Future invoice: pending_approval
# 0-7 days: 60% pending_approval, 40% approved
# 7-30 days: 20% pending_approval, 50% approved, 30% paid
# 30-60 days: 20% approved, 75% paid, 5% disputed (if past due), OR 30% approved, 70% paid (not past due)
# 60+ days: 88% paid, 10% disputed, 2% cancelled (if past due), OR 95% paid, 5% cancelled
po_df = po_df.withColumn("status_rand", expr("abs(hash(purchase_order_id + 7))") % 100)

po_df = po_df.withColumn(
    "invoice_status",
    when(col("invoice_date") > lit(CURRENT_TS), lit("pending_approval"))
    .when(
        col("days_since_invoice") < 7,
        when(col("status_rand") < 60, lit("pending_approval")).otherwise(lit("approved"))
    )
    .when(
        col("days_since_invoice") < 30,
        when(col("status_rand") < 20, lit("pending_approval"))
        .when(col("status_rand") < 70, lit("approved"))
        .otherwise(lit("paid"))
    )
    .when(
        col("days_since_invoice") < 60,
        when(
            col("is_past_due"),
            when(col("status_rand") < 20, lit("approved"))
            .when(col("status_rand") < 95, lit("paid"))
            .otherwise(lit("disputed"))
        ).otherwise(
            when(col("status_rand") < 30, lit("approved")).otherwise(lit("paid"))
        )
    )
    .otherwise(
        when(
            col("is_past_due"),
            when(col("status_rand") < 88, lit("paid"))
            .when(col("status_rand") < 98, lit("disputed"))
            .otherwise(lit("cancelled"))
        ).otherwise(
            when(col("status_rand") < 95, lit("paid")).otherwise(lit("cancelled"))
        )
    )
)

# Amount paid based on status
po_df = po_df.withColumn(
    "amount_paid",
    when(col("invoice_status") == "paid", col("total_invoice_amount"))
    .when(col("invoice_status") == "cancelled", lit(0.0))
    .when(col("invoice_status") == "disputed", col("total_invoice_amount") * (expr("abs(hash(purchase_order_id + 8))") % 50) / 100)
    .otherwise(lit(0.0))
)
po_df = po_df.withColumn("amount_paid", expr("round(amount_paid, 2)"))

# Payment date (only for paid invoices)
po_df = po_df.withColumn("payment_early", (expr("abs(hash(purchase_order_id + 9))") % 100) < 80)
po_df = po_df.withColumn(
    "payment_days_offset",
    when(
        col("payment_early"),
        expr("abs(hash(purchase_order_id + 10))") % expr("cast(days_until_due as int)")
    ).otherwise(
        expr("cast(days_until_due as int)") + ((expr("abs(hash(purchase_order_id + 11))") % 30))
    )
)
po_df = po_df.withColumn(
    "payment_date_calc",
    col("invoice_date") + col("payment_days_offset") * 86400
)
# Don't allow future payment dates
po_df = po_df.withColumn(
    "payment_date",
    when(col("invoice_status") == "paid", 
         when(col("payment_date_calc") > lit(CURRENT_TS), lit(CURRENT_TS)).otherwise(col("payment_date_calc"))
    ).otherwise(lit(None))
)

# Payment method and reference (only for paid invoices)
po_df = po_df.withColumn("payment_method_idx", expr("abs(hash(purchase_order_id + 12))") % 3)
po_df = po_df.withColumn(
    "payment_method",
    when(col("invoice_status") == "paid",
        when(col("payment_method_idx") == 0, lit("check"))
        .when(col("payment_method_idx") == 1, lit("ach"))
        .otherwise(lit("wire"))
    ).otherwise(lit(None))
)
po_df = po_df.withColumn(
    "payment_reference",
    when(col("invoice_status") == "paid",
        concat(lit("PAY"), col("invoice_id").cast("string"), (expr("abs(hash(purchase_order_id + 13))") % 9000 + 1000).cast("string"))
    ).otherwise(lit(None))
)

# Invoice category
po_df = po_df.withColumn("category_idx", expr("abs(hash(purchase_order_id + 14))") % 5)
po_df = po_df.withColumn(
    "invoice_category",
    when(col("category_idx") == 0, lit("labor"))
    .when(col("category_idx") == 1, lit("materials"))
    .when(col("category_idx") == 2, lit("equipment"))
    .when(col("category_idx") == 3, lit("subcontractor"))
    .otherwise(lit("overhead"))
)

# Receipt matched (90% for paid/approved, 50% otherwise)
po_df = po_df.withColumn("receipt_rand", expr("abs(hash(purchase_order_id + 15))") % 100)
po_df = po_df.withColumn(
    "receipt_matched",
    when(
        col("invoice_status").isin("paid", "approved"),
        col("receipt_rand") < 90
    ).otherwise(
        col("receipt_rand") < 50
    )
)

# Three-way match status
po_df = po_df.withColumn("match_rand", expr("abs(hash(purchase_order_id + 16))") % 100)
po_df = po_df.withColumn(
    "three_way_match_status",
    when(
        col("invoice_status").isin("paid", "approved"),
        when(col("match_rand") < 95, lit("matched")).otherwise(lit("unmatched"))
    )
    .when(
        col("invoice_status") == "disputed",
        when(col("match_rand") < 20, lit("matched")).otherwise(lit("unmatched"))
    )
    .otherwise(
        when(col("match_rand") < 70, lit("matched")).otherwise(lit("unmatched"))
    )
)

# Goods received date (between PO date and invoice date)
po_df = po_df.withColumn(
    "goods_received_days_offset",
    when(
        col("days_after_po") > 10,
        5 + (expr("abs(hash(purchase_order_id + 17))") % (col("days_after_po") - 10))
    ).otherwise(
        1 + (expr("abs(hash(purchase_order_id + 18))") % expr("greatest(cast(days_after_po as int), 1)"))
    )
)
po_df = po_df.withColumn("goods_received_date", col("purchase_order_date") + col("goods_received_days_offset") * 86400)

# Add year and month columns for partitioning
po_df = po_df.withColumn("invoice_year", year(from_unixtime(col("invoice_date")))) \
    .withColumn("invoice_month", month(from_unixtime(col("invoice_date"))))

# Select final columns
invoices_df = po_df.select(
    "invoice_id",
    "purchase_order_id",
    "invoice_number",
    "invoice_date",
    "invoice_status",
    "invoice_amount",
    "tax_amount",
    "discount_amount",
    "total_invoice_amount",
    "amount_paid",
    "payment_due_date",
    "invoice_category",
    "payment_date",
    "payment_method",
    "payment_reference",
    "receipt_matched",
    "three_way_match_status",
    "goods_received_date",
    "invoice_year",
    "invoice_month"
)

# Write partitioned by year and month
print(f"\n=== WRITING DATA TO {output_path} ===")
print("Partitioning by invoice_year and invoice_month...")

invoices_df.write \
    .mode("overwrite") \
    .partitionBy("invoice_year", "invoice_month") \
    .option("header", "true") \
    .csv(output_path)

print("Data written successfully!")

# Read back for statistics
print("\n=== READING GENERATED DATA FOR STATISTICS ===")
df = spark.read.option("header", "true").csv(f"{output_path}/*/*/*.csv")

# Convert columns to proper types for analysis
df = df.withColumn("invoice_amount", col("invoice_amount").cast("double")) \
    .withColumn("total_invoice_amount", col("total_invoice_amount").cast("double")) \
    .withColumn("amount_paid", col("amount_paid").cast("double")) \
    .withColumn("invoice_date", col("invoice_date").cast("long")) \
    .withColumn("purchase_order_id", col("purchase_order_id").cast("long"))

# Show sample
print("\nSample of generated data:")
df.show(20, truncate=False)

# Statistics
print("\n=== DATA STATISTICS ===")
print(f"\nTotal records: {df.count():,}")
print(f"Unique invoice IDs: {df.select('invoice_id').distinct().count():,}")
print(f"Unique invoice numbers: {df.select('invoice_number').distinct().count():,}")
print(f"Unique purchase order IDs: {df.select('purchase_order_id').distinct().count():,}")

# Verify 1:1 relationship
po_count_in_invoices = df.select('purchase_order_id').distinct().count()
invoice_count = df.count()
print(f"\n=== 1:1 RELATIONSHIP VERIFICATION ===")
print(f"Total invoices: {invoice_count:,}")
print(f"Unique PO IDs in invoices: {po_count_in_invoices:,}")
if invoice_count == po_count_in_invoices:
    print("✓ 1:1 relationship maintained: Each PO has exactly one invoice")
else:
    print(f"⚠ WARNING: Relationship mismatch!")

# Status distribution
print("\n=== INVOICE STATUS DISTRIBUTION ===")
df.groupBy("invoice_status").count().orderBy("invoice_status").show()

# Category distribution
print("\n=== INVOICE CATEGORY DISTRIBUTION ===")
df.groupBy("invoice_category").count().orderBy(col("count").desc()).show()

# Payment method distribution (for paid invoices only)
print("\n=== PAYMENT METHOD DISTRIBUTION (Paid Invoices Only) ===")
df.filter(col("invoice_status") == "paid") \
    .groupBy("payment_method").count().orderBy(col("count").desc()).show()

# Three-way match status
print("\n=== THREE-WAY MATCH STATUS ===")
df.groupBy("three_way_match_status").count().orderBy("three_way_match_status").show()

# Invoice amounts by year
print("\n=== INVOICE VALUE BY YEAR ===")
df_with_year = df.withColumn("year", year(from_unixtime(col("invoice_date"))))
df_with_year.groupBy("year").agg(
    _sum("invoice_amount").alias("total_invoice_amount"),
    _sum("amount_paid").alias("total_paid"),
    count("*").alias("invoice_count")
).orderBy("year").show()

# Financial summary
print("\n=== FINANCIAL SUMMARY ===")
financial_summary = df.agg(
    _sum("invoice_amount").alias("total_invoice_amount"),
    _sum("tax_amount").alias("total_tax"),
    _sum("discount_amount").alias("total_discount"),
    _sum("total_invoice_amount").alias("grand_total"),
    _sum("amount_paid").alias("total_paid"),
    _avg("invoice_amount").alias("avg_invoice_amount")
).collect()[0]

print(f"Total invoice amount: ${financial_summary['total_invoice_amount']:,.2f}")
print(f"Total tax: ${financial_summary['total_tax']:,.2f}")
print(f"Total discount: ${financial_summary['total_discount']:,.2f}")
print(f"Grand total (with tax, less discount): ${financial_summary['grand_total']:,.2f}")
print(f"Total amount paid: ${financial_summary['total_paid']:,.2f}")
print(f"Average invoice amount: ${financial_summary['avg_invoice_amount']:,.2f}")

outstanding = financial_summary['grand_total'] - financial_summary['total_paid']
print(f"Outstanding balance: ${outstanding:,.2f}")

# Verify invoice amounts don't exceed PO amounts
print("\n=== VALIDATING INVOICE AMOUNTS vs PO AMOUNTS ===")

# Join with PO data
po_df_for_validation = spark.read.option("header", "true").csv(f"{directory}/purchase_orders/*/*.csv") \
    .withColumn("purchase_order_id", col("purchase_order_id").cast("long")) \
    .withColumn("purchase_order_amount", col("purchase_order_amount").cast("double")) \
    .select("purchase_order_id", "purchase_order_amount")

validation_df = df.join(
    po_df_for_validation,
    "purchase_order_id",
    "inner"
).withColumn(
    "exceeds_po",
    col("invoice_amount") > col("purchase_order_amount")
)

violations = validation_df.filter(col("exceeds_po") == True).count()
total_validated = validation_df.count()

if violations > 0:
    print(f"⚠ WARNING: {violations} out of {total_validated:,} invoices exceed their PO amount!")
    print("\nSample violations:")
    validation_df.filter(col("exceeds_po") == True) \
        .select("invoice_id", "purchase_order_id", "invoice_amount", "purchase_order_amount") \
        .show(10)
else:
    print(f"✓ All {total_validated:,} invoices are within their PO amount limits")

# Receipt matched statistics
print("\n=== RECEIPT MATCHED STATISTICS ===")
df.withColumn("receipt_matched_str", col("receipt_matched").cast("string")) \
    .groupBy("receipt_matched_str").count().show()

print("\n=== DATA GENERATION COMPLETE ===")
print(f"Files are located in: {output_path}/")
print("Data is partitioned by invoice_year and invoice_month")

