In [0]:
from pyspark.sql.types import StructType, StructField, LongType, StringType, DoubleType
from pyspark.sql.functions import col, sum as _sum, year, from_unixtime
from datetime import datetime, timedelta
from dateutil.relativedelta import relativedelta
import random

# -------------------- KEY VARIABLES -------------------- #
catalog = 'fin_demo'
directory = f"/Volumes/{catalog}/fin/data_gen_outputs"
output_path = f"{directory}/purchase_orders"

# Number of months to generate (from 2023 to present)
start_year = 2023
start_date = datetime(start_year, 1, 1)
end_date = datetime.now()
num_months = (end_date.year - start_year) * 12 + end_date.month - start_date.month + 1

# Records per month range
records_per_month_min = 50000
records_per_month_max = 100000

# scale aggregate target: .5 Billion dollars (NOT YEARLY TARGET - STATS ARE WEIRD)
# scale target .5 billion dollars = ~1.2 billion dollars
scale_target = 500_000_000

In [0]:
"""
Generate synthetic data for the purchase_orders table.
This script uses PySpark with distributed compute for high-volume data generation.
Data is partitioned by purchase order month with 50K-100K records per partition.
Generates purchase orders starting from 2023.
"""

# Define schema
schema = StructType([
    StructField("purchase_order_id", LongType(), nullable=False),
    StructField("contract_id", LongType(), nullable=True),
    StructField("purchase_order_number", StringType(), nullable=False),
    StructField("purchase_order_date", LongType(), nullable=False),
    StructField("purchase_order_status", StringType(), nullable=False),
    StructField("supplier_id", LongType(), nullable=False),
    StructField("purchase_order_currency", StringType(), nullable=False),
    StructField("purchase_order_amount", DoubleType(), nullable=False),
    StructField("total_purchase_order_value", DoubleType(), nullable=False),
    StructField("coa_id", LongType(), nullable=False)
])

# Enum values from schema
PO_STATUS = ["draft", "approved", "issued", "partially_received", "closed", "cancelled"]

def generate_po_amount(target_avg):
    """Generate realistic PO amounts."""
    # Use lognormal distribution to get variety but bias towards lower values
    amount = random.lognormvariate(0, 1.5) * float(target_avg)
    return float(round(max(100.0, min(amount, target_avg * 100)), 2))

def generate_po_status(po_date_ts):
    """Generate PO status based on how long ago the PO was created."""
    current_ts = int(datetime.now().timestamp())
    days_since_po = (current_ts - po_date_ts) / (24 * 60 * 60)

    if days_since_po < 0:
        # Future PO
        return random.choice(["draft", "approved"])
    elif days_since_po < 7:
        # Very recent
        return random.choices(["draft", "approved", "issued"], weights=[10, 40, 50])[0]
    elif days_since_po < 30:
        # Recent
        return random.choices(["approved", "issued", "partially_received"], weights=[5, 50, 45])[0]
    elif days_since_po < 90:
        # Not too old
        return random.choices(["issued", "partially_received", "closed"], weights=[10, 60, 30])[0]
    else:
        # Old
        return random.choices(["partially_received", "closed", "cancelled"], weights=[10, 85, 5])[0]

# Generate data
print(f"Generating purchase orders from {start_date.strftime('%Y-%m-%d')} to {end_date.strftime('%Y-%m-%d')}")
print(f"Target: {records_per_month_min:,}-{records_per_month_max:,} records per month")
print(f"Annual PO target: ${scale_target:,.2f}")

# Read reference data
print("\nReading reference data...")
contracts_df = spark.read.json(f"{directory}/inbound_contracts")
suppliers_df = spark.read.json(f"{directory}/suppliers")
coa_df = spark.read.json(f"{directory}/coa_hierarchy")

