## Evaluation Functions

In [None]:
import pandas as pd
import numpy as np
from itertools import combinations


def SRMSE(x_population, resamples):
    """
    Function to calculate SRMSE (Standardized Root Mean Square Error)
    - Calculates SRMSE for marginal and bivariate distributions

    Parameters
    ----------
    x_population : DataFrame
        Original population DataFrame
    resamples : DataFrame
        Generated synthetic population DataFrame

    Returns
    -------
    list
        [srmse_mar, srmse_bi] — SRMSE for marginal and bivariate distributions
    """
    # ── Marginal distribution ───────────────────────────────
    sam_marg_cnt = []
    resam_marg_cnt = []
    for col in x_population.columns:
        pop_series = x_population[col].dropna()
        syn_series = resamples[col].dropna()

        resam = syn_series.value_counts().sort_index()
        sam = pop_series.value_counts().sort_index()
        tab = pd.merge(resam, sam, left_index=True, right_index=True, how='outer').fillna(0)
        sam_prop = tab.iloc[:, 1].values / pop_series.shape[0] if pop_series.shape[0] > 0 else 0
        resam_prop = tab.iloc[:, 0].values / syn_series.shape[0] if syn_series.shape[0] > 0 else 0
        sam_marg_cnt.append(sam_prop)
        resam_marg_cnt.append(resam_prop)

    sam_marg_cnt = np.concatenate(sam_marg_cnt) if sam_marg_cnt else np.array([])
    resam_marg_cnt = np.concatenate(resam_marg_cnt) if resam_marg_cnt else np.array([])

    if sam_marg_cnt.size > 0:
        rmse_mar = np.linalg.norm(sam_marg_cnt - resam_marg_cnt) / np.sqrt(len(sam_marg_cnt))
        ybar_mar = sam_marg_cnt.mean()
        srmse_mar = rmse_mar / ybar_mar if ybar_mar != 0 else np.nan
    else:
        srmse_mar = np.nan

    # ── Bivariate distribution ──────────────────────────────
    bi_index = list(combinations(x_population.columns, 2))
    sam_bi_cnt = []
    resam_bi_cnt = []
    for col1, col2 in bi_index:
        pop_pair = x_population[[col1, col2]].dropna()
        syn_pair = resamples[[col1, col2]].dropna()

        sam = pd.DataFrame(pd.crosstab(pop_pair[col1], pop_pair[col2])).stack().sort_index()
        resam = pd.DataFrame(pd.crosstab(syn_pair[col1], syn_pair[col2])).stack().sort_index()
        sam.name = 'pop'
        resam.name = 'syn'
        tab = pd.merge(resam, sam, left_index=True, right_index=True, how='outer').fillna(0)
        sam_prop = tab.iloc[:, 1].values / pop_pair.shape[0] if pop_pair.shape[0] > 0 else 0
        resam_prop = tab.iloc[:, 0].values / syn_pair.shape[0] if syn_pair.shape[0] > 0 else 0
        sam_bi_cnt.append(sam_prop)
        resam_bi_cnt.append(resam_prop)

    sam_bi_cnt = np.concatenate(sam_bi_cnt) if sam_bi_cnt else np.array([])
    resam_bi_cnt = np.concatenate(resam_bi_cnt) if resam_bi_cnt else np.array([])

    if sam_bi_cnt.size > 0:
        rmse_bi = np.linalg.norm(sam_bi_cnt - resam_bi_cnt) / np.sqrt(len(sam_bi_cnt))
        ybar_bi = sam_bi_cnt.mean()
        srmse_bi = rmse_bi / ybar_bi if ybar_bi != 0 else np.nan
    else:
        srmse_bi = np.nan

    return [srmse_mar, srmse_bi]


