In [1]:
import sys
sys.path.insert(0, "../")

import numpy as np
import pandas as pd
from ctgan.synthesizers.ctgan import CTGANSynthesizer
from ctgan.synthesizers.tvae import TVAESynthesizer
from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier
from sklearn.metrics import accuracy_score, roc_auc_score
from sklearn.model_selection import train_test_split
from sklearn.naive_bayes import GaussianNB
from sklearn.neighbors import KNeighborsClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.utils import shuffle

from src.metrics import *
from src.metrics import eval_plugin
from src.synthesizer import fit_ctgan
from src.utils import *

use_trained_model = False

seed = 42
seed_everything(seed)





# Load data

In [4]:
from sklearn.model_selection import StratifiedShuffleSplit, train_test_split 
from src.data_loader import load_support_dataset

seed = 42  # make sure seed is defined

X, y, Data = load_support_dataset()

column_metric = "race"
df = Data  # df includes the column we stratify on

# regions: the stratification labels
regions = df[column_metric]

# First split: train+val vs test
splitter = StratifiedShuffleSplit(n_splits=1, test_size=0.65, random_state=seed)
train_val_indices, test_indices = next(splitter.split(df, regions))

# Use iloc (positional indexing) here
regions_train_val = df[column_metric].iloc[train_val_indices]

# Second split: train vs val
splitter = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=seed)
train_indices, val_indices = next(
    splitter.split(df.iloc[train_val_indices], regions_train_val)
)

# Build splits with iloc
X_train = df.iloc[train_val_indices[train_indices]]
X_test = df.iloc[train_val_indices[val_indices]]
X_oracle = df.iloc[test_indices]

# Optional: further split off HP set
X_train, X_hp = train_test_split(X_train, test_size=0.1, random_state=seed)

# Check all splits have the same set of race values
assert (
    set(X_train[column_metric].unique())
    == set(X_test[column_metric].unique())
    == set(X_oracle[column_metric].unique())
)


# Train Generators

In [6]:
import pickle

ctgans = []
trials = []
params_list = []

# Candidate discrete columns you *wish* to treat as categorical
raw_discrete_columns = [
    "sex",
    "ARF/MOSF w/Sepsis",
    "COPD",
    "CHF",
    "Cirrhosis",
    "Coma",
    "Colon Cancer",
    "Lung Cancer",
    "MOSF w/Malig",
    "ARF/MOSF",
    "Cancer",
    "num.co",
    "hday",
    "diabetes",
    "dementia",
    "hrt",
    "resp",
    "y",
    "salary",
    "race",
]

# Keep only those that actually exist in the data
discrete_columns = [c for c in raw_discrete_columns if c in X_test.columns]

dataset_name = "support"

pd.options.mode.chained_assignment = None
hp_sweep = False
if use_trained_model == False:
    if hp_sweep:
        import optuna
        from optuna.samplers import TPESampler

        # Optimize HPS
        def gan_objective(trial):

            discrete_columns = [
                "sex",
                "ARF/MOSF w/Sepsis",
                "COPD",
                "CHF",
                "Cirrhosis",
                "Coma",
                "Colon Cancer",
                "Lung Cancer",
                "MOSF w/Malig",
                "ARF/MOSF",
                "Cancer",
                "num.co",
                "hday",
                "diabetes",
                "dementia",
                "hrt",
                "resp",
                "y",
                "salary",
                "race",
            ]
            learning_rate = trial.suggest_categorical(
                "learning_rate", [2e-4, 2e-5, 2e-6]
            )
            embedding_dim = trial.suggest_categorical("embedding_dim", [64, 128, 256])
            epochs = trial.suggest_categorical("epochs", [200, 300, 500])

            ctgan = fit_ctgan(
                data=X_test,
                epochs=epochs,
                learning_rate=learning_rate,
                embedding_dim=embedding_dim,
                discrete_columns=discrete_columns,
            )
            D_fake, _ = ctgan.sample(X_hp.shape[0], shift=False)
            metric = MaximumMeanDiscrepancy
            trial_results = eval_plugin(
                metric,
                GenericDataLoader(X_hp.astype(float)),
                GenericDataLoader(D_fake.astype(float)),
            )
            trials.append(trial_results)
            print(
                eval_plugin(
                    WassersteinDistance,
                    GenericDataLoader(X_hp.astype(float)),
                    GenericDataLoader(D_fake.astype(float)),
                )
            )

            print(f"HPS = Params: {trial.params} | Score: {trial_results[0]}")
            return trial_results[0]["joint"]

        gan_study = optuna.create_study(direction="minimize", sampler=TPESampler())
        gan_study.optimize(gan_objective, show_progress_bar=True, n_trials=10)
        print("Best parameters:", gan_study.best_params)

        ctgan = fit_ctgan(
            data=X_test,
            epochs=gan_study.best_params["epochs"],
            learning_rate=gan_study.best_params["learning_rate"],
            embedding_dim=gan_study.best_params["embedding_dim"],
            discrete_columns=discrete_columns,
        )
        ctgans.append(deepcopy(ctgan))

    else:

        from copy import deepcopy
        from tqdm import tqdm

        for i in tqdm(range(10)):

            best_params = {"learning_rate": 0.0002, "embedding_dim": 256, "epochs": 100}
            ctgan = fit_ctgan(
                data=X_test,
                epochs=best_params["epochs"],
                learning_rate=best_params["learning_rate"],
                embedding_dim=best_params["embedding_dim"],
                seed=seed,
                discrete_columns=discrete_columns,
            )
            D_fake, _ = ctgan.sample(X_hp.shape[0], shift=False)

            trial_results = eval_plugin(
                MaximumMeanDiscrepancy,
                GenericDataLoader(X_hp.astype(float)),
                GenericDataLoader(D_fake.astype(float)),
            )
            trials.append(trial_results)
            params_list.append(best_params)
            ctgans.append(deepcopy(ctgan))

        # save each ctgan model in ctgan_list
        for idx, ctgan_save in enumerate(ctgans):
            ctgan_save.save(f"../models/ctgan_{dataset_name}_{idx+1}")

        trials_list = [trials[idx][0]["joint"] for idx in range(len(trials))]
        ctgan_idx = trials_list.index(min(trials_list))

        ctgan = ctgans[ctgan_idx]

        ctgan.save(f"../models/ctgan_{dataset_name}")

        # pickle best params
        with open(f"../models/ctgan_{dataset_name}_params.pkl", "wb") as f:
            pickle.dump(params_list[ctgan_idx], f)


