In [2]:
import pandas as pd
import numpy as np
import pandas as pd
from scipy.stats import ks_2samp, mannwhitneyu, chi2_contingency

## Load the Data

In [30]:
original_data_splits = "../data/original_training_dataset.csv"
real_test_path = "../data/test_df.csv"

baseline_synthetic = "../data/synthetic_data_baseline_prompt.csv"
synthetic_data_no_grounding_synthetic = "../data/prompt_not_grounded_in_synthetic.csv"
synthetic_data_no_info = "../data/prompt_no_info_orignal_data.csv"

categorical_cols = [
            'numberRating', 'highestRating', 'lowestRating',
            'numberLowRating', 'numberMediumRating', 'numberHighRating',
            'numberMessageRead', 'readAllMessage', 'numberMessageReceived', "medianRating"
        ]

continuous_cols = ['sdRating']

all_columns = categorical_cols + continuous_cols


synthetic_df_no_grounding = pd.read_csv(synthetic_data_no_grounding_synthetic)
synthetic_df_no_data_info = pd.read_csv(synthetic_data_no_info)
synthetic_df = pd.read_csv(baseline_synthetic)
real_df_test = pd.read_csv(real_test_path)
real_original_data = pd.read_csv(original_data_splits)

## Run the hierarchical tests

In [18]:
def hierarchical_test(real_df,synthetic_df,columns,alpha= 0.05,categorical_cols=None):
    """
    Perform hierarchical statistical tests to compare distributions of specified columns
    """
    results = []
    if categorical_cols is None:
        categorical_cols = []

    for col in columns:
        
        real_data = real_df[col].dropna()
        synthetic_data = synthetic_df[col].dropna()
        
        result = {
            'Variable': col, 'KS_statistic': np.nan, 'KS_pvalue': np.nan, 'MW_statistic': np.nan,
            'MW_pvalue': np.nan, 'Chi2_statistic': np.nan, 'Chi2_pvalue': np.nan
        }
      
        # --- Stage 1: Kolmogorov-Smirnov Test ---
        ks_stat, ks_pvalue = ks_2samp(real_data, synthetic_data)
        result.update({'KS_statistic': ks_stat, 'KS_pvalue': ks_pvalue})
        
        if ks_pvalue >= alpha:
            result.update({'Test_Passed': 'KS', 'Final_pvalue': ks_pvalue, 'Significant_Difference': False})
            results.append(result)
            continue

        # --- Stage 2: Mann-Whitney U Test (if KS fails) ---
        mwu_stat, mwu_pvalue = mannwhitneyu(real_data, synthetic_data)
        result.update({'MW_statistic': mwu_stat, 'MW_pvalue': mwu_pvalue})
        
        if mwu_pvalue >= alpha:
            result.update({'Test_Passed': 'Mann-Whitney', 'Final_pvalue': mwu_pvalue, 'Significant_Difference': False})
            results.append(result)
            continue
            
        # --- Stage 3: Chi-Square Test ---
        if col in categorical_cols:
            contingency_table = pd.crosstab(pd.concat([real_data, synthetic_data]),['real'] * len(real_data) + ['synth'] * len(synthetic_data)
            )
            try:
                chi2_stat, pvalue, _, _ = chi2_contingency(contingency_table)
                result.update({'Chi2_statistic': chi2_stat, 'Chi2_pvalue': pvalue})
                if pvalue >= alpha:
                    result.update({'Test_Passed': 'Chi-Square', 'Final_pvalue': pvalue, 'Significant_Difference': False})
                else:
                    result.update({'Test_Passed': 'None', 'Final_pvalue': pvalue, 'Significant_Difference': True})
            except ValueError:
                result.update({'Test_Passed': 'Error', 'Final_pvalue': mwu_pvalue, 'Significant_Difference': True})
        else:
            result.update({'Test_Passed': 'None', 'Final_pvalue': mwu_pvalue, 'Significant_Difference': True})
        
        results.append(result)
        
        result_table = pd.DataFrame(results)
        
    return result_table

In [31]:
results_synthetic = hierarchical_test(real_original_data, synthetic_df, all_columns, categorical_cols=categorical_cols)
results_synthetic_no_grounding = hierarchical_test(real_original_data, synthetic_df_no_grounding, all_columns, categorical_cols=categorical_cols)
results_synthetic_no_info = hierarchical_test(real_original_data, synthetic_df_no_data_info, all_columns, categorical_cols=categorical_cols)

In [33]:
results_synthetic

Unnamed: 0,Variable,KS_statistic,KS_pvalue,MW_statistic,MW_pvalue,Chi2_statistic,Chi2_pvalue,Test_Passed,Final_pvalue,Significant_Difference
0,numberRating,0.426481,1.0600509999999999e-203,2712610.5,8.941286e-283,2453.360838,0.0,,0.0,True
1,highestRating,0.415861,2.576202e-193,3824582.5,1.114542e-165,2714.554817,0.0,,0.0,True
2,lowestRating,0.350983,2.4119469999999998e-136,6266590.5,0.05350722,,,Mann-Whitney,0.05350722,False
3,numberLowRating,0.078287,4.787631e-07,5952785.0,1.473981e-20,87.008517,1.27744e-19,,1.27744e-19,True
4,numberMediumRating,0.501045,1.48164e-285,2355477.0,0.0,3125.941746,0.0,,0.0,True
5,numberHighRating,0.032064,0.1531508,,,,,KS,0.1531508,False
6,numberMessageRead,0.132233,2.505797e-19,5369672.5,2.1551370000000002e-27,183.761807,1.3584470000000001e-39,,1.3584470000000001e-39,True
7,readAllMessage,0.288669,1.314686e-91,4593422.5,0.0,1547.443502,0.0,,0.0,True
8,numberMessageReceived,0.016674,0.8748754,,,,,KS,0.8748754,False
9,medianRating,0.425243,1.785523e-202,5462762.0,4.123329e-25,5337.213921,0.0,,0.0,True


In [35]:
results_synthetic_no_info

Unnamed: 0,Variable,KS_statistic,KS_pvalue,MW_statistic,MW_pvalue,Chi2_statistic,Chi2_pvalue,Test_Passed,Final_pvalue,Significant_Difference
0,numberRating,0.624925,4.007e-321,1423562.0,0.0,7541.313241,0.0,,0.0,True
1,highestRating,0.610453,4.29e-321,1528777.0,0.0,6126.381504,0.0,,0.0,True
2,lowestRating,0.521445,7.977099e-311,4717355.5,1.9274690000000002e-190,7312.020349,0.0,,0.0,True
3,numberLowRating,0.029046,0.2420141,,,,,KS,0.2420141,False
4,numberMediumRating,0.697601,2.58e-321,1754155.0,0.0,7136.576894,0.0,,0.0,True
5,numberHighRating,0.533253,7.045e-321,2924929.5,3.4864529999999998e-276,1556.263601,0.0,,0.0,True
6,numberMessageRead,0.11271,3.766464e-14,6184382.5,0.005831641,1069.916264,1.223645e-231,,1.223645e-231,True
7,readAllMessage,0.062109,0.0001353564,6858570.0,6.323809e-06,20.125224,7.253374e-06,,7.253374e-06,True
8,numberMessageReceived,0.014119,0.9626926,,,,,KS,0.9626926,False
9,medianRating,0.649144,3.533e-321,3156472.5,1.7391099999999998e-280,7465.659882,0.0,,0.0,True