def calculate_precision_recall(population_df, generated_df):
    """
    Function to compute Precision and Recall
    - Excludes rows containing any missing (NA) values across all columns
    - Converts each row to a tuple for comparison

    Parameters
    ----------
    population_df : DataFrame
        Original population DataFrame
    generated_df : DataFrame
        Generated synthetic population DataFrame

    Returns
    -------
    dict
        precision, recall, F1 score, number of unique combinations, matching combination info
    """
    all_cols = population_df.columns.tolist()

    # Drop rows with any NA values (analysis copy)
    p_pop_df = population_df.dropna(subset=all_cols)
    p_gen_df = generated_df.dropna(subset=all_cols)

    # Convert each row to a tuple
    p_pop = p_pop_df[all_cols].apply(tuple, axis=1)
    p_gen = p_gen_df[all_cols].apply(tuple, axis=1)

    # Precision: proportion of generated profiles that appear in the population
    pop_set = set(p_pop)
    gen_in_pop = [1 if profile in pop_set else 0 for profile in p_gen]
    precision = round(np.mean(gen_in_pop), 4)

    # Recall: proportion of population profiles that appear in the generated data
    gen_set = set(p_gen)
    pop_in_gen = [1 if profile in gen_set else 0 for profile in p_pop]
    recall = round(np.mean(pop_in_gen), 4)

    # F1 score
    f1_score = round(2 * (precision * recall) / (precision + recall), 4) if (precision + recall) > 0 else 0

    # Intersect sets
    matching_unique_combinations = pop_set.intersection(gen_set)

    # Count matching rows in generated data
    gen_value_counts = p_gen.value_counts()
    matching_rows_count = sum(gen_value_counts[comb] for comb in matching_unique_combinations if comb in gen_value_counts)

    unique_combinations = {
        'population': int(p_pop.nunique()),
        'generated': int(p_gen.nunique())
    }

    matching_combinations = {
        'unique_types': len(matching_unique_combinations),
        'total_count': matching_rows_count
    }

    return {
        'precision': precision,
        'recall': recall,
        'f1_score': f1_score,
        'unique_combinations': unique_combinations,
        'matching_combinations': matching_combinations
    }


def evaluate_synthetic_population(population_csv, generated_csv):
    """
    Synthetic population evaluation function
    - When reading CSV files, the string "None" is NOT treated as missing;
      only empty strings ('') are considered NA.

    Parameters
    ----------
    population_csv : str
        File path to the original population CSV
    generated_csv : str
        File path to the generated synthetic population CSV
    """
    print(f"Loading data from {population_csv} and {generated_csv}...")

    population_df = pd.read_csv(population_csv, keep_default_na=False, na_values=[''])
    generated_df = pd.read_csv(generated_csv, keep_default_na=False, na_values=[''])

    # Check and align column names (and order)
    if set(population_df.columns) != set(generated_df.columns):
        print("Warning: Column names or order don't match between datasets")
        print(f"Population columns: {population_df.columns.tolist()}")
        print(f"Generated columns: {generated_df.columns.tolist()}")
        common_cols = list(set(population_df.columns) & set(generated_df.columns))
        population_df = population_df[common_cols]
        generated_df = generated_df[common_cols]
        print(f"Using common columns: {common_cols}")

    # Data summary
    print("\n=== Data Summary ===")
    print(f"Population data: {len(population_df):,} rows, {len(population_df.columns)} columns")
    print(f"Generated data: {len(generated_df):,} rows, {len(generated_df.columns)} columns")

    # SRMSE
    print("\nCalculating SRMSE...")
    srmse_results = SRMSE(population_df, generated_df)
    print(f"SRMSE for marginal distributions: {srmse_results[0]:.4f}")
    print(f"SRMSE for bivariate distributions: {srmse_results[1]:.4f}")

    # Precision & Recall
    print("\nCalculating precision and recall...")
    pr_metrics = calculate_precision_recall(population_df, generated_df)

    print("\n=== Evaluation Metrics ===")
    print(f"Precision: {pr_metrics['precision']:.4f}")
    print(f"Recall: {pr_metrics['recall']:.4f}")
    print(f"F1 Score: {pr_metrics['f1_score']:.4f}")

    print("\n=== Unique Combinations Analysis ===")
    print(f"Population unique combinations: {pr_metrics['unique_combinations']['population']:,}")
    print(f"Generated unique combinations: {pr_metrics['unique_combinations']['generated']:,}")

    print("\n=== Matching Combinations Analysis ===")
    print(f"Precision) Total rows in generated data: {len(generated_df):,}")
    print(f"Precision) Rows matching the population: {pr_metrics['matching_combinations']['total_count']:,}")
    print(f"Precision) Percentage of rows matching: {(pr_metrics['matching_combinations']['total_count']/len(generated_df)*100):.2f}%")
    print(f"Recall   ) Total unique combinations in generated data: {pr_metrics['unique_combinations']['generated']:,}")
    print(f"Recall   ) Unique combinations matching the population: {pr_metrics['matching_combinations']['unique_types']:,}")
    print(f"Recall   ) Percentage of unique combinations matching: {(pr_metrics['matching_combinations']['unique_types']/pr_metrics['unique_combinations']['generated']*100):.2f}%")

    return {
        'srmse_marginal': srmse_results[0],
        'srmse_bivariate': srmse_results[1],
        'precision': pr_metrics['precision'],
        'recall': pr_metrics['recall'],
        'f1_score': pr_metrics['f1_score'],
        'unique_combinations': pr_metrics['unique_combinations'],
        'matching_combinations': pr_metrics['matching_combinations']
    }



