# Imports

In [1]:
import pandas as pd
from pathlib import Path
import numpy as np
from typing import Optional, Tuple, Union
import seaborn as sns
from syntheval import SynthEval
from datetime import datetime
import matplotlib.pyplot as plt

In [2]:
SAVE_FIG = True

REAL_FILE = "20250301_data_20250510_122405_final_100_train.csv"
SYNTH_FILE = "20250301_data_20250510_122405_final_100_synth.csv"
HOLDOUT_FILE = "20250301_data_20250510_122405_final_100_holdout.csv"

DATA_DIR = Path("../../data")
OUTPUT_DIR_FIG = Path("figures")
OUTPUT_DIR_FIG.mkdir(parents=True, exist_ok=True)  # ensure output folder exists

sns.set_theme(
    style="white",
    context="paper",
    palette="colorblind",
    rc={
        "font.family": "sans-serif",
        "font.sans-serif": ["Arial"],
        "font.size": 7,
        "axes.titlesize": 7,
        "axes.labelsize": 7,
        "xtick.labelsize": 7,
        "ytick.labelsize": 7,
        "legend.fontsize": 7
    }
)

# Utility Functions

In [3]:
def convert_dtypes(df: pd.DataFrame) -> pd.DataFrame:
    """
    Konvertiert bestimmte Spalten des DataFrames in die gewünschten Datentypen:
     - definierte Spalten als category
     - definierte Spalten als pandas Nullable Integer (Int64)
     - consciousness_level und news_score als geordnete Categoricals
    """
    df = df.copy()  # Änderungen nicht am Original vornehmen

    # 1) Kategorische Spalten
    cat_cols = ['gender', 'ethnicity', 'chief_complaint', 'icd_block']
    for col in cat_cols:
        df[col] = df[col].astype('category')

    # 2) Integer-Spalten mit Nullable Integer dtype
    int_cols = ['age', 'systolic_bp', 'diastolic_bp',
                'heart_rate', 'respiratory_rate', 'oxygen_saturation']
    for col in int_cols:
        df[col] = df[col].astype('Int64')

    # 3) Geordnete Categoricals
    df['consciousness_level'] = pd.Categorical(
        df['consciousness_level'],
        categories=['A', 'C', 'V', 'P', 'U'],
        ordered=True
    )
    df['news_score'] = pd.Categorical(
        df['news_score'],
        categories=list(range(19)),
        ordered=True
    )

    return df

def load_data(
    real_filename: Union[str, Path],
    synth_filename: Optional[Union[str, Path]] = None,
    holdout_filename: Optional[Union[str, Path]] = None,
    data_dir: Path = DATA_DIR
) -> Tuple[pd.DataFrame, Optional[pd.DataFrame], Optional[pd.DataFrame]]:
    """
    Lädt die realen, synthetischen und optionalen Holdout-CSV-Dateien
    aus data_dir und wandelt sie über convert_dtypes um.

    Returns:
        df_real: pd.DataFrame
        df_synth: Optional[pd.DataFrame]
        df_holdout: Optional[pd.DataFrame]
    """
    def _read_and_convert(fn: Union[str, Path]) -> pd.DataFrame:
        return (
            pd.read_csv(data_dir / fn, low_memory=False)
              .pipe(convert_dtypes)
        )

    df_real    = _read_and_convert(real_filename)
    df_synth   = _read_and_convert(synth_filename)   if synth_filename   else None
    df_holdout = _read_and_convert(holdout_filename) if holdout_filename else None

    return df_real, df_synth, df_holdout

# Main Routine

In [4]:
df_real, df_synth, df_holdout = load_data(
    REAL_FILE,
    SYNTH_FILE,
    HOLDOUT_FILE,
    data_dir=DATA_DIR
)

In [5]:
num_attrs = df_real.select_dtypes(include=['number']).columns
cat_attrs = df_real.select_dtypes(exclude=['number']).columns

## Correlation Matrix

In [None]:
df_real = df_real[sorted(df_real.columns)]
df_synth = df_synth[sorted(df_synth.columns)]

       age chief_complaint consciousness_level  diastolic_bp ethnicity gender  \