# Collect reference data
contracts_data = [
    {
        'contract_id': row.contract_id,
        'supplier_id': row.supplier_id,
        'start_date': row.contract_start_date,
        'end_date': row.estimated_completion_date,
        'total_value': row.total_contract_value,
        'age_days': (datetime.now().timestamp() - row.contract_start_date) / 86400
    }
    for row in contracts_df.select(
        "contract_id",
        "supplier_id",
        "contract_start_date",
        "estimated_completion_date",
        "total_contract_value"
    ).collect()
]

supplier_ids = [row.supplier_id for row in suppliers_df.select("supplier_id").collect()]
coa_ids = [row.coa_id for row in coa_df.select("coa_id").collect()]

print(f"Loaded {len(contracts_data)} contracts")
print(f"Loaded {len(supplier_ids)} suppliers")
print(f"Loaded {len(coa_ids)} COA entries")

# Calculate target average PO amount
estimated_total_pos = num_months * ((records_per_month_min + records_per_month_max) / 2)
target_avg_amount = (scale_target * (num_months / 12)) / estimated_total_pos

print(f"\nTarget average PO amount: ${target_avg_amount:.2f}")

# Track contract utilization
contract_tracker = {c['contract_id']: 0 for c in contracts_data}
contract_po_count = {c['contract_id']: 0 for c in contracts_data}

print(f"\nAll {len(contracts_data)} contracts available for PO assignment")

po_id = 50000000  # Start with 8-digit ID

# Generate and write POs month by month
current_date = start_date
total_records_generated = 0

