In [None]:
# Notebook 03: strata_design_and_stratified_sampling
from sampling_framework import SamplingFramework
from pyspark.sql import functions as F

sf = SamplingFramework(spark)
df = spark.table("processed_customers")

# 1. Apply Binning
# Quantile binning for balance (10 bins)
df_binned = sf.apply_quantile_binning(df, "balance", n_buckets=10)
# Discrete features with many 1s - use smaller bin count
df_binned = sf.apply_quantile_binning(df_binned, "n_web_logins", n_buckets=3)

# 2. Define Strata Key
df_binned = df_binned.withColumn(
    "strata_id", 
    F.concat_ws("_", F.col("visa_ind"), F.col("balance_bin"), F.col("n_web_logins_bin"))
)

# 3. Merge Sparse Strata (Threshold = 30)
df_final = sf.merge_sparse_strata(df_binned, "strata_id", target_floor=30)

# 4. Stratified Sampling
test_strat, ctrl_strat = sf.stratified_sample(df_final, id_col="cust_ID", strata_col="strata_id_final", test_fraction=0.9)

# 5. Save and Audit
test_strat.write.format("delta").mode("overwrite").saveAsTable("stratified_test_group")
ctrl_strat.write.format("delta").mode("overwrite").saveAsTable("stratified_ctrl_group")

print("Strata count audit (smallest first):")
display(df_final.groupBy("strata_id_final").count().orderBy("count"))