0       71           Other                   V            77     White      M   
1       91  abdominal pain                   A            47     White      M   
2       87      food bolus                   A            90     White      F   
3       87           Other                   V            80     White      F   
4       83           fever                   A            87     White      M   
...    ...             ...                 ...           ...       ...    ...   
98234   62           Other                   A            98     White      F   
98235   61           Other                   A            91     Other      M   
98236   45      flank pain                   A            68     White      F   
98237   48           Other                   A           108     White      F   
98238   65           Other                   A            56     White      M   

       heart_rate icd_block

In [23]:
# Laufzeit: 5s
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
se_cm = SynthEval(df_real, cat_cols=cat_attrs, verbose=True)

result = se_cm.evaluate(
    df_synth,
    corr_diff={
        'mixed_corr': True,
        'return_mats': True,
        "axs_scale": "RdBu_r"
    }
)

if SAVE_FIG:
    plt.gcf().set_size_inches(5.4, 3.5)
    plt.rcParams.update({'font.size': 8})
    plt.savefig(OUTPUT_DIR_FIG / f"heatmap_{ts}.tiff", dpi=300, format='tiff', bbox_inches='tight', pil_kwargs={'compression': 'tiff_lzw'})
    plt.close()


SynthEval: synthetic data read successfully


Syntheval: corr_diff: 100%|██████████| 1/1 [00:03<00:00,  3.25s/it]



SynthEval results

Utility metric description                    value   error                                 
+---------------------------------------------------------------+
| Mixed correlation matrix difference      :   1.3344           |
+---------------------------------------------------------------+
    


### Export Stats

In [42]:
raw_cm = se_cm._raw_results['corr_diff']
real_corr  = raw_cm['real_cor_mat']
real_corr.to_csv(OUTPUT_DIR_FIG / f"real_corr_{ts}.csv", index=False, header=True, sep=';', decimal=',', encoding = 'utf-8')
synth_corr = raw_cm['synt_cor_mat']
synth_corr.to_csv(OUTPUT_DIR_FIG / f"synth_corr_{ts}.csv", index=False, header=True, sep=';', decimal=',', encoding = 'utf-8')
diff_corr  = raw_cm['diff_cor_mat']
diff_corr.to_csv(OUTPUT_DIR_FIG / f"diff_corr_{ts}.csv", index=False, header=True, sep=';', decimal=',', encoding = 'utf-8')

## Mutual Information Matrix

In [41]:
from syntheval.metrics.utility.metric_mutual_information import _pairwise_attributes_mutual_information

# Laufzeit: ~3 min
se_mi = SynthEval(df_real, cat_cols=cat_attrs, verbose=True)

results_mi = se_mi.evaluate(
    df_synth,
    mi_diff={
        # "axs_scale": "RdBu_r"
        "axs_scale": "viridis"
    }
)

if SAVE_FIG:
    plt.gcf().set_size_inches(5.4, 4.5)
    plt.rcParams.update({'font.size': 8})
    plt.savefig(OUTPUT_DIR_FIG / f"mi_heatmap_{ts}.tiff", dpi=300, format='tiff', bbox_inches='tight', pil_kwargs={'compression': 'tiff_lzw'})
    plt.close()

SynthEval: synthetic data read successfully


Syntheval: mi_diff: 100%|██████████| 1/1 [06:31<00:00, 391.64s/it]



SynthEval results

Utility metric description                    value   error                                 
+---------------------------------------------------------------+
| Pairwise mutual information difference   :   0.3348           |
+---------------------------------------------------------------+
    


### Export Stats

In [None]:
mi_matrix_real = _pairwise_attributes_mutual_information(df_real)


In [None]:
print(mi_matrix_real)
mi_matrix_real.to_csv(OUTPUT_DIR_FIG / f"mi_matrix_real_{ts}.csv", index=False, header=True, sep=';', decimal=',', encoding = 'utf-8')

In [39]:
mi_matrix_synth = _pairwise_attributes_mutual_information(df_synth)
mi_matrix_synth.to_csv(OUTPUT_DIR_FIG / f"mi_matrix_synth_{ts}.csv", index=False, header=True, sep=';', decimal=',', encoding = 'utf-8')

In [40]:
mi_diff = mi_matrix_real - mi_matrix_synth
mi_diff.to_csv(OUTPUT_DIR_FIG / f"mi_diff_{ts}.csv", index=False, header=True, sep=';', decimal=',', encoding = 'utf-8')