In [0]:
from pyspark.sql.functions import col, when, concat, lit, floor, rand, round, expr

# Create site_id column
df = spark.range(1, 51).withColumnRenamed("id", "site_id")

# Brand assignment: 'Shell' ~20%, others ~20% each
brand_expr = when(col("site_id") <= 10, lit("Shell")) \
    .when(col("site_id") <= 20, lit("Aral")) \
    .when(col("site_id") <= 30, lit("Esso")) \
    .when(col("site_id") <= 40, lit("Total")) \
    .otherwise(lit("Jet"))

df = df.withColumn("brand", brand_expr)

# site_type logic
df = df.withColumn("site_type", when(col("brand") == "Shell", lit("Internal")).otherwise(lit("Competitor")))

# site_name
df = df.withColumn("site_name", concat(col("brand"), lit(" Station "), col("site_id")))

# region random assignment (fix: use expr with array and cast index)
df = df.withColumn("region", expr("array('DE-BE', 'DE-BY', 'DE-NW', 'DE-HE')[cast(floor(rand() * 4) as int)]"))

# zone_type random assignment (fix: use expr with array and cast index)
df = df.withColumn("zone_type", expr("array('Highway', 'Urban', 'Rural')[cast(floor(rand() * 3) as int)]"))

# latitude random float between 47.0 and 55.0
df = df.withColumn("latitude", round(lit(47.0) + rand() * lit(8.0), 6))

# longitude random float between 5.0 and 15.0
df_sites = df.withColumn("longitude", round(lit(5.0) + rand() * lit(10.0), 6))

# Display schema to confirm site_id is Integer
df_sites.printSchema()

In [0]:
from pyspark.sql.types import StructType, StructField, IntegerType, StringType

schema = StructType([
    StructField("product_id", IntegerType(), False),
    StructField("category", StringType(), False),
    StructField("product_name", StringType(), False),
    StructField("grade", StringType(), False)
])

data = [
    (1, 'Petrol', 'Shell Super FuelSave 95', 'Standard'),
    (2, 'Petrol', 'Shell V-Power Racing 100', 'Premium'),
    (3, 'Diesel', 'Shell Diesel FuelSave', 'Standard'),
    (4, 'Diesel', 'Shell V-Power Diesel', 'Premium'),
    (5, 'Petrol', 'Competitor Unleaded 95', 'Standard'),
    (6, 'Diesel', 'Competitor Diesel', 'Standard')
]

df_product_master = spark.createDataFrame(data, schema)

In [0]:
from pyspark.sql import Row
from pyspark.sql.window import Window
from pyspark.sql.functions import col, rand, expr, monotonically_increasing_id

# Step A: Create Product Rules
product_rules_data = [
    Row(internal_product_id=1, competitor_product_id=5),
    Row(internal_product_id=2, competitor_product_id=5),
    Row(internal_product_id=3, competitor_product_id=6),
    Row(internal_product_id=4, competitor_product_id=6)
]
df_product_rules = spark.createDataFrame(product_rules_data)

# Step B: Create Site Links
df_internal_sites = df.filter(col("site_type") == "Internal").select(col("site_id").alias("internal_site_id"))
df_competitor_sites = df.filter(col("site_type") == "Competitor").select(col("site_id").alias("competitor_site_id"))
df_site_links = df_internal_sites.crossJoin(df_competitor_sites)

w = Window.partitionBy("internal_site_id").orderBy(rand())
df_site_links = df_site_links.withColumn("rn", expr("row_number() over (partition by internal_site_id order by rand())"))
df_site_links = df_site_links.filter(col("rn") == 1).drop("rn")

# Step C: Generate Full Mapping
df_full_mapping = df_site_links.crossJoin(df_product_rules)

# Step D: Add Strategy and mapping_id
strategy_expr = expr("array('MATCH', 'MINUS_1_CENT', 'PLUS_2_CENT')[cast(floor(rand() * 3) as int)]")
df_competitor_mapping = df_full_mapping.withColumn("strategy", strategy_expr)
df_competitor_mapping = df_competitor_mapping.withColumn("mapping_id", monotonically_increasing_id())

