In [0]:
from pyspark.sql.types import StructType, StructField, LongType, StringType
from pyspark.sql.functions import col, monotonically_increasing_id, row_number, lit
from pyspark.sql.window import Window
import random

# -------------------- KEY VARIABLES -------------------- #
catalog = 'fin_demo'
directory = f"/Volumes/{catalog}/fin/data_gen_outputs"
output_path = f"{directory}/customers"

# Number of records to generate
num_records = 300


In [0]:
"""
Generate synthetic data for the customers table.
This script uses PySpark and maintains referential integrity with the third_parties table.
"""

# Set random seed for reproducibility
random.seed(42)

# Define schema
schema = StructType([
    StructField("customer_id", LongType(), nullable=False),
    StructField("third_party_id", LongType(), nullable=False),
    StructField("customer_name", StringType(), nullable=False),
    StructField("customer_industry", StringType(), nullable=False)
])

# Customer industries from YAML enum
customer_industries = [
    "Manufacturing",
    "Technology",
    "Finance",
    "Builders",
    "Energy",
    "Real Estate",
    "Media & Telecom"
]

# Generate data
print(f"Generating {num_records} records for customers...")

# Read third_parties data to maintain referential integrity
third_parties_path = f"{directory}/third_parties"

print(f"Reading third parties data from {third_parties_path}...")
third_parties_df = spark.read.json(third_parties_path)

# Each customer should link to a unique third_party_id (no replacement)
# Limit to exact number needed
sampled_df = third_parties_df.limit(num_records)

# Add customer_id and customer_industry using row_number
window_spec = Window.orderBy(monotonically_increasing_id())
df = sampled_df.withColumn("row_num", row_number().over(window_spec)) \
    .withColumn("customer_id", col("row_num") + lit(19999999)) \
    .select(
        col("customer_id"),
        col("third_party_id"),
        col("third_party_name").alias("customer_name")
    )

# Add customer_industry by assigning industries in a round-robin fashion
# Convert to pandas for easier manipulation, then back to Spark
df_pandas = df.toPandas()
df_pandas['customer_industry'] = [customer_industries[i % len(customer_industries)] for i in range(len(df_pandas))]
df = spark.createDataFrame(df_pandas, schema=schema)

# Show sample
print("\nSample of generated data:")
df.show(10, truncate=False)

print(f"\nTotal records: {df.count()}")
print(f"Unique customer IDs: {df.select('customer_id').distinct().count()}")
print(f"Unique third party IDs: {df.select('third_party_id').distinct().count()}")
print(f"Unique customer names: {df.select('customer_name').distinct().count()}")

# Write to JSON
print(f"\nWriting data to {output_path}...")
df.coalesce(1).write.mode("overwrite").json(output_path)

print("Data generation complete!")