for month_idx in range(num_months):
    # Calculate month boundaries
    month_start = current_date
    month_end = month_start + relativedelta(months=1) - timedelta(seconds=1)

    month_start_ts = int(month_start.timestamp())
    month_end_ts = int(month_end.timestamp())

    # Random number of records for this month
    num_records = random.randint(records_per_month_min, records_per_month_max)

    print(f"Month {month_idx + 1}/{num_months}: {month_start.strftime('%Y-%m')} - {num_records:,} records")

    # Filter contracts active in this month that can still accept POs
    active_contracts = []
    for c in contracts_data:
        if c['start_date'] <= month_end_ts and c['end_date'] >= month_start_ts:
            contract_id = c['contract_id']

            # Calculate age-based utilization target
            contract_age_days = c['age_days']
            max_age = max([c['age_days'] for c in contracts_data]) if contracts_data else 1

            if max_age > 0 and contract_age_days >= 0:
                age_factor = min(contract_age_days / max_age, 1.0)
                # Newer contracts: 50-75% target, Older contracts: 75-95% target
                utilization_target = 0.5 + (age_factor * 0.45)
            else:
                utilization_target = 0.75

            current_utilization = contract_tracker[contract_id] / c['total_value'] if c['total_value'] > 0 else 0

            # Only include if not yet at 95% of contract value
            if current_utilization < 0.95:
                active_contracts.append(c)

    # Generate POs
    month_pos = []

    # 60% of POs should have associated contract (from spec)
    num_with_contract = int(num_records * 0.6)
    num_without_contract = num_records - num_with_contract

    # POs with contracts
    pos_created_with_contract = 0
    for i in range(num_with_contract):
        if not active_contracts:
            # No more contracts available, rest will be without contract
            print(f"  Note: Ran out of available contracts, creating {num_with_contract - i} additional POs without contracts")
            break

        # Select contract
        contract = random.choice(active_contracts)
        contract_id = contract['contract_id']
        supplier_id = contract['supplier_id']

        # PO date within contract period AND month
        contract_start = max(contract['start_date'], month_start_ts)
        contract_end = min(contract['end_date'], month_end_ts)

        if contract_end <= contract_start:
            continue

        contract_duration = contract_end - contract_start
        random_offset = random.randint(0, max(1, int(contract_duration)))
        po_date = contract_start + random_offset

        # Calculate PO amount with age-based utilization
        current_total = contract_tracker[contract_id]
        remaining = contract['total_value'] - current_total

        if remaining <= 0:
            # Contract fully utilized, skip
            continue

        # Calculate age-based utilization target
        contract_age_days = contract['age_days']
        max_age = max([c['age_days'] for c in contracts_data]) if contracts_data else 1

        if max_age > 0 and contract_age_days >= 0:
            age_factor = min(contract_age_days / max_age, 1.0)
            # Newer contracts: 50-75% target, Older contracts: 75-95% target
            utilization_target = 0.5 + (age_factor * 0.45)
        else:
            utilization_target = 0.75

        max_allowed = contract['total_value'] * utilization_target

        # Check if we should add more to this contract
        # Allow some variance but generally respect the utilization target
        if current_total < max_allowed or random.random() < 0.2:  # 20% chance to go beyond target (but not beyond contract value)
            # Generate amount that doesn't exceed remaining
            base_amount = generate_po_amount(target_avg_amount)
            po_amount = float(min(base_amount, remaining))

            contract_tracker[contract_id] += po_amount
        else:
            # This contract has reached its target utilization for now
            continue

        # Generate unique PO number (po_id is already unique)
        po_number = f"PO{po_id}"

        # PO status based on date
        po_status = generate_po_status(po_date)

        coa_id = random.choice(coa_ids)

        month_pos.append({
            "purchase_order_id": po_id,
            "contract_id": contract_id,
            "purchase_order_number": po_number,
            "purchase_order_date": po_date,
            "purchase_order_status": po_status,
            "supplier_id": supplier_id,
            "purchase_order_currency": "USD",
            "purchase_order_amount": po_amount,
            "total_purchase_order_value": po_amount,  # Same as amount
            "coa_id": coa_id
        })

        po_id += 1
        pos_created_with_contract += 1

    # Adjust num_without_contract to make up for any shortfall
    actual_without_contract = num_records - pos_created_with_contract

    # POs without contracts
    for i in range(actual_without_contract):
        random_offset = random.randint(0, int((month_end_ts - month_start_ts)))
        po_date = month_start_ts + random_offset

        supplier_id = random.choice(supplier_ids)
        coa_id = random.choice(coa_ids)
        po_amount = generate_po_amount(target_avg_amount)

        # Generate unique PO number (po_id is already unique)
        po_number = f"PO{po_id}"

        po_status = generate_po_status(po_date)

        month_pos.append({
            "purchase_order_id": po_id,
            "contract_id": None,
            "purchase_order_number": po_number,
            "purchase_order_date": po_date,
            "purchase_order_status": po_status,
            "supplier_id": supplier_id,
            "purchase_order_currency": "USD",
            "purchase_order_amount": po_amount,
            "total_purchase_order_value": po_amount,
            "coa_id": coa_id
        })

        po_id += 1

    # Create DataFrame for this month
    if month_pos:
        month_df = spark.createDataFrame(month_pos, schema=schema)

        # Write this month's data
        year_str = month_start.strftime('%Y')
        month_str = month_start.strftime('%m')
        month_filename = f"purchase_orders_{year_str}_{month_str}"
        month_output_path = f"{output_path}/{month_filename}"

        print(f"  Writing to {month_filename}/ ({len(month_pos):,} records)")

        month_df.coalesce(1).write.mode("overwrite").option("header", "true").csv(month_output_path)

        total_records_generated += len(month_pos)

    # Move to next month
    current_date = current_date + relativedelta(months=1)

print(f"\n\nTotal purchase orders generated: {total_records_generated:,}")

# Read back the data for statistics
print("\nReading generated data for statistics...")
df = spark.read.option("header", "true").csv(f"{output_path}/*/*.csv")

# Convert columns to proper types
df = df.withColumn("purchase_order_amount", col("purchase_order_amount").cast("double"))

# Show sample
print("\nSample of generated data:")
df.show(20, truncate=False)

# Statistics
print(f"\nTotal records: {df.count():,}")
print(f"Unique PO IDs: {df.select('purchase_order_id').distinct().count():,}")

