In [0]:
from pyspark.sql.types import StructType, StructField, LongType, StringType, DoubleType
from pyspark.sql.functions import col, sum as _sum, year, from_unixtime
from datetime import datetime, timedelta
from dateutil.relativedelta import relativedelta
import random


# -------------------- KEY VARIABLES -------------------- #
catalog = 'main'
directory = f"/Volumes/{catalog}/finance_lakehouse/data_gen_outputs"
output_path = f"{directory}/revenue_transactions"


# Number of months to generate (from 2023 to present)
start_year = 2023
start_date = datetime(start_year, 1, 1)
end_date = datetime.now()
num_months = (end_date.year - start_year) * 12 + end_date.month - start_date.month + 1

# Records per month range
records_per_month_min = 50000
records_per_month_max = 100000

# Annual aggregate target: 2 Billion dollars
annual_target = 2_000_000_000


In [0]:
"""
Generate synthetic data for the revenue_transactions table.
This script uses PySpark with distributed compute for high-volume data generation.
Data is partitioned by transaction month with 50K-100K records per partition.
Generates revenue transactions starting from 2023.
"""


# Define schema
schema = StructType([
    StructField("rev_trx_id", LongType(), nullable=False),
    StructField("contract_id", LongType(), nullable=True),
    StructField("customer_id", LongType(), nullable=False),
    StructField("transaction_date", LongType(), nullable=False),
    StructField("coa_id", LongType(), nullable=False),
    StructField("cost_category", StringType(), nullable=False),
    StructField("amount", DoubleType(), nullable=False)
])

# Enum values from schema
COST_CATEGORIES = ["labor", "compute", "materials", "subcontractor", "equipment", "overhead"]

def generate_transaction_amount(cost_category, target_avg):
    """Generate realistic transaction amounts based on cost category."""
    ranges = {
        "labor": (50.0, 5000.0),
        "compute": (50.0, 5000.0),
        "materials": (100.0, 50000.0),
        "subcontractor": (1000.0, 100000.0),
        "equipment": (500.0, 25000.0),
        "overhead": (100.0, 10000.0)
    }

    min_val, max_val = ranges.get(cost_category, (100.0, 10000.0))
    # Use lognormal distribution to get variety but bias towards lower values
    amount = random.lognormvariate(0, 1.5) * float(target_avg)
    return float(round(max(min_val, min(amount, max_val)), 2))

# Generate data
print(f"Generating revenue transactions data from {start_date.strftime('%Y-%m-%d')} to {end_date.strftime('%Y-%m-%d')}")
print(f"Target: {records_per_month_min:,}-{records_per_month_max:,} records per month")
print(f"Annual revenue target: ${annual_target:,.2f}")

# Read reference data
print("\nReading reference data...")
contracts_df = spark.read.json(f"{directory}/outbound_contracts")
customers_df = spark.read.json(f"{directory}/customers")
coa_df = spark.read.json(f"{directory}/coa_hierarchy")

# Collect reference data
contracts_data = [
    {
        'contract_id': row.contract_id,
        'customer_id': row.customer_id,
        'start_date': row.contract_start_date,
        'end_date': row.estimated_completion_date,
        'total_value': row.total_contract_value,
        'age_days': (datetime.now().timestamp() - row.contract_start_date) / 86400
    }
    for row in contracts_df.select(
        "contract_id",
        "customer_id",
        "contract_start_date",
        "estimated_completion_date",
        "total_contract_value"
    ).collect()
]

customer_ids = [row.customer_id for row in customers_df.select("customer_id").collect()]
coa_ids = [row.coa_id for row in coa_df.select("coa_id").collect()]

print(f"Loaded {len(contracts_data)} contracts")
print(f"Loaded {len(customer_ids)} customers")
print(f"Loaded {len(coa_ids)} COA entries")

# Calculate target average transaction amount
estimated_total_transactions = num_months * ((records_per_month_min + records_per_month_max) / 2)
target_avg_amount = (annual_target * (num_months / 12)) / estimated_total_transactions

print(f"\nTarget average transaction amount: ${target_avg_amount:.2f}")

