In [0]:
from pyspark.sql.types import StructType, StructField, LongType, StringType
from pyspark.sql.functions import col, monotonically_increasing_id, row_number, lit
from pyspark.sql.window import Window

# -------------------- KEY VARIABLES -------------------- #
catalog = 'main'
directory = f"/Volumes/{catalog}/finance_lakehouse/data_gen_outputs"
output_path = f"{directory}/suppliers"

# Number of records to generate
num_records = 400

In [0]:
"""
Generate synthetic data for the suppliers table.
This script uses PySpark and maintains referential integrity with the third_parties table.
"""


# Define schema
schema = StructType([
    StructField("supplier_id", LongType(), nullable=False),
    StructField("third_party_id", LongType(), nullable=False),
    StructField("supplier_name", StringType(), nullable=False)
])

# Generate data
# Read third_parties data to maintain referential integrity
third_parties_path = f"{directory}/third_parties"

print(f"Reading third parties data from {third_parties_path}...")
third_parties_df = spark.read.json(third_parties_path)

total_third_parties = third_parties_df.count()
print(f"Found {total_third_parties} third parties")

print(f"Generating {num_records} records for suppliers...")

# Some customers should also be suppliers (use ~30% overlap with customer third_party_ids)
# Customers use first 300 third_party_ids
# Let's make 100 suppliers overlap with customers, and 300 be unique suppliers
overlap_count = 100
unique_suppliers_count = num_records - overlap_count

# Get customers' third parties (first 100 to create overlap)
customers_third_parties = third_parties_df.orderBy(col("third_party_id")).limit(overlap_count)

# Get unique third parties for suppliers (skip first 300 used by customers)
unique_suppliers_third_parties = third_parties_df.orderBy(col("third_party_id")) \
    .limit(300 + unique_suppliers_count) \
    .subtract(third_parties_df.orderBy(col("third_party_id")).limit(300))

# Combine overlapping customers and unique suppliers
sampled_df = customers_third_parties.union(unique_suppliers_third_parties)

# Add supplier_id using row_number
window_spec = Window.orderBy(monotonically_increasing_id())
df = sampled_df.withColumn("row_num", row_number().over(window_spec)) \
    .withColumn("supplier_id", col("row_num") + lit(29999999)) \
    .select(
        col("supplier_id"),
        col("third_party_id"),
        col("third_party_name").alias("supplier_name")
    )

# Show sample
print("\nSample of generated data:")
df.show(10, truncate=False)

print(f"\nTotal records: {df.count()}")
print(f"Unique supplier IDs: {df.select('supplier_id').distinct().count()}")
print(f"Unique third party IDs: {df.select('third_party_id').distinct().count()}")
print(f"Unique supplier names: {df.select('supplier_name').distinct().count()}")

# Write to JSON
print(f"\nWriting data to {output_path}...")
df.coalesce(1).write.mode("overwrite").json(output_path)

print("Data generation complete!")
