In [0]:
from pyspark.sql import functions as F
from pyspark.sql import types as T
from pyspark.sql import Window
import random

In [0]:
# Load Bronze
bronze_df = spark.table("bronze.conversations")
bronze_df_sampled = (
    bronze_df
    .filter(F.col("language") == "English")
    .orderBy(F.rand())
    .limit(10000)
)
# display(bronze_df_sampled)

In [0]:
# -------------------------
# 1. EXPLODE TURNS & LIMIT TO 50
# -------------------------
df_turns = (
    bronze_df_sampled
        .select(
            "conversation_id",
            "language",
            "timestamp",
            F.posexplode("conversation").alias("turn_idx", "turn_struct")
        )
        .withColumn("role", F.col("turn_struct.role"))
        .withColumn("content", F.col("turn_struct.content"))
        .drop("turn_struct")
)

# Keep only first 50 turns per conversation_id
window_turns = (
    Window.partitionBy("conversation_id")
          .orderBy("turn_idx")
)

df_turns_50 = (
    df_turns
       .withColumn("rn", F.row_number().over(window_turns))
       .filter(F.col("rn") <= 50)
       .drop("rn")
)
# display(df_turns_50)

In [0]:
# -------------------------
# 2. TRUNCATE EACH TURN
# -------------------------
MAX_CHARS = 250 

df_truncated = (
    df_turns_50
        .withColumn("char_count", F.length(F.col("content")))
        .withColumn(
            "content_truncated",
            F.when(F.length(F.col("content")) > MAX_CHARS,
                   F.substring(F.col("content"), 1, MAX_CHARS)
            ).otherwise(F.col("content"))
        )
)

# display(df_truncated)

In [0]:
df_truncated.write.mode("overwrite").saveAsTable("silver.chat_turns_truncated")

In [0]:
df_conv = spark.table("silver.chat_turns_truncated")

df_conversations = (
    df_conv
        .groupBy("conversation_id")
        .agg(
            F.collect_list(
                F.struct(
                    "turn_idx",
                    "role",
                    "content_truncated",
                    "char_count"
                )
            ).alias("turns"),
            F.first("language").alias("language"),
            F.max("timestamp").alias("timestamp"),
            F.sum("char_count").alias("conversation_char_count_total")
        )
        .withColumn(
            "turns",
            F.expr("array_sort(turns, (left, right) -> CASE WHEN left.turn_idx < right.turn_idx THEN -1 WHEN left.turn_idx > right.turn_idx THEN 1 ELSE 0 END)")
        )
)

#display(df_conversations)

In [0]:
df_conversations.write.mode("overwrite").saveAsTable("silver.conversations_clean")

In [0]:
## SYNTHETIC USER ISID CREATION
# We create approximately 5 conversations per user
CONV_PER_USER = 5

df_isid = (
    df_conversations
        .withColumn("rand", F.rand(seed=42))
        .withColumn("row_num", F.row_number().over(Window.orderBy("rand")))
        .withColumn("ISID", (F.col("row_num") / CONV_PER_USER).cast("int"))
        .withColumn("ISID", F.concat(F.lit("user_"), F.col("ISID")))
        .select("conversation_id", "ISID")
)

df_isid.write.mode("overwrite").saveAsTable("silver.synthetic_isid_mapping")




In [0]:
## SYNTHETIC COUNTRY MAPPING CREATION
# Define countries with realistic distribution weights
countries = [
    ("United States", 30),
    ("United Kingdom", 15),
    ("Canada", 10),
    ("Germany", 8),
    ("France", 7),
    ("Australia", 6),
    ("India", 5),
    ("Japan", 4),
    ("Brazil", 4),
    ("Spain", 3),
    ("Italy", 3),
    ("Netherlands", 2),
    ("Singapore", 2),
    ("Mexico", 1)
]

# Create weighted country assignment UDF
@F.udf("string")
def assign_country(isid):
    random.seed(hash(isid) % (2**32))  # Deterministic based on ISID
    country_list = [c for c, weight in countries for _ in range(weight)]
    return random.choice(country_list)

# Get unique ISIDs and assign countries
df_isid_country = (
    df_isid
        .select("ISID")
        .distinct()
        .withColumn("country", assign_country("ISID"))
)

# Save the mapping
df_isid_country.write.mode("overwrite").saveAsTable("gold.synthetic_country_mapping")

