In [None]:
import pandas as pd
import numpy as np
from synthcity.plugins.core.dataloader import SurvivalAnalysisDataLoader
from synthcity.metrics import Metrics
import os

In [None]:
def benchmark_limited_data_experiment():
    """
    Benchmark code for Table 4: Performance comparison with reduced training data on AIDS dataset
    """

    # Dataset and methods configuration
    dataset = "aids"  # Using AIDS dataset as specified in Table 4
    data_percentages = [100, 75, 50]  # Training data percentages
    methods = ["ours", "unconditional", "survivalgan"]

    # Define folder paths containing the 5 different datasets for each method
    # Adjust these paths according to your file structure
    data_paths = {
        "ours": ".../ours_datasets/",
        "unconditional": ".../unconditional_datasets/",
        "survivalgan": ".../survivalgan_datasets/"
    }

    # Path to original AIDS dataset
    original_data_path = ".../aids_final.csv"

    # Load original dataset
    df_original = pd.read_csv(original_data_path)
    if 'Unnamed: 0' in df_original.columns:
        df_original = df_original.drop('Unnamed: 0', axis=1)
    df_original = df_original[df_original['duration'] != 0]

    # Initialize results storage
    all_results = {}

    for percentage in data_percentages:
        print(f"\n=== Evaluating {percentage}% Training Data ===")
        all_results[percentage] = {}

        for method in methods:
            print(f"Processing method: {method}")

            # Storage for metrics across iterations
            metrics_list = []

            # Process 5 different datasets (one for each iteration)
            for iteration in range(1, 6):
                try:
                    # Load different synthetic dataset for each iteration
                    # Assumes datasets are named like: aids_ours_1.csv, aids_ours_2.csv, etc.
                    synthetic_filename = f"{dataset}_{method}_{iteration}.csv"
                    synthetic_path = os.path.join(data_paths[method], synthetic_filename)

                    if not os.path.exists(synthetic_path):
                        print(f"Warning: File not found: {synthetic_path}")
                        continue

                    df_synthetic_full = pd.read_csv(synthetic_path)

                    # Ensure synthetic data has same structure
                    if 'Unnamed: 0' in df_synthetic_full.columns:
                        df_synthetic_full = df_synthetic_full.drop('Unnamed: 0', axis=1)

                    # Sample the desired percentage of data
                    if percentage == 100:
                        df_synthetic = df_synthetic_full.copy()
                    else:
                        sample_size = int(len(df_synthetic_full) * (percentage / 100))
                        df_synthetic = df_synthetic_full.sample(n=sample_size, random_state=42).reset_index(drop=True)

                    # Create data loaders for evaluation
                    loader_original = SurvivalAnalysisDataLoader(
                        df_original,
                        target_column="event",
                        time_to_event_column="duration"
                    )
                    loader_synthetic = SurvivalAnalysisDataLoader(
                        df_synthetic,
                        target_column="event",
                        time_to_event_column="duration"
                    )

                    # Evaluate metrics using synthcity
                    met_df = Metrics.evaluate(
                        X_gt=loader_original,
                        X_syn=loader_synthetic,
                        task_type='survival_analysis',
                        metrics={
                            'stats': [
                                'jensenshannon_dist', 'chi_squared_test', 'feature_corr',
                                'inv_kl_divergence', 'ks_test', 'max_mean_discrepancy',
                                'wasserstein_dist', 'prdc', 'alpha_precision', 'survival_km_distance'
                            ],
                            'performance': ['linear_model', 'mlp', 'xgb', 'feat_rank_distance']
                        },
                        use_cache=False,
                        random_state=iteration
                    )

                    met_df = met_df.iloc[:, 0]
                    metrics_list.append(met_df)

                    print(f"  Dataset {iteration} completed (using {synthetic_filename}, sampled {len(df_synthetic)} rows from {len(df_synthetic_full)})")

                except Exception as e:
                    print(f"  Error in dataset {iteration}: {str(e)}")
                    continue

            # Create result_df 
            if len(metrics_list) > 0:
                result_df = pd.concat(metrics_list, axis=1)

                # Calculate the row-wise mean and standard deviation of the metrics
                result_df['Mean'] = result_df.mean(axis=1)
                result_df['Std'] = result_df.std(axis=1)
                result_df['Std'] = result_df['Std'].round(4)

                all_results[percentage][method] = result_df
            else:
                print(f"  No valid data found for {method} at {percentage}%")
                all_results[percentage][method] = None

    return all_results



# Main execution
if __name__ == "__main__":
    print("Starting Table 4 benchmark evaluation...")

    # Run the benchmark
    all_results = benchmark_limited_data_experiment()

    # Print results_df for each method and percentage combination
    for percentage in [100, 75, 50]:
        print(f"\n{'='*60}")
        print(f"{percentage}% Training Data Results")
        print(f"{'='*60}")

        for method in ['ours', 'unconditional', 'survivalgan']:
            if percentage in all_results and method in all_results[percentage] and all_results[percentage][method] is not None:
                print(f"\n{method.upper()} Method:")
                print(all_results[percentage][method])
            else:
                print(f"\n{method.upper()} Method: No data available")

    print("\nTable 4 benchmark evaluation completed!")