# Generating Synthetic Medical Data with GANs
*A Deep Learning Approach for Privacy-Preserving Healthcare Analytics*

## Introduction
This notebook contains the end-to-end implementation for my dissertation, which explores the use of Generative Adversarial Networks (GANs) to generate synthetic medical datasets.  
The aim is to create data that retains the statistical properties and predictive utility of real patient records, while ensuring privacy and reducing the risk of re-identification.


# Installing all dependencies Required to run

In [None]:
!pip install -r requirements.txt

## Data Preparation: Merging and UCI Diabetes Dataset

The original UCI Diabetes 130-US Hospitals dataset is split into two files:  
1. **diabetic_data.csv** – Contains patient-level diabetic and hospitalization records.  
2. **IDS_mapping.csv** – Contains mapping tables that convert coded identifiers (eg: admission_type_id) into descriptive categories.


In this section, we:
- Merge both csv files.
- Handle missing values and remove irrelevant columns.
- Encode categorical features into numerical form for downstream GAN training.
- Save the cleaned dataset for further processing.

In [None]:
import pandas as pd
import numpy as np
import logging

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Loading the csv files
diabetic_df = pd.read_csv('Dataset/rawData/diabetic_data.csv')
ids_mapping_df = pd.read_csv('Dataset/rawData/IDS_mapping.csv')

# Removing irrelevant or high missing value columns
columns_to_drop = ['weight', 'payer_code', 'medical_specialty']
diabetic_df.drop(columns=columns_to_drop, inplace=True)

# Standardize missing values
diabetic_df.replace("?", pd.NA, inplace=True)

# Ensure that 'admission_type_id' is stored as string in both datasets to avoid merge errors.
diabetic_df['admission_type_id'] = diabetic_df['admission_type_id'].astype(str)
ids_mapping_df['admission_type_id'] = ids_mapping_df['admission_type_id'].astype(str)

# Prevents row duplication during merge by keeping only the first occurrence.
ids_mapping_df = ids_mapping_df.drop_duplicates(subset='admission_type_id', keep='first')

#Merge to bring in descriptive admission type labels alongside patient records.
merged_df = diabetic_df.merge(ids_mapping_df, on='admission_type_id', how='left')

#Convert categorical columns to numerical codes.
for column in merged_df.columns:
    if merged_df[column].dtype == 'object' or merged_df[column].dtype.name == 'category':
        merged_df[column] = merged_df[column].astype('category').cat.codes

# Replace -1 (code for NaN) with actual NaN for clarity.
for column in merged_df.columns:
    if (merged_df[column] == -1).any():
        merged_df[column] = merged_df[column].replace(-1, pd.NA)

#Remove duplicate rows
merged_df.drop_duplicates(inplace=True)

#Save the cleaned dataset
cleaned_output_file_path = 'Dataset/MergedDataset/diabetic_data_cleaned_for_Imputation.csv'
merged_df.to_csv(cleaned_output_file_path, index=False)

logging.info(f" Cleaned dataset saved to: {cleaned_output_file_path}")

 Cleaned dataset saved to: Dataset/MergedDataset/diabetic_data_cleaned_for_Imputation.csv


# Imputation Phase #

## KNN Imputation

To ensure our GAN models receive complete datasets, we perform **K-Nearest Neighbors (KNN) imputation** on the cleaned UCI Diabetes dataset.  
KNN imputation works by finding the `k` most similar rows (neighbors) for each missing value and imputing the value based on their average. 


### Key Steps:
1. Load the cleaned dataset.
2. Preserve certain identifier columns that should not be imputed.
3. Apply KNN imputation to the remaining features.
4. Round encoded categorical variables to restore valid category integers.
5. Merge imputed data with preserved columns and save.

In [None]:
import pandas as pd
from sklearn.impute import KNNImputer
import logging


# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

#Loading the cleaned dataset
dataSet_file_path = "Dataset/MergedDataset/diabetic_data_cleaned_for_Imputation.csv"
df = pd.read_csv(dataSet_file_path)