# Track contract utilization globally
contract_tracker = {c['contract_id']: 0 for c in contracts_data}

# Select subset of contracts to have revenue (not all contracts should have revenue)
contracts_with_revenue = random.sample(contracts_data, int(len(contracts_data) * 0.6))
print(f"60% of contracts ({len(contracts_with_revenue)}) will have revenue transactions")

rev_trx_id = 50000000  # Start with 8-digit ID

# Generate and write transactions month by month
current_date = start_date
total_records_generated = 0

for month_idx in range(num_months):
    # Calculate month boundaries
    month_start = current_date
    month_end = month_start + relativedelta(months=1) - timedelta(seconds=1)

    month_start_ts = int(month_start.timestamp())
    month_end_ts = int(month_end.timestamp())

    # Random number of records for this month
    num_records = random.randint(records_per_month_min, records_per_month_max)

    print(f"Month {month_idx + 1}/{num_months}: {month_start.strftime('%Y-%m')} - {num_records:,} records")

    # Filter contracts active in this month
    active_contracts = [
        c for c in contracts_with_revenue
        if c['start_date'] <= month_end_ts and c['end_date'] >= month_start_ts
    ]

    # Generate transactions
    month_transactions = []

    # 80% of records should have associated contract
    num_with_contract = int(num_records * 0.8)
    num_without_contract = num_records - num_with_contract

    # Transactions with contracts
    for i in range(num_with_contract):
        if not active_contracts:
            break

        # Select contract with consideration for utilization
        # Older contracts should have higher utilization
        contract = random.choice(active_contracts)
        contract_id = contract['contract_id']
        customer_id = contract['customer_id']

        # Transaction date within contract period AND month
        contract_start = max(contract['start_date'], month_start_ts)
        contract_end = min(contract['end_date'], month_end_ts)

        if contract_end <= contract_start:
            continue

        contract_duration = contract_end - contract_start
        random_offset = random.randint(0, max(1, int(contract_duration)))
        transaction_date = contract_start + random_offset

        # Calculate utilization target based on contract age
        # Newer contracts: lower utilization, Older contracts: higher utilization
        contract_age_days = contract['age_days']
        max_age = max([c['age_days'] for c in contracts_with_revenue])

        if max_age > 0:
            age_factor = contract_age_days / max_age
            utilization_target = 0.5 + (age_factor * 0.5)  # 50% to 100% utilization
        else:
            utilization_target = 0.75

        max_allowed = contract['total_value'] * utilization_target
        current_total = contract_tracker[contract_id]

        # Don't exceed contract value
        if current_total >= contract['total_value']:
            continue

        # Generate amount
        coa_id = random.choice(coa_ids)
        cost_category = random.choice(COST_CATEGORIES)
        base_amount = generate_transaction_amount(cost_category, target_avg_amount)

        # Ensure we don't exceed contract value
        remaining = float(contract['total_value'] - current_total)
        amount = float(min(base_amount, remaining))

        if amount <= 0:
            continue

        contract_tracker[contract_id] += amount

        month_transactions.append({
            "rev_trx_id": rev_trx_id,
            "contract_id": contract_id,
            "customer_id": customer_id,
            "transaction_date": transaction_date,
            "coa_id": coa_id,
            "cost_category": cost_category,
            "amount": float(round(amount, 2))
        })

        rev_trx_id += 1

    # Transactions without contracts
    for i in range(num_without_contract):
        random_offset = random.randint(0, int((month_end_ts - month_start_ts)))
        transaction_date = month_start_ts + random_offset

        customer_id = random.choice(customer_ids)
        coa_id = random.choice(coa_ids)
        cost_category = random.choice(COST_CATEGORIES)
        amount = generate_transaction_amount(cost_category, target_avg_amount)

        month_transactions.append({
            "rev_trx_id": rev_trx_id,
            "contract_id": None,
            "customer_id": customer_id,
            "transaction_date": transaction_date,
            "coa_id": coa_id,
            "cost_category": cost_category,
            "amount": float(round(amount, 2))
        })

        rev_trx_id += 1

    # Create DataFrame for this month
    if month_transactions:
        month_df = spark.createDataFrame(month_transactions, schema=schema)

        # Write this month's data
        year_str = month_start.strftime('%Y')
        month_str = month_start.strftime('%m')
        month_filename = f"revenue_transactions_{year_str}_{month_str}"
        month_output_path = f"{output_path}/{month_filename}"

        print(f"  Writing to {month_filename}/ ({len(month_transactions):,} records)")

        month_df.coalesce(1).write.mode("overwrite").option("header", "true").csv(month_output_path)

        total_records_generated += len(month_transactions)

    # Move to next month
    current_date = current_date + relativedelta(months=1)