# Contract statistics
contract_count = df.filter(col('contract_id').isNotNull()).count()
total_count = df.count()
if total_count > 0:
    contract_percentage = (contract_count / total_count) * 100
    print(f"\nRecords with contract: {contract_count:,} ({contract_percentage:.1f}%)")
    print(f"Records without contract: {total_count - contract_count:,} ({100 - contract_percentage:.1f}%)")

# Status distribution
print("\nPO status distribution:")
df.groupBy("purchase_order_status").count().orderBy("purchase_order_status").show()

# PO value by year
print("\nPO value by year:")
df_with_year = df.withColumn("year", year(from_unixtime(col("purchase_order_date"))))
df_with_year.groupBy("year").agg(_sum("purchase_order_amount").alias("total_po_value")).orderBy("year").show()

# Validate contract constraints
print("\nValidating contract constraints (PO amounts should not exceed contract value)...")

contracts_df_ref = spark.read.json(f"{directory}/inbound_contracts") \
    .select(
        col("contract_id"),
        col("total_contract_value"),
        col("contract_start_date")
    )

# Aggregate PO amounts by contract
contract_totals = df.filter(col('contract_id').isNotNull()) \
    .withColumn("contract_id_long", col("contract_id").cast("long")) \
    .groupBy('contract_id_long') \
    .agg(_sum('purchase_order_amount').alias('total_po_amount'))

# Join with contract data and calculate percentage
validation = contract_totals.join(
    contracts_df_ref,
    contract_totals.contract_id_long == contracts_df_ref.contract_id,
    "inner"
).withColumn('utilization_pct', (col('total_po_amount') / col('total_contract_value')) * 100) \
 .withColumn('exceeds_limit', col('total_po_amount') > col('total_contract_value')) \
 .withColumn('contract_age_days', (datetime.now().timestamp() - col('contract_start_date')) / 86400)

print("\nContract utilization (top 10 by PO amount):")
validation.select(
    'contract_id',
    'total_po_amount',
    'total_contract_value',
    'utilization_pct',
    'contract_age_days',
    'exceeds_limit'
).orderBy(col('total_po_amount').desc()).show(10)

# Count violations
violations = validation.filter(col('exceeds_limit') == True).count()
total_contracts_with_pos = validation.count()

if violations > 0:
    print(f"\n⚠ WARNING: {violations} out of {total_contracts_with_pos} contracts exceeded their contract value!")
else:
    print(f"\n✓ All {total_contracts_with_pos} contracts are within their contract value limits")

# Show utilization distribution
print("\nContract utilization distribution:")
validation.selectExpr(
    "CASE " +
    "WHEN utilization_pct < 25 THEN '0-25%' " +
    "WHEN utilization_pct >= 25 AND utilization_pct < 50 THEN '25-50%' " +
    "WHEN utilization_pct >= 50 AND utilization_pct < 75 THEN '50-75%' " +
    "WHEN utilization_pct >= 75 AND utilization_pct < 100 THEN '75-100%' " +
    "ELSE 'Above 100%' END as utilization_range"
).groupBy("utilization_range").count().orderBy("utilization_range").show()

# Check age-based utilization
print("\nAge-based utilization check (newer contracts should have lower utilization):")
validation.selectExpr(
    "CASE " +
    "WHEN contract_age_days < 180 THEN 'Very New (0-6mo)' " +
    "WHEN contract_age_days >= 180 AND contract_age_days < 365 THEN 'New (6-12mo)' " +
    "WHEN contract_age_days >= 365 AND contract_age_days < 730 THEN 'Older (1-2yr)' " +
    "ELSE 'Very Old (2yr+)' END as age_range",
    "utilization_pct"
).groupBy("age_range").agg(
    {"utilization_pct": "avg", "utilization_pct": "min", "utilization_pct": "max"}
).orderBy("age_range").show()

print("\nData generation complete!")
print(f"\nFiles are located in: {output_path}/")
print("Each month has its own directory with a CSV file inside")