#Preserving identifier columns
preserve_cols = ["encounter_id", "patient_nbr", "examide", "citoglipton"]
preserved_df = df[preserve_cols]
impute_df = df.drop(columns=preserve_cols)

#Apply KNN Imputation
knn_imputer = KNNImputer(n_neighbors=5)
knn_imputed = knn_imputer.fit_transform(impute_df)

# Convert imputed NumPy array back to DataFrame
knn_imputed_df = pd.DataFrame(knn_imputed, columns=impute_df.columns)


#Round categorical columns
categorical_columns_to_round = [
    'race', 'gender', 'max_glu_serum', 'A1Cresult', 'readmitted',
    'change', 'diabetesMed', 'glipizide', 'glimepiride', 'chlorpropamide',
    'repaglinide', 'metformin', 'pioglitazone', 'acarbose', 'miglitol',
    'glyburide', 'insulin', 'glyburide-metformin', 'rosiglitazone', 'nateglinide',
    'tolazamide', 'tolbutamide', 'acetohexamide', 'troglitazone',
    'metformin-rosiglitazone', 'metformin-pioglitazone', 'glipizide-metformin',
    'glimepiride-pioglitazone', 'examide', 'citoglipton'
]

for column in categorical_columns_to_round:
    if column in knn_imputed_df.columns:
        knn_imputed_df[column] = knn_imputed_df[column].round(0).astype(int)

#Merge preserved and imputed data
final_knn_imp_df = pd.concat([preserved_df.reset_index(drop=True), knn_imputed_df], axis=1)

#save the imputed dataset
imp_file_path = "ImputedData/diabetic_data_KNN_imputed.csv"
final_knn_imp_df.to_csv(imp_file_path, index=False)

logging.info(f"KNN imputed dataset saved to: {imp_file_path}")


KNN imputed dataset saved to: ImputedData/diabetic_data_KNN_imputed.csv


## Handling Missing Values: Mean Imputation

In this step, we address missing values using **Mean Imputation**.  
This approach replaces each missing value in a numeric column with the **average value** of that column, preserving the overall distribution while ensuring no missing entries remain.  

**Process Overview:**
1. Load the cleaned dataset.
2. Preserve identifier columns that must remain unchanged.
3. Apply mean imputation to the remaining features.
4. Round categorical encoded features back to valid integer categories.
5. Save the fully imputed dataset for downstream GAN training.


In [None]:
import pandas as pd 
from sklearn.impute import SimpleImputer
import logging

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Loading the cleaned dataset
dataSet_file_path = "Dataset/MergedDataset/diabetic_data_cleaned_for_Imputation.csv"
df = pd.read_csv(dataSet_file_path)

# Preserving identifier columns
preserve_cols = ["encounter_id", "patient_nbr", "examide", "citoglipton"]
preserved_df = df[preserve_cols]
impute_df = df.drop(columns=preserve_cols)

# Performing mean imputation on remaining features
mean_imputer = SimpleImputer(strategy='mean')
mean_imputed = mean_imputer.fit_transform(impute_df)
mean_imputed_df = pd.DataFrame(mean_imputed, columns=impute_df.columns)

# Rounding categorical columns back to integers
categorical_columns_to_round = [
    'race', 'gender', 'max_glu_serum', 'A1Cresult', 'readmitted',
    'change', 'diabetesMed', 'glipizide', 'glimepiride', 'chlorpropamide',
    'repaglinide', 'metformin', 'pioglitazone', 'acarbose', 'miglitol',
    'glyburide', 'insulin', 'glyburide-metformin', 'rosiglitazone', 'nateglinide',
    'tolazamide', 'tolbutamide', 'acetohexamide', 'troglitazone',
    'metformin-rosiglitazone', 'metformin-pioglitazone', 'glipizide-metformin',
    'glimepiride-pioglitazone', 'examide', 'citoglipton'
]

