In [None]:
# Notebook 04: validation_and_investigation
from sampling_framework import SamplingFramework
from scipy import stats
import matplotlib.pyplot as plt
import seaborn as sns

sf = SamplingFramework(spark)
test = spark.table("stratified_test_group")
ctrl = spark.table("stratified_ctrl_group")

# 1. Standardized Mean Difference (Target < 0.1)
num_cols = ["balance", "n_web_logins", "n_mobile_logins"]
smd_report = sf.calculate_smd(test, ctrl, num_cols)
print("--- SMD Report ---")
print(smd_report)

# 2. Population Stability Index (Target < 0.1)
bal_psi = sf.calculate_psi(test, ctrl, "balance_bin")
print(f"--- PSI (Balance): {bal_psi:.6f} ---")

# 3. Kolmogorov-Smirnov Test (Distribution similarity for continuous variables)
ks_result = sf.calculate_ks_test(test, ctrl, "balance")
print(f"--- KS Test (Balance) ---")
print(f"Statistic: {ks_result['statistic']:.6f}, p-value: {ks_result['pvalue']:.6f}")

# 4. Chi-Square Test (For categorical variables)
chi2_result = sf.calculate_chi_square(test, ctrl, "visa_ind")
print(f"--- Chi-Square Test (visa_ind) ---")
print(f"Statistic: {chi2_result['statistic']:.6f}, p-value: {chi2_result['pvalue']:.6f}")

# 5. T-Test and Visualization (Sampling for local plotting)
pdf_t = test.select("balance").sample(0.2).toPandas()
pdf_c = ctrl.select("balance").sample(0.2).toPandas()

t_stat, p_val = stats.ttest_ind(pdf_t['balance'], pdf_c['balance'], equal_var=False)
print(f"--- T-Test (Balance, Sampled) ---")
print(f"T-statistic: {t_stat:.4f}, p-value: {p_val:.4f}")

plt.figure(figsize=(10,6))
sns.kdeplot(pdf_t['balance'], label='Test (90%)', fill=True)
sns.kdeplot(pdf_c['balance'], label='Control (10%)', fill=True)
plt.title("Covariate Balance Check: Customer Balance Distribution")
plt.legend()
plt.show()