else:
    # Load best_params
    with open(f"../models/ctgan_{dataset_name}_params.pkl", "rb") as f:
        best_params = pickle.load(f)

    ctgan = CTGANSynthesizer(
        embedding_dim=best_params["embedding_dim"],
        generator_dim=(256, 256),
        discriminator_dim=(256, 256),
        generator_lr=best_params["learning_rate"],
        generator_decay=1e-6,
        discriminator_lr=best_params["learning_rate"],
        discriminator_decay=1e-6,
        batch_size=500,
        discriminator_steps=1,
        log_frequency=True,
        verbose=False,
        epochs=best_params["epochs"],
        pac=10,
        cuda=True,
    )

    ctgan = ctgan.load(f"../models/ctgan_{dataset_name}")


100%|██████████| 10/10 [01:25<00:00,  8.54s/it]


# Train the downstream predictors

In [7]:
from sklearn.ensemble import AdaBoostClassifier, BaggingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC

model_dict = {
    "mlp": MLPClassifier(random_state=seed),
    "knn": KNeighborsClassifier(),
    "dt": DecisionTreeClassifier(random_state=seed),
    "rf": RandomForestClassifier(random_state=seed),
    "gbc": GradientBoostingClassifier(random_state=seed),
    "bag": BaggingClassifier(random_state=seed),
    "ada": AdaBoostClassifier(random_state=seed),
    "svm": SVC(random_state=seed),
    "lr": LogisticRegression(random_state=seed),
}

print("training baseline models")

trained_model_dict = train_models(X_train, model_dict)


training baseline models


# Helpers

In [9]:
from copy import deepcopy

import numpy as np
from fairlearn.metrics import demographic_parity_ratio as dp_ratio
from fairlearn.metrics import equalized_odds_ratio as eo_ratio
from sklearn.metrics import accuracy_score, f1_score
from sklearn.model_selection import train_test_split
from tqdm import tqdm


def evaluate_models(models, data, column_metric, sensitive="Sex_male"):
    performance = {"acc": {}, "f1": {}, "eo": {}, "dp": {}}

    for model_name, model in models.items():
        if data.shape[0] == 0:
            performance["acc"][model_name] = 0
            performance["f1"][model_name] = 0
            performance["eo"][model_name] = 0
            performance["dp"][model_name] = 0
            continue

        y_pred = model.predict(data.drop("y", axis=1))
        accuracy = accuracy_score(data["y"], y_pred)
        F1 = f1_score(data["y"], y_pred)

        try:
            eo_score = eo_ratio(
                data["y"], y_pred, sensitive_features=data[sensitive].values
            )
        except:
            eo_score = 0

        try:
            dp_score = dp_ratio(
                data["y"], y_pred, sensitive_features=data[sensitive].values
            )
        except:
            dp_score = 0

        performance["acc"][model_name] = accuracy
        performance["f1"][model_name] = F1
        performance["eo"][model_name] = eo_score
        performance["dp"][model_name] = dp_score

    return performance


def run_analysis(
    column_metric,
    ctgan,
    Data,
    X_train,
    X_test,
    X_oracle,
    trained_model_dict,
    random_state=0,
    n_samples=1000,
):
    # Initialization
    groups = list(np.unique(Data[column_metric]))

    X_train, X_test = train_test_split(
        X_train, test_size=0.2, random_state=random_state
    )
    test_dataset = deepcopy(X_test)

    (
        test_results,
        test_samples,
        oracle_results,
        oracle_samples,
        synth_results,
        synth_samples,
        aug_results,
        aug_samples,
    ) = ({}, {}, {}, {}, {}, {}, {}, {})

    for group in tqdm(groups):
        total_samples = (
            X_test[X_test[column_metric] == 0].shape[0] if group == 0 else n_samples
        )
        total_samples = n_samples
        test_data = X_test[X_test[column_metric] == group]
        oracle_data = X_oracle[X_oracle[column_metric] == group]

        # Dtest
        test_results[group], test_samples[group] = (
            evaluate_models(trained_model_dict, test_data, column_metric),
            test_data.shape[0],
        )
        oracle_results[group], oracle_samples[group] = (
            evaluate_models(trained_model_dict, oracle_data, column_metric),
            oracle_data.shape[0],
        )

        # Synthetic data sampling
        synth_data, _ = ctgan.sample(
            1, shift=False, condition_column=column_metric, condition_value=group
        )
        count = 0
        while synth_data.shape[0] <= total_samples:
            tmp_df = ctgan.sample(
                n_samples,
                shift=False,
                condition_column=column_metric,
                condition_value=group,
            )[0]
            synth_data = synth_data.append(tmp_df[tmp_df[column_metric] == group])
            count += 1
        synth_results[group], synth_samples[group] = (
            evaluate_models(trained_model_dict, synth_data, column_metric),
            synth_data.shape[0],
        )

        # Augmented data evaluation
        aug_data = pd.concat([test_data, synth_data])
        aug_results[group], aug_samples[group] = (
            evaluate_models(trained_model_dict, aug_data, column_metric),
            aug_data.shape[0],
        )

    return (
        test_results,
        test_samples,
        oracle_results,
        oracle_samples,
        synth_results,
        synth_samples,
        aug_results,
        aug_samples,
        test_dataset,
    )