for column in categorical_columns_to_round:
    if column in mean_imputed_df.columns:
        mean_imputed_df[column] = mean_imputed_df[column].round(0).astype(int)

# Combining the preserved identifier columns with the imputed dataset
final_mean_imp_df = pd.concat([preserved_df.reset_index(drop=True), mean_imputed_df], axis=1)

# Saving the final mean-imputed dataset
imp_file_path = "ImputedData/diabetic_data_MEAN_imputed.csv"
final_mean_imp_df.to_csv(imp_file_path, index=False)

logging.info(f"Mean imputed dataset saved to: {imp_file_path}")

Mean imputed dataset saved to: ImputedData/diabetic_data_MEAN_imputed.csv


## GAIN Imputation

In this step, we apply **Generative Adversarial Imputation Networks (GAIN)**

A GAN-based deep learning method designed to accurately fill in missing values by learning the underlying data distribution.  
Unlike simpler imputation methods (mean, KNN), GAIN uses an adversarial training process between a **Generator** which imputes missing data and a **Discriminator** which tries to distinguish between real and imputed values.
  
This approach helps produce more realistic and statistically consistent imputations.

**Process Overview:**
1. Load the cleaned dataset.
2. Normalize data while retaining missing value masks.
3. Define GAIN’s generator and discriminator networks.
4. Train GAIN using adversarial learning.
5. Reverse normalization and round categorical values.
6. Save the final imputed dataset.

In [None]:
import numpy as np
import pandas as pd
import tensorflow.compat.v1 as tf
from sklearn.preprocessing import MinMaxScaler
import logging

# Disable TensorFlow v2 behavior for compatibility
tf.disable_v2_behavior()
tf.logging.set_verbosity(tf.logging.ERROR)

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')


#Load and Prepare the Dataset

dataset_path = "Dataset/MergedDataset/diabetic_data_cleaned_for_Imputation.csv"
diabetes_df = pd.read_csv(dataset_path)
logging.info(f"Dataset loaded: {diabetes_df.shape} rows and columns")

# Normalize dataset and create missing value mask
scaler = MinMaxScaler()
dataset_array = diabetes_df.to_numpy()
normalized_data = scaler.fit_transform(pd.DataFrame(dataset_array).fillna(0))
missing_value_mask = 1 - np.isnan(dataset_array).astype(int)
logging.info("Data normalized and missing value mask created.")


#Define GAIN Generator and Discriminator
#Generator for imputing missing values.
def generator(input_data, mask, feature_dim):
    inputs = tf.concat([input_data, mask], axis=1)
    hidden_layer1 = tf.keras.layers.Dense(units=feature_dim, activation=tf.nn.relu)(inputs)
    hidden_layer2 = tf.keras.layers.Dense(units=feature_dim, activation=tf.nn.relu)(hidden_layer1)
    generated_output = tf.keras.layers.Dense(units=feature_dim, activation=tf.nn.sigmoid)(hidden_layer2)
    return generated_output

#Discriminator to distinguish observed vs. imputed values.
def discriminator(imputed_data, hint, feature_dim):
    inputs = tf.concat([imputed_data, hint], axis=1)
    hidden_layer1 = tf.keras.layers.Dense(units=feature_dim, activation=tf.nn.relu)(inputs)
    hidden_layer2 = tf.keras.layers.Dense(units=feature_dim, activation=tf.nn.relu)(hidden_layer1)
    logits = tf.keras.layers.Dense(units=feature_dim)(hidden_layer2)
    prob_output = tf.nn.sigmoid(logits)
    return prob_output


