In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, expr, lit, floor, rand, when, from_unixtime
from pyspark.sql.types import StructType, StructField, LongType, StringType, TimestampType, DoubleType
import time

spark = SparkSession.builder \
    .appName("generate_large_orders") \
    .config("spark.sql.shuffle.partitions", "400") \
    .getOrCreate()

# Parameters you can change
NUM_ROWS = 100_000_000  # 100 million
TARGET_PARTITIONS = 400  # roughly num_cores * 3
OUTPUT_PATH = "/data/test/orders_parquet"

# 1) create a base range DF â€” this is very fast and parallel
base = spark.range(0, NUM_ROWS, numPartitions=TARGET_PARTITIONS).toDF('idx')

# 2) derive columns from idx (deterministic, reproducible)
orders = (
    base
    .withColumn('order_id', col('idx') + 1)
    .withColumn('user_id', (col('idx') % 10_000_000) + 1)  # 10M unique users
    .withColumn('product_id', (col('idx') % 1_000_000) + 1)  # 1M products
    .withColumn('price', (floor(rand(seed=42) * 10000) / 100).cast('double'))
    .withColumn('currency', when((col('idx') % 3) == 0, lit('USD')).when((col('idx') % 3) == 1, lit('EUR')).otherwise(lit('INR')))
    .withColumn('country', expr("CASE WHEN idx % 50 = 0 THEN 'US' WHEN idx % 50 = 1 THEN 'IN' WHEN idx % 50 = 2 THEN 'DE' ELSE 'GB' END"))
    .withColumn('payment_method', expr("CASE WHEN idx % 6 = 0 THEN 'card' WHEN idx % 6 = 1 THEN 'paypal' WHEN idx % 6 = 2 THEN 'upi' ELSE 'netbanking' END"))
    .withColumn('event_time', from_unixtime((col('idx') % (86400 * 365)) + 1609459200).cast('timestamp'))  # spread over 2021
    .select('order_id','user_id','event_time','product_id','price','currency','country','payment_method')
)

# 3) Optional: introduce skew by duplicating some user ranges (example: heavy users)
heavy_users = orders.filter((col('user_id') >= 1) & (col('user_id') <= 1000))
# write full orders dataset
orders.repartition(TARGET_PARTITIONS, col('country')) \
      .write.mode('overwrite') \
      .partitionBy('country') \
      .parquet(OUTPUT_PATH)

print('Done writing to', OUTPUT_PATH)

# Stop Spark