# Run analysis

In [10]:
from tqdm import tqdm

test_acc_list = []
test_samples_list = []
oracle_acc_list = []
oracle_samples_list = []
synth_acc_list = []
synth_samples_list = []
aug_acc_list = []
aug_samples_list = []
final_test_dataset = []


n_runs = 5
ns = 5000

for i in tqdm(range(n_runs)):

    done = False
    tries = 0

    while done == False:
        try:
            (
                test_final,
                test_samples,
                oracle_final,
                oracle_samples,
                synth_final,
                synth_samples,
                aug_final,
                aug_samples,
                test_dataset,
            ) = run_analysis(
                column_metric=column_metric,
                Data=Data,
                X_train=X_train,
                X_test=X_test,
                X_oracle=X_oracle,
                ctgan=ctgan,
                trained_model_dict=trained_model_dict,
                random_state=i * 100,
                n_samples=ns,
            )
            done = True
            final_test_dataset.append(test_dataset)
        except Exception:
            import traceback

            print(traceback.format_exc())
            tries += 1

            if tries > 5:
                done = True
            continue

    test_acc_list.append(test_final)
    test_samples_list.append(test_samples)
    oracle_acc_list.append(oracle_final)
    oracle_samples_list.append(oracle_samples)
    synth_acc_list.append(synth_final)
    synth_samples_list.append(synth_samples)
    aug_acc_list.append(aug_final)
    aug_samples_list.append(aug_samples)


100%|██████████| 4/4 [00:21<00:00,  5.26s/it]
100%|██████████| 4/4 [00:21<00:00,  5.39s/it]
100%|██████████| 4/4 [00:21<00:00,  5.34s/it]
100%|██████████| 4/4 [00:29<00:00,  7.37s/it]
100%|██████████| 4/4 [00:30<00:00,  7.51s/it]
100%|██████████| 5/5 [02:03<00:00, 24.70s/it]


# Process results & save

In [11]:
import pandas as pd
from brmp import brm
from brmp.numpyro_backend import backend as numpyro
from brmp.pyro_backend import backend as pyro_backend
from copy import deepcopy

metrics = ["acc"]
models = list(trained_model_dict.keys())

models = ['rf']


X_test_new = deepcopy(final_test_dataset[0])

for metric in metrics:

    for model in models:

        try:
            blm = []
            for i in range(5):
                seed = i * 100
                xt = deepcopy(X_train)
                _, X_test = train_test_split(xt, test_size=0.2, random_state=seed)
                X_test["S"] = np.max(
                    trained_model_dict[model].predict_proba(X_test.drop("y", axis=1)),
                    axis=1,
                )
                model3 = brm(
                    "S ~ race + salary + sex + age + ph + glucose + sod + crea + bili + alb + wblc + y",
                    X_test,
                )
                fit3 = model3.fit(
                    backend=pyro_backend, seed=seed, iter=1000, warmup=100
                )
                scores3 = fit3.fitted(what="sample", data=None, seed=seed)
                blm.append(scores3)
            blm = np.array(blm)

            props = [
                count / np.sum(np.unique(Data[column_metric], return_counts=True)[1])
                for count in np.unique(Data[column_metric], return_counts=True)[1]
            ]

            groups = np.sort(list(X_test[column_metric].unique()))

            for model_name in [model]:
                print(model_name)
                data_list = []
                idx = 0

                df = pd.DataFrame(columns=["Group", "3S", "3S+", "BLM", "Dtest"])
                df_std = pd.DataFrame(columns=["Group", "3S", "3S+", "BLM", "Dtest"])

                for group in np.argsort(props)[::-1]:
           
                    group = group + 1
                    idx += 1
                    mylist = oracle_acc_list
                    oracle_res = np.array(
                        [
                            mylist[i][group][metric][model_name]
                            for i in range(len(mylist))
                        ]
                    )

                    mylist = synth_acc_list
                    synth_res = np.array(
                        [
                            mylist[i][group][metric][model_name]
                            for i in range(len(mylist))
                        ]
                    )

                    mylist = test_acc_list
                    test_res = np.array(
                        [
                            mylist[i][group][metric][model_name]
                            for i in range(len(mylist))
                        ]
                    )

                    mylist = aug_acc_list
                    aug_res = np.array(
                        [
                            mylist[i][group][metric][model_name]
                            for i in range(len(mylist))
                        ]
                    )

                    try:
                        # blm
                        blm_res = []
                        for i in range(5):
                            y_pred = (np.mean(blm[i], axis=0) > 0.75).astype(int) # confident predictions
                            group_ids = np.argwhere(
                                np.array(X_test_new[column_metric] == group).astype(int)
                                == 1
                            )
                            y_pred_group = y_pred[group_ids]
                            y_true_group = X_test_new[
                                X_test_new[column_metric] == group
                            ]["y"]
                            blm_res.append(accuracy_score(y_true_group, y_pred_group))

                        blm_res = np.array(blm_res)


                        mydict = {
                            "Group": f"{idx} ({int(round(props[group-1]*100,0))}%)",
                            "3S": round(
                                np.mean(np.abs(oracle_res - synth_res)) * 100, 2
                            ),
                            "3S+": round(
                                np.mean(np.abs(oracle_res - aug_res)) * 100, 2
                            ),
                            "BLM": round(
                                np.mean(np.abs(oracle_res - blm_res)) * 100, 2
                            ),
                            "Dtest": round(
                                np.mean(np.abs(oracle_res - test_res)) * 100, 2
                            ),
                        }
                        df = df.append(mydict, ignore_index=True)

                        mydict = {
                            "Group": f"{idx} ({int(round(props[group-1]*100,0))}%)",
                            "3S": round(
                                np.std(np.abs(oracle_res - synth_res)) * 100, 2
                            ),
                            "3S+": round(np.std(np.abs(oracle_res - aug_res)) * 100, 2),
                            "BLM": round(np.std(np.abs(oracle_res - blm_res)) * 100, 2),
                            "Dtest": round(
                                np.std(np.abs(oracle_res - test_res)) * 100, 2
                            ),
                        }
                        df_std = df_std.append(mydict, ignore_index=True)

                    except Exception:
                        import traceback

                        print(traceback.format_exc())
                        continue

            df.to_csv(f"../results/{dataset_name}_{model}_{metric}.csv")
            df_std.to_csv(f'../results/{dataset_name}_std_{model}_{metric}.csv')

        except Exception as e:
            import traceback

            print(traceback.format_exc())
            print("IM EXCEPTING", e)