#GAIN training Function
def gain(imperfect_data, gain_params):
    batch_size = gain_params['batch_size']
    hint_rate = gain_params['hint_rate']
    alpha = gain_params['alpha']
    max_iterations = gain_params['iterations']

    num_samples, num_features = imperfect_data.shape
    missing_mask = 1 - np.isnan(imperfect_data).astype(int)

    # Placeholders
    X_placeholder = tf.placeholder(tf.float32, shape=[None, num_features])
    M_placeholder = tf.placeholder(tf.float32, shape=[None, num_features])
    H_placeholder = tf.placeholder(tf.float32, shape=[None, num_features])

    # Model architecture
    generated_samples = generator(X_placeholder, M_placeholder, num_features)
    combined_data = X_placeholder * M_placeholder + generated_samples * (1 - M_placeholder)
    discriminator_probs = discriminator(combined_data, H_placeholder, num_features)

    # Loss Functions
    discriminator_loss = -tf.reduce_mean(
        M_placeholder * tf.math.log(discriminator_probs + 1e-8) +
        (1 - M_placeholder) * tf.math.log(1. - discriminator_probs + 1e-8)
    )
    generator_loss_unsupervised = -tf.reduce_mean((1 - M_placeholder) * tf.math.log(discriminator_probs + 1e-8))
    mse_loss = tf.reduce_mean((M_placeholder * X_placeholder - M_placeholder * generated_samples) ** 2) / tf.reduce_mean(M_placeholder)
    generator_loss = generator_loss_unsupervised + alpha * mse_loss

    # Optimizers
    discriminator_optimizer = tf.train.AdamOptimizer().minimize(discriminator_loss)
    generator_optimizer = tf.train.AdamOptimizer().minimize(generator_loss)

    # TensorFlow session
    session = tf.Session()
    session.run(tf.global_variables_initializer())
    logging.info("Starting GAIN training...")

    for iteration in range(max_iterations):
        batch_indices = np.random.choice(num_samples, batch_size, replace=False)
        batch_data = imperfect_data[batch_indices, :]
        batch_mask = missing_mask[batch_indices, :]

        noise = np.random.uniform(0, 0.01, size=[batch_size, num_features])
        batch_data = np.nan_to_num(batch_data, nan=0.0)
        batch_data = batch_mask * batch_data + (1 - batch_mask) * noise

        hint_matrix_temp = np.random.binomial(1, hint_rate, size=[batch_size, num_features])
        hint_matrix = batch_mask * hint_matrix_temp

        session.run(discriminator_optimizer, feed_dict={
            X_placeholder: batch_data, M_placeholder: batch_mask, H_placeholder: hint_matrix
        })
        session.run(generator_optimizer, feed_dict={
            X_placeholder: batch_data, M_placeholder: batch_mask, H_placeholder: hint_matrix
        })

        if iteration % 1000 == 0:
            logging.info(f"Iteration {iteration}/{max_iterations} completed.")

    logging.info("GAIN training completed.")

    # Final Imputation pass
    noise_full = np.random.uniform(0, 0.01, size=[num_samples, num_features])
    filled_data = np.nan_to_num(imperfect_data, nan=0.0)
    filled_data = missing_mask * filled_data + (1 - missing_mask) * noise_full

    imputed_result = session.run(generated_samples, feed_dict={
        X_placeholder: filled_data, M_placeholder: missing_mask
    })
    completed_data = imperfect_data.copy()
    completed_data[np.isnan(imperfect_data)] = imputed_result[np.isnan(imperfect_data)]

    session.close()
    return completed_data


# Run GAIN imputation
gain_hyperparameters = {
    'batch_size': 128,
    'hint_rate': 0.5,
    'alpha': 10,
    'iterations': 20000
}
logging.info("Running GAIN imputation...")
gain_imputed_data = gain(normalized_data.copy(), gain_hyperparameters)

# Reverse normalization
gain_imputed_df = pd.DataFrame(scaler.inverse_transform(gain_imputed_data), columns=diabetes_df.columns)