Loading data from h_population.csv and generated_synthetic_data_WGAN.csv...

=== Data Summary ===
Population data: 1,066,319 rows, 13 columns
Generated data: 1,066,319 rows, 13 columns

Calculating SRMSE...
SRMSE for marginal distributions: 0.0319
SRMSE for bivariate distributions: 0.0944

Calculating precision and recall...

=== Evaluation Metrics ===
Precision: 0.8139
Recall: 0.8079
F1 Score: 0.8109

=== Unique Combinations Analysis ===
Population unique combinations: 264,005
Generated unique combinations: 263,925

=== Matching Combinations Analysis ===
Precision) Total rows in generated data: 1,066,319
Precision) Rows matching the population: 867,849
Precision) Percentage of rows matching: 81.39%
Recall   ) Total unique combinations in generated data: 263,925
Recall   ) Unique combinations matching the population: 111,562
Recall   ) Percentage of unique combinations matching: 42.27%


In [None]:
population_csv = "h_population.csv"
subset_csv = "generated_synthetic_data_WGAN.csv"
metrics = evaluate_synthetic_population(population_csv, subset_csv)


Loading data from h_population.csv and generated_synthetic_data_WGAN.csv...

=== Data Summary ===
Population data: 1,066,319 rows, 13 columns
Generated data: 1,066,319 rows, 13 columns

Calculating SRMSE...
SRMSE for marginal distributions: 0.0319
SRMSE for bivariate distributions: 0.0944

Calculating precision and recall...

=== Evaluation Metrics ===
Precision: 0.8139
Recall: 0.8079
F1 Score: 0.8109

=== Unique Combinations Analysis ===
Population unique combinations: 264,005
Generated unique combinations: 263,925

=== Matching Combinations Analysis ===
Precision) Total rows in generated data: 1,066,319
Precision) Rows matching the population: 867,849
Precision) Percentage of rows matching: 81.39%
Recall   ) Total unique combinations in generated data: 263,925
Recall   ) Unique combinations matching the population: 111,562
Recall   ) Percentage of unique combinations matching: 42.27%


In [None]:
population_csv = "h_population.csv"
subset_csv = "generated_synthetic_data_VAE.csv"
metrics = evaluate_synthetic_population(population_csv, subset_csv)

Loading data from h_population.csv and generated_synthetic_data_VAE.csv...

=== Data Summary ===
Population data: 1,066,319 rows, 13 columns
Generated data: 1,066,319 rows, 13 columns

Calculating SRMSE...
SRMSE for marginal distributions: 0.0366
SRMSE for bivariate distributions: 0.0891

Calculating precision and recall...

=== Evaluation Metrics ===
Precision: 0.7359
Recall: 0.8147
F1 Score: 0.7733

=== Unique Combinations Analysis ===
Population unique combinations: 264,005
Generated unique combinations: 331,747

=== Matching Combinations Analysis ===
Precision) Total rows in generated data: 1,066,319
Precision) Rows matching the population: 784,745
Precision) Percentage of rows matching: 73.59%
Recall   ) Total unique combinations in generated data: 331,747
Recall   ) Unique combinations matching the population: 114,126
Recall   ) Percentage of unique combinations matching: 34.40%
