In [None]:
# Notebook 05: cube_sampling
from sampling_framework import SamplingFramework
from pyspark.sql import functions as F
import time
import matplotlib.pyplot as plt

sf = SamplingFramework(spark)
df = spark.table("processed_customers")

# 1. Scaling for balancing variables
balance_vars = ["balance", "n_web_logins", "n_mobile_logins"]
df_cube_prep = sf.prepare_for_cube(df, balance_vars)

# 2. Runtime Experiment
test_sizes = [1000, 5000, 10000, 50000, 100000]  # Adjust based on your dev cluster
times = []

for n in test_sizes:
    sub_df = df_cube_prep.limit(n).cache()
    sub_df.count() # Force materialization
    start = time.time()
    _ = sf.distributed_cube_sampling(sub_df, balance_vars, id_col="cust_ID", test_fraction=0.9)
    times.append(time.time() - start)
    sub_df.unpersist()

# Extrapolation discussion
plt.plot(test_sizes, times, 'o-')
plt.title("Cube Sampling Runtime Complexity")
plt.xlabel("N rows")
plt.ylabel("Time (seconds)")
plt.show()

# 3. Full Production Run (Example on full table)
# Note: If your ID column is named differently (e.g., "cost_ID"), pass id_col parameter
assignment = sf.distributed_cube_sampling(df_cube_prep, balance_vars, id_col="cust_ID", test_fraction=0.9)
test_cube = df.join(assignment.filter(F.col("is_test") == 1), on="cust_ID", how="inner")
ctrl_cube = df.join(assignment.filter(F.col("is_test") == 0), on="cust_ID", how="inner")

test_cube.write.saveAsTable("cube_test_group")
ctrl_cube.write.saveAsTable("cube_ctrl_group")