# Round categorical columns back to integers
categorical_columns = [
    'race', 'gender', 'max_glu_serum', 'A1Cresult', 'readmitted',
    'change', 'diabetesMed', 'glipizide', 'glimepiride', 'chlorpropamide',
    'repaglinide', 'metformin', 'pioglitazone', 'acarbose', 'miglitol',
    'glyburide', 'insulin', 'glyburide-metformin', 'rosiglitazone', 'nateglinide',
    'tolazamide', 'tolbutamide', 'acetohexamide', 'troglitazone',
    'metformin-rosiglitazone', 'metformin-pioglitazone', 'glipizide-metformin',
    'glimepiride-pioglitazone', 'examide', 'citoglipton'
]
for col in categorical_columns:
    if col in gain_imputed_df.columns:
        gain_imputed_df[col] = gain_imputed_df[col].round().astype(int)

# Save the final imputed dataset
output_path = "ImputedData/diabetic_data_GAIN_imputed.csv"
gain_imputed_df.to_csv(output_path, index=False)
logging.info(f"GAIN-imputed dataset saved at: {output_path}")


## Generating Synthetic Medical Data with CTGAN

In this section, we will leverage the **CTGAN (Conditional Tabular GAN)** model from the SDV Synthetic Data Vault library to generate high-quality synthetic patient records.  
CTGAN is particularly effective for handling **mixed-type tabular data** with both continuous and categorical features, making it suitable for our medical dataset.  

We will:
1. Load the imputed dataset (GAIN, KNN, or Mean).
2. Define the metadata and specify discrete (categorical) columns.
3. Train the CTGAN model on the dataset.
4. Generate synthetic records in batches.
5. Save the synthetic dataset for downstream evaluation.


In [None]:
import pandas as pd
from sdv.metadata import SingleTableMetadata
from sdv.single_table import CTGANSynthesizer
from tqdm import tqdm
import time
import torch
import numpy as np
import logging

# Configure logging
logging.basicConfig(
    level=logging.INFO,  # Change to DEBUG for more detailed logs
    format="%(asctime)s - %(levelname)s - %(message)s"
)

def generate_synthetic_dataset(imputed_file_path, imputation_method):
    start_time = time.time()

    # Ensure GPU availability
    if torch.cuda.is_available():
        device_name = torch.cuda.get_device_name(0)
        logging.info(f"CUDA enabled device: {torch.cuda.get_device_name(0)}")
        device = torch.device("cuda")
    else:
        logging.warning("CUDA is not available. Falling back to CPU.")
        device = torch.device("cpu")
    
    # Load imputed dataset
    logging.info(f"Loading {imputation_method} imputed dataset")
    imputed_df = pd.read_csv(imputed_file_path)

    # Preserve identifier columns
    preserve_cols = ["encounter_id", "patient_nbr", "examide", "citoglipton"]
    preserved_df = imputed_df.drop(columns=preserve_cols, errors='ignore')
    preserved_identifiers_df = imputed_df[preserve_cols].reset_index(drop=True)

    # Detect metadata
    table_metadata = SingleTableMetadata()
    table_metadata.detect_from_dataframe(preserved_df)

    # Mark categorical columns
    categorical_columns = [
        'admission_type_id', 'discharge_disposition_id', 'admission_source_id',
        'time_in_hospital', 'num_lab_procedures', 'num_procedures',
        'num_medications', 'number_outpatient', 'number_emergency',
        'number_inpatient', 'number_diagnoses', 'max_glu_serum', 'A1Cresult',
        'metformin', 'repaglinide', 'nateglinide', 'chlorpropamide',
        'glimepiride', 'acetohexamide', 'glipizide', 'glyburide',
        'tolbutamide', 'pioglitazone', 'rosiglitazone', 'acarbose',
        'miglitol', 'troglitazone', 'tolazamide', 'insulin',
        'glyburide-metformin', 'glipizide-metformin', 'glimepiride-pioglitazone',
        'metformin-rosiglitazone', 'metformin-pioglitazone',
        'change', 'diabetesMed', 'readmitted', 'description'
    ]
    for column in categorical_columns:
        if column in preserved_df.columns:
            table_metadata.update_column(column_name=column, sdtype='categorical')

    # Initialize CTGAN synthesizer
    synthesizer = CTGANSynthesizer(
        table_metadata,
        epochs=300,
        batch_size=500,
        verbose=True,
        log_frequency=10,
    )

    # Train CTGAN model
    logging.info(f"Training CTGAN on {imputation_method} imputed dataset")
    synthesizer.fit(preserved_df)

    # Generate synthetic data in batches
    logging.info("Generating synthetic dataset")
    rows_per_batch = 10000
    total_rows = len(preserved_df)
    synthetic_data_batches = []

    for batch_start in tqdm(range(0, total_rows, rows_per_batch), desc="Generating Rows"):
        batch_count = min(rows_per_batch, total_rows - batch_start)
        synthetic_data_batches.append(synthesizer.sample(num_rows=batch_count))

    synthetic_data_df = pd.concat(synthetic_data_batches).reset_index(drop=True)

    # Clean categorical columns
    for column in categorical_columns:
        if column in synthetic_data_df.columns:
            synthetic_data_df[column] = (
                synthetic_data_df[column]
                .replace([np.inf, -np.inf], np.nan)
                .fillna(-1)
                .round()
                .astype(int)
            )

    # Merge back identifiers
    final_synthetic_df = pd.concat([preserved_identifiers_df, synthetic_data_df], axis=1)

    # Save synthetic dataset
    output_file_path = f"SyntheticData/synthetic_diabetic_using_{imputation_method}_imp.csv"
    final_synthetic_df.to_csv(output_file_path, index=False)

    elapsed_time = round(time.time() - start_time, 2)
    logging.info(f"Synthetic data saved to: {output_file_path}")
    logging.info(f"Time taken: {elapsed_time} seconds")


