# Gold Dataset

This notebook creates the gold dataset that'll be used for classification and aggregation


In [0]:
import pandas as pd
from pyspark.sql import functions as F
from pyspark.sql import types as T


In [0]:
df = (
    spark.table("silver.conversations_clean")
        .join(
            spark.table("silver.synthetic_isid_mapping"),
            "conversation_id"
        )
)
df.printSchema()

In [0]:
schema = T.StructType([
            T.StructField("ISID", T.StringType(), True),
            T.StructField("conversation_id", T.StringType(), True),
            T.StructField("combined_text", T.StringType(), True),
            T.StructField("char_count", T.IntegerType(), True),
        ])

@F.pandas_udf(schema, functionType=F.PandasUDFType.GROUPED_MAP)
def combine_conversation_text(pdf):
    """
    Converts a dataframe with 'turns' array column into individual conversation texts.
    
    Input columns: ISID, conversation_id, turns (array of structs)
    Output columns: ISID, conversation_id, combined_text, char_count
    """
    results = []
    
    for _, row in pdf.iterrows():
        # Combine all turns into single text
        conv_text = " ".join([t["content_truncated"] for t in row["turns"]])
        
        results.append({
            "ISID": row["ISID"],
            "conversation_id": row["conversation_id"],
            "combined_text": conv_text,
            "char_count": len(conv_text)
        })
    
    return pd.DataFrame(results)

df_text = df.groupBy("ISID", "conversation_id").apply(combine_conversation_text)
# display(df_text)

In [0]:
df_isid_country = spark.table("gold.synthetic_country_mapping")
df_text_with_country = (
    df_text
        .join(df_isid_country, "ISID")
)


In [0]:
INITIAL_DATASET_LIMIT = 5000

df_text_included = (
    df_text_with_country
    .orderBy(F.rand())
    .limit(INITIAL_DATASET_LIMIT)
)

spark.sql("DROP TABLE gold.conversations_included")
df_text_included.write.mode("overwrite").saveAsTable("gold.conversations_included")

included_ids = [
    row["conversation_id"] for row in df_text_included.select("conversation_id")
    .collect()
    ]

df_text_excluded = (
    df_text_with_country
    .filter(
        ~F.col("conversation_id")
        .isin(included_ids)
        )
)

spark.sql("DROP TABLE gold.conversations_excluded")
df_text_excluded.write.mode("overwrite").saveAsTable("gold.conversations_excluded")

In [0]:
# display(df_text_excluded.count())
# display(df_text_included.count())