In [0]:
from pyspark.sql.types import StructType, StructField, LongType, StringType, DoubleType
from pyspark.sql.functions import col, year, month, from_unixtime
from datetime import datetime
import random


# -------------------- KEY VARIABLES -------------------- #
catalog = 'fin_demo'
directory = f"/Volumes/{catalog}/fin/data_gen_outputs"
output_path = f"{directory}/revenue_billings"

In [0]:
"""
Generate synthetic data for the revenue_billings table.
This script creates billing records linked to revenue transactions.
Every billing has a revenue transaction reference.
"""


# Define schema
schema = StructType([
    StructField("billing_id", LongType(), nullable=False),
    StructField("rev_trx_id", LongType(), nullable=False),
    StructField("invoice_date", LongType(), nullable=False),
    StructField("invoice_number", StringType(), nullable=False),
    StructField("billed_amount", DoubleType(), nullable=False),
    StructField("payment_due_date", LongType(), nullable=False),
    StructField("payment_received_date", LongType(), nullable=False),
    StructField("retention_amount", DoubleType(), nullable=False)
])

def generate_invoice_number(billing_id, year):
    """Generate a realistic invoice number."""
    return f"INV-{year}-{str(billing_id)[-6:].zfill(6)}"

# Generate data
print("Generating revenue billings data...")

# Read all revenue transactions from partitioned directories
print("\nReading revenue transactions data...")
rev_trx_df = spark.read.option("header", "true").csv(f"{directory}/revenue_transactions/revenue_transactions_*/*.csv")

# Convert string columns to appropriate types
rev_trx_df = rev_trx_df \
    .withColumn("rev_trx_id", col("rev_trx_id").cast("long")) \
    .withColumn("transaction_date", col("transaction_date").cast("long")) \
    .withColumn("amount", col("amount").cast("double"))

# Collect revenue transaction data - filter out zero or null amounts
print("Processing revenue transactions...")
rev_trx_data = rev_trx_df \
    .filter((col("amount").isNotNull()) & (col("amount") > 0)) \
    .select("rev_trx_id", "transaction_date", "amount") \
    .collect()

print(f"Found {len(rev_trx_data)} revenue transactions with valid amounts")

# Generate billing records
billings = []
billing_id = 60000000  # Start with 8-digit ID

for row in rev_trx_data:
    rev_trx_id = row.rev_trx_id
    transaction_date = row.transaction_date
    transaction_amount = row.amount

    # Invoice date: same as or shortly after transaction date (0-30 days)
    invoice_offset_days = random.randint(0, 30)
    invoice_date = transaction_date + (invoice_offset_days * 24 * 60 * 60)

    # Generate invoice number
    invoice_year = datetime.fromtimestamp(invoice_date).year
    invoice_number = generate_invoice_number(billing_id, invoice_year)

    # Billed amount: 90% to 110% of transaction amount
    billed_amount_ratio = random.uniform(0.9, 1.1)
    billed_amount = float(round(transaction_amount * billed_amount_ratio, 2))

    # Payment due date: based on payment terms (assume Net-30, Net-45, Net-60)
    # Use random payment terms for variety
    payment_terms_days = random.choice([30, 45, 60])
    payment_due_date = invoice_date + (payment_terms_days * 24 * 60 * 60)

    # Payment received date: 5 to 90 days after invoice date
    payment_offset_days = random.randint(5, 90)
    payment_received_date = invoice_date + (payment_offset_days * 24 * 60 * 60)

    # Retention amount: 0 to 10% of billed amount
    retention_percentage = random.uniform(0, 0.1)
    retention_amount = float(round(billed_amount * retention_percentage, 2))

    billings.append({
        "billing_id": billing_id,
        "rev_trx_id": rev_trx_id,
        "invoice_date": invoice_date,
        "invoice_number": invoice_number,
        "billed_amount": billed_amount,
        "payment_due_date": payment_due_date,
        "payment_received_date": payment_received_date,
        "retention_amount": retention_amount
    })

    billing_id += 1

print(f"\nGenerated {len(billings)} billing records")

# Create DataFrame
print("Creating DataFrame...")
df = spark.createDataFrame(billings, schema=schema)

# Show sample
print("\nSample of generated data:")
df.show(20, truncate=False)

# Statistics
print(f"\nTotal records: {df.count():,}")
print(f"Unique billing IDs: {df.select('billing_id').distinct().count():,}")
print(f"Unique revenue transaction IDs: {df.select('rev_trx_id').distinct().count():,}")

# Amount statistics
print("\nBilled amount statistics:")
df.select("billed_amount").summary("count", "min", "max", "mean", "stddev").show()

print("\nRetention amount statistics:")
df.select("retention_amount").summary("count", "min", "max", "mean", "stddev").show()

# Validate constraints
print("\nValidating constraints...")

# Join with revenue transactions to validate billed amount
validation_df = df.join(
    rev_trx_df.filter((col("amount").isNotNull()) & (col("amount") > 0)).select(
        col("rev_trx_id").alias("trx_id"),
        col("amount").alias("transaction_amount")
    ),
    df.rev_trx_id == col("trx_id"),
    "inner"
).filter(
    (col("transaction_amount") > 0) & (col("billed_amount") > 0)
).withColumn(
    "billed_percentage",
    (col("billed_amount") / col("transaction_amount")) * 100
).withColumn(
    "retention_percentage",
    (col("retention_amount") / col("billed_amount")) * 100
)

# Check billed amount is within 90-110%
print("\nBilled amount validation (should be 90-110% of transaction amount):")
validation_df.select("billed_percentage").summary("min", "max", "mean").show()

violations_billed = validation_df.filter(
    (col("billed_percentage") < 90) | (col("billed_percentage") > 110)
).count()

if violations_billed > 0:
    print(f"⚠ WARNING: {violations_billed} billings outside 90-110% range!")
else:
    print("✓ All billings within 90-110% range")

# Check retention is within 0-10%
print("\nRetention amount validation (should be 0-10% of billed amount):")
validation_df.select("retention_percentage").summary("min", "max", "mean").show()

violations_retention = validation_df.filter(
    (col("retention_percentage") < 0) | (col("retention_percentage") > 10)
).count()

if violations_retention > 0:
    print(f"⚠ WARNING: {violations_retention} retention amounts outside 0-10% range!")
else:
    print("✓ All retention amounts within 0-10% range")

# Write to CSV grouped by invoice_date month
print(f"\nWriting data to {output_path}...")
print("Partitioning by invoice year and month...")

# Add year and month columns based on invoice_date
df_with_partitions = df.withColumn(
    "invoice_year",
    year(from_unixtime(col("invoice_date")))
).withColumn(
    "invoice_month",
    month(from_unixtime(col("invoice_date")))
)

# Write partitioned by invoice year and month
df_with_partitions.select(
    "billing_id",
    "rev_trx_id",
    "invoice_date",
    "invoice_number",
    "billed_amount",
    "payment_due_date",
    "payment_received_date",
    "retention_amount",
    "invoice_year",
    "invoice_month"
).write \
    .mode("overwrite") \
    .partitionBy("invoice_year", "invoice_month") \
    .option("header", "true") \
    .csv(output_path, encoding="utf-8")

print("\nData generation complete!")
print(f"Files are located in: {output_path}/")
print("Files are organized by invoice year and month (e.g., invoice_year=2024/invoice_month=1/)")