# Generate synthetic datasets
generate_synthetic_dataset("ImputedData/diabetic_data_KNN_imputed.csv", "KNN")
generate_synthetic_dataset("ImputedData/diabetic_data_MEAN_imputed.csv", "MEAN")
generate_synthetic_dataset("ImputedData/diabetic_data_GAIN_imputed.csv", "GAIN")


# Evaluation Pipeline for Real vs Synthetic Datasets

This script compares real imputed datasets with synthetic datasets generated
using different imputation methods (GAIN, KNN, MEAN).

It evaluates them on:
    - Statistical similarity metrics (KS test, JS divergence, Wasserstein distance)
    - Predictive performance metrics (TSTR, TRTS)
    - Correlation structure comparison
    - Visual sample comparisons

Results are saved as CSVs and plots for further analysis.

In [None]:
import os
import pandas as pd
import numpy as np
import random
import seaborn as sns
import matplotlib.pyplot as plt
import logging
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, f1_score
from scipy.stats import ks_2samp

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s"
)

SEED = 42
random.seed(SEED)
np.random.seed(SEED)

REAL_DATA_DIR = "ImputedData"
SYNTHETIC_DATA_DIR = "SyntheticData"
RESULTS_DIR = "results"
PLOTS_DIR = os.path.join(RESULTS_DIR, "plots")

os.makedirs(PLOTS_DIR, exist_ok=True)

TARGET_COLUMN = "readmitted"

# File mapping
file_pairs = {
    "GAIN": ("diabetic_data_GAIN_imputed.csv", "synthetic_diabetic_using_GAIN_imp.csv"),
    "KNN": ("diabetic_data_KNN_imputed.csv", "synthetic_diabetic_using_KNN_imp.csv"),
    "MEAN": ("diabetic_data_MEAN_imputed.csv", "synthetic_diabetic_using_MEAN_imp.csv")
}

# Store results
ks_results = []
tstr_results = []