print(f"\n\nTotal transactions generated: {total_records_generated:,}")

# Read back the data for statistics
print("\nReading generated data for statistics...")
df = spark.read.option("header", "true").csv(f"{output_path}/*/*.csv")

# Show sample
print("\nSample of generated data:")
df.show(20, truncate=False)

# Statistics
print(f"\nTotal records: {df.count():,}")
print(f"Unique transaction IDs: {df.select('rev_trx_id').distinct().count():,}")

# Contract statistics
contract_count = df.filter(col('contract_id').isNotNull()).count()
total_count = df.count()
if total_count > 0:
    contract_percentage = (contract_count / total_count) * 100
    print(f"\nRecords with contract: {contract_count:,} ({contract_percentage:.1f}%)")
    print(f"Records without contract: {total_count - contract_count:,} ({100 - contract_percentage:.1f}%)")

# Category distribution
print("\nCost category distribution:")
df.groupBy("cost_category").count().orderBy("cost_category").show()

# Revenue by year
print("\nRevenue by year:")
df_with_year = df.withColumn("year", year(from_unixtime(col("transaction_date"))))
df_with_year.groupBy("year").agg(_sum("amount").alias("total_revenue")).orderBy("year").show()

# Validate contract constraints
print("\nValidating contract constraints (revenue should not exceed contract value)...")

contracts_df_ref = spark.read.json(f"{directory}/outbound_contracts") \
    .select(
        col("contract_id"),
        col("total_contract_value")
    )

# Aggregate revenue by contract
contract_totals = df.filter(col('contract_id').isNotNull()) \
    .groupBy('contract_id') \
    .agg(_sum('amount').alias('total_revenue'))

# Join with contract data and calculate percentage
validation = contract_totals.join(contracts_df_ref, "contract_id", "inner") \
    .withColumn('utilization_pct', (col('total_revenue') / col('total_contract_value')) * 100) \
    .withColumn('exceeds_limit', col('total_revenue') > col('total_contract_value'))

print("\nContract utilization (top 10 by revenue):")
validation.select(
    'contract_id',
    'total_revenue',
    'total_contract_value',
    'utilization_pct',
    'exceeds_limit'
).orderBy(col('total_revenue').desc()).show(10)

# Count violations
violations = validation.filter(col('exceeds_limit') == True).count()
total_contracts_with_revenue = validation.count()

if violations > 0:
    print(f"\n⚠ WARNING: {violations} out of {total_contracts_with_revenue} contracts exceeded their contract value!")
else:
    print(f"\n✓ All {total_contracts_with_revenue} contracts are within their contract value limits")

# Show utilization distribution
print("\nContract utilization distribution:")
validation.selectExpr(
    "CASE " +
    "WHEN utilization_pct < 25 THEN '0-25%' " +
    "WHEN utilization_pct >= 25 AND utilization_pct < 50 THEN '25-50%' " +
    "WHEN utilization_pct >= 50 AND utilization_pct < 75 THEN '50-75%' " +
    "WHEN utilization_pct >= 75 AND utilization_pct < 100 THEN '75-100%' " +
    "ELSE 'Above 100%' END as utilization_range"
).groupBy("utilization_range").count().orderBy("utilization_range").show()

print("\nData generation complete!")
print(f"\nFiles are located in: {output_path}/")
print("Each month has its own directory with a CSV file inside")