df_competitor_mapping = df_competitor_mapping.select(
    "mapping_id",
    "internal_site_id",
    "internal_product_id",
    "competitor_site_id",
    "competitor_product_id",
    "strategy"
)

In [0]:
from pyspark.sql.functions import sequence, explode, to_timestamp, lit, col, monotonically_increasing_id, round, rand, expr, when

# Time DataFrame: 15-min intervals from '2024-01-01 00:00:00' to '2024-01-01 23:45:00'
df_time = spark.createDataFrame([("2024-01-01 00:00:00", "2024-01-01 23:45:00")], ["start", "end"]) \
    .withColumn("timestamps", sequence(
        to_timestamp(col("start")),
        to_timestamp(col("end")),
        expr("interval 15 minutes")
    )) \
    .select(explode(col("timestamps")).alias("timestamp"))

# Cross join sites, products, and time
df_pricing_feed = df.crossJoin(df_product_master).crossJoin(df_time)

# Price calculation
df_pricing_feed = df_pricing_feed.withColumn(
    "price",
    round(
        lit(1.70) +
        when(col("grade") == "Premium", lit(0.10)).otherwise(lit(0.0)) +
        when(col("zone_type") == "Highway", lit(0.05)).otherwise(lit(0.0)) +
        (rand() * lit(0.10) - lit(0.05)),
        2
    )
)

# Add price_id
df_pricing_feed = df_pricing_feed.withColumn("price_id", monotonically_increasing_id().cast("long"))

# Select required columns
df_pricing_feed = df_pricing_feed.select(
    "price_id",
    "site_id",
    "product_id",
    "timestamp",
    "price"
)

In [0]:
from pyspark.sql.functions import monotonically_increasing_id, round, rand, floor, col

# Filter to Internal sites
df_internal_sites = df.filter(col("site_type") == "Internal").select("site_id")
df_transactions = df_pricing_feed.join(df_internal_sites, "site_id")

# Sample 10% of rows
df_transactions = df_transactions.sample(fraction=0.1)

# Add volume (random integer between 10 and 80)
df_transactions = df_transactions.withColumn(
    "volume",
    (floor(rand() * 71) + 10).cast("double")
)

# Calculate amount
df_transactions = df_transactions.withColumn(
    "amount",
    round(col("volume") * col("price"), 2)
)

# Add transaction_id
df_transactions = df_transactions.withColumn(
    "transaction_id",
    monotonically_increasing_id().cast("long")
)

# Select and order columns
df_transactions = df_transactions.select(
    "transaction_id",
    col("site_id").cast("long"),
    col("product_id").cast("long"),
    "timestamp",
    col("volume").cast("double"),
    col("amount").cast("double")
)

In [0]:
# display(df_sites)
# display(df_product_master)
# display(df_competitor_mapping)
# display(df_pricing_feed)
# display(df_transactions)

In [0]:
# df_sites.write.parquet("/Volumes/main/default/volume/FuelData/Sites", mode="overwrite")
# df_product_master.write.parquet("/Volumes/main/default/volume/FuelData/ProductMaster", mode="overwrite")
# df_competitor_mapping.write.parquet("/Volumes/main/default/volume/FuelData/CompetitorMapping", mode="overwrite")
# df_pricing_feed.write.parquet("/Volumes/main/default/volume/FuelData/PricingFeed", mode="overwrite")
# df_transactions.write.parquet("/Volumes/main/default/volume/FuelData/Transactions", mode="overwrite")

In [0]:
print("/Volumes/main/default/volume/FuelData/Sites")
print("/Volumes/main/default/volume/FuelData/ProductMaster")
print("/Volumes/main/default/volume/FuelData/CompetitorMapping")
print("/Volumes/main/default/volume/FuelData/PricingFeed")
print("/Volumes/main/default/volume/FuelData/Transactions")