#Ks Test Fucntion
def ks_test(real, synthetic):
    numeric_cols = real.select_dtypes(include=['number']).columns
    stats = {}
    for col in numeric_cols:
        if real[col].isnull().all() or synthetic[col].isnull().all():
            continue
        ks_stat, _ = ks_2samp(real[col], synthetic[col])
        stats[col] = ks_stat
    valid_ks = list(stats.values())
    avg_ks = sum(valid_ks) / len(valid_ks) if valid_ks else None
    return avg_ks

#TSTR Test Function
def tstr(real_df, synth_df, target):
    X_train = synth_df.drop(columns=[target])
    y_train = synth_df[target]
    X_test = real_df.drop(columns=[target])
    y_test = real_df[target]

    X_train = X_train[X_test.columns]  

    model = RandomForestClassifier(random_state=SEED)
    model.fit(X_train, y_train)
    y_pred = model.predict(X_test)

    return {
        "Accuracy_TSTR": accuracy_score(y_test, y_pred),
        "F1_TSTR": f1_score(y_test, y_pred, average='macro')
    }


#Main Loop
for method, (real_file, synth_file) in file_pairs.items():
    logging.info(f"Processing {method}...")

    real_data = pd.read_csv(os.path.join(REAL_DATA_DIR, real_file))
    synthetic_data = pd.read_csv(os.path.join(SYNTHETIC_DATA_DIR, synth_file))

    # KS Test
    avg_ks = ks_test(real_data, synthetic_data)
    ks_results.append({"Imputation": method, "Avg_KS": avg_ks})

    # TSTR
    tstr_scores = tstr(real_data, synthetic_data, TARGET_COLUMN)
    tstr_scores["Imputation"] = method
    tstr_results.append(tstr_scores)

# Saving Results
ks_df = pd.DataFrame(ks_results)
tstr_df = pd.DataFrame(tstr_results)
ks_df.to_csv(os.path.join(RESULTS_DIR, "ks_scores.csv"), index=False)
tstr_df.to_csv(os.path.join(RESULTS_DIR, "tstr_scores.csv"), index=False)

logging.info("KS scores and TSTR scores saved to CSV.")

# KS Score Visualization
plt.figure(figsize=(6, 4))
sns.barplot(data=ks_df, x="Imputation", y="Avg_KS", palette="Blues_d")
plt.ylabel("Average KS Score")
plt.title("KS Score by Imputation")
plt.ylim(0, max(ks_df["Avg_KS"]) + 0.02)
plt.tight_layout()
plt.savefig(os.path.join(PLOTS_DIR, "ks_score_bar.png"))
plt.close()
logging.info("KS score bar plot saved.")

# TSTR Scores Visualization
tstr_melt = tstr_df.melt(id_vars="Imputation", value_vars=["Accuracy_TSTR", "F1_TSTR"])
plt.figure(figsize=(6, 4))
sns.barplot(data=tstr_melt, x="Imputation", y="value", hue="variable", palette="Set2")
plt.ylabel("Score")
plt.title("TSTR Evaluation Metrics by Imputation")
plt.ylim(0, max(tstr_melt["value"]) + 0.05)
plt.legend(title="Metric")
plt.tight_layout()
plt.savefig(os.path.join(PLOTS_DIR, "tstr_metrics_bar.png"))
plt.close()
logging.info("TSTR metrics bar plot saved.")

# Combined Visualization
combined_df = pd.merge(tstr_df, ks_df, on="Imputation")
combined_melt = combined_df.melt(id_vars="Imputation", var_name="Metric", value_name="Score")

plt.figure(figsize=(8, 5))
sns.barplot(data=combined_melt, x="Imputation", y="Score", hue="Metric", palette="Set2")
plt.ylabel("Score")
plt.title("TSTR Metrics & KS Score by Imputation", fontsize=14)
plt.legend(title="Metric")
plt.tight_layout()
plt.savefig(os.path.join(PLOTS_DIR, "tstr_ks_bar.png"), dpi=300)
plt.close()
logging.info("Combined TSTR and KS metrics plot saved.")

logging.info("Evaluation complete TSTR and KS results saved combined plot generated.")