Sample: 100%|██████████| 1100/1100 [12:05,  1.52it/s, step size=2.02e-04, acc. prob=0.958]
Sample: 100%|██████████| 1100/1100 [10:46,  1.70it/s, step size=5.02e-04, acc. prob=0.905]
Sample: 100%|██████████| 1100/1100 [06:41,  2.74it/s, step size=5.28e-04, acc. prob=0.950]
Sample: 100%|██████████| 1100/1100 [19:12,  1.05s/it, step size=1.32e-04, acc. prob=0.933]
Sample: 100%|██████████| 1100/1100 [10:11,  1.80it/s, step size=3.80e-04, acc. prob=0.907]


rf
Traceback (most recent call last):
  File "/tmp/ipykernel_51767/1918830799.py", line 104, in <module>
    blm_res.append(accuracy_score(y_true_group, y_pred_group))
  File "/home/dhanush/anaconda3/envs/3s/lib/python3.9/site-packages/sklearn/utils/_param_validation.py", line 214, in wrapper
    return func(*args, **kwargs)
  File "/home/dhanush/anaconda3/envs/3s/lib/python3.9/site-packages/sklearn/metrics/_classification.py", line 220, in accuracy_score
    y_type, y_true, y_pred = _check_targets(y_true, y_pred)
  File "/home/dhanush/anaconda3/envs/3s/lib/python3.9/site-packages/sklearn/metrics/_classification.py", line 93, in _check_targets
    raise ValueError(
ValueError: Classification metrics can't handle a mix of binary and unknown targets

Traceback (most recent call last):
  File "/tmp/ipykernel_51767/1918830799.py", line 61, in <module>
    [
  File "/tmp/ipykernel_51767/1918830799.py", line 62, in <listcomp>
    mylist[i][group][metric][model_name]
KeyError: 4

IM EXCEPTING

# Coverage helper

In [12]:
from copy import deepcopy

import pandas as pd
from brmp import brm
from brmp.numpyro_backend import backend as numpyro
from brmp.pyro_backend import backend as pyro_backend


def get_group(X_test, column_metric, groups, test_group="small"):
    smallest_group = 0
    gsizes = []
    for idx, group in enumerate(groups):
        gsizes.append(X_test[X_test[column_metric] == group].shape[0])

    print(gsizes)
    ns = 10
    for idx, group in enumerate(groups):

        if test_group == "small":
            if idx == 0:
                size_smallest_group = X_test[X_test[column_metric] == 0].shape[0]
                smallest_group = group
                if size_smallest_group < ns:
                    size_smallest_group = 99999
            else:
                if (
                    X_test[X_test[column_metric] == group].shape[0]
                    < size_smallest_group
                ):
                    if X_test[X_test[column_metric] == group].shape[0] > ns:
                        size_smallest_group = X_test[
                            X_test[column_metric] == group
                        ].shape[0]
                        smallest_group = group

        else:
            if idx == 0:
                size_smallest_group = X_test[X_test[column_metric] == 0].shape[0]
                smallest_group = group
            else:
                if (
                    X_test[X_test[column_metric] == group].shape[0]
                    > size_smallest_group
                ):
                    size_smallest_group = X_test[X_test[column_metric] == group].shape[
                        0
                    ]
                    smallest_group = group

    return smallest_group


def uncertainty(
    column_metric,
    ctgans,
    X_train,
    X_test,
    X_oracle,
    trained_model_dict,
    random_state=0,
    n_samples=5000,
    test_group="small",
):
    from sklearn.metrics import f1_score
    from fairlearn.metrics import equalized_odds_ratio as eo_ratio
    from fairlearn.metrics import demographic_parity_ratio as dp_ratio
    from tqdm import tqdm

    score = "acc"

    groups = list(np.unique(Data[column_metric]))
    orig_random_state = deepcopy(random_state)

    coverage_dict = {}
    width_dict = {}
    excess_dict = {}
    deficet_dict = {}

    for model in ["rf"]:

        random_state = orig_random_state


        try:
            trained_model_dict[model].predict_proba(X_train.drop("y", axis=1))
        except:
            continue

        test_coverage = []
        synth_coverage = []
        synth_coverage5 = []
        synth_coverage10 = []
        blm_coverage = []

        test_width = []
        synth_width = []
        synth_width5 = []
        synth_width10 = []
        blm_width = []

        test_excess = []
        synth_excess = []
        synth_excess5 = []
        synth_excess10 = []
        blm_excess = []

        test_deficet = []
        synth_deficet = []
        synth_deficet5 = []
        synth_deficet10 = []
        blm_deficet = []

        error_test = []
        error_synth = []

        for random_state in tqdm(range(20)):

         
            _, X_test = train_test_split(
                X_train, test_size=0.2, random_state=random_state
            )

            deepcopy(X_test)

          
            blm_accs = {}

        
            smallest_group = get_group(
                X_test, column_metric, groups, test_group="small"
            )
            print(f"Smallest group: {smallest_group}, random_state: {random_state}")

            for group in groups:

                if group != smallest_group:
                    continue

                # n_samples = 1000
                if group == 0:
                    total_samples = X_test[X_test[column_metric] == 0].shape[0]
                else:
                    total_samples = n_samples

                test_data = X_test[X_test[column_metric] == group]

                oracle_data = X_oracle[X_oracle[column_metric] == group]

                print(f"TEST SIZE {group} - {test_data.shape[0]}")

                if test_data.shape[0] == 0:
                    accuracy = 0

                else:
                    slack = 0.025  # slack for the confidence interval
                    clf = model_dict[model]

                    ############################################################
                    # Oracle
                    ############################################################
                    y_pred = clf.predict(oracle_data.drop("y", axis=1))
                    accuracy_oracle = accuracy_score(oracle_data["y"], y_pred)
                    print(oracle_data.shape, accuracy_oracle)

                    ############################################################
                    # Bootstrap Dtest
                    ############################################################
                    y_pred = clf.predict(test_data.drop("y", axis=1))
                    accuracy = accuracy_score(test_data["y"], y_pred)
                    R = 1000
                    confidence_level = 0.95
                    metric = accuracy_score
                    scores = bootstrap(
                        test_data["y"].values.reshape(-1, 1),
                        y_pred.reshape(-1, 1),
                        metric,
                        R,
                    )
                    bottom_test, top_test = confidence_intervals(
                        scores, confidence_level
                    )
                    score = 0
                    if (
                        accuracy_oracle >= bottom_test.values - slack
                        and accuracy_oracle <= top_test.values + slack
                    ):
                        score = 1
                    test_coverage.append(score)
                    error_test.append(accuracy_oracle - accuracy)
                    excess, deficet, width = compute_interval_metrics(
                        lb=bottom_test.values - slack,
                        ub=top_test.values + slack,
                        true=accuracy_oracle,
                    )
                    test_width.append(width)
                    if excess != -9999:
                        test_excess.append(excess)
                    if deficet != -9999:
                        test_deficet.append(deficet)

                    ############################################################
                    # MBM Interval
                    ############################################################
                    X_test_new = deepcopy(X_test)
                    X_test_new["S"] = trained_model_dict[model].predict_proba(
                        X_test_new.drop("y", axis=1)
                    )[:, 1]

                    model3 = brm(
                        "S ~ race + salary + sex + age + ph + glucose + sod + crea + bili + alb + wblc + y",
                        X_test_new,
                    )
                    fit3 = model3.fit(
                        backend=pyro_backend, seed=0, iter=1000, warmup=100
                    )
                    myscores = fit3.fitted(what="sample", data=None, seed=0)

                    blm_accs = []
                    for j in range(100):
                        scores3 = np.array(
                            [
                                myscores[np.random.randint(1000, size=1), i]
                                for i in range(myscores.shape[1])
                            ]
                        )
                        y_pred = (scores3 > 0.75).astype(int).flatten()
                        group_ids = np.argwhere(
                            np.array(X_test_new[column_metric] == group).astype(int)
                            == 1
                        )
                        y_pred_group = y_pred[group_ids]
                        acc = accuracy_score(test_data["y"], y_pred_group)
                        blm_accs.append(acc)

                    mean_blm_acc = np.mean(blm_accs)
                    std_blm_acc = np.std(blm_accs)
                    bottom_blm = mean_blm_acc - 1.96 * std_blm_acc
                    top_blm = mean_blm_acc + 1.96 * std_blm_acc

                    score = 0
                    if (
                        accuracy_oracle >= bottom_blm - slack
                        and accuracy_oracle <= top_blm + slack
                    ):
                        score = 1
                    blm_coverage.append(score)
                    excess, deficet, width = compute_interval_metrics(
                        lb=bottom_blm - slack, ub=top_blm + slack, true=accuracy_oracle
                    )
                    blm_width.append(width)
                    if excess != -9999:
                        blm_excess.append(excess)
                    if deficet != -9999:
                        blm_deficet.append(deficet)

                    ############################################################
                    # 3S (k=1)
                    ############################################################
                    syn_accs = []
                    for i in range(100):
                        shift_df, _ = ctgans[0].sample(
                            1,
                            shift=False,
                            condition_column=column_metric,
                            condition_value=group,
                        )
                        count = 0
                        while shift_df.shape[0] <= total_samples:
                            generated_tmp, _ = ctgans[0].sample(
                                n_samples,
                                shift=False,
                                condition_column=column_metric,
                                condition_value=group,
                            )
                            tmp_df = generated_tmp[
                                generated_tmp[column_metric] == group
                            ]
                            shift_df = shift_df.append(tmp_df)
                            count += 1

                        syn_data = shift_df[shift_df[column_metric] == group]
                        y_pred = clf.predict(syn_data.drop("y", axis=1))
                        accuracy = accuracy_score(syn_data["y"], y_pred)
                        syn_accs.append(accuracy)

                    mean_syn_acc = np.mean(syn_accs)
                    std_syn_acc = np.std(syn_accs)
                    bottom_syn = mean_syn_acc - 1.96 * std_syn_acc
                    top_syn = mean_syn_acc + 1.96 * std_syn_acc

                    ############################################################
                    # 3S (k=5)
                    ############################################################
                    syn_accs5 = []
                    for ctgan in ctgans[0:5]:
                        shift_df, _ = ctgan.sample(
                            1,
                            shift=False,
                            condition_column=column_metric,
                            condition_value=group,
                        )

                        count = 0
                        while shift_df.shape[0] <= total_samples:
                            generated_tmp, _ = ctgan.sample(
                                n_samples,
                                shift=False,
                                condition_column=column_metric,
                                condition_value=group,
                            )
                            tmp_df = generated_tmp[
                                generated_tmp[column_metric] == group
                            ]
                            shift_df = shift_df.append(tmp_df)
                            count += 1

                        syn_data = shift_df[shift_df[column_metric] == group]
                        y_pred = clf.predict(syn_data.drop("y", axis=1))
                        accuracy = accuracy_score(syn_data["y"], y_pred)
                        syn_accs5.append(accuracy)

                    mean_syn_acc = np.mean(syn_accs5)
                    std_syn_acc = np.std(syn_accs5)
                    bottom_syn5 = mean_syn_acc - 1.96 * std_syn_acc
                    top_syn5 = mean_syn_acc + 1.96 * std_syn_acc

                    ############################################################
                    # 3S (k=10)
                    ############################################################
                    syn_accs10 = []
                    for ctgan in ctgans:
                        shift_df, _ = ctgan.sample(
                            1,
                            shift=False,
                            condition_column=column_metric,
                            condition_value=group,
                        )

                        count = 0
                        while shift_df.shape[0] <= total_samples:
                            generated_tmp, _ = ctgan.sample(
                                n_samples,
                                shift=False,
                                condition_column=column_metric,
                                condition_value=group,
                            )
                            tmp_df = generated_tmp[
                                generated_tmp[column_metric] == group
                            ]
                            shift_df = shift_df.append(tmp_df)
                            count += 1

                        syn_data = shift_df[shift_df[column_metric] == group]
                        y_pred = clf.predict(syn_data.drop("y", axis=1))
                        accuracy = accuracy_score(syn_data["y"], y_pred)
                        syn_accs10.append(accuracy)

                    mean_syn_acc = np.mean(syn_accs10)
                    std_syn_acc = np.std(syn_accs10)
                    error_synth.append(accuracy_oracle - mean_syn_acc)
                    bottom_syn10 = mean_syn_acc - 1.96 * std_syn_acc
                    top_syn10 = mean_syn_acc + 1.96 * std_syn_acc

                    # compute if the accuracy_oracle is within the confidence interval of the accuracy
                    score = 0
             
                    if (
                        accuracy_oracle >= bottom_syn - slack
                        and accuracy_oracle <= top_syn + slack
                    ):
                        score = 1

                    synth_coverage.append(score)
                    excess, deficet, width = compute_interval_metrics(
                        lb=bottom_syn - slack, ub=top_syn + slack, true=accuracy_oracle
                    )
                    synth_width.append(width)
                    if excess != -9999:
                        synth_excess.append(excess)
                    if deficet != -9999:
                        synth_deficet.append(deficet)

                    score = 0
                    
                    if (
                        accuracy_oracle >= bottom_syn5 - slack
                        and accuracy_oracle <= top_syn5 + slack
                    ):
                        score = 1
                    synth_coverage5.append(score)
                    excess, deficet, width = compute_interval_metrics(
                        lb=bottom_syn5 - slack,
                        ub=top_syn5 + slack,
                        true=accuracy_oracle,
                    )
                    synth_width5.append(width)
                    if excess != -9999:
                        synth_excess5.append(excess)
                    if deficet != -9999:
                        synth_deficet5.append(deficet)

                    score = 0
                    
                    if (
                        accuracy_oracle >= bottom_syn10 - slack
                        and accuracy_oracle <= top_syn10 + slack
                    ):
                        score = 1
                    synth_coverage10.append(score)
                    excess, deficet, width = compute_interval_metrics(
                        lb=bottom_syn10 - slack,
                        ub=top_syn10 + slack,
                        true=accuracy_oracle,
                    )
                    synth_width10.append(width)
                    if excess != -9999:
                        synth_excess10.append(excess)
                    if deficet != -9999:
                        synth_deficet10.append(deficet)

        mean_test_coverage = np.mean(test_coverage)
        mean_blm_coverage = np.mean(blm_coverage)
        mean_synth_coverage = np.mean(synth_coverage)
        mean_synth_coverage5 = np.mean(synth_coverage5)
        mean_synth_coverage10 = np.mean(synth_coverage10)
        np.mean(error_test)
        np.mean(error_synth)

        # mean for width for synth, synth2, synth3, test, mbm
        mean_synth_width = np.mean(synth_width)
        mean_synth_width5 = np.mean(synth_width5)
        mean_synth_width10 = np.mean(synth_width10)
        mean_test_width = np.mean(test_width)
        mean_blm_width = np.mean(blm_width)

        # excess for width for synth, synth2, synth3, test, mbm
        mean_synth_excess = np.mean(synth_excess)
        mean_synth_excess5 = np.mean(synth_excess5)
        mean_synth_excess10 = np.mean(synth_excess10)
        mean_test_excess = np.mean(test_excess)
        mean_blm_excess = np.mean(blm_excess)

        # mean for excess for synth, synth2, synth3, test, mbm
        mean_synth_deficet = np.mean(synth_deficet)
        mean_synth_deficet5 = np.mean(synth_deficet5)
        mean_synth_deficet10 = np.mean(synth_deficet10)
        mean_test_deficet = np.mean(test_deficet)
        mean_blm_deficet = np.mean(blm_deficet)

        # coverage dict
        coverage_dict[model] = {
            "test": mean_test_coverage,
            "synth": mean_synth_coverage,
            "synth2": mean_synth_coverage5,
            "synth3": mean_synth_coverage10,
            "MBM": mean_blm_coverage,
        }

        # width dict
        width_dict[model] = {
            "test": mean_test_width,
            "synth": mean_synth_width,
            "synth2": mean_synth_width5,
            "synth3": mean_synth_width10,
            "MBM": mean_blm_width,
        }

        # excess dict
        excess_dict[model] = {
            "test": mean_test_excess,
            "synth": mean_synth_excess,
            "synth2": mean_synth_excess5,
            "synth3": mean_synth_excess10,
            "MBM": mean_blm_excess,
        }

        # deficet dict
        deficet_dict[model] = {
            "test": mean_test_deficet,
            "synth": mean_synth_deficet,
            "synth2": mean_synth_deficet5,
            "synth3": mean_synth_deficet10,
            "MBM": mean_blm_deficet,
        }

    return coverage_dict, width_dict, excess_dict, deficet_dict


# Coverage helper

In [13]:
from tqdm import tqdm

try:
    coverage, width, excess, deficet = uncertainty(
        column_metric=column_metric,
        ctgans=ctgans[0:10],
        X_train=X_train,
        X_test=X_test,
        X_oracle=X_oracle,
        trained_model_dict=trained_model_dict,
        random_state=i * 100,
        test_group="small",
    )
    done = True

except Exception:
    import traceback

    print(traceback.format_exc())


  0%|          | 0/20 [00:00<?, ?it/s]

[236, 56, 0, 14]
Smallest group: 3.0, random_state: 0
TEST SIZE 3.0 - 14
(140, 41) 0.7928571428571428


Sample: 100%|██████████| 1100/1100 [06:57,  2.63it/s, step size=4.42e-04, acc. prob=0.952]
  5%|▌         | 1/20 [14:15<4:31:00, 855.84s/it]

[247, 48, 4, 7]
Smallest group: 1.0, random_state: 1
TEST SIZE 1.0 - 48
(638, 41) 0.7633228840125392


Sample: 100%|██████████| 1100/1100 [05:03,  3.63it/s, step size=8.02e-04, acc. prob=0.883]
 10%|█         | 2/20 [21:37<3:03:39, 612.20s/it]

[242, 52, 4, 8]
Smallest group: 1.0, random_state: 2
TEST SIZE 1.0 - 52
(638, 41) 0.7633228840125392


Sample: 100%|██████████| 1100/1100 [02:15,  8.10it/s, step size=9.58e-04, acc. prob=0.306]
 15%|█▌        | 3/20 [26:12<2:09:46, 458.03s/it]

[252, 40, 1, 13]
Smallest group: 3.0, random_state: 3
TEST SIZE 3.0 - 13
(140, 41) 0.7928571428571428


Sample: 100%|██████████| 1100/1100 [04:13,  4.34it/s, step size=9.93e-04, acc. prob=0.879]
 20%|██        | 4/20 [38:10<2:29:33, 560.86s/it]

[242, 54, 5, 5]
Smallest group: 1.0, random_state: 4
TEST SIZE 1.0 - 54
(638, 41) 0.7633228840125392


Sample: 100%|██████████| 1100/1100 [08:55,  2.05it/s, step size=2.50e-04, acc. prob=0.952]
 25%|██▌       | 5/20 [49:34<2:31:16, 605.12s/it]

[243, 50, 1, 12]
Smallest group: 3.0, random_state: 5
TEST SIZE 3.0 - 12
(140, 41) 0.7928571428571428


Sample: 100%|██████████| 1100/1100 [07:11,  2.55it/s, step size=3.89e-04, acc. prob=0.962]
 30%|███       | 6/20 [1:04:40<2:45:04, 707.49s/it]

[232, 57, 2, 15]
Smallest group: 3.0, random_state: 6
TEST SIZE 3.0 - 15
(140, 41) 0.7928571428571428


Sample: 100%|██████████| 1100/1100 [04:19,  4.23it/s, step size=7.28e-04, acc. prob=0.795]
 35%|███▌      | 7/20 [1:17:09<2:36:14, 721.13s/it]

[251, 45, 2, 8]
Smallest group: 1.0, random_state: 7
TEST SIZE 1.0 - 45
(638, 41) 0.7633228840125392


Sample: 100%|██████████| 1100/1100 [04:28,  4.09it/s, step size=8.13e-04, acc. prob=0.889]
 40%|████      | 8/20 [1:24:05<2:04:47, 623.97s/it]

[251, 44, 2, 9]
Smallest group: 1.0, random_state: 8
TEST SIZE 1.0 - 44
(638, 41) 0.7633228840125392


Sample: 100%|██████████| 1100/1100 [06:35,  2.78it/s, step size=4.82e-04, acc. prob=0.965]
 45%|████▌     | 9/20 [1:33:10<1:49:51, 599.23s/it]

[253, 39, 4, 10]
Smallest group: 1.0, random_state: 9
TEST SIZE 1.0 - 39
(638, 41) 0.7633228840125392


Sample: 100%|██████████| 1100/1100 [05:26,  3.37it/s, step size=7.69e-04, acc. prob=0.926]
 50%|█████     | 10/20 [1:41:06<1:33:33, 561.30s/it]

[240, 54, 3, 9]
Smallest group: 1.0, random_state: 10
TEST SIZE 1.0 - 54
(638, 41) 0.7633228840125392


Sample: 100%|██████████| 1100/1100 [07:33,  2.43it/s, step size=3.90e-04, acc. prob=0.958]
 55%|█████▌    | 11/20 [1:51:07<1:26:00, 573.42s/it]

[245, 48, 2, 11]
Smallest group: 3.0, random_state: 11
TEST SIZE 3.0 - 11
(140, 41) 0.7928571428571428


Sample: 100%|██████████| 1100/1100 [05:11,  3.53it/s, step size=7.86e-04, acc. prob=0.892]
 60%|██████    | 12/20 [2:04:31<1:25:48, 643.58s/it]

[244, 50, 5, 7]
Smallest group: 1.0, random_state: 12
TEST SIZE 1.0 - 50
(638, 41) 0.7633228840125392


Sample: 100%|██████████| 1100/1100 [10:09,  1.80it/s, step size=3.11e-04, acc. prob=0.969]
 65%|██████▌   | 13/20 [2:17:09<1:19:08, 678.29s/it]

[261, 39, 0, 6]
Smallest group: 1.0, random_state: 13
TEST SIZE 1.0 - 39
(638, 41) 0.7633228840125392


Sample: 100%|██████████| 1100/1100 [09:03,  2.02it/s, step size=2.78e-04, acc. prob=0.936]
 70%|███████   | 14/20 [2:28:46<1:08:22, 683.81s/it]

[241, 49, 3, 13]
Smallest group: 3.0, random_state: 14
TEST SIZE 3.0 - 13
(140, 41) 0.7928571428571428


Sample: 100%|██████████| 1100/1100 [06:27,  2.84it/s, step size=6.64e-04, acc. prob=0.931]
 75%|███████▌  | 15/20 [2:43:24<1:01:51, 742.28s/it]

[237, 54, 3, 12]
Smallest group: 3.0, random_state: 15
TEST SIZE 3.0 - 12
(140, 41) 0.7928571428571428


Sample: 100%|██████████| 1100/1100 [06:01,  3.04it/s, step size=5.71e-04, acc. prob=0.957]
 80%|████████  | 16/20 [2:57:34<51:38, 774.69s/it]  

[250, 49, 0, 7]
Smallest group: 1.0, random_state: 16
TEST SIZE 1.0 - 49
(638, 41) 0.7633228840125392


Sample: 100%|██████████| 1100/1100 [04:40,  3.92it/s, step size=8.52e-04, acc. prob=0.936]
 85%|████████▌ | 17/20 [3:04:45<33:34, 671.48s/it]

[250, 44, 1, 11]
Smallest group: 3.0, random_state: 17
TEST SIZE 3.0 - 11
(140, 41) 0.7928571428571428


Sample: 100%|██████████| 1100/1100 [05:34,  3.29it/s, step size=6.74e-04, acc. prob=0.945]
 90%|█████████ | 18/20 [3:18:34<23:57, 718.80s/it]

[245, 50, 1, 10]
Smallest group: 1.0, random_state: 18
TEST SIZE 1.0 - 50
(638, 41) 0.7633228840125392


Sample: 100%|██████████| 1100/1100 [07:37,  2.41it/s, step size=4.47e-04, acc. prob=0.950]
 95%|█████████▌| 19/20 [3:28:41<11:25, 685.35s/it]

[244, 52, 2, 8]
Smallest group: 1.0, random_state: 19
TEST SIZE 1.0 - 52
(638, 41) 0.7633228840125392


Sample: 100%|██████████| 1100/1100 [06:21,  2.88it/s, step size=6.52e-04, acc. prob=0.960]
100%|██████████| 20/20 [3:37:34<00:00, 652.72s/it]


In [14]:
dname = "support"
folder = "../results"
for model in list(coverage.keys()):
    mdf_coverage = pd.DataFrame.from_dict(
        coverage[model], orient="index", columns=["Value"]
    )
    mdf_coverage.to_csv(f"{folder}/coverage_{dname}_{model}.csv")

for model in list(width.keys()):
    mdf_width = pd.DataFrame.from_dict(width[model], orient="index", columns=["Value"])
    mdf_width.to_csv(f"{folder}/width_{dname}_{model}.csv")

for model in list(excess.keys()):
    mdf_excess = pd.DataFrame.from_dict(
        excess[model], orient="index", columns=["Value"]
    )
    mdf_excess.to_csv(f"{folder}/excess_{dname}_{model}.csv")

for model in list(deficet.keys()):
    mdf_deficit = pd.DataFrame.from_dict(
        deficet[model], orient="index", columns=["Value"]
    )
    mdf_deficit.to_csv(f"{folder}/deficet_{dname}_{model}.csv")
